1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "absl/strings/match.h"
17 #include "llvm/ADT/DenseSet.h"
18 #include "llvm/ADT/STLExtras.h"
19 #include "llvm/Support/FormatVariadic.h"
20 #include "mlir/IR/Attributes.h" // from @llvm-project
21 #include "mlir/IR/BlockAndValueMapping.h" // from @llvm-project
22 #include "mlir/IR/Builders.h" // from @llvm-project
23 #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
24 #include "mlir/IR/PatternMatch.h" // from @llvm-project
25 #include "mlir/IR/SymbolTable.h" // from @llvm-project
26 #include "mlir/IR/Visitors.h" // from @llvm-project
27 #include "mlir/Transforms/DialectConversion.h" // from @llvm-project
28 #include "tensorflow/compiler/mlir/tensorflow/dialect_registration.h"
29 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
30 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
31 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes_detail.h"
32 #include "tensorflow/core/transforms/toposort/pass.h"
33 #include "tensorflow/core/util/device_name_utils.h"
34
35 namespace mlir {
36 namespace TF {
37 namespace {
38
39 // FIXME: This should be consistent with
40 // tensorflow::kImportModelDefaultGraphFuncName
41 static const char kImportModelDefaultGraphFuncName[] = "main";
42
43 // Please refer to the TFG dialect description for the list of used attributes.
44 // Belows are the attributes in TFE.
45 // TFE Arguments and Results (Got from "_Arg",
46 // "_Retval", .etc)
47 // NodeDef.device <-> "tf.device"
48 // NodeDef.attr <-> "tf."
49 //
50 // TFE general operations
51 // NodeDef.device <-> "device"
52 //
53 // The following two functions are only used for mapping/excluding attributes
54 // which are inconsistent between TFG and TFE.
55 //
FilterTfgSpecificArgResultAttributes(mlir::MLIRContext * context,mlir::ArrayRef<Type> types,mlir::ArrayAttr array_attr,llvm::SmallVector<mlir::Type> & output_types,llvm::SmallVector<mlir::DictionaryAttr> & output_attrs)56 static mlir::LogicalResult FilterTfgSpecificArgResultAttributes(
57 mlir::MLIRContext *context, mlir::ArrayRef<Type> types,
58 mlir::ArrayAttr array_attr, llvm::SmallVector<mlir::Type> &output_types,
59 llvm::SmallVector<mlir::DictionaryAttr> &output_attrs) {
60 for (auto it : llvm::zip(
61 types, array_attr.template getAsRange<mlir::DictionaryAttr>())) {
62 if (std::get<0>(it).isa<tfg::ControlType>()) continue;
63 output_types.push_back(std::get<0>(it));
64
65 mlir::NamedAttrList list;
66 for (mlir::NamedAttribute attr : std::get<1>(it).getValue()) {
67 // Skip if the attribute has "tfg" prefix.
68 if (attr.getName().getValue().startswith("tfg")) continue;
69 list.append(attr);
70 }
71 output_attrs.push_back(list.getDictionary(context));
72 }
73 return mlir::success();
74 }
75
ReformatOpAttributes(mlir::MLIRContext * context,llvm::ArrayRef<mlir::NamedAttribute> attrs,llvm::SmallVectorImpl<mlir::NamedAttribute> & output)76 static mlir::LogicalResult ReformatOpAttributes(
77 mlir::MLIRContext *context, llvm::ArrayRef<mlir::NamedAttribute> attrs,
78 llvm::SmallVectorImpl<mlir::NamedAttribute> &output) {
79 for (mlir::NamedAttribute attr : attrs) {
80 if (attr.getName().strref().contains(
81 mlir::tfg::TFGraphDialect::getDeviceAttrKey())) {
82 tensorflow::DeviceNameUtils::ParsedName parsed_name;
83 if (!tensorflow::DeviceNameUtils::ParseFullName(
84 attr.getValue().cast<mlir::StringAttr>().getValue().str(),
85 &parsed_name))
86 return mlir::failure();
87 if (!parsed_name.has_type) {
88 parsed_name.type = "CPU";
89 parsed_name.has_type = true;
90 }
91 if (!parsed_name.has_id) {
92 parsed_name.id = 0;
93 parsed_name.has_id = true;
94 }
95 output.push_back(mlir::NamedAttribute(
96 mlir::StringAttr::get(context, "device"),
97 mlir::StringAttr::get(
98 context,
99 tensorflow::DeviceNameUtils::ParsedNameToString(parsed_name))));
100 } else {
101 output.push_back(attr);
102 }
103 }
104 return mlir::success();
105 }
106
FilterOutBlockArgControlDep(ValueRange operands,llvm::SmallVectorImpl<Value> & filtered)107 static void FilterOutBlockArgControlDep(
108 ValueRange operands, llvm::SmallVectorImpl<Value> &filtered) {
109 for (Value value : operands)
110 if (!value.isa<mlir::BlockArgument>()) filtered.push_back(value);
111 }
112
113 // Split the tfg.NextIteration into tf_executor::NextIterationSourceOp and
114 // tf_executor::NextIterationSinkOp to break the cycle introduced by itself.
SplitNextIteration(Block & block)115 static void SplitNextIteration(Block &block) {
116 // TODO(b/207144333): Supports callback for unregistered ops
117 block.walk([&](Operation *op) {
118 if (!op->getName().getStringRef().equals("tfg.NextIteration")) return;
119 mlir::OpBuilder builder(op);
120
121 llvm::SmallVector<Value, 2> new_operands;
122 FilterOutBlockArgControlDep(op->getOperands().drop_front(), new_operands);
123
124 auto source_op = builder.create<tf_executor::NextIterationSourceOp>(
125 op->getLoc(), op->getOperand(0).getType());
126 builder.create<tf_executor::NextIterationSinkOp>(
127 op->getLoc(), source_op.token(), /*input=*/op->getOperand(0),
128 /*controlInputs=*/new_operands);
129 op->replaceAllUsesWith(
130 ValueRange({source_op.output(), source_op.control()}));
131 op->erase();
132 });
133 }
134
135 class ConvertGraphOp : public OpConversionPattern<tfg::GraphOp> {
136 public:
137 using OpConversionPattern::OpConversionPattern;
138
matchAndRewrite(tfg::GraphOp graph,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const139 LogicalResult matchAndRewrite(
140 tfg::GraphOp graph, OpAdaptor adaptor,
141 ConversionPatternRewriter &rewriter) const final {
142 Location loc = graph.getLoc();
143 // To keep the import-as-graph logic taken by TFG, we create `void func()`
144 // to contain the ops in the tfg::GraphOp. That means the arguments/results
145 // will be the operations inside the function body rather than representing
146 // them in the function signature.
147 FunctionType func_type = rewriter.getFunctionType({}, {});
148 func::FuncOp func = rewriter.create<func::FuncOp>(
149 loc, kImportModelDefaultGraphFuncName, func_type);
150 rewriter.setInsertionPointToStart(func.addEntryBlock());
151 auto executor_graph =
152 rewriter.create<tf_executor::GraphOp>(loc, func_type.getResults());
153 rewriter.inlineRegionBefore(graph.nodes(), executor_graph.body(),
154 executor_graph.body().end());
155
156 // Add terminator of tf_executor::graph
157 rewriter.setInsertionPointToEnd(&executor_graph.body().front());
158 rewriter.create<tf_executor::FetchOp>(loc);
159
160 // Add terminator of func
161 rewriter.setInsertionPointToEnd(&func.getBody().front());
162 rewriter.create<func::ReturnOp>(loc);
163
164 rewriter.replaceOp(graph.getOperation(), func.getOperation()->getResults());
165
166 return success();
167 }
168 };
169
170 class ConvertGraphFuncOp : public OpConversionPattern<tfg::GraphFuncOp> {
171 public:
172 using OpConversionPattern::OpConversionPattern;
173
matchAndRewrite(tfg::GraphFuncOp graph_func,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const174 LogicalResult matchAndRewrite(
175 tfg::GraphFuncOp graph_func, OpAdaptor adaptor,
176 ConversionPatternRewriter &rewriter) const final {
177 assert(!graph_func.generic());
178 Location loc = graph_func.getLoc();
179 FunctionType ftype = graph_func.getFunctionType();
180
181 func::FuncOp func = rewriter.create<func::FuncOp>(
182 graph_func.getLoc(),
183 graph_func->getAttrOfType<StringAttr>(SymbolTable::getSymbolAttrName())
184 .getValue(),
185 ftype);
186
187 func->setAttrs(graph_func->getAttrs());
188
189 llvm::SmallVector<Type> arg_types;
190 llvm::SmallVector<Type> res_types;
191 llvm::SmallVector<DictionaryAttr> arg_attrs;
192 llvm::SmallVector<DictionaryAttr> res_attrs;
193 if (failed(FilterTfgSpecificArgResultAttributes(
194 getContext(), ftype.getInputs(), graph_func.getAllArgAttrs(),
195 arg_types, arg_attrs)) ||
196 failed(FilterTfgSpecificArgResultAttributes(
197 getContext(), ftype.getResults(), graph_func.getAllResultAttrs(),
198 res_types, res_attrs)))
199 return failure();
200
201 // Update the function type which has excluded the control args.
202 func->setAttr("function_type", TypeAttr::get(rewriter.getFunctionType(
203 arg_types, res_types)));
204
205 // Update arg/result attributes.
206 func.setAllArgAttrs(arg_attrs);
207 func.setAllResultAttrs(res_attrs);
208
209 rewriter.setInsertionPointToStart(func.addEntryBlock());
210 // In TFE, the function body is inlined in a GraphOp. Create a GraphOp
211 // instance and move the regions from GraphFuncOp to GraphOp.
212 auto executor_graph = rewriter.create<tf_executor::GraphOp>(
213 loc, func.getFunctionType().getResults());
214
215 // Replace the uses of block arguments with function arguments. Note that we
216 // can't erase the arguments here because the operations may still use them
217 // and these uses will be dropped after legalization of each op.
218 unsigned idx = 0;
219 Block &block = graph_func.body().front();
220 for (auto iter = block.args_begin(), end_iter = block.args_end();
221 iter != end_iter; ++iter) {
222 if (!iter->getType().isa<tfg::ControlType>())
223 iter->replaceAllUsesWith(func.getBody().getArgument(idx++));
224 }
225
226 rewriter.inlineRegionBefore(graph_func.body(), executor_graph.body(),
227 executor_graph.body().end());
228
229 rewriter.setInsertionPointToEnd(&func.getBody().front());
230 rewriter.create<func::ReturnOp>(
231 loc, executor_graph.getOperation()->getResults());
232
233 rewriter.replaceOp(graph_func.getOperation(),
234 func.getOperation()->getResults());
235
236 return success();
237 }
238 };
239
240 class ConvertReturnOp : public OpConversionPattern<tfg::ReturnOp> {
241 public:
242 using OpConversionPattern::OpConversionPattern;
matchAndRewrite(tfg::ReturnOp ret,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const243 LogicalResult matchAndRewrite(
244 tfg::ReturnOp ret, OpAdaptor adaptor,
245 ConversionPatternRewriter &rewriter) const final {
246 rewriter.replaceOpWithNewOp<tf_executor::FetchOp>(ret.getOperation(),
247 adaptor.getOperands());
248 return success();
249 }
250 };
251
252 class ConvertControlTriggerOp : public ConversionPattern {
253 public:
ConvertControlTriggerOp(MLIRContext * context)254 explicit ConvertControlTriggerOp(MLIRContext *context)
255 : ConversionPattern("tfg.ControlTrigger", PatternBenefit(1), context) {}
256
matchAndRewrite(Operation * op,llvm::ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const257 LogicalResult matchAndRewrite(
258 Operation *op, llvm::ArrayRef<Value> operands,
259 ConversionPatternRewriter &rewriter) const final {
260 llvm::SmallVector<Type, 2> new_types(op->getResultTypes());
261 new_types.back() = rewriter.getType<tf_executor::ControlType>();
262
263 llvm::SmallVector<Value, 2> new_operands;
264 FilterOutBlockArgControlDep(operands, new_operands);
265
266 rewriter.replaceOpWithNewOp<tf_executor::ControlTriggerOp>(
267 op, new_types, new_operands, op->getAttrs());
268 return success();
269 }
270 };
271
272 class ConvertEnterOp : public ConversionPattern {
273 public:
ConvertEnterOp(MLIRContext * context)274 explicit ConvertEnterOp(MLIRContext *context)
275 : ConversionPattern("tfg.Enter", PatternBenefit(1), context) {}
276
matchAndRewrite(Operation * op,llvm::ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const277 LogicalResult matchAndRewrite(
278 Operation *op, llvm::ArrayRef<Value> operands,
279 ConversionPatternRewriter &rewriter) const final {
280 llvm::SmallVector<Type, 2> new_types(op->getResultTypes());
281 new_types.back() = rewriter.getType<tf_executor::ControlType>();
282
283 llvm::SmallVector<Value, 2> new_operands;
284 FilterOutBlockArgControlDep(operands, new_operands);
285
286 rewriter.replaceOpWithNewOp<tf_executor::EnterOp>(
287 op, new_types, new_operands, op->getAttrs());
288 return success();
289 }
290 };
291
292 class ConvertExitOp : public ConversionPattern {
293 public:
ConvertExitOp(MLIRContext * context)294 explicit ConvertExitOp(MLIRContext *context)
295 : ConversionPattern("tfg.Exit", PatternBenefit(1), context) {}
296
matchAndRewrite(Operation * op,llvm::ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const297 LogicalResult matchAndRewrite(
298 Operation *op, llvm::ArrayRef<Value> operands,
299 ConversionPatternRewriter &rewriter) const final {
300 llvm::SmallVector<Type, 2> new_types(op->getResultTypes());
301 new_types.back() = rewriter.getType<tf_executor::ControlType>();
302
303 llvm::SmallVector<Value, 2> new_operands;
304 FilterOutBlockArgControlDep(operands, new_operands);
305
306 rewriter.replaceOpWithNewOp<tf_executor::ExitOp>(
307 op, new_types, new_operands, op->getAttrs());
308 return success();
309 }
310 };
311
312 class ConvertLoopCondOp : public ConversionPattern {
313 public:
ConvertLoopCondOp(MLIRContext * context)314 explicit ConvertLoopCondOp(MLIRContext *context)
315 : ConversionPattern("tfg.LoopCond", PatternBenefit(1), context) {}
316
matchAndRewrite(Operation * op,llvm::ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const317 LogicalResult matchAndRewrite(
318 Operation *op, llvm::ArrayRef<Value> operands,
319 ConversionPatternRewriter &rewriter) const final {
320 llvm::SmallVector<Type, 2> new_types(op->getResultTypes());
321 new_types.back() = rewriter.getType<tf_executor::ControlType>();
322
323 llvm::SmallVector<Value, 2> new_operands;
324 FilterOutBlockArgControlDep(operands, new_operands);
325
326 rewriter.replaceOpWithNewOp<tf_executor::LoopCondOp>(
327 op, new_types, new_operands, op->getAttrs());
328 return success();
329 }
330 };
331
332 class ConvertMergeOp : public ConversionPattern {
333 public:
ConvertMergeOp(MLIRContext * context)334 explicit ConvertMergeOp(MLIRContext *context)
335 : ConversionPattern("tfg.Merge", PatternBenefit(1), context) {}
336
matchAndRewrite(Operation * op,llvm::ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const337 LogicalResult matchAndRewrite(
338 Operation *op, llvm::ArrayRef<Value> operands,
339 ConversionPatternRewriter &rewriter) const final {
340 llvm::SmallVector<Type, 2> new_types(op->getResultTypes());
341 new_types.back() = rewriter.getType<tf_executor::ControlType>();
342
343 llvm::SmallVector<Value, 2> new_operands;
344 FilterOutBlockArgControlDep(operands, new_operands);
345
346 rewriter.replaceOpWithNewOp<tf_executor::MergeOp>(
347 op, new_types, new_operands, op->getAttrs());
348 return success();
349 }
350 };
351
352 class ConvertSwitchOp : public ConversionPattern {
353 public:
ConvertSwitchOp(MLIRContext * context)354 explicit ConvertSwitchOp(MLIRContext *context)
355 : ConversionPattern("tfg.Switch", PatternBenefit(1), context) {}
356
matchAndRewrite(Operation * op,llvm::ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const357 LogicalResult matchAndRewrite(
358 Operation *op, llvm::ArrayRef<Value> operands,
359 ConversionPatternRewriter &rewriter) const final {
360 llvm::SmallVector<Type, 2> new_types(op->getResultTypes());
361 new_types.back() = rewriter.getType<tf_executor::ControlType>();
362
363 llvm::SmallVector<Value, 2> new_operands;
364 FilterOutBlockArgControlDep(operands, new_operands);
365
366 rewriter.replaceOpWithNewOp<tf_executor::SwitchOp>(
367 op, new_types, new_operands, op->getAttrs());
368 return success();
369 }
370 };
371
372 class ConvertSwitchNOp : public ConversionPattern {
373 public:
ConvertSwitchNOp(MLIRContext * context)374 explicit ConvertSwitchNOp(MLIRContext *context)
375 : ConversionPattern("tfg.SwitchN", PatternBenefit(1), context) {}
376
matchAndRewrite(Operation * op,llvm::ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const377 LogicalResult matchAndRewrite(
378 Operation *op, llvm::ArrayRef<Value> operands,
379 ConversionPatternRewriter &rewriter) const final {
380 llvm::SmallVector<Type, 2> new_types(op->getResultTypes());
381 new_types.back() = rewriter.getType<tf_executor::ControlType>();
382
383 llvm::SmallVector<Value, 2> new_operands;
384 FilterOutBlockArgControlDep(operands, new_operands);
385
386 rewriter.replaceOpWithNewOp<tf_executor::SwitchNOp>(
387 op, new_types, new_operands, op->getAttrs());
388 return success();
389 }
390 };
391
392 class ConvertGeneralOp : public ConversionPattern {
393 public:
ConvertGeneralOp(MLIRContext * context,const DenseSet<StringRef> & func_symbols)394 ConvertGeneralOp(MLIRContext *context,
395 const DenseSet<StringRef> &func_symbols)
396 : ConversionPattern(MatchAnyOpTypeTag(), PatternBenefit(1), context),
397 func_symbols_(func_symbols) {}
398
matchAndRewrite(Operation * op,llvm::ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const399 LogicalResult matchAndRewrite(
400 Operation *op, llvm::ArrayRef<Value> operands,
401 ConversionPatternRewriter &rewriter) const final {
402 if (!llvm::isa<tfg::TFGraphDialect>(op->getDialect())) return failure();
403
404 Location loc = op->getLoc();
405 llvm::SmallVector<mlir::Type, 2> new_types(op->getResultTypes());
406 // Update the control type from tf_type.control to tf_executor.control.
407 new_types.back() = rewriter.getType<tf_executor::ControlType>();
408
409 // Control operand is attached on tf_executor::IslandOp.
410 llvm::SmallVector<Value> island_control_operands;
411 llvm::SmallVector<Value> inner_op_operands;
412
413 for (Value value : operands) {
414 // Because of the property of graph region, the control operands may
415 // not have been converted to tf_executor::ControlType.
416 if (value.getType().isa<tfg::ControlType>() ||
417 value.getType().isa<tf_executor::ControlType>()) {
418 if (!value.isa<BlockArgument>())
419 island_control_operands.push_back(value);
420 } else {
421 inner_op_operands.push_back(value);
422 }
423 }
424
425 auto island = rewriter.create<tf_executor::IslandOp>(
426 loc, new_types, island_control_operands);
427 island.body().push_back(new mlir::Block);
428
429 rewriter.setInsertionPointToEnd(&island.body().front());
430
431 // Control dependency has been applied on tf_executor.island. Remove it
432 // while creating the tf operations.
433 new_types.pop_back();
434
435 llvm::SmallVector<std::unique_ptr<Region>, 1> new_regions;
436 for (auto ®ion : op->getRegions()) {
437 new_regions.push_back(std::make_unique<Region>());
438 new_regions.back()->takeBody(region);
439 }
440
441 llvm::SmallVector<NamedAttribute, 4> attrs;
442 if (failed(ReformatOpAttributes(getContext(), op->getAttrs(), attrs)))
443 return failure();
444
445 Operation *inner_op;
446
447 StringRef op_name = op->getName().stripDialect();
448 if (!func_symbols_.contains(op_name)) {
449 std::string tf_op_name = llvm::formatv(
450 "{0}.{1}", TF::TensorFlowDialect::getDialectNamespace(), op_name);
451 OperationState state =
452 OperationState(loc, tf_op_name, inner_op_operands, new_types, attrs,
453 op->getSuccessors(), new_regions);
454 inner_op = rewriter.create(state);
455 } else {
456 bool disable_call_shape_inference = false;
457 if (op->hasAttr("_disable_call_shape_inference")) {
458 disable_call_shape_inference =
459 op->getAttrOfType<BoolAttr>("_disable_call_shape_inference")
460 .getValue();
461 }
462 inner_op =
463 rewriter.create<LegacyCallOp>(loc, new_types, inner_op_operands,
464 op_name, disable_call_shape_inference);
465 }
466
467 rewriter.create<tf_executor::YieldOp>(loc, inner_op->getResults());
468
469 rewriter.replaceOp(op, island.getOperation()->getResults());
470
471 return success();
472 }
473
474 private:
475 const DenseSet<StringRef> &func_symbols_;
476 };
477
478 class LegalizeTFGToTFE : public TF::LegalizeTFGToTFPassBase<LegalizeTFGToTFE> {
getDependentDialects(DialectRegistry & registry) const479 void getDependentDialects(DialectRegistry ®istry) const override {
480 RegisterAllTensorFlowDialects(registry);
481 }
482
483 void runOnOperation() override;
484 };
485
486 } // namespace
487
runOnOperation()488 void LegalizeTFGToTFE::runOnOperation() {
489 MLIRContext &context = getContext();
490 ModuleOp module = getOperation();
491
492 DenseSet<StringRef> func_symbols;
493 for (auto &op : module.getBodyRegion().getOps()) {
494 if (auto func = llvm::dyn_cast<tfg::GraphFuncOp>(op)) {
495 func_symbols.insert(
496 func->getAttrOfType<StringAttr>(SymbolTable::getSymbolAttrName())
497 .getValue());
498 }
499 }
500
501 ConversionTarget target(context);
502 target.addLegalDialect<TF::TensorFlowDialect>();
503 target.addLegalDialect<tf_executor::TensorFlowExecutorDialect>();
504 target.addLegalOp<ModuleOp>();
505 target.addLegalOp<func::FuncOp>();
506 target.addLegalOp<func::ReturnOp>();
507
508 RewritePatternSet patterns(&context);
509 patterns.add<ConvertGraphOp>(&context);
510 patterns.add<ConvertGraphFuncOp>(&context);
511 patterns.add<ConvertReturnOp>(&context);
512 patterns.add<ConvertGeneralOp>(&context, func_symbols);
513 // Control flow V1 operation conversion patterns.
514 patterns.add<ConvertControlTriggerOp>(&context);
515 patterns.add<ConvertEnterOp>(&context);
516 patterns.add<ConvertExitOp>(&context);
517 patterns.add<ConvertLoopCondOp>(&context);
518 patterns.add<ConvertMergeOp>(&context);
519 patterns.add<ConvertSwitchOp>(&context);
520 patterns.add<ConvertSwitchNOp>(&context);
521 FrozenRewritePatternSet finalPatterns(std::move(patterns));
522
523 // Turn the graph region into SSACFG region by applying an order to the
524 // operations.
525 for (auto &op : module.getBodyRegion().getOps()) {
526 for (auto ®ion : op.getRegions()) {
527 for (auto &block : region) {
528 // Split tfg.NextIteration to break the cycle.
529 SplitNextIteration(block);
530 tfg::SortTopologically(&block);
531 }
532 }
533 }
534
535 // Version information is embedded in graph operation in TFG. In TFE, it's
536 // embedded in the module operation.
537 for (auto &op : module.getBodyRegion().getOps()) {
538 auto graph = dyn_cast<tfg::GraphOp>(op);
539 if (!graph) continue;
540 Builder b(&context);
541 auto producer = b.getNamedAttr(
542 "producer", b.getI32IntegerAttr(graph.version().getProducer()));
543 auto min_consumer = b.getNamedAttr(
544 "min_consumer", b.getI32IntegerAttr(graph.version().getMinConsumer()));
545 auto bad_consumers = b.getNamedAttr(
546 "bad_consumers", b.getI32ArrayAttr(graph.version().getBadConsumers()));
547 module->setAttr("tf.versions",
548 b.getDictionaryAttr(llvm::ArrayRef<NamedAttribute>(
549 {producer, min_consumer, bad_consumers})));
550 break;
551 }
552
553 if (failed(applyFullConversion(module.getOperation(), target, finalPatterns)))
554 signalPassFailure();
555
556 // The uses of arg control dependency has been dropped. We can safely remove
557 // the block argument here.
558 module.walk([&](tf_executor::GraphOp graph) {
559 graph.body().front().eraseArguments([](BlockArgument arg) { return true; });
560 });
561 }
562
CreateLegalizeTFGToTFEPass()563 std::unique_ptr<Pass> CreateLegalizeTFGToTFEPass() {
564 return std::make_unique<LegalizeTFGToTFE>();
565 }
566
567 } // end namespace TF
568 } // end namespace mlir
569