1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // This file supports the lowering of CHLO/HLO/LHLO dialect to Linalg dialect.
17 
18 #ifndef MLIR_HLO_DIALECT_MHLO_TRANSFORMS_LEGALIZE_TO_LINALG_UTILS_H_
19 #define MLIR_HLO_DIALECT_MHLO_TRANSFORMS_LEGALIZE_TO_LINALG_UTILS_H_
20 
21 #include <algorithm>
22 #include <numeric>
23 #include <string>
24 #include <utility>
25 
26 #include "llvm/ADT/STLExtras.h"
27 #include "llvm/ADT/SmallVector.h"
28 #include "llvm/ADT/StringSet.h"
29 #include "mlir-hlo/Dialect/mhlo/transforms/map_mhlo_to_scalar_op.h"
30 #include "mlir/Dialect/Linalg/IR/Linalg.h"
31 #include "mlir/Dialect/Tensor/IR/Tensor.h"
32 #include "mlir/IR/AffineExpr.h"
33 #include "mlir/IR/Attributes.h"
34 #include "mlir/IR/Builders.h"
35 #include "mlir/IR/BuiltinAttributes.h"
36 #include "mlir/IR/BuiltinTypes.h"
37 #include "mlir/IR/Location.h"
38 #include "mlir/IR/MLIRContext.h"
39 #include "mlir/IR/Operation.h"
40 #include "mlir/IR/OperationSupport.h"
41 #include "mlir/IR/TypeUtilities.h"
42 #include "mlir/Support/LLVM.h"
43 #include "mlir/Support/LogicalResult.h"
44 #include "mlir/Transforms/DialectConversion.h"
45 
46 namespace mlir {
47 namespace mhlo {
48 
49 /// Returns an ArrayAttr that contains `nLoops` attributes. All the attributes
50 /// are "parallel" except the last `nReduction` elements, where are "reduction"
51 /// attributes.
52 SmallVector<StringRef, 3> getParallelAndReductionIterators(unsigned nLoops,
53                                                            unsigned nReduction);
54 
55 /// Returns an ArrayAttr that contains `nParallelLoops` "parallel" attributes.
56 SmallVector<StringRef, 3> getNParallelLoopsAttrs(unsigned nParallelLoops);
57 
58 /// Generates an initTensor op in the linalg dialect.
59 Value getInitTensor(OpBuilder& b, Location loc, ShapedType type,
60                     ArrayRef<Value> dynSizes);
61 
62 /// Generates an tensor initialization for the result of the operation, which
63 /// would be a dense tensor or a sparse tensor.
64 Value getInitTensorFor(OpBuilder& b, Location loc, ShapedType resultType,
65                        Operation* op, ValueRange operands);
66 
67 /// Sparsifies a (block of) operation(s) that cannot be handled directly
68 /// by the sparse compiler but has well-known semi-ring semantics.
69 ///
70 /// This yields something of the following form:
71 ///
72 ///   %result = sparse_tensor.unary %values[0]
73 ///     present={
74 ///       ^bb1(%val):
75 ///         ... codegen proceeds here using %val ....
76 ///         sparse_tensor.yield
77 ///     }
78 ///     absent={}
79 ///   linalg.yield %result
80 Value preSparsify(Operation* op, llvm::SmallVector<Value, 2>& values, Type rtp,
81                   OpBuilder* b);
82 
83 /// Finalizes sparse semi-ring construction.
84 Value postSparsify(Operation* op, Value semiring, Value result, OpBuilder* b);
85 
86 template <typename OpTy>
pruneAttributeList(OpTy op)87 SmallVector<NamedAttribute> pruneAttributeList(OpTy op) {
88   auto opAttributes = op.getAttributeNames();
89   llvm::StringSet<> elidedAttrs;
90   elidedAttrs.insert(opAttributes.begin(), opAttributes.end());
91   SmallVector<NamedAttribute> preservedAttrs;
92   for (auto attr : op->getAttrs()) {
93     if (elidedAttrs.count(attr.getName())) continue;
94     preservedAttrs.push_back(attr);
95   }
96   return preservedAttrs;
97 }
98 
99 /// Converts a HLO operation to a linalg.generic op that contains the
100 /// corresponding scalar operations.
101 template <typename OpTy>
102 class PointwiseToLinalgConverter : public OpConversionPattern<OpTy> {
103  public:
104   using OpConversionPattern<OpTy>::OpConversionPattern;
105 
matchAndRewrite(OpTy op,typename OpTy::Adaptor adaptor,ConversionPatternRewriter & rewriter)106   LogicalResult matchAndRewrite(
107       OpTy op, typename OpTy::Adaptor adaptor,
108       ConversionPatternRewriter& rewriter) const final {
109     // Find maximum rank / number of loops.
110     auto getRank = [](Value v) {
111       return v.getType().cast<ShapedType>().getRank();
112     };
113     auto isScalar = [&](Value v) { return getRank(v) == 0; };
114     auto it = llvm::find_if_not(adaptor.getOperands(), isScalar);
115     Value maxRankArg =
116         it != adaptor.getOperands().end() ? *it : adaptor.getOperands().front();
117     int64_t nloops = getRank(maxRankArg);
118 
119     // Apply only if all operands are scalar or have the same rank. Some ops,
120     // like `mhlo.select`, support implicit broadcasting of scalars.
121     if (!llvm::all_of(adaptor.getOperands(), [&](Value v) {
122           int64_t r = getRank(v);
123           return r == 0 || r == nloops;
124         })) {
125       return rewriter.notifyMatchFailure(
126           op, "Operands must be os same rank or scalar.");
127     }
128 
129     // Find result type, if on tensors.
130     Optional<ShapedType> resultTy;
131     resultTy = this->typeConverter->convertType(op->getResultTypes().front())
132                    .template dyn_cast<ShapedType>();
133 
134     // Check result type compatibility.
135     if (!resultTy || !resultTy->hasRank() || resultTy->getRank() != nloops ||
136         !(resultTy->getElementType().isSignlessIntOrFloat() ||
137           resultTy->getElementType().isa<ComplexType>())) {
138       return rewriter.notifyMatchFailure(
139           op, "mismatched operand/result types or iterator count");
140     }
141 
142     auto loc = op.getLoc();
143     // TODO(jreiffers): Enable this optimization outside of linalg ops. This
144     // currently breaks KernelGen.
145     if (nloops == 0 && isInBodyOfLinalgOps(op)) {
146       // No need to create a linalg.generic if all inputs are scalars.
147       SmallVector<Value> inputs;
148       for (auto input : adaptor.getOperands()) {
149         inputs.push_back(
150             rewriter.create<tensor::ExtractOp>(loc, input, ValueRange()));
151       }
152       Value scalarResult = mhlo::MhloOpToStdScalarOp::mapOp(
153           op, resultTy->getElementType(), inputs, &rewriter);
154       if (!scalarResult) return failure();
155       rewriter.replaceOpWithNewOp<tensor::FromElementsOp>(op, *resultTy,
156                                                           scalarResult);
157       return success();
158     }
159 
160     // Find input/output values and types.
161     ValueRange inputs = adaptor.getOperands();
162     Value output =
163         getInitTensorFor(rewriter, loc, *resultTy, op, adaptor.getOperands());
164 
165     // Create indexing maps.
166     AffineMap scalarMap = AffineMap::get(nloops, 0, rewriter.getContext());
167     AffineMap idMap = rewriter.getMultiDimIdentityMap(nloops);
168     SmallVector<AffineMap, 4> maps;
169     for (Value v : inputs) maps.push_back(isScalar(v) ? scalarMap : idMap);
170     maps.push_back(idMap);
171 
172     // Build `linalg.generic` op.
173     bool failed = false;
174     auto linalgOp = rewriter.create<linalg::GenericOp>(
175         loc, resultTy ? *resultTy : TypeRange{}, inputs, output, maps,
176         getNParallelLoopsAttrs(nloops),
177         [&](OpBuilder& nestedBuilder, Location /*nested_loc*/,
178             ValueRange args) {
179           Type innerResultTy = getElementTypeOrSelf(output);
180           auto argvec = llvm::to_vector<2>(args.take_front(inputs.size()));
181           auto semiring = preSparsify(op, argvec, innerResultTy, &rewriter);
182           Value innerResult = mhlo::MhloOpToStdScalarOp::mapOp(
183               op, innerResultTy, argvec, &rewriter);
184           if (innerResult == nullptr) {
185             failed = true;
186           } else {
187             innerResult = postSparsify(op, semiring, innerResult, &rewriter);
188             nestedBuilder.create<linalg::YieldOp>(loc, innerResult);
189           }
190         },
191         pruneAttributeList(op));
192     if (failed) return failure();
193 
194     rewriter.replaceOp(op, linalgOp->getResults());
195     return success();
196   }
197 
198  private:
isInBodyOfLinalgOps(Operation * op)199   static bool isInBodyOfLinalgOps(Operation* op) {
200     auto* parentOp = op->getParentRegion()->getParentOp();
201     return parentOp->getDialect() ==
202            parentOp->getContext()->getLoadedDialect<linalg::LinalgDialect>();
203   }
204 };
205 
206 }  // namespace mhlo
207 
208 }  // namespace mlir
209 
210 #endif  // MLIR_HLO_DIALECT_MHLO_TRANSFORMS_LEGALIZE_TO_LINALG_UTILS_H_
211