1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/mlir/tensorflow/transforms/tf_saved_model_freeze_variables.h"
17
18 #include <utility>
19 #include <vector>
20
21 #include "llvm/ADT/ArrayRef.h"
22 #include "llvm/ADT/BitVector.h"
23 #include "llvm/ADT/DenseSet.h"
24 #include "llvm/ADT/STLExtras.h"
25 #include "llvm/ADT/SetVector.h"
26 #include "llvm/ADT/SmallVector.h"
27 #include "llvm/ADT/StringRef.h"
28 #include "llvm/Support/Casting.h"
29 #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
30 #include "mlir/IR/Builders.h" // from @llvm-project
31 #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
32 #include "mlir/IR/Location.h" // from @llvm-project
33 #include "mlir/IR/Value.h" // from @llvm-project
34 #include "mlir/Interfaces/CallInterfaces.h" // from @llvm-project
35 #include "mlir/Pass/Pass.h" // from @llvm-project
36 #include "tensorflow/compiler/mlir/tensorflow/analysis/resource_value_typed_analyzer.h"
37 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
38 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h"
39 #include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h"
40 #include "tensorflow/compiler/mlir/tensorflow/utils/session_utils.h"
41 #include "tensorflow/core/framework/resource_var.h"
42 #include "tensorflow/core/framework/types.pb.h"
43 #include "tensorflow/core/public/session.h"
44
45 namespace mlir {
46 namespace tf_saved_model {
47 namespace {
48
49 // Build and returns ElementsAttr which holds the data in 'tensor'.
GetTensorValueAsElementsAttr(const tensorflow::Tensor & tensor,OpBuilder builder)50 ElementsAttr GetTensorValueAsElementsAttr(const tensorflow::Tensor& tensor,
51 OpBuilder builder) {
52 tensorflow::StatusOr<ElementsAttr> tensor_attr_or =
53 tensorflow::ConvertTensor(tensor, &builder);
54 if (!tensor_attr_or.ok()) return nullptr;
55 return tensor_attr_or.ValueOrDie();
56 }
57
58 // Creates a constant op that holds 'tensor_elements'.
GetConstOpFromElementsAttr(ElementsAttr tensor_elements,OpBuilder builder,Location loc)59 TF::ConstOp GetConstOpFromElementsAttr(ElementsAttr tensor_elements,
60 OpBuilder builder, Location loc) {
61 return builder.create<TF::ConstOp>(loc, tensor_elements);
62 }
63
64 // Returns ElementsAttr which has the value held by 'resource_tensor'.
GetTensorValueAsElementsAttr(TF::VarHandleOp var_handle_op,const tensorflow::Tensor & resource_tensor,const tensorflow::DeviceMgr * mgr,OpBuilder builder)65 ElementsAttr GetTensorValueAsElementsAttr(
66 TF::VarHandleOp var_handle_op, const tensorflow::Tensor& resource_tensor,
67 const tensorflow::DeviceMgr* mgr, OpBuilder builder) {
68 if (resource_tensor.dtype() != tensorflow::DT_RESOURCE) {
69 return GetTensorValueAsElementsAttr(resource_tensor, builder);
70 }
71
72 auto handle = resource_tensor.scalar<tensorflow::ResourceHandle>()();
73 auto* var_ptr = tf_saved_model::GetVariableFromSession(var_handle_op,
74 handle.device(), mgr);
75 if (!var_ptr) {
76 return nullptr;
77 }
78 tensorflow::core::RefCountPtr<tensorflow::Var> var(var_ptr);
79 auto* tensor = var_ptr->tensor();
80
81 return GetTensorValueAsElementsAttr(*tensor, builder);
82 }
83
84 // Replace usage of 'read_variable_op' with 'value'.
PropagateUsage(TF::ReadVariableOp read_variable_op,ElementsAttr value)85 void PropagateUsage(TF::ReadVariableOp read_variable_op, ElementsAttr value) {
86 OpBuilder builder(read_variable_op);
87 read_variable_op->getResult(0).replaceAllUsesWith(
88 GetConstOpFromElementsAttr(value, builder, read_variable_op->getLoc()));
89 }
90
91 // Propagates a resource usage across the graph where
92 // 'user_op' uses a resource and is passed to this op at 'argument_index'.
93 // This resource should be replaced by 'value'.
94 // Output params:
95 // - work_list: Is updated with new regions to process that is called
96 // by 'user_op';
97 // - arguments_to_erase: Captures updates to the graph - which arguments
98 // to remove from the op;
PropagateUsage(Operation * user_op,int argument_index,ElementsAttr value,llvm::SmallVector<std::pair<Region *,int>,4> * work_list,llvm::MapVector<Operation *,llvm::SmallVector<unsigned int,4>> * arguments_to_erase)99 void PropagateUsage(
100 Operation* user_op, int argument_index, ElementsAttr value,
101 llvm::SmallVector<std::pair<Region*, int>, 4>* work_list,
102 llvm::MapVector<Operation*, llvm::SmallVector<unsigned int, 4>>*
103 arguments_to_erase) {
104 if (auto read_variable_op = dyn_cast<TF::ReadVariableOp>(user_op)) {
105 (*arguments_to_erase)[read_variable_op];
106 PropagateUsage(read_variable_op, value);
107 } else if (auto call = dyn_cast<CallOpInterface>(user_op)) {
108 (*arguments_to_erase)[call].push_back(argument_index);
109 if (auto func = dyn_cast<func::FuncOp>(call.resolveCallable())) {
110 (*arguments_to_erase)[func].push_back(argument_index);
111 work_list->push_back(std::make_pair(&func.getRegion(), argument_index));
112 }
113 } else if (auto if_op = dyn_cast<TF::IfOp>(user_op)) {
114 (*arguments_to_erase)[if_op].push_back(argument_index);
115 for (auto callee : {if_op.then_function(), if_op.else_function()}) {
116 (*arguments_to_erase)[callee].push_back(argument_index - 1);
117 work_list->push_back(
118 std::make_pair(&callee.getBody(), argument_index - 1));
119 }
120 } else if (auto if_op = dyn_cast<TF::IfRegionOp>(user_op)) {
121 (*arguments_to_erase)[if_op].push_back(argument_index);
122 for (auto callee : {&if_op.then_branch(), &if_op.else_branch()}) {
123 work_list->push_back(std::make_pair(callee, argument_index));
124 }
125 } else if (auto while_op = dyn_cast<TF::WhileOp>(user_op)) {
126 (*arguments_to_erase)[while_op].push_back(argument_index);
127 for (auto callee : {while_op.cond_function(), while_op.body_function()}) {
128 (*arguments_to_erase)[callee].push_back(argument_index);
129 work_list->push_back(std::make_pair(&callee.getBody(), argument_index));
130 }
131 } else if (auto while_op = dyn_cast<TF::WhileRegionOp>(user_op)) {
132 (*arguments_to_erase)[while_op].push_back(argument_index);
133 for (auto callee : {&while_op.cond(), &while_op.body()}) {
134 work_list->push_back(std::make_pair(callee, argument_index));
135 }
136 }
137 }
138
139 // An override that takes region.
PropagateUsage(Region * region,ElementsAttr value,int argument_index,llvm::SmallVector<std::pair<Region *,int>,4> * work_list,llvm::MapVector<Operation *,llvm::SmallVector<unsigned int,4>> * arguments_to_erase)140 void PropagateUsage(
141 Region* region, ElementsAttr value, int argument_index,
142 llvm::SmallVector<std::pair<Region*, int>, 4>* work_list,
143 llvm::MapVector<Operation*, llvm::SmallVector<unsigned int, 4>>*
144 arguments_to_erase) {
145 auto arg = region->getArgument(argument_index);
146 for (auto& usage : arg.getUses()) {
147 auto* user_op = usage.getOwner();
148 int operand_index = usage.getOperandNumber();
149 PropagateUsage(user_op, operand_index, value, work_list,
150 arguments_to_erase);
151 }
152 }
153
154 // Traces usage of 'var_handle_op' and replaces it's usage with constant value
155 // 'value'.
156 // All op operands updates are captured in 'arguments_to_erase'.
ReplaceVarWithConstant(TF::VarHandleOp var_handle_op,ElementsAttr value,llvm::MapVector<Operation *,llvm::SmallVector<unsigned int,4>> * arguments_to_erase)157 void ReplaceVarWithConstant(
158 TF::VarHandleOp var_handle_op, ElementsAttr value,
159 llvm::MapVector<Operation*, llvm::SmallVector<unsigned int, 4>>*
160 arguments_to_erase) {
161 llvm::SmallVector<std::pair<Region*, int>, 4> work_list;
162 for (auto& usage : var_handle_op->getUses()) {
163 auto* user_op = usage.getOwner();
164 int operand_index = usage.getOperandNumber();
165 PropagateUsage(user_op, operand_index, value, &work_list,
166 arguments_to_erase);
167 }
168 // Container to mark visited regions to avoid infinite loop.
169 llvm::DenseSet<std::pair<Region*, int>> visited;
170 while (!work_list.empty()) {
171 auto work_item = work_list.pop_back_val();
172 if (visited.contains(work_item)) continue;
173 PropagateUsage(work_item.first, value, work_item.second, &work_list,
174 arguments_to_erase);
175 visited.insert(work_item);
176 }
177 }
178
179 // Helper that returns the FuncOp that is the SessionInit function which
180 // will be called to initialize all resources.
181 // Returns nullptr if no function is found.
GetSessionInitializerFunc(ModuleOp module)182 func::FuncOp GetSessionInitializerFunc(ModuleOp module) {
183 auto session_init_op = tf_saved_model::GetSessionInitializerOp(module);
184 SymbolTable symbol_table(module);
185 if (session_init_op && !session_init_op.initializers().empty()) {
186 func::FuncOp init_func_op = symbol_table.lookup<mlir::func::FuncOp>(
187 session_init_op.initializers()[0].cast<FlatSymbolRefAttr>().getValue());
188 return init_func_op;
189 }
190 return nullptr;
191 }
192
193 // Returns ID for identifying a resource.
GetResourceKey(Operation * op)194 std::tuple<llvm::StringRef, llvm::StringRef, llvm::StringRef> GetResourceKey(
195 Operation* op) {
196 llvm::StringRef device;
197 if (auto attr = op->getAttrOfType<mlir::StringAttr>("device")) {
198 device = attr.getValue();
199 }
200
201 llvm::StringRef container;
202 if (auto attr = op->getAttrOfType<mlir::StringAttr>("container")) {
203 container = attr.getValue();
204 }
205
206 llvm::StringRef shared_name;
207 if (auto attr = op->getAttrOfType<mlir::StringAttr>("shared_name")) {
208 shared_name = attr.getValue();
209 }
210
211 return std::tuple<llvm::StringRef, llvm::StringRef, llvm::StringRef>{
212 device, container, shared_name};
213 }
214
215 // Remove the initialization of the variables in 'var_handle_ops' from
216 // the session init function 'sesion_init_func'
RemoveVariablesInitializations(const llvm::SmallVector<TF::VarHandleOp,4> & var_handle_ops,func::FuncOp sesion_init_func)217 void RemoveVariablesInitializations(
218 const llvm::SmallVector<TF::VarHandleOp, 4>& var_handle_ops,
219 func::FuncOp sesion_init_func) {
220 // We identify the variables using (device, container, shared_name) of the
221 // resource. Capture them here and use them to identify the useless
222 // initializations.
223 llvm::SetVector<std::tuple<llvm::StringRef, llvm::StringRef, llvm::StringRef>>
224 variables;
225 for (auto var_handle_op : var_handle_ops)
226 variables.insert(GetResourceKey(var_handle_op));
227
228 llvm::SmallVector<Operation*, 4> work_list;
229 for (auto var_handle_op : sesion_init_func.getOps<TF::VarHandleOp>()) {
230 if (variables.count(GetResourceKey(var_handle_op)))
231 work_list.push_back(var_handle_op);
232 }
233
234 // Capture list of ops to be erased by traversing usage starting from
235 // the VarHandle ops.
236 llvm::SetVector<Operation*> erase_list;
237 while (!work_list.empty()) {
238 auto* operation = work_list.pop_back_val();
239 erase_list.insert(operation);
240 for (auto& use : operation->getUses()) {
241 if (erase_list.count(use.getOwner())) continue;
242 work_list.push_back(use.getOwner());
243 }
244 }
245
246 for (auto* op : erase_list) {
247 op->dropAllUses();
248 op->erase();
249 }
250 }
251
252 // Updates terminator op arguments of 'func' after removing arguments
253 // specified in 'arguments_to_erase'.
254 template <typename T>
UpdateTerminatorArguments(T & func,const llvm::SmallVector<unsigned,4> & arguments_to_erase,llvm::BitVector & erase_indices)255 void UpdateTerminatorArguments(
256 T& func, const llvm::SmallVector<unsigned, 4>& arguments_to_erase,
257 llvm::BitVector& erase_indices) {
258 auto terminator = func.front().getTerminator();
259 int num_operands = terminator->getNumOperands();
260 erase_indices.resize(num_operands);
261 for (auto arg_index : arguments_to_erase) {
262 auto argument = func.getArgument(arg_index);
263 for (auto& use : argument.getUses()) {
264 if (llvm::isa<func::ReturnOp, TF::YieldOp>(use.getOwner())) {
265 int operand_index = use.getOperandNumber();
266 erase_indices.set(operand_index);
267 }
268 }
269 func.getArgument(arg_index).dropAllUses();
270 }
271 if (llvm::isa<func::ReturnOp, TF::YieldOp>(func.front().getTerminator())) {
272 terminator->eraseOperands(erase_indices);
273 }
274 }
275
276 // Updates 'while_op' signatures based on which arguments should be removed
277 // in 'arguments_to_erase'.
278 template <typename T, typename U>
GetUpdatedWhileOp(T while_op,const U & argument_types,const llvm::SmallVector<unsigned,4> & arguments_to_erase)279 T GetUpdatedWhileOp(T while_op, const U& argument_types,
280 const llvm::SmallVector<unsigned, 4>& arguments_to_erase) {
281 OpBuilder builder(while_op);
282 llvm::SmallVector<Type, 4> new_operand_types;
283 llvm::SmallVector<Value> new_operands;
284 auto operands = while_op->getOperands();
285 const int num_operands = while_op->getNumOperands();
286 llvm::BitVector skip_indices(num_operands);
287 for (int i : arguments_to_erase) skip_indices.set(i);
288 for (int i = 0; i < num_operands; ++i) {
289 if (!skip_indices.test(i)) {
290 new_operand_types.emplace_back(argument_types[i]);
291 new_operands.emplace_back(operands[i]);
292 }
293 }
294 auto new_while_op = builder.create<T>(while_op->getLoc(), new_operand_types,
295 new_operands, while_op->getAttrs());
296 int new_index = 0;
297 for (int i = 0; i < num_operands; ++i) {
298 if (!skip_indices.test(i)) {
299 while_op->getResult(i).replaceAllUsesWith(
300 new_while_op->getResult(new_index++));
301 }
302 }
303 return new_while_op;
304 }
305
306 } // namespace
307
FreezeVariables(ModuleOp module,tensorflow::Session * session)308 LogicalResult FreezeVariables(ModuleOp module, tensorflow::Session* session) {
309 const tensorflow::DeviceMgr* mgr = nullptr;
310 auto status = session->LocalDeviceManager(&mgr);
311 if (!status.ok()) {
312 module->emitError("failed to fetch device manager: " +
313 status.error_message());
314 return failure();
315 }
316
317 func::FuncOp session_init_func = GetSessionInitializerFunc(module);
318
319 TF::ResourceAnalyzer analyzer(module, /*skip_session_init=*/true);
320 llvm::SmallVector<TF::VarHandleOp, 4> variables;
321 // Capture list of all read only variables.
322 for (auto func : module.getOps<func::FuncOp>()) {
323 if (func == session_init_func) continue;
324 for (auto var_handle_op : func.getOps<TF::VarHandleOp>()) {
325 if (!analyzer.IsPotentiallyWritten(var_handle_op.resource())) {
326 variables.push_back(var_handle_op);
327 }
328 }
329 }
330
331 // Fetch the values to replace the VarHandleOps with.
332 auto resource_tensors_or =
333 tf_saved_model::GetResourcesFromSession(variables, session);
334 if (!resource_tensors_or.ok()) {
335 module->emitError(resource_tensors_or.status().message().data());
336 return failure();
337 }
338
339 auto* context = module.getContext();
340 OpBuilder builder(context);
341 // Note: We can't modify the graph while navigating through it, as erasing
342 // invalidate pointers.
343 // So instead we capture all the updates in the below map, and then
344 // process them after.
345
346 // Container to hold all update actions on ops.
347 // Key: Operation to update.
348 // Value: optional list of arguments to delete from this op.
349 // Note that we use MapVector because we want to iterate on the same order
350 // of insertion.
351 llvm::MapVector<Operation*, llvm::SmallVector<unsigned int, 4>>
352 arguments_to_erase;
353 for (auto variable_value_pair :
354 llvm::zip(variables, resource_tensors_or.value())) {
355 auto var_handle_op = std::get<0>(variable_value_pair);
356 builder.setInsertionPointAfterValue(var_handle_op);
357 auto elements_attr = GetTensorValueAsElementsAttr(
358 var_handle_op, std::get<1>(variable_value_pair), mgr, builder);
359 ReplaceVarWithConstant(var_handle_op, elements_attr, &arguments_to_erase);
360 }
361
362 // All updates to different ops are captured in 'arguments_to_erase'.
363 // Now loop on them and based on each item type update accordingly.
364 for (auto& items : arguments_to_erase) {
365 auto* user_op = items.first;
366 auto& args_to_erase = items.second;
367 if (auto func = dyn_cast<func::FuncOp>(user_op)) {
368 // To update a function we will need to:
369 // 1) Remove the unused arguments from the function itself.
370 // 2) Remove any returns that are not needed from the function terminator
371 // op in the function. 3) Update function result to match the terminator.
372 llvm::BitVector result_indices_to_erase;
373 UpdateTerminatorArguments(func, args_to_erase, result_indices_to_erase);
374 llvm::BitVector args_to_erase_bit_vector(func.getNumArguments());
375 for (auto i : args_to_erase) args_to_erase_bit_vector.set(i);
376 func.eraseArguments(args_to_erase_bit_vector);
377 llvm::BitVector indices_to_erase(func.getNumResults());
378 const int indices_to_erase_size = result_indices_to_erase.size();
379 for (int i = 0; i < indices_to_erase_size; ++i)
380 if (result_indices_to_erase.test(i)) indices_to_erase.set(i);
381 func.eraseResults(indices_to_erase);
382 } else if (auto read_var = dyn_cast<TF::ReadVariableOp>(user_op)) {
383 // Read variables was already replaced by constant op. Just remove the op.
384 read_var->erase();
385 } else if (auto while_op = dyn_cast<TF::WhileOp>(user_op)) {
386 GetUpdatedWhileOp<TF::WhileOp>(
387 while_op, while_op.cond_function().getArgumentTypes(), args_to_erase);
388 while_op->erase();
389 } else if (auto while_op = dyn_cast<TF::WhileRegionOp>(user_op)) {
390 auto new_while_op = GetUpdatedWhileOp(
391 while_op, while_op.cond().getArgumentTypes(), args_to_erase);
392 new_while_op.cond().takeBody(while_op.cond());
393 new_while_op.body().takeBody(while_op.body());
394 llvm::BitVector erase_indices;
395 UpdateTerminatorArguments(new_while_op.body(), args_to_erase,
396 erase_indices);
397 llvm::BitVector body_bit_vector(
398 new_while_op.body().front().getNumArguments());
399 for (auto i : args_to_erase) body_bit_vector.set(i);
400 new_while_op.body().front().eraseArguments(body_bit_vector);
401 llvm::BitVector cond_bit_vector(
402 new_while_op.cond().front().getNumArguments());
403 for (auto i : args_to_erase) cond_bit_vector.set(i);
404 new_while_op.cond().front().eraseArguments(cond_bit_vector);
405 while_op->erase();
406 } else {
407 llvm::BitVector erase_indices(user_op->getNumOperands());
408 for (auto operand_index : args_to_erase) {
409 erase_indices.set(operand_index);
410 }
411 user_op->eraseOperands(erase_indices);
412 }
413 }
414
415 // Remove initialization of unused variables.
416 if (session_init_func)
417 RemoveVariablesInitializations(variables, session_init_func);
418
419 // Remove the unused VarHandleOp.
420 for (auto var_handle_op : variables) {
421 if (var_handle_op) var_handle_op->erase();
422 }
423 return success();
424 }
425
426 } // namespace tf_saved_model
427 } // namespace mlir
428