1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // This transformation pass transforms region bases control flow operations in
17 // the TensorFlow dialect to their functional counterparts, i.e.,
18 // tf.IfRegion ->  tf.If and tf.WhileRegion -> tf.While
19 
20 #include "llvm/ADT/STLExtras.h"
21 #include "llvm/ADT/SmallVector.h"
22 #include "llvm/Support/Casting.h"
23 #include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
24 #include "mlir/IR/Attributes.h"  // from @llvm-project
25 #include "mlir/IR/Builders.h"  // from @llvm-project
26 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
27 #include "mlir/IR/Operation.h"  // from @llvm-project
28 #include "mlir/IR/TypeUtilities.h"  // from @llvm-project
29 #include "mlir/IR/Value.h"  // from @llvm-project
30 #include "mlir/IR/Verifier.h"  // from @llvm-project
31 #include "mlir/IR/Visitors.h"  // from @llvm-project
32 #include "mlir/Pass/Pass.h"  // from @llvm-project
33 #include "mlir/Pass/PassRegistry.h"  // from @llvm-project
34 #include "mlir/Transforms/RegionUtils.h"  // from @llvm-project
35 #include "tensorflow/compiler/mlir/op_or_arg_name_mapper.h"
36 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
37 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
38 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
39 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes_detail.h"
40 #include "tensorflow/compiler/mlir/tensorflow/utils/attribute_utils.h"
41 
42 #define DEBUG_TYPE "tf-region-cf-to-functional"
43 
44 namespace mlir {
45 namespace TF {
46 
47 namespace {
48 
49 constexpr char kElseFuncNameAttr[] = "_else_func_name";
50 constexpr char kThenFuncNameAttr[] = "_then_func_name";
51 constexpr char kXlaPropagateCompileTimeConsts[] =
52     "_xla_propagate_compile_time_consts";
53 
54 struct RegionControlFlowToFunctional
55     : public TF::RegionControlFlowToFunctionalPassBase<
56           RegionControlFlowToFunctional> {
57   void runOnOperation() override;
58 
59  private:
60   LogicalResult ConvertIfOp(IfRegionOp if_region);
61   LogicalResult ConvertWhileOp(WhileRegionOp while_region);
62 
63   // Get unique name by using the loc to name mapping.
64   std::string GetName(Operation* op, StringRef suffix);
65 
66   tensorflow::OpOrArgLocNameMapper mapper;
67   llvm::SmallVector<func::FuncOp, 4> worklist;
68 };
69 
GetName(Operation * op,StringRef suffix)70 std::string RegionControlFlowToFunctional::GetName(Operation* op,
71                                                    StringRef suffix) {
72   return (mapper.GetUniqueName(op) + suffix).str();
73 }
74 
75 // Returns all the external values referenced from the given regions. If the
76 // external value is a constant, sink it into the region instead (and do not
77 // add it to the returned vector).
CollectExternValues(Region & first,Region & second)78 llvm::SmallVector<Value, 4> CollectExternValues(Region& first, Region& second) {
79   llvm::SetVector<Value> extern_values;
80 
81   for (Region* region : {&first, &second}) {
82     llvm::SetVector<Value> region_extern_values;
83     getUsedValuesDefinedAbove(*region, region_extern_values);
84 
85     // Sink down constants into the functions.
86     for (auto extern_value : region_extern_values) {
87       if (!matchPattern(extern_value, m_Constant())) {
88         extern_values.insert(extern_value);
89         continue;
90       }
91       // Add constant at start of region.
92       auto const_builder = OpBuilder::atBlockBegin(&region->front());
93       auto const_value = const_builder.clone(*extern_value.getDefiningOp());
94       replaceAllUsesInRegionWith(extern_value, const_value->getResult(0),
95                                  *region);
96     }
97   }
98 
99   return llvm::to_vector<4>(extern_values);
100 }
101 
102 // Copies over optional attributes from source region op `src` to the given
103 // functional op `dst` and appropriately overrides any necessary attributes.
CopyAndOverrideAttributes(Operation * src,Operation * dst,OpBuilder * builder)104 void CopyAndOverrideAttributes(Operation* src, Operation* dst,
105                                OpBuilder* builder) {
106   CopyDeviceAndUnderscoredAttributes(src, dst);
107 
108   // Explicitly override attribute to propagate constants to the functions
109   // before compiling to XLA. This is necessary along with conversion to
110   // functional format because inlined regions may have moved loop invariant ops
111   // outside of the region which may cause some new legalization failures.
112   // TODO(b/126739593): Enable this attribute in TensorFlow by default. Also,
113   // see b/185542519 for the context.
114   dst->setAttr(kXlaPropagateCompileTimeConsts, builder->getBoolAttr(true));
115 }
116 
117 // Extracts the contents of a region with a single block into a new function.
118 // `extern_values` is the set of external values that the region refers to.
119 //
120 // Inputs to the terminator of the region are converted to return values of
121 // the function. If `extern_values_passthrough` is true, all the extern values
122 // are also added as return values from the function
ExtractSingleBlockRegion(Region & region,StringRef name,llvm::SmallVectorImpl<Value> & extern_values,llvm::SmallVectorImpl<func::FuncOp> & worklist,bool extern_values_passthrough)123 void ExtractSingleBlockRegion(Region& region, StringRef name,
124                               llvm::SmallVectorImpl<Value>& extern_values,
125                               llvm::SmallVectorImpl<func::FuncOp>& worklist,
126                               bool extern_values_passthrough) {
127   ModuleOp module = region.getParentOfType<ModuleOp>();
128   auto builder = OpBuilder::atBlockBegin(module.getBody());
129   auto loc = region.getParentOp()->getLoc();
130   Block& entry = region.front();
131   int num_region_arguments = entry.getNumArguments();
132   Operation* terminator = entry.getTerminator();
133 
134   // Build the function type. Region arguments and extern values together
135   // become the function arguments, with region arguments going first.
136   auto input_types = llvm::to_vector<4>(entry.getArgumentTypes());
137   for (auto input : extern_values) input_types.push_back(input.getType());
138 
139   // Terminator operands and pass through extern values (if enabled) together
140   // become the function return values.
141   auto return_types = llvm::to_vector<4>(terminator->getOperandTypes());
142   if (extern_values_passthrough)
143     for (auto input : extern_values) return_types.push_back(input.getType());
144 
145   auto type = FunctionType::get(region.getContext(), input_types, return_types);
146 
147   // Create new function and extract region body into the function.
148   auto outlined_func = builder.create<func::FuncOp>(loc, name, type);
149   Region& func_region = outlined_func.getBody();
150   func_region.takeBody(region);
151   Block& first_block = func_region.front();
152 
153   // Replace all external uses with function arguments.
154   for (auto it : llvm::enumerate(extern_values)) {
155     Value arg = first_block.addArgument(it.value().getType(), loc);
156     replaceAllUsesInRegionWith(it.value(), arg, func_region);
157   }
158 
159   // Function return values are all the terminator operands + pass through
160   // extern values (if enabled).
161   auto return_values = llvm::to_vector<4>(terminator->getOperands());
162   if (extern_values_passthrough)
163     return_values.insert(return_values.end(),
164                          first_block.args_begin() + num_region_arguments,
165                          first_block.args_end());
166 
167   // Replace the existing terminator with a return.
168   terminator = first_block.getTerminator();
169   builder.setInsertionPoint(terminator);
170   builder.create<func::ReturnOp>(terminator->getLoc(), return_values);
171   terminator->erase();
172 
173   outlined_func.setPrivate();
174 
175   // Add the outlined function to the worklist in case its body has
176   // IfRegion or WhileRegion ops that need to converted.
177   worklist.push_back(outlined_func);
178 }
179 
180 // Returns call for region with single call whose result feeds into the
181 // terminator of the region. if `allow_to_bool` is true, also allows a single
182 // ToBoolOp between the region yield and the call. Returns none if the region
183 // does not conform to this pattern.
IsSingleCallRegion(Region & region,bool allow_to_bool=false)184 llvm::Optional<func::CallOp> IsSingleCallRegion(Region& region,
185                                                 bool allow_to_bool = false) {
186   if (!llvm::hasSingleElement(region)) return llvm::None;
187 
188   Block& block = region.front();
189   auto it = block.rbegin();
190   YieldOp yield = dyn_cast<YieldOp>(*it++);
191 
192   if (it == block.rend()) return llvm::None;
193 
194   // Operation which is expected to consume all the call results.
195   Operation* call_consumer = yield;
196 
197   // Allow a single ToBoolOp between the call and the yield (valid only
198   // when the yield has a single operand)
199   if (allow_to_bool && yield.getNumOperands() == 1 && isa<ToBoolOp>(*it)) {
200     if (it->getResult(0) != yield.getOperand(0)) return llvm::None;
201     call_consumer = cast<ToBoolOp>(*it);
202     it++;
203     if (it == block.rend()) return llvm::None;
204   }
205 
206   // Check if there is a Call before the Yield.
207   func::CallOp call = dyn_cast<func::CallOp>(*it++);
208   if (!call) return llvm::None;
209 
210   // All call results should feed into expected consumer
211   // All results of the call should feed into the yield.
212   if (call.getNumResults() != call_consumer->getNumOperands())
213     return llvm::None;
214 
215   for (auto res_it : llvm::zip(call.getResults(), call_consumer->getOperands()))
216     if (std::get<0>(res_it) != std::get<1>(res_it)) return llvm::None;
217 
218   // There can only be non-truncating cast op's prior to the call.
219   for (; it != block.rend(); ++it) {
220     CastOp cast = dyn_cast<CastOp>(*it);
221     if (!cast || cast.Truncate()) return llvm::None;
222   }
223 
224   return call;
225 }
226 
227 using ArgMatcherFn = function_ref<bool(Value, Region&, Value, Region&)>;
228 
229 // Returns whether the arguments of the given 2 calls are match (after looking
230 // through cast ops). `matcher` is the predicate used to check if two arguments
231 // match.
MatchCallArgs(func::CallOp first,func::CallOp second,ArgMatcherFn matcher)232 bool MatchCallArgs(func::CallOp first, func::CallOp second,
233                    ArgMatcherFn matcher) {
234   if (first.getNumOperands() != second.getNumOperands()) return false;
235 
236   Region& first_region = *first->getParentRegion();
237   Region& second_region = *second->getParentRegion();
238 
239   for (auto it : llvm::zip(first.getArgOperands(), second.getArgOperands())) {
240     // Get the defining Op, skipping over casts.
241     auto get_defining_op = [](Value value) {
242       while (auto cast_op =
243                  llvm::dyn_cast_or_null<CastOp>(value.getDefiningOp())) {
244         // Consider cast compatibility in case
245         //    %cast = "tf.Cast"(%0) : (tensor<2xi64>) -> tensor<2xf32>
246         // is skipped.
247         if (cast_op.SrcT() != cast_op.DstT()) {
248           break;
249         }
250         value = cast_op.getOperand();
251       }
252       return value;
253     };
254     Value first_arg = get_defining_op(std::get<0>(it));
255     Value second_arg = get_defining_op(std::get<1>(it));
256 
257     if (!matcher(first_arg, first_region, second_arg, second_region))
258       return false;
259   }
260   return true;
261 }
262 
263 // Summary information for trivially transforming region based op's to
264 // functional ops. A trivial transformation can be done when the regions are
265 // just calls to functions, in which case no outlining is needed.
266 struct TrivialTransformInfo {
267   // Can the op be transformed trivially?
268   bool can_transform = false;
269 
270   // List of callee names (one for each region).
271   llvm::SmallVector<StringRef, 2> callee_names;
272 
273   // Analyzes the given calls (from regions attached to the same parent op) to
274   // check if the parent op be transformed to functional form trivially (i.e.,
275   // reusing existing functions and without outlining). This is possible when
276   // all the regions are single call regions (checked using matchers outside
277   // this class) and the all the calls match using the given argument matcher.
278   //
279   // If such a trivial transformation is possible, stash the relevant
280   // information needed for the transformation, else indicate that a trivial
281   // transformation is not possible by setting `can_transform` to false.
TrivialTransformInfomlir::TF::__anon6658f75b0111::TrivialTransformInfo282   TrivialTransformInfo(llvm::Optional<func::CallOp> first_call,
283                        llvm::Optional<func::CallOp> second_call,
284                        ArgMatcherFn arg_matcher) {
285     if (!first_call || !second_call) return;
286 
287     if (!MatchCallArgs(first_call.getValue(), second_call.getValue(),
288                        arg_matcher))
289       return;
290 
291     can_transform = true;
292     callee_names = {first_call.getValue().getCallee(),
293                     second_call.getValue().getCallee()};
294   }
295 };
296 
297 // Transform IfRegionOp to IfOp.
ConvertIfOp(IfRegionOp if_region)298 LogicalResult RegionControlFlowToFunctional::ConvertIfOp(IfRegionOp if_region) {
299   llvm::SmallVector<Value, 4> extern_values;
300 
301   // For IfOp, arguments of calls in the then and else regions match if they
302   // are the same value.
303   auto if_arg_matcher = [&](Value first, Region&, Value second, Region&) {
304     if (first != second) return false;
305 
306     // collect the call arguments post lookup through cast Op's
307     extern_values.push_back(first);
308     return true;
309   };
310 
311   const TrivialTransformInfo tti(IsSingleCallRegion(if_region.then_branch()),
312                                  IsSingleCallRegion(if_region.else_branch()),
313                                  if_arg_matcher);
314 
315   std::string then_name, else_name;
316 
317   if (tti.can_transform) {
318     // We can transform to functional form trivially without outlining.
319     then_name = tti.callee_names[0].str();
320     else_name = tti.callee_names[1].str();
321   } else {
322     // Collect external values that are used within the else and then bodies.
323     extern_values =
324         CollectExternValues(if_region.then_branch(), if_region.else_branch());
325 
326     // These external values need to be added as inputs to the generated If. The
327     // order is determined by the order of these values the `extern_vales`.
328 
329     // Create 2 new functions with the input signature matching this order,
330     // and outline the `then` and `else` regions by moving the bodies of these
331     // regions into these functions. Replace tf.yield with a regular return.
332     if (if_region->hasAttrOfType<StringAttr>(kThenFuncNameAttr) &&
333         !if_region._then_func_nameAttr().getValue().empty()) {
334       then_name =
335           mapper.GetUniqueName(if_region._then_func_nameAttr().getValue())
336               .str();
337     } else {
338       then_name = GetName(if_region, "_then");
339     }
340     ExtractSingleBlockRegion(if_region.then_branch(), then_name, extern_values,
341                              worklist, /*extern_values_passthrough=*/false);
342 
343     if (if_region->hasAttrOfType<StringAttr>(kElseFuncNameAttr) &&
344         !if_region._else_func_nameAttr().getValue().empty()) {
345       else_name =
346           mapper.GetUniqueName(if_region._else_func_nameAttr().getValue())
347               .str();
348     } else {
349       else_name = GetName(if_region, "_else");
350     }
351     ExtractSingleBlockRegion(if_region.else_branch(), else_name, extern_values,
352                              worklist, /*extern_values_passthrough=*/false);
353   }
354 
355   // Look through ToBool operations for the condition.
356   Value cond = if_region.cond();
357   auto to_bool = dyn_cast_or_null<ToBoolOp>(cond.getDefiningOp());
358   if (to_bool) cond = to_bool.getOperand();
359 
360   // Once we have the `then` and `else` functions ready (either outlined or
361   // existing ones), replace the region based op with a functional control flow
362   // op.
363   OpBuilder builder(if_region);
364   auto if_op = builder.create<IfOp>(
365       if_region.getLoc(), if_region.getResultTypes(), cond, extern_values,
366       then_name, else_name, if_region.is_stateless());
367   CopyAndOverrideAttributes(if_region, if_op, &builder);
368 
369   if_region.replaceAllUsesWith(if_op.getResults());
370   if_region.erase();
371 
372   if (to_bool && to_bool.use_empty()) to_bool.erase();
373   return success();
374 }
375 
376 // Transform WhileRegion to WhileOp.
ConvertWhileOp(WhileRegionOp while_region)377 LogicalResult RegionControlFlowToFunctional::ConvertWhileOp(
378     WhileRegionOp while_region) {
379   // For While, the arguments of the calls in the body and cond regions match
380   // if they are region arguments with the same region argument numbers. If the
381   // 2 calls have the same value (an extern value) used as an argument, we
382   // cannot do a trivial transformation because post transform, we will need to
383   // pass this extern value as an argument to the function, so we cannot use the
384   // existing function as is.
385   auto while_arg_matcher = [](Value first, Region& first_region, Value second,
386                               Region& second_region) {
387     if (!first.isa<BlockArgument>() || !second.isa<BlockArgument>())
388       return false;
389     BlockArgument first_block_arg = first.cast<BlockArgument>();
390     BlockArgument second_block_arg = second.cast<BlockArgument>();
391 
392     // 2 block arguments will match if they are the same argument number, and
393     // are block arguments of the corresponding containing regions.
394     return first_block_arg.getArgNumber() == second_block_arg.getArgNumber() &&
395            first_block_arg.getParentBlock() == &first_region.front() &&
396            second_block_arg.getParentBlock() == &second_region.front();
397   };
398 
399   const TrivialTransformInfo tti(
400       IsSingleCallRegion(while_region.cond(), /*allow_to_bool=*/true),
401       IsSingleCallRegion(while_region.body()), while_arg_matcher);
402 
403   // All existing inputs to while region are inputs to the functional while.
404   auto new_inputs = llvm::to_vector<4>(while_region.getOperands());
405 
406   // All existing results will also be generated by the functional while.
407   auto new_result_types = llvm::to_vector<4>(while_region.getResultTypes());
408 
409   std::string cond_name, body_name;
410   if (tti.can_transform) {
411     // We can transform to functional form trivially without outlining.
412     cond_name = tti.callee_names[0].str();
413     body_name = tti.callee_names[1].str();
414   } else {
415     // The WhileRegion regions can refer to either arguments of the region, or
416     // external values implicitly captured by the region. When converting to
417     // functional form, all such external values need to become function
418     // arguments of the outlined functions, and become pass through values in
419     // the outlined body function. So when outlining the while body, in addition
420     // to the region arguments, all these external references need to be added
421     // as function arguments.
422     llvm::SmallVector<Value, 4> extern_values =
423         CollectExternValues(while_region.cond(), while_region.body());
424 
425     // Outline the `cond` and `body` regions by moving the bodies of these
426     // regions into new functions. Replace tf.yield with a regular return.
427     cond_name = GetName(while_region, "_cond");
428     ExtractSingleBlockRegion(while_region.cond(), cond_name, extern_values,
429                              worklist, /*extern_values_passthrough=*/false);
430 
431     body_name = GetName(while_region, "_body");
432     ExtractSingleBlockRegion(while_region.body(), body_name, extern_values,
433                              worklist, /*extern_values_passthrough=*/true);
434 
435     // All extern values become additional inputs and additional output types
436     // for the functional while.
437     new_inputs.append(extern_values.begin(), extern_values.end());
438     for (auto ext : extern_values) new_result_types.push_back(ext.getType());
439   }
440 
441   // Once we have the `cond` and `body` functions ready (either outlined or
442   // existing ones), replace the region based op with a functional op.
443   OpBuilder builder(while_region);
444   auto while_op = builder.create<WhileOp>(
445       while_region.getLoc(), new_result_types, new_inputs, cond_name, body_name,
446       while_region.parallel_iterations(), while_region.is_stateless(),
447       while_region.shape_invariant());
448   CopyAndOverrideAttributes(while_region, while_op, &builder);
449 
450   // Redirect old results to new results.
451   for (auto it : llvm::zip(
452            while_region.getResults(),
453            while_op.getResults().take_front(while_region.getNumResults())))
454     std::get<0>(it).replaceAllUsesWith(std::get<1>(it));
455 
456   while_region.erase();
457   return success();
458 }
459 
runOnOperation()460 void RegionControlFlowToFunctional::runOnOperation() {
461   ModuleOp module = getOperation();
462 
463   // Seed worklist with all functions in the module.
464   worklist = llvm::to_vector<4>(module.getOps<func::FuncOp>());
465   while (!worklist.empty()) {
466     func::FuncOp function = worklist.pop_back_val();
467 
468     auto result = function.walk([&](Operation* op) {
469       if (auto if_region = llvm::dyn_cast<IfRegionOp>(op)) {
470         if (failed(ConvertIfOp(if_region))) {
471           op->emitOpError() << "failed to convert to functional form";
472           return WalkResult::interrupt();
473         }
474       } else if (auto while_region = llvm::dyn_cast<WhileRegionOp>(op)) {
475         if (failed(ConvertWhileOp(while_region))) {
476           op->emitOpError() << "failed to convert to functional form";
477           return WalkResult::interrupt();
478         }
479       }
480       return WalkResult::advance();
481     });
482 
483     if (result.wasInterrupted()) return signalPassFailure();
484   }
485 }
486 
487 }  // namespace
488 
489 std::unique_ptr<OperationPass<ModuleOp>>
CreateTFRegionControlFlowToFunctional()490 CreateTFRegionControlFlowToFunctional() {
491   return std::make_unique<RegionControlFlowToFunctional>();
492 }
493 
494 }  // namespace TF
495 }  // namespace mlir
496