xref: /aosp_15_r20/external/llvm/lib/Analysis/CostModel.cpp (revision 9880d6810fe72a1726cb53787c6711e909410d58)
1*9880d681SAndroid Build Coastguard Worker //===- CostModel.cpp ------ Cost Model Analysis ---------------------------===//
2*9880d681SAndroid Build Coastguard Worker //
3*9880d681SAndroid Build Coastguard Worker //                     The LLVM Compiler Infrastructure
4*9880d681SAndroid Build Coastguard Worker //
5*9880d681SAndroid Build Coastguard Worker // This file is distributed under the University of Illinois Open Source
6*9880d681SAndroid Build Coastguard Worker // License. See LICENSE.TXT for details.
7*9880d681SAndroid Build Coastguard Worker //
8*9880d681SAndroid Build Coastguard Worker //===----------------------------------------------------------------------===//
9*9880d681SAndroid Build Coastguard Worker //
10*9880d681SAndroid Build Coastguard Worker // This file defines the cost model analysis. It provides a very basic cost
11*9880d681SAndroid Build Coastguard Worker // estimation for LLVM-IR. This analysis uses the services of the codegen
12*9880d681SAndroid Build Coastguard Worker // to approximate the cost of any IR instruction when lowered to machine
13*9880d681SAndroid Build Coastguard Worker // instructions. The cost results are unit-less and the cost number represents
14*9880d681SAndroid Build Coastguard Worker // the throughput of the machine assuming that all loads hit the cache, all
15*9880d681SAndroid Build Coastguard Worker // branches are predicted, etc. The cost numbers can be added in order to
16*9880d681SAndroid Build Coastguard Worker // compare two or more transformation alternatives.
17*9880d681SAndroid Build Coastguard Worker //
18*9880d681SAndroid Build Coastguard Worker //===----------------------------------------------------------------------===//
19*9880d681SAndroid Build Coastguard Worker 
20*9880d681SAndroid Build Coastguard Worker #include "llvm/ADT/STLExtras.h"
21*9880d681SAndroid Build Coastguard Worker #include "llvm/Analysis/Passes.h"
22*9880d681SAndroid Build Coastguard Worker #include "llvm/Analysis/TargetTransformInfo.h"
23*9880d681SAndroid Build Coastguard Worker #include "llvm/IR/Function.h"
24*9880d681SAndroid Build Coastguard Worker #include "llvm/IR/Instructions.h"
25*9880d681SAndroid Build Coastguard Worker #include "llvm/IR/IntrinsicInst.h"
26*9880d681SAndroid Build Coastguard Worker #include "llvm/IR/Value.h"
27*9880d681SAndroid Build Coastguard Worker #include "llvm/Pass.h"
28*9880d681SAndroid Build Coastguard Worker #include "llvm/Support/CommandLine.h"
29*9880d681SAndroid Build Coastguard Worker #include "llvm/Support/Debug.h"
30*9880d681SAndroid Build Coastguard Worker #include "llvm/Support/raw_ostream.h"
31*9880d681SAndroid Build Coastguard Worker using namespace llvm;
32*9880d681SAndroid Build Coastguard Worker 
33*9880d681SAndroid Build Coastguard Worker #define CM_NAME "cost-model"
34*9880d681SAndroid Build Coastguard Worker #define DEBUG_TYPE CM_NAME
35*9880d681SAndroid Build Coastguard Worker 
36*9880d681SAndroid Build Coastguard Worker static cl::opt<bool> EnableReduxCost("costmodel-reduxcost", cl::init(false),
37*9880d681SAndroid Build Coastguard Worker                                      cl::Hidden,
38*9880d681SAndroid Build Coastguard Worker                                      cl::desc("Recognize reduction patterns."));
39*9880d681SAndroid Build Coastguard Worker 
40*9880d681SAndroid Build Coastguard Worker namespace {
41*9880d681SAndroid Build Coastguard Worker   class CostModelAnalysis : public FunctionPass {
42*9880d681SAndroid Build Coastguard Worker 
43*9880d681SAndroid Build Coastguard Worker   public:
44*9880d681SAndroid Build Coastguard Worker     static char ID; // Class identification, replacement for typeinfo
CostModelAnalysis()45*9880d681SAndroid Build Coastguard Worker     CostModelAnalysis() : FunctionPass(ID), F(nullptr), TTI(nullptr) {
46*9880d681SAndroid Build Coastguard Worker       initializeCostModelAnalysisPass(
47*9880d681SAndroid Build Coastguard Worker         *PassRegistry::getPassRegistry());
48*9880d681SAndroid Build Coastguard Worker     }
49*9880d681SAndroid Build Coastguard Worker 
50*9880d681SAndroid Build Coastguard Worker     /// Returns the expected cost of the instruction.
51*9880d681SAndroid Build Coastguard Worker     /// Returns -1 if the cost is unknown.
52*9880d681SAndroid Build Coastguard Worker     /// Note, this method does not cache the cost calculation and it
53*9880d681SAndroid Build Coastguard Worker     /// can be expensive in some cases.
54*9880d681SAndroid Build Coastguard Worker     unsigned getInstructionCost(const Instruction *I) const;
55*9880d681SAndroid Build Coastguard Worker 
56*9880d681SAndroid Build Coastguard Worker   private:
57*9880d681SAndroid Build Coastguard Worker     void getAnalysisUsage(AnalysisUsage &AU) const override;
58*9880d681SAndroid Build Coastguard Worker     bool runOnFunction(Function &F) override;
59*9880d681SAndroid Build Coastguard Worker     void print(raw_ostream &OS, const Module*) const override;
60*9880d681SAndroid Build Coastguard Worker 
61*9880d681SAndroid Build Coastguard Worker     /// The function that we analyze.
62*9880d681SAndroid Build Coastguard Worker     Function *F;
63*9880d681SAndroid Build Coastguard Worker     /// Target information.
64*9880d681SAndroid Build Coastguard Worker     const TargetTransformInfo *TTI;
65*9880d681SAndroid Build Coastguard Worker   };
66*9880d681SAndroid Build Coastguard Worker }  // End of anonymous namespace
67*9880d681SAndroid Build Coastguard Worker 
68*9880d681SAndroid Build Coastguard Worker // Register this pass.
69*9880d681SAndroid Build Coastguard Worker char CostModelAnalysis::ID = 0;
70*9880d681SAndroid Build Coastguard Worker static const char cm_name[] = "Cost Model Analysis";
INITIALIZE_PASS_BEGIN(CostModelAnalysis,CM_NAME,cm_name,false,true)71*9880d681SAndroid Build Coastguard Worker INITIALIZE_PASS_BEGIN(CostModelAnalysis, CM_NAME, cm_name, false, true)
72*9880d681SAndroid Build Coastguard Worker INITIALIZE_PASS_END  (CostModelAnalysis, CM_NAME, cm_name, false, true)
73*9880d681SAndroid Build Coastguard Worker 
74*9880d681SAndroid Build Coastguard Worker FunctionPass *llvm::createCostModelAnalysisPass() {
75*9880d681SAndroid Build Coastguard Worker   return new CostModelAnalysis();
76*9880d681SAndroid Build Coastguard Worker }
77*9880d681SAndroid Build Coastguard Worker 
78*9880d681SAndroid Build Coastguard Worker void
getAnalysisUsage(AnalysisUsage & AU) const79*9880d681SAndroid Build Coastguard Worker CostModelAnalysis::getAnalysisUsage(AnalysisUsage &AU) const {
80*9880d681SAndroid Build Coastguard Worker   AU.setPreservesAll();
81*9880d681SAndroid Build Coastguard Worker }
82*9880d681SAndroid Build Coastguard Worker 
83*9880d681SAndroid Build Coastguard Worker bool
runOnFunction(Function & F)84*9880d681SAndroid Build Coastguard Worker CostModelAnalysis::runOnFunction(Function &F) {
85*9880d681SAndroid Build Coastguard Worker  this->F = &F;
86*9880d681SAndroid Build Coastguard Worker  auto *TTIWP = getAnalysisIfAvailable<TargetTransformInfoWrapperPass>();
87*9880d681SAndroid Build Coastguard Worker  TTI = TTIWP ? &TTIWP->getTTI(F) : nullptr;
88*9880d681SAndroid Build Coastguard Worker 
89*9880d681SAndroid Build Coastguard Worker  return false;
90*9880d681SAndroid Build Coastguard Worker }
91*9880d681SAndroid Build Coastguard Worker 
isReverseVectorMask(SmallVectorImpl<int> & Mask)92*9880d681SAndroid Build Coastguard Worker static bool isReverseVectorMask(SmallVectorImpl<int> &Mask) {
93*9880d681SAndroid Build Coastguard Worker   for (unsigned i = 0, MaskSize = Mask.size(); i < MaskSize; ++i)
94*9880d681SAndroid Build Coastguard Worker     if (Mask[i] > 0 && Mask[i] != (int)(MaskSize - 1 - i))
95*9880d681SAndroid Build Coastguard Worker       return false;
96*9880d681SAndroid Build Coastguard Worker   return true;
97*9880d681SAndroid Build Coastguard Worker }
98*9880d681SAndroid Build Coastguard Worker 
isAlternateVectorMask(SmallVectorImpl<int> & Mask)99*9880d681SAndroid Build Coastguard Worker static bool isAlternateVectorMask(SmallVectorImpl<int> &Mask) {
100*9880d681SAndroid Build Coastguard Worker   bool isAlternate = true;
101*9880d681SAndroid Build Coastguard Worker   unsigned MaskSize = Mask.size();
102*9880d681SAndroid Build Coastguard Worker 
103*9880d681SAndroid Build Coastguard Worker   // Example: shufflevector A, B, <0,5,2,7>
104*9880d681SAndroid Build Coastguard Worker   for (unsigned i = 0; i < MaskSize && isAlternate; ++i) {
105*9880d681SAndroid Build Coastguard Worker     if (Mask[i] < 0)
106*9880d681SAndroid Build Coastguard Worker       continue;
107*9880d681SAndroid Build Coastguard Worker     isAlternate = Mask[i] == (int)((i & 1) ? MaskSize + i : i);
108*9880d681SAndroid Build Coastguard Worker   }
109*9880d681SAndroid Build Coastguard Worker 
110*9880d681SAndroid Build Coastguard Worker   if (isAlternate)
111*9880d681SAndroid Build Coastguard Worker     return true;
112*9880d681SAndroid Build Coastguard Worker 
113*9880d681SAndroid Build Coastguard Worker   isAlternate = true;
114*9880d681SAndroid Build Coastguard Worker   // Example: shufflevector A, B, <4,1,6,3>
115*9880d681SAndroid Build Coastguard Worker   for (unsigned i = 0; i < MaskSize && isAlternate; ++i) {
116*9880d681SAndroid Build Coastguard Worker     if (Mask[i] < 0)
117*9880d681SAndroid Build Coastguard Worker       continue;
118*9880d681SAndroid Build Coastguard Worker     isAlternate = Mask[i] == (int)((i & 1) ? i : MaskSize + i);
119*9880d681SAndroid Build Coastguard Worker   }
120*9880d681SAndroid Build Coastguard Worker 
121*9880d681SAndroid Build Coastguard Worker   return isAlternate;
122*9880d681SAndroid Build Coastguard Worker }
123*9880d681SAndroid Build Coastguard Worker 
getOperandInfo(Value * V)124*9880d681SAndroid Build Coastguard Worker static TargetTransformInfo::OperandValueKind getOperandInfo(Value *V) {
125*9880d681SAndroid Build Coastguard Worker   TargetTransformInfo::OperandValueKind OpInfo =
126*9880d681SAndroid Build Coastguard Worker     TargetTransformInfo::OK_AnyValue;
127*9880d681SAndroid Build Coastguard Worker 
128*9880d681SAndroid Build Coastguard Worker   // Check for a splat of a constant or for a non uniform vector of constants.
129*9880d681SAndroid Build Coastguard Worker   if (isa<ConstantVector>(V) || isa<ConstantDataVector>(V)) {
130*9880d681SAndroid Build Coastguard Worker     OpInfo = TargetTransformInfo::OK_NonUniformConstantValue;
131*9880d681SAndroid Build Coastguard Worker     if (cast<Constant>(V)->getSplatValue() != nullptr)
132*9880d681SAndroid Build Coastguard Worker       OpInfo = TargetTransformInfo::OK_UniformConstantValue;
133*9880d681SAndroid Build Coastguard Worker   }
134*9880d681SAndroid Build Coastguard Worker 
135*9880d681SAndroid Build Coastguard Worker   return OpInfo;
136*9880d681SAndroid Build Coastguard Worker }
137*9880d681SAndroid Build Coastguard Worker 
matchPairwiseShuffleMask(ShuffleVectorInst * SI,bool IsLeft,unsigned Level)138*9880d681SAndroid Build Coastguard Worker static bool matchPairwiseShuffleMask(ShuffleVectorInst *SI, bool IsLeft,
139*9880d681SAndroid Build Coastguard Worker                                      unsigned Level) {
140*9880d681SAndroid Build Coastguard Worker   // We don't need a shuffle if we just want to have element 0 in position 0 of
141*9880d681SAndroid Build Coastguard Worker   // the vector.
142*9880d681SAndroid Build Coastguard Worker   if (!SI && Level == 0 && IsLeft)
143*9880d681SAndroid Build Coastguard Worker     return true;
144*9880d681SAndroid Build Coastguard Worker   else if (!SI)
145*9880d681SAndroid Build Coastguard Worker     return false;
146*9880d681SAndroid Build Coastguard Worker 
147*9880d681SAndroid Build Coastguard Worker   SmallVector<int, 32> Mask(SI->getType()->getVectorNumElements(), -1);
148*9880d681SAndroid Build Coastguard Worker 
149*9880d681SAndroid Build Coastguard Worker   // Build a mask of 0, 2, ... (left) or 1, 3, ... (right) depending on whether
150*9880d681SAndroid Build Coastguard Worker   // we look at the left or right side.
151*9880d681SAndroid Build Coastguard Worker   for (unsigned i = 0, e = (1 << Level), val = !IsLeft; i != e; ++i, val += 2)
152*9880d681SAndroid Build Coastguard Worker     Mask[i] = val;
153*9880d681SAndroid Build Coastguard Worker 
154*9880d681SAndroid Build Coastguard Worker   SmallVector<int, 16> ActualMask = SI->getShuffleMask();
155*9880d681SAndroid Build Coastguard Worker   return Mask == ActualMask;
156*9880d681SAndroid Build Coastguard Worker }
157*9880d681SAndroid Build Coastguard Worker 
matchPairwiseReductionAtLevel(const BinaryOperator * BinOp,unsigned Level,unsigned NumLevels)158*9880d681SAndroid Build Coastguard Worker static bool matchPairwiseReductionAtLevel(const BinaryOperator *BinOp,
159*9880d681SAndroid Build Coastguard Worker                                           unsigned Level, unsigned NumLevels) {
160*9880d681SAndroid Build Coastguard Worker   // Match one level of pairwise operations.
161*9880d681SAndroid Build Coastguard Worker   // %rdx.shuf.0.0 = shufflevector <4 x float> %rdx, <4 x float> undef,
162*9880d681SAndroid Build Coastguard Worker   //       <4 x i32> <i32 0, i32 2 , i32 undef, i32 undef>
163*9880d681SAndroid Build Coastguard Worker   // %rdx.shuf.0.1 = shufflevector <4 x float> %rdx, <4 x float> undef,
164*9880d681SAndroid Build Coastguard Worker   //       <4 x i32> <i32 1, i32 3, i32 undef, i32 undef>
165*9880d681SAndroid Build Coastguard Worker   // %bin.rdx.0 = fadd <4 x float> %rdx.shuf.0.0, %rdx.shuf.0.1
166*9880d681SAndroid Build Coastguard Worker   if (BinOp == nullptr)
167*9880d681SAndroid Build Coastguard Worker     return false;
168*9880d681SAndroid Build Coastguard Worker 
169*9880d681SAndroid Build Coastguard Worker   assert(BinOp->getType()->isVectorTy() && "Expecting a vector type");
170*9880d681SAndroid Build Coastguard Worker 
171*9880d681SAndroid Build Coastguard Worker   unsigned Opcode = BinOp->getOpcode();
172*9880d681SAndroid Build Coastguard Worker   Value *L = BinOp->getOperand(0);
173*9880d681SAndroid Build Coastguard Worker   Value *R = BinOp->getOperand(1);
174*9880d681SAndroid Build Coastguard Worker 
175*9880d681SAndroid Build Coastguard Worker   ShuffleVectorInst *LS = dyn_cast<ShuffleVectorInst>(L);
176*9880d681SAndroid Build Coastguard Worker   if (!LS && Level)
177*9880d681SAndroid Build Coastguard Worker     return false;
178*9880d681SAndroid Build Coastguard Worker   ShuffleVectorInst *RS = dyn_cast<ShuffleVectorInst>(R);
179*9880d681SAndroid Build Coastguard Worker   if (!RS && Level)
180*9880d681SAndroid Build Coastguard Worker     return false;
181*9880d681SAndroid Build Coastguard Worker 
182*9880d681SAndroid Build Coastguard Worker   // On level 0 we can omit one shufflevector instruction.
183*9880d681SAndroid Build Coastguard Worker   if (!Level && !RS && !LS)
184*9880d681SAndroid Build Coastguard Worker     return false;
185*9880d681SAndroid Build Coastguard Worker 
186*9880d681SAndroid Build Coastguard Worker   // Shuffle inputs must match.
187*9880d681SAndroid Build Coastguard Worker   Value *NextLevelOpL = LS ? LS->getOperand(0) : nullptr;
188*9880d681SAndroid Build Coastguard Worker   Value *NextLevelOpR = RS ? RS->getOperand(0) : nullptr;
189*9880d681SAndroid Build Coastguard Worker   Value *NextLevelOp = nullptr;
190*9880d681SAndroid Build Coastguard Worker   if (NextLevelOpR && NextLevelOpL) {
191*9880d681SAndroid Build Coastguard Worker     // If we have two shuffles their operands must match.
192*9880d681SAndroid Build Coastguard Worker     if (NextLevelOpL != NextLevelOpR)
193*9880d681SAndroid Build Coastguard Worker       return false;
194*9880d681SAndroid Build Coastguard Worker 
195*9880d681SAndroid Build Coastguard Worker     NextLevelOp = NextLevelOpL;
196*9880d681SAndroid Build Coastguard Worker   } else if (Level == 0 && (NextLevelOpR || NextLevelOpL)) {
197*9880d681SAndroid Build Coastguard Worker     // On the first level we can omit the shufflevector <0, undef,...>. So the
198*9880d681SAndroid Build Coastguard Worker     // input to the other shufflevector <1, undef> must match with one of the
199*9880d681SAndroid Build Coastguard Worker     // inputs to the current binary operation.
200*9880d681SAndroid Build Coastguard Worker     // Example:
201*9880d681SAndroid Build Coastguard Worker     //  %NextLevelOpL = shufflevector %R, <1, undef ...>
202*9880d681SAndroid Build Coastguard Worker     //  %BinOp        = fadd          %NextLevelOpL, %R
203*9880d681SAndroid Build Coastguard Worker     if (NextLevelOpL && NextLevelOpL != R)
204*9880d681SAndroid Build Coastguard Worker       return false;
205*9880d681SAndroid Build Coastguard Worker     else if (NextLevelOpR && NextLevelOpR != L)
206*9880d681SAndroid Build Coastguard Worker       return false;
207*9880d681SAndroid Build Coastguard Worker 
208*9880d681SAndroid Build Coastguard Worker     NextLevelOp = NextLevelOpL ? R : L;
209*9880d681SAndroid Build Coastguard Worker   } else
210*9880d681SAndroid Build Coastguard Worker     return false;
211*9880d681SAndroid Build Coastguard Worker 
212*9880d681SAndroid Build Coastguard Worker   // Check that the next levels binary operation exists and matches with the
213*9880d681SAndroid Build Coastguard Worker   // current one.
214*9880d681SAndroid Build Coastguard Worker   BinaryOperator *NextLevelBinOp = nullptr;
215*9880d681SAndroid Build Coastguard Worker   if (Level + 1 != NumLevels) {
216*9880d681SAndroid Build Coastguard Worker     if (!(NextLevelBinOp = dyn_cast<BinaryOperator>(NextLevelOp)))
217*9880d681SAndroid Build Coastguard Worker       return false;
218*9880d681SAndroid Build Coastguard Worker     else if (NextLevelBinOp->getOpcode() != Opcode)
219*9880d681SAndroid Build Coastguard Worker       return false;
220*9880d681SAndroid Build Coastguard Worker   }
221*9880d681SAndroid Build Coastguard Worker 
222*9880d681SAndroid Build Coastguard Worker   // Shuffle mask for pairwise operation must match.
223*9880d681SAndroid Build Coastguard Worker   if (matchPairwiseShuffleMask(LS, true, Level)) {
224*9880d681SAndroid Build Coastguard Worker     if (!matchPairwiseShuffleMask(RS, false, Level))
225*9880d681SAndroid Build Coastguard Worker       return false;
226*9880d681SAndroid Build Coastguard Worker   } else if (matchPairwiseShuffleMask(RS, true, Level)) {
227*9880d681SAndroid Build Coastguard Worker     if (!matchPairwiseShuffleMask(LS, false, Level))
228*9880d681SAndroid Build Coastguard Worker       return false;
229*9880d681SAndroid Build Coastguard Worker   } else
230*9880d681SAndroid Build Coastguard Worker     return false;
231*9880d681SAndroid Build Coastguard Worker 
232*9880d681SAndroid Build Coastguard Worker   if (++Level == NumLevels)
233*9880d681SAndroid Build Coastguard Worker     return true;
234*9880d681SAndroid Build Coastguard Worker 
235*9880d681SAndroid Build Coastguard Worker   // Match next level.
236*9880d681SAndroid Build Coastguard Worker   return matchPairwiseReductionAtLevel(NextLevelBinOp, Level, NumLevels);
237*9880d681SAndroid Build Coastguard Worker }
238*9880d681SAndroid Build Coastguard Worker 
matchPairwiseReduction(const ExtractElementInst * ReduxRoot,unsigned & Opcode,Type * & Ty)239*9880d681SAndroid Build Coastguard Worker static bool matchPairwiseReduction(const ExtractElementInst *ReduxRoot,
240*9880d681SAndroid Build Coastguard Worker                                    unsigned &Opcode, Type *&Ty) {
241*9880d681SAndroid Build Coastguard Worker   if (!EnableReduxCost)
242*9880d681SAndroid Build Coastguard Worker     return false;
243*9880d681SAndroid Build Coastguard Worker 
244*9880d681SAndroid Build Coastguard Worker   // Need to extract the first element.
245*9880d681SAndroid Build Coastguard Worker   ConstantInt *CI = dyn_cast<ConstantInt>(ReduxRoot->getOperand(1));
246*9880d681SAndroid Build Coastguard Worker   unsigned Idx = ~0u;
247*9880d681SAndroid Build Coastguard Worker   if (CI)
248*9880d681SAndroid Build Coastguard Worker     Idx = CI->getZExtValue();
249*9880d681SAndroid Build Coastguard Worker   if (Idx != 0)
250*9880d681SAndroid Build Coastguard Worker     return false;
251*9880d681SAndroid Build Coastguard Worker 
252*9880d681SAndroid Build Coastguard Worker   BinaryOperator *RdxStart = dyn_cast<BinaryOperator>(ReduxRoot->getOperand(0));
253*9880d681SAndroid Build Coastguard Worker   if (!RdxStart)
254*9880d681SAndroid Build Coastguard Worker     return false;
255*9880d681SAndroid Build Coastguard Worker 
256*9880d681SAndroid Build Coastguard Worker   Type *VecTy = ReduxRoot->getOperand(0)->getType();
257*9880d681SAndroid Build Coastguard Worker   unsigned NumVecElems = VecTy->getVectorNumElements();
258*9880d681SAndroid Build Coastguard Worker   if (!isPowerOf2_32(NumVecElems))
259*9880d681SAndroid Build Coastguard Worker     return false;
260*9880d681SAndroid Build Coastguard Worker 
261*9880d681SAndroid Build Coastguard Worker   // We look for a sequence of shuffle,shuffle,add triples like the following
262*9880d681SAndroid Build Coastguard Worker   // that builds a pairwise reduction tree.
263*9880d681SAndroid Build Coastguard Worker   //
264*9880d681SAndroid Build Coastguard Worker   //  (X0, X1, X2, X3)
265*9880d681SAndroid Build Coastguard Worker   //   (X0 + X1, X2 + X3, undef, undef)
266*9880d681SAndroid Build Coastguard Worker   //    ((X0 + X1) + (X2 + X3), undef, undef, undef)
267*9880d681SAndroid Build Coastguard Worker   //
268*9880d681SAndroid Build Coastguard Worker   // %rdx.shuf.0.0 = shufflevector <4 x float> %rdx, <4 x float> undef,
269*9880d681SAndroid Build Coastguard Worker   //       <4 x i32> <i32 0, i32 2 , i32 undef, i32 undef>
270*9880d681SAndroid Build Coastguard Worker   // %rdx.shuf.0.1 = shufflevector <4 x float> %rdx, <4 x float> undef,
271*9880d681SAndroid Build Coastguard Worker   //       <4 x i32> <i32 1, i32 3, i32 undef, i32 undef>
272*9880d681SAndroid Build Coastguard Worker   // %bin.rdx.0 = fadd <4 x float> %rdx.shuf.0.0, %rdx.shuf.0.1
273*9880d681SAndroid Build Coastguard Worker   // %rdx.shuf.1.0 = shufflevector <4 x float> %bin.rdx.0, <4 x float> undef,
274*9880d681SAndroid Build Coastguard Worker   //       <4 x i32> <i32 0, i32 undef, i32 undef, i32 undef>
275*9880d681SAndroid Build Coastguard Worker   // %rdx.shuf.1.1 = shufflevector <4 x float> %bin.rdx.0, <4 x float> undef,
276*9880d681SAndroid Build Coastguard Worker   //       <4 x i32> <i32 1, i32 undef, i32 undef, i32 undef>
277*9880d681SAndroid Build Coastguard Worker   // %bin.rdx8 = fadd <4 x float> %rdx.shuf.1.0, %rdx.shuf.1.1
278*9880d681SAndroid Build Coastguard Worker   // %r = extractelement <4 x float> %bin.rdx8, i32 0
279*9880d681SAndroid Build Coastguard Worker   if (!matchPairwiseReductionAtLevel(RdxStart, 0,  Log2_32(NumVecElems)))
280*9880d681SAndroid Build Coastguard Worker     return false;
281*9880d681SAndroid Build Coastguard Worker 
282*9880d681SAndroid Build Coastguard Worker   Opcode = RdxStart->getOpcode();
283*9880d681SAndroid Build Coastguard Worker   Ty = VecTy;
284*9880d681SAndroid Build Coastguard Worker 
285*9880d681SAndroid Build Coastguard Worker   return true;
286*9880d681SAndroid Build Coastguard Worker }
287*9880d681SAndroid Build Coastguard Worker 
288*9880d681SAndroid Build Coastguard Worker static std::pair<Value *, ShuffleVectorInst *>
getShuffleAndOtherOprd(BinaryOperator * B)289*9880d681SAndroid Build Coastguard Worker getShuffleAndOtherOprd(BinaryOperator *B) {
290*9880d681SAndroid Build Coastguard Worker 
291*9880d681SAndroid Build Coastguard Worker   Value *L = B->getOperand(0);
292*9880d681SAndroid Build Coastguard Worker   Value *R = B->getOperand(1);
293*9880d681SAndroid Build Coastguard Worker   ShuffleVectorInst *S = nullptr;
294*9880d681SAndroid Build Coastguard Worker 
295*9880d681SAndroid Build Coastguard Worker   if ((S = dyn_cast<ShuffleVectorInst>(L)))
296*9880d681SAndroid Build Coastguard Worker     return std::make_pair(R, S);
297*9880d681SAndroid Build Coastguard Worker 
298*9880d681SAndroid Build Coastguard Worker   S = dyn_cast<ShuffleVectorInst>(R);
299*9880d681SAndroid Build Coastguard Worker   return std::make_pair(L, S);
300*9880d681SAndroid Build Coastguard Worker }
301*9880d681SAndroid Build Coastguard Worker 
matchVectorSplittingReduction(const ExtractElementInst * ReduxRoot,unsigned & Opcode,Type * & Ty)302*9880d681SAndroid Build Coastguard Worker static bool matchVectorSplittingReduction(const ExtractElementInst *ReduxRoot,
303*9880d681SAndroid Build Coastguard Worker                                           unsigned &Opcode, Type *&Ty) {
304*9880d681SAndroid Build Coastguard Worker   if (!EnableReduxCost)
305*9880d681SAndroid Build Coastguard Worker     return false;
306*9880d681SAndroid Build Coastguard Worker 
307*9880d681SAndroid Build Coastguard Worker   // Need to extract the first element.
308*9880d681SAndroid Build Coastguard Worker   ConstantInt *CI = dyn_cast<ConstantInt>(ReduxRoot->getOperand(1));
309*9880d681SAndroid Build Coastguard Worker   unsigned Idx = ~0u;
310*9880d681SAndroid Build Coastguard Worker   if (CI)
311*9880d681SAndroid Build Coastguard Worker     Idx = CI->getZExtValue();
312*9880d681SAndroid Build Coastguard Worker   if (Idx != 0)
313*9880d681SAndroid Build Coastguard Worker     return false;
314*9880d681SAndroid Build Coastguard Worker 
315*9880d681SAndroid Build Coastguard Worker   BinaryOperator *RdxStart = dyn_cast<BinaryOperator>(ReduxRoot->getOperand(0));
316*9880d681SAndroid Build Coastguard Worker   if (!RdxStart)
317*9880d681SAndroid Build Coastguard Worker     return false;
318*9880d681SAndroid Build Coastguard Worker   unsigned RdxOpcode = RdxStart->getOpcode();
319*9880d681SAndroid Build Coastguard Worker 
320*9880d681SAndroid Build Coastguard Worker   Type *VecTy = ReduxRoot->getOperand(0)->getType();
321*9880d681SAndroid Build Coastguard Worker   unsigned NumVecElems = VecTy->getVectorNumElements();
322*9880d681SAndroid Build Coastguard Worker   if (!isPowerOf2_32(NumVecElems))
323*9880d681SAndroid Build Coastguard Worker     return false;
324*9880d681SAndroid Build Coastguard Worker 
325*9880d681SAndroid Build Coastguard Worker   // We look for a sequence of shuffles and adds like the following matching one
326*9880d681SAndroid Build Coastguard Worker   // fadd, shuffle vector pair at a time.
327*9880d681SAndroid Build Coastguard Worker   //
328*9880d681SAndroid Build Coastguard Worker   // %rdx.shuf = shufflevector <4 x float> %rdx, <4 x float> undef,
329*9880d681SAndroid Build Coastguard Worker   //                           <4 x i32> <i32 2, i32 3, i32 undef, i32 undef>
330*9880d681SAndroid Build Coastguard Worker   // %bin.rdx = fadd <4 x float> %rdx, %rdx.shuf
331*9880d681SAndroid Build Coastguard Worker   // %rdx.shuf7 = shufflevector <4 x float> %bin.rdx, <4 x float> undef,
332*9880d681SAndroid Build Coastguard Worker   //                          <4 x i32> <i32 1, i32 undef, i32 undef, i32 undef>
333*9880d681SAndroid Build Coastguard Worker   // %bin.rdx8 = fadd <4 x float> %bin.rdx, %rdx.shuf7
334*9880d681SAndroid Build Coastguard Worker   // %r = extractelement <4 x float> %bin.rdx8, i32 0
335*9880d681SAndroid Build Coastguard Worker 
336*9880d681SAndroid Build Coastguard Worker   unsigned MaskStart = 1;
337*9880d681SAndroid Build Coastguard Worker   Value *RdxOp = RdxStart;
338*9880d681SAndroid Build Coastguard Worker   SmallVector<int, 32> ShuffleMask(NumVecElems, 0);
339*9880d681SAndroid Build Coastguard Worker   unsigned NumVecElemsRemain = NumVecElems;
340*9880d681SAndroid Build Coastguard Worker   while (NumVecElemsRemain - 1) {
341*9880d681SAndroid Build Coastguard Worker     // Check for the right reduction operation.
342*9880d681SAndroid Build Coastguard Worker     BinaryOperator *BinOp;
343*9880d681SAndroid Build Coastguard Worker     if (!(BinOp = dyn_cast<BinaryOperator>(RdxOp)))
344*9880d681SAndroid Build Coastguard Worker       return false;
345*9880d681SAndroid Build Coastguard Worker     if (BinOp->getOpcode() != RdxOpcode)
346*9880d681SAndroid Build Coastguard Worker       return false;
347*9880d681SAndroid Build Coastguard Worker 
348*9880d681SAndroid Build Coastguard Worker     Value *NextRdxOp;
349*9880d681SAndroid Build Coastguard Worker     ShuffleVectorInst *Shuffle;
350*9880d681SAndroid Build Coastguard Worker     std::tie(NextRdxOp, Shuffle) = getShuffleAndOtherOprd(BinOp);
351*9880d681SAndroid Build Coastguard Worker 
352*9880d681SAndroid Build Coastguard Worker     // Check the current reduction operation and the shuffle use the same value.
353*9880d681SAndroid Build Coastguard Worker     if (Shuffle == nullptr)
354*9880d681SAndroid Build Coastguard Worker       return false;
355*9880d681SAndroid Build Coastguard Worker     if (Shuffle->getOperand(0) != NextRdxOp)
356*9880d681SAndroid Build Coastguard Worker       return false;
357*9880d681SAndroid Build Coastguard Worker 
358*9880d681SAndroid Build Coastguard Worker     // Check that shuffle masks matches.
359*9880d681SAndroid Build Coastguard Worker     for (unsigned j = 0; j != MaskStart; ++j)
360*9880d681SAndroid Build Coastguard Worker       ShuffleMask[j] = MaskStart + j;
361*9880d681SAndroid Build Coastguard Worker     // Fill the rest of the mask with -1 for undef.
362*9880d681SAndroid Build Coastguard Worker     std::fill(&ShuffleMask[MaskStart], ShuffleMask.end(), -1);
363*9880d681SAndroid Build Coastguard Worker 
364*9880d681SAndroid Build Coastguard Worker     SmallVector<int, 16> Mask = Shuffle->getShuffleMask();
365*9880d681SAndroid Build Coastguard Worker     if (ShuffleMask != Mask)
366*9880d681SAndroid Build Coastguard Worker       return false;
367*9880d681SAndroid Build Coastguard Worker 
368*9880d681SAndroid Build Coastguard Worker     RdxOp = NextRdxOp;
369*9880d681SAndroid Build Coastguard Worker     NumVecElemsRemain /= 2;
370*9880d681SAndroid Build Coastguard Worker     MaskStart *= 2;
371*9880d681SAndroid Build Coastguard Worker   }
372*9880d681SAndroid Build Coastguard Worker 
373*9880d681SAndroid Build Coastguard Worker   Opcode = RdxOpcode;
374*9880d681SAndroid Build Coastguard Worker   Ty = VecTy;
375*9880d681SAndroid Build Coastguard Worker   return true;
376*9880d681SAndroid Build Coastguard Worker }
377*9880d681SAndroid Build Coastguard Worker 
getInstructionCost(const Instruction * I) const378*9880d681SAndroid Build Coastguard Worker unsigned CostModelAnalysis::getInstructionCost(const Instruction *I) const {
379*9880d681SAndroid Build Coastguard Worker   if (!TTI)
380*9880d681SAndroid Build Coastguard Worker     return -1;
381*9880d681SAndroid Build Coastguard Worker 
382*9880d681SAndroid Build Coastguard Worker   switch (I->getOpcode()) {
383*9880d681SAndroid Build Coastguard Worker   case Instruction::GetElementPtr:
384*9880d681SAndroid Build Coastguard Worker     return TTI->getUserCost(I);
385*9880d681SAndroid Build Coastguard Worker 
386*9880d681SAndroid Build Coastguard Worker   case Instruction::Ret:
387*9880d681SAndroid Build Coastguard Worker   case Instruction::PHI:
388*9880d681SAndroid Build Coastguard Worker   case Instruction::Br: {
389*9880d681SAndroid Build Coastguard Worker     return TTI->getCFInstrCost(I->getOpcode());
390*9880d681SAndroid Build Coastguard Worker   }
391*9880d681SAndroid Build Coastguard Worker   case Instruction::Add:
392*9880d681SAndroid Build Coastguard Worker   case Instruction::FAdd:
393*9880d681SAndroid Build Coastguard Worker   case Instruction::Sub:
394*9880d681SAndroid Build Coastguard Worker   case Instruction::FSub:
395*9880d681SAndroid Build Coastguard Worker   case Instruction::Mul:
396*9880d681SAndroid Build Coastguard Worker   case Instruction::FMul:
397*9880d681SAndroid Build Coastguard Worker   case Instruction::UDiv:
398*9880d681SAndroid Build Coastguard Worker   case Instruction::SDiv:
399*9880d681SAndroid Build Coastguard Worker   case Instruction::FDiv:
400*9880d681SAndroid Build Coastguard Worker   case Instruction::URem:
401*9880d681SAndroid Build Coastguard Worker   case Instruction::SRem:
402*9880d681SAndroid Build Coastguard Worker   case Instruction::FRem:
403*9880d681SAndroid Build Coastguard Worker   case Instruction::Shl:
404*9880d681SAndroid Build Coastguard Worker   case Instruction::LShr:
405*9880d681SAndroid Build Coastguard Worker   case Instruction::AShr:
406*9880d681SAndroid Build Coastguard Worker   case Instruction::And:
407*9880d681SAndroid Build Coastguard Worker   case Instruction::Or:
408*9880d681SAndroid Build Coastguard Worker   case Instruction::Xor: {
409*9880d681SAndroid Build Coastguard Worker     TargetTransformInfo::OperandValueKind Op1VK =
410*9880d681SAndroid Build Coastguard Worker       getOperandInfo(I->getOperand(0));
411*9880d681SAndroid Build Coastguard Worker     TargetTransformInfo::OperandValueKind Op2VK =
412*9880d681SAndroid Build Coastguard Worker       getOperandInfo(I->getOperand(1));
413*9880d681SAndroid Build Coastguard Worker     return TTI->getArithmeticInstrCost(I->getOpcode(), I->getType(), Op1VK,
414*9880d681SAndroid Build Coastguard Worker                                        Op2VK);
415*9880d681SAndroid Build Coastguard Worker   }
416*9880d681SAndroid Build Coastguard Worker   case Instruction::Select: {
417*9880d681SAndroid Build Coastguard Worker     const SelectInst *SI = cast<SelectInst>(I);
418*9880d681SAndroid Build Coastguard Worker     Type *CondTy = SI->getCondition()->getType();
419*9880d681SAndroid Build Coastguard Worker     return TTI->getCmpSelInstrCost(I->getOpcode(), I->getType(), CondTy);
420*9880d681SAndroid Build Coastguard Worker   }
421*9880d681SAndroid Build Coastguard Worker   case Instruction::ICmp:
422*9880d681SAndroid Build Coastguard Worker   case Instruction::FCmp: {
423*9880d681SAndroid Build Coastguard Worker     Type *ValTy = I->getOperand(0)->getType();
424*9880d681SAndroid Build Coastguard Worker     return TTI->getCmpSelInstrCost(I->getOpcode(), ValTy);
425*9880d681SAndroid Build Coastguard Worker   }
426*9880d681SAndroid Build Coastguard Worker   case Instruction::Store: {
427*9880d681SAndroid Build Coastguard Worker     const StoreInst *SI = cast<StoreInst>(I);
428*9880d681SAndroid Build Coastguard Worker     Type *ValTy = SI->getValueOperand()->getType();
429*9880d681SAndroid Build Coastguard Worker     return TTI->getMemoryOpCost(I->getOpcode(), ValTy,
430*9880d681SAndroid Build Coastguard Worker                                  SI->getAlignment(),
431*9880d681SAndroid Build Coastguard Worker                                  SI->getPointerAddressSpace());
432*9880d681SAndroid Build Coastguard Worker   }
433*9880d681SAndroid Build Coastguard Worker   case Instruction::Load: {
434*9880d681SAndroid Build Coastguard Worker     const LoadInst *LI = cast<LoadInst>(I);
435*9880d681SAndroid Build Coastguard Worker     return TTI->getMemoryOpCost(I->getOpcode(), I->getType(),
436*9880d681SAndroid Build Coastguard Worker                                  LI->getAlignment(),
437*9880d681SAndroid Build Coastguard Worker                                  LI->getPointerAddressSpace());
438*9880d681SAndroid Build Coastguard Worker   }
439*9880d681SAndroid Build Coastguard Worker   case Instruction::ZExt:
440*9880d681SAndroid Build Coastguard Worker   case Instruction::SExt:
441*9880d681SAndroid Build Coastguard Worker   case Instruction::FPToUI:
442*9880d681SAndroid Build Coastguard Worker   case Instruction::FPToSI:
443*9880d681SAndroid Build Coastguard Worker   case Instruction::FPExt:
444*9880d681SAndroid Build Coastguard Worker   case Instruction::PtrToInt:
445*9880d681SAndroid Build Coastguard Worker   case Instruction::IntToPtr:
446*9880d681SAndroid Build Coastguard Worker   case Instruction::SIToFP:
447*9880d681SAndroid Build Coastguard Worker   case Instruction::UIToFP:
448*9880d681SAndroid Build Coastguard Worker   case Instruction::Trunc:
449*9880d681SAndroid Build Coastguard Worker   case Instruction::FPTrunc:
450*9880d681SAndroid Build Coastguard Worker   case Instruction::BitCast:
451*9880d681SAndroid Build Coastguard Worker   case Instruction::AddrSpaceCast: {
452*9880d681SAndroid Build Coastguard Worker     Type *SrcTy = I->getOperand(0)->getType();
453*9880d681SAndroid Build Coastguard Worker     return TTI->getCastInstrCost(I->getOpcode(), I->getType(), SrcTy);
454*9880d681SAndroid Build Coastguard Worker   }
455*9880d681SAndroid Build Coastguard Worker   case Instruction::ExtractElement: {
456*9880d681SAndroid Build Coastguard Worker     const ExtractElementInst * EEI = cast<ExtractElementInst>(I);
457*9880d681SAndroid Build Coastguard Worker     ConstantInt *CI = dyn_cast<ConstantInt>(I->getOperand(1));
458*9880d681SAndroid Build Coastguard Worker     unsigned Idx = -1;
459*9880d681SAndroid Build Coastguard Worker     if (CI)
460*9880d681SAndroid Build Coastguard Worker       Idx = CI->getZExtValue();
461*9880d681SAndroid Build Coastguard Worker 
462*9880d681SAndroid Build Coastguard Worker     // Try to match a reduction sequence (series of shufflevector and vector
463*9880d681SAndroid Build Coastguard Worker     // adds followed by a extractelement).
464*9880d681SAndroid Build Coastguard Worker     unsigned ReduxOpCode;
465*9880d681SAndroid Build Coastguard Worker     Type *ReduxType;
466*9880d681SAndroid Build Coastguard Worker 
467*9880d681SAndroid Build Coastguard Worker     if (matchVectorSplittingReduction(EEI, ReduxOpCode, ReduxType))
468*9880d681SAndroid Build Coastguard Worker       return TTI->getReductionCost(ReduxOpCode, ReduxType, false);
469*9880d681SAndroid Build Coastguard Worker     else if (matchPairwiseReduction(EEI, ReduxOpCode, ReduxType))
470*9880d681SAndroid Build Coastguard Worker       return TTI->getReductionCost(ReduxOpCode, ReduxType, true);
471*9880d681SAndroid Build Coastguard Worker 
472*9880d681SAndroid Build Coastguard Worker     return TTI->getVectorInstrCost(I->getOpcode(),
473*9880d681SAndroid Build Coastguard Worker                                    EEI->getOperand(0)->getType(), Idx);
474*9880d681SAndroid Build Coastguard Worker   }
475*9880d681SAndroid Build Coastguard Worker   case Instruction::InsertElement: {
476*9880d681SAndroid Build Coastguard Worker     const InsertElementInst * IE = cast<InsertElementInst>(I);
477*9880d681SAndroid Build Coastguard Worker     ConstantInt *CI = dyn_cast<ConstantInt>(IE->getOperand(2));
478*9880d681SAndroid Build Coastguard Worker     unsigned Idx = -1;
479*9880d681SAndroid Build Coastguard Worker     if (CI)
480*9880d681SAndroid Build Coastguard Worker       Idx = CI->getZExtValue();
481*9880d681SAndroid Build Coastguard Worker     return TTI->getVectorInstrCost(I->getOpcode(),
482*9880d681SAndroid Build Coastguard Worker                                    IE->getType(), Idx);
483*9880d681SAndroid Build Coastguard Worker   }
484*9880d681SAndroid Build Coastguard Worker   case Instruction::ShuffleVector: {
485*9880d681SAndroid Build Coastguard Worker     const ShuffleVectorInst *Shuffle = cast<ShuffleVectorInst>(I);
486*9880d681SAndroid Build Coastguard Worker     Type *VecTypOp0 = Shuffle->getOperand(0)->getType();
487*9880d681SAndroid Build Coastguard Worker     unsigned NumVecElems = VecTypOp0->getVectorNumElements();
488*9880d681SAndroid Build Coastguard Worker     SmallVector<int, 16> Mask = Shuffle->getShuffleMask();
489*9880d681SAndroid Build Coastguard Worker 
490*9880d681SAndroid Build Coastguard Worker     if (NumVecElems == Mask.size()) {
491*9880d681SAndroid Build Coastguard Worker       if (isReverseVectorMask(Mask))
492*9880d681SAndroid Build Coastguard Worker         return TTI->getShuffleCost(TargetTransformInfo::SK_Reverse, VecTypOp0,
493*9880d681SAndroid Build Coastguard Worker                                    0, nullptr);
494*9880d681SAndroid Build Coastguard Worker       if (isAlternateVectorMask(Mask))
495*9880d681SAndroid Build Coastguard Worker         return TTI->getShuffleCost(TargetTransformInfo::SK_Alternate,
496*9880d681SAndroid Build Coastguard Worker                                    VecTypOp0, 0, nullptr);
497*9880d681SAndroid Build Coastguard Worker     }
498*9880d681SAndroid Build Coastguard Worker 
499*9880d681SAndroid Build Coastguard Worker     return -1;
500*9880d681SAndroid Build Coastguard Worker   }
501*9880d681SAndroid Build Coastguard Worker   case Instruction::Call:
502*9880d681SAndroid Build Coastguard Worker     if (const IntrinsicInst *II = dyn_cast<IntrinsicInst>(I)) {
503*9880d681SAndroid Build Coastguard Worker       SmallVector<Value *, 4> Args;
504*9880d681SAndroid Build Coastguard Worker       for (unsigned J = 0, JE = II->getNumArgOperands(); J != JE; ++J)
505*9880d681SAndroid Build Coastguard Worker         Args.push_back(II->getArgOperand(J));
506*9880d681SAndroid Build Coastguard Worker 
507*9880d681SAndroid Build Coastguard Worker       FastMathFlags FMF;
508*9880d681SAndroid Build Coastguard Worker       if (auto *FPMO = dyn_cast<FPMathOperator>(II))
509*9880d681SAndroid Build Coastguard Worker         FMF = FPMO->getFastMathFlags();
510*9880d681SAndroid Build Coastguard Worker 
511*9880d681SAndroid Build Coastguard Worker       return TTI->getIntrinsicInstrCost(II->getIntrinsicID(), II->getType(),
512*9880d681SAndroid Build Coastguard Worker                                         Args, FMF);
513*9880d681SAndroid Build Coastguard Worker     }
514*9880d681SAndroid Build Coastguard Worker     return -1;
515*9880d681SAndroid Build Coastguard Worker   default:
516*9880d681SAndroid Build Coastguard Worker     // We don't have any information on this instruction.
517*9880d681SAndroid Build Coastguard Worker     return -1;
518*9880d681SAndroid Build Coastguard Worker   }
519*9880d681SAndroid Build Coastguard Worker }
520*9880d681SAndroid Build Coastguard Worker 
print(raw_ostream & OS,const Module *) const521*9880d681SAndroid Build Coastguard Worker void CostModelAnalysis::print(raw_ostream &OS, const Module*) const {
522*9880d681SAndroid Build Coastguard Worker   if (!F)
523*9880d681SAndroid Build Coastguard Worker     return;
524*9880d681SAndroid Build Coastguard Worker 
525*9880d681SAndroid Build Coastguard Worker   for (BasicBlock &B : *F) {
526*9880d681SAndroid Build Coastguard Worker     for (Instruction &Inst : B) {
527*9880d681SAndroid Build Coastguard Worker       unsigned Cost = getInstructionCost(&Inst);
528*9880d681SAndroid Build Coastguard Worker       if (Cost != (unsigned)-1)
529*9880d681SAndroid Build Coastguard Worker         OS << "Cost Model: Found an estimated cost of " << Cost;
530*9880d681SAndroid Build Coastguard Worker       else
531*9880d681SAndroid Build Coastguard Worker         OS << "Cost Model: Unknown cost";
532*9880d681SAndroid Build Coastguard Worker 
533*9880d681SAndroid Build Coastguard Worker       OS << " for instruction: " << Inst << "\n";
534*9880d681SAndroid Build Coastguard Worker     }
535*9880d681SAndroid Build Coastguard Worker   }
536*9880d681SAndroid Build Coastguard Worker }
537