xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.h"
17 
18 #include <string>
19 #include <utility>
20 
21 #include "absl/container/flat_hash_map.h"
22 #include "absl/container/flat_hash_set.h"
23 #include "absl/container/inlined_vector.h"
24 #include "absl/strings/str_cat.h"
25 #include "absl/strings/string_view.h"
26 #include "absl/types/optional.h"
27 #include "llvm/ADT/ArrayRef.h"
28 #include "llvm/ADT/DenseSet.h"
29 #include "llvm/ADT/STLExtras.h"
30 #include "llvm/ADT/SmallVector.h"
31 #include "llvm/ADT/StringRef.h"
32 #include "llvm/Support/Casting.h"
33 #include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
34 #include "mlir/IR/Attributes.h"  // from @llvm-project
35 #include "mlir/IR/Builders.h"  // from @llvm-project
36 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
37 #include "mlir/IR/Location.h"  // from @llvm-project
38 #include "mlir/IR/Operation.h"  // from @llvm-project
39 #include "mlir/IR/SymbolTable.h"  // from @llvm-project
40 #include "mlir/IR/Types.h"  // from @llvm-project
41 #include "mlir/Pass/Pass.h"  // from @llvm-project
42 #include "mlir/Pass/PassManager.h"  // from @llvm-project
43 #include "mlir/Support/DebugStringHelper.h"  // from @llvm-project
44 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
45 #include "tensorflow/compiler/mlir/op_or_arg_name_mapper.h"
46 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
47 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
48 #include "tensorflow/compiler/mlir/tensorflow/translate/export_tf_dialect_op.h"
49 #include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h"
50 #include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h"
51 #include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h"
52 #include "tensorflow/compiler/mlir/tensorflow/utils/export_utils.h"
53 #include "tensorflow/compiler/mlir/tensorflow/utils/translate_utils.h"
54 #include "tensorflow/compiler/mlir/tensorflow/utils/verify_suitable_for_graph_export.h"
55 #include "tensorflow/compiler/mlir/utils/name_utils.h"
56 #include "tensorflow/compiler/xla/status_macros.h"
57 #include "tensorflow/core/framework/graph.pb.h"
58 #include "tensorflow/core/framework/graph_to_functiondef.h"
59 #include "tensorflow/core/framework/node_def.pb.h"
60 #include "tensorflow/core/framework/node_def_util.h"
61 #include "tensorflow/core/framework/op.h"
62 #include "tensorflow/core/framework/types.pb.h"
63 #include "tensorflow/core/framework/versions.pb.h"
64 #include "tensorflow/core/graph/algorithm.h"
65 #include "tensorflow/core/graph/graph.h"
66 #include "tensorflow/core/graph/tensor_id.h"
67 #include "tensorflow/core/lib/core/errors.h"
68 #include "tensorflow/core/lib/core/status.h"
69 
70 namespace tensorflow {
71 using llvm::dyn_cast;
72 using llvm::isa;
73 using mlir::BlockArgument;
74 using mlir::Dialect;
75 using mlir::Operation;
76 using mlir::SymbolTable;
77 using mlir::Value;
78 using mlir::func::FuncOp;
79 using stream_executor::port::StatusOr;
80 
81 namespace {
82 
83 constexpr char kDeviceAttr[] = "tf.device";
84 constexpr char kResourceArgUniqueIdAttr[] = "tf._resource_arg_unique_id";
85 constexpr char kEntryFuncAttr[] = "tf.entry_function";
86 constexpr char kAliasingAttr[] = "tf.aliasing_output";
87 
88 // OpOrArgLocNameMapper that legalizes the returned name.
89 class LegalizedOpOrValLocNameMapper : public OpOrArgLocNameMapper {
90  private:
GetName(OpOrVal op_or_val)91   std::string GetName(OpOrVal op_or_val) override {
92     std::string name = OpOrArgLocNameMapper::GetName(op_or_val);
93     assert(!name.empty() && "expected non-empty name");
94     mlir::LegalizeNodeName(name);
95     return name;
96   }
97 };
98 
99 // Finds first inner op if `op` is a tf_executor.island. Otherwise `op` is
100 // returned.
GetIslandInnerOpOrSelf(mlir::Operation * op)101 Operation* GetIslandInnerOpOrSelf(mlir::Operation* op) {
102   auto island = llvm::dyn_cast<mlir::tf_executor::IslandOp>(op);
103   if (island) return &island.GetBody().front();
104   return op;
105 }
106 
107 // Stateful helper class to export a function into a Graph.
108 class Exporter {
109  public:
110   // Converts the given Module to a Graph. The given module should only contain
111   // one entry function, which is identified by name "main". This entry function
112   // is converted to the base of the graph graph. The rest of the functions are
113   // converted to the library functions in that graph.
114   static Status Convert(mlir::ModuleOp module, const GraphExportConfig& configs,
115                         std::unique_ptr<Graph>* graph,
116                         FunctionLibraryDefinition* flib_def,
117                         absl::flat_hash_set<Node*>* control_ret_nodes);
118 
119   // Converts a given FuncOp to a FunctionDef and adds it to the function
120   // definition library
121   static Status ConvertLibFunction(
122       const GraphExportConfig& configs, const Dialect* tf_dialect,
123       const SymbolTable& symbol_table, FuncOp function,
124       FunctionDefLibrary* flib, llvm::SmallDenseSet<FuncOp>& visited_functions);
125 
126   // Converts the given FuncOp to a Graph. The arguments and returns of
127   // function are added to the graph with special op names kArgOp and kRetOp.
128   // Later on, this graph can be converted a function definition and added to
129   // another graph.
130   static StatusOr<std::unique_ptr<Graph>> Convert(
131       const GraphExportConfig& configs, const Dialect* tf_dialect,
132       const SymbolTable& symbol_table, FuncOp function,
133       FunctionDefLibrary* flib, llvm::SmallDenseSet<FuncOp>& visited_functions,
134       absl::flat_hash_set<Node*>* control_ret_nodes);
135 
136  private:
Exporter(Graph * graph,const Dialect * tf_dialect)137   explicit Exporter(Graph* graph, const Dialect* tf_dialect)
138       : graph_(graph), tf_dialect_(tf_dialect) {}
139 
140   Status AddArgumentNode(BlockArgument arg, unsigned index,
141                          llvm::StringRef name);
142   Status AddFetchNode(FuncOp function, mlir::tf_executor::FetchOp fetch,
143                       llvm::ArrayRef<llvm::StringRef> names);
144   Status AddInstructionNode(Operation* inst);
145   Status AddEdge(Operation* inst);
146 
147   StatusOr<std::unique_ptr<NodeDef>> GetArgumentNode(BlockArgument arg,
148                                                      unsigned index,
149                                                      llvm::StringRef name);
150   StatusOr<std::unique_ptr<NodeDef>> GetReturnNode(FuncOp function,
151                                                    Value operand,
152                                                    unsigned index,
153                                                    llvm::StringRef name);
154   Status GetControlRetNodes(mlir::tf_executor::FetchOp fetch,
155                             absl::flat_hash_set<Node*>* control_ret_nodes);
156   // Adds one edge between src_node and dst_node. If it is not a control edge,
157   // an index is used to find out the right operand of the dst_node.
158   Status AddEdgeBetweenNodes(Value src, Node* dst_node, unsigned dst_index);
159 
160   Graph* graph_;
161   LegalizedOpOrValLocNameMapper op_to_name_;
162   absl::flat_hash_map<Operation*, Node*> nodes_;
163   llvm::DenseMap<BlockArgument, Node*> args_;
164   // One single return operation can return multiple results, and each of them
165   // will be converted to one node in the graph.
166   typedef absl::InlinedVector<Node*, 4> NodeVector;
167   absl::flat_hash_map<Operation*, NodeVector> returns_;
168   const mlir::Dialect* tf_dialect_;
169 };
170 
GetArgumentNode(BlockArgument arg,unsigned index,llvm::StringRef name)171 StatusOr<std::unique_ptr<NodeDef>> Exporter::GetArgumentNode(
172     BlockArgument arg, unsigned index, llvm::StringRef name) {
173   auto func = arg.getParentRegion()->getParentOfType<FuncOp>();
174 
175   auto node_def = std::make_unique<NodeDef>();
176   if (!name.empty())
177     node_def->set_name(std::string(ParseTensorName(name.str()).node()));
178   else
179     node_def->set_name(
180         std::string(op_to_name_.GetUniqueName(func.getName().str())));
181 
182   node_def->set_op(FunctionLibraryDefinition::kArgOp);
183 
184   mlir::TensorType arg_type = arg.getType().cast<mlir::TensorType>();
185   if (auto resource_type =
186           arg_type.getElementType().dyn_cast<mlir::TF::ResourceType>()) {
187     llvm::ArrayRef<mlir::TensorType> subtypes = resource_type.getSubtypes();
188     if (!subtypes.empty()) {
189       AttrValue handle_dtypes_attr;
190       AttrValue handle_shapes_attr;
191       for (mlir::TensorType subtype : subtypes) {
192         DataType dtype;
193         TF_RETURN_IF_ERROR(ConvertToDataType(subtype.getElementType(), &dtype));
194         handle_dtypes_attr.mutable_list()->add_type(dtype);
195 
196         SetTensorShapeProto(subtype,
197                             handle_shapes_attr.mutable_list()->add_shape());
198       }
199 
200       (*node_def->mutable_attr())["_handle_dtypes"] = handle_dtypes_attr;
201       (*node_def->mutable_attr())["_handle_shapes"] = handle_shapes_attr;
202     }
203   }
204 
205   TF_RETURN_IF_ERROR(
206       SetShapeAttribute("_output_shapes", arg_type, node_def->mutable_attr()));
207 
208   DataType dtype;
209   TF_RETURN_IF_ERROR(ConvertToDataType(arg_type.getElementType(), &dtype));
210   AttrValue type_attr;
211   type_attr.set_type(dtype);
212   (*node_def->mutable_attr())["T"] = type_attr;
213 
214   AttrValue index_attr;
215   index_attr.set_i(index);
216   (*node_def->mutable_attr())["index"] = index_attr;
217 
218   if (auto device_attr =
219           func.getArgAttrOfType<mlir::StringAttr>(index, kDeviceAttr))
220     *node_def->mutable_device() = device_attr.getValue().str();
221 
222   llvm::ArrayRef<mlir::NamedAttribute> func_arg_i_attrs =
223       func.getArgAttrs(index);
224   absl::flat_hash_set<absl::string_view> attrs_to_ignore = {kDeviceAttr,
225                                                             kAliasingAttr};
226   TF_RETURN_IF_ERROR(ConvertAttributes(func_arg_i_attrs, attrs_to_ignore,
227                                        /*remove_ref_type=*/false,
228                                        node_def->mutable_attr()));
229 
230   return node_def;
231 }
232 
GetReturnNode(FuncOp function,Value operand,unsigned index,llvm::StringRef name)233 StatusOr<std::unique_ptr<NodeDef>> Exporter::GetReturnNode(
234     FuncOp function, Value operand, unsigned index, llvm::StringRef name) {
235   auto node_def = std::make_unique<NodeDef>();
236   if (!name.empty())
237     node_def->set_name(std::string(ParseTensorName(name.str()).node()));
238   else
239     node_def->set_name(
240         std::string(op_to_name_.GetUniqueName(function.getName().str())));
241 
242   node_def->set_op(FunctionLibraryDefinition::kRetOp);
243   DataType dtype;
244   TF_RETURN_IF_ERROR(ConvertToDataType(
245       operand.getType().cast<mlir::TensorType>().getElementType(), &dtype));
246   AttrValue type_attr;
247   type_attr.set_type(dtype);
248   (*node_def->mutable_attr())["T"] = type_attr;
249   AttrValue index_attr;
250   index_attr.set_i(index);
251   (*node_def->mutable_attr())["index"] = index_attr;
252 
253   if (auto device_attr =
254           function.getResultAttrOfType<mlir::StringAttr>(index, kDeviceAttr))
255     *node_def->mutable_device() = device_attr.getValue().str();
256 
257   llvm::ArrayRef<mlir::NamedAttribute> func_res_i_attrs =
258       function.getResultAttrs(index);
259   absl::flat_hash_set<absl::string_view> attrs_to_ignore = {kDeviceAttr};
260   TF_RETURN_IF_ERROR(ConvertAttributes(func_res_i_attrs, attrs_to_ignore,
261                                        /*remove_ref_type=*/false,
262                                        node_def->mutable_attr()));
263 
264   return node_def;
265 }
266 
AddEdgeBetweenNodes(Value src,Node * dst_node,unsigned dst_index)267 Status Exporter::AddEdgeBetweenNodes(Value src, Node* dst_node,
268                                      unsigned dst_index) {
269   if (auto input_result = src.dyn_cast<mlir::OpResult>()) {
270     auto* input_inst = GetIslandInnerOpOrSelf(input_result.getOwner());
271     // Replaces the input node with NextIteration sink if it is a NextIteration
272     // source.
273     if (auto next_iter_source =
274             llvm::dyn_cast<mlir::tf_executor::NextIterationSourceOp>(
275                 input_inst))
276       input_inst = next_iter_source.GetSink();
277 
278     auto node_it = nodes_.find(input_inst);
279     TF_RET_CHECK(node_it != nodes_.end())
280         << "Use of OpResult encountered before def!";
281     if (input_result.getType().isa<mlir::tf_executor::ControlType>()) {
282       graph_->AddControlEdge(node_it->second, dst_node);
283     } else {
284       graph_->AddEdge(node_it->second, input_result.getResultNumber(), dst_node,
285                       dst_index);
286     }
287     return OkStatus();
288   }
289 
290   auto input_arg = src.cast<BlockArgument>();
291   auto input_node_it = args_.find(input_arg);
292   TF_RET_CHECK(input_node_it != args_.end())
293       << "Use of BlockArgument encounted before def!";
294   // For argument, there is only one result output, so the index is always 0.
295   graph_->AddEdge(input_node_it->second, 0, dst_node, dst_index);
296   return OkStatus();
297 }
298 
AddEdge(Operation * inst)299 Status Exporter::AddEdge(Operation* inst) {
300   // For tf_executor.fetch, add only its data edges. Control edges are captured
301   // later.
302   if (auto fetch = llvm::dyn_cast<mlir::tf_executor::FetchOp>(inst)) {
303     for (auto operand_and_idx : llvm::enumerate(fetch.getOperands())) {
304       Value operand = operand_and_idx.value();
305       if (operand.getType().isa<mlir::tf_executor::ControlType>()) break;
306 
307       auto* dst_node = returns_[fetch][operand_and_idx.index()];
308       TF_RETURN_IF_ERROR(AddEdgeBetweenNodes(operand, dst_node, 0));
309     }
310 
311     return OkStatus();
312   }
313 
314   // For tf_executor.NextIteration.Sink, skip its token operand and add data and
315   // control edges with their index offset by 1.
316   if (auto next_iter_sink =
317           llvm::dyn_cast<mlir::tf_executor::NextIterationSinkOp>(inst)) {
318     auto* dst_node = nodes_[inst];
319     TF_RETURN_IF_ERROR(
320         AddEdgeBetweenNodes(next_iter_sink.input(), dst_node, 0));
321     for (auto control_and_idx : llvm::enumerate(next_iter_sink.controlInputs()))
322       TF_RETURN_IF_ERROR(AddEdgeBetweenNodes(control_and_idx.value(), dst_node,
323                                              control_and_idx.index() + 1));
324 
325     return OkStatus();
326   }
327 
328   // For tf_executor.NextIteration.Source, op can be skipped as it is assumed
329   // there are no operands.
330   if (llvm::isa<mlir::tf_executor::NextIterationSourceOp>(inst)) {
331     assert(inst->getNumOperands() == 0);
332     return OkStatus();
333   }
334 
335   Operation* op = GetIslandInnerOpOrSelf(inst);
336   auto* dst_node = nodes_[op];
337   int operand_offset = 0;
338   // For tf_executor.island, add data edges from its wrapped op before control
339   // edges.
340   if (auto island = llvm::dyn_cast<mlir::tf_executor::IslandOp>(inst)) {
341     for (auto operand_and_idx : llvm::enumerate(op->getOperands()))
342       TF_RETURN_IF_ERROR(AddEdgeBetweenNodes(operand_and_idx.value(), dst_node,
343                                              operand_and_idx.index()));
344 
345     operand_offset = op->getNumOperands();
346   }
347 
348   // For all other ops (including tf_executor.island), add remaining edges.
349   for (auto operand_and_idx : llvm::enumerate(inst->getOperands()))
350     TF_RETURN_IF_ERROR(
351         AddEdgeBetweenNodes(operand_and_idx.value(), dst_node,
352                             operand_and_idx.index() + operand_offset));
353 
354   return OkStatus();
355 }
356 
AddInstructionNode(Operation * inst)357 Status Exporter::AddInstructionNode(Operation* inst) {
358   std::unique_ptr<NodeDef> node_def;
359   auto name = op_to_name_.GetUniqueName(inst);
360   // Convert registered TF ops to NodeDef. Only registered ops are handled to
361   // ensure that PopulateDerivedAttrs adds the correct attributes.
362   TF_ASSIGN_OR_RETURN(node_def,
363                       ConvertTFDialectOpToNodeDef(
364                           inst, name, /*ignore_unregistered_attrs=*/false));
365 
366   TF_ASSIGN_OR_RETURN(Node * node, graph_->AddNode(*node_def));
367   DCHECK(node != nullptr);
368   nodes_[inst] = node;
369   return OkStatus();
370 }
371 
IsEntryFunctionArg(BlockArgument arg)372 bool IsEntryFunctionArg(BlockArgument arg) {
373   return arg.getParentRegion()->getParentOfType<FuncOp>().getName() == "main";
374 }
375 
376 // Creates argument nodes from Block argument. If a name is supplied, that
377 // name will be used instead of generating a unique name.
AddArgumentNode(BlockArgument arg,unsigned index,llvm::StringRef name)378 Status Exporter::AddArgumentNode(BlockArgument arg, unsigned index,
379                                  llvm::StringRef name) {
380   TF_ASSIGN_OR_RETURN(auto node_def, GetArgumentNode(arg, index, name));
381   TF_ASSIGN_OR_RETURN(Node * node, graph_->AddNode(*node_def));
382   args_[arg] = node;
383   return OkStatus();
384 }
385 
386 // Creates return nodes per operand of a FetchOp. If names is supplied, those
387 // names will be used per node in order instead of generating a unique name.
AddFetchNode(FuncOp function,mlir::tf_executor::FetchOp fetch,llvm::ArrayRef<llvm::StringRef> names)388 Status Exporter::AddFetchNode(FuncOp function, mlir::tf_executor::FetchOp fetch,
389                               llvm::ArrayRef<llvm::StringRef> names) {
390   auto& return_nodes = returns_[fetch];
391   for (auto operand_and_idx : llvm::enumerate(fetch.getOperands())) {
392     if (operand_and_idx.value().getType().isa<mlir::tf_executor::ControlType>())
393       break;
394 
395     TF_ASSIGN_OR_RETURN(
396         auto node_def,
397         GetReturnNode(function, operand_and_idx.value(),
398                       operand_and_idx.index(),
399                       names.empty() ? "" : names[operand_and_idx.index()]));
400     TF_ASSIGN_OR_RETURN(Node * node, graph_->AddNode(*node_def));
401     return_nodes.push_back(node);
402   }
403   return OkStatus();
404 }
405 
406 // Collects control ret Nodes based on tf_executor.graph's associated
407 // tf_executor.fetch control inputs.
GetControlRetNodes(mlir::tf_executor::FetchOp fetch,absl::flat_hash_set<Node * > * control_ret_nodes)408 Status Exporter::GetControlRetNodes(
409     mlir::tf_executor::FetchOp fetch,
410     absl::flat_hash_set<Node*>* control_ret_nodes) {
411   for (Value fetch_operand : fetch.getOperands()) {
412     if (fetch_operand.getType().isa<mlir::tf_executor::ControlType>()) {
413       Operation* defining_op =
414           GetIslandInnerOpOrSelf(fetch_operand.getDefiningOp());
415       auto node_it = nodes_.find(defining_op);
416       TF_RET_CHECK(node_it != nodes_.end());
417       control_ret_nodes->insert(node_it->second);
418     }
419   }
420   return OkStatus();
421 }
422 
Convert(const GraphExportConfig & configs,const Dialect * tf_dialect,const SymbolTable & symbol_table,FuncOp function,FunctionDefLibrary * flib,llvm::SmallDenseSet<FuncOp> & visited_functions,absl::flat_hash_set<Node * > * control_ret_nodes)423 StatusOr<std::unique_ptr<Graph>> Exporter::Convert(
424     const GraphExportConfig& configs, const Dialect* tf_dialect,
425     const SymbolTable& symbol_table, FuncOp function, FunctionDefLibrary* flib,
426     llvm::SmallDenseSet<FuncOp>& visited_functions,
427     absl::flat_hash_set<Node*>* control_ret_nodes) {
428   mlir::Block& block = function.front();
429 
430   // Extract input & output names if set.
431   llvm::SmallVector<llvm::StringRef, 2> input_names;
432   llvm::SmallVector<llvm::StringRef, 2> output_names;
433   llvm::SmallVector<llvm::StringRef, 2> unique_output_names;
434   auto dict_attr =
435       function->getAttrOfType<mlir::DictionaryAttr>(kEntryFuncAttr);
436   if (dict_attr) {
437     TF_RET_CHECK(dict_attr.get("inputs").isa<mlir::StringAttr>())
438         << "inputs missing in entry function attribute";
439     TF_RET_CHECK(dict_attr.get("outputs").isa<mlir::StringAttr>())
440         << "outputs missing in entry function attribute";
441     dict_attr.get("inputs").cast<mlir::StringAttr>().getValue().split(
442         input_names, ',', /*MaxSplit=*/-1, /*KeepEmpty=*/false);
443     dict_attr.get("outputs").cast<mlir::StringAttr>().getValue().split(
444         output_names, ',', /*MaxSplit=*/-1, /*KeepEmpty=*/false);
445   }
446 
447   auto graph = std::make_unique<Graph>(OpRegistry::Global());
448 
449   // Extract version info.
450   VersionDef versions;
451   auto module = function->getParentOfType<mlir::ModuleOp>();
452   if (mlir::succeeded(ExtractTfVersions(module, &versions))) {
453     graph->set_versions(versions);
454   }
455 
456   Exporter exporter(graph.get(), tf_dialect);
457 
458   auto graph_op = llvm::cast<mlir::tf_executor::GraphOp>(block.front());
459 
460   // Set input and output names and increment the use counter for them to help
461   // generate unique names.
462   if (!output_names.empty()) {
463     const int num_data_results = graph_op.getNumResults();
464     const int64_t output_names_size = output_names.size();
465     TF_RET_CHECK(output_names_size == num_data_results)
466         << "output names (" << output_names.size()
467         << ") != terminator operands (" << num_data_results << ")";
468     llvm::DenseMap<Operation*, llvm::StringRef> output_op_to_name;
469     llvm::StringMap<Operation*> name_to_op;
470     for (const auto& it : llvm::enumerate(graph_op.GetFetch().getOperands())) {
471       // Skip control rets.
472       const int64_t index = it.index();
473       if (index >= num_data_results) break;
474       // TODO(jpienaar): If there is a result index specified, ensure only one
475       // and that it matches the result index of the op.
476       std::string name(output_names[index]);
477       auto tensor_id = ParseTensorName(name);
478       std::string tensor_id_node(tensor_id.node());
479       assert(!tensor_id_node.empty() && "expected non-empty name");
480       mlir::LegalizeNodeName(tensor_id_node);
481 
482       // Ensure name does not get reused.
483       unique_output_names.push_back(
484           exporter.op_to_name_.GetUniqueName(tensor_id_node));
485     }
486   }
487 
488   if (!input_names.empty()) {
489     TF_RET_CHECK(input_names.size() == block.getNumArguments());
490     for (const auto& it : llvm::enumerate(function.getArguments())) {
491       // TODO(lyandy): Update when changing feed/fetch import.
492       std::string name(input_names[it.index()]);
493       assert(!name.empty() && "expected non-empty name");
494       mlir::LegalizeNodeName(name);
495       auto tensor_id = ParseTensorName(name);
496       TF_RET_CHECK(tensor_id.index() == 0)
497           << "input port designation not supported";
498       // Only assign user of argument the input name if the main graph did not
499       // have its _Arg nodes lifted into the functions arguments.
500       // Ensure name does not get reused.
501       (void)exporter.op_to_name_.GetUniqueName(name);
502     }
503   }
504 
505   // Adds nodes for basic block (function) arguments.
506   for (auto it : llvm::enumerate(block.getArguments())) {
507     int index = it.index();
508     auto arg = it.value();
509     mlir::Type type = arg.getType();
510     if (!type.isa<mlir::TensorType>()) {
511       return errors::InvalidArgument(
512           "FuncOps arguments must have tensor types. Found ",
513           mlir::debugString(type), " in function ", function.getName().str());
514     }
515 
516     TF_RETURN_IF_ERROR(exporter.AddArgumentNode(
517         arg, index, !input_names.empty() ? input_names[index] : ""));
518   }
519 
520   auto convert_called_function = [&](llvm::StringRef name) {
521     auto func = symbol_table.lookup<FuncOp>(name);
522     if (func != nullptr) {
523       TF_RETURN_IF_ERROR(ConvertLibFunction(configs, tf_dialect, symbol_table,
524                                             func, flib, visited_functions));
525       // TODO(prakalps): Optimize to only add the requested function to graph
526       // library rather than the all the functions exported so far.
527       TF_RETURN_IF_ERROR(graph->AddFunctionLibrary(*flib));
528     }
529     return OkStatus();
530   };
531 
532   // Adds nodes for operations.
533   for (Operation& inst : graph_op.GetBody()) {
534     for (auto type : inst.getResultTypes())
535       if (!type.isa<mlir::TensorType, mlir::tf_executor::ControlType,
536                     mlir::tf_executor::TokenType>())
537         return errors::InvalidArgument(
538             "Values must be of tensor type, TensorFlow control type, or "
539             "TensorFlow token type. Found ",
540             mlir::debugString(type));
541 
542     if (llvm::isa<mlir::tf_executor::NextIterationSourceOp>(inst)) {
543       // Skip tf_executor.NextIteration.Source as associated
544       // tf_executor.NextIteration.Sink will be used instead.
545       continue;
546     } else if (auto fetch = llvm::dyn_cast<mlir::tf_executor::FetchOp>(inst)) {
547       TF_RETURN_IF_ERROR(
548           exporter.AddFetchNode(function, fetch, unique_output_names));
549     } else if (auto island =
550                    llvm::dyn_cast<mlir::tf_executor::IslandOp>(inst)) {
551       Operation& inner_op = island.GetBody().front();
552       auto op_name = GetTensorFlowOpName(inner_op.getName().getStringRef());
553       if (op_name.ok()) {
554         // If it is TF Control dialect specific op, look up custom operation
555         // in the module and first convert that, then add it to function
556         // definition library
557         // TODO(prakalps): If two functions have cyclic dependence, this will
558         // introduce an infinite loop.
559         TF_RETURN_IF_ERROR(convert_called_function(op_name.ValueOrDie().str()));
560       }
561 
562       if (IsLegacyCallInstruction(&inner_op)) {
563         TF_RETURN_IF_ERROR(convert_called_function(
564             inner_op.getAttrOfType<mlir::SymbolRefAttr>("f")
565                 .getLeafReference()
566                 .getValue()));
567       }
568 
569       TF_RETURN_IF_ERROR(exporter.AddInstructionNode(&inner_op));
570     } else {
571       TF_RETURN_IF_ERROR(exporter.AddInstructionNode(&inst));
572     }
573   }
574   // Adds edges between the argument, operation and return nodes.
575   for (Operation& inst : graph_op.GetBody()) {
576     TF_RETURN_IF_ERROR(exporter.AddEdge(&inst));
577   }
578   // Fixes the edges between the inserted nodes and special "_SOURCE" and
579   // "_SINK".
580   FixupSourceAndSinkEdges(graph.get());
581 
582   TF_RETURN_IF_ERROR(
583       exporter.GetControlRetNodes(graph_op.GetFetch(), control_ret_nodes));
584 
585   return graph;
586 }
587 
ConvertLibFunction(const GraphExportConfig & configs,const Dialect * tf_dialect,const SymbolTable & symbol_table,FuncOp function,FunctionDefLibrary * flib,llvm::SmallDenseSet<FuncOp> & visited_functions)588 Status Exporter::ConvertLibFunction(
589     const GraphExportConfig& configs, const Dialect* tf_dialect,
590     const SymbolTable& symbol_table, FuncOp function, FunctionDefLibrary* flib,
591     llvm::SmallDenseSet<FuncOp>& visited_functions) {
592   // Return early if the function has already been exported.
593   bool is_new_function = visited_functions.insert(function).second;
594   if (!is_new_function) return OkStatus();
595 
596   auto function_name = function.getName().str();
597 
598   // TODO(fengliuai): use a small flib_def to reduce overhead
599   absl::flat_hash_set<Node*> control_ret_nodes;
600   TF_ASSIGN_OR_RETURN(
601       auto sub_graph,
602       Exporter::Convert(configs, tf_dialect, symbol_table, function, flib,
603                         visited_functions, &control_ret_nodes));
604   const auto control_ret = [&](const Node* n) -> std::optional<string> {
605     return control_ret_nodes.contains(n)
606                ? absl::make_optional<string>(n->name())
607                : std::nullopt;
608   };
609   FunctionDef func_def;
610   TF_RETURN_IF_ERROR(
611       GraphToFunctionDef(*sub_graph, function_name, control_ret, &func_def));
612 
613   // The node defs in FunctionDef might contain debug info which was added
614   // by the GraphToFunctionDef method. We should remove it if we don't want
615   // to export them to avoid failing the roundtrip test.
616   if (!configs.export_debug_info) {
617     for (auto& node_def : *func_def.mutable_node_def()) {
618       node_def.clear_experimental_debug_info();
619     }
620   }
621 
622   // Checks for gradient attribute. If present converts the gradient function
623   // and populates the GradientDef.
624   auto grad_string = mlir::TF::TensorFlowDialect::GetGradientAttrName();
625   if (auto attr =
626           function->getAttrOfType<mlir::FlatSymbolRefAttr>(grad_string)) {
627     auto grad_func = symbol_table.lookup<FuncOp>(attr.getValue());
628     TF_RETURN_IF_ERROR(ConvertLibFunction(configs, tf_dialect, symbol_table,
629                                           grad_func, flib, visited_functions));
630     GradientDef grad;
631     grad.set_function_name(function_name);
632     grad.set_gradient_func(grad_func.getName().str());
633     *flib->add_gradient() = grad;
634   }
635 
636   auto stateful_string = mlir::TF::TensorFlowDialect::GetStatefulAttrName();
637   if (auto attr = function->getAttrOfType<mlir::UnitAttr>(stateful_string)) {
638     func_def.mutable_signature()->set_is_stateful(true);
639   }
640 
641   // Ignore the gradient and is_stateful attribute on the function as they have
642   // been handled above. Ignore the entry func attribute as it is an MLIR
643   // metadata attribute and is not required in the function definition.
644   absl::flat_hash_set<absl::string_view> attrs_to_ignore = {
645       grad_string.data(), stateful_string.data(), kEntryFuncAttr};
646   llvm::SmallVector<mlir::NamedAttribute, 8> funcAttrs(
647       function->getDialectAttrs());
648   TF_RETURN_IF_ERROR(ConvertAttributes(funcAttrs, attrs_to_ignore,
649                                        /*remove_ref_type=*/false,
650                                        func_def.mutable_attr()));
651 
652   for (int i = 0, e = function.getNumArguments(); i < e; ++i) {
653     if (auto resource_arg_unique_id_attr =
654             function.getArgAttrOfType<mlir::IntegerAttr>(
655                 i, kResourceArgUniqueIdAttr)) {
656       (*func_def.mutable_resource_arg_unique_id())[i] =
657           resource_arg_unique_id_attr.getInt();
658     }
659   }
660 
661   (*flib->add_function()) = std::move(func_def);
662   return OkStatus();
663 }
664 
Convert(mlir::ModuleOp module,const GraphExportConfig & configs,std::unique_ptr<Graph> * graph,FunctionLibraryDefinition * flib_def,absl::flat_hash_set<Node * > * control_ret_nodes)665 Status Exporter::Convert(mlir::ModuleOp module,
666                          const GraphExportConfig& configs,
667                          std::unique_ptr<Graph>* graph,
668                          FunctionLibraryDefinition* flib_def,
669                          absl::flat_hash_set<Node*>* control_ret_nodes) {
670   mlir::StringAttr entry_func_id =
671       mlir::StringAttr::get(module.getContext(), "main");
672   std::optional<FuncOp> entry_func;
673   FunctionDefLibrary flib;
674   llvm::SmallDenseSet<FuncOp> visited_functions;
675   auto tf_dialect = module.getContext()->getLoadedDialect("tf");
676   // Construct SymbolTable to enable cheap function lookups. The cost
677   // of constructing the table is offset by the number of queries.
678   SymbolTable symbol_table(module);
679   for (auto function : module.getOps<FuncOp>()) {
680     if (function.isExternal())
681       return errors::FailedPrecondition("External functions not supported");
682 
683     if (function.getName() == entry_func_id &&
684         !configs.export_entry_func_to_flib) {
685       entry_func.emplace(function);
686     } else {
687       TF_RETURN_IF_ERROR(ConvertLibFunction(configs, tf_dialect, symbol_table,
688                                             function, &flib,
689                                             visited_functions));
690     }
691   }
692 
693   if (!configs.export_entry_func_to_flib) {
694     if (!entry_func.has_value())
695       return errors::FailedPrecondition(
696           "entry function `main` must be present");
697 
698     // Updates the graph and the function library definition.
699     TF_ASSIGN_OR_RETURN(
700         *graph,
701         Exporter::Convert(configs, tf_dialect, symbol_table, entry_func.value(),
702                           &flib, visited_functions, control_ret_nodes));
703     // Add FunctionDefs and GradientDefs of MLIR functions to graph's function
704     // library. If duplicate FunctionDefs already exist (can happen if exporter
705     // had already added some FunctionDefs to the library to support legacy
706     // calls), they are ignored.
707     TF_RETURN_IF_ERROR(graph->get()->AddFunctionLibrary(flib));
708   }
709 
710   for (auto& func_def : flib.function()) {
711     TF_RETURN_IF_ERROR(flib_def->AddFunctionDef(func_def));
712   }
713   for (auto& grad_def : flib.gradient()) {
714     TF_RETURN_IF_ERROR(flib_def->AddGradientDef(grad_def));
715   }
716   return OkStatus();
717 }
718 }  // namespace
719 
ConvertMlirToGraph(mlir::ModuleOp module,const GraphExportConfig & configs,std::unique_ptr<Graph> * graph,FunctionLibraryDefinition * flib_def,absl::flat_hash_set<Node * > * control_ret_nodes)720 Status ConvertMlirToGraph(mlir::ModuleOp module,
721                           const GraphExportConfig& configs,
722                           std::unique_ptr<Graph>* graph,
723                           FunctionLibraryDefinition* flib_def,
724                           absl::flat_hash_set<Node*>* control_ret_nodes) {
725   mlir::StatusScopedDiagnosticHandler sh(module.getContext());
726   if (failed(VerifyExportSuitable(module))) return sh.ConsumeStatus();
727   return sh.Combine(
728       Exporter::Convert(module, configs, graph, flib_def, control_ret_nodes));
729 }
730 
ConvertMlirToGraph(mlir::ModuleOp module,const GraphExportConfig & configs,std::unique_ptr<Graph> * graph,FunctionLibraryDefinition * flib_def)731 Status ConvertMlirToGraph(mlir::ModuleOp module,
732                           const GraphExportConfig& configs,
733                           std::unique_ptr<Graph>* graph,
734                           FunctionLibraryDefinition* flib_def) {
735   absl::flat_hash_set<Node*> control_ret_nodes;
736   return ConvertMlirToGraph(module, configs, graph, flib_def,
737                             &control_ret_nodes);
738 }
739 
ConvertMlirToGraphdef(mlir::ModuleOp module,const GraphExportConfig & configs)740 StatusOr<std::unique_ptr<GraphDef>> ConvertMlirToGraphdef(
741     mlir::ModuleOp module, const GraphExportConfig& configs) {
742   FunctionLibraryDefinition flib_def(OpRegistry::Global(),
743                                      FunctionDefLibrary());
744   std::unique_ptr<Graph> graph;
745   TF_RETURN_IF_ERROR(ConvertMlirToGraph(module, configs, &graph, &flib_def));
746 
747   // If the entry function is exported to flib, then no graph is constructed.
748   // Construct one in that case.
749   if (configs.export_entry_func_to_flib) {
750     graph = std::make_unique<Graph>(OpRegistry::Global());
751     // TODO(hinsu): Avoid Proto -> Memory -> Proto conversion here.
752     FunctionDefLibrary flib = flib_def.ToProto();
753     TF_RETURN_IF_ERROR(graph->AddFunctionLibrary(flib));
754   }
755 
756   auto graphdef = std::make_unique<GraphDef>();
757   graph->ToGraphDef(graphdef.get());
758   if (!configs.export_library) graphdef->clear_library();
759   if (!configs.export_shapes) {
760     for (auto& node_def : *graphdef->mutable_node()) {
761       node_def.mutable_attr()->erase("shape");
762     }
763   }
764   if (!configs.export_debug_info) {
765     for (auto& node_def : *graphdef->mutable_node()) {
766       node_def.clear_experimental_debug_info();
767     }
768   }
769   return graphdef;
770 }
771 
ConvertMlirFunctionToFunctionLibraryDef(FuncOp func,const GraphExportConfig & configs,FunctionDef * function_def)772 stream_executor::port::Status ConvertMlirFunctionToFunctionLibraryDef(
773     FuncOp func, const GraphExportConfig& configs, FunctionDef* function_def) {
774   Dialect* tf_dialect = func.getContext()->getLoadedDialect("tf");
775   FunctionDefLibrary flib;
776   llvm::SmallDenseSet<FuncOp> visited_functions;
777   // Construct SymbolTable to enable cheap function lookups. The cost
778   // of constructing the table is offset by the number of queries. Even
779   // though this only converts one function in theory, this function
780   // may have gradient associated which would result in a lookup. This
781   // could be made lazy if we find this to be broad.
782   SymbolTable symbol_table(func->getParentOfType<mlir::ModuleOp>());
783   TF_RETURN_IF_ERROR(Exporter::ConvertLibFunction(
784       configs, tf_dialect, symbol_table, func, &flib, visited_functions));
785   for (auto& func_def : flib.function()) {
786     if (func_def.signature().name() == func.getName()) {
787       *function_def = func_def;
788       return OkStatus();
789     }
790   }
791   return errors::InvalidArgument(
792       "Function couldn't be found in the FunctionDefLibrary after converting "
793       "from MLIR");
794 }
795 
796 }  // namespace tensorflow
797