1*9880d681SAndroid Build Coastguard Worker //===- CostModel.cpp ------ Cost Model Analysis ---------------------------===//
2*9880d681SAndroid Build Coastguard Worker //
3*9880d681SAndroid Build Coastguard Worker // The LLVM Compiler Infrastructure
4*9880d681SAndroid Build Coastguard Worker //
5*9880d681SAndroid Build Coastguard Worker // This file is distributed under the University of Illinois Open Source
6*9880d681SAndroid Build Coastguard Worker // License. See LICENSE.TXT for details.
7*9880d681SAndroid Build Coastguard Worker //
8*9880d681SAndroid Build Coastguard Worker //===----------------------------------------------------------------------===//
9*9880d681SAndroid Build Coastguard Worker //
10*9880d681SAndroid Build Coastguard Worker // This file defines the cost model analysis. It provides a very basic cost
11*9880d681SAndroid Build Coastguard Worker // estimation for LLVM-IR. This analysis uses the services of the codegen
12*9880d681SAndroid Build Coastguard Worker // to approximate the cost of any IR instruction when lowered to machine
13*9880d681SAndroid Build Coastguard Worker // instructions. The cost results are unit-less and the cost number represents
14*9880d681SAndroid Build Coastguard Worker // the throughput of the machine assuming that all loads hit the cache, all
15*9880d681SAndroid Build Coastguard Worker // branches are predicted, etc. The cost numbers can be added in order to
16*9880d681SAndroid Build Coastguard Worker // compare two or more transformation alternatives.
17*9880d681SAndroid Build Coastguard Worker //
18*9880d681SAndroid Build Coastguard Worker //===----------------------------------------------------------------------===//
19*9880d681SAndroid Build Coastguard Worker
20*9880d681SAndroid Build Coastguard Worker #include "llvm/ADT/STLExtras.h"
21*9880d681SAndroid Build Coastguard Worker #include "llvm/Analysis/Passes.h"
22*9880d681SAndroid Build Coastguard Worker #include "llvm/Analysis/TargetTransformInfo.h"
23*9880d681SAndroid Build Coastguard Worker #include "llvm/IR/Function.h"
24*9880d681SAndroid Build Coastguard Worker #include "llvm/IR/Instructions.h"
25*9880d681SAndroid Build Coastguard Worker #include "llvm/IR/IntrinsicInst.h"
26*9880d681SAndroid Build Coastguard Worker #include "llvm/IR/Value.h"
27*9880d681SAndroid Build Coastguard Worker #include "llvm/Pass.h"
28*9880d681SAndroid Build Coastguard Worker #include "llvm/Support/CommandLine.h"
29*9880d681SAndroid Build Coastguard Worker #include "llvm/Support/Debug.h"
30*9880d681SAndroid Build Coastguard Worker #include "llvm/Support/raw_ostream.h"
31*9880d681SAndroid Build Coastguard Worker using namespace llvm;
32*9880d681SAndroid Build Coastguard Worker
33*9880d681SAndroid Build Coastguard Worker #define CM_NAME "cost-model"
34*9880d681SAndroid Build Coastguard Worker #define DEBUG_TYPE CM_NAME
35*9880d681SAndroid Build Coastguard Worker
36*9880d681SAndroid Build Coastguard Worker static cl::opt<bool> EnableReduxCost("costmodel-reduxcost", cl::init(false),
37*9880d681SAndroid Build Coastguard Worker cl::Hidden,
38*9880d681SAndroid Build Coastguard Worker cl::desc("Recognize reduction patterns."));
39*9880d681SAndroid Build Coastguard Worker
40*9880d681SAndroid Build Coastguard Worker namespace {
41*9880d681SAndroid Build Coastguard Worker class CostModelAnalysis : public FunctionPass {
42*9880d681SAndroid Build Coastguard Worker
43*9880d681SAndroid Build Coastguard Worker public:
44*9880d681SAndroid Build Coastguard Worker static char ID; // Class identification, replacement for typeinfo
CostModelAnalysis()45*9880d681SAndroid Build Coastguard Worker CostModelAnalysis() : FunctionPass(ID), F(nullptr), TTI(nullptr) {
46*9880d681SAndroid Build Coastguard Worker initializeCostModelAnalysisPass(
47*9880d681SAndroid Build Coastguard Worker *PassRegistry::getPassRegistry());
48*9880d681SAndroid Build Coastguard Worker }
49*9880d681SAndroid Build Coastguard Worker
50*9880d681SAndroid Build Coastguard Worker /// Returns the expected cost of the instruction.
51*9880d681SAndroid Build Coastguard Worker /// Returns -1 if the cost is unknown.
52*9880d681SAndroid Build Coastguard Worker /// Note, this method does not cache the cost calculation and it
53*9880d681SAndroid Build Coastguard Worker /// can be expensive in some cases.
54*9880d681SAndroid Build Coastguard Worker unsigned getInstructionCost(const Instruction *I) const;
55*9880d681SAndroid Build Coastguard Worker
56*9880d681SAndroid Build Coastguard Worker private:
57*9880d681SAndroid Build Coastguard Worker void getAnalysisUsage(AnalysisUsage &AU) const override;
58*9880d681SAndroid Build Coastguard Worker bool runOnFunction(Function &F) override;
59*9880d681SAndroid Build Coastguard Worker void print(raw_ostream &OS, const Module*) const override;
60*9880d681SAndroid Build Coastguard Worker
61*9880d681SAndroid Build Coastguard Worker /// The function that we analyze.
62*9880d681SAndroid Build Coastguard Worker Function *F;
63*9880d681SAndroid Build Coastguard Worker /// Target information.
64*9880d681SAndroid Build Coastguard Worker const TargetTransformInfo *TTI;
65*9880d681SAndroid Build Coastguard Worker };
66*9880d681SAndroid Build Coastguard Worker } // End of anonymous namespace
67*9880d681SAndroid Build Coastguard Worker
68*9880d681SAndroid Build Coastguard Worker // Register this pass.
69*9880d681SAndroid Build Coastguard Worker char CostModelAnalysis::ID = 0;
70*9880d681SAndroid Build Coastguard Worker static const char cm_name[] = "Cost Model Analysis";
INITIALIZE_PASS_BEGIN(CostModelAnalysis,CM_NAME,cm_name,false,true)71*9880d681SAndroid Build Coastguard Worker INITIALIZE_PASS_BEGIN(CostModelAnalysis, CM_NAME, cm_name, false, true)
72*9880d681SAndroid Build Coastguard Worker INITIALIZE_PASS_END (CostModelAnalysis, CM_NAME, cm_name, false, true)
73*9880d681SAndroid Build Coastguard Worker
74*9880d681SAndroid Build Coastguard Worker FunctionPass *llvm::createCostModelAnalysisPass() {
75*9880d681SAndroid Build Coastguard Worker return new CostModelAnalysis();
76*9880d681SAndroid Build Coastguard Worker }
77*9880d681SAndroid Build Coastguard Worker
78*9880d681SAndroid Build Coastguard Worker void
getAnalysisUsage(AnalysisUsage & AU) const79*9880d681SAndroid Build Coastguard Worker CostModelAnalysis::getAnalysisUsage(AnalysisUsage &AU) const {
80*9880d681SAndroid Build Coastguard Worker AU.setPreservesAll();
81*9880d681SAndroid Build Coastguard Worker }
82*9880d681SAndroid Build Coastguard Worker
83*9880d681SAndroid Build Coastguard Worker bool
runOnFunction(Function & F)84*9880d681SAndroid Build Coastguard Worker CostModelAnalysis::runOnFunction(Function &F) {
85*9880d681SAndroid Build Coastguard Worker this->F = &F;
86*9880d681SAndroid Build Coastguard Worker auto *TTIWP = getAnalysisIfAvailable<TargetTransformInfoWrapperPass>();
87*9880d681SAndroid Build Coastguard Worker TTI = TTIWP ? &TTIWP->getTTI(F) : nullptr;
88*9880d681SAndroid Build Coastguard Worker
89*9880d681SAndroid Build Coastguard Worker return false;
90*9880d681SAndroid Build Coastguard Worker }
91*9880d681SAndroid Build Coastguard Worker
isReverseVectorMask(SmallVectorImpl<int> & Mask)92*9880d681SAndroid Build Coastguard Worker static bool isReverseVectorMask(SmallVectorImpl<int> &Mask) {
93*9880d681SAndroid Build Coastguard Worker for (unsigned i = 0, MaskSize = Mask.size(); i < MaskSize; ++i)
94*9880d681SAndroid Build Coastguard Worker if (Mask[i] > 0 && Mask[i] != (int)(MaskSize - 1 - i))
95*9880d681SAndroid Build Coastguard Worker return false;
96*9880d681SAndroid Build Coastguard Worker return true;
97*9880d681SAndroid Build Coastguard Worker }
98*9880d681SAndroid Build Coastguard Worker
isAlternateVectorMask(SmallVectorImpl<int> & Mask)99*9880d681SAndroid Build Coastguard Worker static bool isAlternateVectorMask(SmallVectorImpl<int> &Mask) {
100*9880d681SAndroid Build Coastguard Worker bool isAlternate = true;
101*9880d681SAndroid Build Coastguard Worker unsigned MaskSize = Mask.size();
102*9880d681SAndroid Build Coastguard Worker
103*9880d681SAndroid Build Coastguard Worker // Example: shufflevector A, B, <0,5,2,7>
104*9880d681SAndroid Build Coastguard Worker for (unsigned i = 0; i < MaskSize && isAlternate; ++i) {
105*9880d681SAndroid Build Coastguard Worker if (Mask[i] < 0)
106*9880d681SAndroid Build Coastguard Worker continue;
107*9880d681SAndroid Build Coastguard Worker isAlternate = Mask[i] == (int)((i & 1) ? MaskSize + i : i);
108*9880d681SAndroid Build Coastguard Worker }
109*9880d681SAndroid Build Coastguard Worker
110*9880d681SAndroid Build Coastguard Worker if (isAlternate)
111*9880d681SAndroid Build Coastguard Worker return true;
112*9880d681SAndroid Build Coastguard Worker
113*9880d681SAndroid Build Coastguard Worker isAlternate = true;
114*9880d681SAndroid Build Coastguard Worker // Example: shufflevector A, B, <4,1,6,3>
115*9880d681SAndroid Build Coastguard Worker for (unsigned i = 0; i < MaskSize && isAlternate; ++i) {
116*9880d681SAndroid Build Coastguard Worker if (Mask[i] < 0)
117*9880d681SAndroid Build Coastguard Worker continue;
118*9880d681SAndroid Build Coastguard Worker isAlternate = Mask[i] == (int)((i & 1) ? i : MaskSize + i);
119*9880d681SAndroid Build Coastguard Worker }
120*9880d681SAndroid Build Coastguard Worker
121*9880d681SAndroid Build Coastguard Worker return isAlternate;
122*9880d681SAndroid Build Coastguard Worker }
123*9880d681SAndroid Build Coastguard Worker
getOperandInfo(Value * V)124*9880d681SAndroid Build Coastguard Worker static TargetTransformInfo::OperandValueKind getOperandInfo(Value *V) {
125*9880d681SAndroid Build Coastguard Worker TargetTransformInfo::OperandValueKind OpInfo =
126*9880d681SAndroid Build Coastguard Worker TargetTransformInfo::OK_AnyValue;
127*9880d681SAndroid Build Coastguard Worker
128*9880d681SAndroid Build Coastguard Worker // Check for a splat of a constant or for a non uniform vector of constants.
129*9880d681SAndroid Build Coastguard Worker if (isa<ConstantVector>(V) || isa<ConstantDataVector>(V)) {
130*9880d681SAndroid Build Coastguard Worker OpInfo = TargetTransformInfo::OK_NonUniformConstantValue;
131*9880d681SAndroid Build Coastguard Worker if (cast<Constant>(V)->getSplatValue() != nullptr)
132*9880d681SAndroid Build Coastguard Worker OpInfo = TargetTransformInfo::OK_UniformConstantValue;
133*9880d681SAndroid Build Coastguard Worker }
134*9880d681SAndroid Build Coastguard Worker
135*9880d681SAndroid Build Coastguard Worker return OpInfo;
136*9880d681SAndroid Build Coastguard Worker }
137*9880d681SAndroid Build Coastguard Worker
matchPairwiseShuffleMask(ShuffleVectorInst * SI,bool IsLeft,unsigned Level)138*9880d681SAndroid Build Coastguard Worker static bool matchPairwiseShuffleMask(ShuffleVectorInst *SI, bool IsLeft,
139*9880d681SAndroid Build Coastguard Worker unsigned Level) {
140*9880d681SAndroid Build Coastguard Worker // We don't need a shuffle if we just want to have element 0 in position 0 of
141*9880d681SAndroid Build Coastguard Worker // the vector.
142*9880d681SAndroid Build Coastguard Worker if (!SI && Level == 0 && IsLeft)
143*9880d681SAndroid Build Coastguard Worker return true;
144*9880d681SAndroid Build Coastguard Worker else if (!SI)
145*9880d681SAndroid Build Coastguard Worker return false;
146*9880d681SAndroid Build Coastguard Worker
147*9880d681SAndroid Build Coastguard Worker SmallVector<int, 32> Mask(SI->getType()->getVectorNumElements(), -1);
148*9880d681SAndroid Build Coastguard Worker
149*9880d681SAndroid Build Coastguard Worker // Build a mask of 0, 2, ... (left) or 1, 3, ... (right) depending on whether
150*9880d681SAndroid Build Coastguard Worker // we look at the left or right side.
151*9880d681SAndroid Build Coastguard Worker for (unsigned i = 0, e = (1 << Level), val = !IsLeft; i != e; ++i, val += 2)
152*9880d681SAndroid Build Coastguard Worker Mask[i] = val;
153*9880d681SAndroid Build Coastguard Worker
154*9880d681SAndroid Build Coastguard Worker SmallVector<int, 16> ActualMask = SI->getShuffleMask();
155*9880d681SAndroid Build Coastguard Worker return Mask == ActualMask;
156*9880d681SAndroid Build Coastguard Worker }
157*9880d681SAndroid Build Coastguard Worker
matchPairwiseReductionAtLevel(const BinaryOperator * BinOp,unsigned Level,unsigned NumLevels)158*9880d681SAndroid Build Coastguard Worker static bool matchPairwiseReductionAtLevel(const BinaryOperator *BinOp,
159*9880d681SAndroid Build Coastguard Worker unsigned Level, unsigned NumLevels) {
160*9880d681SAndroid Build Coastguard Worker // Match one level of pairwise operations.
161*9880d681SAndroid Build Coastguard Worker // %rdx.shuf.0.0 = shufflevector <4 x float> %rdx, <4 x float> undef,
162*9880d681SAndroid Build Coastguard Worker // <4 x i32> <i32 0, i32 2 , i32 undef, i32 undef>
163*9880d681SAndroid Build Coastguard Worker // %rdx.shuf.0.1 = shufflevector <4 x float> %rdx, <4 x float> undef,
164*9880d681SAndroid Build Coastguard Worker // <4 x i32> <i32 1, i32 3, i32 undef, i32 undef>
165*9880d681SAndroid Build Coastguard Worker // %bin.rdx.0 = fadd <4 x float> %rdx.shuf.0.0, %rdx.shuf.0.1
166*9880d681SAndroid Build Coastguard Worker if (BinOp == nullptr)
167*9880d681SAndroid Build Coastguard Worker return false;
168*9880d681SAndroid Build Coastguard Worker
169*9880d681SAndroid Build Coastguard Worker assert(BinOp->getType()->isVectorTy() && "Expecting a vector type");
170*9880d681SAndroid Build Coastguard Worker
171*9880d681SAndroid Build Coastguard Worker unsigned Opcode = BinOp->getOpcode();
172*9880d681SAndroid Build Coastguard Worker Value *L = BinOp->getOperand(0);
173*9880d681SAndroid Build Coastguard Worker Value *R = BinOp->getOperand(1);
174*9880d681SAndroid Build Coastguard Worker
175*9880d681SAndroid Build Coastguard Worker ShuffleVectorInst *LS = dyn_cast<ShuffleVectorInst>(L);
176*9880d681SAndroid Build Coastguard Worker if (!LS && Level)
177*9880d681SAndroid Build Coastguard Worker return false;
178*9880d681SAndroid Build Coastguard Worker ShuffleVectorInst *RS = dyn_cast<ShuffleVectorInst>(R);
179*9880d681SAndroid Build Coastguard Worker if (!RS && Level)
180*9880d681SAndroid Build Coastguard Worker return false;
181*9880d681SAndroid Build Coastguard Worker
182*9880d681SAndroid Build Coastguard Worker // On level 0 we can omit one shufflevector instruction.
183*9880d681SAndroid Build Coastguard Worker if (!Level && !RS && !LS)
184*9880d681SAndroid Build Coastguard Worker return false;
185*9880d681SAndroid Build Coastguard Worker
186*9880d681SAndroid Build Coastguard Worker // Shuffle inputs must match.
187*9880d681SAndroid Build Coastguard Worker Value *NextLevelOpL = LS ? LS->getOperand(0) : nullptr;
188*9880d681SAndroid Build Coastguard Worker Value *NextLevelOpR = RS ? RS->getOperand(0) : nullptr;
189*9880d681SAndroid Build Coastguard Worker Value *NextLevelOp = nullptr;
190*9880d681SAndroid Build Coastguard Worker if (NextLevelOpR && NextLevelOpL) {
191*9880d681SAndroid Build Coastguard Worker // If we have two shuffles their operands must match.
192*9880d681SAndroid Build Coastguard Worker if (NextLevelOpL != NextLevelOpR)
193*9880d681SAndroid Build Coastguard Worker return false;
194*9880d681SAndroid Build Coastguard Worker
195*9880d681SAndroid Build Coastguard Worker NextLevelOp = NextLevelOpL;
196*9880d681SAndroid Build Coastguard Worker } else if (Level == 0 && (NextLevelOpR || NextLevelOpL)) {
197*9880d681SAndroid Build Coastguard Worker // On the first level we can omit the shufflevector <0, undef,...>. So the
198*9880d681SAndroid Build Coastguard Worker // input to the other shufflevector <1, undef> must match with one of the
199*9880d681SAndroid Build Coastguard Worker // inputs to the current binary operation.
200*9880d681SAndroid Build Coastguard Worker // Example:
201*9880d681SAndroid Build Coastguard Worker // %NextLevelOpL = shufflevector %R, <1, undef ...>
202*9880d681SAndroid Build Coastguard Worker // %BinOp = fadd %NextLevelOpL, %R
203*9880d681SAndroid Build Coastguard Worker if (NextLevelOpL && NextLevelOpL != R)
204*9880d681SAndroid Build Coastguard Worker return false;
205*9880d681SAndroid Build Coastguard Worker else if (NextLevelOpR && NextLevelOpR != L)
206*9880d681SAndroid Build Coastguard Worker return false;
207*9880d681SAndroid Build Coastguard Worker
208*9880d681SAndroid Build Coastguard Worker NextLevelOp = NextLevelOpL ? R : L;
209*9880d681SAndroid Build Coastguard Worker } else
210*9880d681SAndroid Build Coastguard Worker return false;
211*9880d681SAndroid Build Coastguard Worker
212*9880d681SAndroid Build Coastguard Worker // Check that the next levels binary operation exists and matches with the
213*9880d681SAndroid Build Coastguard Worker // current one.
214*9880d681SAndroid Build Coastguard Worker BinaryOperator *NextLevelBinOp = nullptr;
215*9880d681SAndroid Build Coastguard Worker if (Level + 1 != NumLevels) {
216*9880d681SAndroid Build Coastguard Worker if (!(NextLevelBinOp = dyn_cast<BinaryOperator>(NextLevelOp)))
217*9880d681SAndroid Build Coastguard Worker return false;
218*9880d681SAndroid Build Coastguard Worker else if (NextLevelBinOp->getOpcode() != Opcode)
219*9880d681SAndroid Build Coastguard Worker return false;
220*9880d681SAndroid Build Coastguard Worker }
221*9880d681SAndroid Build Coastguard Worker
222*9880d681SAndroid Build Coastguard Worker // Shuffle mask for pairwise operation must match.
223*9880d681SAndroid Build Coastguard Worker if (matchPairwiseShuffleMask(LS, true, Level)) {
224*9880d681SAndroid Build Coastguard Worker if (!matchPairwiseShuffleMask(RS, false, Level))
225*9880d681SAndroid Build Coastguard Worker return false;
226*9880d681SAndroid Build Coastguard Worker } else if (matchPairwiseShuffleMask(RS, true, Level)) {
227*9880d681SAndroid Build Coastguard Worker if (!matchPairwiseShuffleMask(LS, false, Level))
228*9880d681SAndroid Build Coastguard Worker return false;
229*9880d681SAndroid Build Coastguard Worker } else
230*9880d681SAndroid Build Coastguard Worker return false;
231*9880d681SAndroid Build Coastguard Worker
232*9880d681SAndroid Build Coastguard Worker if (++Level == NumLevels)
233*9880d681SAndroid Build Coastguard Worker return true;
234*9880d681SAndroid Build Coastguard Worker
235*9880d681SAndroid Build Coastguard Worker // Match next level.
236*9880d681SAndroid Build Coastguard Worker return matchPairwiseReductionAtLevel(NextLevelBinOp, Level, NumLevels);
237*9880d681SAndroid Build Coastguard Worker }
238*9880d681SAndroid Build Coastguard Worker
matchPairwiseReduction(const ExtractElementInst * ReduxRoot,unsigned & Opcode,Type * & Ty)239*9880d681SAndroid Build Coastguard Worker static bool matchPairwiseReduction(const ExtractElementInst *ReduxRoot,
240*9880d681SAndroid Build Coastguard Worker unsigned &Opcode, Type *&Ty) {
241*9880d681SAndroid Build Coastguard Worker if (!EnableReduxCost)
242*9880d681SAndroid Build Coastguard Worker return false;
243*9880d681SAndroid Build Coastguard Worker
244*9880d681SAndroid Build Coastguard Worker // Need to extract the first element.
245*9880d681SAndroid Build Coastguard Worker ConstantInt *CI = dyn_cast<ConstantInt>(ReduxRoot->getOperand(1));
246*9880d681SAndroid Build Coastguard Worker unsigned Idx = ~0u;
247*9880d681SAndroid Build Coastguard Worker if (CI)
248*9880d681SAndroid Build Coastguard Worker Idx = CI->getZExtValue();
249*9880d681SAndroid Build Coastguard Worker if (Idx != 0)
250*9880d681SAndroid Build Coastguard Worker return false;
251*9880d681SAndroid Build Coastguard Worker
252*9880d681SAndroid Build Coastguard Worker BinaryOperator *RdxStart = dyn_cast<BinaryOperator>(ReduxRoot->getOperand(0));
253*9880d681SAndroid Build Coastguard Worker if (!RdxStart)
254*9880d681SAndroid Build Coastguard Worker return false;
255*9880d681SAndroid Build Coastguard Worker
256*9880d681SAndroid Build Coastguard Worker Type *VecTy = ReduxRoot->getOperand(0)->getType();
257*9880d681SAndroid Build Coastguard Worker unsigned NumVecElems = VecTy->getVectorNumElements();
258*9880d681SAndroid Build Coastguard Worker if (!isPowerOf2_32(NumVecElems))
259*9880d681SAndroid Build Coastguard Worker return false;
260*9880d681SAndroid Build Coastguard Worker
261*9880d681SAndroid Build Coastguard Worker // We look for a sequence of shuffle,shuffle,add triples like the following
262*9880d681SAndroid Build Coastguard Worker // that builds a pairwise reduction tree.
263*9880d681SAndroid Build Coastguard Worker //
264*9880d681SAndroid Build Coastguard Worker // (X0, X1, X2, X3)
265*9880d681SAndroid Build Coastguard Worker // (X0 + X1, X2 + X3, undef, undef)
266*9880d681SAndroid Build Coastguard Worker // ((X0 + X1) + (X2 + X3), undef, undef, undef)
267*9880d681SAndroid Build Coastguard Worker //
268*9880d681SAndroid Build Coastguard Worker // %rdx.shuf.0.0 = shufflevector <4 x float> %rdx, <4 x float> undef,
269*9880d681SAndroid Build Coastguard Worker // <4 x i32> <i32 0, i32 2 , i32 undef, i32 undef>
270*9880d681SAndroid Build Coastguard Worker // %rdx.shuf.0.1 = shufflevector <4 x float> %rdx, <4 x float> undef,
271*9880d681SAndroid Build Coastguard Worker // <4 x i32> <i32 1, i32 3, i32 undef, i32 undef>
272*9880d681SAndroid Build Coastguard Worker // %bin.rdx.0 = fadd <4 x float> %rdx.shuf.0.0, %rdx.shuf.0.1
273*9880d681SAndroid Build Coastguard Worker // %rdx.shuf.1.0 = shufflevector <4 x float> %bin.rdx.0, <4 x float> undef,
274*9880d681SAndroid Build Coastguard Worker // <4 x i32> <i32 0, i32 undef, i32 undef, i32 undef>
275*9880d681SAndroid Build Coastguard Worker // %rdx.shuf.1.1 = shufflevector <4 x float> %bin.rdx.0, <4 x float> undef,
276*9880d681SAndroid Build Coastguard Worker // <4 x i32> <i32 1, i32 undef, i32 undef, i32 undef>
277*9880d681SAndroid Build Coastguard Worker // %bin.rdx8 = fadd <4 x float> %rdx.shuf.1.0, %rdx.shuf.1.1
278*9880d681SAndroid Build Coastguard Worker // %r = extractelement <4 x float> %bin.rdx8, i32 0
279*9880d681SAndroid Build Coastguard Worker if (!matchPairwiseReductionAtLevel(RdxStart, 0, Log2_32(NumVecElems)))
280*9880d681SAndroid Build Coastguard Worker return false;
281*9880d681SAndroid Build Coastguard Worker
282*9880d681SAndroid Build Coastguard Worker Opcode = RdxStart->getOpcode();
283*9880d681SAndroid Build Coastguard Worker Ty = VecTy;
284*9880d681SAndroid Build Coastguard Worker
285*9880d681SAndroid Build Coastguard Worker return true;
286*9880d681SAndroid Build Coastguard Worker }
287*9880d681SAndroid Build Coastguard Worker
288*9880d681SAndroid Build Coastguard Worker static std::pair<Value *, ShuffleVectorInst *>
getShuffleAndOtherOprd(BinaryOperator * B)289*9880d681SAndroid Build Coastguard Worker getShuffleAndOtherOprd(BinaryOperator *B) {
290*9880d681SAndroid Build Coastguard Worker
291*9880d681SAndroid Build Coastguard Worker Value *L = B->getOperand(0);
292*9880d681SAndroid Build Coastguard Worker Value *R = B->getOperand(1);
293*9880d681SAndroid Build Coastguard Worker ShuffleVectorInst *S = nullptr;
294*9880d681SAndroid Build Coastguard Worker
295*9880d681SAndroid Build Coastguard Worker if ((S = dyn_cast<ShuffleVectorInst>(L)))
296*9880d681SAndroid Build Coastguard Worker return std::make_pair(R, S);
297*9880d681SAndroid Build Coastguard Worker
298*9880d681SAndroid Build Coastguard Worker S = dyn_cast<ShuffleVectorInst>(R);
299*9880d681SAndroid Build Coastguard Worker return std::make_pair(L, S);
300*9880d681SAndroid Build Coastguard Worker }
301*9880d681SAndroid Build Coastguard Worker
matchVectorSplittingReduction(const ExtractElementInst * ReduxRoot,unsigned & Opcode,Type * & Ty)302*9880d681SAndroid Build Coastguard Worker static bool matchVectorSplittingReduction(const ExtractElementInst *ReduxRoot,
303*9880d681SAndroid Build Coastguard Worker unsigned &Opcode, Type *&Ty) {
304*9880d681SAndroid Build Coastguard Worker if (!EnableReduxCost)
305*9880d681SAndroid Build Coastguard Worker return false;
306*9880d681SAndroid Build Coastguard Worker
307*9880d681SAndroid Build Coastguard Worker // Need to extract the first element.
308*9880d681SAndroid Build Coastguard Worker ConstantInt *CI = dyn_cast<ConstantInt>(ReduxRoot->getOperand(1));
309*9880d681SAndroid Build Coastguard Worker unsigned Idx = ~0u;
310*9880d681SAndroid Build Coastguard Worker if (CI)
311*9880d681SAndroid Build Coastguard Worker Idx = CI->getZExtValue();
312*9880d681SAndroid Build Coastguard Worker if (Idx != 0)
313*9880d681SAndroid Build Coastguard Worker return false;
314*9880d681SAndroid Build Coastguard Worker
315*9880d681SAndroid Build Coastguard Worker BinaryOperator *RdxStart = dyn_cast<BinaryOperator>(ReduxRoot->getOperand(0));
316*9880d681SAndroid Build Coastguard Worker if (!RdxStart)
317*9880d681SAndroid Build Coastguard Worker return false;
318*9880d681SAndroid Build Coastguard Worker unsigned RdxOpcode = RdxStart->getOpcode();
319*9880d681SAndroid Build Coastguard Worker
320*9880d681SAndroid Build Coastguard Worker Type *VecTy = ReduxRoot->getOperand(0)->getType();
321*9880d681SAndroid Build Coastguard Worker unsigned NumVecElems = VecTy->getVectorNumElements();
322*9880d681SAndroid Build Coastguard Worker if (!isPowerOf2_32(NumVecElems))
323*9880d681SAndroid Build Coastguard Worker return false;
324*9880d681SAndroid Build Coastguard Worker
325*9880d681SAndroid Build Coastguard Worker // We look for a sequence of shuffles and adds like the following matching one
326*9880d681SAndroid Build Coastguard Worker // fadd, shuffle vector pair at a time.
327*9880d681SAndroid Build Coastguard Worker //
328*9880d681SAndroid Build Coastguard Worker // %rdx.shuf = shufflevector <4 x float> %rdx, <4 x float> undef,
329*9880d681SAndroid Build Coastguard Worker // <4 x i32> <i32 2, i32 3, i32 undef, i32 undef>
330*9880d681SAndroid Build Coastguard Worker // %bin.rdx = fadd <4 x float> %rdx, %rdx.shuf
331*9880d681SAndroid Build Coastguard Worker // %rdx.shuf7 = shufflevector <4 x float> %bin.rdx, <4 x float> undef,
332*9880d681SAndroid Build Coastguard Worker // <4 x i32> <i32 1, i32 undef, i32 undef, i32 undef>
333*9880d681SAndroid Build Coastguard Worker // %bin.rdx8 = fadd <4 x float> %bin.rdx, %rdx.shuf7
334*9880d681SAndroid Build Coastguard Worker // %r = extractelement <4 x float> %bin.rdx8, i32 0
335*9880d681SAndroid Build Coastguard Worker
336*9880d681SAndroid Build Coastguard Worker unsigned MaskStart = 1;
337*9880d681SAndroid Build Coastguard Worker Value *RdxOp = RdxStart;
338*9880d681SAndroid Build Coastguard Worker SmallVector<int, 32> ShuffleMask(NumVecElems, 0);
339*9880d681SAndroid Build Coastguard Worker unsigned NumVecElemsRemain = NumVecElems;
340*9880d681SAndroid Build Coastguard Worker while (NumVecElemsRemain - 1) {
341*9880d681SAndroid Build Coastguard Worker // Check for the right reduction operation.
342*9880d681SAndroid Build Coastguard Worker BinaryOperator *BinOp;
343*9880d681SAndroid Build Coastguard Worker if (!(BinOp = dyn_cast<BinaryOperator>(RdxOp)))
344*9880d681SAndroid Build Coastguard Worker return false;
345*9880d681SAndroid Build Coastguard Worker if (BinOp->getOpcode() != RdxOpcode)
346*9880d681SAndroid Build Coastguard Worker return false;
347*9880d681SAndroid Build Coastguard Worker
348*9880d681SAndroid Build Coastguard Worker Value *NextRdxOp;
349*9880d681SAndroid Build Coastguard Worker ShuffleVectorInst *Shuffle;
350*9880d681SAndroid Build Coastguard Worker std::tie(NextRdxOp, Shuffle) = getShuffleAndOtherOprd(BinOp);
351*9880d681SAndroid Build Coastguard Worker
352*9880d681SAndroid Build Coastguard Worker // Check the current reduction operation and the shuffle use the same value.
353*9880d681SAndroid Build Coastguard Worker if (Shuffle == nullptr)
354*9880d681SAndroid Build Coastguard Worker return false;
355*9880d681SAndroid Build Coastguard Worker if (Shuffle->getOperand(0) != NextRdxOp)
356*9880d681SAndroid Build Coastguard Worker return false;
357*9880d681SAndroid Build Coastguard Worker
358*9880d681SAndroid Build Coastguard Worker // Check that shuffle masks matches.
359*9880d681SAndroid Build Coastguard Worker for (unsigned j = 0; j != MaskStart; ++j)
360*9880d681SAndroid Build Coastguard Worker ShuffleMask[j] = MaskStart + j;
361*9880d681SAndroid Build Coastguard Worker // Fill the rest of the mask with -1 for undef.
362*9880d681SAndroid Build Coastguard Worker std::fill(&ShuffleMask[MaskStart], ShuffleMask.end(), -1);
363*9880d681SAndroid Build Coastguard Worker
364*9880d681SAndroid Build Coastguard Worker SmallVector<int, 16> Mask = Shuffle->getShuffleMask();
365*9880d681SAndroid Build Coastguard Worker if (ShuffleMask != Mask)
366*9880d681SAndroid Build Coastguard Worker return false;
367*9880d681SAndroid Build Coastguard Worker
368*9880d681SAndroid Build Coastguard Worker RdxOp = NextRdxOp;
369*9880d681SAndroid Build Coastguard Worker NumVecElemsRemain /= 2;
370*9880d681SAndroid Build Coastguard Worker MaskStart *= 2;
371*9880d681SAndroid Build Coastguard Worker }
372*9880d681SAndroid Build Coastguard Worker
373*9880d681SAndroid Build Coastguard Worker Opcode = RdxOpcode;
374*9880d681SAndroid Build Coastguard Worker Ty = VecTy;
375*9880d681SAndroid Build Coastguard Worker return true;
376*9880d681SAndroid Build Coastguard Worker }
377*9880d681SAndroid Build Coastguard Worker
getInstructionCost(const Instruction * I) const378*9880d681SAndroid Build Coastguard Worker unsigned CostModelAnalysis::getInstructionCost(const Instruction *I) const {
379*9880d681SAndroid Build Coastguard Worker if (!TTI)
380*9880d681SAndroid Build Coastguard Worker return -1;
381*9880d681SAndroid Build Coastguard Worker
382*9880d681SAndroid Build Coastguard Worker switch (I->getOpcode()) {
383*9880d681SAndroid Build Coastguard Worker case Instruction::GetElementPtr:
384*9880d681SAndroid Build Coastguard Worker return TTI->getUserCost(I);
385*9880d681SAndroid Build Coastguard Worker
386*9880d681SAndroid Build Coastguard Worker case Instruction::Ret:
387*9880d681SAndroid Build Coastguard Worker case Instruction::PHI:
388*9880d681SAndroid Build Coastguard Worker case Instruction::Br: {
389*9880d681SAndroid Build Coastguard Worker return TTI->getCFInstrCost(I->getOpcode());
390*9880d681SAndroid Build Coastguard Worker }
391*9880d681SAndroid Build Coastguard Worker case Instruction::Add:
392*9880d681SAndroid Build Coastguard Worker case Instruction::FAdd:
393*9880d681SAndroid Build Coastguard Worker case Instruction::Sub:
394*9880d681SAndroid Build Coastguard Worker case Instruction::FSub:
395*9880d681SAndroid Build Coastguard Worker case Instruction::Mul:
396*9880d681SAndroid Build Coastguard Worker case Instruction::FMul:
397*9880d681SAndroid Build Coastguard Worker case Instruction::UDiv:
398*9880d681SAndroid Build Coastguard Worker case Instruction::SDiv:
399*9880d681SAndroid Build Coastguard Worker case Instruction::FDiv:
400*9880d681SAndroid Build Coastguard Worker case Instruction::URem:
401*9880d681SAndroid Build Coastguard Worker case Instruction::SRem:
402*9880d681SAndroid Build Coastguard Worker case Instruction::FRem:
403*9880d681SAndroid Build Coastguard Worker case Instruction::Shl:
404*9880d681SAndroid Build Coastguard Worker case Instruction::LShr:
405*9880d681SAndroid Build Coastguard Worker case Instruction::AShr:
406*9880d681SAndroid Build Coastguard Worker case Instruction::And:
407*9880d681SAndroid Build Coastguard Worker case Instruction::Or:
408*9880d681SAndroid Build Coastguard Worker case Instruction::Xor: {
409*9880d681SAndroid Build Coastguard Worker TargetTransformInfo::OperandValueKind Op1VK =
410*9880d681SAndroid Build Coastguard Worker getOperandInfo(I->getOperand(0));
411*9880d681SAndroid Build Coastguard Worker TargetTransformInfo::OperandValueKind Op2VK =
412*9880d681SAndroid Build Coastguard Worker getOperandInfo(I->getOperand(1));
413*9880d681SAndroid Build Coastguard Worker return TTI->getArithmeticInstrCost(I->getOpcode(), I->getType(), Op1VK,
414*9880d681SAndroid Build Coastguard Worker Op2VK);
415*9880d681SAndroid Build Coastguard Worker }
416*9880d681SAndroid Build Coastguard Worker case Instruction::Select: {
417*9880d681SAndroid Build Coastguard Worker const SelectInst *SI = cast<SelectInst>(I);
418*9880d681SAndroid Build Coastguard Worker Type *CondTy = SI->getCondition()->getType();
419*9880d681SAndroid Build Coastguard Worker return TTI->getCmpSelInstrCost(I->getOpcode(), I->getType(), CondTy);
420*9880d681SAndroid Build Coastguard Worker }
421*9880d681SAndroid Build Coastguard Worker case Instruction::ICmp:
422*9880d681SAndroid Build Coastguard Worker case Instruction::FCmp: {
423*9880d681SAndroid Build Coastguard Worker Type *ValTy = I->getOperand(0)->getType();
424*9880d681SAndroid Build Coastguard Worker return TTI->getCmpSelInstrCost(I->getOpcode(), ValTy);
425*9880d681SAndroid Build Coastguard Worker }
426*9880d681SAndroid Build Coastguard Worker case Instruction::Store: {
427*9880d681SAndroid Build Coastguard Worker const StoreInst *SI = cast<StoreInst>(I);
428*9880d681SAndroid Build Coastguard Worker Type *ValTy = SI->getValueOperand()->getType();
429*9880d681SAndroid Build Coastguard Worker return TTI->getMemoryOpCost(I->getOpcode(), ValTy,
430*9880d681SAndroid Build Coastguard Worker SI->getAlignment(),
431*9880d681SAndroid Build Coastguard Worker SI->getPointerAddressSpace());
432*9880d681SAndroid Build Coastguard Worker }
433*9880d681SAndroid Build Coastguard Worker case Instruction::Load: {
434*9880d681SAndroid Build Coastguard Worker const LoadInst *LI = cast<LoadInst>(I);
435*9880d681SAndroid Build Coastguard Worker return TTI->getMemoryOpCost(I->getOpcode(), I->getType(),
436*9880d681SAndroid Build Coastguard Worker LI->getAlignment(),
437*9880d681SAndroid Build Coastguard Worker LI->getPointerAddressSpace());
438*9880d681SAndroid Build Coastguard Worker }
439*9880d681SAndroid Build Coastguard Worker case Instruction::ZExt:
440*9880d681SAndroid Build Coastguard Worker case Instruction::SExt:
441*9880d681SAndroid Build Coastguard Worker case Instruction::FPToUI:
442*9880d681SAndroid Build Coastguard Worker case Instruction::FPToSI:
443*9880d681SAndroid Build Coastguard Worker case Instruction::FPExt:
444*9880d681SAndroid Build Coastguard Worker case Instruction::PtrToInt:
445*9880d681SAndroid Build Coastguard Worker case Instruction::IntToPtr:
446*9880d681SAndroid Build Coastguard Worker case Instruction::SIToFP:
447*9880d681SAndroid Build Coastguard Worker case Instruction::UIToFP:
448*9880d681SAndroid Build Coastguard Worker case Instruction::Trunc:
449*9880d681SAndroid Build Coastguard Worker case Instruction::FPTrunc:
450*9880d681SAndroid Build Coastguard Worker case Instruction::BitCast:
451*9880d681SAndroid Build Coastguard Worker case Instruction::AddrSpaceCast: {
452*9880d681SAndroid Build Coastguard Worker Type *SrcTy = I->getOperand(0)->getType();
453*9880d681SAndroid Build Coastguard Worker return TTI->getCastInstrCost(I->getOpcode(), I->getType(), SrcTy);
454*9880d681SAndroid Build Coastguard Worker }
455*9880d681SAndroid Build Coastguard Worker case Instruction::ExtractElement: {
456*9880d681SAndroid Build Coastguard Worker const ExtractElementInst * EEI = cast<ExtractElementInst>(I);
457*9880d681SAndroid Build Coastguard Worker ConstantInt *CI = dyn_cast<ConstantInt>(I->getOperand(1));
458*9880d681SAndroid Build Coastguard Worker unsigned Idx = -1;
459*9880d681SAndroid Build Coastguard Worker if (CI)
460*9880d681SAndroid Build Coastguard Worker Idx = CI->getZExtValue();
461*9880d681SAndroid Build Coastguard Worker
462*9880d681SAndroid Build Coastguard Worker // Try to match a reduction sequence (series of shufflevector and vector
463*9880d681SAndroid Build Coastguard Worker // adds followed by a extractelement).
464*9880d681SAndroid Build Coastguard Worker unsigned ReduxOpCode;
465*9880d681SAndroid Build Coastguard Worker Type *ReduxType;
466*9880d681SAndroid Build Coastguard Worker
467*9880d681SAndroid Build Coastguard Worker if (matchVectorSplittingReduction(EEI, ReduxOpCode, ReduxType))
468*9880d681SAndroid Build Coastguard Worker return TTI->getReductionCost(ReduxOpCode, ReduxType, false);
469*9880d681SAndroid Build Coastguard Worker else if (matchPairwiseReduction(EEI, ReduxOpCode, ReduxType))
470*9880d681SAndroid Build Coastguard Worker return TTI->getReductionCost(ReduxOpCode, ReduxType, true);
471*9880d681SAndroid Build Coastguard Worker
472*9880d681SAndroid Build Coastguard Worker return TTI->getVectorInstrCost(I->getOpcode(),
473*9880d681SAndroid Build Coastguard Worker EEI->getOperand(0)->getType(), Idx);
474*9880d681SAndroid Build Coastguard Worker }
475*9880d681SAndroid Build Coastguard Worker case Instruction::InsertElement: {
476*9880d681SAndroid Build Coastguard Worker const InsertElementInst * IE = cast<InsertElementInst>(I);
477*9880d681SAndroid Build Coastguard Worker ConstantInt *CI = dyn_cast<ConstantInt>(IE->getOperand(2));
478*9880d681SAndroid Build Coastguard Worker unsigned Idx = -1;
479*9880d681SAndroid Build Coastguard Worker if (CI)
480*9880d681SAndroid Build Coastguard Worker Idx = CI->getZExtValue();
481*9880d681SAndroid Build Coastguard Worker return TTI->getVectorInstrCost(I->getOpcode(),
482*9880d681SAndroid Build Coastguard Worker IE->getType(), Idx);
483*9880d681SAndroid Build Coastguard Worker }
484*9880d681SAndroid Build Coastguard Worker case Instruction::ShuffleVector: {
485*9880d681SAndroid Build Coastguard Worker const ShuffleVectorInst *Shuffle = cast<ShuffleVectorInst>(I);
486*9880d681SAndroid Build Coastguard Worker Type *VecTypOp0 = Shuffle->getOperand(0)->getType();
487*9880d681SAndroid Build Coastguard Worker unsigned NumVecElems = VecTypOp0->getVectorNumElements();
488*9880d681SAndroid Build Coastguard Worker SmallVector<int, 16> Mask = Shuffle->getShuffleMask();
489*9880d681SAndroid Build Coastguard Worker
490*9880d681SAndroid Build Coastguard Worker if (NumVecElems == Mask.size()) {
491*9880d681SAndroid Build Coastguard Worker if (isReverseVectorMask(Mask))
492*9880d681SAndroid Build Coastguard Worker return TTI->getShuffleCost(TargetTransformInfo::SK_Reverse, VecTypOp0,
493*9880d681SAndroid Build Coastguard Worker 0, nullptr);
494*9880d681SAndroid Build Coastguard Worker if (isAlternateVectorMask(Mask))
495*9880d681SAndroid Build Coastguard Worker return TTI->getShuffleCost(TargetTransformInfo::SK_Alternate,
496*9880d681SAndroid Build Coastguard Worker VecTypOp0, 0, nullptr);
497*9880d681SAndroid Build Coastguard Worker }
498*9880d681SAndroid Build Coastguard Worker
499*9880d681SAndroid Build Coastguard Worker return -1;
500*9880d681SAndroid Build Coastguard Worker }
501*9880d681SAndroid Build Coastguard Worker case Instruction::Call:
502*9880d681SAndroid Build Coastguard Worker if (const IntrinsicInst *II = dyn_cast<IntrinsicInst>(I)) {
503*9880d681SAndroid Build Coastguard Worker SmallVector<Value *, 4> Args;
504*9880d681SAndroid Build Coastguard Worker for (unsigned J = 0, JE = II->getNumArgOperands(); J != JE; ++J)
505*9880d681SAndroid Build Coastguard Worker Args.push_back(II->getArgOperand(J));
506*9880d681SAndroid Build Coastguard Worker
507*9880d681SAndroid Build Coastguard Worker FastMathFlags FMF;
508*9880d681SAndroid Build Coastguard Worker if (auto *FPMO = dyn_cast<FPMathOperator>(II))
509*9880d681SAndroid Build Coastguard Worker FMF = FPMO->getFastMathFlags();
510*9880d681SAndroid Build Coastguard Worker
511*9880d681SAndroid Build Coastguard Worker return TTI->getIntrinsicInstrCost(II->getIntrinsicID(), II->getType(),
512*9880d681SAndroid Build Coastguard Worker Args, FMF);
513*9880d681SAndroid Build Coastguard Worker }
514*9880d681SAndroid Build Coastguard Worker return -1;
515*9880d681SAndroid Build Coastguard Worker default:
516*9880d681SAndroid Build Coastguard Worker // We don't have any information on this instruction.
517*9880d681SAndroid Build Coastguard Worker return -1;
518*9880d681SAndroid Build Coastguard Worker }
519*9880d681SAndroid Build Coastguard Worker }
520*9880d681SAndroid Build Coastguard Worker
print(raw_ostream & OS,const Module *) const521*9880d681SAndroid Build Coastguard Worker void CostModelAnalysis::print(raw_ostream &OS, const Module*) const {
522*9880d681SAndroid Build Coastguard Worker if (!F)
523*9880d681SAndroid Build Coastguard Worker return;
524*9880d681SAndroid Build Coastguard Worker
525*9880d681SAndroid Build Coastguard Worker for (BasicBlock &B : *F) {
526*9880d681SAndroid Build Coastguard Worker for (Instruction &Inst : B) {
527*9880d681SAndroid Build Coastguard Worker unsigned Cost = getInstructionCost(&Inst);
528*9880d681SAndroid Build Coastguard Worker if (Cost != (unsigned)-1)
529*9880d681SAndroid Build Coastguard Worker OS << "Cost Model: Found an estimated cost of " << Cost;
530*9880d681SAndroid Build Coastguard Worker else
531*9880d681SAndroid Build Coastguard Worker OS << "Cost Model: Unknown cost";
532*9880d681SAndroid Build Coastguard Worker
533*9880d681SAndroid Build Coastguard Worker OS << " for instruction: " << Inst << "\n";
534*9880d681SAndroid Build Coastguard Worker }
535*9880d681SAndroid Build Coastguard Worker }
536*9880d681SAndroid Build Coastguard Worker }
537