1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <sys/types.h>
17 
18 #include <string>
19 #include <utility>
20 
21 #include "mlir/Dialect/Func/IR/FuncOps.h"
22 #include "mlir/IR/AffineMap.h"
23 #include "mlir/IR/BuiltinOps.h"
24 #include "mlir/IR/BuiltinTypes.h"
25 #include "mlir/IR/MLIRContext.h"
26 #include "mlir/IR/Operation.h"
27 #include "mlir/IR/OperationSupport.h"
28 #include "mlir/IR/TypeRange.h"
29 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
30 #include "llvm/ADT/STLExtras.h"
31 #include "llvm/ADT/iterator_range.h"
32 #include "llvm/Support/Casting.h"
33 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"  // from @llvm-project
34 #include "mlir/Dialect/Tensor/IR/Tensor.h"  // from @llvm-project
35 #include "tensorflow/compiler/mlir/tfrt/jit/transforms/tf_jitrt_passes.h"
36 #include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Analysis/shape_component_analysis.h"
37 #include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
38 
39 namespace tensorflow {
40 namespace {
41 
42 using llvm::ArrayRef;
43 using llvm::SmallVector;
44 
45 using mlir::AffineExpr;
46 using mlir::AffineMap;
47 using mlir::failure;
48 using mlir::Location;
49 using mlir::LogicalResult;
50 using mlir::MLIRContext;
51 using mlir::OpBuilder;
52 using mlir::OperationPass;
53 using mlir::RankedTensorType;
54 using mlir::ShapeComponentAnalysis;
55 using mlir::success;
56 using mlir::TypeRange;
57 using mlir::Value;
58 using mlir::ValueRange;
59 using mlir::arith::ConstantIndexOp;
60 using mlir::arith::ConstantOp;
61 using mlir::arith::IndexCastOp;
62 using mlir::func::FuncOp;
63 
64 namespace linalg = mlir::linalg;
65 namespace mhlo = mlir::mhlo;
66 namespace shape = mlir::shape;
67 namespace tensor = mlir::tensor;
68 
69 #define GEN_PASS_CLASSES
70 #include "tensorflow/compiler/mlir/tfrt/jit/transforms/tf_jitrt_passes.h.inc"
71 
72 // -------------------------------------------------------------------------- //
73 
74 
75 
76 
77 
78 
79 // Replace shape.broadcast with a shape if it's statically known.
80 class BroadcastOpLowering final
81     : public mlir::OpRewritePattern<shape::BroadcastOp> {
82  public:
BroadcastOpLowering(MLIRContext * ctx)83   explicit BroadcastOpLowering(MLIRContext* ctx) : OpRewritePattern(ctx) {}
84 
85   LogicalResult matchAndRewrite(shape::BroadcastOp op,
86                                 mlir::PatternRewriter& rewriter) const override;
87 };
88 
89 // Returns a shape tensor if the shapes can be broadcasted to a known shape.
90 // Will either return one of the shapes or a generated mix of the shapes.
simplifyBroadcast(ShapeComponentAnalysis & analysis,ValueRange shapes,Location loc,OpBuilder * builder)91 llvm::Optional<Value> simplifyBroadcast(ShapeComponentAnalysis& analysis,
92                                         ValueRange shapes, Location loc,
93                                         OpBuilder* builder) {
94   // First find the input shape with the largest rank.
95   SmallVector<ArrayRef<ShapeComponentAnalysis::SymbolicExpr>> shapes_found;
96   size_t maxRank = 0;
97   for (const auto &shape : llvm::enumerate(shapes)) {
98     auto found_shape = analysis.GetValueInfo(shape.value());
99     if (!found_shape) return {};
100     shapes_found.push_back(*found_shape);
101     maxRank = std::max(maxRank, found_shape->size());
102   }
103   if (maxRank == 0) {
104     return Value(builder->create<tensor::FromElementsOp>(
105         loc, shapes[0].getType(), SmallVector<Value>()));
106   }
107 
108   SmallVector<const ShapeComponentAnalysis::SymbolicExpr*> joined_dimensions(
109       maxRank);
110   SmallVector<std::pair<Value, int64_t>> shape_and_rank_for_dim(maxRank);
111   for (const auto &shape : llvm::enumerate(shapes_found)) {
112     for (const auto &dim : llvm::enumerate(llvm::reverse(shape.value()))) {
113       // 1 dimensions don't contribute to the final result.
114       if (dim.value().isConstant(1)) continue;
115       // If it's not a 1 dimension it will be present in the result. Remember
116       // where it came from.
117       auto index = maxRank - dim.index() - 1;
118       if (!joined_dimensions[index]) {
119         joined_dimensions[index] = &dim.value();
120         shape_and_rank_for_dim[index] =
121             std::make_pair(shapes[shape.index()], shape.value().size());
122         continue;
123       }
124       // Bail if the dimensions are neither equal nor 1.
125       if (*joined_dimensions[index] != dim.value()) return {};
126     }
127   }
128   // If the output is the same as one of the inputs just return that.
129   if (llvm::is_splat(shape_and_rank_for_dim) &&
130       shape_and_rank_for_dim[0].first) {
131     return shape_and_rank_for_dim[0].first;
132   }
133   // Otherwise rematerialize the shape from the pieces we have.
134   SmallVector<Value> elements;
135   for (int i = 0; i != maxRank; ++i) {
136     // 1 dimensions are filtered above, recreate the constant.
137     if (!shape_and_rank_for_dim[i].first) {
138       auto one = builder->getIntegerAttr(
139           shapes[0].getType().cast<RankedTensorType>().getElementType(), 1);
140       elements.push_back(builder->create<ConstantOp>(loc, one));
141       continue;
142     }
143     // Extract from one of the shapes, accounting for the reverse indexing
144     // performed by broadcast.
145     Value index = builder->create<ConstantIndexOp>(
146         loc, i - maxRank + shape_and_rank_for_dim[i].second);
147     elements.push_back(builder->create<tensor::ExtractOp>(
148         loc, shape_and_rank_for_dim[i].first, index));
149   }
150   return Value(builder->create<tensor::FromElementsOp>(loc, elements));
151 }
152 
matchAndRewrite(shape::BroadcastOp op,mlir::PatternRewriter & rewriter) const153 LogicalResult BroadcastOpLowering::matchAndRewrite(
154     shape::BroadcastOp op, mlir::PatternRewriter& rewriter) const {
155   ShapeComponentAnalysis shape_component_analysis;
156   auto new_broadcast = simplifyBroadcast(
157       shape_component_analysis, op.getShapes(), op.getLoc(), &rewriter);
158   if (!new_broadcast) return failure();
159   rewriter.replaceOp(op, {*new_broadcast});
160   return success();
161 }
162 
163 // -------------------------------------------------------------------------- //
164 
165 // Rewrite mhlo.dynamic_broadcast_in_dim operation into linalg.generic operation
166 // if can infer the indexing maps for the operand from the symbolic shapes.
167 class DynamicBroadcastInDimOpLowering
168     : public mlir::OpRewritePattern<mhlo::DynamicBroadcastInDimOp> {
169  public:
170   using Base = OpRewritePattern<mhlo::DynamicBroadcastInDimOp>;
171 
172   explicit DynamicBroadcastInDimOpLowering(MLIRContext* ctx);
173 
174   LogicalResult matchAndRewrite(mhlo::DynamicBroadcastInDimOp op,
175                                 mlir::PatternRewriter& rewriter) const override;
176 };
177 
DynamicBroadcastInDimOpLowering(MLIRContext * ctx)178 DynamicBroadcastInDimOpLowering::DynamicBroadcastInDimOpLowering(
179     MLIRContext* ctx)
180     : Base(ctx) {}
181 
182 // Check if broadcasting `from` to `to_shape` is statically known to only have
183 // dimensions that never expand or always expand.
isNonExpandingBroadcast(ShapeComponentAnalysis & analysis,Value from,Value to_shape)184 llvm::Optional<AffineMap> isNonExpandingBroadcast(
185     ShapeComponentAnalysis& analysis, Value from, Value to_shape) {
186   auto in_shape = analysis.GetShapeInfo(from);
187   auto out_shape = analysis.GetValueInfo(to_shape);
188   if (!in_shape || !out_shape) return {};
189 
190   SmallVector<AffineExpr> input_map_exprs;
191   size_t rank = out_shape->size();
192   MLIRContext* ctx = (*out_shape)[0].expr.getContext();
193   size_t d = 0;
194   auto affine_zero = getAffineConstantExpr(0, ctx);
195   for (auto zip :
196        llvm::zip(llvm::reverse(*in_shape), llvm::reverse(*out_shape))) {
197     const auto& in = std::get<0>(zip);
198     const auto& out = std::get<1>(zip);
199     bool extend = in.isConstant(1) && !out.isConstant(1);
200     input_map_exprs.push_back(extend ? affine_zero
201                                      : getAffineDimExpr(rank - d - 1, ctx));
202     ++d;
203 
204     // Bail if this is neither a known expansion nor a known non-expansion.
205     if (!extend && in != out) return {};
206   }
207   // Any leading dimensions will be expanded.
208   input_map_exprs.resize(in_shape->size(), affine_zero);
209   std::reverse(input_map_exprs.begin(), input_map_exprs.end());
210   return AffineMap::get(/*dimCount=*/rank,
211                         /*symbolCount=*/0, input_map_exprs, ctx);
212 }
213 
matchAndRewrite(mhlo::DynamicBroadcastInDimOp op,mlir::PatternRewriter & rewriter) const214 LogicalResult DynamicBroadcastInDimOpLowering::matchAndRewrite(
215     mhlo::DynamicBroadcastInDimOp op, mlir::PatternRewriter& rewriter) const {
216   MLIRContext* ctx = getContext();
217 
218   auto in_type = op.operand().getType().dyn_cast<RankedTensorType>();
219   auto out_type = op.getResult().getType().dyn_cast<RankedTensorType>();
220   if (!in_type || !out_type) return failure();
221 
222   // Check that broadcast is right-aligned (numpy style), so that operand
223   // dimensions broadcasted to match inner-most dimensions of the output.
224   auto bcast_dims = op.broadcast_dimensions().getValues<int64_t>();
225   auto expected_bcast_dims = llvm::seq<int64_t>(
226       out_type.getRank() - in_type.getRank(), out_type.getRank());
227   if (!llvm::equal(bcast_dims, expected_bcast_dims)) return failure();
228 
229   ShapeComponentAnalysis shape_component_analysis;
230   auto input_map = isNonExpandingBroadcast(
231       shape_component_analysis, op.operand(), op.output_dimensions());
232   if (!input_map) return failure();
233 
234   // Resolve dynamic output dimensions for the `linalg.init_tensor` operation.
235   SmallVector<Value> output_dyn_dimensions;
236   Location loc = op.getLoc();
237   int64_t rank = out_type.getRank();
238   for (size_t d = 0; d < rank; ++d) {
239     int64_t output_dim = out_type.getShape()[d];
240 
241     // Skip static output dimensions, they will be resolved from the shape.
242     if (output_dim >= 0) continue;
243 
244     // Resolve the dynamic size of the output dimension.
245     Value output_dyn_dim = rewriter.create<tensor::ExtractOp>(
246         loc, op.output_dimensions(),
247         ValueRange{rewriter.create<ConstantIndexOp>(loc, d)});
248 
249     // Symbolic shape analysis might have given us an i32 or i64. Cast to index.
250     if (!output_dyn_dim.getType().isIndex())
251       output_dyn_dim = rewriter.create<IndexCastOp>(
252           loc, rewriter.getIndexType(), output_dyn_dim);
253 
254     output_dyn_dimensions.push_back(output_dyn_dim);
255   }
256 
257   // Create a linalg.tensor_init operation to initialize output.
258   Value init = rewriter.create<linalg::InitTensorOp>(loc, output_dyn_dimensions,
259                                                      out_type.getShape(),
260                                                      out_type.getElementType());
261 
262   // Output indexing map is an identity with `rank` number of loops.
263   AffineMap output_map = AffineMap::getMultiDimIdentityMap(rank, ctx);
264 
265   // All iterators are parallel.
266   SmallVector<llvm::StringRef> iterator_types(rank, "parallel");
267 
268   rewriter.replaceOpWithNewOp<linalg::GenericOp>(
269       op, /*resultTensorTypes=*/TypeRange{init.getType()},
270       /*inputs=*/ValueRange{op.operand()},
271       /*outputs=*/ValueRange{init},
272       /*indexingMaps=*/llvm::makeArrayRef({*input_map, output_map}),
273       /*iteratorTypes=*/iterator_types,
274       [&](OpBuilder& nested_builder, Location nested_loc, ValueRange args) {
275         nested_builder.create<linalg::YieldOp>(nested_loc, args[0]);
276       });
277 
278   return success();
279 }
280 
281 // -------------------------------------------------------------------------- //
282 // Optimize function based on the symbolic shape attributes.
283 // -------------------------------------------------------------------------- //
284 
285 struct SymbolicShapeOptimizationPass
286     : public SymbolicShapeOptimizationBase<SymbolicShapeOptimizationPass> {
287   SymbolicShapeOptimizationPass() = default;
288 
SymbolicShapeOptimizationPasstensorflow::__anondb63220c0111::SymbolicShapeOptimizationPass289   explicit SymbolicShapeOptimizationPass(bool constraints_only) {
290     this->optimize_only_constraints = constraints_only;
291   }
292 
runOnOperationtensorflow::__anondb63220c0111::SymbolicShapeOptimizationPass293   void runOnOperation() override {
294     MLIRContext* ctx = &getContext();
295     mlir::RewritePatternSet patterns(ctx);
296 
297     // Rewrite shape.broadcast based on the symbolic shapes.
298     patterns.add<BroadcastOpLowering>(ctx);
299 
300     // Rewrite broadcasts based on the symbolic shapes if enabled.
301     if (!optimize_only_constraints)
302       patterns.add<DynamicBroadcastInDimOpLowering>(ctx);
303 
304     // Add shape dialect canonicalization patterns to fold shape operations
305     // after constraints are replaced with constant witness.
306     for (auto op : ctx->getRegisteredOperations()) {
307       if (llvm::isa<shape::ShapeDialect>(op.getDialect()))
308         op.getCanonicalizationPatterns(patterns, ctx);
309     }
310 
311     if (failed(mlir::applyPatternsAndFoldGreedily(getOperation(),
312                                                   std::move(patterns)))) {
313       return signalPassFailure();
314     }
315   }
316 };
317 
318 }  // namespace
319 
CreateSymbolicShapeOptimizationPass(bool constraints_only)320 std::unique_ptr<OperationPass<FuncOp>> CreateSymbolicShapeOptimizationPass(
321     bool constraints_only) {
322   return std::make_unique<SymbolicShapeOptimizationPass>(constraints_only);
323 }
324 
325 }  // namespace tensorflow
326