1 //===- RISCVGatherScatterLowering.cpp - Gather/Scatter lowering -----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This pass custom lowers llvm.gather and llvm.scatter instructions to
10 // RISCV intrinsics.
11 //
12 //===----------------------------------------------------------------------===//
13
14 #include "RISCV.h"
15 #include "RISCVTargetMachine.h"
16 #include "llvm/Analysis/LoopInfo.h"
17 #include "llvm/Analysis/ValueTracking.h"
18 #include "llvm/Analysis/VectorUtils.h"
19 #include "llvm/CodeGen/TargetPassConfig.h"
20 #include "llvm/IR/GetElementPtrTypeIterator.h"
21 #include "llvm/IR/IRBuilder.h"
22 #include "llvm/IR/IntrinsicInst.h"
23 #include "llvm/IR/IntrinsicsRISCV.h"
24 #include "llvm/IR/PatternMatch.h"
25 #include "llvm/Transforms/Utils/Local.h"
26 #include <optional>
27
28 using namespace llvm;
29 using namespace PatternMatch;
30
31 #define DEBUG_TYPE "riscv-gather-scatter-lowering"
32
33 namespace {
34
35 class RISCVGatherScatterLowering : public FunctionPass {
36 const RISCVSubtarget *ST = nullptr;
37 const RISCVTargetLowering *TLI = nullptr;
38 LoopInfo *LI = nullptr;
39 const DataLayout *DL = nullptr;
40
41 SmallVector<WeakTrackingVH> MaybeDeadPHIs;
42
43 // Cache of the BasePtr and Stride determined from this GEP. When a GEP is
44 // used by multiple gathers/scatters, this allow us to reuse the scalar
45 // instructions we created for the first gather/scatter for the others.
46 DenseMap<GetElementPtrInst *, std::pair<Value *, Value *>> StridedAddrs;
47
48 public:
49 static char ID; // Pass identification, replacement for typeid
50
RISCVGatherScatterLowering()51 RISCVGatherScatterLowering() : FunctionPass(ID) {}
52
53 bool runOnFunction(Function &F) override;
54
getAnalysisUsage(AnalysisUsage & AU) const55 void getAnalysisUsage(AnalysisUsage &AU) const override {
56 AU.setPreservesCFG();
57 AU.addRequired<TargetPassConfig>();
58 AU.addRequired<LoopInfoWrapperPass>();
59 }
60
getPassName() const61 StringRef getPassName() const override {
62 return "RISCV gather/scatter lowering";
63 }
64
65 private:
66 bool isLegalTypeAndAlignment(Type *DataType, Value *AlignOp);
67
68 bool tryCreateStridedLoadStore(IntrinsicInst *II, Type *DataType, Value *Ptr,
69 Value *AlignOp);
70
71 std::pair<Value *, Value *> determineBaseAndStride(GetElementPtrInst *GEP,
72 IRBuilder<> &Builder);
73
74 bool matchStridedRecurrence(Value *Index, Loop *L, Value *&Stride,
75 PHINode *&BasePtr, BinaryOperator *&Inc,
76 IRBuilder<> &Builder);
77 };
78
79 } // end anonymous namespace
80
81 char RISCVGatherScatterLowering::ID = 0;
82
83 INITIALIZE_PASS(RISCVGatherScatterLowering, DEBUG_TYPE,
84 "RISCV gather/scatter lowering pass", false, false)
85
createRISCVGatherScatterLoweringPass()86 FunctionPass *llvm::createRISCVGatherScatterLoweringPass() {
87 return new RISCVGatherScatterLowering();
88 }
89
isLegalTypeAndAlignment(Type * DataType,Value * AlignOp)90 bool RISCVGatherScatterLowering::isLegalTypeAndAlignment(Type *DataType,
91 Value *AlignOp) {
92 Type *ScalarType = DataType->getScalarType();
93 if (!TLI->isLegalElementTypeForRVV(ScalarType))
94 return false;
95
96 MaybeAlign MA = cast<ConstantInt>(AlignOp)->getMaybeAlignValue();
97 if (MA && MA->value() < DL->getTypeStoreSize(ScalarType).getFixedValue())
98 return false;
99
100 // FIXME: Let the backend type legalize by splitting/widening?
101 EVT DataVT = TLI->getValueType(*DL, DataType);
102 if (!TLI->isTypeLegal(DataVT))
103 return false;
104
105 return true;
106 }
107
108 // TODO: Should we consider the mask when looking for a stride?
matchStridedConstant(Constant * StartC)109 static std::pair<Value *, Value *> matchStridedConstant(Constant *StartC) {
110 unsigned NumElts = cast<FixedVectorType>(StartC->getType())->getNumElements();
111
112 // Check that the start value is a strided constant.
113 auto *StartVal =
114 dyn_cast_or_null<ConstantInt>(StartC->getAggregateElement((unsigned)0));
115 if (!StartVal)
116 return std::make_pair(nullptr, nullptr);
117 APInt StrideVal(StartVal->getValue().getBitWidth(), 0);
118 ConstantInt *Prev = StartVal;
119 for (unsigned i = 1; i != NumElts; ++i) {
120 auto *C = dyn_cast_or_null<ConstantInt>(StartC->getAggregateElement(i));
121 if (!C)
122 return std::make_pair(nullptr, nullptr);
123
124 APInt LocalStride = C->getValue() - Prev->getValue();
125 if (i == 1)
126 StrideVal = LocalStride;
127 else if (StrideVal != LocalStride)
128 return std::make_pair(nullptr, nullptr);
129
130 Prev = C;
131 }
132
133 Value *Stride = ConstantInt::get(StartVal->getType(), StrideVal);
134
135 return std::make_pair(StartVal, Stride);
136 }
137
matchStridedStart(Value * Start,IRBuilder<> & Builder)138 static std::pair<Value *, Value *> matchStridedStart(Value *Start,
139 IRBuilder<> &Builder) {
140 // Base case, start is a strided constant.
141 auto *StartC = dyn_cast<Constant>(Start);
142 if (StartC)
143 return matchStridedConstant(StartC);
144
145 // Base case, start is a stepvector
146 if (match(Start, m_Intrinsic<Intrinsic::experimental_stepvector>())) {
147 auto *Ty = Start->getType()->getScalarType();
148 return std::make_pair(ConstantInt::get(Ty, 0), ConstantInt::get(Ty, 1));
149 }
150
151 // Not a constant, maybe it's a strided constant with a splat added to it.
152 auto *BO = dyn_cast<BinaryOperator>(Start);
153 if (!BO || BO->getOpcode() != Instruction::Add)
154 return std::make_pair(nullptr, nullptr);
155
156 // Look for an operand that is splatted.
157 unsigned OtherIndex = 1;
158 Value *Splat = getSplatValue(BO->getOperand(0));
159 if (!Splat) {
160 Splat = getSplatValue(BO->getOperand(1));
161 OtherIndex = 0;
162 }
163 if (!Splat)
164 return std::make_pair(nullptr, nullptr);
165
166 Value *Stride;
167 std::tie(Start, Stride) = matchStridedStart(BO->getOperand(OtherIndex),
168 Builder);
169 if (!Start)
170 return std::make_pair(nullptr, nullptr);
171
172 // Add the splat value to the start.
173 Builder.SetInsertPoint(BO);
174 Builder.SetCurrentDebugLocation(DebugLoc());
175 Start = Builder.CreateAdd(Start, Splat);
176 return std::make_pair(Start, Stride);
177 }
178
179 // Recursively, walk about the use-def chain until we find a Phi with a strided
180 // start value. Build and update a scalar recurrence as we unwind the recursion.
181 // We also update the Stride as we unwind. Our goal is to move all of the
182 // arithmetic out of the loop.
matchStridedRecurrence(Value * Index,Loop * L,Value * & Stride,PHINode * & BasePtr,BinaryOperator * & Inc,IRBuilder<> & Builder)183 bool RISCVGatherScatterLowering::matchStridedRecurrence(Value *Index, Loop *L,
184 Value *&Stride,
185 PHINode *&BasePtr,
186 BinaryOperator *&Inc,
187 IRBuilder<> &Builder) {
188 // Our base case is a Phi.
189 if (auto *Phi = dyn_cast<PHINode>(Index)) {
190 // A phi node we want to perform this function on should be from the
191 // loop header.
192 if (Phi->getParent() != L->getHeader())
193 return false;
194
195 Value *Step, *Start;
196 if (!matchSimpleRecurrence(Phi, Inc, Start, Step) ||
197 Inc->getOpcode() != Instruction::Add)
198 return false;
199 assert(Phi->getNumIncomingValues() == 2 && "Expected 2 operand phi.");
200 unsigned IncrementingBlock = Phi->getIncomingValue(0) == Inc ? 0 : 1;
201 assert(Phi->getIncomingValue(IncrementingBlock) == Inc &&
202 "Expected one operand of phi to be Inc");
203
204 // Only proceed if the step is loop invariant.
205 if (!L->isLoopInvariant(Step))
206 return false;
207
208 // Step should be a splat.
209 Step = getSplatValue(Step);
210 if (!Step)
211 return false;
212
213 std::tie(Start, Stride) = matchStridedStart(Start, Builder);
214 if (!Start)
215 return false;
216 assert(Stride != nullptr);
217
218 // Build scalar phi and increment.
219 BasePtr =
220 PHINode::Create(Start->getType(), 2, Phi->getName() + ".scalar", Phi);
221 Inc = BinaryOperator::CreateAdd(BasePtr, Step, Inc->getName() + ".scalar",
222 Inc);
223 BasePtr->addIncoming(Start, Phi->getIncomingBlock(1 - IncrementingBlock));
224 BasePtr->addIncoming(Inc, Phi->getIncomingBlock(IncrementingBlock));
225
226 // Note that this Phi might be eligible for removal.
227 MaybeDeadPHIs.push_back(Phi);
228 return true;
229 }
230
231 // Otherwise look for binary operator.
232 auto *BO = dyn_cast<BinaryOperator>(Index);
233 if (!BO)
234 return false;
235
236 if (BO->getOpcode() != Instruction::Add &&
237 BO->getOpcode() != Instruction::Or &&
238 BO->getOpcode() != Instruction::Mul &&
239 BO->getOpcode() != Instruction::Shl)
240 return false;
241
242 // Only support shift by constant.
243 if (BO->getOpcode() == Instruction::Shl && !isa<Constant>(BO->getOperand(1)))
244 return false;
245
246 // We need to be able to treat Or as Add.
247 if (BO->getOpcode() == Instruction::Or &&
248 !haveNoCommonBitsSet(BO->getOperand(0), BO->getOperand(1), *DL))
249 return false;
250
251 // We should have one operand in the loop and one splat.
252 Value *OtherOp;
253 if (isa<Instruction>(BO->getOperand(0)) &&
254 L->contains(cast<Instruction>(BO->getOperand(0)))) {
255 Index = cast<Instruction>(BO->getOperand(0));
256 OtherOp = BO->getOperand(1);
257 } else if (isa<Instruction>(BO->getOperand(1)) &&
258 L->contains(cast<Instruction>(BO->getOperand(1)))) {
259 Index = cast<Instruction>(BO->getOperand(1));
260 OtherOp = BO->getOperand(0);
261 } else {
262 return false;
263 }
264
265 // Make sure other op is loop invariant.
266 if (!L->isLoopInvariant(OtherOp))
267 return false;
268
269 // Make sure we have a splat.
270 Value *SplatOp = getSplatValue(OtherOp);
271 if (!SplatOp)
272 return false;
273
274 // Recurse up the use-def chain.
275 if (!matchStridedRecurrence(Index, L, Stride, BasePtr, Inc, Builder))
276 return false;
277
278 // Locate the Step and Start values from the recurrence.
279 unsigned StepIndex = Inc->getOperand(0) == BasePtr ? 1 : 0;
280 unsigned StartBlock = BasePtr->getOperand(0) == Inc ? 1 : 0;
281 Value *Step = Inc->getOperand(StepIndex);
282 Value *Start = BasePtr->getOperand(StartBlock);
283
284 // We need to adjust the start value in the preheader.
285 Builder.SetInsertPoint(
286 BasePtr->getIncomingBlock(StartBlock)->getTerminator());
287 Builder.SetCurrentDebugLocation(DebugLoc());
288
289 switch (BO->getOpcode()) {
290 default:
291 llvm_unreachable("Unexpected opcode!");
292 case Instruction::Add:
293 case Instruction::Or: {
294 // An add only affects the start value. It's ok to do this for Or because
295 // we already checked that there are no common set bits.
296
297 // If the start value is Zero, just take the SplatOp.
298 if (isa<ConstantInt>(Start) && cast<ConstantInt>(Start)->isZero())
299 Start = SplatOp;
300 else
301 Start = Builder.CreateAdd(Start, SplatOp, "start");
302 BasePtr->setIncomingValue(StartBlock, Start);
303 break;
304 }
305 case Instruction::Mul: {
306 // If the start is zero we don't need to multiply.
307 if (!isa<ConstantInt>(Start) || !cast<ConstantInt>(Start)->isZero())
308 Start = Builder.CreateMul(Start, SplatOp, "start");
309
310 Step = Builder.CreateMul(Step, SplatOp, "step");
311
312 // If the Stride is 1 just take the SplatOpt.
313 if (isa<ConstantInt>(Stride) && cast<ConstantInt>(Stride)->isOne())
314 Stride = SplatOp;
315 else
316 Stride = Builder.CreateMul(Stride, SplatOp, "stride");
317 Inc->setOperand(StepIndex, Step);
318 BasePtr->setIncomingValue(StartBlock, Start);
319 break;
320 }
321 case Instruction::Shl: {
322 // If the start is zero we don't need to shift.
323 if (!isa<ConstantInt>(Start) || !cast<ConstantInt>(Start)->isZero())
324 Start = Builder.CreateShl(Start, SplatOp, "start");
325 Step = Builder.CreateShl(Step, SplatOp, "step");
326 Stride = Builder.CreateShl(Stride, SplatOp, "stride");
327 Inc->setOperand(StepIndex, Step);
328 BasePtr->setIncomingValue(StartBlock, Start);
329 break;
330 }
331 }
332
333 return true;
334 }
335
336 std::pair<Value *, Value *>
determineBaseAndStride(GetElementPtrInst * GEP,IRBuilder<> & Builder)337 RISCVGatherScatterLowering::determineBaseAndStride(GetElementPtrInst *GEP,
338 IRBuilder<> &Builder) {
339
340 auto I = StridedAddrs.find(GEP);
341 if (I != StridedAddrs.end())
342 return I->second;
343
344 SmallVector<Value *, 2> Ops(GEP->operands());
345
346 // Base pointer needs to be a scalar.
347 if (Ops[0]->getType()->isVectorTy())
348 return std::make_pair(nullptr, nullptr);
349
350 std::optional<unsigned> VecOperand;
351 unsigned TypeScale = 0;
352
353 // Look for a vector operand and scale.
354 gep_type_iterator GTI = gep_type_begin(GEP);
355 for (unsigned i = 1, e = GEP->getNumOperands(); i != e; ++i, ++GTI) {
356 if (!Ops[i]->getType()->isVectorTy())
357 continue;
358
359 if (VecOperand)
360 return std::make_pair(nullptr, nullptr);
361
362 VecOperand = i;
363
364 TypeSize TS = DL->getTypeAllocSize(GTI.getIndexedType());
365 if (TS.isScalable())
366 return std::make_pair(nullptr, nullptr);
367
368 TypeScale = TS.getFixedValue();
369 }
370
371 // We need to find a vector index to simplify.
372 if (!VecOperand)
373 return std::make_pair(nullptr, nullptr);
374
375 // We can't extract the stride if the arithmetic is done at a different size
376 // than the pointer type. Adding the stride later may not wrap correctly.
377 // Technically we could handle wider indices, but I don't expect that in
378 // practice.
379 Value *VecIndex = Ops[*VecOperand];
380 Type *VecIntPtrTy = DL->getIntPtrType(GEP->getType());
381 if (VecIndex->getType() != VecIntPtrTy)
382 return std::make_pair(nullptr, nullptr);
383
384 // Handle the non-recursive case. This is what we see if the vectorizer
385 // decides to use a scalar IV + vid on demand instead of a vector IV.
386 auto [Start, Stride] = matchStridedStart(VecIndex, Builder);
387 if (Start) {
388 assert(Stride);
389 Builder.SetInsertPoint(GEP);
390
391 // Replace the vector index with the scalar start and build a scalar GEP.
392 Ops[*VecOperand] = Start;
393 Type *SourceTy = GEP->getSourceElementType();
394 Value *BasePtr =
395 Builder.CreateGEP(SourceTy, Ops[0], ArrayRef(Ops).drop_front());
396
397 // Convert stride to pointer size if needed.
398 Type *IntPtrTy = DL->getIntPtrType(BasePtr->getType());
399 assert(Stride->getType() == IntPtrTy && "Unexpected type");
400
401 // Scale the stride by the size of the indexed type.
402 if (TypeScale != 1)
403 Stride = Builder.CreateMul(Stride, ConstantInt::get(IntPtrTy, TypeScale));
404
405 auto P = std::make_pair(BasePtr, Stride);
406 StridedAddrs[GEP] = P;
407 return P;
408 }
409
410 // Make sure we're in a loop and that has a pre-header and a single latch.
411 Loop *L = LI->getLoopFor(GEP->getParent());
412 if (!L || !L->getLoopPreheader() || !L->getLoopLatch())
413 return std::make_pair(nullptr, nullptr);
414
415 BinaryOperator *Inc;
416 PHINode *BasePhi;
417 if (!matchStridedRecurrence(VecIndex, L, Stride, BasePhi, Inc, Builder))
418 return std::make_pair(nullptr, nullptr);
419
420 assert(BasePhi->getNumIncomingValues() == 2 && "Expected 2 operand phi.");
421 unsigned IncrementingBlock = BasePhi->getOperand(0) == Inc ? 0 : 1;
422 assert(BasePhi->getIncomingValue(IncrementingBlock) == Inc &&
423 "Expected one operand of phi to be Inc");
424
425 Builder.SetInsertPoint(GEP);
426
427 // Replace the vector index with the scalar phi and build a scalar GEP.
428 Ops[*VecOperand] = BasePhi;
429 Type *SourceTy = GEP->getSourceElementType();
430 Value *BasePtr =
431 Builder.CreateGEP(SourceTy, Ops[0], ArrayRef(Ops).drop_front());
432
433 // Final adjustments to stride should go in the start block.
434 Builder.SetInsertPoint(
435 BasePhi->getIncomingBlock(1 - IncrementingBlock)->getTerminator());
436
437 // Convert stride to pointer size if needed.
438 Type *IntPtrTy = DL->getIntPtrType(BasePtr->getType());
439 assert(Stride->getType() == IntPtrTy && "Unexpected type");
440
441 // Scale the stride by the size of the indexed type.
442 if (TypeScale != 1)
443 Stride = Builder.CreateMul(Stride, ConstantInt::get(IntPtrTy, TypeScale));
444
445 auto P = std::make_pair(BasePtr, Stride);
446 StridedAddrs[GEP] = P;
447 return P;
448 }
449
tryCreateStridedLoadStore(IntrinsicInst * II,Type * DataType,Value * Ptr,Value * AlignOp)450 bool RISCVGatherScatterLowering::tryCreateStridedLoadStore(IntrinsicInst *II,
451 Type *DataType,
452 Value *Ptr,
453 Value *AlignOp) {
454 // Make sure the operation will be supported by the backend.
455 if (!isLegalTypeAndAlignment(DataType, AlignOp))
456 return false;
457
458 // Pointer should be a GEP.
459 auto *GEP = dyn_cast<GetElementPtrInst>(Ptr);
460 if (!GEP)
461 return false;
462
463 IRBuilder<> Builder(GEP);
464
465 Value *BasePtr, *Stride;
466 std::tie(BasePtr, Stride) = determineBaseAndStride(GEP, Builder);
467 if (!BasePtr)
468 return false;
469 assert(Stride != nullptr);
470
471 Builder.SetInsertPoint(II);
472
473 CallInst *Call;
474 if (II->getIntrinsicID() == Intrinsic::masked_gather)
475 Call = Builder.CreateIntrinsic(
476 Intrinsic::riscv_masked_strided_load,
477 {DataType, BasePtr->getType(), Stride->getType()},
478 {II->getArgOperand(3), BasePtr, Stride, II->getArgOperand(2)});
479 else
480 Call = Builder.CreateIntrinsic(
481 Intrinsic::riscv_masked_strided_store,
482 {DataType, BasePtr->getType(), Stride->getType()},
483 {II->getArgOperand(0), BasePtr, Stride, II->getArgOperand(3)});
484
485 Call->takeName(II);
486 II->replaceAllUsesWith(Call);
487 II->eraseFromParent();
488
489 if (GEP->use_empty())
490 RecursivelyDeleteTriviallyDeadInstructions(GEP);
491
492 return true;
493 }
494
runOnFunction(Function & F)495 bool RISCVGatherScatterLowering::runOnFunction(Function &F) {
496 if (skipFunction(F))
497 return false;
498
499 auto &TPC = getAnalysis<TargetPassConfig>();
500 auto &TM = TPC.getTM<RISCVTargetMachine>();
501 ST = &TM.getSubtarget<RISCVSubtarget>(F);
502 if (!ST->hasVInstructions() || !ST->useRVVForFixedLengthVectors())
503 return false;
504
505 TLI = ST->getTargetLowering();
506 DL = &F.getParent()->getDataLayout();
507 LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
508
509 StridedAddrs.clear();
510
511 SmallVector<IntrinsicInst *, 4> Gathers;
512 SmallVector<IntrinsicInst *, 4> Scatters;
513
514 bool Changed = false;
515
516 for (BasicBlock &BB : F) {
517 for (Instruction &I : BB) {
518 IntrinsicInst *II = dyn_cast<IntrinsicInst>(&I);
519 if (II && II->getIntrinsicID() == Intrinsic::masked_gather) {
520 Gathers.push_back(II);
521 } else if (II && II->getIntrinsicID() == Intrinsic::masked_scatter) {
522 Scatters.push_back(II);
523 }
524 }
525 }
526
527 // Rewrite gather/scatter to form strided load/store if possible.
528 for (auto *II : Gathers)
529 Changed |= tryCreateStridedLoadStore(
530 II, II->getType(), II->getArgOperand(0), II->getArgOperand(1));
531 for (auto *II : Scatters)
532 Changed |=
533 tryCreateStridedLoadStore(II, II->getArgOperand(0)->getType(),
534 II->getArgOperand(1), II->getArgOperand(2));
535
536 // Remove any dead phis.
537 while (!MaybeDeadPHIs.empty()) {
538 if (auto *Phi = dyn_cast_or_null<PHINode>(MaybeDeadPHIs.pop_back_val()))
539 RecursivelyDeleteDeadPHINode(Phi);
540 }
541
542 return Changed;
543 }
544