1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <memory>
17 #include <string>
18 #include <unordered_map>
19 #include <utility>
20 
21 #include "absl/strings/str_cat.h"
22 #include "llvm/ADT/ArrayRef.h"
23 #include "llvm/ADT/DenseMap.h"
24 #include "llvm/ADT/DenseSet.h"
25 #include "llvm/ADT/SetVector.h"
26 #include "llvm/ADT/SmallVector.h"
27 #include "llvm/ADT/StringRef.h"
28 #include "llvm/Support/Casting.h"
29 #include "llvm/Support/raw_ostream.h"
30 #include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
31 #include "mlir/IR/Attributes.h"  // from @llvm-project
32 #include "mlir/IR/Block.h"  // from @llvm-project
33 #include "mlir/IR/Builders.h"  // from @llvm-project
34 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
35 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
36 #include "mlir/IR/Value.h"  // from @llvm-project
37 #include "mlir/Pass/Pass.h"  // from @llvm-project
38 #include "mlir/Pass/PassRegistry.h"  // from @llvm-project
39 #include "mlir/Support/LLVM.h"  // from @llvm-project
40 #include "tensorflow/compiler/mlir/lite/experimental/tac/common/subgraph.h"
41 #include "tensorflow/compiler/mlir/lite/experimental/tac/common/targets.h"
42 #include "tensorflow/compiler/mlir/lite/experimental/tac/common/utils.h"
43 #include "tensorflow/compiler/mlir/lite/experimental/tac/transforms/passes.h"
44 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
45 
46 namespace mlir {
47 namespace TFL {
48 namespace tac {
49 namespace {
50 
51 // Subgraph here is actually an intermediate data structure holder for the ops:
52 // The ops within share the same "target", they're topologically sorted.
53 // The subgraph here will be later populated to generate func ops.
54 // All the subgraphs should not create cyclic dependencies:
55 // So we should not have:
56 //     subgraph1
57 //             \
58 //            subgraph2
59 //            /
60 //       subgraph1
61 struct Subgraph {
62   // All ops must be inserted in it's topological order.
63   llvm::SetVector<Operation*> all_ops;
64   int subgraph_id;
65   InferenceDeviceType inference_device_type;
66 };
67 
68 // This will exclude arguments & consts & quantize/dequantize ops.
IsNonConstQuantizeOp(Operation * op)69 inline bool IsNonConstQuantizeOp(Operation* op) {
70   return IsNonConstOp(op) && NotTFLQuantDequantizeOp(op) && !IsTerminatorOp(op);
71 }
72 
73 // This pass will group those ops (non-const TFL dialect ops) have the same
74 // target together and raise them as FuncOps.
75 // See the following Example:
76 //
77 //     op1 (GPU)
78 //       \       op2 (GPU)
79 //       \        |
80 //        \      op3 (GPU)
81 //         \     /
82 //         op4 (CPU)
83 //
84 // will be raised as 3 subgraphs:
85 // Subgraph 1: {op1}, GPU -> Func_1_GPU
86 // Subgraph 2: {op2, op3}, GPU -> Func_2_GPU
87 // Subgraph 3: {op4} CPU -> Func_3_CPU
88 //
89 // MainFunc:
90 //   %0 = call @Func_1_GPU
91 //   %1 = call @Func_2_GPU
92 //   %2 = call @Func_3_CPU(%0, %1)
93 class RaiseTargetSubgraphsPass
94     : public mlir::PassWrapper<RaiseTargetSubgraphsPass,
95                                mlir::OperationPass<ModuleOp>> {
96  public:
97   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(RaiseTargetSubgraphsPass)
98 
99  private:
getArgument() const100   llvm::StringRef getArgument() const final {
101     return "tfl-raise-target-subgraphs";
102   }
getDescription() const103   llvm::StringRef getDescription() const final {
104     return "This pass will merge those have target-annotated TFL IRs together "
105            "& raise them as a function.";
106   }
107   void runOnOperation() override;
108 
109   void RaiseTargetSubgraphsForBlock(Block* block, OpBuilder* builder,
110                                     ModuleOp module);
111 
112   void ExtractSubgraphToFunc(Subgraph* subgraph, OpBuilder* builder,
113                              ModuleOp module);
114 
115   func::FuncOp BuildFuncOp(Subgraph* subgraph, OpBuilder* builder,
116                            ModuleOp module_op, SmallVector<Value, 4>* inputs,
117                            SmallVector<Value, 4>* outputs,
118                            InferenceDeviceType* inference_device_type);
119 
120   int subgraph_count_ = 0;
121 };
122 
123 // This is to collect input arguments for the given set of ops.
124 // See the example:
125 //
126 //   value1  value2
127 //    \     /
128 //      op1
129 //        \     value3
130 //        \   /
131 //         op2
132 //         |
133 //         op3
134 //
135 //  Then the arguments will be {value1, value2, value3}
CollectInputs(const llvm::SetVector<Operation * > & all_ops,SmallVector<Value,4> * inputs)136 void CollectInputs(const llvm::SetVector<Operation*>& all_ops,
137                    SmallVector<Value, 4>* inputs) {
138   for (Operation* op : all_ops) {
139     for (Value input : op->getOperands()) {
140       Operation* input_op = input.getDefiningOp();
141       const bool input_within_subgraph =
142           (input_op && all_ops.count(input_op) == 1);
143       if (!input_within_subgraph) {
144         inputs->push_back(input);
145       }
146     }
147   }
148 }
149 
150 // This is to collect outputs arguments for the given set of ops.
151 // See the example:
152 //
153 //      op1
154 //      /    \
155 //   value1   \
156 //           op2
157 //           |  \
158 //         op3  value2
159 //         |
160 //       value3
161 //
162 //  Then the arguments will be {value1, value2, value3}
CollectOutputs(const llvm::SetVector<Operation * > & all_ops,SmallVector<Value,4> * outputs)163 void CollectOutputs(const llvm::SetVector<Operation*>& all_ops,
164                     SmallVector<Value, 4>* outputs) {
165   for (Operation* op : all_ops) {
166     for (Value output : op->getResults()) {
167       bool output_consumed_outside_subgraph = false;
168       for (Operation* consumer : output.getUsers()) {
169         if (all_ops.count(consumer) == 0) {
170           output_consumed_outside_subgraph = true;
171         }
172       }
173       if (output_consumed_outside_subgraph) {
174         outputs->push_back(output);
175       }
176     }
177   }
178 }
179 
BuildTypes(const SmallVector<Value,4> & values,SmallVector<Type,4> * types)180 void BuildTypes(const SmallVector<Value, 4>& values,
181                 SmallVector<Type, 4>* types) {
182   for (auto value : values) {
183     types->push_back(value.getType());
184   }
185 }
186 
GetFunctionName(const Subgraph & subgrpah,std::string * function_name,std::string * interface_name)187 void GetFunctionName(const Subgraph& subgrpah, std::string* function_name,
188                      std::string* interface_name) {
189   *interface_name = absl::StrCat("func_", std::to_string(subgrpah.subgraph_id));
190   *function_name = absl::StrCat(
191       (*interface_name), "_", subgrpah.inference_device_type.hardware, "_",
192       GetInferenceString(subgrpah.inference_device_type.inference_type));
193 }
194 
BuildFuncOp(Subgraph * subgraph,OpBuilder * builder,ModuleOp module_op,SmallVector<Value,4> * inputs,SmallVector<Value,4> * outputs,InferenceDeviceType * inference_device_type)195 func::FuncOp RaiseTargetSubgraphsPass::BuildFuncOp(
196     Subgraph* subgraph, OpBuilder* builder, ModuleOp module_op,
197     SmallVector<Value, 4>* inputs, SmallVector<Value, 4>* outputs,
198     InferenceDeviceType* inference_device_type) {
199   CollectInputs(subgraph->all_ops, inputs);
200   CollectOutputs(subgraph->all_ops, outputs);
201 
202   SmallVector<Type, 4> input_types;
203   SmallVector<Type, 4> return_types;
204 
205   BuildTypes(*inputs, &input_types);
206   BuildTypes(*outputs, &return_types);
207 
208   FunctionType function_type =
209       builder->getFunctionType(input_types, return_types);
210 
211   SmallVector<NamedAttribute, 4> attrs;
212   // Function name.
213   std::string function_name;
214   std::string interface_name;
215   GetFunctionName(*subgraph, &function_name, &interface_name);
216   attrs.push_back(builder->getNamedAttr(
217       kInterfaceNameAttr, builder->getStringAttr(interface_name)));
218 
219   // Inference Device type.
220   attrs.push_back(builder->getNamedAttr(
221       kDevice,
222       builder->getStringAttr(subgraph->inference_device_type.hardware)));
223   attrs.push_back(builder->getNamedAttr(
224       kInferenceType, builder->getStringAttr(GetInferenceString(
225                           subgraph->inference_device_type.inference_type))));
226   *inference_device_type = subgraph->inference_device_type;
227 
228   func::FuncOp new_func =
229       func::FuncOp::create(builder->getUnknownLoc(), function_name,
230                            function_type, llvm::makeArrayRef(attrs));
231   new_func.setPrivate();
232 
233   new_func.addEntryBlock();
234 
235   // Function argument mapping.
236   llvm::DenseMap<Value, int> function_argument_mapping;
237   for (int i = 0; i < inputs->size(); ++i) {
238     function_argument_mapping.insert({(*inputs)[i], i});
239   }
240 
241   OpBuilder function_builder(new_func.getBody());
242 
243   llvm::DenseMap<Operation*, Operation*> op_cloned_op_mapping;
244   llvm::DenseMap<Value, Value> output_cloned_op_output_mapping;
245   for (Operation* op : subgraph->all_ops) {
246     Operation* cloned_op = function_builder.clone(*op);
247     op_cloned_op_mapping.insert({op, cloned_op});
248     for (int i = 0; i < op->getNumResults(); ++i) {
249       Value op_output = op->getResult(i);
250       Value cloned_op_output = cloned_op->getResult(i);
251       output_cloned_op_output_mapping.insert({op_output, cloned_op_output});
252     }
253   }
254 
255   for (Operation* op : subgraph->all_ops) {
256     Operation* cloned_op = op_cloned_op_mapping.find(op)->second;
257     for (int i = 0; i < op->getNumOperands(); ++i) {
258       Value input = op->getOperand(i);
259       Value cloned_op_input;
260       // If the input is actually a function argument.
261       if (function_argument_mapping.count(input) > 0) {
262         int function_argument = function_argument_mapping.find(input)->second;
263         cloned_op_input = new_func.getArgument(function_argument);
264       } else {
265         // The input is actually with in the subgraph.
266         cloned_op_input = output_cloned_op_output_mapping.find(input)->second;
267       }
268       cloned_op->setOperand(i, cloned_op_input);
269     }
270   }
271 
272   SmallVector<Value, 4> final_outputs;
273   for (auto output : *outputs) {
274     auto cloned_output = output_cloned_op_output_mapping.find(output)->second;
275     final_outputs.push_back(cloned_output);
276   }
277   function_builder.create<mlir::func::ReturnOp>(new_func.getLoc(),
278                                                 final_outputs);
279 
280   module_op.push_back(new_func);
281   return new_func;
282 }
283 
ExtractSubgraphToFunc(Subgraph * subgraph,OpBuilder * builder,ModuleOp module)284 void RaiseTargetSubgraphsPass::ExtractSubgraphToFunc(Subgraph* subgraph,
285                                                      OpBuilder* builder,
286                                                      ModuleOp module) {
287   SmallVector<Value, 4> func_inputs;
288   SmallVector<Value, 4> func_outputs;
289 
290   InferenceDeviceType inference_device_type;
291   func::FuncOp func = BuildFuncOp(subgraph, builder, module, &func_inputs,
292                                   &func_outputs, &inference_device_type);
293 
294   // We just use the location of the last ops in the subgraph as the location
295   // for the call_op.
296   Operation* last_output = subgraph->all_ops.back();
297 
298   // TODO(renjieliu): we should add func attributes to the call op.
299   builder->setInsertionPoint(last_output);
300   auto call_op =
301       builder->create<func::CallOp>(last_output->getLoc(), func, func_inputs);
302 
303   auto interface_name = GetInterFaceName(func);
304 
305   // Set call op attribute: interface_name, hardware.
306   call_op->setAttr(kInterfaceNameAttr,
307                    builder->getStringAttr(interface_name.getValue()));
308   call_op->setAttr(kDevice,
309                    builder->getStringAttr(inference_device_type.hardware));
310   call_op->setAttr(kInferenceType, builder->getStringAttr(GetInferenceString(
311                                        inference_device_type.inference_type)));
312 
313   // Rewire the outputs.
314   if (call_op.getNumResults() != func_outputs.size()) {
315     module.emitError("the constructed func op has mismatched returns");
316     signalPassFailure();
317   }
318 
319   for (int i = 0; i < func_outputs.size(); ++i) {
320     Value output = func_outputs[i];
321     output.replaceAllUsesWith(call_op.getResult(i));
322   }
323 
324   // Clear the subgraph.
325   // Those ops should be removed.
326   for (auto* op : subgraph->all_ops) {
327     op->dropAllDefinedValueUses();
328     op->dropAllReferences();
329     op->erase();
330   }
331 }
332 
333 // TODO(renjieliu): We may need to consider about side effect ops: we may leave
334 // those ops alone when building the subgraph.
RaiseTargetSubgraphsForBlock(Block * block,OpBuilder * builder,ModuleOp module)335 void RaiseTargetSubgraphsPass::RaiseTargetSubgraphsForBlock(Block* block,
336                                                             OpBuilder* builder,
337                                                             ModuleOp module) {
338   // This is a very naive implementation:
339   // It will greedily group adjacent ops that have the same inference type to a
340   // subgraph.
341   llvm::DenseMap<int, Subgraph> all_subgraphs;
342   llvm::Optional<InferenceDeviceType> previous_device_type = llvm::None;
343   int current_subgraph_id = -1;
344   for (auto& op : *block) {
345     if (IsNonConstQuantizeOp(&op) && !IsTerminatorOp(&op) &&
346         !llvm::isa<func::ReturnOp, func::FuncOp, CallOpInterface>(op)) {
347       auto current_device_type = GetInferenceDeviceTypeForOp(&op);
348       if (!(current_device_type.has_value() &&
349             current_device_type == previous_device_type)) {
350         // We should start a new subgraph.
351         Subgraph new_subgraph;
352         new_subgraph.inference_device_type = current_device_type.getValue();
353         new_subgraph.subgraph_id = subgraph_count_++;
354         all_subgraphs.insert({new_subgraph.subgraph_id, new_subgraph});
355         current_subgraph_id = new_subgraph.subgraph_id;
356       }
357       previous_device_type = current_device_type;
358       all_subgraphs.find(current_subgraph_id)->second.all_ops.insert(&op);
359     }
360   }
361 
362   // Create FuncOp & replace with current uses based on those subgraphs.
363   for (auto& subgraph : all_subgraphs) {
364     ExtractSubgraphToFunc(&subgraph.second, builder, module);
365   }
366 }
367 
runOnOperation()368 void RaiseTargetSubgraphsPass::runOnOperation() {
369   auto module = getOperation();
370   SmallVector<func::FuncOp, 16> funcs(module.getOps<func::FuncOp>());
371   for (auto func : funcs) {
372     for (auto& block : func) {
373       auto builder = OpBuilder::atBlockBegin(&block);
374       RaiseTargetSubgraphsForBlock(&block, &builder, module);
375     }
376   }
377 }
378 
379 }  // namespace
380 
CreateRaiseTargetSubgraphsPass()381 std::unique_ptr<OperationPass<ModuleOp>> CreateRaiseTargetSubgraphsPass() {
382   return std::make_unique<RaiseTargetSubgraphsPass>();
383 }
384 
385 static PassRegistration<RaiseTargetSubgraphsPass> pass;
386 
387 }  // namespace tac
388 }  // namespace TFL
389 }  // namespace mlir
390