1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/ir/importexport/graphdef_import.h"
17
18 #include <string>
19 #include <utility>
20
21 #include "absl/container/flat_hash_map.h"
22 #include "llvm/ADT/STLExtras.h"
23 #include "llvm/ADT/ScopeExit.h"
24 #include "llvm/Support/raw_ostream.h"
25 #include "mlir/IR/Builders.h" // from @llvm-project
26 #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
27 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
28 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
29 #include "mlir/IR/Location.h" // from @llvm-project
30 #include "mlir/IR/MLIRContext.h" // from @llvm-project
31 #include "mlir/IR/OperationSupport.h" // from @llvm-project
32 #include "mlir/IR/OwningOpRef.h" // from @llvm-project
33 #include "mlir/IR/Threading.h" // from @llvm-project
34 #include "mlir/Support/LLVM.h" // from @llvm-project
35 #include "mlir/Support/LogicalResult.h" // from @llvm-project
36 #include "tensorflow/core/framework/full_type.pb.h"
37 #include "tensorflow/core/framework/function.h"
38 #include "tensorflow/core/framework/function.pb.h"
39 #include "tensorflow/core/framework/graph.pb.h"
40 #include "tensorflow/core/framework/node_def.pb.h"
41 #include "tensorflow/core/framework/node_def_util.h"
42 #include "tensorflow/core/framework/op.h"
43 #include "tensorflow/core/framework/op_def.pb.h"
44 #include "tensorflow/core/framework/op_def_builder.h"
45 #include "tensorflow/core/framework/versions.pb.h"
46 #include "tensorflow/core/graph/graph.h"
47 #include "tensorflow/core/graph/tensor_id.h"
48 #include "tensorflow/core/ir/dialect.h"
49 #include "tensorflow/core/ir/importexport/convert_attributes.h"
50 #include "tensorflow/core/ir/importexport/convert_types.h"
51 #include "tensorflow/core/ir/importexport/functiondef_import.h"
52 #include "tensorflow/core/ir/ops.h"
53 #include "tensorflow/core/ir/types/dialect.h"
54 #include "tensorflow/core/platform/errors.h"
55 #include "tensorflow/core/platform/statusor.h"
56 #include "tensorflow/core/platform/stringpiece.h"
57 #include "tensorflow/core/protobuf/graph_debug_info.pb.h"
58
59 using tensorflow::DataType;
60 using tensorflow::DataTypeVector;
61 using tensorflow::FullTypeDef;
62 using tensorflow::FunctionDef;
63 using tensorflow::FunctionLibraryDefinition;
64 using tensorflow::Graph;
65 using tensorflow::GraphDebugInfo;
66 using tensorflow::GraphDef;
67 using tensorflow::NodeDef;
68 using tensorflow::OpDef;
69 using tensorflow::OpRegistrationData;
70 using tensorflow::OpRegistry;
71 using tensorflow::Status;
72 using tensorflow::StatusOr;
73 using tensorflow::StringPiece;
74 using tensorflow::TensorId;
75 using tensorflow::VersionDef;
76 using tensorflow::errors::InvalidArgument;
77 using tensorflow::errors::NotFound;
78
79 namespace mlir {
80 namespace tfg {
81 namespace {
82 // This class implements an importer for GraphDef directly to TFG.
83 class GraphDefImporter {
84 public:
85 // Initialize the importer.
GraphDefImporter(TFGraphDialect * dialect,const OpRegistry & registry,const GraphDebugInfo & debug_info)86 GraphDefImporter(TFGraphDialect *dialect, const OpRegistry ®istry,
87 const GraphDebugInfo &debug_info)
88 : ctx_(dialect->getContext()),
89 dialect_(dialect),
90 b_(ctx_),
91 registry_(registry),
92 debug_info_(debug_info),
93 unknown_loc_(UnknownLoc::get(ctx_)),
94 placeholder_state_(unknown_loc_, "tfg._mlir_placeholder") {
95 placeholder_state_.addTypes(dialect_->getControlType());
96 }
97
98 // Convert a GraphDef to MLIR module.
99 StatusOr<OwningOpRef<ModuleOp>> ConvertGraphDef(const GraphDef &graph);
100
101 private:
102 // Convert a function. This function must be thread-safe.
103 Status ConvertFunctionDef(
104 GraphFuncOp func_op,
105 const absl::flat_hash_map<StringPiece, StringPiece> &gradient_map,
106 const FunctionDef &function);
107
108 // A result ID representing an output of `node`. E.g.
109 // "foo" -> {0, "foo", ""}
110 // "foo:2" -> {2, "foo", ""}
111 // "foo:output:0" -> {0, "foo", "output"}
112 struct ResultId {
113 // The result or result segment index.
114 int index;
115 // The name of the parent node.
116 StringRef node;
117 // An optional result segment name.
118 StringRef output;
119
120 // Returns true if the result ID references the control token.
IsControlmlir::tfg::__anon07e4e3680111::GraphDefImporter::ResultId121 bool IsControl() const { return index == tensorflow::Graph::kControlSlot; }
122 };
123
124 // An unresolved backedge.
125 struct Backedge {
126 // The edge name and index.
127 ResultId id;
128 // The OpOperand to resolve;
129 OpOperand *operand;
130 };
131
132 // Cached info about the result of an operation.
133 struct ResultInfo {
134 // This flag is true if the results of the operation have been resolved; the
135 // operation has been created and its `data` and `control` results have been
136 // populated. If false, the placeholder should be used.
137 bool resolved = false;
138 // The control result.
139 Value control;
140 // All data results.
141 ValueRange data;
142 // Data results organized by output name.
143 absl::flat_hash_map<StringPiece, ValueRange> outputs;
144 // A list of unresolved backedges.
145 std::vector<Backedge> backedges;
146 };
147
148 // State when converting a list of nodes.
149 class ConversionState
150 : public absl::flat_hash_map<StringPiece, std::unique_ptr<ResultInfo>> {
151 public:
152 // Create a conversion state with a placeholder value. Put the plaecholder
153 // in the block so that it is owned.
ConversionState(Block * block,const OperationState & placeholder_state)154 explicit ConversionState(Block *block,
155 const OperationState &placeholder_state)
156 : placeholder_op_(
157 OpBuilder::atBlockBegin(block).create(placeholder_state)),
158 placeholder_(placeholder_op_->getResult(0)) {}
159
160 // Get the placeholder value.
GetPlaceholder()161 Value GetPlaceholder() { return placeholder_; }
162
163 // Finalize the conversion. The placeholder is destroyed.
Finalize()164 void Finalize() { placeholder_op_->erase(); }
165
166 private:
167 // The placeholder operation.
168 Operation *placeholder_op_;
169 // The placeholder value.
170 Value placeholder_;
171 };
172 // Convert a list a nodes to operations.
173 Status ConvertNodes(
174 OpBuilder &builder, ConversionState &s,
175 const tensorflow::protobuf::RepeatedPtrField<NodeDef> &nodes,
176 Block *block);
177 // Convert a node to an operation.
178 Status ConvertNodeDef(OpBuilder &builder, ConversionState &s,
179 const NodeDef &node);
180 // Resolve a data result reference.
181 static StatusOr<Value> ResolveDataResult(const ResultId &id,
182 ResultInfo *info);
183
184 // Get a named result.
185 struct Result {
186 Value control;
187 Value data;
188 ResultId id;
189 ResultInfo *info = nullptr;
190 };
191 StatusOr<Result> GetResult(ConversionState &s, StringPiece name);
192
193 // Convert TF datatypes to unranked MLIR tensor types.
194 Status ConvertDataTypesToUnrankedTensorTypes(const DataTypeVector &dtypes,
195 SmallVectorImpl<Type> &results);
196 // Extracts the actual data types from `attrs` based on its definition in
197 // `arg_def` and converts them to unranked tensors. Returns the number of
198 // added types.
199 //
200 // TODO(jeffniu): This is a re-implementation of `ArgNumType` in
201 // `core/framework/function.cc` on `NamedAttrList` because the default
202 // attributes need to be added. Find a way to do this in one pass.
203 StatusOr<unsigned> ArgNumType(const NamedAttrList &attrs,
204 const OpDef::ArgDef &arg_def,
205 SmallVectorImpl<Type> &types);
206 // Convert function attributes to MLIR attributes.
207 Status ConvertFunctionAttributes(
208 const absl::flat_hash_map<StringPiece, StringPiece> &gradient_map,
209 const FunctionDef &function, GraphFuncOp op, NamedAttrList &attrs);
210 // Convert function argument attributes to MLIR attributes.
211 Status ConvertArgumentAttributes(const OpDef::ArgDef &def,
212 NamedAttrList &attrs);
213 // Create a location for a node.
214 Location ConvertLocation(const NodeDef &node);
215 // Convert the location of a node from the debug info. If it has no debug
216 // info, return a NameLoc.
217 Location ConvertLocation(StringRef node_name, StringRef func_name);
218
219 // The MLIR context.
220 MLIRContext *ctx_;
221 // Reference to the TFG dialect.
222 TFGraphDialect *dialect_;
223 // The builder instance.
224 Builder b_;
225 // The TF op registry to use.
226 const OpRegistry ®istry_;
227 // The debug info about the graph.
228 const GraphDebugInfo &debug_info_;
229 // Cached unknown location.
230 Location unknown_loc_;
231 // Operation state for creating placeholder ops.
232 OperationState placeholder_state_;
233
234 // Map of function OpDefs.
235 absl::flat_hash_map<StringPiece, const OpDef *> function_op_defs_;
236 };
237 } // namespace
238
239 // Convert a VersionDef to an MLIR version attribute.
ConvertVersionAttr(MLIRContext * context,const VersionDef & version)240 static VersionAttr ConvertVersionAttr(MLIRContext *context,
241 const VersionDef &version) {
242 ArrayRef<int32_t> bad_consumers(version.bad_consumers().data(),
243 version.bad_consumers().size());
244 return VersionAttr::get(context, version.producer(), version.min_consumer(),
245 bad_consumers);
246 }
247
248 // Returns true if the function is a generic function, i.e. it contains
249 // placeholder attributes.
250 //
251 // TODO(jeffniu): Having to iterate over every function just to check for
252 // placeholder attributes is slow. Since most functions are not generic, we can
253 // speculate by converting all functions as non-generic until we see a
254 // placeholder attribute, bail out, and fall back to the generic function
255 // converter.
IsGenericFunction(const FunctionDef & fdef)256 static bool IsGenericFunction(const FunctionDef &fdef) {
257 for (const NodeDef &node : fdef.node_def())
258 for (const auto &named_attr : node.attr())
259 if (!named_attr.second.placeholder().empty()) return true;
260
261 return false;
262 }
263
ConvertGraphDef(const GraphDef & graph)264 StatusOr<OwningOpRef<ModuleOp>> GraphDefImporter::ConvertGraphDef(
265 const GraphDef &graph) {
266 // Create the module.
267 OwningOpRef<ModuleOp> module = ModuleOp::create(unknown_loc_);
268
269 // Create the graph op.
270 auto builder = OpBuilder::atBlockBegin(module->getBody());
271 auto graph_op = builder.create<GraphOp>(
272 module->getLoc(), ConvertVersionAttr(ctx_, graph.versions()));
273 graph_op.nodes().push_back(new Block);
274
275 // Populate the function op defs.
276 function_op_defs_.reserve(graph.library().function_size());
277 for (const FunctionDef &function : graph.library().function()) {
278 function_op_defs_.emplace(function.signature().name(),
279 &function.signature());
280 }
281
282 // Build a map from function name to gradient function name.
283 absl::flat_hash_map<StringPiece, StringPiece> gradient_map;
284 gradient_map.reserve(graph.library().gradient_size());
285 for (const tensorflow::GradientDef &gradient : graph.library().gradient())
286 gradient_map.emplace(gradient.function_name(), gradient.gradient_func());
287
288 // Convert the graph.
289 ConversionState s(&graph_op.nodes().front(), placeholder_state_);
290 TF_RETURN_IF_ERROR(
291 ConvertNodes(builder, s, graph.node(), &graph_op.nodes().front()));
292
293 // A function to convert a generic or non-generic function.
294 const auto convert_func = [this, &gradient_map](GraphFuncOp func_op,
295 const FunctionDef &function) {
296 if (IsGenericFunction(function)) {
297 // Generic functions aren't on the hot path so just call the old
298 // importer.
299 OpBuilder builder(ctx_);
300 TF_RETURN_WITH_CONTEXT_IF_ERROR(
301 ConvertGenericFunction(func_op, function, builder),
302 "While importing generic function: ", function.signature().name());
303 } else {
304 TF_RETURN_WITH_CONTEXT_IF_ERROR(
305 ConvertFunctionDef(func_op, gradient_map, function),
306 "While importing function: ", function.signature().name());
307 }
308 return ::tensorflow::OkStatus();
309 };
310
311 // TODO(jeffniu): Don't import functions in parallel if there are too few (how
312 // few?) or if the functions are too small (how small?).
313 if (ctx_->isMultithreadingEnabled()) {
314 ctx_->enterMultiThreadedExecution();
315 auto exit =
316 llvm::make_scope_exit([this] { ctx_->exitMultiThreadedExecution(); });
317
318 // Prepare the arguments to parallel for each.
319 struct Argument {
320 GraphFuncOp func;
321 const FunctionDef &def;
322 Status status;
323 };
324 std::vector<Argument> args;
325 args.reserve(graph.library().function_size());
326 for (const FunctionDef &function : graph.library().function()) {
327 args.push_back(
328 Argument{builder.create<GraphFuncOp>(unknown_loc_), function});
329 }
330 const auto process_func = [&convert_func](Argument &arg) {
331 arg.status = convert_func(arg.func, arg.def);
332 return success(arg.status.ok());
333 };
334
335 // Execute the imports in parallel.
336 if (failed(failableParallelForEach(ctx_, args, process_func))) {
337 Status result;
338 for (const Argument &arg : args) {
339 result.Update(arg.status);
340 }
341 return result;
342 }
343 } else {
344 // Convert the functions.
345 for (const FunctionDef &function : graph.library().function()) {
346 auto func_op = builder.create<GraphFuncOp>(unknown_loc_);
347 TF_RETURN_IF_ERROR(convert_func(func_op, function));
348 }
349 }
350
351 return module;
352 }
353
ConvertFunctionAttributes(const absl::flat_hash_map<StringPiece,StringPiece> & gradient_map,const FunctionDef & function,GraphFuncOp op,NamedAttrList & attrs)354 Status GraphDefImporter::ConvertFunctionAttributes(
355 const absl::flat_hash_map<StringPiece, StringPiece> &gradient_map,
356 const FunctionDef &function, GraphFuncOp op, NamedAttrList &attrs) {
357 // Import the function attributes with a `tf.` prefix to match the current
358 // infratructure expectations.
359 for (const auto &name_attr : function.attr()) {
360 if (name_attr.first.empty()) {
361 return InvalidArgument("Function ", function.signature().name(),
362 " has an empty attr name");
363 }
364 // TODO(b/230143351): `ConvertAttributeValue` is a little slow due to
365 // `ConvertTensorProto` and `ConvertTensorShapeProto`.
366 TF_ASSIGN_OR_RETURN(Attribute attr,
367 ConvertAttributeValue(name_attr.second, b_));
368 attrs.append(absl::StrCat("tf.", name_attr.first), attr);
369 }
370
371 // Convert the first-class attributes.
372 const tensorflow::OpDef &signature = function.signature();
373 if (signature.name().empty())
374 return InvalidArgument("Function without a name");
375 attrs.append(op.sym_nameAttrName(), b_.getStringAttr(signature.name()));
376
377 if (!signature.description().empty()) {
378 attrs.append(op.descriptionAttrName(),
379 b_.getStringAttr(signature.description()));
380 }
381 if (signature.is_stateful())
382 attrs.append(op.is_statefulAttrName(), b_.getUnitAttr());
383 auto grad_it = gradient_map.find(signature.name());
384 if (grad_it != gradient_map.end()) {
385 StringPiece name = grad_it->second;
386 attrs.append(op.gradientAttrName(),
387 FlatSymbolRefAttr::get(ctx_, {name.data(), name.size()}));
388 }
389
390 // The resource_arg_unique_id is a list of `pair<int, int>`, we import it
391 // as two arrays of integer right now.
392 if (function.resource_arg_unique_id_size()) {
393 SmallVector<int32_t> resource_arg_unique_ids_keys;
394 SmallVector<int32_t> resource_arg_unique_ids_values;
395 resource_arg_unique_ids_keys.reserve(
396 function.resource_arg_unique_id_size());
397 resource_arg_unique_ids_values.reserve(
398 function.resource_arg_unique_id_size());
399 for (const auto &unique_id : function.resource_arg_unique_id()) {
400 resource_arg_unique_ids_keys.push_back(unique_id.first);
401 resource_arg_unique_ids_values.push_back(unique_id.second);
402 }
403 attrs.append(op.resource_arg_unique_ids_keysAttrName(),
404 b_.getI32TensorAttr(resource_arg_unique_ids_keys));
405 attrs.append(op.resource_arg_unique_ids_valuesAttrName(),
406 b_.getI32TensorAttr(resource_arg_unique_ids_values));
407 }
408 return ::tensorflow::OkStatus();
409 }
410
ConvertArgumentAttributes(const OpDef::ArgDef & def,NamedAttrList & attrs)411 Status GraphDefImporter::ConvertArgumentAttributes(const OpDef::ArgDef &def,
412 NamedAttrList &attrs) {
413 attrs.append(dialect_->getTfgNameAttrIdentifier(),
414 b_.getStringAttr(def.name()));
415 if (!def.description().empty()) {
416 attrs.append(dialect_->getTfgDescriptionAttrIdentifier(),
417 b_.getStringAttr(def.description()));
418 }
419 if (def.is_ref())
420 attrs.append(dialect_->getTfgIsRefAttrIdentifier(), b_.getUnitAttr());
421 if (def.handle_data_size()) {
422 TF_ASSIGN_OR_RETURN(Attribute handle_data,
423 ConvertHandleData(b_, def.handle_data()));
424 attrs.append(dialect_->getTfgHandleDataAttrIdentifier(), handle_data);
425 }
426 if (def.has_experimental_full_type()) {
427 TF_ASSIGN_OR_RETURN(tf_type::FullTypeAttr full_type,
428 ConvertAttribute(def.experimental_full_type(), b_));
429 attrs.append(dialect_->getTfgFullTypeAttrIdentifier(), full_type);
430 }
431 return ::tensorflow::OkStatus();
432 }
433
ConvertLocation(const NodeDef & node)434 Location GraphDefImporter::ConvertLocation(const NodeDef &node) {
435 if (!node.has_experimental_debug_info()) return unknown_loc_;
436
437 const auto &debug_info = node.experimental_debug_info();
438 const auto &original_nodes = debug_info.original_node_names();
439 const auto &original_funcs = debug_info.original_func_names();
440 if (original_nodes.empty()) return unknown_loc_;
441
442 SmallVector<Location> node_locs;
443 node_locs.reserve(original_nodes.size());
444 for (auto &it : llvm::enumerate(original_nodes)) {
445 std::string func_name =
446 it.index() < original_funcs.size() ? original_funcs[it.index()] : "";
447 node_locs.push_back(ConvertLocation(it.value(), func_name));
448 }
449 return b_.getFusedLoc(node_locs);
450 }
451
452 // This is a re-implementation of GetLocation in `import.cc`.
ConvertLocation(StringRef node_name,StringRef func_name)453 Location GraphDefImporter::ConvertLocation(StringRef node_name,
454 StringRef func_name) {
455 // Concatenate the node name with the function name to match how the key is
456 // formed in Python.
457 std::string debug_info_key = (node_name + "@" + func_name).str();
458 std::string name_loc = func_name.empty() ? node_name.str() : debug_info_key;
459 auto name_loc_id = b_.getStringAttr(name_loc);
460
461 SmallVector<Location> locs;
462 const auto &traces = debug_info_.traces();
463 // Try to find a stack trace to convert to locations.
464 auto it = traces.find(debug_info_key);
465 if (it != traces.end()) {
466 const auto &trace = it->second;
467 locs.reserve(trace.file_line_cols_size());
468 for (const auto &loc : trace.file_line_cols()) {
469 auto file_name = b_.getStringAttr(debug_info_.files(loc.file_index()));
470 locs.push_back(FileLineColLoc::get(file_name, loc.line(), loc.col()));
471 }
472 }
473
474 if (locs.empty()) return NameLoc::get(name_loc_id);
475
476 // Use the first location to generate a name location.
477 Location node_name_loc = NameLoc::get(name_loc_id, locs.front());
478 // Generate a stack trace using the remaining locations.
479 ArrayRef<Location> callsite_locs = llvm::makeArrayRef(locs).drop_front();
480 return callsite_locs.empty() ? node_name_loc
481 : CallSiteLoc::get(node_name_loc, callsite_locs);
482 }
483
ResolveDataResult(const ResultId & id,ResultInfo * info)484 StatusOr<Value> GraphDefImporter::ResolveDataResult(const ResultId &id,
485 ResultInfo *info) {
486 if (id.output.empty()) {
487 if (id.index >= info->data.size()) {
488 return InvalidArgument("Result #", id.index, " of node '", id.node.str(),
489 "' is out of bounds");
490 }
491 return info->data[id.index];
492 }
493
494 auto it = info->outputs.find({id.output.data(), id.output.size()});
495 if (it == info->outputs.end()) {
496 return InvalidArgument("Node '", id.node.str(), "' has no output called '",
497 id.output.str(), "'");
498 }
499 if (id.index >= it->second.size()) {
500 return InvalidArgument("Result #", id.index, " of segment '", id.node.str(),
501 ":", id.output.str(), "' is out of bounds");
502 }
503 return it->second[id.index];
504 }
505
GetResult(ConversionState & s,StringPiece name)506 StatusOr<GraphDefImporter::Result> GraphDefImporter::GetResult(
507 ConversionState &s, StringPiece name) {
508 TensorId tensor_id = tensorflow::ParseTensorName(name);
509 ResultId id{tensor_id.index()};
510 std::tie(id.node, id.output) =
511 StringRef(tensor_id.node().data(), tensor_id.node().size()).split(':');
512 std::unique_ptr<ResultInfo> &info = s[{id.node.data(), id.node.size()}];
513 if (!info) {
514 info = std::make_unique<ResultInfo>();
515 }
516
517 // If the result is unresolved, return the placeholder;
518 if (!info->resolved) {
519 if (id.IsControl()) {
520 return Result{s.GetPlaceholder(), Value(), id, info.get()};
521 }
522 return Result{Value(), s.GetPlaceholder(), id, info.get()};
523 }
524
525 // If the result is the control token, return it.
526 if (id.IsControl()) {
527 return Result{info->control, Value()};
528 }
529
530 TF_ASSIGN_OR_RETURN(Value value, ResolveDataResult(id, info.get()));
531 return Result{Value(), value};
532 }
533
ConvertFunctionDef(GraphFuncOp func_op,const absl::flat_hash_map<StringPiece,StringPiece> & gradient_map,const FunctionDef & function)534 Status GraphDefImporter::ConvertFunctionDef(
535 GraphFuncOp func_op,
536 const absl::flat_hash_map<StringPiece, StringPiece> &gradient_map,
537 const FunctionDef &function) {
538 const OpDef &signature = function.signature();
539 // TODO(jeffniu): Does the name need to be mangled?
540
541 func_op.body().push_back(new Block);
542 Block *body = &func_op.body().front();
543 auto builder = OpBuilder::atBlockBegin(func_op.getBody());
544
545 // Convert the attributes.
546 NamedAttrList func_attrs;
547 TF_RETURN_IF_ERROR(
548 ConvertFunctionAttributes(gradient_map, function, func_op, func_attrs));
549
550 SmallVector<Attribute> arg_attrs, res_attrs, control_ret_attrs;
551 SmallVector<Type> arg_types, res_types;
552
553 // Convert the arguments and argument attributes.
554 for (auto &it : llvm::enumerate(signature.input_arg())) {
555 Type dtype;
556 TF_RETURN_IF_ERROR(ConvertDataType(it.value().type(), b_, &dtype));
557 BlockArgument data =
558 body->addArgument(UnrankedTensorType::get(dtype), unknown_loc_);
559 BlockArgument ctl =
560 body->addArgument(dialect_->getControlType(), data.getLoc());
561
562 NamedAttrList attrs;
563 TF_RETURN_IF_ERROR(ConvertArgumentAttributes(it.value(), attrs));
564 auto attr_it = function.arg_attr().find(it.index());
565 if (attr_it != function.arg_attr().end()) {
566 for (const auto &name_attr : attr_it->second.attr()) {
567 TF_ASSIGN_OR_RETURN(Attribute attr,
568 ConvertAttributeValue(name_attr.second, b_));
569 attrs.append("tf." + name_attr.first, attr);
570 }
571 }
572
573 arg_attrs.append({attrs.getDictionary(ctx_), b_.getDictionaryAttr({})});
574 arg_types.append({data.getType(), ctl.getType()});
575 }
576
577 // Iterate over the arguments again and map them. We have to add them first
578 // otherwise the ranges will be invalidated.
579 ConversionState s(body, placeholder_state_);
580 for (const auto &it : llvm::enumerate(signature.input_arg())) {
581 s.emplace(
582 it.value().name(),
583 new ResultInfo{/*resolved=*/true, body->getArgument(it.index() * 2 + 1),
584 body->getArguments().slice(it.index() * 2, 1)});
585 }
586 TF_RETURN_IF_ERROR(ConvertNodes(builder, s, function.node_def(), body));
587
588 // Convert the results and the result attributes.
589 SmallVector<Value> return_operands;
590 return_operands.reserve(signature.output_arg_size() +
591 signature.control_output_size());
592 for (const OpDef::ArgDef &def : function.signature().output_arg()) {
593 Type dtype;
594 TF_RETURN_IF_ERROR(ConvertDataType(def.type(), b_, &dtype));
595 NamedAttrList attrs;
596 TF_RETURN_IF_ERROR(ConvertArgumentAttributes(def, attrs));
597 res_attrs.push_back(attrs.getDictionary(ctx_));
598 res_types.push_back(UnrankedTensorType::get(dtype));
599
600 auto ret_it = function.ret().find(def.name());
601 if (ret_it == function.ret().end()) {
602 return InvalidArgument("Output '", def.name(),
603 "' was not found in 'ret'");
604 }
605 TF_ASSIGN_OR_RETURN(Result result, GetResult(s, ret_it->second));
606 if (result.info)
607 return InvalidArgument("Return '", ret_it->second, "' was not found");
608 if (result.control)
609 return InvalidArgument("Unexpected control result: ", ret_it->second);
610 return_operands.push_back(result.data);
611 }
612
613 // Convert the control results.
614 for (const std::string &control_ret : signature.control_output()) {
615 auto ret_it = function.control_ret().find(control_ret);
616 if (ret_it == function.control_ret().end()) {
617 return InvalidArgument("Control output '", control_ret,
618 "' was not found in 'control_ret'");
619 }
620 std::unique_ptr<ResultInfo> &result = s[ret_it->second];
621 if (!result || !result->resolved) {
622 return InvalidArgument("Control return ", ret_it->second,
623 " was not found");
624 }
625 return_operands.push_back(result->control);
626 control_ret_attrs.push_back(b_.getDictionaryAttr(NamedAttribute(
627 dialect_->getTfgNameAttrIdentifier(), b_.getStringAttr(control_ret))));
628 }
629 builder.create<ReturnOp>(unknown_loc_, return_operands,
630 b_.getArrayAttr(control_ret_attrs));
631
632 // Finalize the function attributes.
633 func_attrs.append(func_op.arg_attrsAttrName(), b_.getArrayAttr(arg_attrs));
634 func_attrs.append(func_op.res_attrsAttrName(), b_.getArrayAttr(res_attrs));
635 func_attrs.append(func_op.function_typeAttrName(),
636 TypeAttr::get(b_.getFunctionType(arg_types, res_types)));
637 func_op->setAttrs(func_attrs.getDictionary(ctx_));
638
639 return ::tensorflow::OkStatus();
640 }
641
ConvertNodes(OpBuilder & builder,ConversionState & s,const tensorflow::protobuf::RepeatedPtrField<NodeDef> & nodes,Block * block)642 Status GraphDefImporter::ConvertNodes(
643 OpBuilder &builder, ConversionState &s,
644 const tensorflow::protobuf::RepeatedPtrField<NodeDef> &nodes,
645 Block *block) {
646 OpBuilder::InsertionGuard ig(builder);
647 builder.setInsertionPointToStart(block);
648 for (const NodeDef &node : nodes) {
649 TF_RETURN_IF_ERROR(ConvertNodeDef(builder, s, node));
650 }
651
652 // If the placeholder has remaining uses, then an input is missing.
653 if (TF_PREDICT_FALSE(!s.GetPlaceholder().use_empty())) {
654 // Stringify a result ID.
655 const auto id_to_str = [](const ResultId &id) {
656 std::string name = id.node.str();
657 if (id.IsControl()) return absl::StrCat("^", name);
658 if (id.output.empty())
659 return id.index ? absl::StrCat(id.node.str(), ":", id.index) : name;
660 return absl::StrCat(name, ":", id.output.str(), ":", id.index);
661 };
662 // Gather all missing input edges.
663 std::vector<std::pair<std::string, std::string>> missing_edges;
664 for (const ResultInfo &info :
665 llvm::make_pointee_range(llvm::make_second_range(s))) {
666 if (info.backedges.empty()) continue;
667 const Backedge &edge = info.backedges.front();
668 missing_edges.emplace_back(id_to_str(edge.id),
669 TFOp(edge.operand->getOwner()).name().str());
670 }
671 assert(!missing_edges.empty() &&
672 "placeholder had remaining uses but found no unresolved backedges");
673 // Destroy the invalid IR.
674 block->erase();
675 // Report the missing edges in alphabetical order.
676 llvm::sort(missing_edges);
677 std::string error_message;
678 llvm::raw_string_ostream os(error_message);
679 llvm::interleave(
680 missing_edges, os,
681 [&](const auto &edge) {
682 os << "Non-existent input " << edge.first << " in node "
683 << edge.second;
684 },
685 "\n");
686 return InvalidArgument(std::move(os.str()));
687 }
688 // The placeholder has no uses and should not acquire any more uses. Safely
689 // delete it from the IR.
690 s.Finalize();
691
692 return ::tensorflow::OkStatus();
693 }
694
ArgNumType(const NamedAttrList & attrs,const OpDef::ArgDef & arg_def,SmallVectorImpl<Type> & types)695 StatusOr<unsigned> GraphDefImporter::ArgNumType(const NamedAttrList &attrs,
696 const OpDef::ArgDef &arg_def,
697 SmallVectorImpl<Type> &types) {
698 // Check whether a type list attribute is specified.
699 if (!arg_def.type_list_attr().empty()) {
700 if (auto v =
701 attrs.get(arg_def.type_list_attr()).dyn_cast_or_null<ArrayAttr>()) {
702 for (Attribute attr : v) {
703 if (auto dtype = attr.dyn_cast<TypeAttr>()) {
704 types.push_back(UnrankedTensorType::get(dtype.getValue()));
705 } else {
706 return InvalidArgument("Expected '", arg_def.type_list_attr(),
707 "' to be a list of types");
708 }
709 }
710 return v.size();
711 }
712 return NotFound("Type attr not found: ", arg_def.type_list_attr());
713 }
714
715 unsigned num = 1;
716 // Check whether a number attribute is specified.
717 if (!arg_def.number_attr().empty()) {
718 if (auto v =
719 attrs.get(arg_def.number_attr()).dyn_cast_or_null<IntegerAttr>()) {
720 num = v.getValue().getZExtValue();
721 } else {
722 return NotFound("Type attr not found: ", arg_def.number_attr());
723 }
724 }
725
726 // Check for a type or type attribute.
727 Type dtype;
728 if (arg_def.type() != DataType::DT_INVALID) {
729 TF_RETURN_IF_ERROR(ConvertDataType(arg_def.type(), b_, &dtype));
730 } else if (arg_def.type_attr().empty()) {
731 return InvalidArgument("Arg '", arg_def.name(),
732 "' has invalid type and no type attribute");
733 } else {
734 if (auto v = attrs.get(arg_def.type_attr()).dyn_cast_or_null<TypeAttr>()) {
735 dtype = v.getValue();
736 } else {
737 return NotFound("Type attr not found: ", arg_def.type_attr());
738 }
739 }
740 types.append(num, UnrankedTensorType::get(dtype));
741 return num;
742 }
743
ConvertNodeDef(OpBuilder & builder,ConversionState & s,const NodeDef & node)744 Status GraphDefImporter::ConvertNodeDef(OpBuilder &builder, ConversionState &s,
745 const NodeDef &node) {
746 VLOG(4) << "Importing: " << node.name();
747 if (node.op().empty())
748 return InvalidArgument("Node ", node.name(), " has an empty op name");
749
750 OperationState state(ConvertLocation(node), absl::StrCat("tfg.", node.op()));
751
752 // The GraphImporter does light shape inference, but here we will defer all of
753 // that to the shape inference pass.
754 const OpDef *op_def;
755 const OpRegistrationData *op_reg_data = nullptr;
756 if ((op_reg_data = registry_.LookUp(node.op()))) {
757 op_def = &op_reg_data->op_def;
758 } else {
759 auto it = function_op_defs_.find(node.op());
760 if (it == function_op_defs_.end())
761 return InvalidArgument("Unable to find OpDef for ", node.op());
762 op_def = it->second;
763 }
764
765 // Import the attributes. Reserve `+3` for `device`,`name`, and `fulltype`.
766 state.attributes.reserve(node.attr_size() + 3);
767 if (!node.device().empty()) {
768 state.addAttribute(dialect_->getDeviceAttrIdentifier(),
769 b_.getStringAttr(node.device()));
770 }
771 if (!node.name().empty()) {
772 state.addAttribute(dialect_->getNameAttrIdentifier(),
773 b_.getStringAttr(node.name()));
774 }
775
776 // If the op doesn't have a FullType, try to infer one.
777 const auto add_full_type = [&](const FullTypeDef &full_type_def) {
778 TF_ASSIGN_OR_RETURN(tf_type::FullTypeAttr full_type,
779 ConvertAttribute(full_type_def, b_));
780 state.addAttribute(dialect_->getFullTypeAttrIdentifier(), full_type);
781 return ::tensorflow::OkStatus();
782 };
783 if (node.has_experimental_type()) {
784 TF_RETURN_IF_ERROR(add_full_type(node.experimental_type()));
785 } else if (op_reg_data && op_reg_data->type_ctor) {
786 FullTypeDef full_type_def;
787 TF_RETURN_IF_ERROR(
788 tensorflow::full_type::SpecializeType(node, *op_def, full_type_def));
789 TF_RETURN_IF_ERROR(add_full_type(full_type_def));
790 }
791
792 for (auto &name_attr : node.attr()) {
793 if (name_attr.first.empty())
794 return InvalidArgument("Node ", node.name(), " has an empty attr name");
795 TF_ASSIGN_OR_RETURN(Attribute attr,
796 ConvertAttributeValue(name_attr.second, b_));
797 state.addAttribute(name_attr.first, attr);
798 }
799
800 // Add missing default attributes.
801 for (const auto &attr_def : op_def->attr()) {
802 if (attr_def.has_default_value() &&
803 !state.attributes.get(attr_def.name())) {
804 TF_ASSIGN_OR_RETURN(Attribute attr,
805 ConvertAttributeValue(attr_def.default_value(), b_));
806 state.addAttribute(attr_def.name(), attr);
807 }
808 }
809
810 // Get the result types. Ops can have multiple named results. Track the
811 // segment sizes.
812 SmallVector<std::pair<unsigned, unsigned>> result_segments;
813 result_segments.reserve(op_def->output_arg_size());
814 state.types.reserve(op_def->output_arg_size() + 1);
815 for (const OpDef::ArgDef &def : op_def->output_arg()) {
816 unsigned index = state.types.size();
817 TF_ASSIGN_OR_RETURN(unsigned size,
818 ArgNumType(state.attributes, def, state.types));
819 result_segments.emplace_back(index, size);
820 }
821 state.types.push_back(dialect_->getControlType());
822
823 // Collect the operands. Set backedges to a placeholder and resolve them
824 // later.
825 state.operands.reserve(node.input_size());
826 SmallVector<Value> control_operands;
827 struct BackedgeResolution {
828 ResultInfo *info;
829 size_t operand_index;
830 ResultId id;
831 };
832 SmallVector<BackedgeResolution> unresolved_data_operands,
833 unresolved_control_operands;
834 for (const std::string &input : node.input()) {
835 TF_ASSIGN_OR_RETURN(Result result, GetResult(s, input));
836 if (result.control) {
837 if (result.info) {
838 unresolved_control_operands.push_back(BackedgeResolution{
839 result.info, control_operands.size(), result.id});
840 }
841 control_operands.push_back(result.control);
842 } else {
843 if (result.info) {
844 unresolved_data_operands.push_back(
845 BackedgeResolution{result.info, state.operands.size(), result.id});
846 }
847 state.operands.push_back(result.data);
848 }
849 }
850 unsigned num_data_operands = state.operands.size();
851 state.addOperands(control_operands);
852
853 // Create the op and record any unresolved operands.
854 Operation *op = builder.create(state);
855 for (const BackedgeResolution &r : unresolved_data_operands) {
856 r.info->backedges.push_back(
857 Backedge{r.id, &op->getOpOperand(r.operand_index)});
858 }
859 for (const BackedgeResolution &r : unresolved_control_operands) {
860 r.info->backedges.push_back(
861 Backedge{r.id, &op->getOpOperand(num_data_operands + r.operand_index)});
862 }
863
864 std::unique_ptr<ResultInfo> &info = s[node.name()];
865 if (!info) {
866 info = std::make_unique<ResultInfo>();
867 }
868 info->resolved = true;
869 info->control = *std::prev(op->result_end());
870 info->data = op->getResults().drop_back();
871 for (auto it : llvm::zip(result_segments, op_def->output_arg())) {
872 const std::pair<unsigned, unsigned> &segment = std::get<0>(it);
873 info->outputs.emplace(std::get<1>(it).name(),
874 info->data.slice(segment.first, segment.second));
875 }
876
877 // Resolve any associated backedges.
878 for (const Backedge &backedge : info->backedges) {
879 Value value;
880 if (backedge.id.IsControl()) {
881 value = info->control;
882 } else {
883 TF_ASSIGN_OR_RETURN(value, ResolveDataResult(backedge.id, info.get()));
884 }
885 backedge.operand->set(value);
886 }
887 info->backedges.clear();
888
889 return ::tensorflow::OkStatus();
890 }
891
ConvertDataTypesToUnrankedTensorTypes(const DataTypeVector & dtypes,SmallVectorImpl<Type> & results)892 Status GraphDefImporter::ConvertDataTypesToUnrankedTensorTypes(
893 const DataTypeVector &dtypes, SmallVectorImpl<Type> &results) {
894 Type dtype;
895 for (DataType tf_dtype : dtypes) {
896 TF_RETURN_IF_ERROR(ConvertDataType(tf_dtype, b_, &dtype));
897 results.push_back(UnrankedTensorType::get(dtype));
898 }
899 return ::tensorflow::OkStatus();
900 }
901
ImportGraphDef(MLIRContext * context,const GraphDebugInfo & debug_info,const GraphDef & graph_def)902 StatusOr<OwningOpRef<ModuleOp>> ImportGraphDef(MLIRContext *context,
903 const GraphDebugInfo &debug_info,
904 const GraphDef &graph_def) {
905 GraphDefImporter importer(context->getOrLoadDialect<TFGraphDialect>(),
906 *OpRegistry::Global(), debug_info);
907 return importer.ConvertGraphDef(graph_def);
908 }
909
ImportGraphAndFunctionsToMlir(MLIRContext * context,const GraphDebugInfo & debug_info,const Graph & graph,const FunctionLibraryDefinition & flib_def)910 StatusOr<OwningOpRef<ModuleOp>> ImportGraphAndFunctionsToMlir(
911 MLIRContext *context, const GraphDebugInfo &debug_info, const Graph &graph,
912 const FunctionLibraryDefinition &flib_def) {
913 // TODO(b/231723721): This conversion path is slow because both the graph and
914 // the function library are converted to GraphDef.
915 GraphDef graph_def;
916 graph.ToGraphDef(&graph_def);
917 *graph_def.mutable_library() = flib_def.ToProto();
918 return ImportGraphDef(context, debug_info, graph_def);
919 }
920
921 } // namespace tfg
922 } // namespace mlir
923