xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/mlir/tfrt/jit/transforms/tf_jitrt_fusion.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <utility>
17 
18 #include "mlir/Dialect/Affine/IR/AffineOps.h"
19 #include "mlir/Dialect/Func/IR/FuncOps.h"
20 #include "mlir/Dialect/Linalg/IR/Linalg.h"
21 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
22 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
23 #include "llvm/ADT/STLExtras.h"
24 #include "tensorflow/compiler/mlir/tfrt/jit/transforms/tf_jitrt_passes.h"
25 
26 namespace tensorflow {
27 
28 #define GEN_PASS_CLASSES
29 #include "tensorflow/compiler/mlir/tfrt/jit/transforms/tf_jitrt_passes.h.inc"
30 
31 // -------------------------------------------------------------------------- //
32 // Fuse Linalg generic operations on Tensors.
33 // -------------------------------------------------------------------------- //
34 
35 using mlir::dyn_cast;
36 using mlir::isa;
37 
38 using mlir::AffineMap;
39 using mlir::MLIRContext;
40 using mlir::Operation;
41 using mlir::OpOperand;
42 using mlir::OpResult;
43 using mlir::RewritePatternSet;
44 
45 namespace linalg = mlir::linalg;
46 namespace tensor = mlir::tensor;
47 
48 // Returns true if `op` is a linalg generic operation that only does the
49 // broadcast of the input.
IsBroadcast(Operation * op)50 static bool IsBroadcast(Operation *op) {
51   // Operation must be a generic linalg operation.
52   auto generic = dyn_cast<linalg::GenericOp>(op);
53   if (!generic) return false;
54 
55   // All iterators must be parallel.
56   if (generic.getNumParallelLoops() != generic.getNumLoops()) return false;
57 
58   // The body must simple forward input to the output.
59   if (!isa<linalg::YieldOp>(generic.getBody()->front())) return false;
60 
61   // Operation must have single input and output.
62   if (generic.getNumInputs() != 1 || generic.getNumOutputs() != 1) return false;
63 
64   // Check the input operand indexing map.
65   OpOperand *operand = generic.getInputOperand(0);
66   AffineMap indexing_map = generic.getTiedIndexingMap(operand);
67 
68   if (!indexing_map.isProjectedPermutation() ||
69       indexing_map.getNumDims() == indexing_map.getNumResults())
70     return false;
71 
72   // We found a generic linalg operation that is a simple broadcast.
73   return true;
74 }
75 
76 // Decide if the producer operation should be fused into the consumer.
ControlElementwiseOpsFusion(const OpResult & producer_result,OpOperand &)77 static bool ControlElementwiseOpsFusion(const OpResult &producer_result,
78                                         OpOperand &) {
79   // TODO(ezhulenev): This is a very simplistic heuristic, we need something
80   // better to decide when fusion is beneficial.
81 
82   // Always fuse broadcasts into the consumer.
83   if (IsBroadcast(producer_result.getOwner())) return true;
84 
85   // If producer result has multiple users do not fuse it into the consumer.
86   if (!producer_result.hasOneUse()) return false;
87 
88   return true;
89 }
90 
91 // Check if the reshape operation is only expansion into/collapsing of
92 // unit-dimension.
93 template <typename TensorReshapeOp>
IsUnitDimExpansionOnly(TensorReshapeOp reshape_op)94 static bool IsUnitDimExpansionOnly(TensorReshapeOp reshape_op) {
95   constexpr bool is_expanding =
96       std::is_same<TensorReshapeOp, tensor::ExpandShapeOp>::value;
97   llvm::ArrayRef<int64_t> expanded_shape =
98       (is_expanding ? reshape_op.getResultType().getShape()
99                     : reshape_op.getSrcType().getShape());
100   for (auto &indices : reshape_op.getReassociationIndices()) {
101     unsigned num_unit_dims = 0;
102     for (int64_t position : indices)
103       if (expanded_shape[position] == 1) num_unit_dims++;
104     if (num_unit_dims != indices.size() - 1) return false;
105   }
106   return true;
107 }
108 
109 // Control function to skip unit dim reshape when fusing reshapes by expansion.
SkipUnitDimReshape(const OpResult & producer,OpOperand & consumer)110 static bool SkipUnitDimReshape(const OpResult &producer, OpOperand &consumer) {
111   // If producer result has multiple users do not fuse it into the consumer.
112   if (!producer.hasOneUse()) return false;
113 
114   if (auto producer_collapse_op =
115           dyn_cast<tensor::CollapseShapeOp>(producer.getOwner())) {
116     return !IsUnitDimExpansionOnly(producer_collapse_op);
117   }
118   if (auto consumer_expand_op =
119           dyn_cast<tensor::ExpandShapeOp>(consumer.getOwner())) {
120     return !IsUnitDimExpansionOnly(consumer_expand_op);
121   }
122   return true;
123 }
124 
125 struct FusionPass : public FusionBase<FusionPass> {
runOnOperationtensorflow::FusionPass126   void runOnOperation() override {
127     Operation *op = getOperation();
128 
129     MLIRContext *context = &getContext();
130     RewritePatternSet patterns(op->getContext());
131     linalg::populateElementwiseOpsFusionPatterns(patterns,
132                                                  ControlElementwiseOpsFusion);
133 
134     linalg::populateFoldReshapeOpsByExpansionPatterns(patterns,
135                                                       SkipUnitDimReshape);
136 
137     linalg::populateConstantFoldLinalgOperations(patterns,
138                                                  ControlElementwiseOpsFusion);
139 
140     mlir::AffineApplyOp::getCanonicalizationPatterns(patterns, context);
141     linalg::GenericOp::getCanonicalizationPatterns(patterns, context);
142     tensor::ExpandShapeOp::getCanonicalizationPatterns(patterns, context);
143     tensor::CollapseShapeOp::getCanonicalizationPatterns(patterns, context);
144     context->getLoadedDialect<linalg::LinalgDialect>()
145         ->getCanonicalizationPatterns(patterns);
146     // Use TopDownTraversal for compile time reasons.
147     mlir::GreedyRewriteConfig grc;
148     grc.useTopDownTraversal = true;
149     (void)applyPatternsAndFoldGreedily(op->getRegions(), std::move(patterns),
150                                        grc);
151   }
152 };
153 
CreateFusionPass()154 std::unique_ptr<mlir::OperationPass<mlir::func::FuncOp>> CreateFusionPass() {
155   return std::make_unique<FusionPass>();
156 }
157 
158 }  // namespace tensorflow
159