1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <utility>
17
18 #include "mlir/Dialect/Affine/IR/AffineOps.h"
19 #include "mlir/Dialect/Func/IR/FuncOps.h"
20 #include "mlir/Dialect/Linalg/IR/Linalg.h"
21 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
22 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
23 #include "llvm/ADT/STLExtras.h"
24 #include "tensorflow/compiler/mlir/tfrt/jit/transforms/tf_jitrt_passes.h"
25
26 namespace tensorflow {
27
28 #define GEN_PASS_CLASSES
29 #include "tensorflow/compiler/mlir/tfrt/jit/transforms/tf_jitrt_passes.h.inc"
30
31 // -------------------------------------------------------------------------- //
32 // Fuse Linalg generic operations on Tensors.
33 // -------------------------------------------------------------------------- //
34
35 using mlir::dyn_cast;
36 using mlir::isa;
37
38 using mlir::AffineMap;
39 using mlir::MLIRContext;
40 using mlir::Operation;
41 using mlir::OpOperand;
42 using mlir::OpResult;
43 using mlir::RewritePatternSet;
44
45 namespace linalg = mlir::linalg;
46 namespace tensor = mlir::tensor;
47
48 // Returns true if `op` is a linalg generic operation that only does the
49 // broadcast of the input.
IsBroadcast(Operation * op)50 static bool IsBroadcast(Operation *op) {
51 // Operation must be a generic linalg operation.
52 auto generic = dyn_cast<linalg::GenericOp>(op);
53 if (!generic) return false;
54
55 // All iterators must be parallel.
56 if (generic.getNumParallelLoops() != generic.getNumLoops()) return false;
57
58 // The body must simple forward input to the output.
59 if (!isa<linalg::YieldOp>(generic.getBody()->front())) return false;
60
61 // Operation must have single input and output.
62 if (generic.getNumInputs() != 1 || generic.getNumOutputs() != 1) return false;
63
64 // Check the input operand indexing map.
65 OpOperand *operand = generic.getInputOperand(0);
66 AffineMap indexing_map = generic.getTiedIndexingMap(operand);
67
68 if (!indexing_map.isProjectedPermutation() ||
69 indexing_map.getNumDims() == indexing_map.getNumResults())
70 return false;
71
72 // We found a generic linalg operation that is a simple broadcast.
73 return true;
74 }
75
76 // Decide if the producer operation should be fused into the consumer.
ControlElementwiseOpsFusion(const OpResult & producer_result,OpOperand &)77 static bool ControlElementwiseOpsFusion(const OpResult &producer_result,
78 OpOperand &) {
79 // TODO(ezhulenev): This is a very simplistic heuristic, we need something
80 // better to decide when fusion is beneficial.
81
82 // Always fuse broadcasts into the consumer.
83 if (IsBroadcast(producer_result.getOwner())) return true;
84
85 // If producer result has multiple users do not fuse it into the consumer.
86 if (!producer_result.hasOneUse()) return false;
87
88 return true;
89 }
90
91 // Check if the reshape operation is only expansion into/collapsing of
92 // unit-dimension.
93 template <typename TensorReshapeOp>
IsUnitDimExpansionOnly(TensorReshapeOp reshape_op)94 static bool IsUnitDimExpansionOnly(TensorReshapeOp reshape_op) {
95 constexpr bool is_expanding =
96 std::is_same<TensorReshapeOp, tensor::ExpandShapeOp>::value;
97 llvm::ArrayRef<int64_t> expanded_shape =
98 (is_expanding ? reshape_op.getResultType().getShape()
99 : reshape_op.getSrcType().getShape());
100 for (auto &indices : reshape_op.getReassociationIndices()) {
101 unsigned num_unit_dims = 0;
102 for (int64_t position : indices)
103 if (expanded_shape[position] == 1) num_unit_dims++;
104 if (num_unit_dims != indices.size() - 1) return false;
105 }
106 return true;
107 }
108
109 // Control function to skip unit dim reshape when fusing reshapes by expansion.
SkipUnitDimReshape(const OpResult & producer,OpOperand & consumer)110 static bool SkipUnitDimReshape(const OpResult &producer, OpOperand &consumer) {
111 // If producer result has multiple users do not fuse it into the consumer.
112 if (!producer.hasOneUse()) return false;
113
114 if (auto producer_collapse_op =
115 dyn_cast<tensor::CollapseShapeOp>(producer.getOwner())) {
116 return !IsUnitDimExpansionOnly(producer_collapse_op);
117 }
118 if (auto consumer_expand_op =
119 dyn_cast<tensor::ExpandShapeOp>(consumer.getOwner())) {
120 return !IsUnitDimExpansionOnly(consumer_expand_op);
121 }
122 return true;
123 }
124
125 struct FusionPass : public FusionBase<FusionPass> {
runOnOperationtensorflow::FusionPass126 void runOnOperation() override {
127 Operation *op = getOperation();
128
129 MLIRContext *context = &getContext();
130 RewritePatternSet patterns(op->getContext());
131 linalg::populateElementwiseOpsFusionPatterns(patterns,
132 ControlElementwiseOpsFusion);
133
134 linalg::populateFoldReshapeOpsByExpansionPatterns(patterns,
135 SkipUnitDimReshape);
136
137 linalg::populateConstantFoldLinalgOperations(patterns,
138 ControlElementwiseOpsFusion);
139
140 mlir::AffineApplyOp::getCanonicalizationPatterns(patterns, context);
141 linalg::GenericOp::getCanonicalizationPatterns(patterns, context);
142 tensor::ExpandShapeOp::getCanonicalizationPatterns(patterns, context);
143 tensor::CollapseShapeOp::getCanonicalizationPatterns(patterns, context);
144 context->getLoadedDialect<linalg::LinalgDialect>()
145 ->getCanonicalizationPatterns(patterns);
146 // Use TopDownTraversal for compile time reasons.
147 mlir::GreedyRewriteConfig grc;
148 grc.useTopDownTraversal = true;
149 (void)applyPatternsAndFoldGreedily(op->getRegions(), std::move(patterns),
150 grc);
151 }
152 };
153
CreateFusionPass()154 std::unique_ptr<mlir::OperationPass<mlir::func::FuncOp>> CreateFusionPass() {
155 return std::make_unique<FusionPass>();
156 }
157
158 } // namespace tensorflow
159