1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // This file implements logic for flattening tuples in HLO ops.
17
18 #include <cassert>
19 #include <string>
20 #include <utility>
21
22 #include "llvm/ADT/ArrayRef.h"
23 #include "llvm/ADT/MapVector.h"
24 #include "llvm/ADT/SmallVector.h"
25 #include "llvm/ADT/StringRef.h"
26 #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
27 #include "mlir-hlo/Dialect/mhlo/transforms/PassDetail.h"
28 #include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
29 #include "mlir/Dialect/Func/IR/FuncOps.h"
30 #include "mlir/IR/BuiltinTypes.h"
31 #include "mlir/IR/Location.h"
32 #include "mlir/IR/Operation.h"
33 #include "mlir/IR/Value.h"
34 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
35
36 namespace mlir {
37 namespace mhlo {
38 namespace {
39
40 // Calculates the flatten types of a value.
flattenTupleType(Value value,llvm::SmallVectorImpl<Type> & types)41 void flattenTupleType(Value value, llvm::SmallVectorImpl<Type> &types) {
42 if (!value.getType().isa<TupleType>()) {
43 types.push_back(value.getType());
44 return;
45 }
46
47 // This function doesn't handle nested tuple.
48 auto tupleType = value.getType().cast<TupleType>();
49 types.append(tupleType.begin(), tupleType.end());
50 }
51
52 // FlattenTupleValue and CreateTupleValue is a pair of functions to create and
53 // flatten tuples in the exact same order. CreateTupleValue returns the result
54 // of the root TupleOp or given value if the type is not TupleType.
createTupleValue(OpBuilder & builder,Location loc,ValueRange flattenValues,Type tupleType)55 Value createTupleValue(OpBuilder &builder, Location loc,
56 ValueRange flattenValues, Type tupleType) {
57 if (!tupleType.isa<TupleType>()) {
58 assert(flattenValues.size() == 1);
59 return flattenValues[0];
60 }
61
62 assert(tupleType.cast<TupleType>().getTypes().size() == flattenValues.size());
63 return builder.create<mhlo::TupleOp>(loc, flattenValues);
64 }
65
66 struct FlattenCustomCallOp : public OpRewritePattern<CustomCallOp> {
67 using OpRewritePattern::OpRewritePattern;
68
matchAndRewritemlir::mhlo::__anonc5a502700111::FlattenCustomCallOp69 LogicalResult matchAndRewrite(CustomCallOp op,
70 PatternRewriter &rewriter) const override {
71 llvm::SmallVector<Type, 4> flattenedResultTypes;
72 if (op->getNumResults() != 1 ||
73 !op->getResult(0).getType().isa<TupleType>())
74 return failure();
75
76 // Check for nested tuples.
77 for (Type innerType :
78 op->getResult(0).getType().cast<TupleType>().getTypes())
79 if (innerType.isa<TupleType>()) return failure();
80
81 for (auto result : op->getResults())
82 flattenTupleType(result, flattenedResultTypes);
83
84 auto flattenedCall = rewriter.create<mhlo::CustomCallOp>(
85 op->getLoc(), flattenedResultTypes, op->getOperands(), op->getAttrs());
86
87 auto tuple =
88 createTupleValue(rewriter, op->getLoc(), flattenedCall.getResults(),
89 op->getResult(0).getType());
90 rewriter.replaceOp(op, tuple);
91 return success();
92 }
93 };
94
95 class FlattenTuplePass : public FlattenTuplePassBase<FlattenTuplePass> {
96 public:
runOnOperation()97 void runOnOperation() override {
98 MLIRContext *context = &getContext();
99 RewritePatternSet patterns(context);
100 patterns.add<FlattenCustomCallOp>(context);
101 if (failed(applyPatternsAndFoldGreedily(getOperation(),
102 std::move(patterns)))) {
103 signalPassFailure();
104 }
105 }
106 };
107 } // end namespace
108
109 static PassRegistration<FlattenTuplePass> pass;
110
createFlattenTuplePass()111 std::unique_ptr<OperationPass<func::FuncOp>> createFlattenTuplePass() {
112 return std::make_unique<FlattenTuplePass>();
113 }
114
115 } // end namespace mhlo
116 } // end namespace mlir
117