1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // This file supports the lowering of CHLO/HLO/LHLO dialect to Linalg dialect.
17
18 #ifndef MLIR_HLO_DIALECT_MHLO_TRANSFORMS_LEGALIZE_TO_LINALG_UTILS_H_
19 #define MLIR_HLO_DIALECT_MHLO_TRANSFORMS_LEGALIZE_TO_LINALG_UTILS_H_
20
21 #include <algorithm>
22 #include <numeric>
23 #include <string>
24 #include <utility>
25
26 #include "llvm/ADT/STLExtras.h"
27 #include "llvm/ADT/SmallVector.h"
28 #include "llvm/ADT/StringSet.h"
29 #include "mlir-hlo/Dialect/mhlo/transforms/map_mhlo_to_scalar_op.h"
30 #include "mlir/Dialect/Linalg/IR/Linalg.h"
31 #include "mlir/Dialect/Tensor/IR/Tensor.h"
32 #include "mlir/IR/AffineExpr.h"
33 #include "mlir/IR/Attributes.h"
34 #include "mlir/IR/Builders.h"
35 #include "mlir/IR/BuiltinAttributes.h"
36 #include "mlir/IR/BuiltinTypes.h"
37 #include "mlir/IR/Location.h"
38 #include "mlir/IR/MLIRContext.h"
39 #include "mlir/IR/Operation.h"
40 #include "mlir/IR/OperationSupport.h"
41 #include "mlir/IR/TypeUtilities.h"
42 #include "mlir/Support/LLVM.h"
43 #include "mlir/Support/LogicalResult.h"
44 #include "mlir/Transforms/DialectConversion.h"
45
46 namespace mlir {
47 namespace mhlo {
48
49 /// Returns an ArrayAttr that contains `nLoops` attributes. All the attributes
50 /// are "parallel" except the last `nReduction` elements, where are "reduction"
51 /// attributes.
52 SmallVector<StringRef, 3> getParallelAndReductionIterators(unsigned nLoops,
53 unsigned nReduction);
54
55 /// Returns an ArrayAttr that contains `nParallelLoops` "parallel" attributes.
56 SmallVector<StringRef, 3> getNParallelLoopsAttrs(unsigned nParallelLoops);
57
58 /// Generates an initTensor op in the linalg dialect.
59 Value getInitTensor(OpBuilder& b, Location loc, ShapedType type,
60 ArrayRef<Value> dynSizes);
61
62 /// Generates an tensor initialization for the result of the operation, which
63 /// would be a dense tensor or a sparse tensor.
64 Value getInitTensorFor(OpBuilder& b, Location loc, ShapedType resultType,
65 Operation* op, ValueRange operands);
66
67 /// Sparsifies a (block of) operation(s) that cannot be handled directly
68 /// by the sparse compiler but has well-known semi-ring semantics.
69 ///
70 /// This yields something of the following form:
71 ///
72 /// %result = sparse_tensor.unary %values[0]
73 /// present={
74 /// ^bb1(%val):
75 /// ... codegen proceeds here using %val ....
76 /// sparse_tensor.yield
77 /// }
78 /// absent={}
79 /// linalg.yield %result
80 Value preSparsify(Operation* op, llvm::SmallVector<Value, 2>& values, Type rtp,
81 OpBuilder* b);
82
83 /// Finalizes sparse semi-ring construction.
84 Value postSparsify(Operation* op, Value semiring, Value result, OpBuilder* b);
85
86 template <typename OpTy>
pruneAttributeList(OpTy op)87 SmallVector<NamedAttribute> pruneAttributeList(OpTy op) {
88 auto opAttributes = op.getAttributeNames();
89 llvm::StringSet<> elidedAttrs;
90 elidedAttrs.insert(opAttributes.begin(), opAttributes.end());
91 SmallVector<NamedAttribute> preservedAttrs;
92 for (auto attr : op->getAttrs()) {
93 if (elidedAttrs.count(attr.getName())) continue;
94 preservedAttrs.push_back(attr);
95 }
96 return preservedAttrs;
97 }
98
99 /// Converts a HLO operation to a linalg.generic op that contains the
100 /// corresponding scalar operations.
101 template <typename OpTy>
102 class PointwiseToLinalgConverter : public OpConversionPattern<OpTy> {
103 public:
104 using OpConversionPattern<OpTy>::OpConversionPattern;
105
matchAndRewrite(OpTy op,typename OpTy::Adaptor adaptor,ConversionPatternRewriter & rewriter)106 LogicalResult matchAndRewrite(
107 OpTy op, typename OpTy::Adaptor adaptor,
108 ConversionPatternRewriter& rewriter) const final {
109 // Find maximum rank / number of loops.
110 auto getRank = [](Value v) {
111 return v.getType().cast<ShapedType>().getRank();
112 };
113 auto isScalar = [&](Value v) { return getRank(v) == 0; };
114 auto it = llvm::find_if_not(adaptor.getOperands(), isScalar);
115 Value maxRankArg =
116 it != adaptor.getOperands().end() ? *it : adaptor.getOperands().front();
117 int64_t nloops = getRank(maxRankArg);
118
119 // Apply only if all operands are scalar or have the same rank. Some ops,
120 // like `mhlo.select`, support implicit broadcasting of scalars.
121 if (!llvm::all_of(adaptor.getOperands(), [&](Value v) {
122 int64_t r = getRank(v);
123 return r == 0 || r == nloops;
124 })) {
125 return rewriter.notifyMatchFailure(
126 op, "Operands must be os same rank or scalar.");
127 }
128
129 // Find result type, if on tensors.
130 Optional<ShapedType> resultTy;
131 resultTy = this->typeConverter->convertType(op->getResultTypes().front())
132 .template dyn_cast<ShapedType>();
133
134 // Check result type compatibility.
135 if (!resultTy || !resultTy->hasRank() || resultTy->getRank() != nloops ||
136 !(resultTy->getElementType().isSignlessIntOrFloat() ||
137 resultTy->getElementType().isa<ComplexType>())) {
138 return rewriter.notifyMatchFailure(
139 op, "mismatched operand/result types or iterator count");
140 }
141
142 auto loc = op.getLoc();
143 // TODO(jreiffers): Enable this optimization outside of linalg ops. This
144 // currently breaks KernelGen.
145 if (nloops == 0 && isInBodyOfLinalgOps(op)) {
146 // No need to create a linalg.generic if all inputs are scalars.
147 SmallVector<Value> inputs;
148 for (auto input : adaptor.getOperands()) {
149 inputs.push_back(
150 rewriter.create<tensor::ExtractOp>(loc, input, ValueRange()));
151 }
152 Value scalarResult = mhlo::MhloOpToStdScalarOp::mapOp(
153 op, resultTy->getElementType(), inputs, &rewriter);
154 if (!scalarResult) return failure();
155 rewriter.replaceOpWithNewOp<tensor::FromElementsOp>(op, *resultTy,
156 scalarResult);
157 return success();
158 }
159
160 // Find input/output values and types.
161 ValueRange inputs = adaptor.getOperands();
162 Value output =
163 getInitTensorFor(rewriter, loc, *resultTy, op, adaptor.getOperands());
164
165 // Create indexing maps.
166 AffineMap scalarMap = AffineMap::get(nloops, 0, rewriter.getContext());
167 AffineMap idMap = rewriter.getMultiDimIdentityMap(nloops);
168 SmallVector<AffineMap, 4> maps;
169 for (Value v : inputs) maps.push_back(isScalar(v) ? scalarMap : idMap);
170 maps.push_back(idMap);
171
172 // Build `linalg.generic` op.
173 bool failed = false;
174 auto linalgOp = rewriter.create<linalg::GenericOp>(
175 loc, resultTy ? *resultTy : TypeRange{}, inputs, output, maps,
176 getNParallelLoopsAttrs(nloops),
177 [&](OpBuilder& nestedBuilder, Location /*nested_loc*/,
178 ValueRange args) {
179 Type innerResultTy = getElementTypeOrSelf(output);
180 auto argvec = llvm::to_vector<2>(args.take_front(inputs.size()));
181 auto semiring = preSparsify(op, argvec, innerResultTy, &rewriter);
182 Value innerResult = mhlo::MhloOpToStdScalarOp::mapOp(
183 op, innerResultTy, argvec, &rewriter);
184 if (innerResult == nullptr) {
185 failed = true;
186 } else {
187 innerResult = postSparsify(op, semiring, innerResult, &rewriter);
188 nestedBuilder.create<linalg::YieldOp>(loc, innerResult);
189 }
190 },
191 pruneAttributeList(op));
192 if (failed) return failure();
193
194 rewriter.replaceOp(op, linalgOp->getResults());
195 return success();
196 }
197
198 private:
isInBodyOfLinalgOps(Operation * op)199 static bool isInBodyOfLinalgOps(Operation* op) {
200 auto* parentOp = op->getParentRegion()->getParentOp();
201 return parentOp->getDialect() ==
202 parentOp->getContext()->getLoadedDialect<linalg::LinalgDialect>();
203 }
204 };
205
206 } // namespace mhlo
207
208 } // namespace mlir
209
210 #endif // MLIR_HLO_DIALECT_MHLO_TRANSFORMS_LEGALIZE_TO_LINALG_UTILS_H_
211