1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // This file provides basic utilities for the elemental lowering of
17 // each node
18
19 #include "mlir-hlo/Dialect/lhlo/transforms/lhlo_elemental_utils.h"
20
21 #include "llvm/Support/Debug.h"
22 #include "mlir-hlo/Dialect/lhlo/IR/lhlo_ops.h"
23 #include "mlir-hlo/Dialect/lhlo/transforms/map_lmhlo_to_scalar_op.h"
24 #include "mlir-hlo/utils/codegen_utils.h"
25 #include "mlir/Dialect/Func/IR/FuncOps.h"
26 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
27 #include "mlir/Dialect/MemRef/IR/MemRef.h"
28 #include "mlir/Dialect/SCF/IR/SCF.h"
29 #include "mlir/IR/Attributes.h"
30 #include "mlir/IR/BuiltinTypes.h"
31 #include "mlir/IR/Location.h"
32 #include "mlir/IR/MLIRContext.h"
33 #include "mlir/IR/PatternMatch.h"
34 #include "mlir/Pass/Pass.h"
35
36 using mlir::memref::DimOp;
37 using mlir::memref::LoadOp;
38 using mlir::memref::StoreOp;
39
40 namespace mlir {
41 namespace lmhlo {
42
createLoadOrUseCachedValue(Location loc,OpBuilder * b,Value memref,ValueRange indices,OpBuilder::InsertPoint insertPoint)43 Value createLoadOrUseCachedValue(Location loc, OpBuilder* b, Value memref,
44 ValueRange indices,
45 OpBuilder::InsertPoint insertPoint) {
46 // Check if there are any cached value that can be reused,
47 // within the current Block. Alternatively we can do this for
48 // all the Blocks that dominant this Block, but that will be
49 // complicated anyway.
50 std::vector<StoreOp> storeOps;
51 insertPoint.getBlock()->walk(
52 insertPoint.getBlock()->begin(), insertPoint.getPoint(),
53 [&](StoreOp storeOp) {
54 if (storeOp.getOperation()->getBlock() != insertPoint.getBlock())
55 return;
56 if ((storeOp.getMemRef() == memref) &&
57 (storeOp.getIndices() == indices))
58 storeOps.emplace_back(storeOp);
59 });
60 if (!storeOps.empty()) return storeOps[0].getOperand(0);
61 int rank = memref.getType().dyn_cast<MemRefType>().getRank();
62 return rank > 0 ? b->create<LoadOp>(loc, memref, indices)
63 : b->create<LoadOp>(loc, memref);
64 }
65
noLoaderUser(SmallVectorImpl<Operation * > & ops)66 DenseSet<Operation*> noLoaderUser(SmallVectorImpl<Operation*>& ops) {
67 SmallVector<Operation*, 4> worklist;
68 DenseSet<Operation*> hasLoaderOps;
69 for (Operation* op : ops) {
70 Value memref = cast<LmhloOp>(op).getResultBuffer();
71 if (memref == nullptr) continue;
72 for (auto* user : memref.getUsers()) {
73 if (isa<memref::LoadOp>(user)) {
74 worklist.push_back(op);
75 hasLoaderOps.insert(op);
76 }
77 }
78 }
79
80 while (!worklist.empty()) {
81 Operation* op = worklist.pop_back_val();
82 int numOperands = op->getNumOperands();
83 for (int i = 0; i < numOperands - 1; ++i) {
84 Value memref = op->getOperand(i);
85 for (Operation* user : memref.getUsers()) {
86 if ((!isa<LmhloOp>(user)) || hasLoaderOps.count(user)) continue;
87 if (cast<LmhloOp>(user).getResultBuffer() == memref) {
88 worklist.push_back(user);
89 hasLoaderOps.insert(user);
90 }
91 }
92 }
93 }
94
95 DenseSet<Operation*> noLoaderOps;
96 for (Operation* op : ops)
97 if (!hasLoaderOps.count(op)) noLoaderOps.insert(op);
98 return noLoaderOps;
99 }
100
cleanUnusedLhloOps(Block * parent)101 void cleanUnusedLhloOps(Block* parent) {
102 SmallVector<Operation*, 4> lhloOps;
103 for (Operation& op : parent->getOperations()) {
104 if (op.getDialect() == op.getContext()->getLoadedDialect("lmhlo") &&
105 (!isa<lmhlo::TerminatorOp>(op)))
106 lhloOps.push_back(&op);
107 }
108 for (auto* lhloOp : noLoaderUser(lhloOps)) lhloOp->erase();
109 }
110
111 template <typename LHLO_OpTy>
112 Value elementalLower(OpBuilder* b, Location loc, LHLO_OpTy op,
113 ValueRange outputIndex, bool checkCache);
114
115 template <>
elementalLower(OpBuilder * b,Location loc,lmhlo::RealDynamicSliceOp op,ValueRange outputIndex,bool checkCache)116 Value elementalLower<lmhlo::RealDynamicSliceOp>(OpBuilder* b, Location loc,
117 lmhlo::RealDynamicSliceOp op,
118 ValueRange outputIndex,
119 bool checkCache) {
120 Value startIndicesMemref = op->getOperand(1);
121 Value stridesMemref = op->getOperand(3);
122 int rank = outputIndex.size();
123 SmallVector<Value, 4> inputIndex;
124 for (int dim = 0; dim < rank; ++dim) {
125 SmallVector<Value, 4> dimIndex;
126 dimIndex.push_back(b->create<arith::ConstantOp>(
127 loc, b->getIndexType(), b->getIntegerAttr(b->getIndexType(), dim)));
128 auto startIndexLoad =
129 b->create<LoadOp>(loc, startIndicesMemref, ValueRange{dimIndex});
130 auto startIndex =
131 b->create<arith::IndexCastOp>(loc, b->getIndexType(), startIndexLoad);
132 auto strideLoad =
133 b->create<LoadOp>(loc, stridesMemref, ValueRange{dimIndex});
134 auto stride =
135 b->create<arith::IndexCastOp>(loc, b->getIndexType(), strideLoad);
136 // input_dim = out_dim * stride + start_index
137 auto inputDim = b->create<arith::AddIOp>(
138 loc, b->create<arith::MulIOp>(loc, outputIndex[dim], stride),
139 startIndex);
140 inputIndex.push_back(inputDim);
141 }
142
143 Value operandMemref = *(op->getOperands().begin());
144
145 if (!checkCache) return b->create<LoadOp>(loc, operandMemref, inputIndex);
146 return createLoadOrUseCachedValue(loc, b, operandMemref, inputIndex,
147 b->saveInsertionPoint());
148 }
149
150 namespace {
151
152 template <typename T>
elementalLowerImplForBroadcastInDimOps(OpBuilder * b,Location loc,T broadcastInDim,ValueRange outputIndex,bool checkCache)153 Value elementalLowerImplForBroadcastInDimOps(OpBuilder* b, Location loc,
154 T broadcastInDim,
155 ValueRange outputIndex,
156 bool checkCache) {
157 auto broadcastDimensions =
158 broadcastInDim.getBroadcastDimensions().template getValues<int64_t>();
159 int outRank = outputIndex.size();
160 Value operandMemref = broadcastInDim->getOperand(0);
161 SmallVector<Value, 4> inputIndex;
162 for (int64_t dim = 0; dim < outRank; ++dim) {
163 auto it =
164 std::find(broadcastDimensions.begin(), broadcastDimensions.end(), dim);
165
166 bool isBroadcastDim = (it != broadcastDimensions.end());
167 if (isBroadcastDim) {
168 int inputDim = std::distance(broadcastDimensions.begin(), it);
169 int64_t staticDimSize =
170 operandMemref.getType().cast<MemRefType>().getShape()[inputDim];
171 if (staticDimSize == 1) {
172 // we know this dim is to be broadcasted at compile time
173 auto zero = b->create<arith::ConstantOp>(
174 loc, b->getIndexType(), b->getIntegerAttr(b->getIndexType(), 0));
175 inputIndex.push_back(zero);
176 } else if (staticDimSize == ShapedType::kDynamicSize) {
177 // we are not sure if this dim is to be broadcasted at compile time
178 auto dimSize = b->create<DimOp>(loc, operandMemref, inputDim);
179 auto one = b->create<arith::ConstantOp>(
180 loc, b->getIndexType(), b->getIntegerAttr(b->getIndexType(), 1));
181 auto zero = b->create<arith::ConstantOp>(
182 loc, b->getIndexType(), b->getIntegerAttr(b->getIndexType(), 0));
183 auto dimSizeIs1 = b->create<arith::CmpIOp>(
184 loc, arith::CmpIPredicate::eq, dimSize, one);
185 inputIndex.push_back(b->create<mlir::arith::SelectOp>(
186 loc, dimSizeIs1, zero, outputIndex[dim]));
187 } else {
188 // we know this dim is not to be broadcasted at compile time
189 inputIndex.push_back(outputIndex[dim]);
190 }
191 }
192 }
193
194 if (!checkCache) {
195 int rank = operandMemref.getType().dyn_cast<MemRefType>().getRank();
196 return (rank > 0) ? b->create<LoadOp>(loc, operandMemref, inputIndex)
197 : b->create<LoadOp>(loc, operandMemref, ValueRange());
198 }
199 return createLoadOrUseCachedValue(loc, b, operandMemref, inputIndex,
200 b->saveInsertionPoint());
201 }
202
203 } // namespace
204
205 template <>
elementalLower(OpBuilder * b,Location loc,lmhlo::DynamicBroadcastInDimOp op,ValueRange outputIndex,bool checkCache)206 Value elementalLower<lmhlo::DynamicBroadcastInDimOp>(
207 OpBuilder* b, Location loc, lmhlo::DynamicBroadcastInDimOp op,
208 ValueRange outputIndex, bool checkCache) {
209 return elementalLowerImplForBroadcastInDimOps(b, loc, op, outputIndex,
210 checkCache);
211 }
212
213 template <>
elementalLower(OpBuilder * b,Location loc,lmhlo::BroadcastInDimOp op,ValueRange outputIndex,bool checkCache)214 Value elementalLower<lmhlo::BroadcastInDimOp>(OpBuilder* b, Location loc,
215 lmhlo::BroadcastInDimOp op,
216 ValueRange outputIndex,
217 bool checkCache) {
218 return elementalLowerImplForBroadcastInDimOps(b, loc, op, outputIndex,
219 checkCache);
220 }
221
createLoopAndSetInsPt(OpBuilder & b,Location loc,Value & var,Value lb,Value ub,Value step,ArrayRef<Value> initValues)222 scf::ForOp createLoopAndSetInsPt(OpBuilder& b, Location loc, Value& var,
223 Value lb, Value ub, Value step,
224 ArrayRef<Value> initValues) {
225 auto forOp = b.create<scf::ForOp>(loc, lb, ub, step, initValues);
226 b.setInsertionPointToStart(forOp.getBody());
227 var = forOp.getInductionVar();
228 return forOp;
229 }
230
createParallelAndSetInsPt(OpBuilder & b,Location loc,SmallVectorImpl<Value> & vars,ArrayRef<Value> lbs,ArrayRef<Value> ubs,ArrayRef<Value> steps,ArrayRef<Value> initValues)231 scf::ParallelOp createParallelAndSetInsPt(OpBuilder& b, Location loc,
232 SmallVectorImpl<Value>& vars,
233 ArrayRef<Value> lbs,
234 ArrayRef<Value> ubs,
235 ArrayRef<Value> steps,
236 ArrayRef<Value> initValues) {
237 auto parOp = b.create<scf::ParallelOp>(loc, lbs, ubs, steps, initValues,
238 /*bodyBuilderFn=*/nullptr);
239 b.setInsertionPointToStart(parOp.getBody());
240 vars.append(parOp.getInductionVars().begin(), parOp.getInductionVars().end());
241 return parOp;
242 }
243
244 // reinterpret_cast the input memref into 1D
createMemRef1DReinterpretCast(OpBuilder & b,Location loc,Value memref)245 memref::ReinterpretCastOp createMemRef1DReinterpretCast(OpBuilder& b,
246 Location loc,
247 Value memref) {
248 auto memrefTy = memref.getType().cast<MemRefType>();
249 assert(memrefTy.getLayout().isIdentity());
250 Value size = codegen_utils::emitNumElementsComputation(b, loc, memref);
251 Value stride = b.create<mlir::arith::ConstantOp>(
252 loc, b.getIndexType(), b.getIntegerAttr(b.getIndexType(), 1));
253 Value zero = b.create<mlir::arith::ConstantOp>(
254 loc, b.getIndexType(), b.getIntegerAttr(b.getIndexType(), 0));
255 auto memref1dType =
256 MemRefType::get({ShapedType::kDynamicSize}, memrefTy.getElementType(),
257 b.getMultiDimIdentityMap(1), memrefTy.getMemorySpace());
258 return b.create<memref::ReinterpretCastOp>(
259 loc, memref1dType, memref, zero, ValueRange{size}, ValueRange{stride});
260 }
261
createOffsetStore(OpBuilder & b,Location loc,Value res,Value memref,Value offset)262 void createOffsetStore(OpBuilder& b, Location loc, Value res, Value memref,
263 Value offset) {
264 Value memref1d = createMemRef1DReinterpretCast(b, loc, memref);
265 b.create<memref::StoreOp>(loc, res, memref1d, ValueRange{offset});
266 }
267
createOffsetLoad(OpBuilder & b,Location loc,Value memref,Value offset)268 memref::LoadOp createOffsetLoad(OpBuilder& b, Location loc, Value memref,
269 Value offset) {
270 Value memref1d = createMemRef1DReinterpretCast(b, loc, memref);
271 return b.create<memref::LoadOp>(loc, memref1d, ValueRange{offset});
272 }
273
274 } // namespace lmhlo
275 } // namespace mlir
276