1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <sys/types.h>
17
18 #include <string>
19 #include <utility>
20
21 #include "mlir/Dialect/Func/IR/FuncOps.h"
22 #include "mlir/IR/AffineMap.h"
23 #include "mlir/IR/BuiltinOps.h"
24 #include "mlir/IR/BuiltinTypes.h"
25 #include "mlir/IR/MLIRContext.h"
26 #include "mlir/IR/Operation.h"
27 #include "mlir/IR/OperationSupport.h"
28 #include "mlir/IR/TypeRange.h"
29 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
30 #include "llvm/ADT/STLExtras.h"
31 #include "llvm/ADT/iterator_range.h"
32 #include "llvm/Support/Casting.h"
33 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" // from @llvm-project
34 #include "mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project
35 #include "tensorflow/compiler/mlir/tfrt/jit/transforms/tf_jitrt_passes.h"
36 #include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Analysis/shape_component_analysis.h"
37 #include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
38
39 namespace tensorflow {
40 namespace {
41
42 using llvm::ArrayRef;
43 using llvm::SmallVector;
44
45 using mlir::AffineExpr;
46 using mlir::AffineMap;
47 using mlir::failure;
48 using mlir::Location;
49 using mlir::LogicalResult;
50 using mlir::MLIRContext;
51 using mlir::OpBuilder;
52 using mlir::OperationPass;
53 using mlir::RankedTensorType;
54 using mlir::ShapeComponentAnalysis;
55 using mlir::success;
56 using mlir::TypeRange;
57 using mlir::Value;
58 using mlir::ValueRange;
59 using mlir::arith::ConstantIndexOp;
60 using mlir::arith::ConstantOp;
61 using mlir::arith::IndexCastOp;
62 using mlir::func::FuncOp;
63
64 namespace linalg = mlir::linalg;
65 namespace mhlo = mlir::mhlo;
66 namespace shape = mlir::shape;
67 namespace tensor = mlir::tensor;
68
69 #define GEN_PASS_CLASSES
70 #include "tensorflow/compiler/mlir/tfrt/jit/transforms/tf_jitrt_passes.h.inc"
71
72 // -------------------------------------------------------------------------- //
73
74
75
76
77
78
79 // Replace shape.broadcast with a shape if it's statically known.
80 class BroadcastOpLowering final
81 : public mlir::OpRewritePattern<shape::BroadcastOp> {
82 public:
BroadcastOpLowering(MLIRContext * ctx)83 explicit BroadcastOpLowering(MLIRContext* ctx) : OpRewritePattern(ctx) {}
84
85 LogicalResult matchAndRewrite(shape::BroadcastOp op,
86 mlir::PatternRewriter& rewriter) const override;
87 };
88
89 // Returns a shape tensor if the shapes can be broadcasted to a known shape.
90 // Will either return one of the shapes or a generated mix of the shapes.
simplifyBroadcast(ShapeComponentAnalysis & analysis,ValueRange shapes,Location loc,OpBuilder * builder)91 llvm::Optional<Value> simplifyBroadcast(ShapeComponentAnalysis& analysis,
92 ValueRange shapes, Location loc,
93 OpBuilder* builder) {
94 // First find the input shape with the largest rank.
95 SmallVector<ArrayRef<ShapeComponentAnalysis::SymbolicExpr>> shapes_found;
96 size_t maxRank = 0;
97 for (const auto &shape : llvm::enumerate(shapes)) {
98 auto found_shape = analysis.GetValueInfo(shape.value());
99 if (!found_shape) return {};
100 shapes_found.push_back(*found_shape);
101 maxRank = std::max(maxRank, found_shape->size());
102 }
103 if (maxRank == 0) {
104 return Value(builder->create<tensor::FromElementsOp>(
105 loc, shapes[0].getType(), SmallVector<Value>()));
106 }
107
108 SmallVector<const ShapeComponentAnalysis::SymbolicExpr*> joined_dimensions(
109 maxRank);
110 SmallVector<std::pair<Value, int64_t>> shape_and_rank_for_dim(maxRank);
111 for (const auto &shape : llvm::enumerate(shapes_found)) {
112 for (const auto &dim : llvm::enumerate(llvm::reverse(shape.value()))) {
113 // 1 dimensions don't contribute to the final result.
114 if (dim.value().isConstant(1)) continue;
115 // If it's not a 1 dimension it will be present in the result. Remember
116 // where it came from.
117 auto index = maxRank - dim.index() - 1;
118 if (!joined_dimensions[index]) {
119 joined_dimensions[index] = &dim.value();
120 shape_and_rank_for_dim[index] =
121 std::make_pair(shapes[shape.index()], shape.value().size());
122 continue;
123 }
124 // Bail if the dimensions are neither equal nor 1.
125 if (*joined_dimensions[index] != dim.value()) return {};
126 }
127 }
128 // If the output is the same as one of the inputs just return that.
129 if (llvm::is_splat(shape_and_rank_for_dim) &&
130 shape_and_rank_for_dim[0].first) {
131 return shape_and_rank_for_dim[0].first;
132 }
133 // Otherwise rematerialize the shape from the pieces we have.
134 SmallVector<Value> elements;
135 for (int i = 0; i != maxRank; ++i) {
136 // 1 dimensions are filtered above, recreate the constant.
137 if (!shape_and_rank_for_dim[i].first) {
138 auto one = builder->getIntegerAttr(
139 shapes[0].getType().cast<RankedTensorType>().getElementType(), 1);
140 elements.push_back(builder->create<ConstantOp>(loc, one));
141 continue;
142 }
143 // Extract from one of the shapes, accounting for the reverse indexing
144 // performed by broadcast.
145 Value index = builder->create<ConstantIndexOp>(
146 loc, i - maxRank + shape_and_rank_for_dim[i].second);
147 elements.push_back(builder->create<tensor::ExtractOp>(
148 loc, shape_and_rank_for_dim[i].first, index));
149 }
150 return Value(builder->create<tensor::FromElementsOp>(loc, elements));
151 }
152
matchAndRewrite(shape::BroadcastOp op,mlir::PatternRewriter & rewriter) const153 LogicalResult BroadcastOpLowering::matchAndRewrite(
154 shape::BroadcastOp op, mlir::PatternRewriter& rewriter) const {
155 ShapeComponentAnalysis shape_component_analysis;
156 auto new_broadcast = simplifyBroadcast(
157 shape_component_analysis, op.getShapes(), op.getLoc(), &rewriter);
158 if (!new_broadcast) return failure();
159 rewriter.replaceOp(op, {*new_broadcast});
160 return success();
161 }
162
163 // -------------------------------------------------------------------------- //
164
165 // Rewrite mhlo.dynamic_broadcast_in_dim operation into linalg.generic operation
166 // if can infer the indexing maps for the operand from the symbolic shapes.
167 class DynamicBroadcastInDimOpLowering
168 : public mlir::OpRewritePattern<mhlo::DynamicBroadcastInDimOp> {
169 public:
170 using Base = OpRewritePattern<mhlo::DynamicBroadcastInDimOp>;
171
172 explicit DynamicBroadcastInDimOpLowering(MLIRContext* ctx);
173
174 LogicalResult matchAndRewrite(mhlo::DynamicBroadcastInDimOp op,
175 mlir::PatternRewriter& rewriter) const override;
176 };
177
DynamicBroadcastInDimOpLowering(MLIRContext * ctx)178 DynamicBroadcastInDimOpLowering::DynamicBroadcastInDimOpLowering(
179 MLIRContext* ctx)
180 : Base(ctx) {}
181
182 // Check if broadcasting `from` to `to_shape` is statically known to only have
183 // dimensions that never expand or always expand.
isNonExpandingBroadcast(ShapeComponentAnalysis & analysis,Value from,Value to_shape)184 llvm::Optional<AffineMap> isNonExpandingBroadcast(
185 ShapeComponentAnalysis& analysis, Value from, Value to_shape) {
186 auto in_shape = analysis.GetShapeInfo(from);
187 auto out_shape = analysis.GetValueInfo(to_shape);
188 if (!in_shape || !out_shape) return {};
189
190 SmallVector<AffineExpr> input_map_exprs;
191 size_t rank = out_shape->size();
192 MLIRContext* ctx = (*out_shape)[0].expr.getContext();
193 size_t d = 0;
194 auto affine_zero = getAffineConstantExpr(0, ctx);
195 for (auto zip :
196 llvm::zip(llvm::reverse(*in_shape), llvm::reverse(*out_shape))) {
197 const auto& in = std::get<0>(zip);
198 const auto& out = std::get<1>(zip);
199 bool extend = in.isConstant(1) && !out.isConstant(1);
200 input_map_exprs.push_back(extend ? affine_zero
201 : getAffineDimExpr(rank - d - 1, ctx));
202 ++d;
203
204 // Bail if this is neither a known expansion nor a known non-expansion.
205 if (!extend && in != out) return {};
206 }
207 // Any leading dimensions will be expanded.
208 input_map_exprs.resize(in_shape->size(), affine_zero);
209 std::reverse(input_map_exprs.begin(), input_map_exprs.end());
210 return AffineMap::get(/*dimCount=*/rank,
211 /*symbolCount=*/0, input_map_exprs, ctx);
212 }
213
matchAndRewrite(mhlo::DynamicBroadcastInDimOp op,mlir::PatternRewriter & rewriter) const214 LogicalResult DynamicBroadcastInDimOpLowering::matchAndRewrite(
215 mhlo::DynamicBroadcastInDimOp op, mlir::PatternRewriter& rewriter) const {
216 MLIRContext* ctx = getContext();
217
218 auto in_type = op.operand().getType().dyn_cast<RankedTensorType>();
219 auto out_type = op.getResult().getType().dyn_cast<RankedTensorType>();
220 if (!in_type || !out_type) return failure();
221
222 // Check that broadcast is right-aligned (numpy style), so that operand
223 // dimensions broadcasted to match inner-most dimensions of the output.
224 auto bcast_dims = op.broadcast_dimensions().getValues<int64_t>();
225 auto expected_bcast_dims = llvm::seq<int64_t>(
226 out_type.getRank() - in_type.getRank(), out_type.getRank());
227 if (!llvm::equal(bcast_dims, expected_bcast_dims)) return failure();
228
229 ShapeComponentAnalysis shape_component_analysis;
230 auto input_map = isNonExpandingBroadcast(
231 shape_component_analysis, op.operand(), op.output_dimensions());
232 if (!input_map) return failure();
233
234 // Resolve dynamic output dimensions for the `linalg.init_tensor` operation.
235 SmallVector<Value> output_dyn_dimensions;
236 Location loc = op.getLoc();
237 int64_t rank = out_type.getRank();
238 for (size_t d = 0; d < rank; ++d) {
239 int64_t output_dim = out_type.getShape()[d];
240
241 // Skip static output dimensions, they will be resolved from the shape.
242 if (output_dim >= 0) continue;
243
244 // Resolve the dynamic size of the output dimension.
245 Value output_dyn_dim = rewriter.create<tensor::ExtractOp>(
246 loc, op.output_dimensions(),
247 ValueRange{rewriter.create<ConstantIndexOp>(loc, d)});
248
249 // Symbolic shape analysis might have given us an i32 or i64. Cast to index.
250 if (!output_dyn_dim.getType().isIndex())
251 output_dyn_dim = rewriter.create<IndexCastOp>(
252 loc, rewriter.getIndexType(), output_dyn_dim);
253
254 output_dyn_dimensions.push_back(output_dyn_dim);
255 }
256
257 // Create a linalg.tensor_init operation to initialize output.
258 Value init = rewriter.create<linalg::InitTensorOp>(loc, output_dyn_dimensions,
259 out_type.getShape(),
260 out_type.getElementType());
261
262 // Output indexing map is an identity with `rank` number of loops.
263 AffineMap output_map = AffineMap::getMultiDimIdentityMap(rank, ctx);
264
265 // All iterators are parallel.
266 SmallVector<llvm::StringRef> iterator_types(rank, "parallel");
267
268 rewriter.replaceOpWithNewOp<linalg::GenericOp>(
269 op, /*resultTensorTypes=*/TypeRange{init.getType()},
270 /*inputs=*/ValueRange{op.operand()},
271 /*outputs=*/ValueRange{init},
272 /*indexingMaps=*/llvm::makeArrayRef({*input_map, output_map}),
273 /*iteratorTypes=*/iterator_types,
274 [&](OpBuilder& nested_builder, Location nested_loc, ValueRange args) {
275 nested_builder.create<linalg::YieldOp>(nested_loc, args[0]);
276 });
277
278 return success();
279 }
280
281 // -------------------------------------------------------------------------- //
282 // Optimize function based on the symbolic shape attributes.
283 // -------------------------------------------------------------------------- //
284
285 struct SymbolicShapeOptimizationPass
286 : public SymbolicShapeOptimizationBase<SymbolicShapeOptimizationPass> {
287 SymbolicShapeOptimizationPass() = default;
288
SymbolicShapeOptimizationPasstensorflow::__anondb63220c0111::SymbolicShapeOptimizationPass289 explicit SymbolicShapeOptimizationPass(bool constraints_only) {
290 this->optimize_only_constraints = constraints_only;
291 }
292
runOnOperationtensorflow::__anondb63220c0111::SymbolicShapeOptimizationPass293 void runOnOperation() override {
294 MLIRContext* ctx = &getContext();
295 mlir::RewritePatternSet patterns(ctx);
296
297 // Rewrite shape.broadcast based on the symbolic shapes.
298 patterns.add<BroadcastOpLowering>(ctx);
299
300 // Rewrite broadcasts based on the symbolic shapes if enabled.
301 if (!optimize_only_constraints)
302 patterns.add<DynamicBroadcastInDimOpLowering>(ctx);
303
304 // Add shape dialect canonicalization patterns to fold shape operations
305 // after constraints are replaced with constant witness.
306 for (auto op : ctx->getRegisteredOperations()) {
307 if (llvm::isa<shape::ShapeDialect>(op.getDialect()))
308 op.getCanonicalizationPatterns(patterns, ctx);
309 }
310
311 if (failed(mlir::applyPatternsAndFoldGreedily(getOperation(),
312 std::move(patterns)))) {
313 return signalPassFailure();
314 }
315 }
316 };
317
318 } // namespace
319
CreateSymbolicShapeOptimizationPass(bool constraints_only)320 std::unique_ptr<OperationPass<FuncOp>> CreateSymbolicShapeOptimizationPass(
321 bool constraints_only) {
322 return std::make_unique<SymbolicShapeOptimizationPass>(constraints_only);
323 }
324
325 } // namespace tensorflow
326