1 //===- ScalarizeMaskedMemIntrin.cpp - Scalarize unsupported masked mem ----===//
2 // intrinsics
3 //
4 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
5 // See https://llvm.org/LICENSE.txt for license information.
6 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 // This pass replaces masked memory intrinsics - when unsupported by the target
11 // - with a chain of basic blocks, that deal with the elements one-by-one if the
12 // appropriate mask bit is set.
13 //
14 //===----------------------------------------------------------------------===//
15
16 #include "llvm/Transforms/Scalar/ScalarizeMaskedMemIntrin.h"
17 #include "llvm/ADT/Twine.h"
18 #include "llvm/Analysis/DomTreeUpdater.h"
19 #include "llvm/Analysis/TargetTransformInfo.h"
20 #include "llvm/IR/BasicBlock.h"
21 #include "llvm/IR/Constant.h"
22 #include "llvm/IR/Constants.h"
23 #include "llvm/IR/DerivedTypes.h"
24 #include "llvm/IR/Dominators.h"
25 #include "llvm/IR/Function.h"
26 #include "llvm/IR/IRBuilder.h"
27 #include "llvm/IR/Instruction.h"
28 #include "llvm/IR/Instructions.h"
29 #include "llvm/IR/IntrinsicInst.h"
30 #include "llvm/IR/Type.h"
31 #include "llvm/IR/Value.h"
32 #include "llvm/InitializePasses.h"
33 #include "llvm/Pass.h"
34 #include "llvm/Support/Casting.h"
35 #include "llvm/Transforms/Scalar.h"
36 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
37 #include <cassert>
38 #include <optional>
39
40 using namespace llvm;
41
42 #define DEBUG_TYPE "scalarize-masked-mem-intrin"
43
44 namespace {
45
46 class ScalarizeMaskedMemIntrinLegacyPass : public FunctionPass {
47 public:
48 static char ID; // Pass identification, replacement for typeid
49
ScalarizeMaskedMemIntrinLegacyPass()50 explicit ScalarizeMaskedMemIntrinLegacyPass() : FunctionPass(ID) {
51 initializeScalarizeMaskedMemIntrinLegacyPassPass(
52 *PassRegistry::getPassRegistry());
53 }
54
55 bool runOnFunction(Function &F) override;
56
getPassName() const57 StringRef getPassName() const override {
58 return "Scalarize Masked Memory Intrinsics";
59 }
60
getAnalysisUsage(AnalysisUsage & AU) const61 void getAnalysisUsage(AnalysisUsage &AU) const override {
62 AU.addRequired<TargetTransformInfoWrapperPass>();
63 AU.addPreserved<DominatorTreeWrapperPass>();
64 }
65 };
66
67 } // end anonymous namespace
68
69 static bool optimizeBlock(BasicBlock &BB, bool &ModifiedDT,
70 const TargetTransformInfo &TTI, const DataLayout &DL,
71 DomTreeUpdater *DTU);
72 static bool optimizeCallInst(CallInst *CI, bool &ModifiedDT,
73 const TargetTransformInfo &TTI,
74 const DataLayout &DL, DomTreeUpdater *DTU);
75
76 char ScalarizeMaskedMemIntrinLegacyPass::ID = 0;
77
78 INITIALIZE_PASS_BEGIN(ScalarizeMaskedMemIntrinLegacyPass, DEBUG_TYPE,
79 "Scalarize unsupported masked memory intrinsics", false,
80 false)
INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass)81 INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass)
82 INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
83 INITIALIZE_PASS_END(ScalarizeMaskedMemIntrinLegacyPass, DEBUG_TYPE,
84 "Scalarize unsupported masked memory intrinsics", false,
85 false)
86
87 FunctionPass *llvm::createScalarizeMaskedMemIntrinLegacyPass() {
88 return new ScalarizeMaskedMemIntrinLegacyPass();
89 }
90
isConstantIntVector(Value * Mask)91 static bool isConstantIntVector(Value *Mask) {
92 Constant *C = dyn_cast<Constant>(Mask);
93 if (!C)
94 return false;
95
96 unsigned NumElts = cast<FixedVectorType>(Mask->getType())->getNumElements();
97 for (unsigned i = 0; i != NumElts; ++i) {
98 Constant *CElt = C->getAggregateElement(i);
99 if (!CElt || !isa<ConstantInt>(CElt))
100 return false;
101 }
102
103 return true;
104 }
105
adjustForEndian(const DataLayout & DL,unsigned VectorWidth,unsigned Idx)106 static unsigned adjustForEndian(const DataLayout &DL, unsigned VectorWidth,
107 unsigned Idx) {
108 return DL.isBigEndian() ? VectorWidth - 1 - Idx : Idx;
109 }
110
111 // Translate a masked load intrinsic like
112 // <16 x i32 > @llvm.masked.load( <16 x i32>* %addr, i32 align,
113 // <16 x i1> %mask, <16 x i32> %passthru)
114 // to a chain of basic blocks, with loading element one-by-one if
115 // the appropriate mask bit is set
116 //
117 // %1 = bitcast i8* %addr to i32*
118 // %2 = extractelement <16 x i1> %mask, i32 0
119 // br i1 %2, label %cond.load, label %else
120 //
121 // cond.load: ; preds = %0
122 // %3 = getelementptr i32* %1, i32 0
123 // %4 = load i32* %3
124 // %5 = insertelement <16 x i32> %passthru, i32 %4, i32 0
125 // br label %else
126 //
127 // else: ; preds = %0, %cond.load
128 // %res.phi.else = phi <16 x i32> [ %5, %cond.load ], [ undef, %0 ]
129 // %6 = extractelement <16 x i1> %mask, i32 1
130 // br i1 %6, label %cond.load1, label %else2
131 //
132 // cond.load1: ; preds = %else
133 // %7 = getelementptr i32* %1, i32 1
134 // %8 = load i32* %7
135 // %9 = insertelement <16 x i32> %res.phi.else, i32 %8, i32 1
136 // br label %else2
137 //
138 // else2: ; preds = %else, %cond.load1
139 // %res.phi.else3 = phi <16 x i32> [ %9, %cond.load1 ], [ %res.phi.else, %else ]
140 // %10 = extractelement <16 x i1> %mask, i32 2
141 // br i1 %10, label %cond.load4, label %else5
142 //
scalarizeMaskedLoad(const DataLayout & DL,CallInst * CI,DomTreeUpdater * DTU,bool & ModifiedDT)143 static void scalarizeMaskedLoad(const DataLayout &DL, CallInst *CI,
144 DomTreeUpdater *DTU, bool &ModifiedDT) {
145 Value *Ptr = CI->getArgOperand(0);
146 Value *Alignment = CI->getArgOperand(1);
147 Value *Mask = CI->getArgOperand(2);
148 Value *Src0 = CI->getArgOperand(3);
149
150 const Align AlignVal = cast<ConstantInt>(Alignment)->getAlignValue();
151 VectorType *VecType = cast<FixedVectorType>(CI->getType());
152
153 Type *EltTy = VecType->getElementType();
154
155 IRBuilder<> Builder(CI->getContext());
156 Instruction *InsertPt = CI;
157 BasicBlock *IfBlock = CI->getParent();
158
159 Builder.SetInsertPoint(InsertPt);
160 Builder.SetCurrentDebugLocation(CI->getDebugLoc());
161
162 // Short-cut if the mask is all-true.
163 if (isa<Constant>(Mask) && cast<Constant>(Mask)->isAllOnesValue()) {
164 Value *NewI = Builder.CreateAlignedLoad(VecType, Ptr, AlignVal);
165 CI->replaceAllUsesWith(NewI);
166 CI->eraseFromParent();
167 return;
168 }
169
170 // Adjust alignment for the scalar instruction.
171 const Align AdjustedAlignVal =
172 commonAlignment(AlignVal, EltTy->getPrimitiveSizeInBits() / 8);
173 // Bitcast %addr from i8* to EltTy*
174 Type *NewPtrType =
175 EltTy->getPointerTo(Ptr->getType()->getPointerAddressSpace());
176 Value *FirstEltPtr = Builder.CreateBitCast(Ptr, NewPtrType);
177 unsigned VectorWidth = cast<FixedVectorType>(VecType)->getNumElements();
178
179 // The result vector
180 Value *VResult = Src0;
181
182 if (isConstantIntVector(Mask)) {
183 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
184 if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue())
185 continue;
186 Value *Gep = Builder.CreateConstInBoundsGEP1_32(EltTy, FirstEltPtr, Idx);
187 LoadInst *Load = Builder.CreateAlignedLoad(EltTy, Gep, AdjustedAlignVal);
188 VResult = Builder.CreateInsertElement(VResult, Load, Idx);
189 }
190 CI->replaceAllUsesWith(VResult);
191 CI->eraseFromParent();
192 return;
193 }
194
195 // If the mask is not v1i1, use scalar bit test operations. This generates
196 // better results on X86 at least.
197 Value *SclrMask;
198 if (VectorWidth != 1) {
199 Type *SclrMaskTy = Builder.getIntNTy(VectorWidth);
200 SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask");
201 }
202
203 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
204 // Fill the "else" block, created in the previous iteration
205 //
206 // %res.phi.else3 = phi <16 x i32> [ %11, %cond.load1 ], [ %res.phi.else, %else ]
207 // %mask_1 = and i16 %scalar_mask, i32 1 << Idx
208 // %cond = icmp ne i16 %mask_1, 0
209 // br i1 %mask_1, label %cond.load, label %else
210 //
211 Value *Predicate;
212 if (VectorWidth != 1) {
213 Value *Mask = Builder.getInt(APInt::getOneBitSet(
214 VectorWidth, adjustForEndian(DL, VectorWidth, Idx)));
215 Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask),
216 Builder.getIntN(VectorWidth, 0));
217 } else {
218 Predicate = Builder.CreateExtractElement(Mask, Idx);
219 }
220
221 // Create "cond" block
222 //
223 // %EltAddr = getelementptr i32* %1, i32 0
224 // %Elt = load i32* %EltAddr
225 // VResult = insertelement <16 x i32> VResult, i32 %Elt, i32 Idx
226 //
227 Instruction *ThenTerm =
228 SplitBlockAndInsertIfThen(Predicate, InsertPt, /*Unreachable=*/false,
229 /*BranchWeights=*/nullptr, DTU);
230
231 BasicBlock *CondBlock = ThenTerm->getParent();
232 CondBlock->setName("cond.load");
233
234 Builder.SetInsertPoint(CondBlock->getTerminator());
235 Value *Gep = Builder.CreateConstInBoundsGEP1_32(EltTy, FirstEltPtr, Idx);
236 LoadInst *Load = Builder.CreateAlignedLoad(EltTy, Gep, AdjustedAlignVal);
237 Value *NewVResult = Builder.CreateInsertElement(VResult, Load, Idx);
238
239 // Create "else" block, fill it in the next iteration
240 BasicBlock *NewIfBlock = ThenTerm->getSuccessor(0);
241 NewIfBlock->setName("else");
242 BasicBlock *PrevIfBlock = IfBlock;
243 IfBlock = NewIfBlock;
244
245 // Create the phi to join the new and previous value.
246 Builder.SetInsertPoint(NewIfBlock, NewIfBlock->begin());
247 PHINode *Phi = Builder.CreatePHI(VecType, 2, "res.phi.else");
248 Phi->addIncoming(NewVResult, CondBlock);
249 Phi->addIncoming(VResult, PrevIfBlock);
250 VResult = Phi;
251 }
252
253 CI->replaceAllUsesWith(VResult);
254 CI->eraseFromParent();
255
256 ModifiedDT = true;
257 }
258
259 // Translate a masked store intrinsic, like
260 // void @llvm.masked.store(<16 x i32> %src, <16 x i32>* %addr, i32 align,
261 // <16 x i1> %mask)
262 // to a chain of basic blocks, that stores element one-by-one if
263 // the appropriate mask bit is set
264 //
265 // %1 = bitcast i8* %addr to i32*
266 // %2 = extractelement <16 x i1> %mask, i32 0
267 // br i1 %2, label %cond.store, label %else
268 //
269 // cond.store: ; preds = %0
270 // %3 = extractelement <16 x i32> %val, i32 0
271 // %4 = getelementptr i32* %1, i32 0
272 // store i32 %3, i32* %4
273 // br label %else
274 //
275 // else: ; preds = %0, %cond.store
276 // %5 = extractelement <16 x i1> %mask, i32 1
277 // br i1 %5, label %cond.store1, label %else2
278 //
279 // cond.store1: ; preds = %else
280 // %6 = extractelement <16 x i32> %val, i32 1
281 // %7 = getelementptr i32* %1, i32 1
282 // store i32 %6, i32* %7
283 // br label %else2
284 // . . .
scalarizeMaskedStore(const DataLayout & DL,CallInst * CI,DomTreeUpdater * DTU,bool & ModifiedDT)285 static void scalarizeMaskedStore(const DataLayout &DL, CallInst *CI,
286 DomTreeUpdater *DTU, bool &ModifiedDT) {
287 Value *Src = CI->getArgOperand(0);
288 Value *Ptr = CI->getArgOperand(1);
289 Value *Alignment = CI->getArgOperand(2);
290 Value *Mask = CI->getArgOperand(3);
291
292 const Align AlignVal = cast<ConstantInt>(Alignment)->getAlignValue();
293 auto *VecType = cast<VectorType>(Src->getType());
294
295 Type *EltTy = VecType->getElementType();
296
297 IRBuilder<> Builder(CI->getContext());
298 Instruction *InsertPt = CI;
299 Builder.SetInsertPoint(InsertPt);
300 Builder.SetCurrentDebugLocation(CI->getDebugLoc());
301
302 // Short-cut if the mask is all-true.
303 if (isa<Constant>(Mask) && cast<Constant>(Mask)->isAllOnesValue()) {
304 Builder.CreateAlignedStore(Src, Ptr, AlignVal);
305 CI->eraseFromParent();
306 return;
307 }
308
309 // Adjust alignment for the scalar instruction.
310 const Align AdjustedAlignVal =
311 commonAlignment(AlignVal, EltTy->getPrimitiveSizeInBits() / 8);
312 // Bitcast %addr from i8* to EltTy*
313 Type *NewPtrType =
314 EltTy->getPointerTo(Ptr->getType()->getPointerAddressSpace());
315 Value *FirstEltPtr = Builder.CreateBitCast(Ptr, NewPtrType);
316 unsigned VectorWidth = cast<FixedVectorType>(VecType)->getNumElements();
317
318 if (isConstantIntVector(Mask)) {
319 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
320 if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue())
321 continue;
322 Value *OneElt = Builder.CreateExtractElement(Src, Idx);
323 Value *Gep = Builder.CreateConstInBoundsGEP1_32(EltTy, FirstEltPtr, Idx);
324 Builder.CreateAlignedStore(OneElt, Gep, AdjustedAlignVal);
325 }
326 CI->eraseFromParent();
327 return;
328 }
329
330 // If the mask is not v1i1, use scalar bit test operations. This generates
331 // better results on X86 at least.
332 Value *SclrMask;
333 if (VectorWidth != 1) {
334 Type *SclrMaskTy = Builder.getIntNTy(VectorWidth);
335 SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask");
336 }
337
338 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
339 // Fill the "else" block, created in the previous iteration
340 //
341 // %mask_1 = and i16 %scalar_mask, i32 1 << Idx
342 // %cond = icmp ne i16 %mask_1, 0
343 // br i1 %mask_1, label %cond.store, label %else
344 //
345 Value *Predicate;
346 if (VectorWidth != 1) {
347 Value *Mask = Builder.getInt(APInt::getOneBitSet(
348 VectorWidth, adjustForEndian(DL, VectorWidth, Idx)));
349 Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask),
350 Builder.getIntN(VectorWidth, 0));
351 } else {
352 Predicate = Builder.CreateExtractElement(Mask, Idx);
353 }
354
355 // Create "cond" block
356 //
357 // %OneElt = extractelement <16 x i32> %Src, i32 Idx
358 // %EltAddr = getelementptr i32* %1, i32 0
359 // %store i32 %OneElt, i32* %EltAddr
360 //
361 Instruction *ThenTerm =
362 SplitBlockAndInsertIfThen(Predicate, InsertPt, /*Unreachable=*/false,
363 /*BranchWeights=*/nullptr, DTU);
364
365 BasicBlock *CondBlock = ThenTerm->getParent();
366 CondBlock->setName("cond.store");
367
368 Builder.SetInsertPoint(CondBlock->getTerminator());
369 Value *OneElt = Builder.CreateExtractElement(Src, Idx);
370 Value *Gep = Builder.CreateConstInBoundsGEP1_32(EltTy, FirstEltPtr, Idx);
371 Builder.CreateAlignedStore(OneElt, Gep, AdjustedAlignVal);
372
373 // Create "else" block, fill it in the next iteration
374 BasicBlock *NewIfBlock = ThenTerm->getSuccessor(0);
375 NewIfBlock->setName("else");
376
377 Builder.SetInsertPoint(NewIfBlock, NewIfBlock->begin());
378 }
379 CI->eraseFromParent();
380
381 ModifiedDT = true;
382 }
383
384 // Translate a masked gather intrinsic like
385 // <16 x i32 > @llvm.masked.gather.v16i32( <16 x i32*> %Ptrs, i32 4,
386 // <16 x i1> %Mask, <16 x i32> %Src)
387 // to a chain of basic blocks, with loading element one-by-one if
388 // the appropriate mask bit is set
389 //
390 // %Ptrs = getelementptr i32, i32* %base, <16 x i64> %ind
391 // %Mask0 = extractelement <16 x i1> %Mask, i32 0
392 // br i1 %Mask0, label %cond.load, label %else
393 //
394 // cond.load:
395 // %Ptr0 = extractelement <16 x i32*> %Ptrs, i32 0
396 // %Load0 = load i32, i32* %Ptr0, align 4
397 // %Res0 = insertelement <16 x i32> undef, i32 %Load0, i32 0
398 // br label %else
399 //
400 // else:
401 // %res.phi.else = phi <16 x i32>[%Res0, %cond.load], [undef, %0]
402 // %Mask1 = extractelement <16 x i1> %Mask, i32 1
403 // br i1 %Mask1, label %cond.load1, label %else2
404 //
405 // cond.load1:
406 // %Ptr1 = extractelement <16 x i32*> %Ptrs, i32 1
407 // %Load1 = load i32, i32* %Ptr1, align 4
408 // %Res1 = insertelement <16 x i32> %res.phi.else, i32 %Load1, i32 1
409 // br label %else2
410 // . . .
411 // %Result = select <16 x i1> %Mask, <16 x i32> %res.phi.select, <16 x i32> %Src
412 // ret <16 x i32> %Result
scalarizeMaskedGather(const DataLayout & DL,CallInst * CI,DomTreeUpdater * DTU,bool & ModifiedDT)413 static void scalarizeMaskedGather(const DataLayout &DL, CallInst *CI,
414 DomTreeUpdater *DTU, bool &ModifiedDT) {
415 Value *Ptrs = CI->getArgOperand(0);
416 Value *Alignment = CI->getArgOperand(1);
417 Value *Mask = CI->getArgOperand(2);
418 Value *Src0 = CI->getArgOperand(3);
419
420 auto *VecType = cast<FixedVectorType>(CI->getType());
421 Type *EltTy = VecType->getElementType();
422
423 IRBuilder<> Builder(CI->getContext());
424 Instruction *InsertPt = CI;
425 BasicBlock *IfBlock = CI->getParent();
426 Builder.SetInsertPoint(InsertPt);
427 MaybeAlign AlignVal = cast<ConstantInt>(Alignment)->getMaybeAlignValue();
428
429 Builder.SetCurrentDebugLocation(CI->getDebugLoc());
430
431 // The result vector
432 Value *VResult = Src0;
433 unsigned VectorWidth = VecType->getNumElements();
434
435 // Shorten the way if the mask is a vector of constants.
436 if (isConstantIntVector(Mask)) {
437 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
438 if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue())
439 continue;
440 Value *Ptr = Builder.CreateExtractElement(Ptrs, Idx, "Ptr" + Twine(Idx));
441 LoadInst *Load =
442 Builder.CreateAlignedLoad(EltTy, Ptr, AlignVal, "Load" + Twine(Idx));
443 VResult =
444 Builder.CreateInsertElement(VResult, Load, Idx, "Res" + Twine(Idx));
445 }
446 CI->replaceAllUsesWith(VResult);
447 CI->eraseFromParent();
448 return;
449 }
450
451 // If the mask is not v1i1, use scalar bit test operations. This generates
452 // better results on X86 at least.
453 Value *SclrMask;
454 if (VectorWidth != 1) {
455 Type *SclrMaskTy = Builder.getIntNTy(VectorWidth);
456 SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask");
457 }
458
459 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
460 // Fill the "else" block, created in the previous iteration
461 //
462 // %Mask1 = and i16 %scalar_mask, i32 1 << Idx
463 // %cond = icmp ne i16 %mask_1, 0
464 // br i1 %Mask1, label %cond.load, label %else
465 //
466
467 Value *Predicate;
468 if (VectorWidth != 1) {
469 Value *Mask = Builder.getInt(APInt::getOneBitSet(
470 VectorWidth, adjustForEndian(DL, VectorWidth, Idx)));
471 Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask),
472 Builder.getIntN(VectorWidth, 0));
473 } else {
474 Predicate = Builder.CreateExtractElement(Mask, Idx, "Mask" + Twine(Idx));
475 }
476
477 // Create "cond" block
478 //
479 // %EltAddr = getelementptr i32* %1, i32 0
480 // %Elt = load i32* %EltAddr
481 // VResult = insertelement <16 x i32> VResult, i32 %Elt, i32 Idx
482 //
483 Instruction *ThenTerm =
484 SplitBlockAndInsertIfThen(Predicate, InsertPt, /*Unreachable=*/false,
485 /*BranchWeights=*/nullptr, DTU);
486
487 BasicBlock *CondBlock = ThenTerm->getParent();
488 CondBlock->setName("cond.load");
489
490 Builder.SetInsertPoint(CondBlock->getTerminator());
491 Value *Ptr = Builder.CreateExtractElement(Ptrs, Idx, "Ptr" + Twine(Idx));
492 LoadInst *Load =
493 Builder.CreateAlignedLoad(EltTy, Ptr, AlignVal, "Load" + Twine(Idx));
494 Value *NewVResult =
495 Builder.CreateInsertElement(VResult, Load, Idx, "Res" + Twine(Idx));
496
497 // Create "else" block, fill it in the next iteration
498 BasicBlock *NewIfBlock = ThenTerm->getSuccessor(0);
499 NewIfBlock->setName("else");
500 BasicBlock *PrevIfBlock = IfBlock;
501 IfBlock = NewIfBlock;
502
503 // Create the phi to join the new and previous value.
504 Builder.SetInsertPoint(NewIfBlock, NewIfBlock->begin());
505 PHINode *Phi = Builder.CreatePHI(VecType, 2, "res.phi.else");
506 Phi->addIncoming(NewVResult, CondBlock);
507 Phi->addIncoming(VResult, PrevIfBlock);
508 VResult = Phi;
509 }
510
511 CI->replaceAllUsesWith(VResult);
512 CI->eraseFromParent();
513
514 ModifiedDT = true;
515 }
516
517 // Translate a masked scatter intrinsic, like
518 // void @llvm.masked.scatter.v16i32(<16 x i32> %Src, <16 x i32*>* %Ptrs, i32 4,
519 // <16 x i1> %Mask)
520 // to a chain of basic blocks, that stores element one-by-one if
521 // the appropriate mask bit is set.
522 //
523 // %Ptrs = getelementptr i32, i32* %ptr, <16 x i64> %ind
524 // %Mask0 = extractelement <16 x i1> %Mask, i32 0
525 // br i1 %Mask0, label %cond.store, label %else
526 //
527 // cond.store:
528 // %Elt0 = extractelement <16 x i32> %Src, i32 0
529 // %Ptr0 = extractelement <16 x i32*> %Ptrs, i32 0
530 // store i32 %Elt0, i32* %Ptr0, align 4
531 // br label %else
532 //
533 // else:
534 // %Mask1 = extractelement <16 x i1> %Mask, i32 1
535 // br i1 %Mask1, label %cond.store1, label %else2
536 //
537 // cond.store1:
538 // %Elt1 = extractelement <16 x i32> %Src, i32 1
539 // %Ptr1 = extractelement <16 x i32*> %Ptrs, i32 1
540 // store i32 %Elt1, i32* %Ptr1, align 4
541 // br label %else2
542 // . . .
scalarizeMaskedScatter(const DataLayout & DL,CallInst * CI,DomTreeUpdater * DTU,bool & ModifiedDT)543 static void scalarizeMaskedScatter(const DataLayout &DL, CallInst *CI,
544 DomTreeUpdater *DTU, bool &ModifiedDT) {
545 Value *Src = CI->getArgOperand(0);
546 Value *Ptrs = CI->getArgOperand(1);
547 Value *Alignment = CI->getArgOperand(2);
548 Value *Mask = CI->getArgOperand(3);
549
550 auto *SrcFVTy = cast<FixedVectorType>(Src->getType());
551
552 assert(
553 isa<VectorType>(Ptrs->getType()) &&
554 isa<PointerType>(cast<VectorType>(Ptrs->getType())->getElementType()) &&
555 "Vector of pointers is expected in masked scatter intrinsic");
556
557 IRBuilder<> Builder(CI->getContext());
558 Instruction *InsertPt = CI;
559 Builder.SetInsertPoint(InsertPt);
560 Builder.SetCurrentDebugLocation(CI->getDebugLoc());
561
562 MaybeAlign AlignVal = cast<ConstantInt>(Alignment)->getMaybeAlignValue();
563 unsigned VectorWidth = SrcFVTy->getNumElements();
564
565 // Shorten the way if the mask is a vector of constants.
566 if (isConstantIntVector(Mask)) {
567 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
568 if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue())
569 continue;
570 Value *OneElt =
571 Builder.CreateExtractElement(Src, Idx, "Elt" + Twine(Idx));
572 Value *Ptr = Builder.CreateExtractElement(Ptrs, Idx, "Ptr" + Twine(Idx));
573 Builder.CreateAlignedStore(OneElt, Ptr, AlignVal);
574 }
575 CI->eraseFromParent();
576 return;
577 }
578
579 // If the mask is not v1i1, use scalar bit test operations. This generates
580 // better results on X86 at least.
581 Value *SclrMask;
582 if (VectorWidth != 1) {
583 Type *SclrMaskTy = Builder.getIntNTy(VectorWidth);
584 SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask");
585 }
586
587 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
588 // Fill the "else" block, created in the previous iteration
589 //
590 // %Mask1 = and i16 %scalar_mask, i32 1 << Idx
591 // %cond = icmp ne i16 %mask_1, 0
592 // br i1 %Mask1, label %cond.store, label %else
593 //
594 Value *Predicate;
595 if (VectorWidth != 1) {
596 Value *Mask = Builder.getInt(APInt::getOneBitSet(
597 VectorWidth, adjustForEndian(DL, VectorWidth, Idx)));
598 Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask),
599 Builder.getIntN(VectorWidth, 0));
600 } else {
601 Predicate = Builder.CreateExtractElement(Mask, Idx, "Mask" + Twine(Idx));
602 }
603
604 // Create "cond" block
605 //
606 // %Elt1 = extractelement <16 x i32> %Src, i32 1
607 // %Ptr1 = extractelement <16 x i32*> %Ptrs, i32 1
608 // %store i32 %Elt1, i32* %Ptr1
609 //
610 Instruction *ThenTerm =
611 SplitBlockAndInsertIfThen(Predicate, InsertPt, /*Unreachable=*/false,
612 /*BranchWeights=*/nullptr, DTU);
613
614 BasicBlock *CondBlock = ThenTerm->getParent();
615 CondBlock->setName("cond.store");
616
617 Builder.SetInsertPoint(CondBlock->getTerminator());
618 Value *OneElt = Builder.CreateExtractElement(Src, Idx, "Elt" + Twine(Idx));
619 Value *Ptr = Builder.CreateExtractElement(Ptrs, Idx, "Ptr" + Twine(Idx));
620 Builder.CreateAlignedStore(OneElt, Ptr, AlignVal);
621
622 // Create "else" block, fill it in the next iteration
623 BasicBlock *NewIfBlock = ThenTerm->getSuccessor(0);
624 NewIfBlock->setName("else");
625
626 Builder.SetInsertPoint(NewIfBlock, NewIfBlock->begin());
627 }
628 CI->eraseFromParent();
629
630 ModifiedDT = true;
631 }
632
scalarizeMaskedExpandLoad(const DataLayout & DL,CallInst * CI,DomTreeUpdater * DTU,bool & ModifiedDT)633 static void scalarizeMaskedExpandLoad(const DataLayout &DL, CallInst *CI,
634 DomTreeUpdater *DTU, bool &ModifiedDT) {
635 Value *Ptr = CI->getArgOperand(0);
636 Value *Mask = CI->getArgOperand(1);
637 Value *PassThru = CI->getArgOperand(2);
638
639 auto *VecType = cast<FixedVectorType>(CI->getType());
640
641 Type *EltTy = VecType->getElementType();
642
643 IRBuilder<> Builder(CI->getContext());
644 Instruction *InsertPt = CI;
645 BasicBlock *IfBlock = CI->getParent();
646
647 Builder.SetInsertPoint(InsertPt);
648 Builder.SetCurrentDebugLocation(CI->getDebugLoc());
649
650 unsigned VectorWidth = VecType->getNumElements();
651
652 // The result vector
653 Value *VResult = PassThru;
654
655 // Shorten the way if the mask is a vector of constants.
656 // Create a build_vector pattern, with loads/undefs as necessary and then
657 // shuffle blend with the pass through value.
658 if (isConstantIntVector(Mask)) {
659 unsigned MemIndex = 0;
660 VResult = PoisonValue::get(VecType);
661 SmallVector<int, 16> ShuffleMask(VectorWidth, UndefMaskElem);
662 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
663 Value *InsertElt;
664 if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue()) {
665 InsertElt = UndefValue::get(EltTy);
666 ShuffleMask[Idx] = Idx + VectorWidth;
667 } else {
668 Value *NewPtr =
669 Builder.CreateConstInBoundsGEP1_32(EltTy, Ptr, MemIndex);
670 InsertElt = Builder.CreateAlignedLoad(EltTy, NewPtr, Align(1),
671 "Load" + Twine(Idx));
672 ShuffleMask[Idx] = Idx;
673 ++MemIndex;
674 }
675 VResult = Builder.CreateInsertElement(VResult, InsertElt, Idx,
676 "Res" + Twine(Idx));
677 }
678 VResult = Builder.CreateShuffleVector(VResult, PassThru, ShuffleMask);
679 CI->replaceAllUsesWith(VResult);
680 CI->eraseFromParent();
681 return;
682 }
683
684 // If the mask is not v1i1, use scalar bit test operations. This generates
685 // better results on X86 at least.
686 Value *SclrMask;
687 if (VectorWidth != 1) {
688 Type *SclrMaskTy = Builder.getIntNTy(VectorWidth);
689 SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask");
690 }
691
692 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
693 // Fill the "else" block, created in the previous iteration
694 //
695 // %res.phi.else3 = phi <16 x i32> [ %11, %cond.load1 ], [ %res.phi.else, %else ]
696 // %mask_1 = extractelement <16 x i1> %mask, i32 Idx
697 // br i1 %mask_1, label %cond.load, label %else
698 //
699
700 Value *Predicate;
701 if (VectorWidth != 1) {
702 Value *Mask = Builder.getInt(APInt::getOneBitSet(
703 VectorWidth, adjustForEndian(DL, VectorWidth, Idx)));
704 Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask),
705 Builder.getIntN(VectorWidth, 0));
706 } else {
707 Predicate = Builder.CreateExtractElement(Mask, Idx, "Mask" + Twine(Idx));
708 }
709
710 // Create "cond" block
711 //
712 // %EltAddr = getelementptr i32* %1, i32 0
713 // %Elt = load i32* %EltAddr
714 // VResult = insertelement <16 x i32> VResult, i32 %Elt, i32 Idx
715 //
716 Instruction *ThenTerm =
717 SplitBlockAndInsertIfThen(Predicate, InsertPt, /*Unreachable=*/false,
718 /*BranchWeights=*/nullptr, DTU);
719
720 BasicBlock *CondBlock = ThenTerm->getParent();
721 CondBlock->setName("cond.load");
722
723 Builder.SetInsertPoint(CondBlock->getTerminator());
724 LoadInst *Load = Builder.CreateAlignedLoad(EltTy, Ptr, Align(1));
725 Value *NewVResult = Builder.CreateInsertElement(VResult, Load, Idx);
726
727 // Move the pointer if there are more blocks to come.
728 Value *NewPtr;
729 if ((Idx + 1) != VectorWidth)
730 NewPtr = Builder.CreateConstInBoundsGEP1_32(EltTy, Ptr, 1);
731
732 // Create "else" block, fill it in the next iteration
733 BasicBlock *NewIfBlock = ThenTerm->getSuccessor(0);
734 NewIfBlock->setName("else");
735 BasicBlock *PrevIfBlock = IfBlock;
736 IfBlock = NewIfBlock;
737
738 // Create the phi to join the new and previous value.
739 Builder.SetInsertPoint(NewIfBlock, NewIfBlock->begin());
740 PHINode *ResultPhi = Builder.CreatePHI(VecType, 2, "res.phi.else");
741 ResultPhi->addIncoming(NewVResult, CondBlock);
742 ResultPhi->addIncoming(VResult, PrevIfBlock);
743 VResult = ResultPhi;
744
745 // Add a PHI for the pointer if this isn't the last iteration.
746 if ((Idx + 1) != VectorWidth) {
747 PHINode *PtrPhi = Builder.CreatePHI(Ptr->getType(), 2, "ptr.phi.else");
748 PtrPhi->addIncoming(NewPtr, CondBlock);
749 PtrPhi->addIncoming(Ptr, PrevIfBlock);
750 Ptr = PtrPhi;
751 }
752 }
753
754 CI->replaceAllUsesWith(VResult);
755 CI->eraseFromParent();
756
757 ModifiedDT = true;
758 }
759
scalarizeMaskedCompressStore(const DataLayout & DL,CallInst * CI,DomTreeUpdater * DTU,bool & ModifiedDT)760 static void scalarizeMaskedCompressStore(const DataLayout &DL, CallInst *CI,
761 DomTreeUpdater *DTU,
762 bool &ModifiedDT) {
763 Value *Src = CI->getArgOperand(0);
764 Value *Ptr = CI->getArgOperand(1);
765 Value *Mask = CI->getArgOperand(2);
766
767 auto *VecType = cast<FixedVectorType>(Src->getType());
768
769 IRBuilder<> Builder(CI->getContext());
770 Instruction *InsertPt = CI;
771 BasicBlock *IfBlock = CI->getParent();
772
773 Builder.SetInsertPoint(InsertPt);
774 Builder.SetCurrentDebugLocation(CI->getDebugLoc());
775
776 Type *EltTy = VecType->getElementType();
777
778 unsigned VectorWidth = VecType->getNumElements();
779
780 // Shorten the way if the mask is a vector of constants.
781 if (isConstantIntVector(Mask)) {
782 unsigned MemIndex = 0;
783 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
784 if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue())
785 continue;
786 Value *OneElt =
787 Builder.CreateExtractElement(Src, Idx, "Elt" + Twine(Idx));
788 Value *NewPtr = Builder.CreateConstInBoundsGEP1_32(EltTy, Ptr, MemIndex);
789 Builder.CreateAlignedStore(OneElt, NewPtr, Align(1));
790 ++MemIndex;
791 }
792 CI->eraseFromParent();
793 return;
794 }
795
796 // If the mask is not v1i1, use scalar bit test operations. This generates
797 // better results on X86 at least.
798 Value *SclrMask;
799 if (VectorWidth != 1) {
800 Type *SclrMaskTy = Builder.getIntNTy(VectorWidth);
801 SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask");
802 }
803
804 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
805 // Fill the "else" block, created in the previous iteration
806 //
807 // %mask_1 = extractelement <16 x i1> %mask, i32 Idx
808 // br i1 %mask_1, label %cond.store, label %else
809 //
810 Value *Predicate;
811 if (VectorWidth != 1) {
812 Value *Mask = Builder.getInt(APInt::getOneBitSet(
813 VectorWidth, adjustForEndian(DL, VectorWidth, Idx)));
814 Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask),
815 Builder.getIntN(VectorWidth, 0));
816 } else {
817 Predicate = Builder.CreateExtractElement(Mask, Idx, "Mask" + Twine(Idx));
818 }
819
820 // Create "cond" block
821 //
822 // %OneElt = extractelement <16 x i32> %Src, i32 Idx
823 // %EltAddr = getelementptr i32* %1, i32 0
824 // %store i32 %OneElt, i32* %EltAddr
825 //
826 Instruction *ThenTerm =
827 SplitBlockAndInsertIfThen(Predicate, InsertPt, /*Unreachable=*/false,
828 /*BranchWeights=*/nullptr, DTU);
829
830 BasicBlock *CondBlock = ThenTerm->getParent();
831 CondBlock->setName("cond.store");
832
833 Builder.SetInsertPoint(CondBlock->getTerminator());
834 Value *OneElt = Builder.CreateExtractElement(Src, Idx);
835 Builder.CreateAlignedStore(OneElt, Ptr, Align(1));
836
837 // Move the pointer if there are more blocks to come.
838 Value *NewPtr;
839 if ((Idx + 1) != VectorWidth)
840 NewPtr = Builder.CreateConstInBoundsGEP1_32(EltTy, Ptr, 1);
841
842 // Create "else" block, fill it in the next iteration
843 BasicBlock *NewIfBlock = ThenTerm->getSuccessor(0);
844 NewIfBlock->setName("else");
845 BasicBlock *PrevIfBlock = IfBlock;
846 IfBlock = NewIfBlock;
847
848 Builder.SetInsertPoint(NewIfBlock, NewIfBlock->begin());
849
850 // Add a PHI for the pointer if this isn't the last iteration.
851 if ((Idx + 1) != VectorWidth) {
852 PHINode *PtrPhi = Builder.CreatePHI(Ptr->getType(), 2, "ptr.phi.else");
853 PtrPhi->addIncoming(NewPtr, CondBlock);
854 PtrPhi->addIncoming(Ptr, PrevIfBlock);
855 Ptr = PtrPhi;
856 }
857 }
858 CI->eraseFromParent();
859
860 ModifiedDT = true;
861 }
862
runImpl(Function & F,const TargetTransformInfo & TTI,DominatorTree * DT)863 static bool runImpl(Function &F, const TargetTransformInfo &TTI,
864 DominatorTree *DT) {
865 std::optional<DomTreeUpdater> DTU;
866 if (DT)
867 DTU.emplace(DT, DomTreeUpdater::UpdateStrategy::Lazy);
868
869 bool EverMadeChange = false;
870 bool MadeChange = true;
871 auto &DL = F.getParent()->getDataLayout();
872 while (MadeChange) {
873 MadeChange = false;
874 for (BasicBlock &BB : llvm::make_early_inc_range(F)) {
875 bool ModifiedDTOnIteration = false;
876 MadeChange |= optimizeBlock(BB, ModifiedDTOnIteration, TTI, DL,
877 DTU ? &*DTU : nullptr);
878
879 // Restart BB iteration if the dominator tree of the Function was changed
880 if (ModifiedDTOnIteration)
881 break;
882 }
883
884 EverMadeChange |= MadeChange;
885 }
886 return EverMadeChange;
887 }
888
runOnFunction(Function & F)889 bool ScalarizeMaskedMemIntrinLegacyPass::runOnFunction(Function &F) {
890 auto &TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
891 DominatorTree *DT = nullptr;
892 if (auto *DTWP = getAnalysisIfAvailable<DominatorTreeWrapperPass>())
893 DT = &DTWP->getDomTree();
894 return runImpl(F, TTI, DT);
895 }
896
897 PreservedAnalyses
run(Function & F,FunctionAnalysisManager & AM)898 ScalarizeMaskedMemIntrinPass::run(Function &F, FunctionAnalysisManager &AM) {
899 auto &TTI = AM.getResult<TargetIRAnalysis>(F);
900 auto *DT = AM.getCachedResult<DominatorTreeAnalysis>(F);
901 if (!runImpl(F, TTI, DT))
902 return PreservedAnalyses::all();
903 PreservedAnalyses PA;
904 PA.preserve<TargetIRAnalysis>();
905 PA.preserve<DominatorTreeAnalysis>();
906 return PA;
907 }
908
optimizeBlock(BasicBlock & BB,bool & ModifiedDT,const TargetTransformInfo & TTI,const DataLayout & DL,DomTreeUpdater * DTU)909 static bool optimizeBlock(BasicBlock &BB, bool &ModifiedDT,
910 const TargetTransformInfo &TTI, const DataLayout &DL,
911 DomTreeUpdater *DTU) {
912 bool MadeChange = false;
913
914 BasicBlock::iterator CurInstIterator = BB.begin();
915 while (CurInstIterator != BB.end()) {
916 if (CallInst *CI = dyn_cast<CallInst>(&*CurInstIterator++))
917 MadeChange |= optimizeCallInst(CI, ModifiedDT, TTI, DL, DTU);
918 if (ModifiedDT)
919 return true;
920 }
921
922 return MadeChange;
923 }
924
optimizeCallInst(CallInst * CI,bool & ModifiedDT,const TargetTransformInfo & TTI,const DataLayout & DL,DomTreeUpdater * DTU)925 static bool optimizeCallInst(CallInst *CI, bool &ModifiedDT,
926 const TargetTransformInfo &TTI,
927 const DataLayout &DL, DomTreeUpdater *DTU) {
928 IntrinsicInst *II = dyn_cast<IntrinsicInst>(CI);
929 if (II) {
930 // The scalarization code below does not work for scalable vectors.
931 if (isa<ScalableVectorType>(II->getType()) ||
932 any_of(II->args(),
933 [](Value *V) { return isa<ScalableVectorType>(V->getType()); }))
934 return false;
935
936 switch (II->getIntrinsicID()) {
937 default:
938 break;
939 case Intrinsic::masked_load:
940 // Scalarize unsupported vector masked load
941 if (TTI.isLegalMaskedLoad(
942 CI->getType(),
943 cast<ConstantInt>(CI->getArgOperand(1))->getAlignValue()))
944 return false;
945 scalarizeMaskedLoad(DL, CI, DTU, ModifiedDT);
946 return true;
947 case Intrinsic::masked_store:
948 if (TTI.isLegalMaskedStore(
949 CI->getArgOperand(0)->getType(),
950 cast<ConstantInt>(CI->getArgOperand(2))->getAlignValue()))
951 return false;
952 scalarizeMaskedStore(DL, CI, DTU, ModifiedDT);
953 return true;
954 case Intrinsic::masked_gather: {
955 MaybeAlign MA =
956 cast<ConstantInt>(CI->getArgOperand(1))->getMaybeAlignValue();
957 Type *LoadTy = CI->getType();
958 Align Alignment = DL.getValueOrABITypeAlignment(MA,
959 LoadTy->getScalarType());
960 if (TTI.isLegalMaskedGather(LoadTy, Alignment) &&
961 !TTI.forceScalarizeMaskedGather(cast<VectorType>(LoadTy), Alignment))
962 return false;
963 scalarizeMaskedGather(DL, CI, DTU, ModifiedDT);
964 return true;
965 }
966 case Intrinsic::masked_scatter: {
967 MaybeAlign MA =
968 cast<ConstantInt>(CI->getArgOperand(2))->getMaybeAlignValue();
969 Type *StoreTy = CI->getArgOperand(0)->getType();
970 Align Alignment = DL.getValueOrABITypeAlignment(MA,
971 StoreTy->getScalarType());
972 if (TTI.isLegalMaskedScatter(StoreTy, Alignment) &&
973 !TTI.forceScalarizeMaskedScatter(cast<VectorType>(StoreTy),
974 Alignment))
975 return false;
976 scalarizeMaskedScatter(DL, CI, DTU, ModifiedDT);
977 return true;
978 }
979 case Intrinsic::masked_expandload:
980 if (TTI.isLegalMaskedExpandLoad(CI->getType()))
981 return false;
982 scalarizeMaskedExpandLoad(DL, CI, DTU, ModifiedDT);
983 return true;
984 case Intrinsic::masked_compressstore:
985 if (TTI.isLegalMaskedCompressStore(CI->getArgOperand(0)->getType()))
986 return false;
987 scalarizeMaskedCompressStore(DL, CI, DTU, ModifiedDT);
988 return true;
989 }
990 }
991
992 return false;
993 }
994