xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/mlir/tensorflow/transforms/fold_broadcast.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <memory>
17 
18 #include "absl/memory/memory.h"
19 #include "llvm/ADT/STLExtras.h"
20 #include "llvm/ADT/SmallVector.h"
21 #include "llvm/Support/Casting.h"
22 #include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
23 #include "mlir/Dialect/Traits.h"  // from @llvm-project
24 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
25 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
26 #include "mlir/IR/Operation.h"  // from @llvm-project
27 #include "mlir/IR/PatternMatch.h"  // from @llvm-project
28 #include "mlir/Pass/Pass.h"  // from @llvm-project
29 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
30 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"  // from @llvm-project
31 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
32 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes_detail.h"
33 
34 namespace mlir {
35 namespace {
36 
37 class ConvertResultsBroadcastableShapeOp : public RewritePattern {
38  public:
ConvertResultsBroadcastableShapeOp(MLIRContext * context)39   ConvertResultsBroadcastableShapeOp(MLIRContext* context)
40       : RewritePattern(MatchAnyOpTypeTag(), 1, context) {}
41 
42   LogicalResult matchAndRewrite(Operation* op,
43                                 PatternRewriter& rewriter) const override;
44 
45  private:
46   template <typename Op>
47   LogicalResult RewriteEqOp(Operation* op, PatternRewriter& rewriter) const;
48 
49   LogicalResult RewriteOp(
50       Operation* op, PatternRewriter& rewriter,
51       const std::function<bool(ArrayRef<int64_t>, ArrayRef<int64_t>,
52                                SmallVectorImpl<int64_t>&)>&
53           get_broadcasted_shape) const;
54 
55   LogicalResult RewriteBatchMatMulV2Op(Operation* op,
56                                        PatternRewriter& rewriter) const;
57 };
58 
59 class BroadcastFoldPass : public TF::BroadcastFoldPassBase<BroadcastFoldPass> {
60  public:
61   void runOnOperation() override;
62 };
63 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const64 LogicalResult ConvertResultsBroadcastableShapeOp::matchAndRewrite(
65     Operation* op, PatternRewriter& rewriter) const {
66   if (op->hasTrait<OpTrait::ResultsBroadcastableShape>())
67     return RewriteOp(op, rewriter, OpTrait::util::getBroadcastedShape);
68 
69   // tf.Equal and tf.NotEqual ops only satisfy ResultsBroadcastableShape when
70   // incompatible_shape_error is `true` (what is also checked by the verifier).
71   if (succeeded(RewriteEqOp<TF::EqualOp>(op, rewriter))) return success();
72   if (succeeded(RewriteEqOp<TF::NotEqualOp>(op, rewriter))) return success();
73   if (succeeded(RewriteBatchMatMulV2Op(op, rewriter))) return success();
74 
75   return failure();
76 }
77 
RewriteBatchMatMulV2Op(Operation * op,PatternRewriter & rewriter) const78 LogicalResult ConvertResultsBroadcastableShapeOp::RewriteBatchMatMulV2Op(
79     Operation* op, PatternRewriter& rewriter) const {
80   auto matmul_op = llvm::dyn_cast<TF::BatchMatMulV2Op>(op);
81   if (!matmul_op) return failure();
82 
83   // Gets the broadcasted output shape for tf.BatchMatMulV2Op. `shape_x` is the
84   // shape of op's first/left-hand-side operand and `shape_y` is the shape of
85   // op's second/right-hand-side operand.
86   const auto get_broadcasted_shape =
87       [&](ArrayRef<int64_t> shape_x, ArrayRef<int64_t> shape_y,
88           SmallVectorImpl<int64_t>& result_shape) {
89         if (shape_x.size() < 2 || shape_y.size() < 2) {
90           return false;
91         }
92 
93         // Checks outer dimensions (i.e., the dimensions higher than 2D) are
94         // broadcastable. If true, then get the broadcasted shape for outer
95         // dimension.
96         if (!OpTrait::util::getBroadcastedShape(
97                 shape_x.drop_back(2), shape_y.drop_back(2), result_shape)) {
98           return false;
99         }
100 
101         const int x_row =
102             matmul_op.adj_x() ? shape_x.back() : *(shape_x.rbegin() + 1);
103         const int x_col =
104             !matmul_op.adj_x() ? shape_x.back() : *(shape_x.rbegin() + 1);
105 
106         const int y_row =
107             matmul_op.adj_y() ? shape_y.back() : *(shape_y.rbegin() + 1);
108         const int y_col =
109             !matmul_op.adj_y() ? shape_y.back() : *(shape_y.rbegin() + 1);
110 
111         // Checks that matrix multiply can perform a valid contraction.
112         if (x_col != y_row) {
113           result_shape.clear();
114           return false;
115         }
116 
117         result_shape.push_back(x_row);
118         result_shape.push_back(y_col);
119         return true;
120       };
121 
122   return RewriteOp(op, rewriter, get_broadcasted_shape);
123 }
124 
125 template <typename Op>
RewriteEqOp(Operation * op,PatternRewriter & rewriter) const126 LogicalResult ConvertResultsBroadcastableShapeOp::RewriteEqOp(
127     Operation* op, PatternRewriter& rewriter) const {
128   auto eq_op = llvm::dyn_cast_or_null<Op>(op);
129   if (eq_op && eq_op.incompatible_shape_error())
130     return RewriteOp(op, rewriter, OpTrait::util::getBroadcastedShape);
131   return failure();
132 }
133 
RewriteOp(Operation * op,PatternRewriter & rewriter,const std::function<bool (ArrayRef<int64_t>,ArrayRef<int64_t>,SmallVectorImpl<int64_t> &)> & get_broadcasted_shape) const134 LogicalResult ConvertResultsBroadcastableShapeOp::RewriteOp(
135     Operation* op, PatternRewriter& rewriter,
136     const std::function<bool(ArrayRef<int64_t>, ArrayRef<int64_t>,
137                              SmallVectorImpl<int64_t>&)>& get_broadcasted_shape)
138     const {
139   if (op->getNumOperands() != 2 || op->getResultTypes().size() != 1)
140     return failure();
141 
142   // Check that the result shape is fully defined.
143   auto result_type =
144       op->getResultTypes().front().dyn_cast_or_null<RankedTensorType>();
145   if (!result_type || !result_type.hasStaticShape()) return failure();
146 
147   bool changed = false;
148   for (uint64_t i = 0, e = op->getNumOperands(); i < e; ++i) {
149     // Check that the i'th operand is a broadcast.
150     auto broadcast = llvm::dyn_cast_or_null<TF::BroadcastToOp>(
151         op->getOpOperand(i).get().getDefiningOp());
152     if (!broadcast) continue;
153 
154     // Check that the operand of the broadcast has fully defined shape.
155     auto broadcast_arg_type =
156         broadcast.input().getType().dyn_cast_or_null<RankedTensorType>();
157     if (!broadcast_arg_type || !broadcast_arg_type.hasStaticShape()) continue;
158 
159     // Check that the other argument has fully defined shape.
160     auto argument_type = op->getOpOperand(1 - i)
161                              .get()
162                              .getType()
163                              .dyn_cast_or_null<RankedTensorType>();
164     if (!argument_type || !argument_type.hasStaticShape()) continue;
165 
166     // Get the unbroadcasted shapes in the operand order.
167     std::array<llvm::ArrayRef<int64_t>, 2> operand_shapes;
168     operand_shapes[i] = broadcast_arg_type.getShape();
169     operand_shapes[1 - i] = argument_type.getShape();
170 
171     // Check that the input of the broadcast and the other operand is broadcast
172     // compatible.
173     llvm::SmallVector<int64_t, 4> broadcasted_shape;
174     if (!get_broadcasted_shape(operand_shapes[0], operand_shapes[1],
175                                broadcasted_shape))
176       continue;
177 
178     // Check that an implicit broadcast between the operand of the broadcast and
179     // the other argument would result in the same type as the result type.
180     if (broadcasted_shape != result_type.getShape()) continue;
181 
182     // Update the operand of the op to be the operand of the broadcast.
183     rewriter.updateRootInPlace(
184         op, [&]() { op->getOpOperand(i).set(broadcast.input()); });
185     changed = true;
186   }
187   return success(changed);
188 }
189 
runOnOperation()190 void BroadcastFoldPass::runOnOperation() {
191   RewritePatternSet patterns(&getContext());
192   auto func = getOperation();
193 
194   patterns.add<ConvertResultsBroadcastableShapeOp>(func.getContext());
195   (void)applyPatternsAndFoldGreedily(func, std::move(patterns));
196 }
197 
198 }  // namespace
199 
200 namespace TF {
CreateBroadcastFoldPass()201 std::unique_ptr<OperationPass<func::FuncOp>> CreateBroadcastFoldPass() {
202   return std::make_unique<BroadcastFoldPass>();
203 }
204 }  // namespace TF
205 
206 }  // namespace mlir
207