1 //===- ComplexDeinterleavingPass.cpp --------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Identification:
10 // This step is responsible for finding the patterns that can be lowered to
11 // complex instructions, and building a graph to represent the complex
12 // structures. Starting from the "Converging Shuffle" (a shuffle that
13 // reinterleaves the complex components, with a mask of <0, 2, 1, 3>), the
14 // operands are evaluated and identified as "Composite Nodes" (collections of
15 // instructions that can potentially be lowered to a single complex
16 // instruction). This is performed by checking the real and imaginary components
17 // and tracking the data flow for each component while following the operand
18 // pairs. Validity of each node is expected to be done upon creation, and any
19 // validation errors should halt traversal and prevent further graph
20 // construction.
21 //
22 // Replacement:
23 // This step traverses the graph built up by identification, delegating to the
24 // target to validate and generate the correct intrinsics, and plumbs them
25 // together connecting each end of the new intrinsics graph to the existing
26 // use-def chain. This step is assumed to finish successfully, as all
27 // information is expected to be correct by this point.
28 //
29 //
30 // Internal data structure:
31 // ComplexDeinterleavingGraph:
32 // Keeps references to all the valid CompositeNodes formed as part of the
33 // transformation, and every Instruction contained within said nodes. It also
34 // holds onto a reference to the root Instruction, and the root node that should
35 // replace it.
36 //
37 // ComplexDeinterleavingCompositeNode:
38 // A CompositeNode represents a single transformation point; each node should
39 // transform into a single complex instruction (ignoring vector splitting, which
40 // would generate more instructions per node). They are identified in a
41 // depth-first manner, traversing and identifying the operands of each
42 // instruction in the order they appear in the IR.
43 // Each node maintains a reference to its Real and Imaginary instructions,
44 // as well as any additional instructions that make up the identified operation
45 // (Internal instructions should only have uses within their containing node).
46 // A Node also contains the rotation and operation type that it represents.
47 // Operands contains pointers to other CompositeNodes, acting as the edges in
48 // the graph. ReplacementValue is the transformed Value* that has been emitted
49 // to the IR.
50 //
51 // Note: If the operation of a Node is Shuffle, only the Real, Imaginary, and
52 // ReplacementValue fields of that Node are relevant, where the ReplacementValue
53 // should be pre-populated.
54 //
55 //===----------------------------------------------------------------------===//
56
57 #include "llvm/CodeGen/ComplexDeinterleavingPass.h"
58 #include "llvm/ADT/Statistic.h"
59 #include "llvm/Analysis/TargetLibraryInfo.h"
60 #include "llvm/Analysis/TargetTransformInfo.h"
61 #include "llvm/CodeGen/TargetLowering.h"
62 #include "llvm/CodeGen/TargetPassConfig.h"
63 #include "llvm/CodeGen/TargetSubtargetInfo.h"
64 #include "llvm/IR/IRBuilder.h"
65 #include "llvm/InitializePasses.h"
66 #include "llvm/Target/TargetMachine.h"
67 #include "llvm/Transforms/Utils/Local.h"
68 #include <algorithm>
69
70 using namespace llvm;
71 using namespace PatternMatch;
72
73 #define DEBUG_TYPE "complex-deinterleaving"
74
75 STATISTIC(NumComplexTransformations, "Amount of complex patterns transformed");
76
77 static cl::opt<bool> ComplexDeinterleavingEnabled(
78 "enable-complex-deinterleaving",
79 cl::desc("Enable generation of complex instructions"), cl::init(true),
80 cl::Hidden);
81
82 /// Checks the given mask, and determines whether said mask is interleaving.
83 ///
84 /// To be interleaving, a mask must alternate between `i` and `i + (Length /
85 /// 2)`, and must contain all numbers within the range of `[0..Length)` (e.g. a
86 /// 4x vector interleaving mask would be <0, 2, 1, 3>).
87 static bool isInterleavingMask(ArrayRef<int> Mask);
88
89 /// Checks the given mask, and determines whether said mask is deinterleaving.
90 ///
91 /// To be deinterleaving, a mask must increment in steps of 2, and either start
92 /// with 0 or 1.
93 /// (e.g. an 8x vector deinterleaving mask would be either <0, 2, 4, 6> or
94 /// <1, 3, 5, 7>).
95 static bool isDeinterleavingMask(ArrayRef<int> Mask);
96
97 namespace {
98
99 class ComplexDeinterleavingLegacyPass : public FunctionPass {
100 public:
101 static char ID;
102
ComplexDeinterleavingLegacyPass(const TargetMachine * TM=nullptr)103 ComplexDeinterleavingLegacyPass(const TargetMachine *TM = nullptr)
104 : FunctionPass(ID), TM(TM) {
105 initializeComplexDeinterleavingLegacyPassPass(
106 *PassRegistry::getPassRegistry());
107 }
108
getPassName() const109 StringRef getPassName() const override {
110 return "Complex Deinterleaving Pass";
111 }
112
113 bool runOnFunction(Function &F) override;
getAnalysisUsage(AnalysisUsage & AU) const114 void getAnalysisUsage(AnalysisUsage &AU) const override {
115 AU.addRequired<TargetLibraryInfoWrapperPass>();
116 AU.setPreservesCFG();
117 }
118
119 private:
120 const TargetMachine *TM;
121 };
122
123 class ComplexDeinterleavingGraph;
124 struct ComplexDeinterleavingCompositeNode {
125
ComplexDeinterleavingCompositeNode__anonbedf4bd40111::ComplexDeinterleavingCompositeNode126 ComplexDeinterleavingCompositeNode(ComplexDeinterleavingOperation Op,
127 Instruction *R, Instruction *I)
128 : Operation(Op), Real(R), Imag(I) {}
129
130 private:
131 friend class ComplexDeinterleavingGraph;
132 using NodePtr = std::shared_ptr<ComplexDeinterleavingCompositeNode>;
133 using RawNodePtr = ComplexDeinterleavingCompositeNode *;
134
135 public:
136 ComplexDeinterleavingOperation Operation;
137 Instruction *Real;
138 Instruction *Imag;
139
140 // Instructions that should only exist within this node, there should be no
141 // users of these instructions outside the node. An example of these would be
142 // the multiply instructions of a partial multiply operation.
143 SmallVector<Instruction *> InternalInstructions;
144 ComplexDeinterleavingRotation Rotation;
145 SmallVector<RawNodePtr> Operands;
146 Value *ReplacementNode = nullptr;
147
addInstruction__anonbedf4bd40111::ComplexDeinterleavingCompositeNode148 void addInstruction(Instruction *I) { InternalInstructions.push_back(I); }
addOperand__anonbedf4bd40111::ComplexDeinterleavingCompositeNode149 void addOperand(NodePtr Node) { Operands.push_back(Node.get()); }
150
151 bool hasAllInternalUses(SmallPtrSet<Instruction *, 16> &AllInstructions);
152
dump__anonbedf4bd40111::ComplexDeinterleavingCompositeNode153 void dump() { dump(dbgs()); }
dump__anonbedf4bd40111::ComplexDeinterleavingCompositeNode154 void dump(raw_ostream &OS) {
155 auto PrintValue = [&](Value *V) {
156 if (V) {
157 OS << "\"";
158 V->print(OS, true);
159 OS << "\"\n";
160 } else
161 OS << "nullptr\n";
162 };
163 auto PrintNodeRef = [&](RawNodePtr Ptr) {
164 if (Ptr)
165 OS << Ptr << "\n";
166 else
167 OS << "nullptr\n";
168 };
169
170 OS << "- CompositeNode: " << this << "\n";
171 OS << " Real: ";
172 PrintValue(Real);
173 OS << " Imag: ";
174 PrintValue(Imag);
175 OS << " ReplacementNode: ";
176 PrintValue(ReplacementNode);
177 OS << " Operation: " << (int)Operation << "\n";
178 OS << " Rotation: " << ((int)Rotation * 90) << "\n";
179 OS << " Operands: \n";
180 for (const auto &Op : Operands) {
181 OS << " - ";
182 PrintNodeRef(Op);
183 }
184 OS << " InternalInstructions:\n";
185 for (const auto &I : InternalInstructions) {
186 OS << " - \"";
187 I->print(OS, true);
188 OS << "\"\n";
189 }
190 }
191 };
192
193 class ComplexDeinterleavingGraph {
194 public:
195 using NodePtr = ComplexDeinterleavingCompositeNode::NodePtr;
196 using RawNodePtr = ComplexDeinterleavingCompositeNode::RawNodePtr;
ComplexDeinterleavingGraph(const TargetLowering * tl)197 explicit ComplexDeinterleavingGraph(const TargetLowering *tl) : TL(tl) {}
198
199 private:
200 const TargetLowering *TL;
201 Instruction *RootValue;
202 NodePtr RootNode;
203 SmallVector<NodePtr> CompositeNodes;
204 SmallPtrSet<Instruction *, 16> AllInstructions;
205
prepareCompositeNode(ComplexDeinterleavingOperation Operation,Instruction * R,Instruction * I)206 NodePtr prepareCompositeNode(ComplexDeinterleavingOperation Operation,
207 Instruction *R, Instruction *I) {
208 return std::make_shared<ComplexDeinterleavingCompositeNode>(Operation, R,
209 I);
210 }
211
submitCompositeNode(NodePtr Node)212 NodePtr submitCompositeNode(NodePtr Node) {
213 CompositeNodes.push_back(Node);
214 AllInstructions.insert(Node->Real);
215 AllInstructions.insert(Node->Imag);
216 for (auto *I : Node->InternalInstructions)
217 AllInstructions.insert(I);
218 return Node;
219 }
220
getContainingComposite(Value * R,Value * I)221 NodePtr getContainingComposite(Value *R, Value *I) {
222 for (const auto &CN : CompositeNodes) {
223 if (CN->Real == R && CN->Imag == I)
224 return CN;
225 }
226 return nullptr;
227 }
228
229 /// Identifies a complex partial multiply pattern and its rotation, based on
230 /// the following patterns
231 ///
232 /// 0: r: cr + ar * br
233 /// i: ci + ar * bi
234 /// 90: r: cr - ai * bi
235 /// i: ci + ai * br
236 /// 180: r: cr - ar * br
237 /// i: ci - ar * bi
238 /// 270: r: cr + ai * bi
239 /// i: ci - ai * br
240 NodePtr identifyPartialMul(Instruction *Real, Instruction *Imag);
241
242 /// Identify the other branch of a Partial Mul, taking the CommonOperandI that
243 /// is partially known from identifyPartialMul, filling in the other half of
244 /// the complex pair.
245 NodePtr identifyNodeWithImplicitAdd(
246 Instruction *I, Instruction *J,
247 std::pair<Instruction *, Instruction *> &CommonOperandI);
248
249 /// Identifies a complex add pattern and its rotation, based on the following
250 /// patterns.
251 ///
252 /// 90: r: ar - bi
253 /// i: ai + br
254 /// 270: r: ar + bi
255 /// i: ai - br
256 NodePtr identifyAdd(Instruction *Real, Instruction *Imag);
257
258 NodePtr identifyNode(Instruction *I, Instruction *J);
259
260 Value *replaceNode(RawNodePtr Node);
261
262 public:
dump()263 void dump() { dump(dbgs()); }
dump(raw_ostream & OS)264 void dump(raw_ostream &OS) {
265 for (const auto &Node : CompositeNodes)
266 Node->dump(OS);
267 }
268
269 /// Returns false if the deinterleaving operation should be cancelled for the
270 /// current graph.
271 bool identifyNodes(Instruction *RootI);
272
273 /// Perform the actual replacement of the underlying instruction graph.
274 /// Returns false if the deinterleaving operation should be cancelled for the
275 /// current graph.
276 void replaceNodes();
277 };
278
279 class ComplexDeinterleaving {
280 public:
ComplexDeinterleaving(const TargetLowering * tl,const TargetLibraryInfo * tli)281 ComplexDeinterleaving(const TargetLowering *tl, const TargetLibraryInfo *tli)
282 : TL(tl), TLI(tli) {}
283 bool runOnFunction(Function &F);
284
285 private:
286 bool evaluateBasicBlock(BasicBlock *B);
287
288 const TargetLowering *TL = nullptr;
289 const TargetLibraryInfo *TLI = nullptr;
290 };
291
292 } // namespace
293
294 char ComplexDeinterleavingLegacyPass::ID = 0;
295
296 INITIALIZE_PASS_BEGIN(ComplexDeinterleavingLegacyPass, DEBUG_TYPE,
297 "Complex Deinterleaving", false, false)
298 INITIALIZE_PASS_END(ComplexDeinterleavingLegacyPass, DEBUG_TYPE,
299 "Complex Deinterleaving", false, false)
300
run(Function & F,FunctionAnalysisManager & AM)301 PreservedAnalyses ComplexDeinterleavingPass::run(Function &F,
302 FunctionAnalysisManager &AM) {
303 const TargetLowering *TL = TM->getSubtargetImpl(F)->getTargetLowering();
304 auto &TLI = AM.getResult<llvm::TargetLibraryAnalysis>(F);
305 if (!ComplexDeinterleaving(TL, &TLI).runOnFunction(F))
306 return PreservedAnalyses::all();
307
308 PreservedAnalyses PA;
309 PA.preserve<FunctionAnalysisManagerModuleProxy>();
310 return PA;
311 }
312
createComplexDeinterleavingPass(const TargetMachine * TM)313 FunctionPass *llvm::createComplexDeinterleavingPass(const TargetMachine *TM) {
314 return new ComplexDeinterleavingLegacyPass(TM);
315 }
316
runOnFunction(Function & F)317 bool ComplexDeinterleavingLegacyPass::runOnFunction(Function &F) {
318 const auto *TL = TM->getSubtargetImpl(F)->getTargetLowering();
319 auto TLI = getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F);
320 return ComplexDeinterleaving(TL, &TLI).runOnFunction(F);
321 }
322
runOnFunction(Function & F)323 bool ComplexDeinterleaving::runOnFunction(Function &F) {
324 if (!ComplexDeinterleavingEnabled) {
325 LLVM_DEBUG(
326 dbgs() << "Complex deinterleaving has been explicitly disabled.\n");
327 return false;
328 }
329
330 if (!TL->isComplexDeinterleavingSupported()) {
331 LLVM_DEBUG(
332 dbgs() << "Complex deinterleaving has been disabled, target does "
333 "not support lowering of complex number operations.\n");
334 return false;
335 }
336
337 bool Changed = false;
338 for (auto &B : F)
339 Changed |= evaluateBasicBlock(&B);
340
341 return Changed;
342 }
343
isInterleavingMask(ArrayRef<int> Mask)344 static bool isInterleavingMask(ArrayRef<int> Mask) {
345 // If the size is not even, it's not an interleaving mask
346 if ((Mask.size() & 1))
347 return false;
348
349 int HalfNumElements = Mask.size() / 2;
350 for (int Idx = 0; Idx < HalfNumElements; ++Idx) {
351 int MaskIdx = Idx * 2;
352 if (Mask[MaskIdx] != Idx || Mask[MaskIdx + 1] != (Idx + HalfNumElements))
353 return false;
354 }
355
356 return true;
357 }
358
isDeinterleavingMask(ArrayRef<int> Mask)359 static bool isDeinterleavingMask(ArrayRef<int> Mask) {
360 int Offset = Mask[0];
361 int HalfNumElements = Mask.size() / 2;
362
363 for (int Idx = 1; Idx < HalfNumElements; ++Idx) {
364 if (Mask[Idx] != (Idx * 2) + Offset)
365 return false;
366 }
367
368 return true;
369 }
370
evaluateBasicBlock(BasicBlock * B)371 bool ComplexDeinterleaving::evaluateBasicBlock(BasicBlock *B) {
372 bool Changed = false;
373
374 SmallVector<Instruction *> DeadInstrRoots;
375
376 for (auto &I : *B) {
377 auto *SVI = dyn_cast<ShuffleVectorInst>(&I);
378 if (!SVI)
379 continue;
380
381 // Look for a shufflevector that takes separate vectors of the real and
382 // imaginary components and recombines them into a single vector.
383 if (!isInterleavingMask(SVI->getShuffleMask()))
384 continue;
385
386 ComplexDeinterleavingGraph Graph(TL);
387 if (!Graph.identifyNodes(SVI))
388 continue;
389
390 Graph.replaceNodes();
391 DeadInstrRoots.push_back(SVI);
392 Changed = true;
393 }
394
395 for (const auto &I : DeadInstrRoots) {
396 if (!I || I->getParent() == nullptr)
397 continue;
398 llvm::RecursivelyDeleteTriviallyDeadInstructions(I, TLI);
399 }
400
401 return Changed;
402 }
403
404 ComplexDeinterleavingGraph::NodePtr
identifyNodeWithImplicitAdd(Instruction * Real,Instruction * Imag,std::pair<Instruction *,Instruction * > & PartialMatch)405 ComplexDeinterleavingGraph::identifyNodeWithImplicitAdd(
406 Instruction *Real, Instruction *Imag,
407 std::pair<Instruction *, Instruction *> &PartialMatch) {
408 LLVM_DEBUG(dbgs() << "identifyNodeWithImplicitAdd " << *Real << " / " << *Imag
409 << "\n");
410
411 if (!Real->hasOneUse() || !Imag->hasOneUse()) {
412 LLVM_DEBUG(dbgs() << " - Mul operand has multiple uses.\n");
413 return nullptr;
414 }
415
416 if (Real->getOpcode() != Instruction::FMul ||
417 Imag->getOpcode() != Instruction::FMul) {
418 LLVM_DEBUG(dbgs() << " - Real or imaginary instruction is not fmul\n");
419 return nullptr;
420 }
421
422 Instruction *R0 = dyn_cast<Instruction>(Real->getOperand(0));
423 Instruction *R1 = dyn_cast<Instruction>(Real->getOperand(1));
424 Instruction *I0 = dyn_cast<Instruction>(Imag->getOperand(0));
425 Instruction *I1 = dyn_cast<Instruction>(Imag->getOperand(1));
426 if (!R0 || !R1 || !I0 || !I1) {
427 LLVM_DEBUG(dbgs() << " - Mul operand not Instruction\n");
428 return nullptr;
429 }
430
431 // A +/+ has a rotation of 0. If any of the operands are fneg, we flip the
432 // rotations and use the operand.
433 unsigned Negs = 0;
434 SmallVector<Instruction *> FNegs;
435 if (R0->getOpcode() == Instruction::FNeg ||
436 R1->getOpcode() == Instruction::FNeg) {
437 Negs |= 1;
438 if (R0->getOpcode() == Instruction::FNeg) {
439 FNegs.push_back(R0);
440 R0 = dyn_cast<Instruction>(R0->getOperand(0));
441 } else {
442 FNegs.push_back(R1);
443 R1 = dyn_cast<Instruction>(R1->getOperand(0));
444 }
445 if (!R0 || !R1)
446 return nullptr;
447 }
448 if (I0->getOpcode() == Instruction::FNeg ||
449 I1->getOpcode() == Instruction::FNeg) {
450 Negs |= 2;
451 Negs ^= 1;
452 if (I0->getOpcode() == Instruction::FNeg) {
453 FNegs.push_back(I0);
454 I0 = dyn_cast<Instruction>(I0->getOperand(0));
455 } else {
456 FNegs.push_back(I1);
457 I1 = dyn_cast<Instruction>(I1->getOperand(0));
458 }
459 if (!I0 || !I1)
460 return nullptr;
461 }
462
463 ComplexDeinterleavingRotation Rotation = (ComplexDeinterleavingRotation)Negs;
464
465 Instruction *CommonOperand;
466 Instruction *UncommonRealOp;
467 Instruction *UncommonImagOp;
468
469 if (R0 == I0 || R0 == I1) {
470 CommonOperand = R0;
471 UncommonRealOp = R1;
472 } else if (R1 == I0 || R1 == I1) {
473 CommonOperand = R1;
474 UncommonRealOp = R0;
475 } else {
476 LLVM_DEBUG(dbgs() << " - No equal operand\n");
477 return nullptr;
478 }
479
480 UncommonImagOp = (CommonOperand == I0) ? I1 : I0;
481 if (Rotation == ComplexDeinterleavingRotation::Rotation_90 ||
482 Rotation == ComplexDeinterleavingRotation::Rotation_270)
483 std::swap(UncommonRealOp, UncommonImagOp);
484
485 // Between identifyPartialMul and here we need to have found a complete valid
486 // pair from the CommonOperand of each part.
487 if (Rotation == ComplexDeinterleavingRotation::Rotation_0 ||
488 Rotation == ComplexDeinterleavingRotation::Rotation_180)
489 PartialMatch.first = CommonOperand;
490 else
491 PartialMatch.second = CommonOperand;
492
493 if (!PartialMatch.first || !PartialMatch.second) {
494 LLVM_DEBUG(dbgs() << " - Incomplete partial match\n");
495 return nullptr;
496 }
497
498 NodePtr CommonNode = identifyNode(PartialMatch.first, PartialMatch.second);
499 if (!CommonNode) {
500 LLVM_DEBUG(dbgs() << " - No CommonNode identified\n");
501 return nullptr;
502 }
503
504 NodePtr UncommonNode = identifyNode(UncommonRealOp, UncommonImagOp);
505 if (!UncommonNode) {
506 LLVM_DEBUG(dbgs() << " - No UncommonNode identified\n");
507 return nullptr;
508 }
509
510 NodePtr Node = prepareCompositeNode(
511 ComplexDeinterleavingOperation::CMulPartial, Real, Imag);
512 Node->Rotation = Rotation;
513 Node->addOperand(CommonNode);
514 Node->addOperand(UncommonNode);
515 Node->InternalInstructions.append(FNegs);
516 return submitCompositeNode(Node);
517 }
518
519 ComplexDeinterleavingGraph::NodePtr
identifyPartialMul(Instruction * Real,Instruction * Imag)520 ComplexDeinterleavingGraph::identifyPartialMul(Instruction *Real,
521 Instruction *Imag) {
522 LLVM_DEBUG(dbgs() << "identifyPartialMul " << *Real << " / " << *Imag
523 << "\n");
524 // Determine rotation
525 ComplexDeinterleavingRotation Rotation;
526 if (Real->getOpcode() == Instruction::FAdd &&
527 Imag->getOpcode() == Instruction::FAdd)
528 Rotation = ComplexDeinterleavingRotation::Rotation_0;
529 else if (Real->getOpcode() == Instruction::FSub &&
530 Imag->getOpcode() == Instruction::FAdd)
531 Rotation = ComplexDeinterleavingRotation::Rotation_90;
532 else if (Real->getOpcode() == Instruction::FSub &&
533 Imag->getOpcode() == Instruction::FSub)
534 Rotation = ComplexDeinterleavingRotation::Rotation_180;
535 else if (Real->getOpcode() == Instruction::FAdd &&
536 Imag->getOpcode() == Instruction::FSub)
537 Rotation = ComplexDeinterleavingRotation::Rotation_270;
538 else {
539 LLVM_DEBUG(dbgs() << " - Unhandled rotation.\n");
540 return nullptr;
541 }
542
543 if (!Real->getFastMathFlags().allowContract() ||
544 !Imag->getFastMathFlags().allowContract()) {
545 LLVM_DEBUG(dbgs() << " - Contract is missing from the FastMath flags.\n");
546 return nullptr;
547 }
548
549 Value *CR = Real->getOperand(0);
550 Instruction *RealMulI = dyn_cast<Instruction>(Real->getOperand(1));
551 if (!RealMulI)
552 return nullptr;
553 Value *CI = Imag->getOperand(0);
554 Instruction *ImagMulI = dyn_cast<Instruction>(Imag->getOperand(1));
555 if (!ImagMulI)
556 return nullptr;
557
558 if (!RealMulI->hasOneUse() || !ImagMulI->hasOneUse()) {
559 LLVM_DEBUG(dbgs() << " - Mul instruction has multiple uses\n");
560 return nullptr;
561 }
562
563 Instruction *R0 = dyn_cast<Instruction>(RealMulI->getOperand(0));
564 Instruction *R1 = dyn_cast<Instruction>(RealMulI->getOperand(1));
565 Instruction *I0 = dyn_cast<Instruction>(ImagMulI->getOperand(0));
566 Instruction *I1 = dyn_cast<Instruction>(ImagMulI->getOperand(1));
567 if (!R0 || !R1 || !I0 || !I1) {
568 LLVM_DEBUG(dbgs() << " - Mul operand not Instruction\n");
569 return nullptr;
570 }
571
572 Instruction *CommonOperand;
573 Instruction *UncommonRealOp;
574 Instruction *UncommonImagOp;
575
576 if (R0 == I0 || R0 == I1) {
577 CommonOperand = R0;
578 UncommonRealOp = R1;
579 } else if (R1 == I0 || R1 == I1) {
580 CommonOperand = R1;
581 UncommonRealOp = R0;
582 } else {
583 LLVM_DEBUG(dbgs() << " - No equal operand\n");
584 return nullptr;
585 }
586
587 UncommonImagOp = (CommonOperand == I0) ? I1 : I0;
588 if (Rotation == ComplexDeinterleavingRotation::Rotation_90 ||
589 Rotation == ComplexDeinterleavingRotation::Rotation_270)
590 std::swap(UncommonRealOp, UncommonImagOp);
591
592 std::pair<Instruction *, Instruction *> PartialMatch(
593 (Rotation == ComplexDeinterleavingRotation::Rotation_0 ||
594 Rotation == ComplexDeinterleavingRotation::Rotation_180)
595 ? CommonOperand
596 : nullptr,
597 (Rotation == ComplexDeinterleavingRotation::Rotation_90 ||
598 Rotation == ComplexDeinterleavingRotation::Rotation_270)
599 ? CommonOperand
600 : nullptr);
601 NodePtr CNode = identifyNodeWithImplicitAdd(
602 cast<Instruction>(CR), cast<Instruction>(CI), PartialMatch);
603 if (!CNode) {
604 LLVM_DEBUG(dbgs() << " - No cnode identified\n");
605 return nullptr;
606 }
607
608 NodePtr UncommonRes = identifyNode(UncommonRealOp, UncommonImagOp);
609 if (!UncommonRes) {
610 LLVM_DEBUG(dbgs() << " - No UncommonRes identified\n");
611 return nullptr;
612 }
613
614 assert(PartialMatch.first && PartialMatch.second);
615 NodePtr CommonRes = identifyNode(PartialMatch.first, PartialMatch.second);
616 if (!CommonRes) {
617 LLVM_DEBUG(dbgs() << " - No CommonRes identified\n");
618 return nullptr;
619 }
620
621 NodePtr Node = prepareCompositeNode(
622 ComplexDeinterleavingOperation::CMulPartial, Real, Imag);
623 Node->addInstruction(RealMulI);
624 Node->addInstruction(ImagMulI);
625 Node->Rotation = Rotation;
626 Node->addOperand(CommonRes);
627 Node->addOperand(UncommonRes);
628 Node->addOperand(CNode);
629 return submitCompositeNode(Node);
630 }
631
632 ComplexDeinterleavingGraph::NodePtr
identifyAdd(Instruction * Real,Instruction * Imag)633 ComplexDeinterleavingGraph::identifyAdd(Instruction *Real, Instruction *Imag) {
634 LLVM_DEBUG(dbgs() << "identifyAdd " << *Real << " / " << *Imag << "\n");
635
636 // Determine rotation
637 ComplexDeinterleavingRotation Rotation;
638 if ((Real->getOpcode() == Instruction::FSub &&
639 Imag->getOpcode() == Instruction::FAdd) ||
640 (Real->getOpcode() == Instruction::Sub &&
641 Imag->getOpcode() == Instruction::Add))
642 Rotation = ComplexDeinterleavingRotation::Rotation_90;
643 else if ((Real->getOpcode() == Instruction::FAdd &&
644 Imag->getOpcode() == Instruction::FSub) ||
645 (Real->getOpcode() == Instruction::Add &&
646 Imag->getOpcode() == Instruction::Sub))
647 Rotation = ComplexDeinterleavingRotation::Rotation_270;
648 else {
649 LLVM_DEBUG(dbgs() << " - Unhandled case, rotation is not assigned.\n");
650 return nullptr;
651 }
652
653 auto *AR = dyn_cast<Instruction>(Real->getOperand(0));
654 auto *BI = dyn_cast<Instruction>(Real->getOperand(1));
655 auto *AI = dyn_cast<Instruction>(Imag->getOperand(0));
656 auto *BR = dyn_cast<Instruction>(Imag->getOperand(1));
657
658 if (!AR || !AI || !BR || !BI) {
659 LLVM_DEBUG(dbgs() << " - Not all operands are instructions.\n");
660 return nullptr;
661 }
662
663 NodePtr ResA = identifyNode(AR, AI);
664 if (!ResA) {
665 LLVM_DEBUG(dbgs() << " - AR/AI is not identified as a composite node.\n");
666 return nullptr;
667 }
668 NodePtr ResB = identifyNode(BR, BI);
669 if (!ResB) {
670 LLVM_DEBUG(dbgs() << " - BR/BI is not identified as a composite node.\n");
671 return nullptr;
672 }
673
674 NodePtr Node =
675 prepareCompositeNode(ComplexDeinterleavingOperation::CAdd, Real, Imag);
676 Node->Rotation = Rotation;
677 Node->addOperand(ResA);
678 Node->addOperand(ResB);
679 return submitCompositeNode(Node);
680 }
681
isInstructionPairAdd(Instruction * A,Instruction * B)682 static bool isInstructionPairAdd(Instruction *A, Instruction *B) {
683 unsigned OpcA = A->getOpcode();
684 unsigned OpcB = B->getOpcode();
685
686 return (OpcA == Instruction::FSub && OpcB == Instruction::FAdd) ||
687 (OpcA == Instruction::FAdd && OpcB == Instruction::FSub) ||
688 (OpcA == Instruction::Sub && OpcB == Instruction::Add) ||
689 (OpcA == Instruction::Add && OpcB == Instruction::Sub);
690 }
691
isInstructionPairMul(Instruction * A,Instruction * B)692 static bool isInstructionPairMul(Instruction *A, Instruction *B) {
693 auto Pattern =
694 m_BinOp(m_FMul(m_Value(), m_Value()), m_FMul(m_Value(), m_Value()));
695
696 return match(A, Pattern) && match(B, Pattern);
697 }
698
699 ComplexDeinterleavingGraph::NodePtr
identifyNode(Instruction * Real,Instruction * Imag)700 ComplexDeinterleavingGraph::identifyNode(Instruction *Real, Instruction *Imag) {
701 LLVM_DEBUG(dbgs() << "identifyNode on " << *Real << " / " << *Imag << "\n");
702 if (NodePtr CN = getContainingComposite(Real, Imag)) {
703 LLVM_DEBUG(dbgs() << " - Folding to existing node\n");
704 return CN;
705 }
706
707 auto *RealShuffle = dyn_cast<ShuffleVectorInst>(Real);
708 auto *ImagShuffle = dyn_cast<ShuffleVectorInst>(Imag);
709 if (RealShuffle && ImagShuffle) {
710 Value *RealOp1 = RealShuffle->getOperand(1);
711 if (!isa<UndefValue>(RealOp1) && !isa<ConstantAggregateZero>(RealOp1)) {
712 LLVM_DEBUG(dbgs() << " - RealOp1 is not undef or zero.\n");
713 return nullptr;
714 }
715 Value *ImagOp1 = ImagShuffle->getOperand(1);
716 if (!isa<UndefValue>(ImagOp1) && !isa<ConstantAggregateZero>(ImagOp1)) {
717 LLVM_DEBUG(dbgs() << " - ImagOp1 is not undef or zero.\n");
718 return nullptr;
719 }
720
721 Value *RealOp0 = RealShuffle->getOperand(0);
722 Value *ImagOp0 = ImagShuffle->getOperand(0);
723
724 if (RealOp0 != ImagOp0) {
725 LLVM_DEBUG(dbgs() << " - Shuffle operands are not equal.\n");
726 return nullptr;
727 }
728
729 ArrayRef<int> RealMask = RealShuffle->getShuffleMask();
730 ArrayRef<int> ImagMask = ImagShuffle->getShuffleMask();
731 if (!isDeinterleavingMask(RealMask) || !isDeinterleavingMask(ImagMask)) {
732 LLVM_DEBUG(dbgs() << " - Masks are not deinterleaving.\n");
733 return nullptr;
734 }
735
736 if (RealMask[0] != 0 || ImagMask[0] != 1) {
737 LLVM_DEBUG(dbgs() << " - Masks do not have the correct initial value.\n");
738 return nullptr;
739 }
740
741 // Type checking, the shuffle type should be a vector type of the same
742 // scalar type, but half the size
743 auto CheckType = [&](ShuffleVectorInst *Shuffle) {
744 Value *Op = Shuffle->getOperand(0);
745 auto *ShuffleTy = cast<FixedVectorType>(Shuffle->getType());
746 auto *OpTy = cast<FixedVectorType>(Op->getType());
747
748 if (OpTy->getScalarType() != ShuffleTy->getScalarType())
749 return false;
750 if ((ShuffleTy->getNumElements() * 2) != OpTy->getNumElements())
751 return false;
752
753 return true;
754 };
755
756 auto CheckDeinterleavingShuffle = [&](ShuffleVectorInst *Shuffle) -> bool {
757 if (!CheckType(Shuffle))
758 return false;
759
760 ArrayRef<int> Mask = Shuffle->getShuffleMask();
761 int Last = *Mask.rbegin();
762
763 Value *Op = Shuffle->getOperand(0);
764 auto *OpTy = cast<FixedVectorType>(Op->getType());
765 int NumElements = OpTy->getNumElements();
766
767 // Ensure that the deinterleaving shuffle only pulls from the first
768 // shuffle operand.
769 return Last < NumElements;
770 };
771
772 if (RealShuffle->getType() != ImagShuffle->getType()) {
773 LLVM_DEBUG(dbgs() << " - Shuffle types aren't equal.\n");
774 return nullptr;
775 }
776 if (!CheckDeinterleavingShuffle(RealShuffle)) {
777 LLVM_DEBUG(dbgs() << " - RealShuffle is invalid type.\n");
778 return nullptr;
779 }
780 if (!CheckDeinterleavingShuffle(ImagShuffle)) {
781 LLVM_DEBUG(dbgs() << " - ImagShuffle is invalid type.\n");
782 return nullptr;
783 }
784
785 NodePtr PlaceholderNode =
786 prepareCompositeNode(llvm::ComplexDeinterleavingOperation::Shuffle,
787 RealShuffle, ImagShuffle);
788 PlaceholderNode->ReplacementNode = RealShuffle->getOperand(0);
789 return submitCompositeNode(PlaceholderNode);
790 }
791 if (RealShuffle || ImagShuffle)
792 return nullptr;
793
794 auto *VTy = cast<FixedVectorType>(Real->getType());
795 auto *NewVTy =
796 FixedVectorType::get(VTy->getScalarType(), VTy->getNumElements() * 2);
797
798 if (TL->isComplexDeinterleavingOperationSupported(
799 ComplexDeinterleavingOperation::CMulPartial, NewVTy) &&
800 isInstructionPairMul(Real, Imag)) {
801 return identifyPartialMul(Real, Imag);
802 }
803
804 if (TL->isComplexDeinterleavingOperationSupported(
805 ComplexDeinterleavingOperation::CAdd, NewVTy) &&
806 isInstructionPairAdd(Real, Imag)) {
807 return identifyAdd(Real, Imag);
808 }
809
810 return nullptr;
811 }
812
identifyNodes(Instruction * RootI)813 bool ComplexDeinterleavingGraph::identifyNodes(Instruction *RootI) {
814 Instruction *Real;
815 Instruction *Imag;
816 if (!match(RootI, m_Shuffle(m_Instruction(Real), m_Instruction(Imag))))
817 return false;
818
819 RootValue = RootI;
820 AllInstructions.insert(RootI);
821 RootNode = identifyNode(Real, Imag);
822
823 LLVM_DEBUG({
824 Function *F = RootI->getFunction();
825 BasicBlock *B = RootI->getParent();
826 dbgs() << "Complex deinterleaving graph for " << F->getName()
827 << "::" << B->getName() << ".\n";
828 dump(dbgs());
829 dbgs() << "\n";
830 });
831
832 // Check all instructions have internal uses
833 for (const auto &Node : CompositeNodes) {
834 if (!Node->hasAllInternalUses(AllInstructions)) {
835 LLVM_DEBUG(dbgs() << " - Invalid internal uses\n");
836 return false;
837 }
838 }
839 return RootNode != nullptr;
840 }
841
replaceNode(ComplexDeinterleavingGraph::RawNodePtr Node)842 Value *ComplexDeinterleavingGraph::replaceNode(
843 ComplexDeinterleavingGraph::RawNodePtr Node) {
844 if (Node->ReplacementNode)
845 return Node->ReplacementNode;
846
847 Value *Input0 = replaceNode(Node->Operands[0]);
848 Value *Input1 = replaceNode(Node->Operands[1]);
849 Value *Accumulator =
850 Node->Operands.size() > 2 ? replaceNode(Node->Operands[2]) : nullptr;
851
852 assert(Input0->getType() == Input1->getType() &&
853 "Node inputs need to be of the same type");
854
855 Node->ReplacementNode = TL->createComplexDeinterleavingIR(
856 Node->Real, Node->Operation, Node->Rotation, Input0, Input1, Accumulator);
857
858 assert(Node->ReplacementNode && "Target failed to create Intrinsic call.");
859 NumComplexTransformations += 1;
860 return Node->ReplacementNode;
861 }
862
replaceNodes()863 void ComplexDeinterleavingGraph::replaceNodes() {
864 Value *R = replaceNode(RootNode.get());
865 assert(R && "Unable to find replacement for RootValue");
866 RootValue->replaceAllUsesWith(R);
867 }
868
hasAllInternalUses(SmallPtrSet<Instruction *,16> & AllInstructions)869 bool ComplexDeinterleavingCompositeNode::hasAllInternalUses(
870 SmallPtrSet<Instruction *, 16> &AllInstructions) {
871 if (Operation == ComplexDeinterleavingOperation::Shuffle)
872 return true;
873
874 for (auto *User : Real->users()) {
875 if (!AllInstructions.contains(cast<Instruction>(User)))
876 return false;
877 }
878 for (auto *User : Imag->users()) {
879 if (!AllInstructions.contains(cast<Instruction>(User)))
880 return false;
881 }
882 for (auto *I : InternalInstructions) {
883 for (auto *User : I->users()) {
884 if (!AllInstructions.contains(cast<Instruction>(User)))
885 return false;
886 }
887 }
888 return true;
889 }
890