xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/mlir/lite/transforms/while_loop_outline.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <string>
17 
18 #include "llvm/ADT/ArrayRef.h"
19 #include "llvm/ADT/STLExtras.h"
20 #include "llvm/ADT/SetVector.h"
21 #include "llvm/Support/CommandLine.h"
22 #include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
23 #include "mlir/IR/Builders.h"  // from @llvm-project
24 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
25 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
26 #include "mlir/IR/Location.h"  // from @llvm-project
27 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
28 #include "mlir/IR/Matchers.h"  // from @llvm-project
29 #include "mlir/IR/SymbolTable.h"  // from @llvm-project
30 #include "mlir/Pass/Pass.h"  // from @llvm-project
31 #include "mlir/Transforms/RegionUtils.h"  // from @llvm-project
32 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
33 #include "tensorflow/compiler/mlir/lite/transforms/passes.h"
34 #include "tensorflow/compiler/mlir/op_or_arg_name_mapper.h"
35 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
36 
37 namespace mlir {
38 namespace TFL {
39 namespace {
40 #define GEN_PASS_CLASSES
41 #include "tensorflow/compiler/mlir/lite/transforms/passes.h.inc"
42 
43 // This pass outlines the cond/body region of the TFL WhileOp into functions and
44 // replaces the regions with calls to these outlined functions.
45 class WhileOutlinePass : public WhileOutlinePassBase<WhileOutlinePass> {
46  public:
47   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(WhileOutlinePass)
WhileOutlinePass()48   explicit WhileOutlinePass() {}
49 
50  private:
51   void runOnOperation() override;
52 
53   // Outlines the regions of the WhileOp's cond and body and insert function
54   // calls instead,
55   void OutlineWhile(WhileOp while_op);
56 
57   // Get unique name by using the loc to name mapping.
58   std::string GetName(Operation* op, StringRef suffix);
59 
60   tensorflow::OpOrArgLocNameMapper mapper_;
61 };
62 
GetName(Operation * op,StringRef suffix)63 std::string WhileOutlinePass::GetName(Operation* op, StringRef suffix) {
64   return (mapper_.GetUniqueName(op) + suffix).str();
65 }
66 
67 // Returns whether the WhileOp is already outlined (e.g., only consists of calls
68 // to functions).
IsAlreadyOutlined(WhileOp while_op)69 bool IsAlreadyOutlined(WhileOp while_op) {
70   auto just_call = [](Region& region) {
71     auto it = region.front().begin();
72     if (!isa<func::CallOp>(*it)) return false;
73     ++it;
74     if (!isa<YieldOp>(*it)) return false;
75     return true;
76   };
77   return just_call(while_op.body()) && just_call(while_op.cond());
78 }
79 
IsCompatibleTypeWithTFLCastOp(Type type)80 bool IsCompatibleTypeWithTFLCastOp(Type type) {
81   auto elemType = getElementTypeOrSelf(type);
82   // F32 and BF16 types are allowed.
83   if (elemType.isBF16() || elemType.isF32()) return true;
84 
85   // I1, I8 I16, I32, I64 types are allowed.
86   if (elemType.isInteger(1) || elemType.isInteger(8) ||
87       elemType.isInteger(16) || elemType.isInteger(32) ||
88       elemType.isInteger(64))
89     return true;
90 
91   // Complex<F<32>> is allowed.
92   if (elemType.isa<ComplexType>() &&
93       elemType.cast<ComplexType>().getElementType().isF32())
94     return true;
95 
96   // QUINT8 and UI8 are allowed.
97   if (elemType.isa<TF::Quint8Type>() ||
98       (elemType.isInteger(8) && elemType.cast<IntegerType>().isUnsigned()))
99     return true;
100 
101   return false;
102 }
103 
CreateOutlineFunc(StringRef name,Region & region,bool passthru_extra_args,int num_loop_carried,const llvm::SetVector<Value> & extern_values,const SmallVectorImpl<Type> & types,Location loc)104 func::FuncOp CreateOutlineFunc(StringRef name, Region& region,
105                                bool passthru_extra_args, int num_loop_carried,
106                                const llvm::SetVector<Value>& extern_values,
107                                const SmallVectorImpl<Type>& types,
108                                Location loc) {
109   MLIRContext* context = loc.getContext();
110   OpBuilder builder(context);
111   FunctionType type;
112   if (passthru_extra_args) {
113     type = FunctionType::get(context, types, types);
114   } else {
115     SmallVector<Type, 4> result_types;
116     auto operands = region.front().getTerminator()->getOperandTypes();
117     result_types.append(operands.begin(), operands.end());
118     type = FunctionType::get(context, types, result_types);
119   }
120 
121   auto outlined_func = builder.create<func::FuncOp>(loc, name, type);
122   outlined_func.getBody().takeBody(region);
123   Region& func_region = outlined_func.getBody();
124 
125   // Replace all external uses with block args and update uses.
126   llvm::SmallVector<Value, 4> new_args;
127   new_args.reserve(extern_values.size());
128   Block& block = func_region.front();
129   for (Value value : extern_values) {
130     auto arg = block.addArgument(value.getType(), loc);
131     replaceAllUsesInRegionWith(value, arg, func_region);
132     new_args.push_back(arg);
133   }
134 
135   // Replace yield op with return.
136   Operation* yield_op = outlined_func.getBody().front().getTerminator();
137   OpBuilder b(yield_op);
138   llvm::SmallVector<Value, 4> args;
139   auto loop_carried_yield_operands =
140       yield_op->getOperands().take_front(num_loop_carried);
141   args.reserve(loop_carried_yield_operands.size() + new_args.size());
142   if (passthru_extra_args) {
143     // Add operands of yield to the return, inserting casts if needed.
144     for (auto it : llvm::zip_first(loop_carried_yield_operands, types)) {
145       auto value = std::get<0>(it);
146       auto type = std::get<1>(it);
147       if (value.getType() == type) {
148         args.push_back(value);
149       } else {
150         if (IsCompatibleTypeWithTFLCastOp(value.getType()) &&
151             IsCompatibleTypeWithTFLCastOp(type)) {
152           auto cast = b.create<CastOp>(yield_op->getLoc(), type, value);
153           args.push_back(cast);
154         } else {
155           auto cast = b.create<TF::CastOp>(yield_op->getLoc(), type, value);
156           args.push_back(cast);
157         }
158       }
159     }
160     args.append(new_args.begin(), new_args.end());
161   } else {
162     args.append(yield_op->operand_begin(), yield_op->operand_end());
163   }
164   b.create<func::ReturnOp>(yield_op->getLoc(), args);
165   yield_op->erase();
166   SymbolTable(region.getParentOfType<ModuleOp>()).insert(outlined_func);
167   outlined_func.setPrivate();
168   return outlined_func;
169 }
170 
171 // Replace region with call to outline function.
ReplaceRegionWithCall(StringRef name,Region & region,bool passthru_extra_args,int num_loop_carried,const llvm::SetVector<Value> & extern_values,const SmallVectorImpl<Type> & types,Location loc)172 void ReplaceRegionWithCall(StringRef name, Region& region,
173                            bool passthru_extra_args, int num_loop_carried,
174                            const llvm::SetVector<Value>& extern_values,
175                            const SmallVectorImpl<Type>& types, Location loc) {
176   auto func = CreateOutlineFunc(name, region, passthru_extra_args,
177                                 num_loop_carried, extern_values, types, loc);
178   OpBuilder b(region);
179   // The body of the region is empty/has been outlined into the function.
180   auto block = b.createBlock(&region);
181   SmallVector<Value, 4> new_operands;
182   new_operands.reserve(types.size());
183   for (Type t : llvm::makeArrayRef(types).drop_back(extern_values.size()))
184     new_operands.push_back(block->addArgument(t, loc));
185   for (Value v : extern_values) new_operands.push_back(v);
186   auto call = b.create<func::CallOp>(loc, func, new_operands);
187   b.create<YieldOp>(loc, call.getResults());
188 }
189 
OutlineWhile(WhileOp while_op)190 void WhileOutlinePass::OutlineWhile(WhileOp while_op) {
191   OpBuilder builder(&getContext());
192   // Collect external values used.
193   llvm::SetVector<Value> extern_values;
194 
195   // The basic block arguments correspond to values that are loop carried, while
196   // all those post are loop independent. Initialize extern_values with while_op
197   // not loop carried operands.
198   auto num_loop_carried = while_op.cond().getNumArguments();
199   auto not_carried_operands =
200       while_op.getOperands().drop_front(num_loop_carried);
201   extern_values.insert(not_carried_operands.begin(),
202                        not_carried_operands.end());
203   auto old_extern_values_size = extern_values.size();
204 
205   llvm::SmallVector<Region*, 2> regions{&while_op.cond(), &while_op.body()};
206   for (const auto& it : llvm::enumerate(regions)) {
207     llvm::SetVector<Value> region_extern_values;
208     getUsedValuesDefinedAbove(*it.value(), region_extern_values);
209 
210     // Sink down constants into the functions.
211     for (auto extern_value : region_extern_values) {
212       if (!matchPattern(extern_value, m_Constant())) {
213         extern_values.insert(extern_value);
214         continue;
215       }
216       // Add constant at start of region.
217       auto const_builder =
218           OpBuilder(&it.value()->front(), it.value()->front().begin());
219       auto const_value = const_builder.clone(*extern_value.getDefiningOp());
220       replaceAllUsesInRegionWith(extern_value, const_value->getResult(0),
221                                  *it.value());
222     }
223   }
224 
225   bool has_extra_extern_values = old_extern_values_size != extern_values.size();
226   // If an extern value is already an operand post the loop carried operands,
227   // then it need not be passed in again.
228   // Compute all the extra operands that have to be added to the while.
229   llvm::SetVector<Value> extra_operands;
230   if (has_extra_extern_values) {
231     auto new_extern =
232         extern_values.getArrayRef().drop_front(old_extern_values_size);
233     extra_operands.insert(new_extern.begin(), new_extern.end());
234   }
235 
236   // Skip if already just calls.
237   if (extra_operands.empty() && IsAlreadyOutlined(while_op)) return;
238 
239   // Collect new types.
240   SmallVector<Type, 4> types;
241   types.reserve(extra_operands.size() + while_op.getNumOperands());
242   for (Type type : while_op.cond().getArgumentTypes()) types.push_back(type);
243   for (Value operand : extern_values) types.push_back(operand.getType());
244 
245   // Create outline function from region. Optional pass extra arguments through
246   // to yield.
247   ReplaceRegionWithCall(GetName(while_op.getOperation(), "_cond"),
248                         while_op.cond(), false, num_loop_carried, extern_values,
249                         types, while_op.getLoc());
250   ReplaceRegionWithCall(GetName(while_op.getOperation(), "_body"),
251                         while_op.body(), true, num_loop_carried, extern_values,
252                         types, while_op.getLoc());
253 
254   // If there are extern values used then the result type of the while has to
255   // change, so replace with new while op.
256   if (extra_operands.empty()) return;
257 
258   const int operands_size = while_op.getNumOperands() + extra_operands.size();
259   SmallVector<Value, 4> operands;
260   operands.reserve(operands_size);
261   operands.append(while_op.getOperands().begin(), while_op.getOperands().end());
262   operands.append(extra_operands.begin(), extra_operands.end());
263   SmallVector<Type, 4> new_types;
264   new_types.reserve(operands_size);
265   new_types.append(while_op.getResultTypes().begin(),
266                    while_op.getResultTypes().end());
267   for (auto extra_operand : extra_operands)
268     new_types.push_back(extra_operand.getType());
269 
270   auto new_while_op = OpBuilder(while_op).create<WhileOp>(
271       while_op.getLoc(), new_types, operands, while_op->getAttrs());
272   new_while_op.cond().takeBody(while_op.cond());
273   new_while_op.body().takeBody(while_op.body());
274   while_op.replaceAllUsesWith(
275       new_while_op.getResults().take_front(while_op.getNumResults()));
276   while_op.erase();
277 }
278 
runOnOperation()279 void WhileOutlinePass::runOnOperation() {
280   getOperation().walk(
281       [&](mlir::TFL::WhileOp while_op) { OutlineWhile(while_op); });
282 }
283 }  // namespace
284 
285 // Creates an instance of the TensorFlow Lite dialect WhileOp outline pass.
CreateWhileOutlinePass()286 std::unique_ptr<OperationPass<ModuleOp>> CreateWhileOutlinePass() {
287   return std::make_unique<WhileOutlinePass>();
288 }
289 
290 }  // namespace TFL
291 }  // namespace mlir
292