1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // This file provides basic utilities for the elemental lowering of
17 // each node
18 
19 #include "mlir-hlo/Dialect/lhlo/transforms/lhlo_elemental_utils.h"
20 
21 #include "llvm/Support/Debug.h"
22 #include "mlir-hlo/Dialect/lhlo/IR/lhlo_ops.h"
23 #include "mlir-hlo/Dialect/lhlo/transforms/map_lmhlo_to_scalar_op.h"
24 #include "mlir-hlo/utils/codegen_utils.h"
25 #include "mlir/Dialect/Func/IR/FuncOps.h"
26 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
27 #include "mlir/Dialect/MemRef/IR/MemRef.h"
28 #include "mlir/Dialect/SCF/IR/SCF.h"
29 #include "mlir/IR/Attributes.h"
30 #include "mlir/IR/BuiltinTypes.h"
31 #include "mlir/IR/Location.h"
32 #include "mlir/IR/MLIRContext.h"
33 #include "mlir/IR/PatternMatch.h"
34 #include "mlir/Pass/Pass.h"
35 
36 using mlir::memref::DimOp;
37 using mlir::memref::LoadOp;
38 using mlir::memref::StoreOp;
39 
40 namespace mlir {
41 namespace lmhlo {
42 
createLoadOrUseCachedValue(Location loc,OpBuilder * b,Value memref,ValueRange indices,OpBuilder::InsertPoint insertPoint)43 Value createLoadOrUseCachedValue(Location loc, OpBuilder* b, Value memref,
44                                  ValueRange indices,
45                                  OpBuilder::InsertPoint insertPoint) {
46   // Check if there are any cached value that can be reused,
47   // within the current Block. Alternatively we can do this for
48   // all the Blocks that dominant this Block, but that will be
49   // complicated anyway.
50   std::vector<StoreOp> storeOps;
51   insertPoint.getBlock()->walk(
52       insertPoint.getBlock()->begin(), insertPoint.getPoint(),
53       [&](StoreOp storeOp) {
54         if (storeOp.getOperation()->getBlock() != insertPoint.getBlock())
55           return;
56         if ((storeOp.getMemRef() == memref) &&
57             (storeOp.getIndices() == indices))
58           storeOps.emplace_back(storeOp);
59       });
60   if (!storeOps.empty()) return storeOps[0].getOperand(0);
61   int rank = memref.getType().dyn_cast<MemRefType>().getRank();
62   return rank > 0 ? b->create<LoadOp>(loc, memref, indices)
63                   : b->create<LoadOp>(loc, memref);
64 }
65 
noLoaderUser(SmallVectorImpl<Operation * > & ops)66 DenseSet<Operation*> noLoaderUser(SmallVectorImpl<Operation*>& ops) {
67   SmallVector<Operation*, 4> worklist;
68   DenseSet<Operation*> hasLoaderOps;
69   for (Operation* op : ops) {
70     Value memref = cast<LmhloOp>(op).getResultBuffer();
71     if (memref == nullptr) continue;
72     for (auto* user : memref.getUsers()) {
73       if (isa<memref::LoadOp>(user)) {
74         worklist.push_back(op);
75         hasLoaderOps.insert(op);
76       }
77     }
78   }
79 
80   while (!worklist.empty()) {
81     Operation* op = worklist.pop_back_val();
82     int numOperands = op->getNumOperands();
83     for (int i = 0; i < numOperands - 1; ++i) {
84       Value memref = op->getOperand(i);
85       for (Operation* user : memref.getUsers()) {
86         if ((!isa<LmhloOp>(user)) || hasLoaderOps.count(user)) continue;
87         if (cast<LmhloOp>(user).getResultBuffer() == memref) {
88           worklist.push_back(user);
89           hasLoaderOps.insert(user);
90         }
91       }
92     }
93   }
94 
95   DenseSet<Operation*> noLoaderOps;
96   for (Operation* op : ops)
97     if (!hasLoaderOps.count(op)) noLoaderOps.insert(op);
98   return noLoaderOps;
99 }
100 
cleanUnusedLhloOps(Block * parent)101 void cleanUnusedLhloOps(Block* parent) {
102   SmallVector<Operation*, 4> lhloOps;
103   for (Operation& op : parent->getOperations()) {
104     if (op.getDialect() == op.getContext()->getLoadedDialect("lmhlo") &&
105         (!isa<lmhlo::TerminatorOp>(op)))
106       lhloOps.push_back(&op);
107   }
108   for (auto* lhloOp : noLoaderUser(lhloOps)) lhloOp->erase();
109 }
110 
111 template <typename LHLO_OpTy>
112 Value elementalLower(OpBuilder* b, Location loc, LHLO_OpTy op,
113                      ValueRange outputIndex, bool checkCache);
114 
115 template <>
elementalLower(OpBuilder * b,Location loc,lmhlo::RealDynamicSliceOp op,ValueRange outputIndex,bool checkCache)116 Value elementalLower<lmhlo::RealDynamicSliceOp>(OpBuilder* b, Location loc,
117                                                 lmhlo::RealDynamicSliceOp op,
118                                                 ValueRange outputIndex,
119                                                 bool checkCache) {
120   Value startIndicesMemref = op->getOperand(1);
121   Value stridesMemref = op->getOperand(3);
122   int rank = outputIndex.size();
123   SmallVector<Value, 4> inputIndex;
124   for (int dim = 0; dim < rank; ++dim) {
125     SmallVector<Value, 4> dimIndex;
126     dimIndex.push_back(b->create<arith::ConstantOp>(
127         loc, b->getIndexType(), b->getIntegerAttr(b->getIndexType(), dim)));
128     auto startIndexLoad =
129         b->create<LoadOp>(loc, startIndicesMemref, ValueRange{dimIndex});
130     auto startIndex =
131         b->create<arith::IndexCastOp>(loc, b->getIndexType(), startIndexLoad);
132     auto strideLoad =
133         b->create<LoadOp>(loc, stridesMemref, ValueRange{dimIndex});
134     auto stride =
135         b->create<arith::IndexCastOp>(loc, b->getIndexType(), strideLoad);
136     // input_dim = out_dim * stride + start_index
137     auto inputDim = b->create<arith::AddIOp>(
138         loc, b->create<arith::MulIOp>(loc, outputIndex[dim], stride),
139         startIndex);
140     inputIndex.push_back(inputDim);
141   }
142 
143   Value operandMemref = *(op->getOperands().begin());
144 
145   if (!checkCache) return b->create<LoadOp>(loc, operandMemref, inputIndex);
146   return createLoadOrUseCachedValue(loc, b, operandMemref, inputIndex,
147                                     b->saveInsertionPoint());
148 }
149 
150 namespace {
151 
152 template <typename T>
elementalLowerImplForBroadcastInDimOps(OpBuilder * b,Location loc,T broadcastInDim,ValueRange outputIndex,bool checkCache)153 Value elementalLowerImplForBroadcastInDimOps(OpBuilder* b, Location loc,
154                                              T broadcastInDim,
155                                              ValueRange outputIndex,
156                                              bool checkCache) {
157   auto broadcastDimensions =
158       broadcastInDim.getBroadcastDimensions().template getValues<int64_t>();
159   int outRank = outputIndex.size();
160   Value operandMemref = broadcastInDim->getOperand(0);
161   SmallVector<Value, 4> inputIndex;
162   for (int64_t dim = 0; dim < outRank; ++dim) {
163     auto it =
164         std::find(broadcastDimensions.begin(), broadcastDimensions.end(), dim);
165 
166     bool isBroadcastDim = (it != broadcastDimensions.end());
167     if (isBroadcastDim) {
168       int inputDim = std::distance(broadcastDimensions.begin(), it);
169       int64_t staticDimSize =
170           operandMemref.getType().cast<MemRefType>().getShape()[inputDim];
171       if (staticDimSize == 1) {
172         // we know this dim is to be broadcasted at compile time
173         auto zero = b->create<arith::ConstantOp>(
174             loc, b->getIndexType(), b->getIntegerAttr(b->getIndexType(), 0));
175         inputIndex.push_back(zero);
176       } else if (staticDimSize == ShapedType::kDynamicSize) {
177         // we are not sure if this dim is to be broadcasted at compile time
178         auto dimSize = b->create<DimOp>(loc, operandMemref, inputDim);
179         auto one = b->create<arith::ConstantOp>(
180             loc, b->getIndexType(), b->getIntegerAttr(b->getIndexType(), 1));
181         auto zero = b->create<arith::ConstantOp>(
182             loc, b->getIndexType(), b->getIntegerAttr(b->getIndexType(), 0));
183         auto dimSizeIs1 = b->create<arith::CmpIOp>(
184             loc, arith::CmpIPredicate::eq, dimSize, one);
185         inputIndex.push_back(b->create<mlir::arith::SelectOp>(
186             loc, dimSizeIs1, zero, outputIndex[dim]));
187       } else {
188         // we know this dim is not to be broadcasted at compile time
189         inputIndex.push_back(outputIndex[dim]);
190       }
191     }
192   }
193 
194   if (!checkCache) {
195     int rank = operandMemref.getType().dyn_cast<MemRefType>().getRank();
196     return (rank > 0) ? b->create<LoadOp>(loc, operandMemref, inputIndex)
197                       : b->create<LoadOp>(loc, operandMemref, ValueRange());
198   }
199   return createLoadOrUseCachedValue(loc, b, operandMemref, inputIndex,
200                                     b->saveInsertionPoint());
201 }
202 
203 }  // namespace
204 
205 template <>
elementalLower(OpBuilder * b,Location loc,lmhlo::DynamicBroadcastInDimOp op,ValueRange outputIndex,bool checkCache)206 Value elementalLower<lmhlo::DynamicBroadcastInDimOp>(
207     OpBuilder* b, Location loc, lmhlo::DynamicBroadcastInDimOp op,
208     ValueRange outputIndex, bool checkCache) {
209   return elementalLowerImplForBroadcastInDimOps(b, loc, op, outputIndex,
210                                                 checkCache);
211 }
212 
213 template <>
elementalLower(OpBuilder * b,Location loc,lmhlo::BroadcastInDimOp op,ValueRange outputIndex,bool checkCache)214 Value elementalLower<lmhlo::BroadcastInDimOp>(OpBuilder* b, Location loc,
215                                               lmhlo::BroadcastInDimOp op,
216                                               ValueRange outputIndex,
217                                               bool checkCache) {
218   return elementalLowerImplForBroadcastInDimOps(b, loc, op, outputIndex,
219                                                 checkCache);
220 }
221 
createLoopAndSetInsPt(OpBuilder & b,Location loc,Value & var,Value lb,Value ub,Value step,ArrayRef<Value> initValues)222 scf::ForOp createLoopAndSetInsPt(OpBuilder& b, Location loc, Value& var,
223                                  Value lb, Value ub, Value step,
224                                  ArrayRef<Value> initValues) {
225   auto forOp = b.create<scf::ForOp>(loc, lb, ub, step, initValues);
226   b.setInsertionPointToStart(forOp.getBody());
227   var = forOp.getInductionVar();
228   return forOp;
229 }
230 
createParallelAndSetInsPt(OpBuilder & b,Location loc,SmallVectorImpl<Value> & vars,ArrayRef<Value> lbs,ArrayRef<Value> ubs,ArrayRef<Value> steps,ArrayRef<Value> initValues)231 scf::ParallelOp createParallelAndSetInsPt(OpBuilder& b, Location loc,
232                                           SmallVectorImpl<Value>& vars,
233                                           ArrayRef<Value> lbs,
234                                           ArrayRef<Value> ubs,
235                                           ArrayRef<Value> steps,
236                                           ArrayRef<Value> initValues) {
237   auto parOp = b.create<scf::ParallelOp>(loc, lbs, ubs, steps, initValues,
238                                          /*bodyBuilderFn=*/nullptr);
239   b.setInsertionPointToStart(parOp.getBody());
240   vars.append(parOp.getInductionVars().begin(), parOp.getInductionVars().end());
241   return parOp;
242 }
243 
244 // reinterpret_cast the input memref into 1D
createMemRef1DReinterpretCast(OpBuilder & b,Location loc,Value memref)245 memref::ReinterpretCastOp createMemRef1DReinterpretCast(OpBuilder& b,
246                                                         Location loc,
247                                                         Value memref) {
248   auto memrefTy = memref.getType().cast<MemRefType>();
249   assert(memrefTy.getLayout().isIdentity());
250   Value size = codegen_utils::emitNumElementsComputation(b, loc, memref);
251   Value stride = b.create<mlir::arith::ConstantOp>(
252       loc, b.getIndexType(), b.getIntegerAttr(b.getIndexType(), 1));
253   Value zero = b.create<mlir::arith::ConstantOp>(
254       loc, b.getIndexType(), b.getIntegerAttr(b.getIndexType(), 0));
255   auto memref1dType =
256       MemRefType::get({ShapedType::kDynamicSize}, memrefTy.getElementType(),
257                       b.getMultiDimIdentityMap(1), memrefTy.getMemorySpace());
258   return b.create<memref::ReinterpretCastOp>(
259       loc, memref1dType, memref, zero, ValueRange{size}, ValueRange{stride});
260 }
261 
createOffsetStore(OpBuilder & b,Location loc,Value res,Value memref,Value offset)262 void createOffsetStore(OpBuilder& b, Location loc, Value res, Value memref,
263                        Value offset) {
264   Value memref1d = createMemRef1DReinterpretCast(b, loc, memref);
265   b.create<memref::StoreOp>(loc, res, memref1d, ValueRange{offset});
266 }
267 
createOffsetLoad(OpBuilder & b,Location loc,Value memref,Value offset)268 memref::LoadOp createOffsetLoad(OpBuilder& b, Location loc, Value memref,
269                                 Value offset) {
270   Value memref1d = createMemRef1DReinterpretCast(b, loc, memref);
271   return b.create<memref::LoadOp>(loc, memref1d, ValueRange{offset});
272 }
273 
274 }  // namespace lmhlo
275 }  // namespace mlir
276