1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/mlir/tensorflow/transforms/tf_saved_model_freeze_variables.h"
17 
18 #include <utility>
19 #include <vector>
20 
21 #include "llvm/ADT/ArrayRef.h"
22 #include "llvm/ADT/BitVector.h"
23 #include "llvm/ADT/DenseSet.h"
24 #include "llvm/ADT/STLExtras.h"
25 #include "llvm/ADT/SetVector.h"
26 #include "llvm/ADT/SmallVector.h"
27 #include "llvm/ADT/StringRef.h"
28 #include "llvm/Support/Casting.h"
29 #include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
30 #include "mlir/IR/Builders.h"  // from @llvm-project
31 #include "mlir/IR/BuiltinAttributes.h"  // from @llvm-project
32 #include "mlir/IR/Location.h"  // from @llvm-project
33 #include "mlir/IR/Value.h"  // from @llvm-project
34 #include "mlir/Interfaces/CallInterfaces.h"  // from @llvm-project
35 #include "mlir/Pass/Pass.h"  // from @llvm-project
36 #include "tensorflow/compiler/mlir/tensorflow/analysis/resource_value_typed_analyzer.h"
37 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
38 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h"
39 #include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h"
40 #include "tensorflow/compiler/mlir/tensorflow/utils/session_utils.h"
41 #include "tensorflow/core/framework/resource_var.h"
42 #include "tensorflow/core/framework/types.pb.h"
43 #include "tensorflow/core/public/session.h"
44 
45 namespace mlir {
46 namespace tf_saved_model {
47 namespace {
48 
49 // Build and returns ElementsAttr which holds the data in 'tensor'.
GetTensorValueAsElementsAttr(const tensorflow::Tensor & tensor,OpBuilder builder)50 ElementsAttr GetTensorValueAsElementsAttr(const tensorflow::Tensor& tensor,
51                                           OpBuilder builder) {
52   tensorflow::StatusOr<ElementsAttr> tensor_attr_or =
53       tensorflow::ConvertTensor(tensor, &builder);
54   if (!tensor_attr_or.ok()) return nullptr;
55   return tensor_attr_or.ValueOrDie();
56 }
57 
58 // Creates a constant op that holds 'tensor_elements'.
GetConstOpFromElementsAttr(ElementsAttr tensor_elements,OpBuilder builder,Location loc)59 TF::ConstOp GetConstOpFromElementsAttr(ElementsAttr tensor_elements,
60                                        OpBuilder builder, Location loc) {
61   return builder.create<TF::ConstOp>(loc, tensor_elements);
62 }
63 
64 // Returns ElementsAttr which has the value held by 'resource_tensor'.
GetTensorValueAsElementsAttr(TF::VarHandleOp var_handle_op,const tensorflow::Tensor & resource_tensor,const tensorflow::DeviceMgr * mgr,OpBuilder builder)65 ElementsAttr GetTensorValueAsElementsAttr(
66     TF::VarHandleOp var_handle_op, const tensorflow::Tensor& resource_tensor,
67     const tensorflow::DeviceMgr* mgr, OpBuilder builder) {
68   if (resource_tensor.dtype() != tensorflow::DT_RESOURCE) {
69     return GetTensorValueAsElementsAttr(resource_tensor, builder);
70   }
71 
72   auto handle = resource_tensor.scalar<tensorflow::ResourceHandle>()();
73   auto* var_ptr = tf_saved_model::GetVariableFromSession(var_handle_op,
74                                                          handle.device(), mgr);
75   if (!var_ptr) {
76     return nullptr;
77   }
78   tensorflow::core::RefCountPtr<tensorflow::Var> var(var_ptr);
79   auto* tensor = var_ptr->tensor();
80 
81   return GetTensorValueAsElementsAttr(*tensor, builder);
82 }
83 
84 // Replace usage of 'read_variable_op' with 'value'.
PropagateUsage(TF::ReadVariableOp read_variable_op,ElementsAttr value)85 void PropagateUsage(TF::ReadVariableOp read_variable_op, ElementsAttr value) {
86   OpBuilder builder(read_variable_op);
87   read_variable_op->getResult(0).replaceAllUsesWith(
88       GetConstOpFromElementsAttr(value, builder, read_variable_op->getLoc()));
89 }
90 
91 // Propagates a resource usage across the graph where
92 // 'user_op' uses a resource and is passed to this op at 'argument_index'.
93 // This resource should be replaced by 'value'.
94 // Output params:
95 // - work_list: Is updated with new regions to process that is called
96 //   by 'user_op';
97 // - arguments_to_erase: Captures updates to the graph - which arguments
98 //   to remove from the op;
PropagateUsage(Operation * user_op,int argument_index,ElementsAttr value,llvm::SmallVector<std::pair<Region *,int>,4> * work_list,llvm::MapVector<Operation *,llvm::SmallVector<unsigned int,4>> * arguments_to_erase)99 void PropagateUsage(
100     Operation* user_op, int argument_index, ElementsAttr value,
101     llvm::SmallVector<std::pair<Region*, int>, 4>* work_list,
102     llvm::MapVector<Operation*, llvm::SmallVector<unsigned int, 4>>*
103         arguments_to_erase) {
104   if (auto read_variable_op = dyn_cast<TF::ReadVariableOp>(user_op)) {
105     (*arguments_to_erase)[read_variable_op];
106     PropagateUsage(read_variable_op, value);
107   } else if (auto call = dyn_cast<CallOpInterface>(user_op)) {
108     (*arguments_to_erase)[call].push_back(argument_index);
109     if (auto func = dyn_cast<func::FuncOp>(call.resolveCallable())) {
110       (*arguments_to_erase)[func].push_back(argument_index);
111       work_list->push_back(std::make_pair(&func.getRegion(), argument_index));
112     }
113   } else if (auto if_op = dyn_cast<TF::IfOp>(user_op)) {
114     (*arguments_to_erase)[if_op].push_back(argument_index);
115     for (auto callee : {if_op.then_function(), if_op.else_function()}) {
116       (*arguments_to_erase)[callee].push_back(argument_index - 1);
117       work_list->push_back(
118           std::make_pair(&callee.getBody(), argument_index - 1));
119     }
120   } else if (auto if_op = dyn_cast<TF::IfRegionOp>(user_op)) {
121     (*arguments_to_erase)[if_op].push_back(argument_index);
122     for (auto callee : {&if_op.then_branch(), &if_op.else_branch()}) {
123       work_list->push_back(std::make_pair(callee, argument_index));
124     }
125   } else if (auto while_op = dyn_cast<TF::WhileOp>(user_op)) {
126     (*arguments_to_erase)[while_op].push_back(argument_index);
127     for (auto callee : {while_op.cond_function(), while_op.body_function()}) {
128       (*arguments_to_erase)[callee].push_back(argument_index);
129       work_list->push_back(std::make_pair(&callee.getBody(), argument_index));
130     }
131   } else if (auto while_op = dyn_cast<TF::WhileRegionOp>(user_op)) {
132     (*arguments_to_erase)[while_op].push_back(argument_index);
133     for (auto callee : {&while_op.cond(), &while_op.body()}) {
134       work_list->push_back(std::make_pair(callee, argument_index));
135     }
136   }
137 }
138 
139 // An override that takes region.
PropagateUsage(Region * region,ElementsAttr value,int argument_index,llvm::SmallVector<std::pair<Region *,int>,4> * work_list,llvm::MapVector<Operation *,llvm::SmallVector<unsigned int,4>> * arguments_to_erase)140 void PropagateUsage(
141     Region* region, ElementsAttr value, int argument_index,
142     llvm::SmallVector<std::pair<Region*, int>, 4>* work_list,
143     llvm::MapVector<Operation*, llvm::SmallVector<unsigned int, 4>>*
144         arguments_to_erase) {
145   auto arg = region->getArgument(argument_index);
146   for (auto& usage : arg.getUses()) {
147     auto* user_op = usage.getOwner();
148     int operand_index = usage.getOperandNumber();
149     PropagateUsage(user_op, operand_index, value, work_list,
150                    arguments_to_erase);
151   }
152 }
153 
154 // Traces usage of 'var_handle_op' and replaces it's usage with constant value
155 // 'value'.
156 // All op operands updates are captured in 'arguments_to_erase'.
ReplaceVarWithConstant(TF::VarHandleOp var_handle_op,ElementsAttr value,llvm::MapVector<Operation *,llvm::SmallVector<unsigned int,4>> * arguments_to_erase)157 void ReplaceVarWithConstant(
158     TF::VarHandleOp var_handle_op, ElementsAttr value,
159     llvm::MapVector<Operation*, llvm::SmallVector<unsigned int, 4>>*
160         arguments_to_erase) {
161   llvm::SmallVector<std::pair<Region*, int>, 4> work_list;
162   for (auto& usage : var_handle_op->getUses()) {
163     auto* user_op = usage.getOwner();
164     int operand_index = usage.getOperandNumber();
165     PropagateUsage(user_op, operand_index, value, &work_list,
166                    arguments_to_erase);
167   }
168   // Container to mark visited regions to avoid infinite loop.
169   llvm::DenseSet<std::pair<Region*, int>> visited;
170   while (!work_list.empty()) {
171     auto work_item = work_list.pop_back_val();
172     if (visited.contains(work_item)) continue;
173     PropagateUsage(work_item.first, value, work_item.second, &work_list,
174                    arguments_to_erase);
175     visited.insert(work_item);
176   }
177 }
178 
179 // Helper that returns the FuncOp that is the SessionInit function which
180 // will be called to initialize all resources.
181 // Returns nullptr if no function is found.
GetSessionInitializerFunc(ModuleOp module)182 func::FuncOp GetSessionInitializerFunc(ModuleOp module) {
183   auto session_init_op = tf_saved_model::GetSessionInitializerOp(module);
184   SymbolTable symbol_table(module);
185   if (session_init_op && !session_init_op.initializers().empty()) {
186     func::FuncOp init_func_op = symbol_table.lookup<mlir::func::FuncOp>(
187         session_init_op.initializers()[0].cast<FlatSymbolRefAttr>().getValue());
188     return init_func_op;
189   }
190   return nullptr;
191 }
192 
193 // Returns ID for identifying a resource.
GetResourceKey(Operation * op)194 std::tuple<llvm::StringRef, llvm::StringRef, llvm::StringRef> GetResourceKey(
195     Operation* op) {
196   llvm::StringRef device;
197   if (auto attr = op->getAttrOfType<mlir::StringAttr>("device")) {
198     device = attr.getValue();
199   }
200 
201   llvm::StringRef container;
202   if (auto attr = op->getAttrOfType<mlir::StringAttr>("container")) {
203     container = attr.getValue();
204   }
205 
206   llvm::StringRef shared_name;
207   if (auto attr = op->getAttrOfType<mlir::StringAttr>("shared_name")) {
208     shared_name = attr.getValue();
209   }
210 
211   return std::tuple<llvm::StringRef, llvm::StringRef, llvm::StringRef>{
212       device, container, shared_name};
213 }
214 
215 // Remove the initialization of the variables in 'var_handle_ops' from
216 // the session init function 'sesion_init_func'
RemoveVariablesInitializations(const llvm::SmallVector<TF::VarHandleOp,4> & var_handle_ops,func::FuncOp sesion_init_func)217 void RemoveVariablesInitializations(
218     const llvm::SmallVector<TF::VarHandleOp, 4>& var_handle_ops,
219     func::FuncOp sesion_init_func) {
220   // We identify the variables using (device, container, shared_name) of the
221   // resource. Capture them here and use them to identify the useless
222   // initializations.
223   llvm::SetVector<std::tuple<llvm::StringRef, llvm::StringRef, llvm::StringRef>>
224       variables;
225   for (auto var_handle_op : var_handle_ops)
226     variables.insert(GetResourceKey(var_handle_op));
227 
228   llvm::SmallVector<Operation*, 4> work_list;
229   for (auto var_handle_op : sesion_init_func.getOps<TF::VarHandleOp>()) {
230     if (variables.count(GetResourceKey(var_handle_op)))
231       work_list.push_back(var_handle_op);
232   }
233 
234   // Capture list of ops to be erased by traversing usage starting from
235   // the VarHandle ops.
236   llvm::SetVector<Operation*> erase_list;
237   while (!work_list.empty()) {
238     auto* operation = work_list.pop_back_val();
239     erase_list.insert(operation);
240     for (auto& use : operation->getUses()) {
241       if (erase_list.count(use.getOwner())) continue;
242       work_list.push_back(use.getOwner());
243     }
244   }
245 
246   for (auto* op : erase_list) {
247     op->dropAllUses();
248     op->erase();
249   }
250 }
251 
252 // Updates terminator op arguments of 'func' after removing arguments
253 // specified in 'arguments_to_erase'.
254 template <typename T>
UpdateTerminatorArguments(T & func,const llvm::SmallVector<unsigned,4> & arguments_to_erase,llvm::BitVector & erase_indices)255 void UpdateTerminatorArguments(
256     T& func, const llvm::SmallVector<unsigned, 4>& arguments_to_erase,
257     llvm::BitVector& erase_indices) {
258   auto terminator = func.front().getTerminator();
259   int num_operands = terminator->getNumOperands();
260   erase_indices.resize(num_operands);
261   for (auto arg_index : arguments_to_erase) {
262     auto argument = func.getArgument(arg_index);
263     for (auto& use : argument.getUses()) {
264       if (llvm::isa<func::ReturnOp, TF::YieldOp>(use.getOwner())) {
265         int operand_index = use.getOperandNumber();
266         erase_indices.set(operand_index);
267       }
268     }
269     func.getArgument(arg_index).dropAllUses();
270   }
271   if (llvm::isa<func::ReturnOp, TF::YieldOp>(func.front().getTerminator())) {
272     terminator->eraseOperands(erase_indices);
273   }
274 }
275 
276 // Updates 'while_op' signatures based on which arguments should be removed
277 // in 'arguments_to_erase'.
278 template <typename T, typename U>
GetUpdatedWhileOp(T while_op,const U & argument_types,const llvm::SmallVector<unsigned,4> & arguments_to_erase)279 T GetUpdatedWhileOp(T while_op, const U& argument_types,
280                     const llvm::SmallVector<unsigned, 4>& arguments_to_erase) {
281   OpBuilder builder(while_op);
282   llvm::SmallVector<Type, 4> new_operand_types;
283   llvm::SmallVector<Value> new_operands;
284   auto operands = while_op->getOperands();
285   const int num_operands = while_op->getNumOperands();
286   llvm::BitVector skip_indices(num_operands);
287   for (int i : arguments_to_erase) skip_indices.set(i);
288   for (int i = 0; i < num_operands; ++i) {
289     if (!skip_indices.test(i)) {
290       new_operand_types.emplace_back(argument_types[i]);
291       new_operands.emplace_back(operands[i]);
292     }
293   }
294   auto new_while_op = builder.create<T>(while_op->getLoc(), new_operand_types,
295                                         new_operands, while_op->getAttrs());
296   int new_index = 0;
297   for (int i = 0; i < num_operands; ++i) {
298     if (!skip_indices.test(i)) {
299       while_op->getResult(i).replaceAllUsesWith(
300           new_while_op->getResult(new_index++));
301     }
302   }
303   return new_while_op;
304 }
305 
306 }  // namespace
307 
FreezeVariables(ModuleOp module,tensorflow::Session * session)308 LogicalResult FreezeVariables(ModuleOp module, tensorflow::Session* session) {
309   const tensorflow::DeviceMgr* mgr = nullptr;
310   auto status = session->LocalDeviceManager(&mgr);
311   if (!status.ok()) {
312     module->emitError("failed to fetch device manager: " +
313                       status.error_message());
314     return failure();
315   }
316 
317   func::FuncOp session_init_func = GetSessionInitializerFunc(module);
318 
319   TF::ResourceAnalyzer analyzer(module, /*skip_session_init=*/true);
320   llvm::SmallVector<TF::VarHandleOp, 4> variables;
321   // Capture list of all read only variables.
322   for (auto func : module.getOps<func::FuncOp>()) {
323     if (func == session_init_func) continue;
324     for (auto var_handle_op : func.getOps<TF::VarHandleOp>()) {
325       if (!analyzer.IsPotentiallyWritten(var_handle_op.resource())) {
326         variables.push_back(var_handle_op);
327       }
328     }
329   }
330 
331   // Fetch the values to replace the VarHandleOps with.
332   auto resource_tensors_or =
333       tf_saved_model::GetResourcesFromSession(variables, session);
334   if (!resource_tensors_or.ok()) {
335     module->emitError(resource_tensors_or.status().message().data());
336     return failure();
337   }
338 
339   auto* context = module.getContext();
340   OpBuilder builder(context);
341   // Note: We can't modify the graph while navigating through it, as erasing
342   // invalidate pointers.
343   // So instead we capture all the updates in the below map, and then
344   // process them after.
345 
346   // Container to hold all update actions on ops.
347   // Key: Operation to update.
348   // Value: optional list of arguments to delete from this op.
349   // Note that we use MapVector because we want to iterate on the same order
350   // of insertion.
351   llvm::MapVector<Operation*, llvm::SmallVector<unsigned int, 4>>
352       arguments_to_erase;
353   for (auto variable_value_pair :
354        llvm::zip(variables, resource_tensors_or.value())) {
355     auto var_handle_op = std::get<0>(variable_value_pair);
356     builder.setInsertionPointAfterValue(var_handle_op);
357     auto elements_attr = GetTensorValueAsElementsAttr(
358         var_handle_op, std::get<1>(variable_value_pair), mgr, builder);
359     ReplaceVarWithConstant(var_handle_op, elements_attr, &arguments_to_erase);
360   }
361 
362   // All updates to different ops are captured in 'arguments_to_erase'.
363   // Now loop on them and based on each item type update accordingly.
364   for (auto& items : arguments_to_erase) {
365     auto* user_op = items.first;
366     auto& args_to_erase = items.second;
367     if (auto func = dyn_cast<func::FuncOp>(user_op)) {
368       // To update a function we will need to:
369       // 1) Remove the unused arguments from the function itself.
370       // 2) Remove any returns that are not needed from the function terminator
371       // op in the function. 3) Update function result to match the terminator.
372       llvm::BitVector result_indices_to_erase;
373       UpdateTerminatorArguments(func, args_to_erase, result_indices_to_erase);
374       llvm::BitVector args_to_erase_bit_vector(func.getNumArguments());
375       for (auto i : args_to_erase) args_to_erase_bit_vector.set(i);
376       func.eraseArguments(args_to_erase_bit_vector);
377       llvm::BitVector indices_to_erase(func.getNumResults());
378       const int indices_to_erase_size = result_indices_to_erase.size();
379       for (int i = 0; i < indices_to_erase_size; ++i)
380         if (result_indices_to_erase.test(i)) indices_to_erase.set(i);
381       func.eraseResults(indices_to_erase);
382     } else if (auto read_var = dyn_cast<TF::ReadVariableOp>(user_op)) {
383       // Read variables was already replaced by constant op. Just remove the op.
384       read_var->erase();
385     } else if (auto while_op = dyn_cast<TF::WhileOp>(user_op)) {
386       GetUpdatedWhileOp<TF::WhileOp>(
387           while_op, while_op.cond_function().getArgumentTypes(), args_to_erase);
388       while_op->erase();
389     } else if (auto while_op = dyn_cast<TF::WhileRegionOp>(user_op)) {
390       auto new_while_op = GetUpdatedWhileOp(
391           while_op, while_op.cond().getArgumentTypes(), args_to_erase);
392       new_while_op.cond().takeBody(while_op.cond());
393       new_while_op.body().takeBody(while_op.body());
394       llvm::BitVector erase_indices;
395       UpdateTerminatorArguments(new_while_op.body(), args_to_erase,
396                                 erase_indices);
397       llvm::BitVector body_bit_vector(
398           new_while_op.body().front().getNumArguments());
399       for (auto i : args_to_erase) body_bit_vector.set(i);
400       new_while_op.body().front().eraseArguments(body_bit_vector);
401       llvm::BitVector cond_bit_vector(
402           new_while_op.cond().front().getNumArguments());
403       for (auto i : args_to_erase) cond_bit_vector.set(i);
404       new_while_op.cond().front().eraseArguments(cond_bit_vector);
405       while_op->erase();
406     } else {
407       llvm::BitVector erase_indices(user_op->getNumOperands());
408       for (auto operand_index : args_to_erase) {
409         erase_indices.set(operand_index);
410       }
411       user_op->eraseOperands(erase_indices);
412     }
413   }
414 
415   // Remove initialization of unused variables.
416   if (session_init_func)
417     RemoveVariablesInitializations(variables, session_init_func);
418 
419   // Remove the unused VarHandleOp.
420   for (auto var_handle_op : variables) {
421     if (var_handle_op) var_handle_op->erase();
422   }
423   return success();
424 }
425 
426 }  // namespace tf_saved_model
427 }  // namespace mlir
428