xref: /aosp_15_r20/external/tensorflow/tensorflow/core/ir/importexport/graphdef_import.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/ir/importexport/graphdef_import.h"
17 
18 #include <string>
19 #include <utility>
20 
21 #include "absl/container/flat_hash_map.h"
22 #include "llvm/ADT/STLExtras.h"
23 #include "llvm/ADT/ScopeExit.h"
24 #include "llvm/Support/raw_ostream.h"
25 #include "mlir/IR/Builders.h"  // from @llvm-project
26 #include "mlir/IR/BuiltinAttributes.h"  // from @llvm-project
27 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
28 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
29 #include "mlir/IR/Location.h"  // from @llvm-project
30 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
31 #include "mlir/IR/OperationSupport.h"  // from @llvm-project
32 #include "mlir/IR/OwningOpRef.h"  // from @llvm-project
33 #include "mlir/IR/Threading.h"  // from @llvm-project
34 #include "mlir/Support/LLVM.h"  // from @llvm-project
35 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
36 #include "tensorflow/core/framework/full_type.pb.h"
37 #include "tensorflow/core/framework/function.h"
38 #include "tensorflow/core/framework/function.pb.h"
39 #include "tensorflow/core/framework/graph.pb.h"
40 #include "tensorflow/core/framework/node_def.pb.h"
41 #include "tensorflow/core/framework/node_def_util.h"
42 #include "tensorflow/core/framework/op.h"
43 #include "tensorflow/core/framework/op_def.pb.h"
44 #include "tensorflow/core/framework/op_def_builder.h"
45 #include "tensorflow/core/framework/versions.pb.h"
46 #include "tensorflow/core/graph/graph.h"
47 #include "tensorflow/core/graph/tensor_id.h"
48 #include "tensorflow/core/ir/dialect.h"
49 #include "tensorflow/core/ir/importexport/convert_attributes.h"
50 #include "tensorflow/core/ir/importexport/convert_types.h"
51 #include "tensorflow/core/ir/importexport/functiondef_import.h"
52 #include "tensorflow/core/ir/ops.h"
53 #include "tensorflow/core/ir/types/dialect.h"
54 #include "tensorflow/core/platform/errors.h"
55 #include "tensorflow/core/platform/statusor.h"
56 #include "tensorflow/core/platform/stringpiece.h"
57 #include "tensorflow/core/protobuf/graph_debug_info.pb.h"
58 
59 using tensorflow::DataType;
60 using tensorflow::DataTypeVector;
61 using tensorflow::FullTypeDef;
62 using tensorflow::FunctionDef;
63 using tensorflow::FunctionLibraryDefinition;
64 using tensorflow::Graph;
65 using tensorflow::GraphDebugInfo;
66 using tensorflow::GraphDef;
67 using tensorflow::NodeDef;
68 using tensorflow::OpDef;
69 using tensorflow::OpRegistrationData;
70 using tensorflow::OpRegistry;
71 using tensorflow::Status;
72 using tensorflow::StatusOr;
73 using tensorflow::StringPiece;
74 using tensorflow::TensorId;
75 using tensorflow::VersionDef;
76 using tensorflow::errors::InvalidArgument;
77 using tensorflow::errors::NotFound;
78 
79 namespace mlir {
80 namespace tfg {
81 namespace {
82 // This class implements an importer for GraphDef directly to TFG.
83 class GraphDefImporter {
84  public:
85   // Initialize the importer.
GraphDefImporter(TFGraphDialect * dialect,const OpRegistry & registry,const GraphDebugInfo & debug_info)86   GraphDefImporter(TFGraphDialect *dialect, const OpRegistry &registry,
87                    const GraphDebugInfo &debug_info)
88       : ctx_(dialect->getContext()),
89         dialect_(dialect),
90         b_(ctx_),
91         registry_(registry),
92         debug_info_(debug_info),
93         unknown_loc_(UnknownLoc::get(ctx_)),
94         placeholder_state_(unknown_loc_, "tfg._mlir_placeholder") {
95     placeholder_state_.addTypes(dialect_->getControlType());
96   }
97 
98   // Convert a GraphDef to MLIR module.
99   StatusOr<OwningOpRef<ModuleOp>> ConvertGraphDef(const GraphDef &graph);
100 
101  private:
102   // Convert a function. This function must be thread-safe.
103   Status ConvertFunctionDef(
104       GraphFuncOp func_op,
105       const absl::flat_hash_map<StringPiece, StringPiece> &gradient_map,
106       const FunctionDef &function);
107 
108   // A result ID representing an output of `node`. E.g.
109   // "foo" -> {0, "foo", ""}
110   // "foo:2" -> {2, "foo", ""}
111   // "foo:output:0" -> {0, "foo", "output"}
112   struct ResultId {
113     // The result or result segment index.
114     int index;
115     // The name of the parent node.
116     StringRef node;
117     // An optional result segment name.
118     StringRef output;
119 
120     // Returns true if the result ID references the control token.
IsControlmlir::tfg::__anon07e4e3680111::GraphDefImporter::ResultId121     bool IsControl() const { return index == tensorflow::Graph::kControlSlot; }
122   };
123 
124   // An unresolved backedge.
125   struct Backedge {
126     // The edge name and index.
127     ResultId id;
128     // The OpOperand to resolve;
129     OpOperand *operand;
130   };
131 
132   // Cached info about the result of an operation.
133   struct ResultInfo {
134     // This flag is true if the results of the operation have been resolved; the
135     // operation has been created and its `data` and `control` results have been
136     // populated. If false, the placeholder should be used.
137     bool resolved = false;
138     // The control result.
139     Value control;
140     // All data results.
141     ValueRange data;
142     // Data results organized by output name.
143     absl::flat_hash_map<StringPiece, ValueRange> outputs;
144     // A list of unresolved backedges.
145     std::vector<Backedge> backedges;
146   };
147 
148   // State when converting a list of nodes.
149   class ConversionState
150       : public absl::flat_hash_map<StringPiece, std::unique_ptr<ResultInfo>> {
151    public:
152     // Create a conversion state with a placeholder value. Put the plaecholder
153     // in the block so that it is owned.
ConversionState(Block * block,const OperationState & placeholder_state)154     explicit ConversionState(Block *block,
155                              const OperationState &placeholder_state)
156         : placeholder_op_(
157               OpBuilder::atBlockBegin(block).create(placeholder_state)),
158           placeholder_(placeholder_op_->getResult(0)) {}
159 
160     // Get the placeholder value.
GetPlaceholder()161     Value GetPlaceholder() { return placeholder_; }
162 
163     // Finalize the conversion. The placeholder is destroyed.
Finalize()164     void Finalize() { placeholder_op_->erase(); }
165 
166    private:
167     // The placeholder operation.
168     Operation *placeholder_op_;
169     // The placeholder value.
170     Value placeholder_;
171   };
172   // Convert a list a nodes to operations.
173   Status ConvertNodes(
174       OpBuilder &builder, ConversionState &s,
175       const tensorflow::protobuf::RepeatedPtrField<NodeDef> &nodes,
176       Block *block);
177   // Convert a node to an operation.
178   Status ConvertNodeDef(OpBuilder &builder, ConversionState &s,
179                         const NodeDef &node);
180   // Resolve a data result reference.
181   static StatusOr<Value> ResolveDataResult(const ResultId &id,
182                                            ResultInfo *info);
183 
184   // Get a named result.
185   struct Result {
186     Value control;
187     Value data;
188     ResultId id;
189     ResultInfo *info = nullptr;
190   };
191   StatusOr<Result> GetResult(ConversionState &s, StringPiece name);
192 
193   // Convert TF datatypes to unranked MLIR tensor types.
194   Status ConvertDataTypesToUnrankedTensorTypes(const DataTypeVector &dtypes,
195                                                SmallVectorImpl<Type> &results);
196   // Extracts the actual data types from `attrs` based on its definition in
197   // `arg_def` and converts them to unranked tensors. Returns the number of
198   // added types.
199   //
200   // TODO(jeffniu): This is a re-implementation of `ArgNumType` in
201   // `core/framework/function.cc` on `NamedAttrList` because the default
202   // attributes need to be added. Find a way to do this in one pass.
203   StatusOr<unsigned> ArgNumType(const NamedAttrList &attrs,
204                                 const OpDef::ArgDef &arg_def,
205                                 SmallVectorImpl<Type> &types);
206   // Convert function attributes to MLIR attributes.
207   Status ConvertFunctionAttributes(
208       const absl::flat_hash_map<StringPiece, StringPiece> &gradient_map,
209       const FunctionDef &function, GraphFuncOp op, NamedAttrList &attrs);
210   // Convert function argument attributes to MLIR attributes.
211   Status ConvertArgumentAttributes(const OpDef::ArgDef &def,
212                                    NamedAttrList &attrs);
213   // Create a location for a node.
214   Location ConvertLocation(const NodeDef &node);
215   // Convert the location of a node from the debug info. If it has no debug
216   // info, return a NameLoc.
217   Location ConvertLocation(StringRef node_name, StringRef func_name);
218 
219   // The MLIR context.
220   MLIRContext *ctx_;
221   // Reference to the TFG dialect.
222   TFGraphDialect *dialect_;
223   // The builder instance.
224   Builder b_;
225   // The TF op registry to use.
226   const OpRegistry &registry_;
227   // The debug info about the graph.
228   const GraphDebugInfo &debug_info_;
229   // Cached unknown location.
230   Location unknown_loc_;
231   // Operation state for creating placeholder ops.
232   OperationState placeholder_state_;
233 
234   // Map of function OpDefs.
235   absl::flat_hash_map<StringPiece, const OpDef *> function_op_defs_;
236 };
237 }  // namespace
238 
239 // Convert a VersionDef to an MLIR version attribute.
ConvertVersionAttr(MLIRContext * context,const VersionDef & version)240 static VersionAttr ConvertVersionAttr(MLIRContext *context,
241                                       const VersionDef &version) {
242   ArrayRef<int32_t> bad_consumers(version.bad_consumers().data(),
243                                   version.bad_consumers().size());
244   return VersionAttr::get(context, version.producer(), version.min_consumer(),
245                           bad_consumers);
246 }
247 
248 // Returns true if the function is a generic function, i.e. it contains
249 // placeholder attributes.
250 //
251 // TODO(jeffniu): Having to iterate over every function just to check for
252 // placeholder attributes is slow. Since most functions are not generic, we can
253 // speculate by converting all functions as non-generic until we see a
254 // placeholder attribute, bail out, and fall back to the generic function
255 // converter.
IsGenericFunction(const FunctionDef & fdef)256 static bool IsGenericFunction(const FunctionDef &fdef) {
257   for (const NodeDef &node : fdef.node_def())
258     for (const auto &named_attr : node.attr())
259       if (!named_attr.second.placeholder().empty()) return true;
260 
261   return false;
262 }
263 
ConvertGraphDef(const GraphDef & graph)264 StatusOr<OwningOpRef<ModuleOp>> GraphDefImporter::ConvertGraphDef(
265     const GraphDef &graph) {
266   // Create the module.
267   OwningOpRef<ModuleOp> module = ModuleOp::create(unknown_loc_);
268 
269   // Create the graph op.
270   auto builder = OpBuilder::atBlockBegin(module->getBody());
271   auto graph_op = builder.create<GraphOp>(
272       module->getLoc(), ConvertVersionAttr(ctx_, graph.versions()));
273   graph_op.nodes().push_back(new Block);
274 
275   // Populate the function op defs.
276   function_op_defs_.reserve(graph.library().function_size());
277   for (const FunctionDef &function : graph.library().function()) {
278     function_op_defs_.emplace(function.signature().name(),
279                               &function.signature());
280   }
281 
282   // Build a map from function name to gradient function name.
283   absl::flat_hash_map<StringPiece, StringPiece> gradient_map;
284   gradient_map.reserve(graph.library().gradient_size());
285   for (const tensorflow::GradientDef &gradient : graph.library().gradient())
286     gradient_map.emplace(gradient.function_name(), gradient.gradient_func());
287 
288   // Convert the graph.
289   ConversionState s(&graph_op.nodes().front(), placeholder_state_);
290   TF_RETURN_IF_ERROR(
291       ConvertNodes(builder, s, graph.node(), &graph_op.nodes().front()));
292 
293   // A function to convert a generic or non-generic function.
294   const auto convert_func = [this, &gradient_map](GraphFuncOp func_op,
295                                                   const FunctionDef &function) {
296     if (IsGenericFunction(function)) {
297       // Generic functions aren't on the hot path so just call the old
298       // importer.
299       OpBuilder builder(ctx_);
300       TF_RETURN_WITH_CONTEXT_IF_ERROR(
301           ConvertGenericFunction(func_op, function, builder),
302           "While importing generic function: ", function.signature().name());
303     } else {
304       TF_RETURN_WITH_CONTEXT_IF_ERROR(
305           ConvertFunctionDef(func_op, gradient_map, function),
306           "While importing function: ", function.signature().name());
307     }
308     return ::tensorflow::OkStatus();
309   };
310 
311   // TODO(jeffniu): Don't import functions in parallel if there are too few (how
312   // few?) or if the functions are too small (how small?).
313   if (ctx_->isMultithreadingEnabled()) {
314     ctx_->enterMultiThreadedExecution();
315     auto exit =
316         llvm::make_scope_exit([this] { ctx_->exitMultiThreadedExecution(); });
317 
318     // Prepare the arguments to parallel for each.
319     struct Argument {
320       GraphFuncOp func;
321       const FunctionDef &def;
322       Status status;
323     };
324     std::vector<Argument> args;
325     args.reserve(graph.library().function_size());
326     for (const FunctionDef &function : graph.library().function()) {
327       args.push_back(
328           Argument{builder.create<GraphFuncOp>(unknown_loc_), function});
329     }
330     const auto process_func = [&convert_func](Argument &arg) {
331       arg.status = convert_func(arg.func, arg.def);
332       return success(arg.status.ok());
333     };
334 
335     // Execute the imports in parallel.
336     if (failed(failableParallelForEach(ctx_, args, process_func))) {
337       Status result;
338       for (const Argument &arg : args) {
339         result.Update(arg.status);
340       }
341       return result;
342     }
343   } else {
344     // Convert the functions.
345     for (const FunctionDef &function : graph.library().function()) {
346       auto func_op = builder.create<GraphFuncOp>(unknown_loc_);
347       TF_RETURN_IF_ERROR(convert_func(func_op, function));
348     }
349   }
350 
351   return module;
352 }
353 
ConvertFunctionAttributes(const absl::flat_hash_map<StringPiece,StringPiece> & gradient_map,const FunctionDef & function,GraphFuncOp op,NamedAttrList & attrs)354 Status GraphDefImporter::ConvertFunctionAttributes(
355     const absl::flat_hash_map<StringPiece, StringPiece> &gradient_map,
356     const FunctionDef &function, GraphFuncOp op, NamedAttrList &attrs) {
357   // Import the function attributes with a `tf.` prefix to match the current
358   // infratructure expectations.
359   for (const auto &name_attr : function.attr()) {
360     if (name_attr.first.empty()) {
361       return InvalidArgument("Function ", function.signature().name(),
362                              " has an empty attr name");
363     }
364     // TODO(b/230143351): `ConvertAttributeValue` is a little slow due to
365     // `ConvertTensorProto` and `ConvertTensorShapeProto`.
366     TF_ASSIGN_OR_RETURN(Attribute attr,
367                         ConvertAttributeValue(name_attr.second, b_));
368     attrs.append(absl::StrCat("tf.", name_attr.first), attr);
369   }
370 
371   // Convert the first-class attributes.
372   const tensorflow::OpDef &signature = function.signature();
373   if (signature.name().empty())
374     return InvalidArgument("Function without a name");
375   attrs.append(op.sym_nameAttrName(), b_.getStringAttr(signature.name()));
376 
377   if (!signature.description().empty()) {
378     attrs.append(op.descriptionAttrName(),
379                  b_.getStringAttr(signature.description()));
380   }
381   if (signature.is_stateful())
382     attrs.append(op.is_statefulAttrName(), b_.getUnitAttr());
383   auto grad_it = gradient_map.find(signature.name());
384   if (grad_it != gradient_map.end()) {
385     StringPiece name = grad_it->second;
386     attrs.append(op.gradientAttrName(),
387                  FlatSymbolRefAttr::get(ctx_, {name.data(), name.size()}));
388   }
389 
390   // The resource_arg_unique_id is a list of `pair<int, int>`, we import it
391   // as two arrays of integer right now.
392   if (function.resource_arg_unique_id_size()) {
393     SmallVector<int32_t> resource_arg_unique_ids_keys;
394     SmallVector<int32_t> resource_arg_unique_ids_values;
395     resource_arg_unique_ids_keys.reserve(
396         function.resource_arg_unique_id_size());
397     resource_arg_unique_ids_values.reserve(
398         function.resource_arg_unique_id_size());
399     for (const auto &unique_id : function.resource_arg_unique_id()) {
400       resource_arg_unique_ids_keys.push_back(unique_id.first);
401       resource_arg_unique_ids_values.push_back(unique_id.second);
402     }
403     attrs.append(op.resource_arg_unique_ids_keysAttrName(),
404                  b_.getI32TensorAttr(resource_arg_unique_ids_keys));
405     attrs.append(op.resource_arg_unique_ids_valuesAttrName(),
406                  b_.getI32TensorAttr(resource_arg_unique_ids_values));
407   }
408   return ::tensorflow::OkStatus();
409 }
410 
ConvertArgumentAttributes(const OpDef::ArgDef & def,NamedAttrList & attrs)411 Status GraphDefImporter::ConvertArgumentAttributes(const OpDef::ArgDef &def,
412                                                    NamedAttrList &attrs) {
413   attrs.append(dialect_->getTfgNameAttrIdentifier(),
414                b_.getStringAttr(def.name()));
415   if (!def.description().empty()) {
416     attrs.append(dialect_->getTfgDescriptionAttrIdentifier(),
417                  b_.getStringAttr(def.description()));
418   }
419   if (def.is_ref())
420     attrs.append(dialect_->getTfgIsRefAttrIdentifier(), b_.getUnitAttr());
421   if (def.handle_data_size()) {
422     TF_ASSIGN_OR_RETURN(Attribute handle_data,
423                         ConvertHandleData(b_, def.handle_data()));
424     attrs.append(dialect_->getTfgHandleDataAttrIdentifier(), handle_data);
425   }
426   if (def.has_experimental_full_type()) {
427     TF_ASSIGN_OR_RETURN(tf_type::FullTypeAttr full_type,
428                         ConvertAttribute(def.experimental_full_type(), b_));
429     attrs.append(dialect_->getTfgFullTypeAttrIdentifier(), full_type);
430   }
431   return ::tensorflow::OkStatus();
432 }
433 
ConvertLocation(const NodeDef & node)434 Location GraphDefImporter::ConvertLocation(const NodeDef &node) {
435   if (!node.has_experimental_debug_info()) return unknown_loc_;
436 
437   const auto &debug_info = node.experimental_debug_info();
438   const auto &original_nodes = debug_info.original_node_names();
439   const auto &original_funcs = debug_info.original_func_names();
440   if (original_nodes.empty()) return unknown_loc_;
441 
442   SmallVector<Location> node_locs;
443   node_locs.reserve(original_nodes.size());
444   for (auto &it : llvm::enumerate(original_nodes)) {
445     std::string func_name =
446         it.index() < original_funcs.size() ? original_funcs[it.index()] : "";
447     node_locs.push_back(ConvertLocation(it.value(), func_name));
448   }
449   return b_.getFusedLoc(node_locs);
450 }
451 
452 // This is a re-implementation of GetLocation in `import.cc`.
ConvertLocation(StringRef node_name,StringRef func_name)453 Location GraphDefImporter::ConvertLocation(StringRef node_name,
454                                            StringRef func_name) {
455   // Concatenate the node name with the function name to match how the key is
456   // formed in Python.
457   std::string debug_info_key = (node_name + "@" + func_name).str();
458   std::string name_loc = func_name.empty() ? node_name.str() : debug_info_key;
459   auto name_loc_id = b_.getStringAttr(name_loc);
460 
461   SmallVector<Location> locs;
462   const auto &traces = debug_info_.traces();
463   // Try to find a stack trace to convert to locations.
464   auto it = traces.find(debug_info_key);
465   if (it != traces.end()) {
466     const auto &trace = it->second;
467     locs.reserve(trace.file_line_cols_size());
468     for (const auto &loc : trace.file_line_cols()) {
469       auto file_name = b_.getStringAttr(debug_info_.files(loc.file_index()));
470       locs.push_back(FileLineColLoc::get(file_name, loc.line(), loc.col()));
471     }
472   }
473 
474   if (locs.empty()) return NameLoc::get(name_loc_id);
475 
476   // Use the first location to generate a name location.
477   Location node_name_loc = NameLoc::get(name_loc_id, locs.front());
478   // Generate a stack trace using the remaining locations.
479   ArrayRef<Location> callsite_locs = llvm::makeArrayRef(locs).drop_front();
480   return callsite_locs.empty() ? node_name_loc
481                                : CallSiteLoc::get(node_name_loc, callsite_locs);
482 }
483 
ResolveDataResult(const ResultId & id,ResultInfo * info)484 StatusOr<Value> GraphDefImporter::ResolveDataResult(const ResultId &id,
485                                                     ResultInfo *info) {
486   if (id.output.empty()) {
487     if (id.index >= info->data.size()) {
488       return InvalidArgument("Result #", id.index, " of node '", id.node.str(),
489                              "' is out of bounds");
490     }
491     return info->data[id.index];
492   }
493 
494   auto it = info->outputs.find({id.output.data(), id.output.size()});
495   if (it == info->outputs.end()) {
496     return InvalidArgument("Node '", id.node.str(), "' has no output called '",
497                            id.output.str(), "'");
498   }
499   if (id.index >= it->second.size()) {
500     return InvalidArgument("Result #", id.index, " of segment '", id.node.str(),
501                            ":", id.output.str(), "' is out of bounds");
502   }
503   return it->second[id.index];
504 }
505 
GetResult(ConversionState & s,StringPiece name)506 StatusOr<GraphDefImporter::Result> GraphDefImporter::GetResult(
507     ConversionState &s, StringPiece name) {
508   TensorId tensor_id = tensorflow::ParseTensorName(name);
509   ResultId id{tensor_id.index()};
510   std::tie(id.node, id.output) =
511       StringRef(tensor_id.node().data(), tensor_id.node().size()).split(':');
512   std::unique_ptr<ResultInfo> &info = s[{id.node.data(), id.node.size()}];
513   if (!info) {
514     info = std::make_unique<ResultInfo>();
515   }
516 
517   // If the result is unresolved, return the placeholder;
518   if (!info->resolved) {
519     if (id.IsControl()) {
520       return Result{s.GetPlaceholder(), Value(), id, info.get()};
521     }
522     return Result{Value(), s.GetPlaceholder(), id, info.get()};
523   }
524 
525   // If the result is the control token, return it.
526   if (id.IsControl()) {
527     return Result{info->control, Value()};
528   }
529 
530   TF_ASSIGN_OR_RETURN(Value value, ResolveDataResult(id, info.get()));
531   return Result{Value(), value};
532 }
533 
ConvertFunctionDef(GraphFuncOp func_op,const absl::flat_hash_map<StringPiece,StringPiece> & gradient_map,const FunctionDef & function)534 Status GraphDefImporter::ConvertFunctionDef(
535     GraphFuncOp func_op,
536     const absl::flat_hash_map<StringPiece, StringPiece> &gradient_map,
537     const FunctionDef &function) {
538   const OpDef &signature = function.signature();
539   // TODO(jeffniu): Does the name need to be mangled?
540 
541   func_op.body().push_back(new Block);
542   Block *body = &func_op.body().front();
543   auto builder = OpBuilder::atBlockBegin(func_op.getBody());
544 
545   // Convert the attributes.
546   NamedAttrList func_attrs;
547   TF_RETURN_IF_ERROR(
548       ConvertFunctionAttributes(gradient_map, function, func_op, func_attrs));
549 
550   SmallVector<Attribute> arg_attrs, res_attrs, control_ret_attrs;
551   SmallVector<Type> arg_types, res_types;
552 
553   // Convert the arguments and argument attributes.
554   for (auto &it : llvm::enumerate(signature.input_arg())) {
555     Type dtype;
556     TF_RETURN_IF_ERROR(ConvertDataType(it.value().type(), b_, &dtype));
557     BlockArgument data =
558         body->addArgument(UnrankedTensorType::get(dtype), unknown_loc_);
559     BlockArgument ctl =
560         body->addArgument(dialect_->getControlType(), data.getLoc());
561 
562     NamedAttrList attrs;
563     TF_RETURN_IF_ERROR(ConvertArgumentAttributes(it.value(), attrs));
564     auto attr_it = function.arg_attr().find(it.index());
565     if (attr_it != function.arg_attr().end()) {
566       for (const auto &name_attr : attr_it->second.attr()) {
567         TF_ASSIGN_OR_RETURN(Attribute attr,
568                             ConvertAttributeValue(name_attr.second, b_));
569         attrs.append("tf." + name_attr.first, attr);
570       }
571     }
572 
573     arg_attrs.append({attrs.getDictionary(ctx_), b_.getDictionaryAttr({})});
574     arg_types.append({data.getType(), ctl.getType()});
575   }
576 
577   // Iterate over the arguments again and map them. We have to add them first
578   // otherwise the ranges will be invalidated.
579   ConversionState s(body, placeholder_state_);
580   for (const auto &it : llvm::enumerate(signature.input_arg())) {
581     s.emplace(
582         it.value().name(),
583         new ResultInfo{/*resolved=*/true, body->getArgument(it.index() * 2 + 1),
584                        body->getArguments().slice(it.index() * 2, 1)});
585   }
586   TF_RETURN_IF_ERROR(ConvertNodes(builder, s, function.node_def(), body));
587 
588   // Convert the results and the result attributes.
589   SmallVector<Value> return_operands;
590   return_operands.reserve(signature.output_arg_size() +
591                           signature.control_output_size());
592   for (const OpDef::ArgDef &def : function.signature().output_arg()) {
593     Type dtype;
594     TF_RETURN_IF_ERROR(ConvertDataType(def.type(), b_, &dtype));
595     NamedAttrList attrs;
596     TF_RETURN_IF_ERROR(ConvertArgumentAttributes(def, attrs));
597     res_attrs.push_back(attrs.getDictionary(ctx_));
598     res_types.push_back(UnrankedTensorType::get(dtype));
599 
600     auto ret_it = function.ret().find(def.name());
601     if (ret_it == function.ret().end()) {
602       return InvalidArgument("Output '", def.name(),
603                              "' was not found in 'ret'");
604     }
605     TF_ASSIGN_OR_RETURN(Result result, GetResult(s, ret_it->second));
606     if (result.info)
607       return InvalidArgument("Return '", ret_it->second, "' was not found");
608     if (result.control)
609       return InvalidArgument("Unexpected control result: ", ret_it->second);
610     return_operands.push_back(result.data);
611   }
612 
613   // Convert the control results.
614   for (const std::string &control_ret : signature.control_output()) {
615     auto ret_it = function.control_ret().find(control_ret);
616     if (ret_it == function.control_ret().end()) {
617       return InvalidArgument("Control output '", control_ret,
618                              "' was not found in 'control_ret'");
619     }
620     std::unique_ptr<ResultInfo> &result = s[ret_it->second];
621     if (!result || !result->resolved) {
622       return InvalidArgument("Control return ", ret_it->second,
623                              " was not found");
624     }
625     return_operands.push_back(result->control);
626     control_ret_attrs.push_back(b_.getDictionaryAttr(NamedAttribute(
627         dialect_->getTfgNameAttrIdentifier(), b_.getStringAttr(control_ret))));
628   }
629   builder.create<ReturnOp>(unknown_loc_, return_operands,
630                            b_.getArrayAttr(control_ret_attrs));
631 
632   // Finalize the function attributes.
633   func_attrs.append(func_op.arg_attrsAttrName(), b_.getArrayAttr(arg_attrs));
634   func_attrs.append(func_op.res_attrsAttrName(), b_.getArrayAttr(res_attrs));
635   func_attrs.append(func_op.function_typeAttrName(),
636                     TypeAttr::get(b_.getFunctionType(arg_types, res_types)));
637   func_op->setAttrs(func_attrs.getDictionary(ctx_));
638 
639   return ::tensorflow::OkStatus();
640 }
641 
ConvertNodes(OpBuilder & builder,ConversionState & s,const tensorflow::protobuf::RepeatedPtrField<NodeDef> & nodes,Block * block)642 Status GraphDefImporter::ConvertNodes(
643     OpBuilder &builder, ConversionState &s,
644     const tensorflow::protobuf::RepeatedPtrField<NodeDef> &nodes,
645     Block *block) {
646   OpBuilder::InsertionGuard ig(builder);
647   builder.setInsertionPointToStart(block);
648   for (const NodeDef &node : nodes) {
649     TF_RETURN_IF_ERROR(ConvertNodeDef(builder, s, node));
650   }
651 
652   // If the placeholder has remaining uses, then an input is missing.
653   if (TF_PREDICT_FALSE(!s.GetPlaceholder().use_empty())) {
654     // Stringify a result ID.
655     const auto id_to_str = [](const ResultId &id) {
656       std::string name = id.node.str();
657       if (id.IsControl()) return absl::StrCat("^", name);
658       if (id.output.empty())
659         return id.index ? absl::StrCat(id.node.str(), ":", id.index) : name;
660       return absl::StrCat(name, ":", id.output.str(), ":", id.index);
661     };
662     // Gather all missing input edges.
663     std::vector<std::pair<std::string, std::string>> missing_edges;
664     for (const ResultInfo &info :
665          llvm::make_pointee_range(llvm::make_second_range(s))) {
666       if (info.backedges.empty()) continue;
667       const Backedge &edge = info.backedges.front();
668       missing_edges.emplace_back(id_to_str(edge.id),
669                                  TFOp(edge.operand->getOwner()).name().str());
670     }
671     assert(!missing_edges.empty() &&
672            "placeholder had remaining uses but found no unresolved backedges");
673     // Destroy the invalid IR.
674     block->erase();
675     // Report the missing edges in alphabetical order.
676     llvm::sort(missing_edges);
677     std::string error_message;
678     llvm::raw_string_ostream os(error_message);
679     llvm::interleave(
680         missing_edges, os,
681         [&](const auto &edge) {
682           os << "Non-existent input " << edge.first << " in node "
683              << edge.second;
684         },
685         "\n");
686     return InvalidArgument(std::move(os.str()));
687   }
688   // The placeholder has no uses and should not acquire any more uses. Safely
689   // delete it from the IR.
690   s.Finalize();
691 
692   return ::tensorflow::OkStatus();
693 }
694 
ArgNumType(const NamedAttrList & attrs,const OpDef::ArgDef & arg_def,SmallVectorImpl<Type> & types)695 StatusOr<unsigned> GraphDefImporter::ArgNumType(const NamedAttrList &attrs,
696                                                 const OpDef::ArgDef &arg_def,
697                                                 SmallVectorImpl<Type> &types) {
698   // Check whether a type list attribute is specified.
699   if (!arg_def.type_list_attr().empty()) {
700     if (auto v =
701             attrs.get(arg_def.type_list_attr()).dyn_cast_or_null<ArrayAttr>()) {
702       for (Attribute attr : v) {
703         if (auto dtype = attr.dyn_cast<TypeAttr>()) {
704           types.push_back(UnrankedTensorType::get(dtype.getValue()));
705         } else {
706           return InvalidArgument("Expected '", arg_def.type_list_attr(),
707                                  "' to be a list of types");
708         }
709       }
710       return v.size();
711     }
712     return NotFound("Type attr not found: ", arg_def.type_list_attr());
713   }
714 
715   unsigned num = 1;
716   // Check whether a number attribute is specified.
717   if (!arg_def.number_attr().empty()) {
718     if (auto v =
719             attrs.get(arg_def.number_attr()).dyn_cast_or_null<IntegerAttr>()) {
720       num = v.getValue().getZExtValue();
721     } else {
722       return NotFound("Type attr not found: ", arg_def.number_attr());
723     }
724   }
725 
726   // Check for a type or type attribute.
727   Type dtype;
728   if (arg_def.type() != DataType::DT_INVALID) {
729     TF_RETURN_IF_ERROR(ConvertDataType(arg_def.type(), b_, &dtype));
730   } else if (arg_def.type_attr().empty()) {
731     return InvalidArgument("Arg '", arg_def.name(),
732                            "' has invalid type and no type attribute");
733   } else {
734     if (auto v = attrs.get(arg_def.type_attr()).dyn_cast_or_null<TypeAttr>()) {
735       dtype = v.getValue();
736     } else {
737       return NotFound("Type attr not found: ", arg_def.type_attr());
738     }
739   }
740   types.append(num, UnrankedTensorType::get(dtype));
741   return num;
742 }
743 
ConvertNodeDef(OpBuilder & builder,ConversionState & s,const NodeDef & node)744 Status GraphDefImporter::ConvertNodeDef(OpBuilder &builder, ConversionState &s,
745                                         const NodeDef &node) {
746   VLOG(4) << "Importing: " << node.name();
747   if (node.op().empty())
748     return InvalidArgument("Node ", node.name(), " has an empty op name");
749 
750   OperationState state(ConvertLocation(node), absl::StrCat("tfg.", node.op()));
751 
752   // The GraphImporter does light shape inference, but here we will defer all of
753   // that to the shape inference pass.
754   const OpDef *op_def;
755   const OpRegistrationData *op_reg_data = nullptr;
756   if ((op_reg_data = registry_.LookUp(node.op()))) {
757     op_def = &op_reg_data->op_def;
758   } else {
759     auto it = function_op_defs_.find(node.op());
760     if (it == function_op_defs_.end())
761       return InvalidArgument("Unable to find OpDef for ", node.op());
762     op_def = it->second;
763   }
764 
765   // Import the attributes. Reserve `+3` for `device`,`name`, and `fulltype`.
766   state.attributes.reserve(node.attr_size() + 3);
767   if (!node.device().empty()) {
768     state.addAttribute(dialect_->getDeviceAttrIdentifier(),
769                        b_.getStringAttr(node.device()));
770   }
771   if (!node.name().empty()) {
772     state.addAttribute(dialect_->getNameAttrIdentifier(),
773                        b_.getStringAttr(node.name()));
774   }
775 
776   // If the op doesn't have a FullType, try to infer one.
777   const auto add_full_type = [&](const FullTypeDef &full_type_def) {
778     TF_ASSIGN_OR_RETURN(tf_type::FullTypeAttr full_type,
779                         ConvertAttribute(full_type_def, b_));
780     state.addAttribute(dialect_->getFullTypeAttrIdentifier(), full_type);
781     return ::tensorflow::OkStatus();
782   };
783   if (node.has_experimental_type()) {
784     TF_RETURN_IF_ERROR(add_full_type(node.experimental_type()));
785   } else if (op_reg_data && op_reg_data->type_ctor) {
786     FullTypeDef full_type_def;
787     TF_RETURN_IF_ERROR(
788         tensorflow::full_type::SpecializeType(node, *op_def, full_type_def));
789     TF_RETURN_IF_ERROR(add_full_type(full_type_def));
790   }
791 
792   for (auto &name_attr : node.attr()) {
793     if (name_attr.first.empty())
794       return InvalidArgument("Node ", node.name(), " has an empty attr name");
795     TF_ASSIGN_OR_RETURN(Attribute attr,
796                         ConvertAttributeValue(name_attr.second, b_));
797     state.addAttribute(name_attr.first, attr);
798   }
799 
800   // Add missing default attributes.
801   for (const auto &attr_def : op_def->attr()) {
802     if (attr_def.has_default_value() &&
803         !state.attributes.get(attr_def.name())) {
804       TF_ASSIGN_OR_RETURN(Attribute attr,
805                           ConvertAttributeValue(attr_def.default_value(), b_));
806       state.addAttribute(attr_def.name(), attr);
807     }
808   }
809 
810   // Get the result types. Ops can have multiple named results. Track the
811   // segment sizes.
812   SmallVector<std::pair<unsigned, unsigned>> result_segments;
813   result_segments.reserve(op_def->output_arg_size());
814   state.types.reserve(op_def->output_arg_size() + 1);
815   for (const OpDef::ArgDef &def : op_def->output_arg()) {
816     unsigned index = state.types.size();
817     TF_ASSIGN_OR_RETURN(unsigned size,
818                         ArgNumType(state.attributes, def, state.types));
819     result_segments.emplace_back(index, size);
820   }
821   state.types.push_back(dialect_->getControlType());
822 
823   // Collect the operands. Set backedges to a placeholder and resolve them
824   // later.
825   state.operands.reserve(node.input_size());
826   SmallVector<Value> control_operands;
827   struct BackedgeResolution {
828     ResultInfo *info;
829     size_t operand_index;
830     ResultId id;
831   };
832   SmallVector<BackedgeResolution> unresolved_data_operands,
833       unresolved_control_operands;
834   for (const std::string &input : node.input()) {
835     TF_ASSIGN_OR_RETURN(Result result, GetResult(s, input));
836     if (result.control) {
837       if (result.info) {
838         unresolved_control_operands.push_back(BackedgeResolution{
839             result.info, control_operands.size(), result.id});
840       }
841       control_operands.push_back(result.control);
842     } else {
843       if (result.info) {
844         unresolved_data_operands.push_back(
845             BackedgeResolution{result.info, state.operands.size(), result.id});
846       }
847       state.operands.push_back(result.data);
848     }
849   }
850   unsigned num_data_operands = state.operands.size();
851   state.addOperands(control_operands);
852 
853   // Create the op and record any unresolved operands.
854   Operation *op = builder.create(state);
855   for (const BackedgeResolution &r : unresolved_data_operands) {
856     r.info->backedges.push_back(
857         Backedge{r.id, &op->getOpOperand(r.operand_index)});
858   }
859   for (const BackedgeResolution &r : unresolved_control_operands) {
860     r.info->backedges.push_back(
861         Backedge{r.id, &op->getOpOperand(num_data_operands + r.operand_index)});
862   }
863 
864   std::unique_ptr<ResultInfo> &info = s[node.name()];
865   if (!info) {
866     info = std::make_unique<ResultInfo>();
867   }
868   info->resolved = true;
869   info->control = *std::prev(op->result_end());
870   info->data = op->getResults().drop_back();
871   for (auto it : llvm::zip(result_segments, op_def->output_arg())) {
872     const std::pair<unsigned, unsigned> &segment = std::get<0>(it);
873     info->outputs.emplace(std::get<1>(it).name(),
874                           info->data.slice(segment.first, segment.second));
875   }
876 
877   // Resolve any associated backedges.
878   for (const Backedge &backedge : info->backedges) {
879     Value value;
880     if (backedge.id.IsControl()) {
881       value = info->control;
882     } else {
883       TF_ASSIGN_OR_RETURN(value, ResolveDataResult(backedge.id, info.get()));
884     }
885     backedge.operand->set(value);
886   }
887   info->backedges.clear();
888 
889   return ::tensorflow::OkStatus();
890 }
891 
ConvertDataTypesToUnrankedTensorTypes(const DataTypeVector & dtypes,SmallVectorImpl<Type> & results)892 Status GraphDefImporter::ConvertDataTypesToUnrankedTensorTypes(
893     const DataTypeVector &dtypes, SmallVectorImpl<Type> &results) {
894   Type dtype;
895   for (DataType tf_dtype : dtypes) {
896     TF_RETURN_IF_ERROR(ConvertDataType(tf_dtype, b_, &dtype));
897     results.push_back(UnrankedTensorType::get(dtype));
898   }
899   return ::tensorflow::OkStatus();
900 }
901 
ImportGraphDef(MLIRContext * context,const GraphDebugInfo & debug_info,const GraphDef & graph_def)902 StatusOr<OwningOpRef<ModuleOp>> ImportGraphDef(MLIRContext *context,
903                                                const GraphDebugInfo &debug_info,
904                                                const GraphDef &graph_def) {
905   GraphDefImporter importer(context->getOrLoadDialect<TFGraphDialect>(),
906                             *OpRegistry::Global(), debug_info);
907   return importer.ConvertGraphDef(graph_def);
908 }
909 
ImportGraphAndFunctionsToMlir(MLIRContext * context,const GraphDebugInfo & debug_info,const Graph & graph,const FunctionLibraryDefinition & flib_def)910 StatusOr<OwningOpRef<ModuleOp>> ImportGraphAndFunctionsToMlir(
911     MLIRContext *context, const GraphDebugInfo &debug_info, const Graph &graph,
912     const FunctionLibraryDefinition &flib_def) {
913   // TODO(b/231723721): This conversion path is slow because both the graph and
914   // the function library are converted to GraphDef.
915   GraphDef graph_def;
916   graph.ToGraphDef(&graph_def);
917   *graph_def.mutable_library() = flib_def.ToProto();
918   return ImportGraphDef(context, debug_info, graph_def);
919 }
920 
921 }  // namespace tfg
922 }  // namespace mlir
923