1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <memory>
17
18 #include "absl/memory/memory.h"
19 #include "llvm/ADT/STLExtras.h"
20 #include "llvm/ADT/SmallVector.h"
21 #include "llvm/Support/Casting.h"
22 #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
23 #include "mlir/Dialect/Traits.h" // from @llvm-project
24 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
25 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
26 #include "mlir/IR/Operation.h" // from @llvm-project
27 #include "mlir/IR/PatternMatch.h" // from @llvm-project
28 #include "mlir/Pass/Pass.h" // from @llvm-project
29 #include "mlir/Support/LogicalResult.h" // from @llvm-project
30 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project
31 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
32 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes_detail.h"
33
34 namespace mlir {
35 namespace {
36
37 class ConvertResultsBroadcastableShapeOp : public RewritePattern {
38 public:
ConvertResultsBroadcastableShapeOp(MLIRContext * context)39 ConvertResultsBroadcastableShapeOp(MLIRContext* context)
40 : RewritePattern(MatchAnyOpTypeTag(), 1, context) {}
41
42 LogicalResult matchAndRewrite(Operation* op,
43 PatternRewriter& rewriter) const override;
44
45 private:
46 template <typename Op>
47 LogicalResult RewriteEqOp(Operation* op, PatternRewriter& rewriter) const;
48
49 LogicalResult RewriteOp(
50 Operation* op, PatternRewriter& rewriter,
51 const std::function<bool(ArrayRef<int64_t>, ArrayRef<int64_t>,
52 SmallVectorImpl<int64_t>&)>&
53 get_broadcasted_shape) const;
54
55 LogicalResult RewriteBatchMatMulV2Op(Operation* op,
56 PatternRewriter& rewriter) const;
57 };
58
59 class BroadcastFoldPass : public TF::BroadcastFoldPassBase<BroadcastFoldPass> {
60 public:
61 void runOnOperation() override;
62 };
63
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const64 LogicalResult ConvertResultsBroadcastableShapeOp::matchAndRewrite(
65 Operation* op, PatternRewriter& rewriter) const {
66 if (op->hasTrait<OpTrait::ResultsBroadcastableShape>())
67 return RewriteOp(op, rewriter, OpTrait::util::getBroadcastedShape);
68
69 // tf.Equal and tf.NotEqual ops only satisfy ResultsBroadcastableShape when
70 // incompatible_shape_error is `true` (what is also checked by the verifier).
71 if (succeeded(RewriteEqOp<TF::EqualOp>(op, rewriter))) return success();
72 if (succeeded(RewriteEqOp<TF::NotEqualOp>(op, rewriter))) return success();
73 if (succeeded(RewriteBatchMatMulV2Op(op, rewriter))) return success();
74
75 return failure();
76 }
77
RewriteBatchMatMulV2Op(Operation * op,PatternRewriter & rewriter) const78 LogicalResult ConvertResultsBroadcastableShapeOp::RewriteBatchMatMulV2Op(
79 Operation* op, PatternRewriter& rewriter) const {
80 auto matmul_op = llvm::dyn_cast<TF::BatchMatMulV2Op>(op);
81 if (!matmul_op) return failure();
82
83 // Gets the broadcasted output shape for tf.BatchMatMulV2Op. `shape_x` is the
84 // shape of op's first/left-hand-side operand and `shape_y` is the shape of
85 // op's second/right-hand-side operand.
86 const auto get_broadcasted_shape =
87 [&](ArrayRef<int64_t> shape_x, ArrayRef<int64_t> shape_y,
88 SmallVectorImpl<int64_t>& result_shape) {
89 if (shape_x.size() < 2 || shape_y.size() < 2) {
90 return false;
91 }
92
93 // Checks outer dimensions (i.e., the dimensions higher than 2D) are
94 // broadcastable. If true, then get the broadcasted shape for outer
95 // dimension.
96 if (!OpTrait::util::getBroadcastedShape(
97 shape_x.drop_back(2), shape_y.drop_back(2), result_shape)) {
98 return false;
99 }
100
101 const int x_row =
102 matmul_op.adj_x() ? shape_x.back() : *(shape_x.rbegin() + 1);
103 const int x_col =
104 !matmul_op.adj_x() ? shape_x.back() : *(shape_x.rbegin() + 1);
105
106 const int y_row =
107 matmul_op.adj_y() ? shape_y.back() : *(shape_y.rbegin() + 1);
108 const int y_col =
109 !matmul_op.adj_y() ? shape_y.back() : *(shape_y.rbegin() + 1);
110
111 // Checks that matrix multiply can perform a valid contraction.
112 if (x_col != y_row) {
113 result_shape.clear();
114 return false;
115 }
116
117 result_shape.push_back(x_row);
118 result_shape.push_back(y_col);
119 return true;
120 };
121
122 return RewriteOp(op, rewriter, get_broadcasted_shape);
123 }
124
125 template <typename Op>
RewriteEqOp(Operation * op,PatternRewriter & rewriter) const126 LogicalResult ConvertResultsBroadcastableShapeOp::RewriteEqOp(
127 Operation* op, PatternRewriter& rewriter) const {
128 auto eq_op = llvm::dyn_cast_or_null<Op>(op);
129 if (eq_op && eq_op.incompatible_shape_error())
130 return RewriteOp(op, rewriter, OpTrait::util::getBroadcastedShape);
131 return failure();
132 }
133
RewriteOp(Operation * op,PatternRewriter & rewriter,const std::function<bool (ArrayRef<int64_t>,ArrayRef<int64_t>,SmallVectorImpl<int64_t> &)> & get_broadcasted_shape) const134 LogicalResult ConvertResultsBroadcastableShapeOp::RewriteOp(
135 Operation* op, PatternRewriter& rewriter,
136 const std::function<bool(ArrayRef<int64_t>, ArrayRef<int64_t>,
137 SmallVectorImpl<int64_t>&)>& get_broadcasted_shape)
138 const {
139 if (op->getNumOperands() != 2 || op->getResultTypes().size() != 1)
140 return failure();
141
142 // Check that the result shape is fully defined.
143 auto result_type =
144 op->getResultTypes().front().dyn_cast_or_null<RankedTensorType>();
145 if (!result_type || !result_type.hasStaticShape()) return failure();
146
147 bool changed = false;
148 for (uint64_t i = 0, e = op->getNumOperands(); i < e; ++i) {
149 // Check that the i'th operand is a broadcast.
150 auto broadcast = llvm::dyn_cast_or_null<TF::BroadcastToOp>(
151 op->getOpOperand(i).get().getDefiningOp());
152 if (!broadcast) continue;
153
154 // Check that the operand of the broadcast has fully defined shape.
155 auto broadcast_arg_type =
156 broadcast.input().getType().dyn_cast_or_null<RankedTensorType>();
157 if (!broadcast_arg_type || !broadcast_arg_type.hasStaticShape()) continue;
158
159 // Check that the other argument has fully defined shape.
160 auto argument_type = op->getOpOperand(1 - i)
161 .get()
162 .getType()
163 .dyn_cast_or_null<RankedTensorType>();
164 if (!argument_type || !argument_type.hasStaticShape()) continue;
165
166 // Get the unbroadcasted shapes in the operand order.
167 std::array<llvm::ArrayRef<int64_t>, 2> operand_shapes;
168 operand_shapes[i] = broadcast_arg_type.getShape();
169 operand_shapes[1 - i] = argument_type.getShape();
170
171 // Check that the input of the broadcast and the other operand is broadcast
172 // compatible.
173 llvm::SmallVector<int64_t, 4> broadcasted_shape;
174 if (!get_broadcasted_shape(operand_shapes[0], operand_shapes[1],
175 broadcasted_shape))
176 continue;
177
178 // Check that an implicit broadcast between the operand of the broadcast and
179 // the other argument would result in the same type as the result type.
180 if (broadcasted_shape != result_type.getShape()) continue;
181
182 // Update the operand of the op to be the operand of the broadcast.
183 rewriter.updateRootInPlace(
184 op, [&]() { op->getOpOperand(i).set(broadcast.input()); });
185 changed = true;
186 }
187 return success(changed);
188 }
189
runOnOperation()190 void BroadcastFoldPass::runOnOperation() {
191 RewritePatternSet patterns(&getContext());
192 auto func = getOperation();
193
194 patterns.add<ConvertResultsBroadcastableShapeOp>(func.getContext());
195 (void)applyPatternsAndFoldGreedily(func, std::move(patterns));
196 }
197
198 } // namespace
199
200 namespace TF {
CreateBroadcastFoldPass()201 std::unique_ptr<OperationPass<func::FuncOp>> CreateBroadcastFoldPass() {
202 return std::make_unique<BroadcastFoldPass>();
203 }
204 } // namespace TF
205
206 } // namespace mlir
207