1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "mlir-hlo/Dialect/gml_st/transforms/fusion_interface_impl.h"
17 
18 #include <functional>
19 #include <tuple>
20 #include <utility>
21 
22 #include "llvm/ADT/ArrayRef.h"
23 #include "llvm/ADT/Optional.h"
24 #include "llvm/ADT/STLExtras.h"
25 #include "llvm/ADT/SmallVector.h"
26 #include "llvm/ADT/iterator_range.h"
27 #include "llvm/Support/Casting.h"
28 #include "llvm/Support/Error.h"
29 #include "llvm/Support/MathExtras.h"
30 #include "mlir-hlo/Dialect/gml_st/IR/gml_st_ops.h"
31 #include "mlir-hlo/Dialect/gml_st/transforms/fusion_interface.h"
32 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
33 #include "mlir/Dialect/Linalg/IR/Linalg.h"
34 #include "mlir/Dialect/Tensor/IR/Tensor.h"
35 #include "mlir/IR/AffineExpr.h"
36 #include "mlir/IR/BuiltinAttributes.h"
37 #include "mlir/IR/BuiltinTypes.h"
38 #include "mlir/IR/Location.h"
39 #include "mlir/IR/TypeRange.h"
40 
41 namespace mlir {
42 namespace gml_st {
43 
44 namespace {
45 
46 // Whether the operand needs the materialization of a point, given the output
47 // subset.
48 // This is the case if a) the outputput subset is a point and b) there are
49 // no reductions.
operandMaterializesToPoint(Value operand,const SmallVector<int64_t> & operandDimsToOutputDims,Value subset)50 bool operandMaterializesToPoint(
51     Value operand, const SmallVector<int64_t>& operandDimsToOutputDims,
52     Value subset) {
53   if (!subset.getType().isa<PointType>()) return false;
54 
55   const auto& operandShape =
56       operand.getType().cast<RankedTensorType>().getShape();
57   return llvm::all_of(llvm::zip(operandDimsToOutputDims, operandShape),
58                       [](const auto& e) {
59                         auto outputDimIndex = std::get<0>(e);
60                         auto operandDimSize = std::get<1>(e);
61                         // If the operand dimension maps to an output dim (the
62                         // output is already known to be a point), or the
63                         // operand's dimensions has a static size of 1, the
64                         // operand can materialize as a point.
65                         return (outputDimIndex >= 0) || (operandDimSize == 1);
66                       });
67 }
68 
buildPointOp(Location loc,OpBuilder & builder,Value operand,Value subset,const SmallVector<int64_t> & operandDimsToOutputDims)69 Value buildPointOp(Location loc, OpBuilder& builder, Value operand,
70                    Value subset,
71                    const SmallVector<int64_t>& operandDimsToOutputDims) {
72   auto operandShape = operand.getType().cast<RankedTensorType>().getShape();
73 
74   SmallVector<int64_t> staticOffsets(operandShape.size(),
75                                      ShapedType::kDynamicStrideOrOffset);
76   SmallVector<Value> dynamicOffsets;
77   for (int i = 0; i < operandShape.size(); ++i) {
78     if (int outputDim = operandDimsToOutputDims[i]; outputDim >= 0) {
79       auto index = builder.create<arith::ConstantIndexOp>(loc, outputDim);
80       dynamicOffsets.push_back(builder.create<OffsetOp>(loc, subset, index));
81     } else {
82       staticOffsets[i] = 0;
83     }
84   }
85 
86   SmallVector<int64_t> staticSizes(operandShape.size(), 1);
87   auto staticSizesAttr = builder.getI64ArrayAttr(staticSizes);
88   SpaceOp spaceOp =
89       builder.create<SpaceOp>(loc, builder.getType<TileType>(staticSizes),
90                               ValueRange{}, staticSizesAttr);
91 
92   auto staticOffsetsAttr = builder.getI64ArrayAttr(staticOffsets);
93   return builder.create<PointOp>(loc, builder.getType<PointType>(), spaceOp,
94                                  dynamicOffsets, staticOffsetsAttr);
95 }
96 
buildTileOp(Location loc,OpBuilder & builder,Value operand,Value subset,const SmallVector<int64_t> & operandDimsToOutputDims)97 Value buildTileOp(Location loc, OpBuilder& builder, Value operand, Value subset,
98                   const SmallVector<int64_t>& operandDimsToOutputDims) {
99   auto operandRank = operand.getType().cast<RankedTensorType>().getRank();
100 
101   SmallVector<int64_t> staticSizes(operandRank, ShapedType::kDynamicSize);
102   SmallVector<int64_t> staticStrides(operandRank,
103                                      ShapedType::kDynamicStrideOrOffset);
104   SmallVector<int64_t> staticOffsets(operandRank,
105                                      ShapedType::kDynamicStrideOrOffset);
106 
107   SmallVector<Value> dynamicSizes;
108   SmallVector<Value> dynamicStrides;
109   SmallVector<Value> dynamicOffsets;
110   for (int i = 0; i < operandRank; ++i) {
111     if (int outputDim = operandDimsToOutputDims[i]; outputDim >= 0) {
112       auto index = builder.create<arith::ConstantIndexOp>(loc, outputDim);
113       dynamicOffsets.push_back(builder.create<OffsetOp>(loc, subset, index));
114       if (subset.getType().isa<PointType>()) {
115         staticSizes[i] = 1;
116         staticStrides[i] = 1;
117       } else {
118         dynamicStrides.push_back(builder.create<StrideOp>(loc, subset, index));
119         dynamicSizes.push_back(builder.create<SizeOp>(loc, subset, index));
120       }
121     } else {
122       staticOffsets[i] = 0;
123       staticStrides[i] = 1;
124       dynamicSizes.push_back(builder.create<tensor::DimOp>(loc, operand, i));
125     }
126   }
127 
128   auto staticSizesAttr = builder.getI64ArrayAttr(staticSizes);
129   auto tileType = builder.getType<TileType>(staticSizes);
130   SpaceOp spaceOp =
131       builder.create<SpaceOp>(loc, tileType, dynamicSizes, staticSizesAttr);
132 
133   auto staticOffsetsAttr = builder.getI64ArrayAttr(staticOffsets);
134   auto staticStridesAttr = builder.getI64ArrayAttr(staticStrides);
135   return builder.create<TileOp>(loc, tileType, spaceOp, dynamicOffsets,
136                                 dynamicSizes, dynamicStrides, staticOffsetsAttr,
137                                 staticSizesAttr, staticStridesAttr);
138 }
139 
140 // For each iterator, returns the dimension in the output affine map where it
141 // occurs (unless it's a reduction iterator).
mapIteratorsToOutputs(AffineMap outputMap)142 Optional<SmallVector<Optional<int32_t>>> mapIteratorsToOutputs(
143     AffineMap outputMap) {
144   SmallVector<Optional<int32_t>> result(outputMap.getNumInputs());
145   for (uint32_t i = 0; i < outputMap.getNumResults(); ++i) {
146     auto dim = outputMap.getResult(i).dyn_cast<AffineDimExpr>();
147     if (!dim) return {};
148     if (result[dim.getPosition()]) return {};
149     result[dim.getPosition()] = i;
150   }
151   return result;
152 }
153 
154 struct LinalgGenericFusionInterface
155     : public FusionInterface::ExternalModel<LinalgGenericFusionInterface,
156                                             linalg::GenericOp> {
157   // Supports linalg.generics with a single output, if all output dimensions in
158   // all affine maps are affine dimensions (e.g.., (a,b,c) -> (a,b), but not
159   // (a,b,c) -> (a, 0)).
160   // See the test file tiling_and_fusion.mlir for examples.
fusemlir::gml_st::__anon49664e550111::LinalgGenericFusionInterface161   Value fuse(Operation* op, Location loc, Value subset,
162              OpBuilder& builder) const {
163     auto genericOp = llvm::cast<linalg::GenericOp>(op);
164     if (genericOp.getNumOutputs() != 1) return {};
165     Value output = genericOp.outputs().front();
166     auto outputRank = output.getType().cast<RankedTensorType>().getRank();
167 
168     auto indexingMaps =
169         to_vector(genericOp.indexing_maps().getAsValueRange<AffineMapAttr>());
170     auto maybeIteratorsToOutputs = mapIteratorsToOutputs(indexingMaps.back());
171     if (!maybeIteratorsToOutputs) return {};
172     const SmallVector<Optional<int32_t>>& iteratorsToOutputs =
173         *maybeIteratorsToOutputs;
174 
175     SmallVector<Value> materializedOperands;
176     SmallVector<bool> operandsArePoints;
177     for (const auto&& [operand, operandMap] :
178          llvm::zip(genericOp.inputs(), indexingMaps)) {
179       // Mapping from an operand dimension to an output dimension, or -1 if it
180       // doesn't occur in the output.
181       SmallVector<int64_t> operandDimsToOutputDims;
182 
183       // Whether the composition of the inverse of the operand's affine map and
184       // the output's affine map is the identity function (i.e., a given output
185       // coordinate maps to the same coordinate in the input).
186       bool isIdentity = operandMap.getResults().size() == outputRank;
187       SmallVector<bool> containsDim(outputRank);
188       for (const AffineExpr& expression : operandMap.getResults()) {
189         auto dim = expression.dyn_cast<AffineDimExpr>();
190         if (!dim) return {};
191         auto output = iteratorsToOutputs[dim.getPosition()];
192         operandDimsToOutputDims.push_back(output.value_or(-1));
193         if (output) containsDim[*output] = true;
194         isIdentity &= output.value_or(-1) == operandDimsToOutputDims.size() - 1;
195       }
196 
197       Value operandSubset;
198       if (isIdentity) {
199         operandSubset = subset;
200         operandsArePoints.push_back(subset.getType().isa<PointType>());
201       } else if (operandDimsToOutputDims.size() == outputRank &&
202                  !llvm::is_contained(containsDim, false)) {
203         operandSubset = builder.create<TransposeDimsOp>(
204             loc, subset,
205             DenseI64ArrayAttr::get(builder.getContext(),
206                                    operandDimsToOutputDims));
207         operandsArePoints.push_back(subset.getType().isa<PointType>());
208       } else if (operandMaterializesToPoint(operand, operandDimsToOutputDims,
209                                             subset)) {
210         operandSubset = buildPointOp(loc, builder, operand, subset,
211                                      operandDimsToOutputDims);
212         operandsArePoints.push_back(true);
213       } else {
214         operandSubset =
215             buildTileOp(loc, builder, operand, subset, operandDimsToOutputDims);
216         operandsArePoints.push_back(false);
217       }
218 
219       materializedOperands.push_back(
220           builder.create<MaterializeOp>(loc, operand, operandSubset));
221     }
222 
223     materializedOperands.push_back(
224         builder.create<MaterializeOp>(loc, output, subset));
225     if (!llvm::is_contained(operandsArePoints, false)) {
226       // Create scalar computation by copying from the `linalg.generic`
227       // body.
228       BlockAndValueMapping bvm;
229       Block* block = genericOp.getBody();
230       assert(block->getArguments().size() == materializedOperands.size() &&
231              "block argument count and sub operand count should be equal");
232       for (const auto&& [arg, materialized] :
233            llvm::zip(block->getArguments(), materializedOperands)) {
234         bvm.map(arg, materialized);
235       }
236       for (auto& it : block->without_terminator()) builder.clone(it, bvm);
237       auto innerResults = block->getTerminator()->getOperands();
238       assert(innerResults.size() == 1 && "expect unique inner result");
239       return bvm.lookup(innerResults.front());
240     }
241 
242     // Materialize tiled `linalg.generic` op.
243     auto outputTy = output.getType().cast<RankedTensorType>();
244     RankedTensorType subResultTy;
245     if (subset.getType().isa<TileType>()) {
246       subResultTy =
247           RankedTensorType::get(subset.getType().cast<TileType>().getShape(),
248                                 outputTy.getElementType());
249     } else {
250       // Replace the materialized operand: it must be a tensor.
251       subResultTy = RankedTensorType::get(
252           SmallVector<int64_t>(outputTy.getShape().size(), 1),
253           outputTy.getElementType());
254       materializedOperands.back() = builder.create<tensor::FromElementsOp>(
255           loc, subResultTy, materializedOperands.back());
256     }
257 
258     linalg::LinalgOp linalgOp = genericOp;
259     auto outputOp = cast<linalg::GenericOp>(
260         *linalgOp.clone(builder, loc, subResultTy, materializedOperands));
261 
262     // If any operands are points...
263     if (llvm::is_contained(operandsArePoints, true)) {
264       SmallVector<AffineMap> newIndexingMaps;
265       for (const auto&& [isPoint, indexingMap] :
266            llvm::zip(operandsArePoints, indexingMaps)) {
267         if (isPoint) {
268           // Replace the affine map for the input with (...) -> () - the input
269           // to the linalg.generic is a scalar.
270           newIndexingMaps.push_back(AffineMap::get(indexingMap.getNumInputs(),
271                                                    indexingMap.getNumSymbols(),
272                                                    indexingMap.getContext()));
273         } else {
274           newIndexingMaps.push_back(indexingMap);
275         }
276       }
277       newIndexingMaps.push_back(indexingMaps.back());
278       outputOp.setIndexingMapsAttr(
279           builder.getAffineMapArrayAttr(newIndexingMaps));
280     }
281 
282     Value result = outputOp.getResults().front();
283     if (subset.getType().isa<PointType>()) {
284       result = builder.create<tensor::ExtractOp>(
285           loc, result,
286           SmallVector<Value>(outputTy.getShape().size(),
287                              builder.create<arith::ConstantIndexOp>(loc, 0)));
288     }
289     return result;
290   }
291 };
292 
293 }  // namespace
294 
registerFusionInterfaceExternalModels(DialectRegistry & registry)295 void registerFusionInterfaceExternalModels(DialectRegistry& registry) {
296   registry.insert<linalg::LinalgDialect>();
297   registry.addExtension(+[](MLIRContext* ctx, linalg::LinalgDialect*) {
298     linalg::GenericOp::attachInterface<LinalgGenericFusionInterface>(*ctx);
299   });
300 }
301 
302 }  // namespace gml_st
303 }  // namespace mlir
304