xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/mlir/tensorflow/transforms/tfg-to-tfe.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "absl/strings/match.h"
17 #include "llvm/ADT/DenseSet.h"
18 #include "llvm/ADT/STLExtras.h"
19 #include "llvm/Support/FormatVariadic.h"
20 #include "mlir/IR/Attributes.h"  // from @llvm-project
21 #include "mlir/IR/BlockAndValueMapping.h"  // from @llvm-project
22 #include "mlir/IR/Builders.h"  // from @llvm-project
23 #include "mlir/IR/BuiltinAttributes.h"  // from @llvm-project
24 #include "mlir/IR/PatternMatch.h"  // from @llvm-project
25 #include "mlir/IR/SymbolTable.h"  // from @llvm-project
26 #include "mlir/IR/Visitors.h"  // from @llvm-project
27 #include "mlir/Transforms/DialectConversion.h"  // from @llvm-project
28 #include "tensorflow/compiler/mlir/tensorflow/dialect_registration.h"
29 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
30 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
31 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes_detail.h"
32 #include "tensorflow/core/transforms/toposort/pass.h"
33 #include "tensorflow/core/util/device_name_utils.h"
34 
35 namespace mlir {
36 namespace TF {
37 namespace {
38 
39 // FIXME: This should be consistent with
40 // tensorflow::kImportModelDefaultGraphFuncName
41 static const char kImportModelDefaultGraphFuncName[] = "main";
42 
43 // Please refer to the TFG dialect description for the list of used attributes.
44 // Belows are the attributes in TFE.
45 // TFE Arguments and Results (Got from "_Arg",
46 // "_Retval", .etc)
47 //  NodeDef.device <-> "tf.device"
48 //  NodeDef.attr <-> "tf."
49 //
50 // TFE general operations
51 //  NodeDef.device <-> "device"
52 //
53 // The following two functions are only used for mapping/excluding attributes
54 // which are inconsistent between TFG and TFE.
55 //
FilterTfgSpecificArgResultAttributes(mlir::MLIRContext * context,mlir::ArrayRef<Type> types,mlir::ArrayAttr array_attr,llvm::SmallVector<mlir::Type> & output_types,llvm::SmallVector<mlir::DictionaryAttr> & output_attrs)56 static mlir::LogicalResult FilterTfgSpecificArgResultAttributes(
57     mlir::MLIRContext *context, mlir::ArrayRef<Type> types,
58     mlir::ArrayAttr array_attr, llvm::SmallVector<mlir::Type> &output_types,
59     llvm::SmallVector<mlir::DictionaryAttr> &output_attrs) {
60   for (auto it : llvm::zip(
61            types, array_attr.template getAsRange<mlir::DictionaryAttr>())) {
62     if (std::get<0>(it).isa<tfg::ControlType>()) continue;
63     output_types.push_back(std::get<0>(it));
64 
65     mlir::NamedAttrList list;
66     for (mlir::NamedAttribute attr : std::get<1>(it).getValue()) {
67       // Skip if the attribute has "tfg" prefix.
68       if (attr.getName().getValue().startswith("tfg")) continue;
69       list.append(attr);
70     }
71     output_attrs.push_back(list.getDictionary(context));
72   }
73   return mlir::success();
74 }
75 
ReformatOpAttributes(mlir::MLIRContext * context,llvm::ArrayRef<mlir::NamedAttribute> attrs,llvm::SmallVectorImpl<mlir::NamedAttribute> & output)76 static mlir::LogicalResult ReformatOpAttributes(
77     mlir::MLIRContext *context, llvm::ArrayRef<mlir::NamedAttribute> attrs,
78     llvm::SmallVectorImpl<mlir::NamedAttribute> &output) {
79   for (mlir::NamedAttribute attr : attrs) {
80     if (attr.getName().strref().contains(
81             mlir::tfg::TFGraphDialect::getDeviceAttrKey())) {
82       tensorflow::DeviceNameUtils::ParsedName parsed_name;
83       if (!tensorflow::DeviceNameUtils::ParseFullName(
84               attr.getValue().cast<mlir::StringAttr>().getValue().str(),
85               &parsed_name))
86         return mlir::failure();
87       if (!parsed_name.has_type) {
88         parsed_name.type = "CPU";
89         parsed_name.has_type = true;
90       }
91       if (!parsed_name.has_id) {
92         parsed_name.id = 0;
93         parsed_name.has_id = true;
94       }
95       output.push_back(mlir::NamedAttribute(
96           mlir::StringAttr::get(context, "device"),
97           mlir::StringAttr::get(
98               context,
99               tensorflow::DeviceNameUtils::ParsedNameToString(parsed_name))));
100     } else {
101       output.push_back(attr);
102     }
103   }
104   return mlir::success();
105 }
106 
FilterOutBlockArgControlDep(ValueRange operands,llvm::SmallVectorImpl<Value> & filtered)107 static void FilterOutBlockArgControlDep(
108     ValueRange operands, llvm::SmallVectorImpl<Value> &filtered) {
109   for (Value value : operands)
110     if (!value.isa<mlir::BlockArgument>()) filtered.push_back(value);
111 }
112 
113 // Split the tfg.NextIteration into tf_executor::NextIterationSourceOp and
114 // tf_executor::NextIterationSinkOp to break the cycle introduced by itself.
SplitNextIteration(Block & block)115 static void SplitNextIteration(Block &block) {
116   // TODO(b/207144333): Supports callback for unregistered ops
117   block.walk([&](Operation *op) {
118     if (!op->getName().getStringRef().equals("tfg.NextIteration")) return;
119     mlir::OpBuilder builder(op);
120 
121     llvm::SmallVector<Value, 2> new_operands;
122     FilterOutBlockArgControlDep(op->getOperands().drop_front(), new_operands);
123 
124     auto source_op = builder.create<tf_executor::NextIterationSourceOp>(
125         op->getLoc(), op->getOperand(0).getType());
126     builder.create<tf_executor::NextIterationSinkOp>(
127         op->getLoc(), source_op.token(), /*input=*/op->getOperand(0),
128         /*controlInputs=*/new_operands);
129     op->replaceAllUsesWith(
130         ValueRange({source_op.output(), source_op.control()}));
131     op->erase();
132   });
133 }
134 
135 class ConvertGraphOp : public OpConversionPattern<tfg::GraphOp> {
136  public:
137   using OpConversionPattern::OpConversionPattern;
138 
matchAndRewrite(tfg::GraphOp graph,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const139   LogicalResult matchAndRewrite(
140       tfg::GraphOp graph, OpAdaptor adaptor,
141       ConversionPatternRewriter &rewriter) const final {
142     Location loc = graph.getLoc();
143     // To keep the import-as-graph logic taken by TFG, we create `void func()`
144     // to contain the ops in the tfg::GraphOp. That means the arguments/results
145     // will be the operations inside the function body rather than representing
146     // them in the function signature.
147     FunctionType func_type = rewriter.getFunctionType({}, {});
148     func::FuncOp func = rewriter.create<func::FuncOp>(
149         loc, kImportModelDefaultGraphFuncName, func_type);
150     rewriter.setInsertionPointToStart(func.addEntryBlock());
151     auto executor_graph =
152         rewriter.create<tf_executor::GraphOp>(loc, func_type.getResults());
153     rewriter.inlineRegionBefore(graph.nodes(), executor_graph.body(),
154                                 executor_graph.body().end());
155 
156     // Add terminator of tf_executor::graph
157     rewriter.setInsertionPointToEnd(&executor_graph.body().front());
158     rewriter.create<tf_executor::FetchOp>(loc);
159 
160     // Add terminator of func
161     rewriter.setInsertionPointToEnd(&func.getBody().front());
162     rewriter.create<func::ReturnOp>(loc);
163 
164     rewriter.replaceOp(graph.getOperation(), func.getOperation()->getResults());
165 
166     return success();
167   }
168 };
169 
170 class ConvertGraphFuncOp : public OpConversionPattern<tfg::GraphFuncOp> {
171  public:
172   using OpConversionPattern::OpConversionPattern;
173 
matchAndRewrite(tfg::GraphFuncOp graph_func,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const174   LogicalResult matchAndRewrite(
175       tfg::GraphFuncOp graph_func, OpAdaptor adaptor,
176       ConversionPatternRewriter &rewriter) const final {
177     assert(!graph_func.generic());
178     Location loc = graph_func.getLoc();
179     FunctionType ftype = graph_func.getFunctionType();
180 
181     func::FuncOp func = rewriter.create<func::FuncOp>(
182         graph_func.getLoc(),
183         graph_func->getAttrOfType<StringAttr>(SymbolTable::getSymbolAttrName())
184             .getValue(),
185         ftype);
186 
187     func->setAttrs(graph_func->getAttrs());
188 
189     llvm::SmallVector<Type> arg_types;
190     llvm::SmallVector<Type> res_types;
191     llvm::SmallVector<DictionaryAttr> arg_attrs;
192     llvm::SmallVector<DictionaryAttr> res_attrs;
193     if (failed(FilterTfgSpecificArgResultAttributes(
194             getContext(), ftype.getInputs(), graph_func.getAllArgAttrs(),
195             arg_types, arg_attrs)) ||
196         failed(FilterTfgSpecificArgResultAttributes(
197             getContext(), ftype.getResults(), graph_func.getAllResultAttrs(),
198             res_types, res_attrs)))
199       return failure();
200 
201     // Update the function type which has excluded the control args.
202     func->setAttr("function_type", TypeAttr::get(rewriter.getFunctionType(
203                                        arg_types, res_types)));
204 
205     // Update arg/result attributes.
206     func.setAllArgAttrs(arg_attrs);
207     func.setAllResultAttrs(res_attrs);
208 
209     rewriter.setInsertionPointToStart(func.addEntryBlock());
210     // In TFE, the function body is inlined in a GraphOp. Create a GraphOp
211     // instance and move the regions from GraphFuncOp to GraphOp.
212     auto executor_graph = rewriter.create<tf_executor::GraphOp>(
213         loc, func.getFunctionType().getResults());
214 
215     // Replace the uses of block arguments with function arguments. Note that we
216     // can't erase the arguments here because the operations may still use them
217     // and these uses will be dropped after legalization of each op.
218     unsigned idx = 0;
219     Block &block = graph_func.body().front();
220     for (auto iter = block.args_begin(), end_iter = block.args_end();
221          iter != end_iter; ++iter) {
222       if (!iter->getType().isa<tfg::ControlType>())
223         iter->replaceAllUsesWith(func.getBody().getArgument(idx++));
224     }
225 
226     rewriter.inlineRegionBefore(graph_func.body(), executor_graph.body(),
227                                 executor_graph.body().end());
228 
229     rewriter.setInsertionPointToEnd(&func.getBody().front());
230     rewriter.create<func::ReturnOp>(
231         loc, executor_graph.getOperation()->getResults());
232 
233     rewriter.replaceOp(graph_func.getOperation(),
234                        func.getOperation()->getResults());
235 
236     return success();
237   }
238 };
239 
240 class ConvertReturnOp : public OpConversionPattern<tfg::ReturnOp> {
241  public:
242   using OpConversionPattern::OpConversionPattern;
matchAndRewrite(tfg::ReturnOp ret,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const243   LogicalResult matchAndRewrite(
244       tfg::ReturnOp ret, OpAdaptor adaptor,
245       ConversionPatternRewriter &rewriter) const final {
246     rewriter.replaceOpWithNewOp<tf_executor::FetchOp>(ret.getOperation(),
247                                                       adaptor.getOperands());
248     return success();
249   }
250 };
251 
252 class ConvertControlTriggerOp : public ConversionPattern {
253  public:
ConvertControlTriggerOp(MLIRContext * context)254   explicit ConvertControlTriggerOp(MLIRContext *context)
255       : ConversionPattern("tfg.ControlTrigger", PatternBenefit(1), context) {}
256 
matchAndRewrite(Operation * op,llvm::ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const257   LogicalResult matchAndRewrite(
258       Operation *op, llvm::ArrayRef<Value> operands,
259       ConversionPatternRewriter &rewriter) const final {
260     llvm::SmallVector<Type, 2> new_types(op->getResultTypes());
261     new_types.back() = rewriter.getType<tf_executor::ControlType>();
262 
263     llvm::SmallVector<Value, 2> new_operands;
264     FilterOutBlockArgControlDep(operands, new_operands);
265 
266     rewriter.replaceOpWithNewOp<tf_executor::ControlTriggerOp>(
267         op, new_types, new_operands, op->getAttrs());
268     return success();
269   }
270 };
271 
272 class ConvertEnterOp : public ConversionPattern {
273  public:
ConvertEnterOp(MLIRContext * context)274   explicit ConvertEnterOp(MLIRContext *context)
275       : ConversionPattern("tfg.Enter", PatternBenefit(1), context) {}
276 
matchAndRewrite(Operation * op,llvm::ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const277   LogicalResult matchAndRewrite(
278       Operation *op, llvm::ArrayRef<Value> operands,
279       ConversionPatternRewriter &rewriter) const final {
280     llvm::SmallVector<Type, 2> new_types(op->getResultTypes());
281     new_types.back() = rewriter.getType<tf_executor::ControlType>();
282 
283     llvm::SmallVector<Value, 2> new_operands;
284     FilterOutBlockArgControlDep(operands, new_operands);
285 
286     rewriter.replaceOpWithNewOp<tf_executor::EnterOp>(
287         op, new_types, new_operands, op->getAttrs());
288     return success();
289   }
290 };
291 
292 class ConvertExitOp : public ConversionPattern {
293  public:
ConvertExitOp(MLIRContext * context)294   explicit ConvertExitOp(MLIRContext *context)
295       : ConversionPattern("tfg.Exit", PatternBenefit(1), context) {}
296 
matchAndRewrite(Operation * op,llvm::ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const297   LogicalResult matchAndRewrite(
298       Operation *op, llvm::ArrayRef<Value> operands,
299       ConversionPatternRewriter &rewriter) const final {
300     llvm::SmallVector<Type, 2> new_types(op->getResultTypes());
301     new_types.back() = rewriter.getType<tf_executor::ControlType>();
302 
303     llvm::SmallVector<Value, 2> new_operands;
304     FilterOutBlockArgControlDep(operands, new_operands);
305 
306     rewriter.replaceOpWithNewOp<tf_executor::ExitOp>(
307         op, new_types, new_operands, op->getAttrs());
308     return success();
309   }
310 };
311 
312 class ConvertLoopCondOp : public ConversionPattern {
313  public:
ConvertLoopCondOp(MLIRContext * context)314   explicit ConvertLoopCondOp(MLIRContext *context)
315       : ConversionPattern("tfg.LoopCond", PatternBenefit(1), context) {}
316 
matchAndRewrite(Operation * op,llvm::ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const317   LogicalResult matchAndRewrite(
318       Operation *op, llvm::ArrayRef<Value> operands,
319       ConversionPatternRewriter &rewriter) const final {
320     llvm::SmallVector<Type, 2> new_types(op->getResultTypes());
321     new_types.back() = rewriter.getType<tf_executor::ControlType>();
322 
323     llvm::SmallVector<Value, 2> new_operands;
324     FilterOutBlockArgControlDep(operands, new_operands);
325 
326     rewriter.replaceOpWithNewOp<tf_executor::LoopCondOp>(
327         op, new_types, new_operands, op->getAttrs());
328     return success();
329   }
330 };
331 
332 class ConvertMergeOp : public ConversionPattern {
333  public:
ConvertMergeOp(MLIRContext * context)334   explicit ConvertMergeOp(MLIRContext *context)
335       : ConversionPattern("tfg.Merge", PatternBenefit(1), context) {}
336 
matchAndRewrite(Operation * op,llvm::ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const337   LogicalResult matchAndRewrite(
338       Operation *op, llvm::ArrayRef<Value> operands,
339       ConversionPatternRewriter &rewriter) const final {
340     llvm::SmallVector<Type, 2> new_types(op->getResultTypes());
341     new_types.back() = rewriter.getType<tf_executor::ControlType>();
342 
343     llvm::SmallVector<Value, 2> new_operands;
344     FilterOutBlockArgControlDep(operands, new_operands);
345 
346     rewriter.replaceOpWithNewOp<tf_executor::MergeOp>(
347         op, new_types, new_operands, op->getAttrs());
348     return success();
349   }
350 };
351 
352 class ConvertSwitchOp : public ConversionPattern {
353  public:
ConvertSwitchOp(MLIRContext * context)354   explicit ConvertSwitchOp(MLIRContext *context)
355       : ConversionPattern("tfg.Switch", PatternBenefit(1), context) {}
356 
matchAndRewrite(Operation * op,llvm::ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const357   LogicalResult matchAndRewrite(
358       Operation *op, llvm::ArrayRef<Value> operands,
359       ConversionPatternRewriter &rewriter) const final {
360     llvm::SmallVector<Type, 2> new_types(op->getResultTypes());
361     new_types.back() = rewriter.getType<tf_executor::ControlType>();
362 
363     llvm::SmallVector<Value, 2> new_operands;
364     FilterOutBlockArgControlDep(operands, new_operands);
365 
366     rewriter.replaceOpWithNewOp<tf_executor::SwitchOp>(
367         op, new_types, new_operands, op->getAttrs());
368     return success();
369   }
370 };
371 
372 class ConvertSwitchNOp : public ConversionPattern {
373  public:
ConvertSwitchNOp(MLIRContext * context)374   explicit ConvertSwitchNOp(MLIRContext *context)
375       : ConversionPattern("tfg.SwitchN", PatternBenefit(1), context) {}
376 
matchAndRewrite(Operation * op,llvm::ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const377   LogicalResult matchAndRewrite(
378       Operation *op, llvm::ArrayRef<Value> operands,
379       ConversionPatternRewriter &rewriter) const final {
380     llvm::SmallVector<Type, 2> new_types(op->getResultTypes());
381     new_types.back() = rewriter.getType<tf_executor::ControlType>();
382 
383     llvm::SmallVector<Value, 2> new_operands;
384     FilterOutBlockArgControlDep(operands, new_operands);
385 
386     rewriter.replaceOpWithNewOp<tf_executor::SwitchNOp>(
387         op, new_types, new_operands, op->getAttrs());
388     return success();
389   }
390 };
391 
392 class ConvertGeneralOp : public ConversionPattern {
393  public:
ConvertGeneralOp(MLIRContext * context,const DenseSet<StringRef> & func_symbols)394   ConvertGeneralOp(MLIRContext *context,
395                    const DenseSet<StringRef> &func_symbols)
396       : ConversionPattern(MatchAnyOpTypeTag(), PatternBenefit(1), context),
397         func_symbols_(func_symbols) {}
398 
matchAndRewrite(Operation * op,llvm::ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const399   LogicalResult matchAndRewrite(
400       Operation *op, llvm::ArrayRef<Value> operands,
401       ConversionPatternRewriter &rewriter) const final {
402     if (!llvm::isa<tfg::TFGraphDialect>(op->getDialect())) return failure();
403 
404     Location loc = op->getLoc();
405     llvm::SmallVector<mlir::Type, 2> new_types(op->getResultTypes());
406     // Update the control type from tf_type.control to tf_executor.control.
407     new_types.back() = rewriter.getType<tf_executor::ControlType>();
408 
409     // Control operand is attached on tf_executor::IslandOp.
410     llvm::SmallVector<Value> island_control_operands;
411     llvm::SmallVector<Value> inner_op_operands;
412 
413     for (Value value : operands) {
414       // Because of the property of graph region, the control operands may
415       // not have been converted to tf_executor::ControlType.
416       if (value.getType().isa<tfg::ControlType>() ||
417           value.getType().isa<tf_executor::ControlType>()) {
418         if (!value.isa<BlockArgument>())
419           island_control_operands.push_back(value);
420       } else {
421         inner_op_operands.push_back(value);
422       }
423     }
424 
425     auto island = rewriter.create<tf_executor::IslandOp>(
426         loc, new_types, island_control_operands);
427     island.body().push_back(new mlir::Block);
428 
429     rewriter.setInsertionPointToEnd(&island.body().front());
430 
431     // Control dependency has been applied on tf_executor.island. Remove it
432     // while creating the tf operations.
433     new_types.pop_back();
434 
435     llvm::SmallVector<std::unique_ptr<Region>, 1> new_regions;
436     for (auto &region : op->getRegions()) {
437       new_regions.push_back(std::make_unique<Region>());
438       new_regions.back()->takeBody(region);
439     }
440 
441     llvm::SmallVector<NamedAttribute, 4> attrs;
442     if (failed(ReformatOpAttributes(getContext(), op->getAttrs(), attrs)))
443       return failure();
444 
445     Operation *inner_op;
446 
447     StringRef op_name = op->getName().stripDialect();
448     if (!func_symbols_.contains(op_name)) {
449       std::string tf_op_name = llvm::formatv(
450           "{0}.{1}", TF::TensorFlowDialect::getDialectNamespace(), op_name);
451       OperationState state =
452           OperationState(loc, tf_op_name, inner_op_operands, new_types, attrs,
453                          op->getSuccessors(), new_regions);
454       inner_op = rewriter.create(state);
455     } else {
456       bool disable_call_shape_inference = false;
457       if (op->hasAttr("_disable_call_shape_inference")) {
458         disable_call_shape_inference =
459             op->getAttrOfType<BoolAttr>("_disable_call_shape_inference")
460                 .getValue();
461       }
462       inner_op =
463           rewriter.create<LegacyCallOp>(loc, new_types, inner_op_operands,
464                                         op_name, disable_call_shape_inference);
465     }
466 
467     rewriter.create<tf_executor::YieldOp>(loc, inner_op->getResults());
468 
469     rewriter.replaceOp(op, island.getOperation()->getResults());
470 
471     return success();
472   }
473 
474  private:
475   const DenseSet<StringRef> &func_symbols_;
476 };
477 
478 class LegalizeTFGToTFE : public TF::LegalizeTFGToTFPassBase<LegalizeTFGToTFE> {
getDependentDialects(DialectRegistry & registry) const479   void getDependentDialects(DialectRegistry &registry) const override {
480     RegisterAllTensorFlowDialects(registry);
481   }
482 
483   void runOnOperation() override;
484 };
485 
486 }  // namespace
487 
runOnOperation()488 void LegalizeTFGToTFE::runOnOperation() {
489   MLIRContext &context = getContext();
490   ModuleOp module = getOperation();
491 
492   DenseSet<StringRef> func_symbols;
493   for (auto &op : module.getBodyRegion().getOps()) {
494     if (auto func = llvm::dyn_cast<tfg::GraphFuncOp>(op)) {
495       func_symbols.insert(
496           func->getAttrOfType<StringAttr>(SymbolTable::getSymbolAttrName())
497               .getValue());
498     }
499   }
500 
501   ConversionTarget target(context);
502   target.addLegalDialect<TF::TensorFlowDialect>();
503   target.addLegalDialect<tf_executor::TensorFlowExecutorDialect>();
504   target.addLegalOp<ModuleOp>();
505   target.addLegalOp<func::FuncOp>();
506   target.addLegalOp<func::ReturnOp>();
507 
508   RewritePatternSet patterns(&context);
509   patterns.add<ConvertGraphOp>(&context);
510   patterns.add<ConvertGraphFuncOp>(&context);
511   patterns.add<ConvertReturnOp>(&context);
512   patterns.add<ConvertGeneralOp>(&context, func_symbols);
513   // Control flow V1 operation conversion patterns.
514   patterns.add<ConvertControlTriggerOp>(&context);
515   patterns.add<ConvertEnterOp>(&context);
516   patterns.add<ConvertExitOp>(&context);
517   patterns.add<ConvertLoopCondOp>(&context);
518   patterns.add<ConvertMergeOp>(&context);
519   patterns.add<ConvertSwitchOp>(&context);
520   patterns.add<ConvertSwitchNOp>(&context);
521   FrozenRewritePatternSet finalPatterns(std::move(patterns));
522 
523   // Turn the graph region into SSACFG region by applying an order to the
524   // operations.
525   for (auto &op : module.getBodyRegion().getOps()) {
526     for (auto &region : op.getRegions()) {
527       for (auto &block : region) {
528         // Split tfg.NextIteration to break the cycle.
529         SplitNextIteration(block);
530         tfg::SortTopologically(&block);
531       }
532     }
533   }
534 
535   // Version information is embedded in graph operation in TFG. In TFE, it's
536   // embedded in the module operation.
537   for (auto &op : module.getBodyRegion().getOps()) {
538     auto graph = dyn_cast<tfg::GraphOp>(op);
539     if (!graph) continue;
540     Builder b(&context);
541     auto producer = b.getNamedAttr(
542         "producer", b.getI32IntegerAttr(graph.version().getProducer()));
543     auto min_consumer = b.getNamedAttr(
544         "min_consumer", b.getI32IntegerAttr(graph.version().getMinConsumer()));
545     auto bad_consumers = b.getNamedAttr(
546         "bad_consumers", b.getI32ArrayAttr(graph.version().getBadConsumers()));
547     module->setAttr("tf.versions",
548                     b.getDictionaryAttr(llvm::ArrayRef<NamedAttribute>(
549                         {producer, min_consumer, bad_consumers})));
550     break;
551   }
552 
553   if (failed(applyFullConversion(module.getOperation(), target, finalPatterns)))
554     signalPassFailure();
555 
556   // The uses of arg control dependency has been dropped. We can safely remove
557   // the block argument here.
558   module.walk([&](tf_executor::GraphOp graph) {
559     graph.body().front().eraseArguments([](BlockArgument arg) { return true; });
560   });
561 }
562 
CreateLegalizeTFGToTFEPass()563 std::unique_ptr<Pass> CreateLegalizeTFGToTFEPass() {
564   return std::make_unique<LegalizeTFGToTFE>();
565 }
566 
567 }  // end namespace TF
568 }  // end namespace mlir
569