1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/mlir/tensorflow/translate/upgrade_graph.h"
17
18 #include "llvm/ADT/StringSet.h"
19 #include "tensorflow/compiler/mlir/tensorflow/utils/attribute_utils.h"
20 #include "tensorflow/compiler/tf2xla/functionalize_control_flow.h"
21 #include "tensorflow/core/common_runtime/device.h"
22 #include "tensorflow/core/common_runtime/device_factory.h"
23 #include "tensorflow/core/common_runtime/device_mgr.h"
24 #include "tensorflow/core/grappler/clusters/virtual_cluster.h"
25 #include "tensorflow/core/grappler/grappler_item.h"
26 #include "tensorflow/core/grappler/grappler_item_builder.h"
27 #include "tensorflow/core/grappler/optimizers/meta_optimizer.h"
28 #include "tensorflow/core/protobuf/meta_graph.pb.h"
29
30 namespace tensorflow {
31 namespace {
32
33 // Returns the ops that should use node name if shared_name is empty.
GetOpsUsingNodeName()34 const llvm::StringSet<>& GetOpsUsingNodeName() {
35 static auto* const ops =
36 new llvm::StringSet<>({"VariableV2", "Variable", "BatchFunction"});
37 return *ops;
38 }
39
40 // Returns the set of ops that we want to generate shared_names for them if
41 // empty.
GetSharedNameGenerationCompatibleOps()42 const llvm::StringSet<>& GetSharedNameGenerationCompatibleOps() {
43 return GetOpsUsingNodeName();
44 }
45
46 } // namespace
47
GenerateResourceSharedNameIfEmpty(GraphDef & gdef,const OpRegistryInterface * default_registry)48 Status GenerateResourceSharedNameIfEmpty(
49 GraphDef& gdef, const OpRegistryInterface* default_registry) {
50 auto is_resource_op_with_empty_shared_name = [](const NodeDef& node_def,
51 const OpDef& op_def) {
52 if (!GetSharedNameGenerationCompatibleOps().contains(op_def.name())) {
53 // If this op is not in the allowlist, then it is likely a custom op.
54 // Currently for these ops, we are relying on its "use_node_name_sharing"
55 // to decide whether it is valid to generate shared_names. If the OpDef
56 // has "use_node_name_sharing" field, then it is valid to use node names
57 // as shared names.
58 if (!std::any_of(op_def.attr().begin(), op_def.attr().end(),
59 [](const auto& attr_def) {
60 return attr_def.name() == "use_node_name_sharing" &&
61 attr_def.type() == "bool";
62 }))
63 return false;
64 }
65
66 if (!std::any_of(op_def.attr().begin(), op_def.attr().end(),
67 [](const auto& attr_def) {
68 return attr_def.name() == "shared_name" &&
69 attr_def.type() == "string";
70 }))
71 return false;
72
73 auto iter = node_def.attr().find("shared_name");
74 if (iter == node_def.attr().end()) return true;
75 return iter->second.s().empty();
76 };
77
78 FunctionDefLibrary* library = gdef.mutable_library();
79 auto flib_def = library ? std::make_unique<FunctionLibraryDefinition>(
80 default_registry, *library)
81 : std::make_unique<FunctionLibraryDefinition>(
82 default_registry, FunctionDefLibrary());
83
84 if (library) {
85 // Upgrade nodes in the functions.
86 for (FunctionDef& fdef : *library->mutable_function()) {
87 auto func_name = fdef.signature().name();
88 for (auto& node_def : *fdef.mutable_node_def()) {
89 const OpDef* op_def = nullptr;
90 // With lazy loading, some functions might not be executed, thus we skip
91 // the node if the op is not registered.
92 if (flib_def->LookUpOpDef(node_def.op(), &op_def).ok() &&
93 is_resource_op_with_empty_shared_name(node_def, *op_def)) {
94 // TODO(b/197144710): improve the shared_name attr, each op may use
95 // the shared_name differently.
96 if (GetOpsUsingNodeName().contains(op_def->name())) {
97 // Use the node name for such ops as the shared_name according to
98 // the document of variable ops.
99 (*node_def.mutable_attr())["shared_name"].set_s(node_def.name());
100 } else {
101 // Use the concat of function name and node name for such ops in a
102 // function as the shared_name. "@" is used as the separator because
103 // it is not allowed in the function name or the node name.
104 (*node_def.mutable_attr())["shared_name"].set_s(
105 absl::StrCat(node_def.name(), "@", func_name));
106 }
107 }
108 }
109 }
110 }
111
112 // Upgrade nodes in the GraphDef.
113 for (auto& node_def : *gdef.mutable_node()) {
114 const OpDef* op_def = nullptr;
115 TF_RETURN_IF_ERROR(flib_def->LookUpOpDef(node_def.op(), &op_def));
116 // TODO(b/197144710): improve the shared_name attr, each op may use the
117 // shared_name differently.
118 if (is_resource_op_with_empty_shared_name(node_def, *op_def)) {
119 (*node_def.mutable_attr())["shared_name"].set_s(node_def.name());
120 }
121 }
122
123 return OkStatus();
124 }
125
IsCompiledNode(const Node * n)126 bool IsCompiledNode(const Node* n) {
127 return n->attrs().Find(tensorflow::kTpuReplicateAttr) ||
128 n->attrs().Find(tensorflow::kCompileDeviceTypeAttr);
129 }
130
UpgradeLegacyGraph(Graph * graph,FunctionLibraryDefinition * flib_def,bool restrict_functionalization_to_compiled_nodes)131 Status UpgradeLegacyGraph(Graph* graph, FunctionLibraryDefinition* flib_def,
132 bool restrict_functionalization_to_compiled_nodes) {
133 NodeFilter node_filter = restrict_functionalization_to_compiled_nodes
134 ? IsCompiledNode
135 : NodeFilter{};
136 TF_RETURN_WITH_CONTEXT_IF_ERROR(
137 FunctionalizeControlFlow(graph, flib_def, node_filter,
138 /*include_functions=*/true),
139 "Failed to functionalize Control Flow V1 ops. Consider using Control "
140 "Flow V2 ops instead. See https://www.tensorflow.org/api_docs/python/tf/"
141 "compat/v1/enable_control_flow_v2.");
142 return OkStatus();
143 }
144
145 } // namespace tensorflow
146