1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "llvm/ADT/ArrayRef.h"
17 #include "llvm/ADT/STLExtras.h"
18 #include "llvm/ADT/SmallVector.h"
19 #include "llvm/ADT/StringExtras.h"
20 #include "llvm/ADT/StringRef.h"
21 #include "llvm/Support/Casting.h"
22 #include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
23 #include "mlir/IR/Attributes.h"  // from @llvm-project
24 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
25 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
26 #include "mlir/IR/Types.h"  // from @llvm-project
27 #include "mlir/IR/Value.h"  // from @llvm-project
28 #include "mlir/Pass/Pass.h"  // from @llvm-project
29 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
30 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
31 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
32 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes_detail.h"
33 
34 namespace mlir {
35 namespace TF {
36 namespace {
37 
38 // Location attribute.
39 constexpr StringRef kClassAttr = "_class";
40 constexpr StringRef kSharedNameAttr = "shared_name";
41 constexpr StringRef kLocationPrefix = "loc:@";
42 
43 // A pass that converts readonly reference variables to the corresponding
44 // resource variables.
45 //
46 // It converts (VariableV2 -> Identity) to (VarHandle -> ReadVariable).
47 //
48 // For the background, this pass is a part of hoisting VariableV2 ops by
49 // re-using the pipeline for hoisting (VarHandle -> ReadVariable) cases, which
50 //  can be done by the following passes:
51 //  - Capturing resource values into global tensors (importing saved model).
52 //  - Promoting VarHandle ops to function input/outputs.
53 //  - Freezing global tensor pass.
54 //
55 // This path assumes that all the VariableV2 ops is read-only via verifying the
56 // heuristic method that assumes that all the users of them is Identity op,
57 // fed directly.
58 class ConvertReadonlyReferenceVariablesToResourceVariablesPass
59     : public ConvertReadonlyReferenceVariablesToResourceVariablesPassBase<
60           ConvertReadonlyReferenceVariablesToResourceVariablesPass> {
61   void runOnOperation() override;
62 };
63 
64 // Parse node name from "_class" or "shared_name" attributes.
GetNodeNameFromClassAttrOrSharedNameAttr(Operation * op)65 StringRef GetNodeNameFromClassAttrOrSharedNameAttr(Operation *op) {
66   // Parse node name from the `shared_name` attribute first. The variable v2 op
67   // relies on the share name to look up from the TensorFlow's resource manager.
68   StringAttr shared_name_attr = op->getAttrOfType<StringAttr>(kSharedNameAttr);
69   if (shared_name_attr) {
70     auto shared_name = StringRef(shared_name_attr.getValue());
71     if (!shared_name.empty()) {
72       return shared_name;
73     }
74   }
75   // Attempt to parse "_class" attribute if there is no "shared_name"
76   // attribute.
77   ArrayAttr classes_attr = op->getAttrOfType<ArrayAttr>(kClassAttr);
78   if (!classes_attr) {
79     // Attempt to parse "_class" from the IdentityOp that follows VariableV2.
80     // For read-only reference variables, IdentityOp should be the only user of
81     // VariableV2.
82     auto identity_op = op->getUsers().begin();
83     classes_attr = identity_op->getAttrOfType<ArrayAttr>(kClassAttr);
84     if (!classes_attr) {
85       op->emitOpError() << "has no '_class' and 'shared_name' attributes";
86       return StringRef();
87     }
88   }
89 
90   StringRef result;
91   for (Attribute class_attr : classes_attr) {
92     StringRef node_name = class_attr.cast<StringAttr>().getValue();
93     if (!node_name.startswith(kLocationPrefix)) {
94       continue;
95     }
96     if (!result.empty()) {
97       // Invalid case since there are multiple loc:@ attributes.
98       op->emitOpError()
99           << "expects only one named location in '_class' attribute, but got "
100           << classes_attr;
101       return StringRef();
102     }
103     result = node_name.drop_front(kLocationPrefix.size());
104   }
105   if (result.empty()) {
106     op->emitOpError() << "expects variable name in '_class' attribute, but got "
107                       << classes_attr;
108   }
109   return result;
110 }
111 
112 void ConvertReadonlyReferenceVariablesToResourceVariablesPass::
runOnOperation()113     runOnOperation() {
114   func::FuncOp func = getOperation();
115 
116   OpBuilder builder(func.getContext());
117   SmallVector<VariableV2Op, 4> variable_v2s_to_replace;
118 
119   // Checks all the VariableV2 ops is read-only via verifying the heuristic
120   // method that assumes that all the users of them is Identity op, feeded
121   // directly.
122   auto read_only_vars_fn = [&variable_v2s_to_replace](
123                                VariableV2Op variable_v2_op) {
124     if (variable_v2_op.getResult().use_empty()) {
125       // Erase the op when there is no user.
126       variable_v2_op.erase();
127       return mlir::WalkResult::advance();
128     }
129     if (!all_of(variable_v2_op.getResult().getUsers(), [&variable_v2_op](
130                                                            Operation *user) {
131           if (!isa<IdentityOp>(user)) {
132             variable_v2_op.emitOpError()
133                 << "expects all users to be 'tf.Identity', but got user "
134                 << user->getName();
135             return false;
136           }
137           return true;
138         })) {
139       return mlir::WalkResult::interrupt();
140     }
141     variable_v2s_to_replace.push_back(variable_v2_op);
142     return mlir::WalkResult::advance();
143   };
144 
145   WalkResult walk_res = func.walk(read_only_vars_fn);
146   if (walk_res.wasInterrupted()) return signalPassFailure();
147 
148   for (VariableV2Op variable_v2_op : variable_v2s_to_replace) {
149     builder.setInsertionPoint(variable_v2_op);
150     ShapedType shaped_type =
151         variable_v2_op.getResult().getType().cast<ShapedType>();
152     TensorType tensor_type = DropRefType(shaped_type).cast<TensorType>();
153     StringAttr device_attr =
154         variable_v2_op->getAttrOfType<StringAttr>("device");
155     if (!device_attr) device_attr = builder.getStringAttr("");
156     StringRef variable_name =
157         GetNodeNameFromClassAttrOrSharedNameAttr(variable_v2_op);
158     if (variable_name.empty()) {
159       return signalPassFailure();
160     }
161     VarHandleOp var_handle_op = builder.create<VarHandleOp>(
162         variable_v2_op.getLoc(),
163         ArrayRef<Type>{RankedTensorType::get(
164             {}, TF::ResourceType::get(ArrayRef<TensorType>{tensor_type},
165                                       builder.getContext()))},
166         ArrayRef<Value>{},
167         ArrayRef<NamedAttribute>{
168             builder.getNamedAttr("device", device_attr),
169             builder.getNamedAttr("container", variable_v2_op.containerAttr()),
170             builder.getNamedAttr("shared_name",
171                                  builder.getStringAttr(variable_name))});
172     for (Operation *user :
173          make_early_inc_range(variable_v2_op.getResult().getUsers())) {
174       builder.setInsertionPoint(user);
175       ReadVariableOp read_variable_op = builder.create<ReadVariableOp>(
176           user->getLoc(), ArrayRef<Type>{tensor_type},
177           ArrayRef<Value>{var_handle_op});
178       user->getResult(0).replaceAllUsesWith(read_variable_op.getResult());
179       user->erase();
180     }
181     variable_v2_op.erase();
182   }
183 }
184 
185 }  // namespace
186 
187 std::unique_ptr<OperationPass<func::FuncOp>>
CreateConvertReadonlyReferenceVariablesToResourceVariablesPass()188 CreateConvertReadonlyReferenceVariablesToResourceVariablesPass() {
189   return std::make_unique<
190       ConvertReadonlyReferenceVariablesToResourceVariablesPass>();
191 }
192 
193 }  // namespace TF
194 
195 }  // namespace mlir
196