1 //===- TLSVariableHoist.cpp -------- Remove Redundant TLS Loads ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This pass identifies/eliminate Redundant TLS Loads if related option is set.
10 // The example: Please refer to the comment at the head of TLSVariableHoist.h.
11 //
12 //===----------------------------------------------------------------------===//
13
14 #include "llvm/ADT/SmallVector.h"
15 #include "llvm/IR/BasicBlock.h"
16 #include "llvm/IR/Dominators.h"
17 #include "llvm/IR/Function.h"
18 #include "llvm/IR/InstrTypes.h"
19 #include "llvm/IR/Instruction.h"
20 #include "llvm/IR/Instructions.h"
21 #include "llvm/IR/IntrinsicInst.h"
22 #include "llvm/IR/Module.h"
23 #include "llvm/IR/Value.h"
24 #include "llvm/InitializePasses.h"
25 #include "llvm/Pass.h"
26 #include "llvm/Support/Casting.h"
27 #include "llvm/Support/Debug.h"
28 #include "llvm/Support/raw_ostream.h"
29 #include "llvm/Transforms/Scalar.h"
30 #include "llvm/Transforms/Scalar/TLSVariableHoist.h"
31 #include <algorithm>
32 #include <cassert>
33 #include <cstdint>
34 #include <iterator>
35 #include <tuple>
36 #include <utility>
37
38 using namespace llvm;
39 using namespace tlshoist;
40
41 #define DEBUG_TYPE "tlshoist"
42
43 static cl::opt<bool> TLSLoadHoist(
44 "tls-load-hoist", cl::init(false), cl::Hidden,
45 cl::desc("hoist the TLS loads in PIC model to eliminate redundant "
46 "TLS address calculation."));
47
48 namespace {
49
50 /// The TLS Variable hoist pass.
51 class TLSVariableHoistLegacyPass : public FunctionPass {
52 public:
53 static char ID; // Pass identification, replacement for typeid
54
TLSVariableHoistLegacyPass()55 TLSVariableHoistLegacyPass() : FunctionPass(ID) {
56 initializeTLSVariableHoistLegacyPassPass(*PassRegistry::getPassRegistry());
57 }
58
59 bool runOnFunction(Function &Fn) override;
60
getPassName() const61 StringRef getPassName() const override { return "TLS Variable Hoist"; }
62
getAnalysisUsage(AnalysisUsage & AU) const63 void getAnalysisUsage(AnalysisUsage &AU) const override {
64 AU.setPreservesCFG();
65 AU.addRequired<DominatorTreeWrapperPass>();
66 AU.addRequired<LoopInfoWrapperPass>();
67 }
68
69 private:
70 TLSVariableHoistPass Impl;
71 };
72
73 } // end anonymous namespace
74
75 char TLSVariableHoistLegacyPass::ID = 0;
76
77 INITIALIZE_PASS_BEGIN(TLSVariableHoistLegacyPass, "tlshoist",
78 "TLS Variable Hoist", false, false)
INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)79 INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
80 INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass)
81 INITIALIZE_PASS_END(TLSVariableHoistLegacyPass, "tlshoist",
82 "TLS Variable Hoist", false, false)
83
84 FunctionPass *llvm::createTLSVariableHoistPass() {
85 return new TLSVariableHoistLegacyPass();
86 }
87
88 /// Perform the TLS Variable Hoist optimization for the given function.
runOnFunction(Function & Fn)89 bool TLSVariableHoistLegacyPass::runOnFunction(Function &Fn) {
90 if (skipFunction(Fn))
91 return false;
92
93 LLVM_DEBUG(dbgs() << "********** Begin TLS Variable Hoist **********\n");
94 LLVM_DEBUG(dbgs() << "********** Function: " << Fn.getName() << '\n');
95
96 bool MadeChange =
97 Impl.runImpl(Fn, getAnalysis<DominatorTreeWrapperPass>().getDomTree(),
98 getAnalysis<LoopInfoWrapperPass>().getLoopInfo());
99
100 if (MadeChange) {
101 LLVM_DEBUG(dbgs() << "********** Function after TLS Variable Hoist: "
102 << Fn.getName() << '\n');
103 LLVM_DEBUG(dbgs() << Fn);
104 }
105 LLVM_DEBUG(dbgs() << "********** End TLS Variable Hoist **********\n");
106
107 return MadeChange;
108 }
109
collectTLSCandidate(Instruction * Inst)110 void TLSVariableHoistPass::collectTLSCandidate(Instruction *Inst) {
111 // Skip all cast instructions. They are visited indirectly later on.
112 if (Inst->isCast())
113 return;
114
115 // Scan all operands.
116 for (unsigned Idx = 0, E = Inst->getNumOperands(); Idx != E; ++Idx) {
117 auto *GV = dyn_cast<GlobalVariable>(Inst->getOperand(Idx));
118 if (!GV || !GV->isThreadLocal())
119 continue;
120
121 // Add Candidate to TLSCandMap (GV --> Candidate).
122 TLSCandMap[GV].addUser(Inst, Idx);
123 }
124 }
125
collectTLSCandidates(Function & Fn)126 void TLSVariableHoistPass::collectTLSCandidates(Function &Fn) {
127 // First, quickly check if there is TLS Variable.
128 Module *M = Fn.getParent();
129
130 bool HasTLS = llvm::any_of(
131 M->globals(), [](GlobalVariable &GV) { return GV.isThreadLocal(); });
132
133 // If non, directly return.
134 if (!HasTLS)
135 return;
136
137 TLSCandMap.clear();
138
139 // Then, collect TLS Variable info.
140 for (BasicBlock &BB : Fn) {
141 // Ignore unreachable basic blocks.
142 if (!DT->isReachableFromEntry(&BB))
143 continue;
144
145 for (Instruction &Inst : BB)
146 collectTLSCandidate(&Inst);
147 }
148 }
149
oneUseOutsideLoop(tlshoist::TLSCandidate & Cand,LoopInfo * LI)150 static bool oneUseOutsideLoop(tlshoist::TLSCandidate &Cand, LoopInfo *LI) {
151 if (Cand.Users.size() != 1)
152 return false;
153
154 BasicBlock *BB = Cand.Users[0].Inst->getParent();
155 if (LI->getLoopFor(BB))
156 return false;
157
158 return true;
159 }
160
getNearestLoopDomInst(BasicBlock * BB,Loop * L)161 Instruction *TLSVariableHoistPass::getNearestLoopDomInst(BasicBlock *BB,
162 Loop *L) {
163 assert(L && "Unexcepted Loop status!");
164
165 // Get the outermost loop.
166 while (Loop *Parent = L->getParentLoop())
167 L = Parent;
168
169 BasicBlock *PreHeader = L->getLoopPreheader();
170
171 // There is unique predecessor outside the loop.
172 if (PreHeader)
173 return PreHeader->getTerminator();
174
175 BasicBlock *Header = L->getHeader();
176 BasicBlock *Dom = Header;
177 for (BasicBlock *PredBB : predecessors(Header))
178 Dom = DT->findNearestCommonDominator(Dom, PredBB);
179
180 assert(Dom && "Not find dominator BB!");
181 Instruction *Term = Dom->getTerminator();
182
183 return Term;
184 }
185
getDomInst(Instruction * I1,Instruction * I2)186 Instruction *TLSVariableHoistPass::getDomInst(Instruction *I1,
187 Instruction *I2) {
188 if (!I1)
189 return I2;
190 return DT->findNearestCommonDominator(I1, I2);
191 }
192
findInsertPos(Function & Fn,GlobalVariable * GV,BasicBlock * & PosBB)193 BasicBlock::iterator TLSVariableHoistPass::findInsertPos(Function &Fn,
194 GlobalVariable *GV,
195 BasicBlock *&PosBB) {
196 tlshoist::TLSCandidate &Cand = TLSCandMap[GV];
197
198 // We should hoist the TLS use out of loop, so choose its nearest instruction
199 // which dominate the loop and the outside loops (if exist).
200 Instruction *LastPos = nullptr;
201 for (auto &User : Cand.Users) {
202 BasicBlock *BB = User.Inst->getParent();
203 Instruction *Pos = User.Inst;
204 if (Loop *L = LI->getLoopFor(BB)) {
205 Pos = getNearestLoopDomInst(BB, L);
206 assert(Pos && "Not find insert position out of loop!");
207 }
208 Pos = getDomInst(LastPos, Pos);
209 LastPos = Pos;
210 }
211
212 assert(LastPos && "Unexpected insert position!");
213 BasicBlock *Parent = LastPos->getParent();
214 PosBB = Parent;
215 return LastPos->getIterator();
216 }
217
218 // Generate a bitcast (no type change) to replace the uses of TLS Candidate.
genBitCastInst(Function & Fn,GlobalVariable * GV)219 Instruction *TLSVariableHoistPass::genBitCastInst(Function &Fn,
220 GlobalVariable *GV) {
221 BasicBlock *PosBB = &Fn.getEntryBlock();
222 BasicBlock::iterator Iter = findInsertPos(Fn, GV, PosBB);
223 Type *Ty = GV->getType();
224 auto *CastInst = new BitCastInst(GV, Ty, "tls_bitcast");
225 CastInst->insertInto(PosBB, Iter);
226 return CastInst;
227 }
228
tryReplaceTLSCandidate(Function & Fn,GlobalVariable * GV)229 bool TLSVariableHoistPass::tryReplaceTLSCandidate(Function &Fn,
230 GlobalVariable *GV) {
231
232 tlshoist::TLSCandidate &Cand = TLSCandMap[GV];
233
234 // If only used 1 time and not in loops, we no need to replace it.
235 if (oneUseOutsideLoop(Cand, LI))
236 return false;
237
238 // Generate a bitcast (no type change)
239 auto *CastInst = genBitCastInst(Fn, GV);
240
241 // to replace the uses of TLS Candidate
242 for (auto &User : Cand.Users)
243 User.Inst->setOperand(User.OpndIdx, CastInst);
244
245 return true;
246 }
247
tryReplaceTLSCandidates(Function & Fn)248 bool TLSVariableHoistPass::tryReplaceTLSCandidates(Function &Fn) {
249 if (TLSCandMap.empty())
250 return false;
251
252 bool Replaced = false;
253 for (auto &GV2Cand : TLSCandMap) {
254 GlobalVariable *GV = GV2Cand.first;
255 Replaced |= tryReplaceTLSCandidate(Fn, GV);
256 }
257
258 return Replaced;
259 }
260
261 /// Optimize expensive TLS variables in the given function.
runImpl(Function & Fn,DominatorTree & DT,LoopInfo & LI)262 bool TLSVariableHoistPass::runImpl(Function &Fn, DominatorTree &DT,
263 LoopInfo &LI) {
264 if (Fn.hasOptNone())
265 return false;
266
267 if (!TLSLoadHoist && !Fn.getAttributes().hasFnAttr("tls-load-hoist"))
268 return false;
269
270 this->LI = &LI;
271 this->DT = &DT;
272 assert(this->LI && this->DT && "Unexcepted requirement!");
273
274 // Collect all TLS variable candidates.
275 collectTLSCandidates(Fn);
276
277 bool MadeChange = tryReplaceTLSCandidates(Fn);
278
279 return MadeChange;
280 }
281
run(Function & F,FunctionAnalysisManager & AM)282 PreservedAnalyses TLSVariableHoistPass::run(Function &F,
283 FunctionAnalysisManager &AM) {
284
285 auto &LI = AM.getResult<LoopAnalysis>(F);
286 auto &DT = AM.getResult<DominatorTreeAnalysis>(F);
287
288 if (!runImpl(F, DT, LI))
289 return PreservedAnalyses::all();
290
291 PreservedAnalyses PA;
292 PA.preserveSet<CFGAnalyses>();
293 return PA;
294 }
295