1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <memory>
17 #include <string>
18 #include <unordered_map>
19 #include <utility>
20
21 #include "absl/strings/str_cat.h"
22 #include "llvm/ADT/ArrayRef.h"
23 #include "llvm/ADT/DenseMap.h"
24 #include "llvm/ADT/DenseSet.h"
25 #include "llvm/ADT/SetVector.h"
26 #include "llvm/ADT/SmallVector.h"
27 #include "llvm/ADT/StringRef.h"
28 #include "llvm/Support/Casting.h"
29 #include "llvm/Support/raw_ostream.h"
30 #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
31 #include "mlir/IR/Attributes.h" // from @llvm-project
32 #include "mlir/IR/Block.h" // from @llvm-project
33 #include "mlir/IR/Builders.h" // from @llvm-project
34 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
35 #include "mlir/IR/MLIRContext.h" // from @llvm-project
36 #include "mlir/IR/Value.h" // from @llvm-project
37 #include "mlir/Pass/Pass.h" // from @llvm-project
38 #include "mlir/Pass/PassRegistry.h" // from @llvm-project
39 #include "mlir/Support/LLVM.h" // from @llvm-project
40 #include "tensorflow/compiler/mlir/lite/experimental/tac/common/subgraph.h"
41 #include "tensorflow/compiler/mlir/lite/experimental/tac/common/targets.h"
42 #include "tensorflow/compiler/mlir/lite/experimental/tac/common/utils.h"
43 #include "tensorflow/compiler/mlir/lite/experimental/tac/transforms/passes.h"
44 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
45
46 namespace mlir {
47 namespace TFL {
48 namespace tac {
49 namespace {
50
51 // Subgraph here is actually an intermediate data structure holder for the ops:
52 // The ops within share the same "target", they're topologically sorted.
53 // The subgraph here will be later populated to generate func ops.
54 // All the subgraphs should not create cyclic dependencies:
55 // So we should not have:
56 // subgraph1
57 // \
58 // subgraph2
59 // /
60 // subgraph1
61 struct Subgraph {
62 // All ops must be inserted in it's topological order.
63 llvm::SetVector<Operation*> all_ops;
64 int subgraph_id;
65 InferenceDeviceType inference_device_type;
66 };
67
68 // This will exclude arguments & consts & quantize/dequantize ops.
IsNonConstQuantizeOp(Operation * op)69 inline bool IsNonConstQuantizeOp(Operation* op) {
70 return IsNonConstOp(op) && NotTFLQuantDequantizeOp(op) && !IsTerminatorOp(op);
71 }
72
73 // This pass will group those ops (non-const TFL dialect ops) have the same
74 // target together and raise them as FuncOps.
75 // See the following Example:
76 //
77 // op1 (GPU)
78 // \ op2 (GPU)
79 // \ |
80 // \ op3 (GPU)
81 // \ /
82 // op4 (CPU)
83 //
84 // will be raised as 3 subgraphs:
85 // Subgraph 1: {op1}, GPU -> Func_1_GPU
86 // Subgraph 2: {op2, op3}, GPU -> Func_2_GPU
87 // Subgraph 3: {op4} CPU -> Func_3_CPU
88 //
89 // MainFunc:
90 // %0 = call @Func_1_GPU
91 // %1 = call @Func_2_GPU
92 // %2 = call @Func_3_CPU(%0, %1)
93 class RaiseTargetSubgraphsPass
94 : public mlir::PassWrapper<RaiseTargetSubgraphsPass,
95 mlir::OperationPass<ModuleOp>> {
96 public:
97 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(RaiseTargetSubgraphsPass)
98
99 private:
getArgument() const100 llvm::StringRef getArgument() const final {
101 return "tfl-raise-target-subgraphs";
102 }
getDescription() const103 llvm::StringRef getDescription() const final {
104 return "This pass will merge those have target-annotated TFL IRs together "
105 "& raise them as a function.";
106 }
107 void runOnOperation() override;
108
109 void RaiseTargetSubgraphsForBlock(Block* block, OpBuilder* builder,
110 ModuleOp module);
111
112 void ExtractSubgraphToFunc(Subgraph* subgraph, OpBuilder* builder,
113 ModuleOp module);
114
115 func::FuncOp BuildFuncOp(Subgraph* subgraph, OpBuilder* builder,
116 ModuleOp module_op, SmallVector<Value, 4>* inputs,
117 SmallVector<Value, 4>* outputs,
118 InferenceDeviceType* inference_device_type);
119
120 int subgraph_count_ = 0;
121 };
122
123 // This is to collect input arguments for the given set of ops.
124 // See the example:
125 //
126 // value1 value2
127 // \ /
128 // op1
129 // \ value3
130 // \ /
131 // op2
132 // |
133 // op3
134 //
135 // Then the arguments will be {value1, value2, value3}
CollectInputs(const llvm::SetVector<Operation * > & all_ops,SmallVector<Value,4> * inputs)136 void CollectInputs(const llvm::SetVector<Operation*>& all_ops,
137 SmallVector<Value, 4>* inputs) {
138 for (Operation* op : all_ops) {
139 for (Value input : op->getOperands()) {
140 Operation* input_op = input.getDefiningOp();
141 const bool input_within_subgraph =
142 (input_op && all_ops.count(input_op) == 1);
143 if (!input_within_subgraph) {
144 inputs->push_back(input);
145 }
146 }
147 }
148 }
149
150 // This is to collect outputs arguments for the given set of ops.
151 // See the example:
152 //
153 // op1
154 // / \
155 // value1 \
156 // op2
157 // | \
158 // op3 value2
159 // |
160 // value3
161 //
162 // Then the arguments will be {value1, value2, value3}
CollectOutputs(const llvm::SetVector<Operation * > & all_ops,SmallVector<Value,4> * outputs)163 void CollectOutputs(const llvm::SetVector<Operation*>& all_ops,
164 SmallVector<Value, 4>* outputs) {
165 for (Operation* op : all_ops) {
166 for (Value output : op->getResults()) {
167 bool output_consumed_outside_subgraph = false;
168 for (Operation* consumer : output.getUsers()) {
169 if (all_ops.count(consumer) == 0) {
170 output_consumed_outside_subgraph = true;
171 }
172 }
173 if (output_consumed_outside_subgraph) {
174 outputs->push_back(output);
175 }
176 }
177 }
178 }
179
BuildTypes(const SmallVector<Value,4> & values,SmallVector<Type,4> * types)180 void BuildTypes(const SmallVector<Value, 4>& values,
181 SmallVector<Type, 4>* types) {
182 for (auto value : values) {
183 types->push_back(value.getType());
184 }
185 }
186
GetFunctionName(const Subgraph & subgrpah,std::string * function_name,std::string * interface_name)187 void GetFunctionName(const Subgraph& subgrpah, std::string* function_name,
188 std::string* interface_name) {
189 *interface_name = absl::StrCat("func_", std::to_string(subgrpah.subgraph_id));
190 *function_name = absl::StrCat(
191 (*interface_name), "_", subgrpah.inference_device_type.hardware, "_",
192 GetInferenceString(subgrpah.inference_device_type.inference_type));
193 }
194
BuildFuncOp(Subgraph * subgraph,OpBuilder * builder,ModuleOp module_op,SmallVector<Value,4> * inputs,SmallVector<Value,4> * outputs,InferenceDeviceType * inference_device_type)195 func::FuncOp RaiseTargetSubgraphsPass::BuildFuncOp(
196 Subgraph* subgraph, OpBuilder* builder, ModuleOp module_op,
197 SmallVector<Value, 4>* inputs, SmallVector<Value, 4>* outputs,
198 InferenceDeviceType* inference_device_type) {
199 CollectInputs(subgraph->all_ops, inputs);
200 CollectOutputs(subgraph->all_ops, outputs);
201
202 SmallVector<Type, 4> input_types;
203 SmallVector<Type, 4> return_types;
204
205 BuildTypes(*inputs, &input_types);
206 BuildTypes(*outputs, &return_types);
207
208 FunctionType function_type =
209 builder->getFunctionType(input_types, return_types);
210
211 SmallVector<NamedAttribute, 4> attrs;
212 // Function name.
213 std::string function_name;
214 std::string interface_name;
215 GetFunctionName(*subgraph, &function_name, &interface_name);
216 attrs.push_back(builder->getNamedAttr(
217 kInterfaceNameAttr, builder->getStringAttr(interface_name)));
218
219 // Inference Device type.
220 attrs.push_back(builder->getNamedAttr(
221 kDevice,
222 builder->getStringAttr(subgraph->inference_device_type.hardware)));
223 attrs.push_back(builder->getNamedAttr(
224 kInferenceType, builder->getStringAttr(GetInferenceString(
225 subgraph->inference_device_type.inference_type))));
226 *inference_device_type = subgraph->inference_device_type;
227
228 func::FuncOp new_func =
229 func::FuncOp::create(builder->getUnknownLoc(), function_name,
230 function_type, llvm::makeArrayRef(attrs));
231 new_func.setPrivate();
232
233 new_func.addEntryBlock();
234
235 // Function argument mapping.
236 llvm::DenseMap<Value, int> function_argument_mapping;
237 for (int i = 0; i < inputs->size(); ++i) {
238 function_argument_mapping.insert({(*inputs)[i], i});
239 }
240
241 OpBuilder function_builder(new_func.getBody());
242
243 llvm::DenseMap<Operation*, Operation*> op_cloned_op_mapping;
244 llvm::DenseMap<Value, Value> output_cloned_op_output_mapping;
245 for (Operation* op : subgraph->all_ops) {
246 Operation* cloned_op = function_builder.clone(*op);
247 op_cloned_op_mapping.insert({op, cloned_op});
248 for (int i = 0; i < op->getNumResults(); ++i) {
249 Value op_output = op->getResult(i);
250 Value cloned_op_output = cloned_op->getResult(i);
251 output_cloned_op_output_mapping.insert({op_output, cloned_op_output});
252 }
253 }
254
255 for (Operation* op : subgraph->all_ops) {
256 Operation* cloned_op = op_cloned_op_mapping.find(op)->second;
257 for (int i = 0; i < op->getNumOperands(); ++i) {
258 Value input = op->getOperand(i);
259 Value cloned_op_input;
260 // If the input is actually a function argument.
261 if (function_argument_mapping.count(input) > 0) {
262 int function_argument = function_argument_mapping.find(input)->second;
263 cloned_op_input = new_func.getArgument(function_argument);
264 } else {
265 // The input is actually with in the subgraph.
266 cloned_op_input = output_cloned_op_output_mapping.find(input)->second;
267 }
268 cloned_op->setOperand(i, cloned_op_input);
269 }
270 }
271
272 SmallVector<Value, 4> final_outputs;
273 for (auto output : *outputs) {
274 auto cloned_output = output_cloned_op_output_mapping.find(output)->second;
275 final_outputs.push_back(cloned_output);
276 }
277 function_builder.create<mlir::func::ReturnOp>(new_func.getLoc(),
278 final_outputs);
279
280 module_op.push_back(new_func);
281 return new_func;
282 }
283
ExtractSubgraphToFunc(Subgraph * subgraph,OpBuilder * builder,ModuleOp module)284 void RaiseTargetSubgraphsPass::ExtractSubgraphToFunc(Subgraph* subgraph,
285 OpBuilder* builder,
286 ModuleOp module) {
287 SmallVector<Value, 4> func_inputs;
288 SmallVector<Value, 4> func_outputs;
289
290 InferenceDeviceType inference_device_type;
291 func::FuncOp func = BuildFuncOp(subgraph, builder, module, &func_inputs,
292 &func_outputs, &inference_device_type);
293
294 // We just use the location of the last ops in the subgraph as the location
295 // for the call_op.
296 Operation* last_output = subgraph->all_ops.back();
297
298 // TODO(renjieliu): we should add func attributes to the call op.
299 builder->setInsertionPoint(last_output);
300 auto call_op =
301 builder->create<func::CallOp>(last_output->getLoc(), func, func_inputs);
302
303 auto interface_name = GetInterFaceName(func);
304
305 // Set call op attribute: interface_name, hardware.
306 call_op->setAttr(kInterfaceNameAttr,
307 builder->getStringAttr(interface_name.getValue()));
308 call_op->setAttr(kDevice,
309 builder->getStringAttr(inference_device_type.hardware));
310 call_op->setAttr(kInferenceType, builder->getStringAttr(GetInferenceString(
311 inference_device_type.inference_type)));
312
313 // Rewire the outputs.
314 if (call_op.getNumResults() != func_outputs.size()) {
315 module.emitError("the constructed func op has mismatched returns");
316 signalPassFailure();
317 }
318
319 for (int i = 0; i < func_outputs.size(); ++i) {
320 Value output = func_outputs[i];
321 output.replaceAllUsesWith(call_op.getResult(i));
322 }
323
324 // Clear the subgraph.
325 // Those ops should be removed.
326 for (auto* op : subgraph->all_ops) {
327 op->dropAllDefinedValueUses();
328 op->dropAllReferences();
329 op->erase();
330 }
331 }
332
333 // TODO(renjieliu): We may need to consider about side effect ops: we may leave
334 // those ops alone when building the subgraph.
RaiseTargetSubgraphsForBlock(Block * block,OpBuilder * builder,ModuleOp module)335 void RaiseTargetSubgraphsPass::RaiseTargetSubgraphsForBlock(Block* block,
336 OpBuilder* builder,
337 ModuleOp module) {
338 // This is a very naive implementation:
339 // It will greedily group adjacent ops that have the same inference type to a
340 // subgraph.
341 llvm::DenseMap<int, Subgraph> all_subgraphs;
342 llvm::Optional<InferenceDeviceType> previous_device_type = llvm::None;
343 int current_subgraph_id = -1;
344 for (auto& op : *block) {
345 if (IsNonConstQuantizeOp(&op) && !IsTerminatorOp(&op) &&
346 !llvm::isa<func::ReturnOp, func::FuncOp, CallOpInterface>(op)) {
347 auto current_device_type = GetInferenceDeviceTypeForOp(&op);
348 if (!(current_device_type.has_value() &&
349 current_device_type == previous_device_type)) {
350 // We should start a new subgraph.
351 Subgraph new_subgraph;
352 new_subgraph.inference_device_type = current_device_type.getValue();
353 new_subgraph.subgraph_id = subgraph_count_++;
354 all_subgraphs.insert({new_subgraph.subgraph_id, new_subgraph});
355 current_subgraph_id = new_subgraph.subgraph_id;
356 }
357 previous_device_type = current_device_type;
358 all_subgraphs.find(current_subgraph_id)->second.all_ops.insert(&op);
359 }
360 }
361
362 // Create FuncOp & replace with current uses based on those subgraphs.
363 for (auto& subgraph : all_subgraphs) {
364 ExtractSubgraphToFunc(&subgraph.second, builder, module);
365 }
366 }
367
runOnOperation()368 void RaiseTargetSubgraphsPass::runOnOperation() {
369 auto module = getOperation();
370 SmallVector<func::FuncOp, 16> funcs(module.getOps<func::FuncOp>());
371 for (auto func : funcs) {
372 for (auto& block : func) {
373 auto builder = OpBuilder::atBlockBegin(&block);
374 RaiseTargetSubgraphsForBlock(&block, &builder, module);
375 }
376 }
377 }
378
379 } // namespace
380
CreateRaiseTargetSubgraphsPass()381 std::unique_ptr<OperationPass<ModuleOp>> CreateRaiseTargetSubgraphsPass() {
382 return std::make_unique<RaiseTargetSubgraphsPass>();
383 }
384
385 static PassRegistration<RaiseTargetSubgraphsPass> pass;
386
387 } // namespace tac
388 } // namespace TFL
389 } // namespace mlir
390