1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <memory>
17 #include <utility>
18
19 #include "mlir-hlo/Dialect/gml_st/IR/gml_st_ops.h"
20 #include "mlir-hlo/Dialect/gml_st/transforms/fusion_interface.h"
21 #include "mlir-hlo/Dialect/gml_st/transforms/fusion_interface_impl.h"
22 #include "mlir-hlo/Dialect/gml_st/transforms/pass_detail.h"
23 #include "mlir-hlo/Dialect/gml_st/transforms/passes.h"
24 #include "mlir-hlo/Dialect/thlo/IR/thlo_ops.h"
25 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
26 #include "mlir/Dialect/Func/IR/FuncOps.h"
27 #include "mlir/Dialect/Linalg/IR/Linalg.h"
28 #include "mlir/Dialect/SCF/IR/SCF.h"
29 #include "mlir/Dialect/Shape/IR/Shape.h"
30 #include "mlir/Dialect/Tensor/IR/Tensor.h"
31 #include "mlir/IR/Attributes.h"
32 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
33
34 namespace mlir {
35 namespace gml_st {
36 namespace {
37
38 // TODO(frgossen): Move this to the shape reification pass.
39 struct DimOpFissionPattern : public OpRewritePattern<tensor::ExtractOp> {
40 using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
41
matchAndRewritemlir::gml_st::__anonc39c8f940111::DimOpFissionPattern42 LogicalResult matchAndRewrite(tensor::ExtractOp extract,
43 PatternRewriter& rewriter) const override {
44 auto shapeDef = llvm::dyn_cast_or_null<shape::ShapeOfOp>(
45 extract.getTensor().getDefiningOp());
46 if (!shapeDef || extract.getIndices().size() != 1) return failure();
47 rewriter.replaceOpWithNewOp<tensor::DimOp>(extract, shapeDef.getArg(),
48 extract.getIndices().front());
49 return success();
50 }
51 };
52
53 // TODO(frgossen): Implement this through the shape reification interface and
54 // move this pattern to the shape reification pass.
55 struct DimOpReificationPattern : public OpRewritePattern<tensor::DimOp> {
56 using OpRewritePattern<tensor::DimOp>::OpRewritePattern;
57
matchAndRewritemlir::gml_st::__anonc39c8f940111::DimOpReificationPattern58 LogicalResult matchAndRewrite(tensor::DimOp op,
59 PatternRewriter& rewriter) const override {
60 Operation* def = op.getSource().getDefiningOp();
61 if (!def) return failure();
62
63 // Case MaterializeOp.
64 if (auto materializeOp = llvm::dyn_cast<MaterializeOp>(def)) {
65 assert(materializeOp->getNumResults() == 1 && "assume single result");
66 Value set = materializeOp.set();
67 if (!set.getType().isa<TileType>()) return failure();
68 rewriter.replaceOpWithNewOp<gml_st::SizeOp>(op, set, op.getIndex());
69 return success();
70 }
71
72 // Case GenericOp.
73 if (auto genericOp = llvm::dyn_cast<linalg::GenericOp>(def)) {
74 if (genericOp.getNumResults() != 1 || !genericOp.hasTensorSemantics()) {
75 return failure();
76 }
77 Value outputOperand = genericOp.getOutputOperand(0)->get();
78 rewriter.replaceOpWithNewOp<tensor::DimOp>(op, outputOperand,
79 op.getIndex());
80 return success();
81 }
82
83 // Case InitTensorOp.
84 if (auto initTensorOp = llvm::dyn_cast<linalg::InitTensorOp>(def)) {
85 if (auto indexConstantOp = llvm::dyn_cast_or_null<arith::ConstantOp>(
86 op.getIndex().getDefiningOp())) {
87 int64_t idx =
88 indexConstantOp.getValue().dyn_cast<IntegerAttr>().getInt();
89 OpFoldResult dim = initTensorOp.getMixedSizes()[idx];
90 Value dimValue;
91 if (dim.is<Value>()) {
92 dimValue = dim.get<Value>();
93 } else {
94 assert(dim.is<Attribute>() && "expected Value or Attribute");
95 int64_t dimInt = dim.get<Attribute>().cast<IntegerAttr>().getInt();
96 dimValue =
97 rewriter.create<arith::ConstantIndexOp>(op.getLoc(), dimInt);
98 }
99 assert(dimValue);
100 rewriter.replaceOp(op, ValueRange{dimValue});
101 return success();
102 }
103 }
104
105 // Case ConcatenateOp.
106 if (auto concat = llvm::dyn_cast<thlo::ConcatenateOp>(def)) {
107 rewriter.replaceOpWithNewOp<tensor::DimOp>(op, concat.init(),
108 op.getIndex());
109 return success();
110 }
111
112 // Case DynamicBroadcastInDimOp.
113 if (auto bcast = llvm::dyn_cast<thlo::DynamicBroadcastInDimOp>(def)) {
114 rewriter.replaceOpWithNewOp<tensor::DimOp>(op, bcast.init(),
115 op.getIndex());
116 return success();
117 }
118
119 return failure();
120 }
121 };
122
123 struct FusionPattern : public OpRewritePattern<MaterializeOp> {
124 using OpRewritePattern<MaterializeOp>::OpRewritePattern;
125
matchAndRewritemlir::gml_st::__anonc39c8f940111::FusionPattern126 LogicalResult matchAndRewrite(MaterializeOp op,
127 PatternRewriter& rewriter) const override {
128 Operation* def = op.source().getDefiningOp();
129 if (!def) return failure();
130
131 auto iface = llvm::dyn_cast<FusionInterface>(def);
132 if (!iface) return failure();
133
134 Value fused = iface.fuse(op.getLoc(), op.set(), rewriter);
135 if (!fused) return failure();
136
137 rewriter.replaceOp(op, fused);
138 return success();
139 }
140 };
141
142 class FusionPass : public FusionPassBase<FusionPass> {
getDependentDialects(DialectRegistry & registry) const143 void getDependentDialects(DialectRegistry& registry) const final {
144 registry.insert<scf::SCFDialect>();
145 registerFusionInterfaceExternalModels(registry);
146 }
147
runOnOperation()148 void runOnOperation() final {
149 MLIRContext* ctx = &getContext();
150
151 // Populate patterns.
152 RewritePatternSet patterns(ctx);
153 // clang-format off
154 patterns.insert<
155 DimOpFissionPattern,
156 DimOpReificationPattern,
157 FusionPattern>(ctx);
158 // clang-format on
159
160 if (failed(applyPatternsAndFoldGreedily(getOperation(),
161 std::move(patterns)))) {
162 return signalPassFailure();
163 }
164 }
165 };
166
167 } // namespace
168
createFusionPass()169 std::unique_ptr<OperationPass<func::FuncOp>> createFusionPass() {
170 return std::make_unique<FusionPass>();
171 }
172
173 } // namespace gml_st
174 } // namespace mlir
175