1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // This transformation pass transforms region bases control flow operations in
17 // the TensorFlow dialect to their functional counterparts, i.e.,
18 // tf.IfRegion -> tf.If and tf.WhileRegion -> tf.While
19
20 #include "llvm/ADT/STLExtras.h"
21 #include "llvm/ADT/SmallVector.h"
22 #include "llvm/Support/Casting.h"
23 #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
24 #include "mlir/IR/Attributes.h" // from @llvm-project
25 #include "mlir/IR/Builders.h" // from @llvm-project
26 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
27 #include "mlir/IR/Operation.h" // from @llvm-project
28 #include "mlir/IR/TypeUtilities.h" // from @llvm-project
29 #include "mlir/IR/Value.h" // from @llvm-project
30 #include "mlir/IR/Verifier.h" // from @llvm-project
31 #include "mlir/IR/Visitors.h" // from @llvm-project
32 #include "mlir/Pass/Pass.h" // from @llvm-project
33 #include "mlir/Pass/PassRegistry.h" // from @llvm-project
34 #include "mlir/Transforms/RegionUtils.h" // from @llvm-project
35 #include "tensorflow/compiler/mlir/op_or_arg_name_mapper.h"
36 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
37 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
38 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
39 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes_detail.h"
40 #include "tensorflow/compiler/mlir/tensorflow/utils/attribute_utils.h"
41
42 #define DEBUG_TYPE "tf-region-cf-to-functional"
43
44 namespace mlir {
45 namespace TF {
46
47 namespace {
48
49 constexpr char kElseFuncNameAttr[] = "_else_func_name";
50 constexpr char kThenFuncNameAttr[] = "_then_func_name";
51 constexpr char kXlaPropagateCompileTimeConsts[] =
52 "_xla_propagate_compile_time_consts";
53
54 struct RegionControlFlowToFunctional
55 : public TF::RegionControlFlowToFunctionalPassBase<
56 RegionControlFlowToFunctional> {
57 void runOnOperation() override;
58
59 private:
60 LogicalResult ConvertIfOp(IfRegionOp if_region);
61 LogicalResult ConvertWhileOp(WhileRegionOp while_region);
62
63 // Get unique name by using the loc to name mapping.
64 std::string GetName(Operation* op, StringRef suffix);
65
66 tensorflow::OpOrArgLocNameMapper mapper;
67 llvm::SmallVector<func::FuncOp, 4> worklist;
68 };
69
GetName(Operation * op,StringRef suffix)70 std::string RegionControlFlowToFunctional::GetName(Operation* op,
71 StringRef suffix) {
72 return (mapper.GetUniqueName(op) + suffix).str();
73 }
74
75 // Returns all the external values referenced from the given regions. If the
76 // external value is a constant, sink it into the region instead (and do not
77 // add it to the returned vector).
CollectExternValues(Region & first,Region & second)78 llvm::SmallVector<Value, 4> CollectExternValues(Region& first, Region& second) {
79 llvm::SetVector<Value> extern_values;
80
81 for (Region* region : {&first, &second}) {
82 llvm::SetVector<Value> region_extern_values;
83 getUsedValuesDefinedAbove(*region, region_extern_values);
84
85 // Sink down constants into the functions.
86 for (auto extern_value : region_extern_values) {
87 if (!matchPattern(extern_value, m_Constant())) {
88 extern_values.insert(extern_value);
89 continue;
90 }
91 // Add constant at start of region.
92 auto const_builder = OpBuilder::atBlockBegin(®ion->front());
93 auto const_value = const_builder.clone(*extern_value.getDefiningOp());
94 replaceAllUsesInRegionWith(extern_value, const_value->getResult(0),
95 *region);
96 }
97 }
98
99 return llvm::to_vector<4>(extern_values);
100 }
101
102 // Copies over optional attributes from source region op `src` to the given
103 // functional op `dst` and appropriately overrides any necessary attributes.
CopyAndOverrideAttributes(Operation * src,Operation * dst,OpBuilder * builder)104 void CopyAndOverrideAttributes(Operation* src, Operation* dst,
105 OpBuilder* builder) {
106 CopyDeviceAndUnderscoredAttributes(src, dst);
107
108 // Explicitly override attribute to propagate constants to the functions
109 // before compiling to XLA. This is necessary along with conversion to
110 // functional format because inlined regions may have moved loop invariant ops
111 // outside of the region which may cause some new legalization failures.
112 // TODO(b/126739593): Enable this attribute in TensorFlow by default. Also,
113 // see b/185542519 for the context.
114 dst->setAttr(kXlaPropagateCompileTimeConsts, builder->getBoolAttr(true));
115 }
116
117 // Extracts the contents of a region with a single block into a new function.
118 // `extern_values` is the set of external values that the region refers to.
119 //
120 // Inputs to the terminator of the region are converted to return values of
121 // the function. If `extern_values_passthrough` is true, all the extern values
122 // are also added as return values from the function
ExtractSingleBlockRegion(Region & region,StringRef name,llvm::SmallVectorImpl<Value> & extern_values,llvm::SmallVectorImpl<func::FuncOp> & worklist,bool extern_values_passthrough)123 void ExtractSingleBlockRegion(Region& region, StringRef name,
124 llvm::SmallVectorImpl<Value>& extern_values,
125 llvm::SmallVectorImpl<func::FuncOp>& worklist,
126 bool extern_values_passthrough) {
127 ModuleOp module = region.getParentOfType<ModuleOp>();
128 auto builder = OpBuilder::atBlockBegin(module.getBody());
129 auto loc = region.getParentOp()->getLoc();
130 Block& entry = region.front();
131 int num_region_arguments = entry.getNumArguments();
132 Operation* terminator = entry.getTerminator();
133
134 // Build the function type. Region arguments and extern values together
135 // become the function arguments, with region arguments going first.
136 auto input_types = llvm::to_vector<4>(entry.getArgumentTypes());
137 for (auto input : extern_values) input_types.push_back(input.getType());
138
139 // Terminator operands and pass through extern values (if enabled) together
140 // become the function return values.
141 auto return_types = llvm::to_vector<4>(terminator->getOperandTypes());
142 if (extern_values_passthrough)
143 for (auto input : extern_values) return_types.push_back(input.getType());
144
145 auto type = FunctionType::get(region.getContext(), input_types, return_types);
146
147 // Create new function and extract region body into the function.
148 auto outlined_func = builder.create<func::FuncOp>(loc, name, type);
149 Region& func_region = outlined_func.getBody();
150 func_region.takeBody(region);
151 Block& first_block = func_region.front();
152
153 // Replace all external uses with function arguments.
154 for (auto it : llvm::enumerate(extern_values)) {
155 Value arg = first_block.addArgument(it.value().getType(), loc);
156 replaceAllUsesInRegionWith(it.value(), arg, func_region);
157 }
158
159 // Function return values are all the terminator operands + pass through
160 // extern values (if enabled).
161 auto return_values = llvm::to_vector<4>(terminator->getOperands());
162 if (extern_values_passthrough)
163 return_values.insert(return_values.end(),
164 first_block.args_begin() + num_region_arguments,
165 first_block.args_end());
166
167 // Replace the existing terminator with a return.
168 terminator = first_block.getTerminator();
169 builder.setInsertionPoint(terminator);
170 builder.create<func::ReturnOp>(terminator->getLoc(), return_values);
171 terminator->erase();
172
173 outlined_func.setPrivate();
174
175 // Add the outlined function to the worklist in case its body has
176 // IfRegion or WhileRegion ops that need to converted.
177 worklist.push_back(outlined_func);
178 }
179
180 // Returns call for region with single call whose result feeds into the
181 // terminator of the region. if `allow_to_bool` is true, also allows a single
182 // ToBoolOp between the region yield and the call. Returns none if the region
183 // does not conform to this pattern.
IsSingleCallRegion(Region & region,bool allow_to_bool=false)184 llvm::Optional<func::CallOp> IsSingleCallRegion(Region& region,
185 bool allow_to_bool = false) {
186 if (!llvm::hasSingleElement(region)) return llvm::None;
187
188 Block& block = region.front();
189 auto it = block.rbegin();
190 YieldOp yield = dyn_cast<YieldOp>(*it++);
191
192 if (it == block.rend()) return llvm::None;
193
194 // Operation which is expected to consume all the call results.
195 Operation* call_consumer = yield;
196
197 // Allow a single ToBoolOp between the call and the yield (valid only
198 // when the yield has a single operand)
199 if (allow_to_bool && yield.getNumOperands() == 1 && isa<ToBoolOp>(*it)) {
200 if (it->getResult(0) != yield.getOperand(0)) return llvm::None;
201 call_consumer = cast<ToBoolOp>(*it);
202 it++;
203 if (it == block.rend()) return llvm::None;
204 }
205
206 // Check if there is a Call before the Yield.
207 func::CallOp call = dyn_cast<func::CallOp>(*it++);
208 if (!call) return llvm::None;
209
210 // All call results should feed into expected consumer
211 // All results of the call should feed into the yield.
212 if (call.getNumResults() != call_consumer->getNumOperands())
213 return llvm::None;
214
215 for (auto res_it : llvm::zip(call.getResults(), call_consumer->getOperands()))
216 if (std::get<0>(res_it) != std::get<1>(res_it)) return llvm::None;
217
218 // There can only be non-truncating cast op's prior to the call.
219 for (; it != block.rend(); ++it) {
220 CastOp cast = dyn_cast<CastOp>(*it);
221 if (!cast || cast.Truncate()) return llvm::None;
222 }
223
224 return call;
225 }
226
227 using ArgMatcherFn = function_ref<bool(Value, Region&, Value, Region&)>;
228
229 // Returns whether the arguments of the given 2 calls are match (after looking
230 // through cast ops). `matcher` is the predicate used to check if two arguments
231 // match.
MatchCallArgs(func::CallOp first,func::CallOp second,ArgMatcherFn matcher)232 bool MatchCallArgs(func::CallOp first, func::CallOp second,
233 ArgMatcherFn matcher) {
234 if (first.getNumOperands() != second.getNumOperands()) return false;
235
236 Region& first_region = *first->getParentRegion();
237 Region& second_region = *second->getParentRegion();
238
239 for (auto it : llvm::zip(first.getArgOperands(), second.getArgOperands())) {
240 // Get the defining Op, skipping over casts.
241 auto get_defining_op = [](Value value) {
242 while (auto cast_op =
243 llvm::dyn_cast_or_null<CastOp>(value.getDefiningOp())) {
244 // Consider cast compatibility in case
245 // %cast = "tf.Cast"(%0) : (tensor<2xi64>) -> tensor<2xf32>
246 // is skipped.
247 if (cast_op.SrcT() != cast_op.DstT()) {
248 break;
249 }
250 value = cast_op.getOperand();
251 }
252 return value;
253 };
254 Value first_arg = get_defining_op(std::get<0>(it));
255 Value second_arg = get_defining_op(std::get<1>(it));
256
257 if (!matcher(first_arg, first_region, second_arg, second_region))
258 return false;
259 }
260 return true;
261 }
262
263 // Summary information for trivially transforming region based op's to
264 // functional ops. A trivial transformation can be done when the regions are
265 // just calls to functions, in which case no outlining is needed.
266 struct TrivialTransformInfo {
267 // Can the op be transformed trivially?
268 bool can_transform = false;
269
270 // List of callee names (one for each region).
271 llvm::SmallVector<StringRef, 2> callee_names;
272
273 // Analyzes the given calls (from regions attached to the same parent op) to
274 // check if the parent op be transformed to functional form trivially (i.e.,
275 // reusing existing functions and without outlining). This is possible when
276 // all the regions are single call regions (checked using matchers outside
277 // this class) and the all the calls match using the given argument matcher.
278 //
279 // If such a trivial transformation is possible, stash the relevant
280 // information needed for the transformation, else indicate that a trivial
281 // transformation is not possible by setting `can_transform` to false.
TrivialTransformInfomlir::TF::__anon6658f75b0111::TrivialTransformInfo282 TrivialTransformInfo(llvm::Optional<func::CallOp> first_call,
283 llvm::Optional<func::CallOp> second_call,
284 ArgMatcherFn arg_matcher) {
285 if (!first_call || !second_call) return;
286
287 if (!MatchCallArgs(first_call.getValue(), second_call.getValue(),
288 arg_matcher))
289 return;
290
291 can_transform = true;
292 callee_names = {first_call.getValue().getCallee(),
293 second_call.getValue().getCallee()};
294 }
295 };
296
297 // Transform IfRegionOp to IfOp.
ConvertIfOp(IfRegionOp if_region)298 LogicalResult RegionControlFlowToFunctional::ConvertIfOp(IfRegionOp if_region) {
299 llvm::SmallVector<Value, 4> extern_values;
300
301 // For IfOp, arguments of calls in the then and else regions match if they
302 // are the same value.
303 auto if_arg_matcher = [&](Value first, Region&, Value second, Region&) {
304 if (first != second) return false;
305
306 // collect the call arguments post lookup through cast Op's
307 extern_values.push_back(first);
308 return true;
309 };
310
311 const TrivialTransformInfo tti(IsSingleCallRegion(if_region.then_branch()),
312 IsSingleCallRegion(if_region.else_branch()),
313 if_arg_matcher);
314
315 std::string then_name, else_name;
316
317 if (tti.can_transform) {
318 // We can transform to functional form trivially without outlining.
319 then_name = tti.callee_names[0].str();
320 else_name = tti.callee_names[1].str();
321 } else {
322 // Collect external values that are used within the else and then bodies.
323 extern_values =
324 CollectExternValues(if_region.then_branch(), if_region.else_branch());
325
326 // These external values need to be added as inputs to the generated If. The
327 // order is determined by the order of these values the `extern_vales`.
328
329 // Create 2 new functions with the input signature matching this order,
330 // and outline the `then` and `else` regions by moving the bodies of these
331 // regions into these functions. Replace tf.yield with a regular return.
332 if (if_region->hasAttrOfType<StringAttr>(kThenFuncNameAttr) &&
333 !if_region._then_func_nameAttr().getValue().empty()) {
334 then_name =
335 mapper.GetUniqueName(if_region._then_func_nameAttr().getValue())
336 .str();
337 } else {
338 then_name = GetName(if_region, "_then");
339 }
340 ExtractSingleBlockRegion(if_region.then_branch(), then_name, extern_values,
341 worklist, /*extern_values_passthrough=*/false);
342
343 if (if_region->hasAttrOfType<StringAttr>(kElseFuncNameAttr) &&
344 !if_region._else_func_nameAttr().getValue().empty()) {
345 else_name =
346 mapper.GetUniqueName(if_region._else_func_nameAttr().getValue())
347 .str();
348 } else {
349 else_name = GetName(if_region, "_else");
350 }
351 ExtractSingleBlockRegion(if_region.else_branch(), else_name, extern_values,
352 worklist, /*extern_values_passthrough=*/false);
353 }
354
355 // Look through ToBool operations for the condition.
356 Value cond = if_region.cond();
357 auto to_bool = dyn_cast_or_null<ToBoolOp>(cond.getDefiningOp());
358 if (to_bool) cond = to_bool.getOperand();
359
360 // Once we have the `then` and `else` functions ready (either outlined or
361 // existing ones), replace the region based op with a functional control flow
362 // op.
363 OpBuilder builder(if_region);
364 auto if_op = builder.create<IfOp>(
365 if_region.getLoc(), if_region.getResultTypes(), cond, extern_values,
366 then_name, else_name, if_region.is_stateless());
367 CopyAndOverrideAttributes(if_region, if_op, &builder);
368
369 if_region.replaceAllUsesWith(if_op.getResults());
370 if_region.erase();
371
372 if (to_bool && to_bool.use_empty()) to_bool.erase();
373 return success();
374 }
375
376 // Transform WhileRegion to WhileOp.
ConvertWhileOp(WhileRegionOp while_region)377 LogicalResult RegionControlFlowToFunctional::ConvertWhileOp(
378 WhileRegionOp while_region) {
379 // For While, the arguments of the calls in the body and cond regions match
380 // if they are region arguments with the same region argument numbers. If the
381 // 2 calls have the same value (an extern value) used as an argument, we
382 // cannot do a trivial transformation because post transform, we will need to
383 // pass this extern value as an argument to the function, so we cannot use the
384 // existing function as is.
385 auto while_arg_matcher = [](Value first, Region& first_region, Value second,
386 Region& second_region) {
387 if (!first.isa<BlockArgument>() || !second.isa<BlockArgument>())
388 return false;
389 BlockArgument first_block_arg = first.cast<BlockArgument>();
390 BlockArgument second_block_arg = second.cast<BlockArgument>();
391
392 // 2 block arguments will match if they are the same argument number, and
393 // are block arguments of the corresponding containing regions.
394 return first_block_arg.getArgNumber() == second_block_arg.getArgNumber() &&
395 first_block_arg.getParentBlock() == &first_region.front() &&
396 second_block_arg.getParentBlock() == &second_region.front();
397 };
398
399 const TrivialTransformInfo tti(
400 IsSingleCallRegion(while_region.cond(), /*allow_to_bool=*/true),
401 IsSingleCallRegion(while_region.body()), while_arg_matcher);
402
403 // All existing inputs to while region are inputs to the functional while.
404 auto new_inputs = llvm::to_vector<4>(while_region.getOperands());
405
406 // All existing results will also be generated by the functional while.
407 auto new_result_types = llvm::to_vector<4>(while_region.getResultTypes());
408
409 std::string cond_name, body_name;
410 if (tti.can_transform) {
411 // We can transform to functional form trivially without outlining.
412 cond_name = tti.callee_names[0].str();
413 body_name = tti.callee_names[1].str();
414 } else {
415 // The WhileRegion regions can refer to either arguments of the region, or
416 // external values implicitly captured by the region. When converting to
417 // functional form, all such external values need to become function
418 // arguments of the outlined functions, and become pass through values in
419 // the outlined body function. So when outlining the while body, in addition
420 // to the region arguments, all these external references need to be added
421 // as function arguments.
422 llvm::SmallVector<Value, 4> extern_values =
423 CollectExternValues(while_region.cond(), while_region.body());
424
425 // Outline the `cond` and `body` regions by moving the bodies of these
426 // regions into new functions. Replace tf.yield with a regular return.
427 cond_name = GetName(while_region, "_cond");
428 ExtractSingleBlockRegion(while_region.cond(), cond_name, extern_values,
429 worklist, /*extern_values_passthrough=*/false);
430
431 body_name = GetName(while_region, "_body");
432 ExtractSingleBlockRegion(while_region.body(), body_name, extern_values,
433 worklist, /*extern_values_passthrough=*/true);
434
435 // All extern values become additional inputs and additional output types
436 // for the functional while.
437 new_inputs.append(extern_values.begin(), extern_values.end());
438 for (auto ext : extern_values) new_result_types.push_back(ext.getType());
439 }
440
441 // Once we have the `cond` and `body` functions ready (either outlined or
442 // existing ones), replace the region based op with a functional op.
443 OpBuilder builder(while_region);
444 auto while_op = builder.create<WhileOp>(
445 while_region.getLoc(), new_result_types, new_inputs, cond_name, body_name,
446 while_region.parallel_iterations(), while_region.is_stateless(),
447 while_region.shape_invariant());
448 CopyAndOverrideAttributes(while_region, while_op, &builder);
449
450 // Redirect old results to new results.
451 for (auto it : llvm::zip(
452 while_region.getResults(),
453 while_op.getResults().take_front(while_region.getNumResults())))
454 std::get<0>(it).replaceAllUsesWith(std::get<1>(it));
455
456 while_region.erase();
457 return success();
458 }
459
runOnOperation()460 void RegionControlFlowToFunctional::runOnOperation() {
461 ModuleOp module = getOperation();
462
463 // Seed worklist with all functions in the module.
464 worklist = llvm::to_vector<4>(module.getOps<func::FuncOp>());
465 while (!worklist.empty()) {
466 func::FuncOp function = worklist.pop_back_val();
467
468 auto result = function.walk([&](Operation* op) {
469 if (auto if_region = llvm::dyn_cast<IfRegionOp>(op)) {
470 if (failed(ConvertIfOp(if_region))) {
471 op->emitOpError() << "failed to convert to functional form";
472 return WalkResult::interrupt();
473 }
474 } else if (auto while_region = llvm::dyn_cast<WhileRegionOp>(op)) {
475 if (failed(ConvertWhileOp(while_region))) {
476 op->emitOpError() << "failed to convert to functional form";
477 return WalkResult::interrupt();
478 }
479 }
480 return WalkResult::advance();
481 });
482
483 if (result.wasInterrupted()) return signalPassFailure();
484 }
485 }
486
487 } // namespace
488
489 std::unique_ptr<OperationPass<ModuleOp>>
CreateTFRegionControlFlowToFunctional()490 CreateTFRegionControlFlowToFunctional() {
491 return std::make_unique<RegionControlFlowToFunctional>();
492 }
493
494 } // namespace TF
495 } // namespace mlir
496