xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/gml_st/transforms/fusion.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <memory>
17 #include <utility>
18 
19 #include "mlir-hlo/Dialect/gml_st/IR/gml_st_ops.h"
20 #include "mlir-hlo/Dialect/gml_st/transforms/fusion_interface.h"
21 #include "mlir-hlo/Dialect/gml_st/transforms/fusion_interface_impl.h"
22 #include "mlir-hlo/Dialect/gml_st/transforms/pass_detail.h"
23 #include "mlir-hlo/Dialect/gml_st/transforms/passes.h"
24 #include "mlir-hlo/Dialect/thlo/IR/thlo_ops.h"
25 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
26 #include "mlir/Dialect/Func/IR/FuncOps.h"
27 #include "mlir/Dialect/Linalg/IR/Linalg.h"
28 #include "mlir/Dialect/SCF/IR/SCF.h"
29 #include "mlir/Dialect/Shape/IR/Shape.h"
30 #include "mlir/Dialect/Tensor/IR/Tensor.h"
31 #include "mlir/IR/Attributes.h"
32 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
33 
34 namespace mlir {
35 namespace gml_st {
36 namespace {
37 
38 // TODO(frgossen): Move this to the shape reification pass.
39 struct DimOpFissionPattern : public OpRewritePattern<tensor::ExtractOp> {
40   using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
41 
matchAndRewritemlir::gml_st::__anonc39c8f940111::DimOpFissionPattern42   LogicalResult matchAndRewrite(tensor::ExtractOp extract,
43                                 PatternRewriter& rewriter) const override {
44     auto shapeDef = llvm::dyn_cast_or_null<shape::ShapeOfOp>(
45         extract.getTensor().getDefiningOp());
46     if (!shapeDef || extract.getIndices().size() != 1) return failure();
47     rewriter.replaceOpWithNewOp<tensor::DimOp>(extract, shapeDef.getArg(),
48                                                extract.getIndices().front());
49     return success();
50   }
51 };
52 
53 // TODO(frgossen): Implement this through the shape reification interface and
54 // move this pattern to the shape reification pass.
55 struct DimOpReificationPattern : public OpRewritePattern<tensor::DimOp> {
56   using OpRewritePattern<tensor::DimOp>::OpRewritePattern;
57 
matchAndRewritemlir::gml_st::__anonc39c8f940111::DimOpReificationPattern58   LogicalResult matchAndRewrite(tensor::DimOp op,
59                                 PatternRewriter& rewriter) const override {
60     Operation* def = op.getSource().getDefiningOp();
61     if (!def) return failure();
62 
63     // Case MaterializeOp.
64     if (auto materializeOp = llvm::dyn_cast<MaterializeOp>(def)) {
65       assert(materializeOp->getNumResults() == 1 && "assume single result");
66       Value set = materializeOp.set();
67       if (!set.getType().isa<TileType>()) return failure();
68       rewriter.replaceOpWithNewOp<gml_st::SizeOp>(op, set, op.getIndex());
69       return success();
70     }
71 
72     // Case GenericOp.
73     if (auto genericOp = llvm::dyn_cast<linalg::GenericOp>(def)) {
74       if (genericOp.getNumResults() != 1 || !genericOp.hasTensorSemantics()) {
75         return failure();
76       }
77       Value outputOperand = genericOp.getOutputOperand(0)->get();
78       rewriter.replaceOpWithNewOp<tensor::DimOp>(op, outputOperand,
79                                                  op.getIndex());
80       return success();
81     }
82 
83     // Case InitTensorOp.
84     if (auto initTensorOp = llvm::dyn_cast<linalg::InitTensorOp>(def)) {
85       if (auto indexConstantOp = llvm::dyn_cast_or_null<arith::ConstantOp>(
86               op.getIndex().getDefiningOp())) {
87         int64_t idx =
88             indexConstantOp.getValue().dyn_cast<IntegerAttr>().getInt();
89         OpFoldResult dim = initTensorOp.getMixedSizes()[idx];
90         Value dimValue;
91         if (dim.is<Value>()) {
92           dimValue = dim.get<Value>();
93         } else {
94           assert(dim.is<Attribute>() && "expected Value or Attribute");
95           int64_t dimInt = dim.get<Attribute>().cast<IntegerAttr>().getInt();
96           dimValue =
97               rewriter.create<arith::ConstantIndexOp>(op.getLoc(), dimInt);
98         }
99         assert(dimValue);
100         rewriter.replaceOp(op, ValueRange{dimValue});
101         return success();
102       }
103     }
104 
105     // Case ConcatenateOp.
106     if (auto concat = llvm::dyn_cast<thlo::ConcatenateOp>(def)) {
107       rewriter.replaceOpWithNewOp<tensor::DimOp>(op, concat.init(),
108                                                  op.getIndex());
109       return success();
110     }
111 
112     // Case DynamicBroadcastInDimOp.
113     if (auto bcast = llvm::dyn_cast<thlo::DynamicBroadcastInDimOp>(def)) {
114       rewriter.replaceOpWithNewOp<tensor::DimOp>(op, bcast.init(),
115                                                  op.getIndex());
116       return success();
117     }
118 
119     return failure();
120   }
121 };
122 
123 struct FusionPattern : public OpRewritePattern<MaterializeOp> {
124   using OpRewritePattern<MaterializeOp>::OpRewritePattern;
125 
matchAndRewritemlir::gml_st::__anonc39c8f940111::FusionPattern126   LogicalResult matchAndRewrite(MaterializeOp op,
127                                 PatternRewriter& rewriter) const override {
128     Operation* def = op.source().getDefiningOp();
129     if (!def) return failure();
130 
131     auto iface = llvm::dyn_cast<FusionInterface>(def);
132     if (!iface) return failure();
133 
134     Value fused = iface.fuse(op.getLoc(), op.set(), rewriter);
135     if (!fused) return failure();
136 
137     rewriter.replaceOp(op, fused);
138     return success();
139   }
140 };
141 
142 class FusionPass : public FusionPassBase<FusionPass> {
getDependentDialects(DialectRegistry & registry) const143   void getDependentDialects(DialectRegistry& registry) const final {
144     registry.insert<scf::SCFDialect>();
145     registerFusionInterfaceExternalModels(registry);
146   }
147 
runOnOperation()148   void runOnOperation() final {
149     MLIRContext* ctx = &getContext();
150 
151     // Populate patterns.
152     RewritePatternSet patterns(ctx);
153     // clang-format off
154     patterns.insert<
155         DimOpFissionPattern,
156         DimOpReificationPattern,
157         FusionPattern>(ctx);
158     // clang-format on
159 
160     if (failed(applyPatternsAndFoldGreedily(getOperation(),
161                                             std::move(patterns)))) {
162       return signalPassFailure();
163     }
164   }
165 };
166 
167 }  // namespace
168 
createFusionPass()169 std::unique_ptr<OperationPass<func::FuncOp>> createFusionPass() {
170   return std::make_unique<FusionPass>();
171 }
172 
173 }  // namespace gml_st
174 }  // namespace mlir
175