1 //===- MVELaneInterleaving.cpp - Inverleave for MVE instructions ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This pass interleaves around sext/zext/trunc instructions. MVE does not have
10 // a single sext/zext or trunc instruction that takes the bottom half of a
11 // vector and extends to a full width, like NEON has with MOVL. Instead it is
12 // expected that this happens through top/bottom instructions. So the MVE
13 // equivalent VMOVLT/B instructions take either the even or odd elements of the
14 // input and extend them to the larger type, producing a vector with half the
15 // number of elements each of double the bitwidth. As there is no simple
16 // instruction, we often have to turn sext/zext/trunc into a series of lane
17 // moves (or stack loads/stores, which we do not do yet).
18 //
19 // This pass takes vector code that starts at truncs, looks for interconnected
20 // blobs of operations that end with sext/zext (or constants/splats) of the
21 // form:
22 // %sa = sext v8i16 %a to v8i32
23 // %sb = sext v8i16 %b to v8i32
24 // %add = add v8i32 %sa, %sb
25 // %r = trunc %add to v8i16
26 // And adds shuffles to allow the use of VMOVL/VMOVN instrctions:
27 // %sha = shuffle v8i16 %a, undef, <0, 2, 4, 6, 1, 3, 5, 7>
28 // %sa = sext v8i16 %sha to v8i32
29 // %shb = shuffle v8i16 %b, undef, <0, 2, 4, 6, 1, 3, 5, 7>
30 // %sb = sext v8i16 %shb to v8i32
31 // %add = add v8i32 %sa, %sb
32 // %r = trunc %add to v8i16
33 // %shr = shuffle v8i16 %r, undef, <0, 4, 1, 5, 2, 6, 3, 7>
34 // Which can then be split and lowered to MVE instructions efficiently:
35 // %sa_b = VMOVLB.s16 %a
36 // %sa_t = VMOVLT.s16 %a
37 // %sb_b = VMOVLB.s16 %b
38 // %sb_t = VMOVLT.s16 %b
39 // %add_b = VADD.i32 %sa_b, %sb_b
40 // %add_t = VADD.i32 %sa_t, %sb_t
41 // %r = VMOVNT.i16 %add_b, %add_t
42 //
43 //===----------------------------------------------------------------------===//
44
45 #include "ARM.h"
46 #include "ARMBaseInstrInfo.h"
47 #include "ARMSubtarget.h"
48 #include "llvm/ADT/SetVector.h"
49 #include "llvm/Analysis/TargetTransformInfo.h"
50 #include "llvm/CodeGen/TargetLowering.h"
51 #include "llvm/CodeGen/TargetPassConfig.h"
52 #include "llvm/CodeGen/TargetSubtargetInfo.h"
53 #include "llvm/IR/BasicBlock.h"
54 #include "llvm/IR/Constant.h"
55 #include "llvm/IR/Constants.h"
56 #include "llvm/IR/DerivedTypes.h"
57 #include "llvm/IR/Function.h"
58 #include "llvm/IR/IRBuilder.h"
59 #include "llvm/IR/InstIterator.h"
60 #include "llvm/IR/InstrTypes.h"
61 #include "llvm/IR/Instruction.h"
62 #include "llvm/IR/Instructions.h"
63 #include "llvm/IR/IntrinsicInst.h"
64 #include "llvm/IR/Intrinsics.h"
65 #include "llvm/IR/IntrinsicsARM.h"
66 #include "llvm/IR/PatternMatch.h"
67 #include "llvm/IR/Type.h"
68 #include "llvm/IR/Value.h"
69 #include "llvm/InitializePasses.h"
70 #include "llvm/Pass.h"
71 #include "llvm/Support/Casting.h"
72 #include <algorithm>
73 #include <cassert>
74
75 using namespace llvm;
76
77 #define DEBUG_TYPE "mve-laneinterleave"
78
79 cl::opt<bool> EnableInterleave(
80 "enable-mve-interleave", cl::Hidden, cl::init(true),
81 cl::desc("Enable interleave MVE vector operation lowering"));
82
83 namespace {
84
85 class MVELaneInterleaving : public FunctionPass {
86 public:
87 static char ID; // Pass identification, replacement for typeid
88
MVELaneInterleaving()89 explicit MVELaneInterleaving() : FunctionPass(ID) {
90 initializeMVELaneInterleavingPass(*PassRegistry::getPassRegistry());
91 }
92
93 bool runOnFunction(Function &F) override;
94
getPassName() const95 StringRef getPassName() const override { return "MVE lane interleaving"; }
96
getAnalysisUsage(AnalysisUsage & AU) const97 void getAnalysisUsage(AnalysisUsage &AU) const override {
98 AU.setPreservesCFG();
99 AU.addRequired<TargetPassConfig>();
100 FunctionPass::getAnalysisUsage(AU);
101 }
102 };
103
104 } // end anonymous namespace
105
106 char MVELaneInterleaving::ID = 0;
107
108 INITIALIZE_PASS(MVELaneInterleaving, DEBUG_TYPE, "MVE lane interleaving", false,
109 false)
110
createMVELaneInterleavingPass()111 Pass *llvm::createMVELaneInterleavingPass() {
112 return new MVELaneInterleaving();
113 }
114
isProfitableToInterleave(SmallSetVector<Instruction *,4> & Exts,SmallSetVector<Instruction *,4> & Truncs)115 static bool isProfitableToInterleave(SmallSetVector<Instruction *, 4> &Exts,
116 SmallSetVector<Instruction *, 4> &Truncs) {
117 // This is not always beneficial to transform. Exts can be incorporated into
118 // loads, Truncs can be folded into stores.
119 // Truncs are usually the same number of instructions,
120 // VSTRH.32(A);VSTRH.32(B) vs VSTRH.16(VMOVNT A, B) with interleaving
121 // Exts are unfortunately more instructions in the general case:
122 // A=VLDRH.32; B=VLDRH.32;
123 // vs with interleaving:
124 // T=VLDRH.16; A=VMOVNB T; B=VMOVNT T
125 // But those VMOVL may be folded into a VMULL.
126
127 // But expensive extends/truncs are always good to remove. FPExts always
128 // involve extra VCVT's so are always considered to be beneficial to convert.
129 for (auto *E : Exts) {
130 if (isa<FPExtInst>(E) || !isa<LoadInst>(E->getOperand(0))) {
131 LLVM_DEBUG(dbgs() << "Beneficial due to " << *E << "\n");
132 return true;
133 }
134 }
135 for (auto *T : Truncs) {
136 if (T->hasOneUse() && !isa<StoreInst>(*T->user_begin())) {
137 LLVM_DEBUG(dbgs() << "Beneficial due to " << *T << "\n");
138 return true;
139 }
140 }
141
142 // Otherwise, we know we have a load(ext), see if any of the Extends are a
143 // vmull. This is a simple heuristic and certainly not perfect.
144 for (auto *E : Exts) {
145 if (!E->hasOneUse() ||
146 cast<Instruction>(*E->user_begin())->getOpcode() != Instruction::Mul) {
147 LLVM_DEBUG(dbgs() << "Not beneficial due to " << *E << "\n");
148 return false;
149 }
150 }
151 return true;
152 }
153
tryInterleave(Instruction * Start,SmallPtrSetImpl<Instruction * > & Visited)154 static bool tryInterleave(Instruction *Start,
155 SmallPtrSetImpl<Instruction *> &Visited) {
156 LLVM_DEBUG(dbgs() << "tryInterleave from " << *Start << "\n");
157 auto *VT = cast<FixedVectorType>(Start->getType());
158
159 if (!isa<Instruction>(Start->getOperand(0)))
160 return false;
161
162 // Look for connected operations starting from Ext's, terminating at Truncs.
163 std::vector<Instruction *> Worklist;
164 Worklist.push_back(Start);
165 Worklist.push_back(cast<Instruction>(Start->getOperand(0)));
166
167 SmallSetVector<Instruction *, 4> Truncs;
168 SmallSetVector<Instruction *, 4> Exts;
169 SmallSetVector<Use *, 4> OtherLeafs;
170 SmallSetVector<Instruction *, 4> Ops;
171
172 while (!Worklist.empty()) {
173 Instruction *I = Worklist.back();
174 Worklist.pop_back();
175
176 switch (I->getOpcode()) {
177 // Truncs
178 case Instruction::Trunc:
179 case Instruction::FPTrunc:
180 if (!Truncs.insert(I))
181 continue;
182 Visited.insert(I);
183 break;
184
185 // Extend leafs
186 case Instruction::SExt:
187 case Instruction::ZExt:
188 case Instruction::FPExt:
189 if (Exts.count(I))
190 continue;
191 for (auto *Use : I->users())
192 Worklist.push_back(cast<Instruction>(Use));
193 Exts.insert(I);
194 break;
195
196 case Instruction::Call: {
197 IntrinsicInst *II = dyn_cast<IntrinsicInst>(I);
198 if (!II)
199 return false;
200
201 switch (II->getIntrinsicID()) {
202 case Intrinsic::abs:
203 case Intrinsic::smin:
204 case Intrinsic::smax:
205 case Intrinsic::umin:
206 case Intrinsic::umax:
207 case Intrinsic::sadd_sat:
208 case Intrinsic::ssub_sat:
209 case Intrinsic::uadd_sat:
210 case Intrinsic::usub_sat:
211 case Intrinsic::minnum:
212 case Intrinsic::maxnum:
213 case Intrinsic::fabs:
214 case Intrinsic::fma:
215 case Intrinsic::ceil:
216 case Intrinsic::floor:
217 case Intrinsic::rint:
218 case Intrinsic::round:
219 case Intrinsic::trunc:
220 break;
221 default:
222 return false;
223 }
224 [[fallthrough]]; // Fall through to treating these like an operator below.
225 }
226 // Binary/tertiary ops
227 case Instruction::Add:
228 case Instruction::Sub:
229 case Instruction::Mul:
230 case Instruction::AShr:
231 case Instruction::LShr:
232 case Instruction::Shl:
233 case Instruction::ICmp:
234 case Instruction::FCmp:
235 case Instruction::FAdd:
236 case Instruction::FMul:
237 case Instruction::Select:
238 if (!Ops.insert(I))
239 continue;
240
241 for (Use &Op : I->operands()) {
242 if (!isa<FixedVectorType>(Op->getType()))
243 continue;
244 if (isa<Instruction>(Op))
245 Worklist.push_back(cast<Instruction>(&Op));
246 else
247 OtherLeafs.insert(&Op);
248 }
249
250 for (auto *Use : I->users())
251 Worklist.push_back(cast<Instruction>(Use));
252 break;
253
254 case Instruction::ShuffleVector:
255 // A shuffle of a splat is a splat.
256 if (cast<ShuffleVectorInst>(I)->isZeroEltSplat())
257 continue;
258 [[fallthrough]];
259
260 default:
261 LLVM_DEBUG(dbgs() << " Unhandled instruction: " << *I << "\n");
262 return false;
263 }
264 }
265
266 if (Exts.empty() && OtherLeafs.empty())
267 return false;
268
269 LLVM_DEBUG({
270 dbgs() << "Found group:\n Exts:";
271 for (auto *I : Exts)
272 dbgs() << " " << *I << "\n";
273 dbgs() << " Ops:";
274 for (auto *I : Ops)
275 dbgs() << " " << *I << "\n";
276 dbgs() << " OtherLeafs:";
277 for (auto *I : OtherLeafs)
278 dbgs() << " " << *I->get() << " of " << *I->getUser() << "\n";
279 dbgs() << "Truncs:";
280 for (auto *I : Truncs)
281 dbgs() << " " << *I << "\n";
282 });
283
284 assert(!Truncs.empty() && "Expected some truncs");
285
286 // Check types
287 unsigned NumElts = VT->getNumElements();
288 unsigned BaseElts = VT->getScalarSizeInBits() == 16
289 ? 8
290 : (VT->getScalarSizeInBits() == 8 ? 16 : 0);
291 if (BaseElts == 0 || NumElts % BaseElts != 0) {
292 LLVM_DEBUG(dbgs() << " Type is unsupported\n");
293 return false;
294 }
295 if (Start->getOperand(0)->getType()->getScalarSizeInBits() !=
296 VT->getScalarSizeInBits() * 2) {
297 LLVM_DEBUG(dbgs() << " Type not double sized\n");
298 return false;
299 }
300 for (Instruction *I : Exts)
301 if (I->getOperand(0)->getType() != VT) {
302 LLVM_DEBUG(dbgs() << " Wrong type on " << *I << "\n");
303 return false;
304 }
305 for (Instruction *I : Truncs)
306 if (I->getType() != VT) {
307 LLVM_DEBUG(dbgs() << " Wrong type on " << *I << "\n");
308 return false;
309 }
310
311 // Check that it looks beneficial
312 if (!isProfitableToInterleave(Exts, Truncs))
313 return false;
314
315 // Create new shuffles around the extends / truncs / other leaves.
316 IRBuilder<> Builder(Start);
317
318 SmallVector<int, 16> LeafMask;
319 SmallVector<int, 16> TruncMask;
320 // LeafMask : 0, 2, 4, 6, 1, 3, 5, 7 8, 10, 12, 14, 9, 11, 13, 15
321 // TruncMask: 0, 4, 1, 5, 2, 6, 3, 7 8, 12, 9, 13, 10, 14, 11, 15
322 for (unsigned Base = 0; Base < NumElts; Base += BaseElts) {
323 for (unsigned i = 0; i < BaseElts / 2; i++)
324 LeafMask.push_back(Base + i * 2);
325 for (unsigned i = 0; i < BaseElts / 2; i++)
326 LeafMask.push_back(Base + i * 2 + 1);
327 }
328 for (unsigned Base = 0; Base < NumElts; Base += BaseElts) {
329 for (unsigned i = 0; i < BaseElts / 2; i++) {
330 TruncMask.push_back(Base + i);
331 TruncMask.push_back(Base + i + BaseElts / 2);
332 }
333 }
334
335 for (Instruction *I : Exts) {
336 LLVM_DEBUG(dbgs() << "Replacing ext " << *I << "\n");
337 Builder.SetInsertPoint(I);
338 Value *Shuffle = Builder.CreateShuffleVector(I->getOperand(0), LeafMask);
339 bool FPext = isa<FPExtInst>(I);
340 bool Sext = isa<SExtInst>(I);
341 Value *Ext = FPext ? Builder.CreateFPExt(Shuffle, I->getType())
342 : Sext ? Builder.CreateSExt(Shuffle, I->getType())
343 : Builder.CreateZExt(Shuffle, I->getType());
344 I->replaceAllUsesWith(Ext);
345 LLVM_DEBUG(dbgs() << " with " << *Shuffle << "\n");
346 }
347
348 for (Use *I : OtherLeafs) {
349 LLVM_DEBUG(dbgs() << "Replacing leaf " << *I << "\n");
350 Builder.SetInsertPoint(cast<Instruction>(I->getUser()));
351 Value *Shuffle = Builder.CreateShuffleVector(I->get(), LeafMask);
352 I->getUser()->setOperand(I->getOperandNo(), Shuffle);
353 LLVM_DEBUG(dbgs() << " with " << *Shuffle << "\n");
354 }
355
356 for (Instruction *I : Truncs) {
357 LLVM_DEBUG(dbgs() << "Replacing trunc " << *I << "\n");
358
359 Builder.SetInsertPoint(I->getParent(), ++I->getIterator());
360 Value *Shuf = Builder.CreateShuffleVector(I, TruncMask);
361 I->replaceAllUsesWith(Shuf);
362 cast<Instruction>(Shuf)->setOperand(0, I);
363
364 LLVM_DEBUG(dbgs() << " with " << *Shuf << "\n");
365 }
366
367 return true;
368 }
369
runOnFunction(Function & F)370 bool MVELaneInterleaving::runOnFunction(Function &F) {
371 if (!EnableInterleave)
372 return false;
373 auto &TPC = getAnalysis<TargetPassConfig>();
374 auto &TM = TPC.getTM<TargetMachine>();
375 auto *ST = &TM.getSubtarget<ARMSubtarget>(F);
376 if (!ST->hasMVEIntegerOps())
377 return false;
378
379 bool Changed = false;
380
381 SmallPtrSet<Instruction *, 16> Visited;
382 for (Instruction &I : reverse(instructions(F))) {
383 if (I.getType()->isVectorTy() &&
384 (isa<TruncInst>(I) || isa<FPTruncInst>(I)) && !Visited.count(&I))
385 Changed |= tryInterleave(&I, Visited);
386 }
387
388 return Changed;
389 }
390