1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <string>
17
18 #include "llvm/ADT/ArrayRef.h"
19 #include "llvm/ADT/STLExtras.h"
20 #include "llvm/ADT/SetVector.h"
21 #include "llvm/Support/CommandLine.h"
22 #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
23 #include "mlir/IR/Builders.h" // from @llvm-project
24 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
25 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
26 #include "mlir/IR/Location.h" // from @llvm-project
27 #include "mlir/IR/MLIRContext.h" // from @llvm-project
28 #include "mlir/IR/Matchers.h" // from @llvm-project
29 #include "mlir/IR/SymbolTable.h" // from @llvm-project
30 #include "mlir/Pass/Pass.h" // from @llvm-project
31 #include "mlir/Transforms/RegionUtils.h" // from @llvm-project
32 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
33 #include "tensorflow/compiler/mlir/lite/transforms/passes.h"
34 #include "tensorflow/compiler/mlir/op_or_arg_name_mapper.h"
35 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
36
37 namespace mlir {
38 namespace TFL {
39 namespace {
40 #define GEN_PASS_CLASSES
41 #include "tensorflow/compiler/mlir/lite/transforms/passes.h.inc"
42
43 // This pass outlines the cond/body region of the TFL WhileOp into functions and
44 // replaces the regions with calls to these outlined functions.
45 class WhileOutlinePass : public WhileOutlinePassBase<WhileOutlinePass> {
46 public:
47 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(WhileOutlinePass)
WhileOutlinePass()48 explicit WhileOutlinePass() {}
49
50 private:
51 void runOnOperation() override;
52
53 // Outlines the regions of the WhileOp's cond and body and insert function
54 // calls instead,
55 void OutlineWhile(WhileOp while_op);
56
57 // Get unique name by using the loc to name mapping.
58 std::string GetName(Operation* op, StringRef suffix);
59
60 tensorflow::OpOrArgLocNameMapper mapper_;
61 };
62
GetName(Operation * op,StringRef suffix)63 std::string WhileOutlinePass::GetName(Operation* op, StringRef suffix) {
64 return (mapper_.GetUniqueName(op) + suffix).str();
65 }
66
67 // Returns whether the WhileOp is already outlined (e.g., only consists of calls
68 // to functions).
IsAlreadyOutlined(WhileOp while_op)69 bool IsAlreadyOutlined(WhileOp while_op) {
70 auto just_call = [](Region& region) {
71 auto it = region.front().begin();
72 if (!isa<func::CallOp>(*it)) return false;
73 ++it;
74 if (!isa<YieldOp>(*it)) return false;
75 return true;
76 };
77 return just_call(while_op.body()) && just_call(while_op.cond());
78 }
79
IsCompatibleTypeWithTFLCastOp(Type type)80 bool IsCompatibleTypeWithTFLCastOp(Type type) {
81 auto elemType = getElementTypeOrSelf(type);
82 // F32 and BF16 types are allowed.
83 if (elemType.isBF16() || elemType.isF32()) return true;
84
85 // I1, I8 I16, I32, I64 types are allowed.
86 if (elemType.isInteger(1) || elemType.isInteger(8) ||
87 elemType.isInteger(16) || elemType.isInteger(32) ||
88 elemType.isInteger(64))
89 return true;
90
91 // Complex<F<32>> is allowed.
92 if (elemType.isa<ComplexType>() &&
93 elemType.cast<ComplexType>().getElementType().isF32())
94 return true;
95
96 // QUINT8 and UI8 are allowed.
97 if (elemType.isa<TF::Quint8Type>() ||
98 (elemType.isInteger(8) && elemType.cast<IntegerType>().isUnsigned()))
99 return true;
100
101 return false;
102 }
103
CreateOutlineFunc(StringRef name,Region & region,bool passthru_extra_args,int num_loop_carried,const llvm::SetVector<Value> & extern_values,const SmallVectorImpl<Type> & types,Location loc)104 func::FuncOp CreateOutlineFunc(StringRef name, Region& region,
105 bool passthru_extra_args, int num_loop_carried,
106 const llvm::SetVector<Value>& extern_values,
107 const SmallVectorImpl<Type>& types,
108 Location loc) {
109 MLIRContext* context = loc.getContext();
110 OpBuilder builder(context);
111 FunctionType type;
112 if (passthru_extra_args) {
113 type = FunctionType::get(context, types, types);
114 } else {
115 SmallVector<Type, 4> result_types;
116 auto operands = region.front().getTerminator()->getOperandTypes();
117 result_types.append(operands.begin(), operands.end());
118 type = FunctionType::get(context, types, result_types);
119 }
120
121 auto outlined_func = builder.create<func::FuncOp>(loc, name, type);
122 outlined_func.getBody().takeBody(region);
123 Region& func_region = outlined_func.getBody();
124
125 // Replace all external uses with block args and update uses.
126 llvm::SmallVector<Value, 4> new_args;
127 new_args.reserve(extern_values.size());
128 Block& block = func_region.front();
129 for (Value value : extern_values) {
130 auto arg = block.addArgument(value.getType(), loc);
131 replaceAllUsesInRegionWith(value, arg, func_region);
132 new_args.push_back(arg);
133 }
134
135 // Replace yield op with return.
136 Operation* yield_op = outlined_func.getBody().front().getTerminator();
137 OpBuilder b(yield_op);
138 llvm::SmallVector<Value, 4> args;
139 auto loop_carried_yield_operands =
140 yield_op->getOperands().take_front(num_loop_carried);
141 args.reserve(loop_carried_yield_operands.size() + new_args.size());
142 if (passthru_extra_args) {
143 // Add operands of yield to the return, inserting casts if needed.
144 for (auto it : llvm::zip_first(loop_carried_yield_operands, types)) {
145 auto value = std::get<0>(it);
146 auto type = std::get<1>(it);
147 if (value.getType() == type) {
148 args.push_back(value);
149 } else {
150 if (IsCompatibleTypeWithTFLCastOp(value.getType()) &&
151 IsCompatibleTypeWithTFLCastOp(type)) {
152 auto cast = b.create<CastOp>(yield_op->getLoc(), type, value);
153 args.push_back(cast);
154 } else {
155 auto cast = b.create<TF::CastOp>(yield_op->getLoc(), type, value);
156 args.push_back(cast);
157 }
158 }
159 }
160 args.append(new_args.begin(), new_args.end());
161 } else {
162 args.append(yield_op->operand_begin(), yield_op->operand_end());
163 }
164 b.create<func::ReturnOp>(yield_op->getLoc(), args);
165 yield_op->erase();
166 SymbolTable(region.getParentOfType<ModuleOp>()).insert(outlined_func);
167 outlined_func.setPrivate();
168 return outlined_func;
169 }
170
171 // Replace region with call to outline function.
ReplaceRegionWithCall(StringRef name,Region & region,bool passthru_extra_args,int num_loop_carried,const llvm::SetVector<Value> & extern_values,const SmallVectorImpl<Type> & types,Location loc)172 void ReplaceRegionWithCall(StringRef name, Region& region,
173 bool passthru_extra_args, int num_loop_carried,
174 const llvm::SetVector<Value>& extern_values,
175 const SmallVectorImpl<Type>& types, Location loc) {
176 auto func = CreateOutlineFunc(name, region, passthru_extra_args,
177 num_loop_carried, extern_values, types, loc);
178 OpBuilder b(region);
179 // The body of the region is empty/has been outlined into the function.
180 auto block = b.createBlock(®ion);
181 SmallVector<Value, 4> new_operands;
182 new_operands.reserve(types.size());
183 for (Type t : llvm::makeArrayRef(types).drop_back(extern_values.size()))
184 new_operands.push_back(block->addArgument(t, loc));
185 for (Value v : extern_values) new_operands.push_back(v);
186 auto call = b.create<func::CallOp>(loc, func, new_operands);
187 b.create<YieldOp>(loc, call.getResults());
188 }
189
OutlineWhile(WhileOp while_op)190 void WhileOutlinePass::OutlineWhile(WhileOp while_op) {
191 OpBuilder builder(&getContext());
192 // Collect external values used.
193 llvm::SetVector<Value> extern_values;
194
195 // The basic block arguments correspond to values that are loop carried, while
196 // all those post are loop independent. Initialize extern_values with while_op
197 // not loop carried operands.
198 auto num_loop_carried = while_op.cond().getNumArguments();
199 auto not_carried_operands =
200 while_op.getOperands().drop_front(num_loop_carried);
201 extern_values.insert(not_carried_operands.begin(),
202 not_carried_operands.end());
203 auto old_extern_values_size = extern_values.size();
204
205 llvm::SmallVector<Region*, 2> regions{&while_op.cond(), &while_op.body()};
206 for (const auto& it : llvm::enumerate(regions)) {
207 llvm::SetVector<Value> region_extern_values;
208 getUsedValuesDefinedAbove(*it.value(), region_extern_values);
209
210 // Sink down constants into the functions.
211 for (auto extern_value : region_extern_values) {
212 if (!matchPattern(extern_value, m_Constant())) {
213 extern_values.insert(extern_value);
214 continue;
215 }
216 // Add constant at start of region.
217 auto const_builder =
218 OpBuilder(&it.value()->front(), it.value()->front().begin());
219 auto const_value = const_builder.clone(*extern_value.getDefiningOp());
220 replaceAllUsesInRegionWith(extern_value, const_value->getResult(0),
221 *it.value());
222 }
223 }
224
225 bool has_extra_extern_values = old_extern_values_size != extern_values.size();
226 // If an extern value is already an operand post the loop carried operands,
227 // then it need not be passed in again.
228 // Compute all the extra operands that have to be added to the while.
229 llvm::SetVector<Value> extra_operands;
230 if (has_extra_extern_values) {
231 auto new_extern =
232 extern_values.getArrayRef().drop_front(old_extern_values_size);
233 extra_operands.insert(new_extern.begin(), new_extern.end());
234 }
235
236 // Skip if already just calls.
237 if (extra_operands.empty() && IsAlreadyOutlined(while_op)) return;
238
239 // Collect new types.
240 SmallVector<Type, 4> types;
241 types.reserve(extra_operands.size() + while_op.getNumOperands());
242 for (Type type : while_op.cond().getArgumentTypes()) types.push_back(type);
243 for (Value operand : extern_values) types.push_back(operand.getType());
244
245 // Create outline function from region. Optional pass extra arguments through
246 // to yield.
247 ReplaceRegionWithCall(GetName(while_op.getOperation(), "_cond"),
248 while_op.cond(), false, num_loop_carried, extern_values,
249 types, while_op.getLoc());
250 ReplaceRegionWithCall(GetName(while_op.getOperation(), "_body"),
251 while_op.body(), true, num_loop_carried, extern_values,
252 types, while_op.getLoc());
253
254 // If there are extern values used then the result type of the while has to
255 // change, so replace with new while op.
256 if (extra_operands.empty()) return;
257
258 const int operands_size = while_op.getNumOperands() + extra_operands.size();
259 SmallVector<Value, 4> operands;
260 operands.reserve(operands_size);
261 operands.append(while_op.getOperands().begin(), while_op.getOperands().end());
262 operands.append(extra_operands.begin(), extra_operands.end());
263 SmallVector<Type, 4> new_types;
264 new_types.reserve(operands_size);
265 new_types.append(while_op.getResultTypes().begin(),
266 while_op.getResultTypes().end());
267 for (auto extra_operand : extra_operands)
268 new_types.push_back(extra_operand.getType());
269
270 auto new_while_op = OpBuilder(while_op).create<WhileOp>(
271 while_op.getLoc(), new_types, operands, while_op->getAttrs());
272 new_while_op.cond().takeBody(while_op.cond());
273 new_while_op.body().takeBody(while_op.body());
274 while_op.replaceAllUsesWith(
275 new_while_op.getResults().take_front(while_op.getNumResults()));
276 while_op.erase();
277 }
278
runOnOperation()279 void WhileOutlinePass::runOnOperation() {
280 getOperation().walk(
281 [&](mlir::TFL::WhileOp while_op) { OutlineWhile(while_op); });
282 }
283 } // namespace
284
285 // Creates an instance of the TensorFlow Lite dialect WhileOp outline pass.
CreateWhileOutlinePass()286 std::unique_ptr<OperationPass<ModuleOp>> CreateWhileOutlinePass() {
287 return std::make_unique<WhileOutlinePass>();
288 }
289
290 } // namespace TFL
291 } // namespace mlir
292