1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "mlir-hlo/Dialect/gml_st/transforms/fusion_interface_impl.h"
17
18 #include <functional>
19 #include <tuple>
20 #include <utility>
21
22 #include "llvm/ADT/ArrayRef.h"
23 #include "llvm/ADT/Optional.h"
24 #include "llvm/ADT/STLExtras.h"
25 #include "llvm/ADT/SmallVector.h"
26 #include "llvm/ADT/iterator_range.h"
27 #include "llvm/Support/Casting.h"
28 #include "llvm/Support/Error.h"
29 #include "llvm/Support/MathExtras.h"
30 #include "mlir-hlo/Dialect/gml_st/IR/gml_st_ops.h"
31 #include "mlir-hlo/Dialect/gml_st/transforms/fusion_interface.h"
32 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
33 #include "mlir/Dialect/Linalg/IR/Linalg.h"
34 #include "mlir/Dialect/Tensor/IR/Tensor.h"
35 #include "mlir/IR/AffineExpr.h"
36 #include "mlir/IR/BuiltinAttributes.h"
37 #include "mlir/IR/BuiltinTypes.h"
38 #include "mlir/IR/Location.h"
39 #include "mlir/IR/TypeRange.h"
40
41 namespace mlir {
42 namespace gml_st {
43
44 namespace {
45
46 // Whether the operand needs the materialization of a point, given the output
47 // subset.
48 // This is the case if a) the outputput subset is a point and b) there are
49 // no reductions.
operandMaterializesToPoint(Value operand,const SmallVector<int64_t> & operandDimsToOutputDims,Value subset)50 bool operandMaterializesToPoint(
51 Value operand, const SmallVector<int64_t>& operandDimsToOutputDims,
52 Value subset) {
53 if (!subset.getType().isa<PointType>()) return false;
54
55 const auto& operandShape =
56 operand.getType().cast<RankedTensorType>().getShape();
57 return llvm::all_of(llvm::zip(operandDimsToOutputDims, operandShape),
58 [](const auto& e) {
59 auto outputDimIndex = std::get<0>(e);
60 auto operandDimSize = std::get<1>(e);
61 // If the operand dimension maps to an output dim (the
62 // output is already known to be a point), or the
63 // operand's dimensions has a static size of 1, the
64 // operand can materialize as a point.
65 return (outputDimIndex >= 0) || (operandDimSize == 1);
66 });
67 }
68
buildPointOp(Location loc,OpBuilder & builder,Value operand,Value subset,const SmallVector<int64_t> & operandDimsToOutputDims)69 Value buildPointOp(Location loc, OpBuilder& builder, Value operand,
70 Value subset,
71 const SmallVector<int64_t>& operandDimsToOutputDims) {
72 auto operandShape = operand.getType().cast<RankedTensorType>().getShape();
73
74 SmallVector<int64_t> staticOffsets(operandShape.size(),
75 ShapedType::kDynamicStrideOrOffset);
76 SmallVector<Value> dynamicOffsets;
77 for (int i = 0; i < operandShape.size(); ++i) {
78 if (int outputDim = operandDimsToOutputDims[i]; outputDim >= 0) {
79 auto index = builder.create<arith::ConstantIndexOp>(loc, outputDim);
80 dynamicOffsets.push_back(builder.create<OffsetOp>(loc, subset, index));
81 } else {
82 staticOffsets[i] = 0;
83 }
84 }
85
86 SmallVector<int64_t> staticSizes(operandShape.size(), 1);
87 auto staticSizesAttr = builder.getI64ArrayAttr(staticSizes);
88 SpaceOp spaceOp =
89 builder.create<SpaceOp>(loc, builder.getType<TileType>(staticSizes),
90 ValueRange{}, staticSizesAttr);
91
92 auto staticOffsetsAttr = builder.getI64ArrayAttr(staticOffsets);
93 return builder.create<PointOp>(loc, builder.getType<PointType>(), spaceOp,
94 dynamicOffsets, staticOffsetsAttr);
95 }
96
buildTileOp(Location loc,OpBuilder & builder,Value operand,Value subset,const SmallVector<int64_t> & operandDimsToOutputDims)97 Value buildTileOp(Location loc, OpBuilder& builder, Value operand, Value subset,
98 const SmallVector<int64_t>& operandDimsToOutputDims) {
99 auto operandRank = operand.getType().cast<RankedTensorType>().getRank();
100
101 SmallVector<int64_t> staticSizes(operandRank, ShapedType::kDynamicSize);
102 SmallVector<int64_t> staticStrides(operandRank,
103 ShapedType::kDynamicStrideOrOffset);
104 SmallVector<int64_t> staticOffsets(operandRank,
105 ShapedType::kDynamicStrideOrOffset);
106
107 SmallVector<Value> dynamicSizes;
108 SmallVector<Value> dynamicStrides;
109 SmallVector<Value> dynamicOffsets;
110 for (int i = 0; i < operandRank; ++i) {
111 if (int outputDim = operandDimsToOutputDims[i]; outputDim >= 0) {
112 auto index = builder.create<arith::ConstantIndexOp>(loc, outputDim);
113 dynamicOffsets.push_back(builder.create<OffsetOp>(loc, subset, index));
114 if (subset.getType().isa<PointType>()) {
115 staticSizes[i] = 1;
116 staticStrides[i] = 1;
117 } else {
118 dynamicStrides.push_back(builder.create<StrideOp>(loc, subset, index));
119 dynamicSizes.push_back(builder.create<SizeOp>(loc, subset, index));
120 }
121 } else {
122 staticOffsets[i] = 0;
123 staticStrides[i] = 1;
124 dynamicSizes.push_back(builder.create<tensor::DimOp>(loc, operand, i));
125 }
126 }
127
128 auto staticSizesAttr = builder.getI64ArrayAttr(staticSizes);
129 auto tileType = builder.getType<TileType>(staticSizes);
130 SpaceOp spaceOp =
131 builder.create<SpaceOp>(loc, tileType, dynamicSizes, staticSizesAttr);
132
133 auto staticOffsetsAttr = builder.getI64ArrayAttr(staticOffsets);
134 auto staticStridesAttr = builder.getI64ArrayAttr(staticStrides);
135 return builder.create<TileOp>(loc, tileType, spaceOp, dynamicOffsets,
136 dynamicSizes, dynamicStrides, staticOffsetsAttr,
137 staticSizesAttr, staticStridesAttr);
138 }
139
140 // For each iterator, returns the dimension in the output affine map where it
141 // occurs (unless it's a reduction iterator).
mapIteratorsToOutputs(AffineMap outputMap)142 Optional<SmallVector<Optional<int32_t>>> mapIteratorsToOutputs(
143 AffineMap outputMap) {
144 SmallVector<Optional<int32_t>> result(outputMap.getNumInputs());
145 for (uint32_t i = 0; i < outputMap.getNumResults(); ++i) {
146 auto dim = outputMap.getResult(i).dyn_cast<AffineDimExpr>();
147 if (!dim) return {};
148 if (result[dim.getPosition()]) return {};
149 result[dim.getPosition()] = i;
150 }
151 return result;
152 }
153
154 struct LinalgGenericFusionInterface
155 : public FusionInterface::ExternalModel<LinalgGenericFusionInterface,
156 linalg::GenericOp> {
157 // Supports linalg.generics with a single output, if all output dimensions in
158 // all affine maps are affine dimensions (e.g.., (a,b,c) -> (a,b), but not
159 // (a,b,c) -> (a, 0)).
160 // See the test file tiling_and_fusion.mlir for examples.
fusemlir::gml_st::__anon49664e550111::LinalgGenericFusionInterface161 Value fuse(Operation* op, Location loc, Value subset,
162 OpBuilder& builder) const {
163 auto genericOp = llvm::cast<linalg::GenericOp>(op);
164 if (genericOp.getNumOutputs() != 1) return {};
165 Value output = genericOp.outputs().front();
166 auto outputRank = output.getType().cast<RankedTensorType>().getRank();
167
168 auto indexingMaps =
169 to_vector(genericOp.indexing_maps().getAsValueRange<AffineMapAttr>());
170 auto maybeIteratorsToOutputs = mapIteratorsToOutputs(indexingMaps.back());
171 if (!maybeIteratorsToOutputs) return {};
172 const SmallVector<Optional<int32_t>>& iteratorsToOutputs =
173 *maybeIteratorsToOutputs;
174
175 SmallVector<Value> materializedOperands;
176 SmallVector<bool> operandsArePoints;
177 for (const auto&& [operand, operandMap] :
178 llvm::zip(genericOp.inputs(), indexingMaps)) {
179 // Mapping from an operand dimension to an output dimension, or -1 if it
180 // doesn't occur in the output.
181 SmallVector<int64_t> operandDimsToOutputDims;
182
183 // Whether the composition of the inverse of the operand's affine map and
184 // the output's affine map is the identity function (i.e., a given output
185 // coordinate maps to the same coordinate in the input).
186 bool isIdentity = operandMap.getResults().size() == outputRank;
187 SmallVector<bool> containsDim(outputRank);
188 for (const AffineExpr& expression : operandMap.getResults()) {
189 auto dim = expression.dyn_cast<AffineDimExpr>();
190 if (!dim) return {};
191 auto output = iteratorsToOutputs[dim.getPosition()];
192 operandDimsToOutputDims.push_back(output.value_or(-1));
193 if (output) containsDim[*output] = true;
194 isIdentity &= output.value_or(-1) == operandDimsToOutputDims.size() - 1;
195 }
196
197 Value operandSubset;
198 if (isIdentity) {
199 operandSubset = subset;
200 operandsArePoints.push_back(subset.getType().isa<PointType>());
201 } else if (operandDimsToOutputDims.size() == outputRank &&
202 !llvm::is_contained(containsDim, false)) {
203 operandSubset = builder.create<TransposeDimsOp>(
204 loc, subset,
205 DenseI64ArrayAttr::get(builder.getContext(),
206 operandDimsToOutputDims));
207 operandsArePoints.push_back(subset.getType().isa<PointType>());
208 } else if (operandMaterializesToPoint(operand, operandDimsToOutputDims,
209 subset)) {
210 operandSubset = buildPointOp(loc, builder, operand, subset,
211 operandDimsToOutputDims);
212 operandsArePoints.push_back(true);
213 } else {
214 operandSubset =
215 buildTileOp(loc, builder, operand, subset, operandDimsToOutputDims);
216 operandsArePoints.push_back(false);
217 }
218
219 materializedOperands.push_back(
220 builder.create<MaterializeOp>(loc, operand, operandSubset));
221 }
222
223 materializedOperands.push_back(
224 builder.create<MaterializeOp>(loc, output, subset));
225 if (!llvm::is_contained(operandsArePoints, false)) {
226 // Create scalar computation by copying from the `linalg.generic`
227 // body.
228 BlockAndValueMapping bvm;
229 Block* block = genericOp.getBody();
230 assert(block->getArguments().size() == materializedOperands.size() &&
231 "block argument count and sub operand count should be equal");
232 for (const auto&& [arg, materialized] :
233 llvm::zip(block->getArguments(), materializedOperands)) {
234 bvm.map(arg, materialized);
235 }
236 for (auto& it : block->without_terminator()) builder.clone(it, bvm);
237 auto innerResults = block->getTerminator()->getOperands();
238 assert(innerResults.size() == 1 && "expect unique inner result");
239 return bvm.lookup(innerResults.front());
240 }
241
242 // Materialize tiled `linalg.generic` op.
243 auto outputTy = output.getType().cast<RankedTensorType>();
244 RankedTensorType subResultTy;
245 if (subset.getType().isa<TileType>()) {
246 subResultTy =
247 RankedTensorType::get(subset.getType().cast<TileType>().getShape(),
248 outputTy.getElementType());
249 } else {
250 // Replace the materialized operand: it must be a tensor.
251 subResultTy = RankedTensorType::get(
252 SmallVector<int64_t>(outputTy.getShape().size(), 1),
253 outputTy.getElementType());
254 materializedOperands.back() = builder.create<tensor::FromElementsOp>(
255 loc, subResultTy, materializedOperands.back());
256 }
257
258 linalg::LinalgOp linalgOp = genericOp;
259 auto outputOp = cast<linalg::GenericOp>(
260 *linalgOp.clone(builder, loc, subResultTy, materializedOperands));
261
262 // If any operands are points...
263 if (llvm::is_contained(operandsArePoints, true)) {
264 SmallVector<AffineMap> newIndexingMaps;
265 for (const auto&& [isPoint, indexingMap] :
266 llvm::zip(operandsArePoints, indexingMaps)) {
267 if (isPoint) {
268 // Replace the affine map for the input with (...) -> () - the input
269 // to the linalg.generic is a scalar.
270 newIndexingMaps.push_back(AffineMap::get(indexingMap.getNumInputs(),
271 indexingMap.getNumSymbols(),
272 indexingMap.getContext()));
273 } else {
274 newIndexingMaps.push_back(indexingMap);
275 }
276 }
277 newIndexingMaps.push_back(indexingMaps.back());
278 outputOp.setIndexingMapsAttr(
279 builder.getAffineMapArrayAttr(newIndexingMaps));
280 }
281
282 Value result = outputOp.getResults().front();
283 if (subset.getType().isa<PointType>()) {
284 result = builder.create<tensor::ExtractOp>(
285 loc, result,
286 SmallVector<Value>(outputTy.getShape().size(),
287 builder.create<arith::ConstantIndexOp>(loc, 0)));
288 }
289 return result;
290 }
291 };
292
293 } // namespace
294
registerFusionInterfaceExternalModels(DialectRegistry & registry)295 void registerFusionInterfaceExternalModels(DialectRegistry& registry) {
296 registry.insert<linalg::LinalgDialect>();
297 registry.addExtension(+[](MLIRContext* ctx, linalg::LinalgDialect*) {
298 linalg::GenericOp::attachInterface<LinalgGenericFusionInterface>(*ctx);
299 });
300 }
301
302 } // namespace gml_st
303 } // namespace mlir
304