1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // This file implements logic for lowering TensorFlow dialect's control flow to
17 // the XLA dialect.
18
19 #include <cstddef>
20 #include <cstdint>
21 #include <iterator>
22 #include <numeric>
23 #include <tuple>
24
25 #include "llvm/ADT/ArrayRef.h"
26 #include "llvm/ADT/STLExtras.h"
27 #include "llvm/ADT/SetVector.h"
28 #include "llvm/ADT/SmallVector.h"
29 #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
30 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
31 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
32 #include "mlir/IR/Operation.h" // from @llvm-project
33 #include "mlir/IR/OperationSupport.h" // from @llvm-project
34 #include "mlir/IR/Types.h" // from @llvm-project
35 #include "mlir/Pass/Pass.h" // from @llvm-project
36 #include "mlir/Pass/PassRegistry.h" // from @llvm-project
37 #include "mlir/Transforms/RegionUtils.h" // from @llvm-project
38 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
39 #include "tensorflow/compiler/mlir/xla/transforms/xla_legalize_tf_passes_detail.h"
40 #include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
41
42 using mlir::PassRegistration;
43
44 namespace mlir {
45 namespace mhlo {
46 namespace {
47 class LegalizeTFControlFlow
48 : public LegalizeTFControlFlowBase<LegalizeTFControlFlow> {
49 public:
50 void runOnOperation() override;
51 };
52 } // namespace
53
54 std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
createLegalizeTFControlFlowPass()55 createLegalizeTFControlFlowPass() {
56 return std::make_unique<LegalizeTFControlFlow>();
57 }
58
59 namespace {
60
Detuple(Value tuple,ValueRange replace,OpBuilder * builder)61 void Detuple(Value tuple, ValueRange replace, OpBuilder* builder) {
62 // De-tuple the results of the xla hlo if result.
63 for (auto result_it : llvm::enumerate(replace)) {
64 auto get_tuple_value = builder->create<mhlo::GetTupleElementOp>(
65 result_it.value().getLoc(), tuple, result_it.index());
66 result_it.value().replaceAllUsesWith(get_tuple_value);
67 }
68 }
69
70 // For mlir::IfOp or mlir::CaseOp, replace the uses of their region's block
71 // arguments with 'implicit_operands'. Here | 'implicit_operands' | == Number of
72 // arguments in any of the regions in IfOp or CaseOp.
ReplaceBlockArgumentsWithImplicitOperands(mlir::Operation * op,llvm::ArrayRef<mlir::Value> implicit_operands)73 void ReplaceBlockArgumentsWithImplicitOperands(
74 mlir::Operation* op, llvm::ArrayRef<mlir::Value> implicit_operands) {
75 assert((mlir::dyn_cast<mlir::mhlo::IfOp>(*op) ||
76 mlir::dyn_cast<mlir::mhlo::CaseOp>(*op)) &&
77 "Unexpected mlir op in ReplaceBlockArgumentsWithImplicitOperands!");
78
79 for (auto& region : op->getRegions()) {
80 int implicit_operand_index = 0;
81 for (auto arg : region.getArguments()) {
82 assert(implicit_operand_index < implicit_operands.size());
83 arg.replaceAllUsesWith(implicit_operands[implicit_operand_index++]);
84 }
85
86 region.front().eraseArguments(
87 llvm::to_vector(llvm::seq<unsigned>(0, region.getNumArguments())));
88 }
89 }
90
91 // Imports the source region into the destination region. MHLO supports
92 // multiple arguments per branch and multiple returns which are individually
93 // tupled together during export to XLA. This tupling is needed as XLA if/while
94 // operation only supports one argument per branch and a single return value.
95 // `tuple_arg` allows any branch that requires additional arguments to have
96 // their values be tupled together. Similarly, `tuple_return` allows the results
97 // of the if/while operation to be tupled together.
ImportXlaRegion(mlir::func::FuncOp func,Region * dest_region,Location loc,bool tuple_return=true,bool tuple_arg=true)98 void ImportXlaRegion(mlir::func::FuncOp func, Region* dest_region, Location loc,
99 bool tuple_return = true, bool tuple_arg = true) {
100 OpBuilder builder(dest_region);
101
102 auto entry_block = builder.createBlock(dest_region);
103 func::CallOp result;
104 if (!tuple_arg) {
105 auto inputs = func.getFunctionType().getInputs();
106 auto args = entry_block->addArguments(
107 inputs, SmallVector<Location>(inputs.size(), loc));
108 ArrayRef<Value> callop_args(args.begin(), args.end());
109 result = builder.create<func::CallOp>(loc, func, callop_args);
110 } else {
111 auto tuple_arg = entry_block->addArgument(
112 builder.getTupleType(func.getFunctionType().getInputs()), loc);
113 llvm::SmallVector<Value, 4> detupled_args;
114 detupled_args.reserve(func.getNumArguments());
115
116 for (int64_t i = 0, s = func.getNumArguments(); i < s; i++) {
117 auto extract = builder.create<GetTupleElementOp>(loc, tuple_arg, i);
118 detupled_args.push_back(extract);
119 }
120
121 result = builder.create<func::CallOp>(loc, func, detupled_args);
122 }
123
124 if (!tuple_return) {
125 builder.create<mhlo::ReturnOp>(loc, result.getResults());
126 } else {
127 auto tuple_op = builder.create<TupleOp>(loc, result.getResults());
128 builder.create<mhlo::ReturnOp>(loc, tuple_op.getResult());
129 }
130 }
131
LowerIf(TF::IfOp op)132 void LowerIf(TF::IfOp op) {
133 Location loc = op.getLoc();
134 OpBuilder builder(op);
135
136 SmallVector<Value, 3> inputs(op.input());
137
138 // Create the new `mhlo.if` op.
139 auto if_op = builder.create<mhlo::IfOp>(loc, op.getResultTypes(), op.cond());
140
141 // Import the regions for both the true and false cases. These regions
142 // must be updated to tuple the return results together and use the xla hlo
143 // return op.
144 ImportXlaRegion(op.then_function(), &if_op.true_branch(), loc,
145 /*tuple_return=*/false, /*tuple_arg=*/false);
146 ImportXlaRegion(op.else_function(), &if_op.false_branch(), loc,
147 /*tuple_return=*/false, /*tuple_arg=*/false);
148
149 // Replace the uses of block-arguments of the IfOp with the
150 // implicit_operands.
151 ReplaceBlockArgumentsWithImplicitOperands(if_op.getOperation(), inputs);
152
153 op->replaceAllUsesWith(if_op);
154 op.erase();
155 }
156
LowerCase(TF::CaseOp op)157 void LowerCase(TF::CaseOp op) {
158 Location loc = op.getLoc();
159 OpBuilder builder(op);
160
161 SmallVector<Value, 4> inputs(op.input());
162
163 // Create the new `mhlo.case` op.
164 auto case_op = builder.create<mhlo::CaseOp>(
165 loc, op.getResultTypes(), op.branch_index(), op.branches().size());
166
167 // Import the regions for all branches.
168 for (unsigned i = 0; i < op.num_branches(); ++i) {
169 mlir::func::FuncOp branch_func = op.branch_function(i);
170 ImportXlaRegion(branch_func, &case_op.branches()[i], loc,
171 /*tuple_return=*/false, /*tuple_arg=*/false);
172 }
173
174 // Replace the uses of block-arguments of the IfOp with the
175 // implicit_operands.
176 ReplaceBlockArgumentsWithImplicitOperands(case_op.getOperation(), inputs);
177
178 op.replaceAllUsesWith(case_op);
179 op.erase();
180 }
181
LowerWhile(TF::WhileOp op)182 void LowerWhile(TF::WhileOp op) {
183 Location loc = op.getLoc();
184 OpBuilder builder(op);
185
186 // XLA prefers tuple arguments for control flow due to XLA not supporting
187 // multiple return values.
188 SmallVector<Value, 3> inputs(op.input());
189 builder.setInsertionPoint(op);
190
191 // Create the new `mhlo.while` op with inputs.
192 auto while_op =
193 builder.create<mhlo::WhileOp>(loc, op.getResultTypes(), inputs);
194
195 // Import the regions for both the cond and body.
196 ImportXlaRegion(op.body_function(), &while_op.body(), loc,
197 /*tuple_return=*/false, /*tuple_arg=*/false);
198 ImportXlaRegion(op.cond_function(), &while_op.cond(), loc,
199 /*tuple_return=*/false, /*tuple_arg=*/false);
200
201 op->replaceAllUsesWith(while_op);
202 op.erase();
203 }
204
205 // Replaces all block arguments of a block with a single block arg of Tuple
206 // type `tuple_type`. Single block arguments are removed and remapped to
207 // get_tuple_element(tuple_arg, index).
ReplaceBlockArgs(Block * block,Type tuple_type,OpBuilder * builder)208 void ReplaceBlockArgs(Block* block, Type tuple_type, OpBuilder* builder) {
209 auto tuple_arg = block->addArgument(tuple_type, block->getParent()->getLoc());
210 Detuple(tuple_arg, block->getArguments().drop_back(1), builder);
211 for (int i = block->getNumArguments() - 2; i >= 0; --i)
212 block->eraseArgument(i);
213 }
214
215 // Replaces implicitly captured value uses with block arguments.
ReplaceImplicitInputs(Block * block,int offset,ArrayRef<Value> implicit_inputs)216 llvm::SmallVector<Value, 4> ReplaceImplicitInputs(
217 Block* block, int offset, ArrayRef<Value> implicit_inputs) {
218 llvm::SmallVector<Value, 4> implicit_input_elements;
219 implicit_input_elements.reserve(implicit_inputs.size());
220
221 Region* region = block->getParent();
222
223 for (auto& implicit_input : llvm::enumerate(implicit_inputs)) {
224 Value implicit_input_value = implicit_input.value();
225 BlockArgument arg = block->getArgument(implicit_input.index() + offset);
226 implicit_input_elements.emplace_back(arg);
227 for (auto& use :
228 llvm::make_early_inc_range(implicit_input_value.getUses())) {
229 if (!region->isAncestor(use.getOwner()->getParentRegion())) continue;
230 use.set(arg);
231 }
232 }
233
234 return implicit_input_elements;
235 }
236
237 // Replaces implicitly captured value uses with tuple block argument.
238 // get_tuple_element's are created to extract specific values. Values from
239 // get_tuple_element's are returned in the order of `implicit_inputs`.
ReplaceImplicitInputsWithTupleElements(Block * block,int offset,ArrayRef<Value> implicit_inputs,OpBuilder * builder)240 llvm::SmallVector<Value, 4> ReplaceImplicitInputsWithTupleElements(
241 Block* block, int offset, ArrayRef<Value> implicit_inputs,
242 OpBuilder* builder) {
243 llvm::SmallVector<Value, 4> implicit_input_elements;
244 implicit_input_elements.reserve(implicit_inputs.size());
245
246 Region* region = block->getParent();
247 assert(block->getNumArguments() == 1);
248
249 BlockArgument tuple_arg = block->getArgument(0);
250 for (auto& implicit_input : llvm::enumerate(implicit_inputs)) {
251 Value implicit_input_value = implicit_input.value();
252 auto get_tuple_element = builder->create<mhlo::GetTupleElementOp>(
253 implicit_input_value.getLoc(), tuple_arg,
254 implicit_input.index() + offset);
255 implicit_input_elements.emplace_back(get_tuple_element.getResult());
256 for (auto& use :
257 llvm::make_early_inc_range(implicit_input_value.getUses())) {
258 if (!region->isAncestor(use.getOwner()->getParentRegion())) continue;
259 use.set(get_tuple_element.getResult());
260 }
261 }
262
263 return implicit_input_elements;
264 }
265
266 // Finds and replaces implicitly captured value uses with tuple block argument.
267 // A tuple of implicitly captured values is also created and returned, for use
268 // as an operand to the associated mhlo control flow op.
TupleImplicitInputs(Region & region,Location loc,OpBuilder * builder)269 Value TupleImplicitInputs(Region& region, Location loc, OpBuilder* builder) {
270 llvm::SetVector<Value> implicit_inputs;
271 getUsedValuesDefinedAbove(region, region, implicit_inputs);
272 llvm::ArrayRef<Value> implicit_inputs_ref = implicit_inputs.getArrayRef();
273 Value tuple_input = builder->create<mhlo::TupleOp>(loc, implicit_inputs_ref);
274 Block& block = region.front();
275 // `tf.CaseRegion`/`tf.IfRegion` are expected to have no block arguments and
276 // instead all inputs used by their branch regions are implicitly captured
277 // from above.
278 assert(block.getNumArguments() == 0);
279 block.addArgument(tuple_input.getType(), loc);
280 builder->setInsertionPointToStart(&block);
281 ReplaceImplicitInputsWithTupleElements(&block, /*offset=*/0,
282 implicit_inputs_ref, builder);
283 return tuple_input;
284 }
285
286 // Replaces block terminator (tf.Yield) with `mhlo.return`. Additional results
287 // can be returned if `extra_results` is not empty. If `tuple_return` is
288 // set, a tuple of the return values will be set as the terminator operand.
ReplaceTerminator(Block * block,ArrayRef<Value> extra_results,OpBuilder * builder,bool tuple_return=true)289 void ReplaceTerminator(Block* block, ArrayRef<Value> extra_results,
290 OpBuilder* builder, bool tuple_return = true) {
291 Operation* terminator = block->getTerminator();
292 assert(isa<TF::YieldOp>(terminator));
293 Location loc = terminator->getLoc();
294
295 builder->setInsertionPoint(terminator);
296 auto results = llvm::to_vector<4>(terminator->getOperands());
297 results.append(extra_results.begin(), extra_results.end());
298 if (tuple_return) {
299 auto tuple_results = builder->create<mhlo::TupleOp>(loc, results);
300 builder->create<mhlo::ReturnOp>(loc, tuple_results.getResult());
301 } else {
302 builder->create<mhlo::ReturnOp>(loc, results);
303 }
304
305 terminator->erase();
306 }
307
LowerIfRegion(TF::IfRegionOp op)308 void LowerIfRegion(TF::IfRegionOp op) {
309 Location loc = op.getLoc();
310 OpBuilder builder(op);
311
312 builder.setInsertionPoint(op);
313 ReplaceTerminator(&op.then_branch().front(), /*extra_results=*/{}, &builder,
314 /*tuple_return=*/false);
315
316 builder.setInsertionPoint(op);
317 ReplaceTerminator(&op.else_branch().front(), /*extra_results=*/{}, &builder,
318 /*tuple_return=*/false);
319
320 // Create the new `mhlo.if` op and take ownership of regions from
321 // `tf.IfRegion` op.
322 builder.setInsertionPoint(op);
323 auto if_op = builder.create<mhlo::IfOp>(loc, op.getResultTypes(), op.cond());
324 if_op.true_branch().takeBody(op.then_branch());
325 if_op.false_branch().takeBody(op.else_branch());
326
327 // Replace all uses of `op` results with that of `mhlo.IfOp`.
328 op->replaceAllUsesWith(if_op);
329
330 op.erase();
331 }
332
LowerCaseRegion(TF::CaseRegionOp op)333 void LowerCaseRegion(TF::CaseRegionOp op) {
334 Location loc = op.getLoc();
335 OpBuilder builder(op);
336
337 for (Region& region : op.branches()) {
338 builder.setInsertionPoint(op);
339 ReplaceTerminator(®ion.front(), /*extra_results=*/{}, &builder,
340 /*tuple_return=*/false);
341 }
342
343 // Create the new `mhlo.case` op and take ownership of regions from
344 // `tf.CaseRegion` op.
345 builder.setInsertionPoint(op);
346 auto case_op = builder.create<mhlo::CaseOp>(
347 loc, op.getResultTypes(), op.branch_index(), op.branches().size());
348 for (auto region : llvm::zip(case_op.branches(), op.branches()))
349 std::get<0>(region).takeBody(std::get<1>(region));
350
351 // Replace all uses of `op` results with that of `mhlo.CaseOp`.
352 op.replaceAllUsesWith(case_op);
353 op.erase();
354 }
355
LowerWhileRegion(TF::WhileRegionOp op)356 void LowerWhileRegion(TF::WhileRegionOp op) {
357 Location loc = op.getLoc();
358 OpBuilder builder(op);
359
360 SmallVector<Value, 3> inputs(op.input());
361 const int inputs_size = inputs.size();
362 llvm::SetVector<Value> implicit_inputs;
363 getUsedValuesDefinedAbove(op.getOperation()->getRegions(), implicit_inputs);
364 inputs.append(implicit_inputs.begin(), implicit_inputs.end());
365
366 builder.setInsertionPoint(op);
367
368 // Create the new `mhlo.while` op with 'inputs'. Implicit inputs are also
369 // returned.
370 auto while_result_types = llvm::to_vector<4>(op.getResultTypes());
371 while_result_types.reserve(while_result_types.size() +
372 implicit_inputs.size());
373 for (const auto& implicit_input : implicit_inputs)
374 while_result_types.emplace_back(implicit_input.getType());
375 auto while_op =
376 builder.create<mhlo::WhileOp>(loc, while_result_types, inputs);
377
378 // Rewrite cond and associated block arguments and terminator. Ownership of
379 // cond region is transfered over from `tf.WhileRegion` to `mhlo.while`.
380 Region& cond = while_op.cond();
381 cond.takeBody(op.cond());
382 Block& cond_block = cond.front();
383 builder.setInsertionPointToStart(&cond_block);
384
385 // Add args corresponding to 'implicit_inputs'.
386 for (const auto& implicit_input : implicit_inputs)
387 cond_block.addArgument(implicit_input.getType(), loc);
388 ReplaceImplicitInputs(&cond_block, inputs_size,
389 implicit_inputs.getArrayRef());
390 // Cond always returns a single result of bool type.
391 ReplaceTerminator(&cond_block, /*extra_results=*/{}, &builder,
392 /*tuple_return=*/false);
393
394 // Rewrite body and associated block arguments and terminator. Ownership of
395 // body region is transfered over from `tf.WhileRegion` to `mhlo.while`.
396 Region& body = while_op.body();
397 body.takeBody(op.body());
398 Block& body_block = body.front();
399 builder.setInsertionPointToStart(&body_block);
400 // Add args corresponding to 'implicit_inputs'.
401 for (const auto& implicit_input : implicit_inputs)
402 body_block.addArgument(implicit_input.getType(), loc);
403 auto implicit_input_elements = ReplaceImplicitInputs(
404 &body_block, inputs_size, implicit_inputs.getArrayRef());
405 ReplaceTerminator(&body_block, implicit_input_elements, &builder, false);
406
407 // Replace all uses of `op` results with that of `mhlo.while`.
408 builder.setInsertionPoint(op);
409 if (while_op.getNumResults() > 1) {
410 for (const auto& result_it : llvm::enumerate(op.getResults()))
411 result_it.value().replaceAllUsesWith(
412 while_op.getResult(result_it.index()));
413 } else {
414 op->replaceAllUsesWith(while_op);
415 }
416 op.erase();
417 }
418 } // namespace
419
runOnOperation()420 void LegalizeTFControlFlow::runOnOperation() {
421 getOperation().walk([&](Operation* op) {
422 if (auto while_op = dyn_cast<TF::WhileOp>(op)) {
423 LowerWhile(while_op);
424 return;
425 }
426 if (auto while_region_op = dyn_cast<TF::WhileRegionOp>(op)) {
427 LowerWhileRegion(while_region_op);
428 return;
429 }
430 if (auto if_op = dyn_cast<TF::IfOp>(op)) {
431 LowerIf(if_op);
432 return;
433 }
434 if (auto if_region_op = dyn_cast<TF::IfRegionOp>(op)) {
435 LowerIfRegion(if_region_op);
436 return;
437 }
438 if (auto case_op = dyn_cast<TF::CaseOp>(op)) {
439 LowerCase(case_op);
440 return;
441 }
442 if (auto case_region_op = dyn_cast<TF::CaseRegionOp>(op)) {
443 LowerCaseRegion(case_region_op);
444 return;
445 }
446 });
447 }
448 } // namespace mhlo
449 } // namespace mlir
450