1 //===- MachineUniformityAnalysis.cpp --------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "llvm/CodeGen/MachineUniformityAnalysis.h"
10 #include "llvm/ADT/GenericUniformityImpl.h"
11 #include "llvm/CodeGen/MachineCycleAnalysis.h"
12 #include "llvm/CodeGen/MachineDominators.h"
13 #include "llvm/CodeGen/MachineRegisterInfo.h"
14 #include "llvm/CodeGen/MachineSSAContext.h"
15 #include "llvm/CodeGen/TargetInstrInfo.h"
16 #include "llvm/InitializePasses.h"
17
18 using namespace llvm;
19
20 template <>
hasDivergentDefs(const MachineInstr & I) const21 bool llvm::GenericUniformityAnalysisImpl<MachineSSAContext>::hasDivergentDefs(
22 const MachineInstr &I) const {
23 for (auto &op : I.operands()) {
24 if (!op.isReg() || !op.isDef())
25 continue;
26 if (isDivergent(op.getReg()))
27 return true;
28 }
29 return false;
30 }
31
32 template <>
markDefsDivergent(const MachineInstr & Instr,bool AllDefsDivergent)33 bool llvm::GenericUniformityAnalysisImpl<MachineSSAContext>::markDefsDivergent(
34 const MachineInstr &Instr, bool AllDefsDivergent) {
35 bool insertedDivergent = false;
36 const auto &MRI = F.getRegInfo();
37 const auto &TRI = *MRI.getTargetRegisterInfo();
38 for (auto &op : Instr.operands()) {
39 if (!op.isReg() || !op.isDef())
40 continue;
41 if (!op.getReg().isVirtual())
42 continue;
43 assert(!op.getSubReg());
44 if (!AllDefsDivergent) {
45 auto *RC = MRI.getRegClassOrNull(op.getReg());
46 if (RC && !TRI.isDivergentRegClass(RC))
47 continue;
48 }
49 insertedDivergent |= markDivergent(op.getReg());
50 }
51 return insertedDivergent;
52 }
53
54 template <>
initialize()55 void llvm::GenericUniformityAnalysisImpl<MachineSSAContext>::initialize() {
56 const auto &InstrInfo = *F.getSubtarget().getInstrInfo();
57
58 for (const MachineBasicBlock &block : F) {
59 for (const MachineInstr &instr : block) {
60 auto uniformity = InstrInfo.getInstructionUniformity(instr);
61 if (uniformity == InstructionUniformity::AlwaysUniform) {
62 addUniformOverride(instr);
63 continue;
64 }
65
66 if (uniformity == InstructionUniformity::NeverUniform) {
67 markDefsDivergent(instr, /* AllDefsDivergent = */ false);
68 }
69 }
70 }
71 }
72
73 template <>
pushUsers(Register Reg)74 void llvm::GenericUniformityAnalysisImpl<MachineSSAContext>::pushUsers(
75 Register Reg) {
76 const auto &RegInfo = F.getRegInfo();
77 for (MachineInstr &UserInstr : RegInfo.use_instructions(Reg)) {
78 if (isAlwaysUniform(UserInstr))
79 continue;
80 if (markDivergent(UserInstr))
81 Worklist.push_back(&UserInstr);
82 }
83 }
84
85 template <>
pushUsers(const MachineInstr & Instr)86 void llvm::GenericUniformityAnalysisImpl<MachineSSAContext>::pushUsers(
87 const MachineInstr &Instr) {
88 assert(!isAlwaysUniform(Instr));
89 if (Instr.isTerminator())
90 return;
91 for (const MachineOperand &op : Instr.operands()) {
92 if (op.isReg() && op.isDef() && op.getReg().isVirtual())
93 pushUsers(op.getReg());
94 }
95 }
96
97 template <>
usesValueFromCycle(const MachineInstr & I,const MachineCycle & DefCycle) const98 bool llvm::GenericUniformityAnalysisImpl<MachineSSAContext>::usesValueFromCycle(
99 const MachineInstr &I, const MachineCycle &DefCycle) const {
100 assert(!isAlwaysUniform(I));
101 for (auto &Op : I.operands()) {
102 if (!Op.isReg() || !Op.readsReg())
103 continue;
104 auto Reg = Op.getReg();
105 assert(Reg.isVirtual());
106 auto *Def = F.getRegInfo().getVRegDef(Reg);
107 if (DefCycle.contains(Def->getParent()))
108 return true;
109 }
110 return false;
111 }
112
113 // This ensures explicit instantiation of
114 // GenericUniformityAnalysisImpl::ImplDeleter::operator()
115 template class llvm::GenericUniformityInfo<MachineSSAContext>;
116 template struct llvm::GenericUniformityAnalysisImplDeleter<
117 llvm::GenericUniformityAnalysisImpl<MachineSSAContext>>;
118
119 MachineUniformityInfo
computeMachineUniformityInfo(MachineFunction & F,const MachineCycleInfo & cycleInfo,const MachineDomTree & domTree)120 llvm::computeMachineUniformityInfo(MachineFunction &F,
121 const MachineCycleInfo &cycleInfo,
122 const MachineDomTree &domTree) {
123 assert(F.getRegInfo().isSSA() && "Expected to be run on SSA form!");
124 return MachineUniformityInfo(F, domTree, cycleInfo);
125 }
126
127 namespace {
128
129 /// Legacy analysis pass which computes a \ref MachineUniformityInfo.
130 class MachineUniformityAnalysisPass : public MachineFunctionPass {
131 MachineUniformityInfo UI;
132
133 public:
134 static char ID;
135
136 MachineUniformityAnalysisPass();
137
getUniformityInfo()138 MachineUniformityInfo &getUniformityInfo() { return UI; }
getUniformityInfo() const139 const MachineUniformityInfo &getUniformityInfo() const { return UI; }
140
141 bool runOnMachineFunction(MachineFunction &F) override;
142 void getAnalysisUsage(AnalysisUsage &AU) const override;
143 void print(raw_ostream &OS, const Module *M = nullptr) const override;
144
145 // TODO: verify analysis
146 };
147
148 class MachineUniformityInfoPrinterPass : public MachineFunctionPass {
149 public:
150 static char ID;
151
152 MachineUniformityInfoPrinterPass();
153
154 bool runOnMachineFunction(MachineFunction &F) override;
155 void getAnalysisUsage(AnalysisUsage &AU) const override;
156 };
157
158 } // namespace
159
160 char MachineUniformityAnalysisPass::ID = 0;
161
MachineUniformityAnalysisPass()162 MachineUniformityAnalysisPass::MachineUniformityAnalysisPass()
163 : MachineFunctionPass(ID) {
164 initializeMachineUniformityAnalysisPassPass(*PassRegistry::getPassRegistry());
165 }
166
167 INITIALIZE_PASS_BEGIN(MachineUniformityAnalysisPass, "machine-uniformity",
168 "Machine Uniformity Info Analysis", true, true)
INITIALIZE_PASS_DEPENDENCY(MachineCycleInfoWrapperPass)169 INITIALIZE_PASS_DEPENDENCY(MachineCycleInfoWrapperPass)
170 INITIALIZE_PASS_DEPENDENCY(MachineDominatorTree)
171 INITIALIZE_PASS_END(MachineUniformityAnalysisPass, "machine-uniformity",
172 "Machine Uniformity Info Analysis", true, true)
173
174 void MachineUniformityAnalysisPass::getAnalysisUsage(AnalysisUsage &AU) const {
175 AU.setPreservesAll();
176 AU.addRequired<MachineCycleInfoWrapperPass>();
177 AU.addRequired<MachineDominatorTree>();
178 MachineFunctionPass::getAnalysisUsage(AU);
179 }
180
runOnMachineFunction(MachineFunction & MF)181 bool MachineUniformityAnalysisPass::runOnMachineFunction(MachineFunction &MF) {
182 auto &DomTree = getAnalysis<MachineDominatorTree>().getBase();
183 auto &CI = getAnalysis<MachineCycleInfoWrapperPass>().getCycleInfo();
184 UI = computeMachineUniformityInfo(MF, CI, DomTree);
185 return false;
186 }
187
print(raw_ostream & OS,const Module *) const188 void MachineUniformityAnalysisPass::print(raw_ostream &OS,
189 const Module *) const {
190 OS << "MachineUniformityInfo for function: " << UI.getFunction().getName()
191 << "\n";
192 UI.print(OS);
193 }
194
195 char MachineUniformityInfoPrinterPass::ID = 0;
196
MachineUniformityInfoPrinterPass()197 MachineUniformityInfoPrinterPass::MachineUniformityInfoPrinterPass()
198 : MachineFunctionPass(ID) {
199 initializeMachineUniformityInfoPrinterPassPass(
200 *PassRegistry::getPassRegistry());
201 }
202
203 INITIALIZE_PASS_BEGIN(MachineUniformityInfoPrinterPass,
204 "print-machine-uniformity",
205 "Print Machine Uniformity Info Analysis", true, true)
INITIALIZE_PASS_DEPENDENCY(MachineUniformityAnalysisPass)206 INITIALIZE_PASS_DEPENDENCY(MachineUniformityAnalysisPass)
207 INITIALIZE_PASS_END(MachineUniformityInfoPrinterPass,
208 "print-machine-uniformity",
209 "Print Machine Uniformity Info Analysis", true, true)
210
211 void MachineUniformityInfoPrinterPass::getAnalysisUsage(
212 AnalysisUsage &AU) const {
213 AU.setPreservesAll();
214 AU.addRequired<MachineUniformityAnalysisPass>();
215 MachineFunctionPass::getAnalysisUsage(AU);
216 }
217
runOnMachineFunction(MachineFunction & F)218 bool MachineUniformityInfoPrinterPass::runOnMachineFunction(
219 MachineFunction &F) {
220 auto &UI = getAnalysis<MachineUniformityAnalysisPass>();
221 UI.print(errs());
222 return false;
223 }
224