xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/mlir/lite/transforms/legalize_tf_while.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // Converts TF While to TFL While with single call in body and cond.
17 
18 #include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
19 #include "mlir/IR/Attributes.h"  // from @llvm-project
20 #include "mlir/IR/Builders.h"  // from @llvm-project
21 #include "mlir/IR/Operation.h"  // from @llvm-project
22 #include "mlir/IR/PatternMatch.h"  // from @llvm-project
23 #include "mlir/Pass/Pass.h"  // from @llvm-project
24 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
25 #include "tensorflow/compiler/mlir/lite/transforms/passes.h"
26 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
27 
28 namespace mlir {
29 namespace TFL {
30 namespace {
31 #define GEN_PASS_CLASSES
32 #include "tensorflow/compiler/mlir/lite/transforms/passes.h.inc"
33 
34 // Legalize TF While to TFL While with calls to the original functions from the
35 // cond and body regions.
36 struct LegalizeWhilePass : public LegalizeWhilePassBase<LegalizeWhilePass> {
37   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(LegalizeWhilePass)
38   void RunOnFunction(func::FuncOp func);
39 
runOnOperationmlir::TFL::__anonb11ea8e50111::LegalizeWhilePass40   void runOnOperation() override {
41     for (auto op : getOperation().getOps<func::FuncOp>()) RunOnFunction(op);
42   }
43 };
44 
45 }  // namespace
46 
47 // Inserts call to the given function into the 'region'.
CreateRegionWithCall(func::FuncOp func,Region & region,Location loc)48 void CreateRegionWithCall(func::FuncOp func, Region& region, Location loc) {
49   OpBuilder builder(region);
50   auto block = builder.createBlock(&region);
51   SmallVector<Value, 4> new_operands;
52   for (Type t : func.getFunctionType().getInputs())
53     new_operands.push_back(block->addArgument(t, loc));
54   auto call = builder.create<func::CallOp>(loc, func, new_operands);
55   builder.create<YieldOp>(loc, call.getResults());
56   // Mark old function as private so that it can be DCE'd if not called.
57   func.setPrivate();
58 }
59 
RunOnWhile(TF::WhileOp while_op)60 void RunOnWhile(TF::WhileOp while_op) {
61   Operation* op = while_op.getOperation();
62   // Create new TFL While op that will be used to replace TF While op.
63   auto new_op = OpBuilder(op).create<TFL::WhileOp>(
64       op->getLoc(), op->getResultTypes(), op->getOperands(),
65       while_op.is_stateless());
66   Location loc = while_op->getLoc();
67   CreateRegionWithCall(while_op.cond_function(), new_op.cond(), loc);
68   CreateRegionWithCall(while_op.body_function(), new_op.body(), loc);
69 
70   op->replaceAllUsesWith(new_op.getResults());
71   op->erase();
72 }
73 
RunOnFunction(func::FuncOp func)74 void LegalizeWhilePass::RunOnFunction(func::FuncOp func) {
75   // Convert all TF WhileOps inside the function body to TFL While ops.
76   func.getBody().walk([](TF::WhileOp while_op) { RunOnWhile(while_op); });
77 }
78 
79 // Creates an instance of the TensorFlow While to TFLite While pass.
CreateLegalizeTFWhilePass()80 std::unique_ptr<OperationPass<ModuleOp>> CreateLegalizeTFWhilePass() {
81   return std::make_unique<LegalizeWhilePass>();
82 }
83 
84 static PassRegistration<LegalizeWhilePass> pass;
85 
86 }  // namespace TFL
87 }  // namespace mlir
88