xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/mlir/tensorflow/transforms/gpu_fusion.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "llvm/ADT/STLExtras.h"
17 #include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
18 #include "mlir/IR/Attributes.h"  // from @llvm-project
19 #include "mlir/IR/Builders.h"  // from @llvm-project
20 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
21 #include "mlir/IR/PatternMatch.h"  // from @llvm-project
22 #include "mlir/Pass/Pass.h"  // from @llvm-project
23 #include "mlir/Pass/PassManager.h"  // from @llvm-project
24 #include "mlir/Pass/PassRegistry.h"  // from @llvm-project
25 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"  // from @llvm-project
26 #include "mlir/Transforms/Passes.h"  // from @llvm-project
27 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
28 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
29 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes_detail.h"
30 
31 #define DEBUG_TYPE "tf-gpu-op-fusion"
32 
33 namespace mlir {
34 namespace TF {
35 
36 namespace {
37 
38 // GpuOpFusionPass is a pass performing fusion specific to GPU targets.
39 // This is an ad-hoc pass for now, but should be integrated with some notion
40 // of "target" in the MLIR pipeline in the future.
41 class GpuOpFusionPass : public TensorflowGPUFusionBase<GpuOpFusionPass> {
42  public:
43   void runOnOperation() final;
44 };
45 
46 //   %y:6 = "tf.FusedBatchNormV3"(%x, %scale, %offset, %mean, %variance)
47 //   %0 = "tf.Relu"(%y#0)
48 // ->
49 //   %y:6 = "tf._FusedBatchNormEx"(%x, %scale, %offset, %mean, %variance)
50 //
51 // Or:
52 //   %y:6 = "tf.FusedBatchNormV3"(%x, %scale, %offset, %mean, %variance)
53 //   %0 = "tf.AddV2"(%y#0, %side_input)
54 //   %1 = "tf.Relu"(%0)
55 // ->
56 //  %y:6 = "tf._FusedBatchNormEx"(%x, %scale, %offset, %mean, %variance,
57 //                                %side_input)
58 // TODO(aminim): we should revisit this as a declarative pattern.
59 // For the second pattern, there is not good way in the framework to handle the
60 // commutativity of the AddV2: we want the FusedBatchNormV3 on any side.
61 // Also we need some native calls to handle the "hasOneUse" aspects and the
62 // optional extra operands for the AddV2 case.
63 struct ReluToFusedBatchNorm : public OpRewritePattern<ReluOp> {
64   using OpRewritePattern<ReluOp>::OpRewritePattern;
65 
matchAndRewritemlir::TF::__anon2c86410b0111::ReluToFusedBatchNorm66   LogicalResult matchAndRewrite(ReluOp relu_op,
67                                 PatternRewriter &rewriter) const override {
68     Operation *relu_input = relu_op.features().getDefiningOp();
69     if (!relu_input) return failure();
70     auto batch_norm = dyn_cast_or_null<FusedBatchNormV3Op>(relu_input);
71     AddV2Op add_op;
72     Value side_input;
73     if (!batch_norm) {
74       // We don't have a FusedBatchNorm as input to the ReLu, but we can get
75       // through an AddV2 as well.
76       add_op = dyn_cast_or_null<AddV2Op>(relu_input);
77       if (!add_op) return failure();
78 
79       batch_norm =
80           dyn_cast_or_null<FusedBatchNormV3Op>(add_op.x().getDefiningOp());
81       if (batch_norm) {
82         side_input = add_op.y();
83       } else {
84         // Didn't get a FusedBatchNorm on the LHS of the AddV2, try the RHS.
85         batch_norm =
86             dyn_cast_or_null<FusedBatchNormV3Op>(add_op.y().getDefiningOp());
87         if (!batch_norm) return failure();
88         side_input = add_op.x();
89       }
90     }
91     assert(batch_norm);
92     if (batch_norm.is_training()) return failure();
93     if (!batch_norm.y().hasOneUse()) return failure();
94 
95     // Build the newly fused operation to replace the batch norm
96     OperationState state(batch_norm.getLoc(),
97                          _FusedBatchNormExOp::getOperationName());
98     state.addOperands(batch_norm.getOperands());
99     if (side_input) state.operands.push_back(side_input);
100     state.addTypes(batch_norm.getResultTypes());
101     state.addAttributes(batch_norm->getAttrs());
102     Operation *op = rewriter.create(state);
103     rewriter.replaceOp(batch_norm, op->getResults());
104 
105     // Depending on the case, we may fuse the add, the relu, or both.
106     if (!add_op || add_op.z().hasOneUse()) {
107       // We fuse the Relu only if the add has a single use, otherwise we only
108       // fuse the add itself.
109       op->setAttr("activation_mode", rewriter.getStringAttr("Relu"));
110       rewriter.replaceOp(relu_op, op->getResult(0));
111     }
112     if (add_op) {
113       rewriter.replaceOp(add_op, op->getResult(0));
114     }
115 
116     return success();
117   }
118 };
119 
runOnOperation()120 void GpuOpFusionPass::runOnOperation() {
121   func::FuncOp func = getOperation();
122   RewritePatternSet patterns(&getContext());
123   patterns.add<ReluToFusedBatchNorm>(&getContext());
124   (void)applyPatternsAndFoldGreedily(func, std::move(patterns));
125 }
126 
127 }  // namespace
128 
CreateGpuOpFusionPass()129 std::unique_ptr<OperationPass<func::FuncOp>> CreateGpuOpFusionPass() {
130   return std::make_unique<GpuOpFusionPass>();
131 }
132 
133 }  // namespace TF
134 }  // namespace mlir
135