xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/mlir/xla/transforms/legalize_tf_control_flow.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // This file implements logic for lowering TensorFlow dialect's control flow to
17 // the XLA dialect.
18 
19 #include <cstddef>
20 #include <cstdint>
21 #include <iterator>
22 #include <numeric>
23 #include <tuple>
24 
25 #include "llvm/ADT/ArrayRef.h"
26 #include "llvm/ADT/STLExtras.h"
27 #include "llvm/ADT/SetVector.h"
28 #include "llvm/ADT/SmallVector.h"
29 #include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
30 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
31 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
32 #include "mlir/IR/Operation.h"  // from @llvm-project
33 #include "mlir/IR/OperationSupport.h"  // from @llvm-project
34 #include "mlir/IR/Types.h"  // from @llvm-project
35 #include "mlir/Pass/Pass.h"  // from @llvm-project
36 #include "mlir/Pass/PassRegistry.h"  // from @llvm-project
37 #include "mlir/Transforms/RegionUtils.h"  // from @llvm-project
38 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
39 #include "tensorflow/compiler/mlir/xla/transforms/xla_legalize_tf_passes_detail.h"
40 #include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
41 
42 using mlir::PassRegistration;
43 
44 namespace mlir {
45 namespace mhlo {
46 namespace {
47 class LegalizeTFControlFlow
48     : public LegalizeTFControlFlowBase<LegalizeTFControlFlow> {
49  public:
50   void runOnOperation() override;
51 };
52 }  // namespace
53 
54 std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
createLegalizeTFControlFlowPass()55 createLegalizeTFControlFlowPass() {
56   return std::make_unique<LegalizeTFControlFlow>();
57 }
58 
59 namespace {
60 
Detuple(Value tuple,ValueRange replace,OpBuilder * builder)61 void Detuple(Value tuple, ValueRange replace, OpBuilder* builder) {
62   // De-tuple the results of the xla hlo if result.
63   for (auto result_it : llvm::enumerate(replace)) {
64     auto get_tuple_value = builder->create<mhlo::GetTupleElementOp>(
65         result_it.value().getLoc(), tuple, result_it.index());
66     result_it.value().replaceAllUsesWith(get_tuple_value);
67   }
68 }
69 
70 // For mlir::IfOp or mlir::CaseOp, replace the uses of their region's block
71 // arguments with 'implicit_operands'. Here | 'implicit_operands' | == Number of
72 // arguments in any of the regions in IfOp or CaseOp.
ReplaceBlockArgumentsWithImplicitOperands(mlir::Operation * op,llvm::ArrayRef<mlir::Value> implicit_operands)73 void ReplaceBlockArgumentsWithImplicitOperands(
74     mlir::Operation* op, llvm::ArrayRef<mlir::Value> implicit_operands) {
75   assert((mlir::dyn_cast<mlir::mhlo::IfOp>(*op) ||
76           mlir::dyn_cast<mlir::mhlo::CaseOp>(*op)) &&
77          "Unexpected mlir op in ReplaceBlockArgumentsWithImplicitOperands!");
78 
79   for (auto& region : op->getRegions()) {
80     int implicit_operand_index = 0;
81     for (auto arg : region.getArguments()) {
82       assert(implicit_operand_index < implicit_operands.size());
83       arg.replaceAllUsesWith(implicit_operands[implicit_operand_index++]);
84     }
85 
86     region.front().eraseArguments(
87         llvm::to_vector(llvm::seq<unsigned>(0, region.getNumArguments())));
88   }
89 }
90 
91 // Imports the source region into the destination region. MHLO supports
92 // multiple arguments per branch and multiple returns which are individually
93 // tupled together during export to XLA. This tupling is needed as XLA if/while
94 // operation only supports one argument per branch and a single return value.
95 // `tuple_arg` allows any branch that requires additional arguments to have
96 // their values be tupled together. Similarly, `tuple_return` allows the results
97 // of the if/while operation to be tupled together.
ImportXlaRegion(mlir::func::FuncOp func,Region * dest_region,Location loc,bool tuple_return=true,bool tuple_arg=true)98 void ImportXlaRegion(mlir::func::FuncOp func, Region* dest_region, Location loc,
99                      bool tuple_return = true, bool tuple_arg = true) {
100   OpBuilder builder(dest_region);
101 
102   auto entry_block = builder.createBlock(dest_region);
103   func::CallOp result;
104   if (!tuple_arg) {
105     auto inputs = func.getFunctionType().getInputs();
106     auto args = entry_block->addArguments(
107         inputs, SmallVector<Location>(inputs.size(), loc));
108     ArrayRef<Value> callop_args(args.begin(), args.end());
109     result = builder.create<func::CallOp>(loc, func, callop_args);
110   } else {
111     auto tuple_arg = entry_block->addArgument(
112         builder.getTupleType(func.getFunctionType().getInputs()), loc);
113     llvm::SmallVector<Value, 4> detupled_args;
114     detupled_args.reserve(func.getNumArguments());
115 
116     for (int64_t i = 0, s = func.getNumArguments(); i < s; i++) {
117       auto extract = builder.create<GetTupleElementOp>(loc, tuple_arg, i);
118       detupled_args.push_back(extract);
119     }
120 
121     result = builder.create<func::CallOp>(loc, func, detupled_args);
122   }
123 
124   if (!tuple_return) {
125     builder.create<mhlo::ReturnOp>(loc, result.getResults());
126   } else {
127     auto tuple_op = builder.create<TupleOp>(loc, result.getResults());
128     builder.create<mhlo::ReturnOp>(loc, tuple_op.getResult());
129   }
130 }
131 
LowerIf(TF::IfOp op)132 void LowerIf(TF::IfOp op) {
133   Location loc = op.getLoc();
134   OpBuilder builder(op);
135 
136   SmallVector<Value, 3> inputs(op.input());
137 
138   // Create the new `mhlo.if` op.
139   auto if_op = builder.create<mhlo::IfOp>(loc, op.getResultTypes(), op.cond());
140 
141   // Import the regions for both the true and false cases. These regions
142   // must be updated to tuple the return results together and use the xla hlo
143   // return op.
144   ImportXlaRegion(op.then_function(), &if_op.true_branch(), loc,
145                   /*tuple_return=*/false, /*tuple_arg=*/false);
146   ImportXlaRegion(op.else_function(), &if_op.false_branch(), loc,
147                   /*tuple_return=*/false, /*tuple_arg=*/false);
148 
149   // Replace the uses of block-arguments of the IfOp with the
150   // implicit_operands.
151   ReplaceBlockArgumentsWithImplicitOperands(if_op.getOperation(), inputs);
152 
153   op->replaceAllUsesWith(if_op);
154   op.erase();
155 }
156 
LowerCase(TF::CaseOp op)157 void LowerCase(TF::CaseOp op) {
158   Location loc = op.getLoc();
159   OpBuilder builder(op);
160 
161   SmallVector<Value, 4> inputs(op.input());
162 
163   // Create the new `mhlo.case` op.
164   auto case_op = builder.create<mhlo::CaseOp>(
165       loc, op.getResultTypes(), op.branch_index(), op.branches().size());
166 
167   // Import the regions for all branches.
168   for (unsigned i = 0; i < op.num_branches(); ++i) {
169     mlir::func::FuncOp branch_func = op.branch_function(i);
170     ImportXlaRegion(branch_func, &case_op.branches()[i], loc,
171                     /*tuple_return=*/false, /*tuple_arg=*/false);
172   }
173 
174   // Replace the uses of block-arguments of the IfOp with the
175   // implicit_operands.
176   ReplaceBlockArgumentsWithImplicitOperands(case_op.getOperation(), inputs);
177 
178   op.replaceAllUsesWith(case_op);
179   op.erase();
180 }
181 
LowerWhile(TF::WhileOp op)182 void LowerWhile(TF::WhileOp op) {
183   Location loc = op.getLoc();
184   OpBuilder builder(op);
185 
186   // XLA prefers tuple arguments for control flow due to XLA not supporting
187   // multiple return values.
188   SmallVector<Value, 3> inputs(op.input());
189   builder.setInsertionPoint(op);
190 
191   // Create the new `mhlo.while` op with inputs.
192   auto while_op =
193       builder.create<mhlo::WhileOp>(loc, op.getResultTypes(), inputs);
194 
195   // Import the regions for both the cond and body.
196   ImportXlaRegion(op.body_function(), &while_op.body(), loc,
197                   /*tuple_return=*/false, /*tuple_arg=*/false);
198   ImportXlaRegion(op.cond_function(), &while_op.cond(), loc,
199                   /*tuple_return=*/false, /*tuple_arg=*/false);
200 
201   op->replaceAllUsesWith(while_op);
202   op.erase();
203 }
204 
205 // Replaces all block arguments of a block with a single block arg of Tuple
206 // type `tuple_type`. Single block arguments are removed and remapped to
207 // get_tuple_element(tuple_arg, index).
ReplaceBlockArgs(Block * block,Type tuple_type,OpBuilder * builder)208 void ReplaceBlockArgs(Block* block, Type tuple_type, OpBuilder* builder) {
209   auto tuple_arg = block->addArgument(tuple_type, block->getParent()->getLoc());
210   Detuple(tuple_arg, block->getArguments().drop_back(1), builder);
211   for (int i = block->getNumArguments() - 2; i >= 0; --i)
212     block->eraseArgument(i);
213 }
214 
215 // Replaces implicitly captured value uses with block arguments.
ReplaceImplicitInputs(Block * block,int offset,ArrayRef<Value> implicit_inputs)216 llvm::SmallVector<Value, 4> ReplaceImplicitInputs(
217     Block* block, int offset, ArrayRef<Value> implicit_inputs) {
218   llvm::SmallVector<Value, 4> implicit_input_elements;
219   implicit_input_elements.reserve(implicit_inputs.size());
220 
221   Region* region = block->getParent();
222 
223   for (auto& implicit_input : llvm::enumerate(implicit_inputs)) {
224     Value implicit_input_value = implicit_input.value();
225     BlockArgument arg = block->getArgument(implicit_input.index() + offset);
226     implicit_input_elements.emplace_back(arg);
227     for (auto& use :
228          llvm::make_early_inc_range(implicit_input_value.getUses())) {
229       if (!region->isAncestor(use.getOwner()->getParentRegion())) continue;
230       use.set(arg);
231     }
232   }
233 
234   return implicit_input_elements;
235 }
236 
237 // Replaces implicitly captured value uses with tuple block argument.
238 // get_tuple_element's are created to extract specific values. Values from
239 // get_tuple_element's are returned in the order of `implicit_inputs`.
ReplaceImplicitInputsWithTupleElements(Block * block,int offset,ArrayRef<Value> implicit_inputs,OpBuilder * builder)240 llvm::SmallVector<Value, 4> ReplaceImplicitInputsWithTupleElements(
241     Block* block, int offset, ArrayRef<Value> implicit_inputs,
242     OpBuilder* builder) {
243   llvm::SmallVector<Value, 4> implicit_input_elements;
244   implicit_input_elements.reserve(implicit_inputs.size());
245 
246   Region* region = block->getParent();
247   assert(block->getNumArguments() == 1);
248 
249   BlockArgument tuple_arg = block->getArgument(0);
250   for (auto& implicit_input : llvm::enumerate(implicit_inputs)) {
251     Value implicit_input_value = implicit_input.value();
252     auto get_tuple_element = builder->create<mhlo::GetTupleElementOp>(
253         implicit_input_value.getLoc(), tuple_arg,
254         implicit_input.index() + offset);
255     implicit_input_elements.emplace_back(get_tuple_element.getResult());
256     for (auto& use :
257          llvm::make_early_inc_range(implicit_input_value.getUses())) {
258       if (!region->isAncestor(use.getOwner()->getParentRegion())) continue;
259       use.set(get_tuple_element.getResult());
260     }
261   }
262 
263   return implicit_input_elements;
264 }
265 
266 // Finds and replaces implicitly captured value uses with tuple block argument.
267 // A tuple of implicitly captured values is also created and returned, for use
268 // as an operand to the associated mhlo control flow op.
TupleImplicitInputs(Region & region,Location loc,OpBuilder * builder)269 Value TupleImplicitInputs(Region& region, Location loc, OpBuilder* builder) {
270   llvm::SetVector<Value> implicit_inputs;
271   getUsedValuesDefinedAbove(region, region, implicit_inputs);
272   llvm::ArrayRef<Value> implicit_inputs_ref = implicit_inputs.getArrayRef();
273   Value tuple_input = builder->create<mhlo::TupleOp>(loc, implicit_inputs_ref);
274   Block& block = region.front();
275   // `tf.CaseRegion`/`tf.IfRegion` are expected to have no block arguments and
276   // instead all inputs used by their branch regions are implicitly captured
277   // from above.
278   assert(block.getNumArguments() == 0);
279   block.addArgument(tuple_input.getType(), loc);
280   builder->setInsertionPointToStart(&block);
281   ReplaceImplicitInputsWithTupleElements(&block, /*offset=*/0,
282                                          implicit_inputs_ref, builder);
283   return tuple_input;
284 }
285 
286 // Replaces block terminator (tf.Yield) with `mhlo.return`. Additional results
287 // can be returned if `extra_results` is not empty. If `tuple_return` is
288 // set, a tuple of the return values will be set as the terminator operand.
ReplaceTerminator(Block * block,ArrayRef<Value> extra_results,OpBuilder * builder,bool tuple_return=true)289 void ReplaceTerminator(Block* block, ArrayRef<Value> extra_results,
290                        OpBuilder* builder, bool tuple_return = true) {
291   Operation* terminator = block->getTerminator();
292   assert(isa<TF::YieldOp>(terminator));
293   Location loc = terminator->getLoc();
294 
295   builder->setInsertionPoint(terminator);
296   auto results = llvm::to_vector<4>(terminator->getOperands());
297   results.append(extra_results.begin(), extra_results.end());
298   if (tuple_return) {
299     auto tuple_results = builder->create<mhlo::TupleOp>(loc, results);
300     builder->create<mhlo::ReturnOp>(loc, tuple_results.getResult());
301   } else {
302     builder->create<mhlo::ReturnOp>(loc, results);
303   }
304 
305   terminator->erase();
306 }
307 
LowerIfRegion(TF::IfRegionOp op)308 void LowerIfRegion(TF::IfRegionOp op) {
309   Location loc = op.getLoc();
310   OpBuilder builder(op);
311 
312   builder.setInsertionPoint(op);
313   ReplaceTerminator(&op.then_branch().front(), /*extra_results=*/{}, &builder,
314                     /*tuple_return=*/false);
315 
316   builder.setInsertionPoint(op);
317   ReplaceTerminator(&op.else_branch().front(), /*extra_results=*/{}, &builder,
318                     /*tuple_return=*/false);
319 
320   // Create the new `mhlo.if` op and take ownership of regions from
321   // `tf.IfRegion` op.
322   builder.setInsertionPoint(op);
323   auto if_op = builder.create<mhlo::IfOp>(loc, op.getResultTypes(), op.cond());
324   if_op.true_branch().takeBody(op.then_branch());
325   if_op.false_branch().takeBody(op.else_branch());
326 
327   // Replace all uses of `op` results with that of `mhlo.IfOp`.
328   op->replaceAllUsesWith(if_op);
329 
330   op.erase();
331 }
332 
LowerCaseRegion(TF::CaseRegionOp op)333 void LowerCaseRegion(TF::CaseRegionOp op) {
334   Location loc = op.getLoc();
335   OpBuilder builder(op);
336 
337   for (Region& region : op.branches()) {
338     builder.setInsertionPoint(op);
339     ReplaceTerminator(&region.front(), /*extra_results=*/{}, &builder,
340                       /*tuple_return=*/false);
341   }
342 
343   // Create the new `mhlo.case` op and take ownership of regions from
344   // `tf.CaseRegion` op.
345   builder.setInsertionPoint(op);
346   auto case_op = builder.create<mhlo::CaseOp>(
347       loc, op.getResultTypes(), op.branch_index(), op.branches().size());
348   for (auto region : llvm::zip(case_op.branches(), op.branches()))
349     std::get<0>(region).takeBody(std::get<1>(region));
350 
351   // Replace all uses of `op` results with that of `mhlo.CaseOp`.
352   op.replaceAllUsesWith(case_op);
353   op.erase();
354 }
355 
LowerWhileRegion(TF::WhileRegionOp op)356 void LowerWhileRegion(TF::WhileRegionOp op) {
357   Location loc = op.getLoc();
358   OpBuilder builder(op);
359 
360   SmallVector<Value, 3> inputs(op.input());
361   const int inputs_size = inputs.size();
362   llvm::SetVector<Value> implicit_inputs;
363   getUsedValuesDefinedAbove(op.getOperation()->getRegions(), implicit_inputs);
364   inputs.append(implicit_inputs.begin(), implicit_inputs.end());
365 
366   builder.setInsertionPoint(op);
367 
368   // Create the new `mhlo.while` op with 'inputs'. Implicit inputs are also
369   // returned.
370   auto while_result_types = llvm::to_vector<4>(op.getResultTypes());
371   while_result_types.reserve(while_result_types.size() +
372                              implicit_inputs.size());
373   for (const auto& implicit_input : implicit_inputs)
374     while_result_types.emplace_back(implicit_input.getType());
375   auto while_op =
376       builder.create<mhlo::WhileOp>(loc, while_result_types, inputs);
377 
378   // Rewrite cond and associated block arguments and terminator. Ownership of
379   // cond region is transfered over from `tf.WhileRegion` to `mhlo.while`.
380   Region& cond = while_op.cond();
381   cond.takeBody(op.cond());
382   Block& cond_block = cond.front();
383   builder.setInsertionPointToStart(&cond_block);
384 
385   // Add args corresponding to 'implicit_inputs'.
386   for (const auto& implicit_input : implicit_inputs)
387     cond_block.addArgument(implicit_input.getType(), loc);
388   ReplaceImplicitInputs(&cond_block, inputs_size,
389                         implicit_inputs.getArrayRef());
390   // Cond always returns a single result of bool type.
391   ReplaceTerminator(&cond_block, /*extra_results=*/{}, &builder,
392                     /*tuple_return=*/false);
393 
394   // Rewrite body and associated block arguments and terminator. Ownership of
395   // body region is transfered over from `tf.WhileRegion` to `mhlo.while`.
396   Region& body = while_op.body();
397   body.takeBody(op.body());
398   Block& body_block = body.front();
399   builder.setInsertionPointToStart(&body_block);
400   // Add args corresponding to 'implicit_inputs'.
401   for (const auto& implicit_input : implicit_inputs)
402     body_block.addArgument(implicit_input.getType(), loc);
403   auto implicit_input_elements = ReplaceImplicitInputs(
404       &body_block, inputs_size, implicit_inputs.getArrayRef());
405   ReplaceTerminator(&body_block, implicit_input_elements, &builder, false);
406 
407   // Replace all uses of `op` results with that of `mhlo.while`.
408   builder.setInsertionPoint(op);
409   if (while_op.getNumResults() > 1) {
410     for (const auto& result_it : llvm::enumerate(op.getResults()))
411       result_it.value().replaceAllUsesWith(
412           while_op.getResult(result_it.index()));
413   } else {
414     op->replaceAllUsesWith(while_op);
415   }
416   op.erase();
417 }
418 }  // namespace
419 
runOnOperation()420 void LegalizeTFControlFlow::runOnOperation() {
421   getOperation().walk([&](Operation* op) {
422     if (auto while_op = dyn_cast<TF::WhileOp>(op)) {
423       LowerWhile(while_op);
424       return;
425     }
426     if (auto while_region_op = dyn_cast<TF::WhileRegionOp>(op)) {
427       LowerWhileRegion(while_region_op);
428       return;
429     }
430     if (auto if_op = dyn_cast<TF::IfOp>(op)) {
431       LowerIf(if_op);
432       return;
433     }
434     if (auto if_region_op = dyn_cast<TF::IfRegionOp>(op)) {
435       LowerIfRegion(if_region_op);
436       return;
437     }
438     if (auto case_op = dyn_cast<TF::CaseOp>(op)) {
439       LowerCase(case_op);
440       return;
441     }
442     if (auto case_region_op = dyn_cast<TF::CaseRegionOp>(op)) {
443       LowerCaseRegion(case_region_op);
444       return;
445     }
446   });
447 }
448 }  // namespace mhlo
449 }  // namespace mlir
450