xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/mlir/tensorflow/translate/upgrade_graph.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/mlir/tensorflow/translate/upgrade_graph.h"
17 
18 #include "llvm/ADT/StringSet.h"
19 #include "tensorflow/compiler/mlir/tensorflow/utils/attribute_utils.h"
20 #include "tensorflow/compiler/tf2xla/functionalize_control_flow.h"
21 #include "tensorflow/core/common_runtime/device.h"
22 #include "tensorflow/core/common_runtime/device_factory.h"
23 #include "tensorflow/core/common_runtime/device_mgr.h"
24 #include "tensorflow/core/grappler/clusters/virtual_cluster.h"
25 #include "tensorflow/core/grappler/grappler_item.h"
26 #include "tensorflow/core/grappler/grappler_item_builder.h"
27 #include "tensorflow/core/grappler/optimizers/meta_optimizer.h"
28 #include "tensorflow/core/protobuf/meta_graph.pb.h"
29 
30 namespace tensorflow {
31 namespace {
32 
33 // Returns the ops that should use node name if shared_name is empty.
GetOpsUsingNodeName()34 const llvm::StringSet<>& GetOpsUsingNodeName() {
35   static auto* const ops =
36       new llvm::StringSet<>({"VariableV2", "Variable", "BatchFunction"});
37   return *ops;
38 }
39 
40 // Returns the set of ops that we want to generate shared_names for them if
41 // empty.
GetSharedNameGenerationCompatibleOps()42 const llvm::StringSet<>& GetSharedNameGenerationCompatibleOps() {
43   return GetOpsUsingNodeName();
44 }
45 
46 }  // namespace
47 
GenerateResourceSharedNameIfEmpty(GraphDef & gdef,const OpRegistryInterface * default_registry)48 Status GenerateResourceSharedNameIfEmpty(
49     GraphDef& gdef, const OpRegistryInterface* default_registry) {
50   auto is_resource_op_with_empty_shared_name = [](const NodeDef& node_def,
51                                                   const OpDef& op_def) {
52     if (!GetSharedNameGenerationCompatibleOps().contains(op_def.name())) {
53       // If this op is not in the allowlist, then it is likely a custom op.
54       // Currently for these ops, we are relying on its "use_node_name_sharing"
55       // to decide whether it is valid to generate shared_names. If the OpDef
56       // has "use_node_name_sharing" field, then it is valid to use node names
57       // as shared names.
58       if (!std::any_of(op_def.attr().begin(), op_def.attr().end(),
59                        [](const auto& attr_def) {
60                          return attr_def.name() == "use_node_name_sharing" &&
61                                 attr_def.type() == "bool";
62                        }))
63         return false;
64     }
65 
66     if (!std::any_of(op_def.attr().begin(), op_def.attr().end(),
67                      [](const auto& attr_def) {
68                        return attr_def.name() == "shared_name" &&
69                               attr_def.type() == "string";
70                      }))
71       return false;
72 
73     auto iter = node_def.attr().find("shared_name");
74     if (iter == node_def.attr().end()) return true;
75     return iter->second.s().empty();
76   };
77 
78   FunctionDefLibrary* library = gdef.mutable_library();
79   auto flib_def = library ? std::make_unique<FunctionLibraryDefinition>(
80                                 default_registry, *library)
81                           : std::make_unique<FunctionLibraryDefinition>(
82                                 default_registry, FunctionDefLibrary());
83 
84   if (library) {
85     // Upgrade nodes in the functions.
86     for (FunctionDef& fdef : *library->mutable_function()) {
87       auto func_name = fdef.signature().name();
88       for (auto& node_def : *fdef.mutable_node_def()) {
89         const OpDef* op_def = nullptr;
90         // With lazy loading, some functions might not be executed, thus we skip
91         // the node if the op is not registered.
92         if (flib_def->LookUpOpDef(node_def.op(), &op_def).ok() &&
93             is_resource_op_with_empty_shared_name(node_def, *op_def)) {
94           // TODO(b/197144710): improve the shared_name attr, each op may use
95           // the shared_name differently.
96           if (GetOpsUsingNodeName().contains(op_def->name())) {
97             // Use the node name for such ops as the shared_name according to
98             // the document of variable ops.
99             (*node_def.mutable_attr())["shared_name"].set_s(node_def.name());
100           } else {
101             // Use the concat of function name and node name for such ops in a
102             // function as the shared_name. "@" is used as the separator because
103             // it is not allowed in the function name or the node name.
104             (*node_def.mutable_attr())["shared_name"].set_s(
105                 absl::StrCat(node_def.name(), "@", func_name));
106           }
107         }
108       }
109     }
110   }
111 
112   // Upgrade nodes in the GraphDef.
113   for (auto& node_def : *gdef.mutable_node()) {
114     const OpDef* op_def = nullptr;
115     TF_RETURN_IF_ERROR(flib_def->LookUpOpDef(node_def.op(), &op_def));
116     // TODO(b/197144710): improve the shared_name attr, each op may use the
117     // shared_name differently.
118     if (is_resource_op_with_empty_shared_name(node_def, *op_def)) {
119       (*node_def.mutable_attr())["shared_name"].set_s(node_def.name());
120     }
121   }
122 
123   return OkStatus();
124 }
125 
IsCompiledNode(const Node * n)126 bool IsCompiledNode(const Node* n) {
127   return n->attrs().Find(tensorflow::kTpuReplicateAttr) ||
128          n->attrs().Find(tensorflow::kCompileDeviceTypeAttr);
129 }
130 
UpgradeLegacyGraph(Graph * graph,FunctionLibraryDefinition * flib_def,bool restrict_functionalization_to_compiled_nodes)131 Status UpgradeLegacyGraph(Graph* graph, FunctionLibraryDefinition* flib_def,
132                           bool restrict_functionalization_to_compiled_nodes) {
133   NodeFilter node_filter = restrict_functionalization_to_compiled_nodes
134                                ? IsCompiledNode
135                                : NodeFilter{};
136   TF_RETURN_WITH_CONTEXT_IF_ERROR(
137       FunctionalizeControlFlow(graph, flib_def, node_filter,
138                                /*include_functions=*/true),
139       "Failed to functionalize Control Flow V1 ops. Consider using Control "
140       "Flow V2 ops instead. See https://www.tensorflow.org/api_docs/python/tf/"
141       "compat/v1/enable_control_flow_v2.");
142   return OkStatus();
143 }
144 
145 }  // namespace tensorflow
146