1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <utility>
17
18 #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
19 #include "mlir-hlo/Dialect/mhlo/transforms/PassDetail.h"
20 #include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
21 #include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
22 #include "mlir/Dialect/Func/IR/FuncOps.h"
23 #include "mlir/IR/BlockAndValueMapping.h"
24 #include "mlir/IR/BuiltinOps.h"
25 #include "mlir/Pass/Pass.h"
26 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
27
28 namespace mlir {
29 namespace mhlo {
30 namespace {
31
32 // TODO(b/228448038): consider to move this pattern to mhlo.map canonicalizer.
33 // Pattern to convert map of pure elementwise ops to directly use elementwise
34 // ops without map. e.g.
35 // %0 = "mhlo.map"(%arg, %arg1) ({
36 // ^bb0(%a: tensor<f32>, %b: tensor<f32>):
37 // %output = mhlo.add %a, %b : tensor<f32>
38 // "mhlo.return"(%output) : (tensor<f32>) -> ()
39 // }) {dimensions = dense<[0]> : tensor<1xi64>} :
40 // (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
41 // To:
42 // %0 = mhlo.add %arg, %arg1 : tensor<?xf32>
43 struct ConvertMapOfElementwiseOps : public OpRewritePattern<MapOp> {
44 using OpRewritePattern<MapOp>::OpRewritePattern;
45
matchAndRewritemlir::mhlo::__anone5d2bdfb0111::ConvertMapOfElementwiseOps46 LogicalResult matchAndRewrite(MapOp map,
47 PatternRewriter &rewriter) const override {
48 // Matches that the computation block only has element-wise ops.
49 if (llvm::any_of(map.computation().front().without_terminator(),
50 [](Operation &op) {
51 return op.getNumResults() != 1 ||
52 !op.hasTrait<::mlir::OpTrait::Elementwise>();
53 })) {
54 return failure();
55 }
56
57 rewriter.setInsertionPointAfter(map);
58 BlockAndValueMapping blockAndValueMap;
59 for (mlir::BlockArgument barg : map.computation().front().getArguments()) {
60 blockAndValueMap.map(barg, map->getOperand(barg.getArgNumber()));
61 }
62 auto shape = map.getType().getShape();
63 for (Operation &op : map.computation().front().without_terminator()) {
64 SmallVector<Value, 2> operands;
65 // Remaps the operands.
66 operands.reserve(op.getNumOperands());
67 for (auto value : op.getOperands())
68 operands.push_back(blockAndValueMap.lookup(value));
69 auto *newOp = rewriter.create(
70 op.getLoc(), op.getName().getIdentifier(), operands,
71 op.getResultTypes()[0].cast<TensorType>().clone(shape));
72 // Maps the result.
73 blockAndValueMap.map(op.getResult(0), newOp->getResult(0));
74 }
75
76 auto retOp = cast<ReturnOp>(map.computation().front().back());
77 map->getResult(0).replaceAllUsesWith(
78 blockAndValueMap.lookup(retOp->getOperand(0)));
79 return success();
80 }
81 };
82
83 struct CollapseElementwiseMapPass
84 : public CollapseElementwiseMapPassBase<CollapseElementwiseMapPass> {
runOnOperationmlir::mhlo::__anone5d2bdfb0111::CollapseElementwiseMapPass85 void runOnOperation() override {
86 MLIRContext *ctx = &getContext();
87 RewritePatternSet patterns(ctx);
88 patterns.add<ConvertMapOfElementwiseOps>(ctx);
89 if (failed(
90 applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
91 return signalPassFailure();
92 }
93 };
94 } // namespace
95
96 std::unique_ptr<OperationPass<func::FuncOp>>
createCollapseElementwiseMapPass()97 createCollapseElementwiseMapPass() {
98 return std::make_unique<CollapseElementwiseMapPass>();
99 }
100
101 } // namespace mhlo
102 } // namespace mlir
103