xref: /aosp_15_r20/external/tensorflow/tensorflow/core/transforms/functional_to_region/impl.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/transforms/functional_to_region/impl.h"
17 
18 #include <algorithm>
19 #include <tuple>
20 
21 #include "llvm/ADT/DenseSet.h"
22 #include "llvm/ADT/STLExtras.h"
23 #include "llvm/ADT/Sequence.h"
24 #include "mlir/IR/BlockAndValueMapping.h"  // from @llvm-project
25 #include "mlir/IR/BuiltinAttributes.h"  // from @llvm-project
26 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
27 #include "mlir/IR/Diagnostics.h"  // from @llvm-project
28 #include "mlir/IR/PatternMatch.h"  // from @llvm-project
29 #include "mlir/IR/SymbolTable.h"  // from @llvm-project
30 #include "mlir/IR/Value.h"  // from @llvm-project
31 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
32 #include "tensorflow/core/ir/dialect.h"
33 #include "tensorflow/core/ir/ops.h"
34 #include "tensorflow/core/ir/types/dialect.h"
35 #include "tensorflow/core/ir/utility.h"
36 #include "tensorflow/core/transforms/utils/utils.h"
37 
38 namespace mlir {
39 namespace tfg {
40 
41 //===----------------------------------------------------------------------===//
42 // Pattern Definitions
43 //===----------------------------------------------------------------------===//
44 
45 namespace {
46 // Base class for patterns that convert functional ops to region-based ops. This
47 // class contains common utility functions and class members.
48 class BasePattern {
49  public:
BasePattern(SymbolTable & table,TFGraphDialect & dialect)50   BasePattern(SymbolTable &table, TFGraphDialect &dialect)
51       : table_(table), dialect_(dialect) {}
52 
53  protected:
54   // Lookup, using the symbol table, a graph function.
LookupFunc(FuncAttr func_ref) const55   GraphFuncOp LookupFunc(FuncAttr func_ref) const {
56     return table_.lookup<GraphFuncOp>(func_ref.getName().getLeafReference());
57   }
58 
59   // Split a range of non-control and control operands.
SplitControl(ValueRange values) const60   std::pair<ValueRange, ValueRange> SplitControl(ValueRange values) const {
61     return SplitDataAndControlValues(values, dialect_.getControlType());
62   }
63 
64   // Convert the terminator of a region from `return` to `yield`.
65   YieldOp ReplaceReturnWithYield(Block &block, TypeRange types,
66                                  PatternRewriter &rewriter) const;
67 
68   // Copy a region from a function body to a loop body, reordering the arguments
69   // from function order (pairs of data and control values) to loop order (all
70   // data values followed by all control values).
71   void CloneAndReorderArgs(TypeRange types, Region &from, Region &to,
72                            PatternRewriter &rewriter) const;
73 
74   // Clone ops from one region to another with a given value mapping. Rename
75   // clone ops with unique names.
76   void CloneAndRename(Region &from, Region &to, BlockAndValueMapping &bv) const;
77 
78  protected:
79   // Symbol table for looking up branch/loop functions.
80   SymbolTable &table_;
81   // Dialect reference for getting cached values.
82   TFGraphDialect &dialect_;
83 };
84 
85 // Base class for converting a functional control-flow `SourceOp` to a
86 // region-based `DestOp`.
87 template <typename SourceOp, typename DestOp>
88 class ConvertFunctionalToRegionPattern : public OpRewritePattern<SourceOp>,
89                                          public BasePattern {
90  public:
ConvertFunctionalToRegionPattern(MLIRContext * context,SymbolTable & table,TFGraphDialect & dialect)91   explicit ConvertFunctionalToRegionPattern(MLIRContext *context,
92                                             SymbolTable &table,
93                                             TFGraphDialect &dialect)
94       : OpRewritePattern<SourceOp>(context, /*benefit=*/1,
95                                    {DestOp::getOperationName()}),
96         BasePattern(table, dialect) {}
97 };
98 
99 // Base class for patterns to convert an if-like TFG op to region form.
100 template <typename IfLikeOp, typename IfLikeRegionOp>
101 struct ConvertIfLikeOp
102     : public ConvertFunctionalToRegionPattern<IfLikeOp, IfLikeRegionOp> {
103   using ConvertFunctionalToRegionPattern<
104       IfLikeOp, IfLikeRegionOp>::ConvertFunctionalToRegionPattern;
105 
106   LogicalResult matchAndRewrite(IfLikeOp op,
107                                 PatternRewriter &rewriter) const override;
108 };
109 
110 using ConvertIfOp = ConvertIfLikeOp<IfOp, IfRegionOp>;
111 using ConvertStatelessIfOp =
112     ConvertIfLikeOp<StatelessIfOp, StatelessIfRegionOp>;
113 using ConvertStatefulIfOp = ConvertIfLikeOp<StatefulIfOp, StatefulIfRegionOp>;
114 
115 // Base class for patterns to convert a case-like TFG op to region form.
116 template <typename CaseLikeOp, typename CaseLikeRegionOp>
117 struct ConvertCaseLikeOp
118     : public ConvertFunctionalToRegionPattern<CaseLikeOp, CaseLikeRegionOp> {
119   using ConvertFunctionalToRegionPattern<
120       CaseLikeOp, CaseLikeRegionOp>::ConvertFunctionalToRegionPattern;
121 
122   LogicalResult matchAndRewrite(CaseLikeOp op,
123                                 PatternRewriter &rewriter) const override;
124 };
125 
126 using ConvertCaseOp = ConvertCaseLikeOp<CaseOp, CaseRegionOp>;
127 using ConvertStatelessCaseOp =
128     ConvertCaseLikeOp<StatelessCaseOp, StatelessCaseRegionOp>;
129 using ConvertStatefulCaseOp =
130     ConvertCaseLikeOp<StatefulCaseOp, StatefulCaseRegionOp>;
131 
132 // Base class for patterns to convert a while-like TFG op to region form.
133 template <typename WhileLikeOp, typename WhileLikeRegionOp>
134 struct ConvertWhileLikeOp
135     : public ConvertFunctionalToRegionPattern<WhileLikeOp, WhileLikeRegionOp> {
136   using ConvertFunctionalToRegionPattern<
137       WhileLikeOp, WhileLikeRegionOp>::ConvertFunctionalToRegionPattern;
138 
139   LogicalResult matchAndRewrite(WhileLikeOp op,
140                                 PatternRewriter &rewriter) const override;
141 };
142 
143 using ConvertWhileOp = ConvertWhileLikeOp<WhileOp, WhileRegionOp>;
144 using ConvertStatelessWhileOp =
145     ConvertWhileLikeOp<StatelessWhileOp, StatelessWhileRegionOp>;
146 using ConvertStatefulWhileOp =
147     ConvertWhileLikeOp<StatefulWhileOp, StatefulWhileRegionOp>;
148 
149 // Convert a functional for-loop to a region-based for-loop.
150 struct ConvertForOp
151     : public ConvertFunctionalToRegionPattern<ForOp, ForRegionOp> {
152   using ConvertFunctionalToRegionPattern<
153       ForOp, ForRegionOp>::ConvertFunctionalToRegionPattern;
154 
155   LogicalResult matchAndRewrite(tfg::ForOp op,
156                                 PatternRewriter &rewriter) const override;
157 };
158 
159 }  // namespace
160 
161 //===----------------------------------------------------------------------===//
162 // Utility Functions
163 //===----------------------------------------------------------------------===//
164 
165 // We cannot inline or modify a function if it does not exist, if it is generic,
166 // if it has a computed gradient, or if it is marked for compilation (e.g. by
167 // XLA).
CannotInline(GraphFuncOp func)168 static bool CannotInline(GraphFuncOp func) {
169   return !func || func.generic() || func.gradient() ||
170          func.isMarkedForCompilation();
171 }
172 
173 // Determine which optional attributes of a non-generic function to preserve.
174 // Preserved attributes:
175 // - `description`
176 // - `is_stateful`
177 // - `resource_arg_unique_ids_keys`
178 // - `resource_arg_unique_ids_values`
179 //
180 // The attributes of a non-generic function to preserve:
181 // - Intrinsic `tfg.*` attributes are preserved.
182 // - Non-intrinsic `tf.*` attributes are preserved.
183 //
184 // The result attributes of a non-generic function to preserve:
185 // - Intrinsic `tfg.*` attributes are preserved.
PreserveFunctionAttributes(GraphFuncOp func)186 static DictionaryAttr PreserveFunctionAttributes(GraphFuncOp func) {
187   NamedAttrList preserved_attrs;
188   const auto preserve = [&](StringAttr name) {
189     if (Attribute attr = func->getAttr(name))
190       preserved_attrs.append(name, attr);
191   };
192   preserve(func.descriptionAttrName());
193   preserve(func.is_statefulAttrName());
194   preserve(func.resource_arg_unique_ids_keysAttrName());
195   preserve(func.resource_arg_unique_ids_valuesAttrName());
196   // Propagate tf.* attributes.
197   // TODO(jeffniu): `tf` dialect is not loaded.
198   for (const NamedAttribute &attr : func->getAttrs())
199     if (attr.getName().getValue().startswith("tf."))
200       preserved_attrs.append(attr);
201 
202   // Certain pipelines (Brella) will split a graph into subgraphs before merging
203   // them back together. If the subgraphs pass through conversion to and from
204   // region form, the previously unique branch/loop body function names become
205   // not unique, which prevents the graphs from being correctly merged back
206   // together. Also, if an op is referenced in two different subgraphs, if
207   // Grappler changes the function name, the reference will only be valid in the
208   // first subgraph, leading to a function-not-found error. Preserve the
209   // original function name.
210   preserve(func.sym_nameAttrName());
211 
212   return preserved_attrs.getDictionary(func.getContext());
213 }
214 
215 // Given the function, argument, and result attributes to be preserved,
216 // determine if they are empty and can be dropped.
ArePreservedAttrsEmpty(DictionaryAttr func_attrs,ArrayAttr arg_attrs,ArrayAttr res_attrs)217 static bool ArePreservedAttrsEmpty(DictionaryAttr func_attrs,
218                                    ArrayAttr arg_attrs, ArrayAttr res_attrs) {
219   const auto is_empty = [](DictionaryAttr dict) { return dict.empty(); };
220   return func_attrs.empty() &&
221          llvm::all_of(arg_attrs.getAsRange<DictionaryAttr>(), is_empty) &&
222          llvm::all_of(res_attrs.getAsRange<DictionaryAttr>(), is_empty);
223 }
224 
225 // Determine if the region attributes are empty.
AreRegionAttrsEmpty(RegionAttr attrs)226 static bool AreRegionAttrsEmpty(RegionAttr attrs) {
227   return ArePreservedAttrsEmpty(attrs.getAttrs(), attrs.getArgAttrs(),
228                                 attrs.getResAttrs());
229 }
230 
231 // Preserve certain attributes of a function so that they can be used later if
232 // the region op is converted back to functional form. When `If` and `Case` are
233 // converted, all arguments attributes are dropped because the arguments are
234 // converted to implicit captures. For `While` and `For`, no arguments are
235 // removed.
236 //
237 // If `drop_args` is set, then all argument attributes are dropped, regardless
238 // of the number of arguments in the function.
239 //
240 // If `allow_empty` is set, then this function will always return a non-null
241 // attribute, even if the region attributes are empty.
PreserveAttributes(GraphFuncOp func,bool drop_args=false,bool allow_empty=false)242 static RegionAttr PreserveAttributes(GraphFuncOp func, bool drop_args = false,
243                                      bool allow_empty = false) {
244   DictionaryAttr func_attrs = PreserveFunctionAttributes(func);
245   // Since all argument and result attributes are preserved, just propagate the
246   // array attributes. Remove the control argument attributes from the argument
247   // attributes.
248   const auto every_other = [](ArrayAttr attrs) {
249     SmallVector<Attribute> others;
250     for (unsigned i = 0; i < attrs.size(); i += 2) others.push_back(attrs[i]);
251     return ArrayAttr::get(attrs.getContext(), others);
252   };
253 
254   ArrayAttr arg_attrs = drop_args || !func.arg_attrs()
255                             ? ArrayAttr::get(func.getContext(), {})
256                             : every_other(*func.arg_attrs());
257   ArrayAttr res_attrs = func.res_attrs()
258                             ? *func.res_attrs()
259                             : ArrayAttr::get(func.getContext(), {});
260 
261   if (!allow_empty && ArePreservedAttrsEmpty(func_attrs, arg_attrs, res_attrs))
262     return nullptr;
263   return RegionAttr::get(func_attrs, arg_attrs, res_attrs);
264 }
265 
ReplaceReturnWithYield(Block & block,TypeRange types,PatternRewriter & rewriter) const266 YieldOp BasePattern::ReplaceReturnWithYield(Block &block, TypeRange types,
267                                             PatternRewriter &rewriter) const {
268   auto op = cast<ReturnOp>(block.getTerminator());
269   rewriter.setInsertionPoint(op);
270   ValueRange args, ctls;
271   std::tie(args, ctls) = SplitControl(op.getOperands());
272   return rewriter.replaceOpWithNewOp<YieldOp>(op, args, ctls);
273 }
274 
CloneAndReorderArgs(TypeRange types,Region & from,Region & to,PatternRewriter & rewriter) const275 void BasePattern::CloneAndReorderArgs(TypeRange types, Region &from, Region &to,
276                                       PatternRewriter &rewriter) const {
277   ControlType control_ty = dialect_.getControlType();
278   BlockAndValueMapping bv;
279   CloneAndRename(from, to, bv);
280   SmallVector<Location> arg_locs(types.size(), from.getLoc());
281   for (auto &it :
282        llvm::enumerate(llvm::to_vector(to.addArguments(types, arg_locs)))) {
283     BlockArgument arg = to.getArgument(it.index() * 2);
284     BlockArgument ctl = to.getArgument(arg.getArgNumber() + 1);
285     arg.replaceAllUsesWith(it.value());
286     ctl.replaceAllUsesWith(to.addArgument(control_ty, arg.getLoc()));
287   }
288   llvm::BitVector erase_indices(to.getNumArguments());
289   erase_indices.set(0, types.size() * 2);
290   to.front().eraseArguments(erase_indices);
291 }
292 
CloneAndRename(Region & from,Region & to,BlockAndValueMapping & bv) const293 void BasePattern::CloneAndRename(Region &from, Region &to,
294                                  BlockAndValueMapping &bv) const {
295   from.cloneInto(&to, bv);
296   StringAttr name_id = dialect_.getNameAttrIdentifier();
297   auto op_name = to.getParentOp()->getAttrOfType<StringAttr>(name_id);
298   if (!op_name) return;
299   for (Operation &op : to.getOps()) {
300     if (auto name = op.getAttrOfType<StringAttr>(name_id)) {
301       auto new_name =
302           StringAttr::get(op.getContext(), name.getValue() + "_tfg_inlined_" +
303                                                op_name.getValue() + "_" +
304                                                Twine(to.getRegionNumber()));
305       op.setAttr(name_id, new_name);
306     }
307   }
308 }
309 
310 //===----------------------------------------------------------------------===//
311 // ConvertIfLikeOp
312 //===----------------------------------------------------------------------===//
313 
314 template <typename IfLikeOp, typename IfLikeRegionOp>
matchAndRewrite(IfLikeOp op,PatternRewriter & rewriter) const315 LogicalResult ConvertIfLikeOp<IfLikeOp, IfLikeRegionOp>::matchAndRewrite(
316     IfLikeOp op, PatternRewriter &rewriter) const {
317   GraphFuncOp then_func = this->LookupFunc(op.then_branch());
318   GraphFuncOp else_func = this->LookupFunc(op.else_branch());
319   if (CannotInline(then_func) || CannotInline(else_func)) return failure();
320 
321   // Create the region-based op, passing in the required attributes.
322   ValueRange args, ctls;
323   std::tie(args, ctls) = this->SplitControl(op.args());
324   auto region_op = rewriter.create<IfLikeRegionOp>(
325       op.getLoc(), op.getResultTypes(), op.cond(), ctls,
326       op.then_branch().getAttrs(), op.else_branch().getAttrs(),
327       PreserveAttributes(then_func, /*drop_args=*/true),
328       PreserveAttributes(else_func, /*drop_args=*/true));
329   util::ForwardNonIntrinsicAttributes(op, region_op);
330 
331   // Move the regions over and replace the block arguments.
332   ControlType control_ty = this->dialect_.getControlType();
333   BlockAndValueMapping then_bv, else_bv;
334   auto func_args =
335       llvm::zip(then_func.getArguments(), else_func.getArguments()).begin();
336   rewriter.setInsertionPoint(region_op);
337   Value then_arg, else_arg, then_ctl, else_ctl;
338   for (Value arg : args) {
339     std::tie(then_arg, else_arg) = *func_args;
340     ++func_args;
341     std::tie(then_ctl, else_ctl) = *func_args;
342     ++func_args;
343     Value ctl = LookupControlDependency(arg);
344     then_bv.map(then_arg, arg);
345     else_bv.map(else_arg, arg);
346     then_bv.map(then_ctl, ctl);
347     else_bv.map(else_ctl, ctl);
348   }
349   this->CloneAndRename(then_func.body(), region_op.then_region(), then_bv);
350   this->CloneAndRename(else_func.body(), region_op.else_region(), else_bv);
351 
352   // Replace the terminators `return` with `yield`.
353   TypeRange ret_types = region_op.outs().getTypes();
354   this->ReplaceReturnWithYield(region_op.then_block(), ret_types, rewriter);
355   this->ReplaceReturnWithYield(region_op.else_block(), ret_types, rewriter);
356   rewriter.replaceOp(op, region_op.getResults());
357   return success();
358 }
359 
360 //===----------------------------------------------------------------------===//
361 // ConvertCaseLikeOp
362 //===----------------------------------------------------------------------===//
363 
364 template <typename CaseLikeOp, typename CaseLikeRegionOp>
matchAndRewrite(CaseLikeOp op,PatternRewriter & rewriter) const365 LogicalResult ConvertCaseLikeOp<CaseLikeOp, CaseLikeRegionOp>::matchAndRewrite(
366     CaseLikeOp op, PatternRewriter &rewriter) const {
367   // Lookup all the branch functions and save their attributes.
368   SmallVector<GraphFuncOp> branch_funcs;
369   SmallVector<Attribute> branch_attrs;
370   branch_funcs.reserve(op.branches().size());
371   for (auto attr : op.branches().template getAsRange<FuncAttr>()) {
372     GraphFuncOp branch_func = this->LookupFunc(attr);
373     if (CannotInline(branch_func)) return failure();
374     branch_funcs.push_back(branch_func);
375     branch_attrs.push_back(attr.getAttrs());
376   }
377 
378   SmallVector<Attribute> preserved_attrs;
379   for (GraphFuncOp func : branch_funcs) {
380     preserved_attrs.push_back(
381         PreserveAttributes(func, /*drop_args=*/true, /*allow_empty=*/true));
382   }
383   ArrayAttr region_attrs = nullptr;
384   if (!llvm::all_of(preserved_attrs, [](Attribute attr) {
385         return AreRegionAttrsEmpty(attr.cast<RegionAttr>());
386       }))
387     region_attrs = rewriter.getArrayAttr(preserved_attrs);
388 
389   // Create the region-based op, passing in the required attributes.
390   ValueRange args, ctls;
391   std::tie(args, ctls) = this->SplitControl(op.args());
392   auto region_op = rewriter.create<CaseLikeRegionOp>(
393       op.getLoc(), op.getResultTypes(), op.branch_index(), ctls,
394       rewriter.getArrayAttr(branch_attrs), region_attrs, op.branches().size());
395   util::ForwardNonIntrinsicAttributes(op, region_op);
396 
397   // Move the regions over and replace the block arguments.
398   ControlType control_ty = this->dialect_.getControlType();
399   SmallVector<BlockAndValueMapping> bvs(branch_funcs.size(), {});
400   rewriter.setInsertionPoint(region_op);
401   for (auto &arg : llvm::enumerate(args)) {
402     for (auto it : llvm::zip(branch_funcs, bvs)) {
403       BlockArgument branch_arg =
404           GraphFuncOp::getDataValue(std::get<0>(it).body(), arg.index());
405       BlockAndValueMapping &bv = std::get<1>(it);
406       bv.map(branch_arg, arg.value());
407       bv.map(GraphFuncOp::getControlTokenOf(branch_arg),
408              LookupControlDependency(arg.value()));
409     }
410   }
411   for (auto it : llvm::zip(branch_funcs, region_op.branches(), bvs)) {
412     this->CloneAndRename(std::get<0>(it).body(), std::get<1>(it),
413                          std::get<2>(it));
414   }
415 
416   // Replace the terminators `return` with `yield`.
417   TypeRange ret_types = region_op.outs().getTypes();
418   for (Region &branch : region_op.branches())
419     this->ReplaceReturnWithYield(branch.front(), ret_types, rewriter);
420   rewriter.replaceOp(op, region_op.getResults());
421   return success();
422 }
423 
424 //===----------------------------------------------------------------------===//
425 // ConvertWhileLikeOp
426 //===----------------------------------------------------------------------===//
427 
428 template <typename WhileLikeOp, typename WhileLikeRegionOp>
429 LogicalResult
matchAndRewrite(WhileLikeOp op,PatternRewriter & rewriter) const430 ConvertWhileLikeOp<WhileLikeOp, WhileLikeRegionOp>::matchAndRewrite(
431     WhileLikeOp op, PatternRewriter &rewriter) const {
432   GraphFuncOp cond_func = this->LookupFunc(op.cond());
433   GraphFuncOp body_func = this->LookupFunc(op.body());
434   if (CannotInline(cond_func) || CannotInline(body_func)) return failure();
435 
436   // Note that `tfg.While` may not have the same input and output types. We will
437   // need to insert casts.
438   // TODO(jeffniu): Change this to call the infer return types builder.
439   ValueRange init, ctls;
440   std::tie(init, ctls) = this->SplitControl(op.args());
441   auto region_op = rewriter.create<WhileLikeRegionOp>(
442       op.getLoc(), op.getResultTypes(), init, ctls,
443       op.parallel_iterationsAttr(), op.cond().getAttrs(), op.body().getAttrs(),
444       PreserveAttributes(cond_func), PreserveAttributes(body_func));
445   util::ForwardNonIntrinsicAttributes(op, region_op);
446 
447   // Just copy the function bodies into the regions. `RegionBranchOpInterface`
448   // requires that we re-order the block arguments such that the control tokens
449   // all come after the data arguments.
450   this->CloneAndReorderArgs(init.getTypes(), cond_func.body(),
451                             region_op.cond_region(), rewriter);
452   this->CloneAndReorderArgs(init.getTypes(), body_func.body(),
453                             region_op.body_region(), rewriter);
454   this->ReplaceReturnWithYield(region_op.body_block(), init.getTypes(),
455                                rewriter);
456 
457   // Replace `return(tensor<*xi1>)` with `condition`.
458   auto ret_op = cast<ReturnOp>(region_op.cond_block().getTerminator());
459   ValueRange ret_args, ret_ctls;
460   std::tie(ret_args, ret_ctls) = this->SplitControl(ret_op.getOperands());
461   rewriter.setInsertionPoint(ret_op);
462   rewriter.replaceOpWithNewOp<ConditionOp>(
463       ret_op, ret_args.front(), GetLoopRegionDataArgs(region_op.cond_region()),
464       ret_ctls);
465   rewriter.replaceOp(op, region_op->getResults());
466   return success();
467 }
468 
469 //===----------------------------------------------------------------------===//
470 // ConvertForOp
471 //===----------------------------------------------------------------------===//
472 
matchAndRewrite(tfg::ForOp op,PatternRewriter & rewriter) const473 LogicalResult ConvertForOp::matchAndRewrite(tfg::ForOp op,
474                                             PatternRewriter &rewriter) const {
475   GraphFuncOp body_func = LookupFunc(op.body());
476   if (CannotInline(body_func)) return failure();
477 
478   // Note that `For` may not have the same input and output typse, although
479   // `ForRegion` does. We will need to insert casts.
480   ValueRange init, ctls;
481   std::tie(init, ctls) = SplitControl(op.args());
482   auto region_op = rewriter.create<ForRegionOp>(
483       op.getLoc(), op.getResultTypes(), op.start(), op.limit(), op.delta(),
484       init, ctls, op.body().getAttrs(), PreserveAttributes(body_func));
485   util::ForwardNonIntrinsicAttributes(op, region_op);
486 
487   // Copy the function body into the region. One index type must be added.
488   OperandRange args = op.getOperands().drop_front(2).drop_back(ctls.size());
489   CloneAndReorderArgs(args.getTypes(), body_func.body(),
490                       region_op.body_region(), rewriter);
491   ReplaceReturnWithYield(region_op.body_block(), init.getTypes(), rewriter);
492   rewriter.replaceOp(op, region_op->getResults());
493   return success();
494 }
495 
496 //===----------------------------------------------------------------------===//
497 // Populate Patterns
498 //===----------------------------------------------------------------------===//
499 
PopulateFunctionalToRegionPatterns(RewritePatternSet & patterns,SymbolTable & table)500 void PopulateFunctionalToRegionPatterns(RewritePatternSet &patterns,
501                                         SymbolTable &table) {
502   patterns.insert<ConvertIfOp, ConvertStatelessIfOp, ConvertStatefulIfOp,
503                   ConvertWhileOp, ConvertStatelessWhileOp,
504                   ConvertStatefulWhileOp, ConvertCaseOp, ConvertStatelessCaseOp,
505                   ConvertStatefulCaseOp, ConvertForOp>(
506       patterns.getContext(), table,
507       *patterns.getContext()->getOrLoadDialect<TFGraphDialect>());
508 }
509 
510 }  // namespace tfg
511 }  // namespace mlir
512