1*08b48e0bSAndroid Build Coastguard Worker /* 2*08b48e0bSAndroid Build Coastguard Worker * Copyright 2016 laf-intel 3*08b48e0bSAndroid Build Coastguard Worker * 4*08b48e0bSAndroid Build Coastguard Worker * Licensed under the Apache License, Version 2.0 (the "License"); 5*08b48e0bSAndroid Build Coastguard Worker * you may not use this file except in compliance with the License. 6*08b48e0bSAndroid Build Coastguard Worker * You may obtain a copy of the License at 7*08b48e0bSAndroid Build Coastguard Worker * 8*08b48e0bSAndroid Build Coastguard Worker * https://www.apache.org/licenses/LICENSE-2.0 9*08b48e0bSAndroid Build Coastguard Worker * 10*08b48e0bSAndroid Build Coastguard Worker * Unless required by applicable law or agreed to in writing, software 11*08b48e0bSAndroid Build Coastguard Worker * distributed under the License is distributed on an "AS IS" BASIS, 12*08b48e0bSAndroid Build Coastguard Worker * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13*08b48e0bSAndroid Build Coastguard Worker * See the License for the specific language governing permissions and 14*08b48e0bSAndroid Build Coastguard Worker * limitations under the License. 15*08b48e0bSAndroid Build Coastguard Worker */ 16*08b48e0bSAndroid Build Coastguard Worker 17*08b48e0bSAndroid Build Coastguard Worker #include <stdio.h> 18*08b48e0bSAndroid Build Coastguard Worker #include <stdlib.h> 19*08b48e0bSAndroid Build Coastguard Worker #include <unistd.h> 20*08b48e0bSAndroid Build Coastguard Worker 21*08b48e0bSAndroid Build Coastguard Worker #include <list> 22*08b48e0bSAndroid Build Coastguard Worker #include <string> 23*08b48e0bSAndroid Build Coastguard Worker #include <fstream> 24*08b48e0bSAndroid Build Coastguard Worker #include <sys/time.h> 25*08b48e0bSAndroid Build Coastguard Worker 26*08b48e0bSAndroid Build Coastguard Worker #include "llvm/Config/llvm-config.h" 27*08b48e0bSAndroid Build Coastguard Worker 28*08b48e0bSAndroid Build Coastguard Worker #include "llvm/ADT/Statistic.h" 29*08b48e0bSAndroid Build Coastguard Worker #include "llvm/IR/IRBuilder.h" 30*08b48e0bSAndroid Build Coastguard Worker #if LLVM_VERSION_MAJOR >= 11 /* use new pass manager */ 31*08b48e0bSAndroid Build Coastguard Worker #include "llvm/Passes/PassPlugin.h" 32*08b48e0bSAndroid Build Coastguard Worker #include "llvm/Passes/PassBuilder.h" 33*08b48e0bSAndroid Build Coastguard Worker #include "llvm/IR/PassManager.h" 34*08b48e0bSAndroid Build Coastguard Worker #else 35*08b48e0bSAndroid Build Coastguard Worker #include "llvm/IR/LegacyPassManager.h" 36*08b48e0bSAndroid Build Coastguard Worker #include "llvm/Transforms/IPO/PassManagerBuilder.h" 37*08b48e0bSAndroid Build Coastguard Worker #endif 38*08b48e0bSAndroid Build Coastguard Worker #include "llvm/IR/Module.h" 39*08b48e0bSAndroid Build Coastguard Worker #include "llvm/Support/Debug.h" 40*08b48e0bSAndroid Build Coastguard Worker #include "llvm/Support/raw_ostream.h" 41*08b48e0bSAndroid Build Coastguard Worker #include "llvm/Transforms/Utils/BasicBlockUtils.h" 42*08b48e0bSAndroid Build Coastguard Worker #include "llvm/Pass.h" 43*08b48e0bSAndroid Build Coastguard Worker #include "llvm/Analysis/ValueTracking.h" 44*08b48e0bSAndroid Build Coastguard Worker #if LLVM_VERSION_MAJOR >= 14 /* how about stable interfaces? */ 45*08b48e0bSAndroid Build Coastguard Worker #include "llvm/Passes/OptimizationLevel.h" 46*08b48e0bSAndroid Build Coastguard Worker #endif 47*08b48e0bSAndroid Build Coastguard Worker 48*08b48e0bSAndroid Build Coastguard Worker #include "llvm/IR/IRBuilder.h" 49*08b48e0bSAndroid Build Coastguard Worker #if LLVM_VERSION_MAJOR >= 4 || \ 50*08b48e0bSAndroid Build Coastguard Worker (LLVM_VERSION_MAJOR == 3 && LLVM_VERSION_MINOR > 4) 51*08b48e0bSAndroid Build Coastguard Worker #include "llvm/IR/Verifier.h" 52*08b48e0bSAndroid Build Coastguard Worker #include "llvm/IR/DebugInfo.h" 53*08b48e0bSAndroid Build Coastguard Worker #else 54*08b48e0bSAndroid Build Coastguard Worker #include "llvm/Analysis/Verifier.h" 55*08b48e0bSAndroid Build Coastguard Worker #include "llvm/DebugInfo.h" 56*08b48e0bSAndroid Build Coastguard Worker #define nullptr 0 57*08b48e0bSAndroid Build Coastguard Worker #endif 58*08b48e0bSAndroid Build Coastguard Worker 59*08b48e0bSAndroid Build Coastguard Worker #include <set> 60*08b48e0bSAndroid Build Coastguard Worker #include "afl-llvm-common.h" 61*08b48e0bSAndroid Build Coastguard Worker 62*08b48e0bSAndroid Build Coastguard Worker using namespace llvm; 63*08b48e0bSAndroid Build Coastguard Worker 64*08b48e0bSAndroid Build Coastguard Worker namespace { 65*08b48e0bSAndroid Build Coastguard Worker 66*08b48e0bSAndroid Build Coastguard Worker #if LLVM_VERSION_MAJOR >= 11 /* use new pass manager */ 67*08b48e0bSAndroid Build Coastguard Worker class SplitSwitchesTransform : public PassInfoMixin<SplitSwitchesTransform> { 68*08b48e0bSAndroid Build Coastguard Worker 69*08b48e0bSAndroid Build Coastguard Worker public: SplitSwitchesTransform()70*08b48e0bSAndroid Build Coastguard Worker SplitSwitchesTransform() { 71*08b48e0bSAndroid Build Coastguard Worker 72*08b48e0bSAndroid Build Coastguard Worker #else 73*08b48e0bSAndroid Build Coastguard Worker class SplitSwitchesTransform : public ModulePass { 74*08b48e0bSAndroid Build Coastguard Worker 75*08b48e0bSAndroid Build Coastguard Worker public: 76*08b48e0bSAndroid Build Coastguard Worker static char ID; 77*08b48e0bSAndroid Build Coastguard Worker SplitSwitchesTransform() : ModulePass(ID) { 78*08b48e0bSAndroid Build Coastguard Worker 79*08b48e0bSAndroid Build Coastguard Worker #endif 80*08b48e0bSAndroid Build Coastguard Worker initInstrumentList(); 81*08b48e0bSAndroid Build Coastguard Worker 82*08b48e0bSAndroid Build Coastguard Worker } 83*08b48e0bSAndroid Build Coastguard Worker 84*08b48e0bSAndroid Build Coastguard Worker #if LLVM_VERSION_MAJOR >= 11 /* use new pass manager */ 85*08b48e0bSAndroid Build Coastguard Worker PreservedAnalyses run(Module &M, ModuleAnalysisManager &MAM); 86*08b48e0bSAndroid Build Coastguard Worker #else 87*08b48e0bSAndroid Build Coastguard Worker bool runOnModule(Module &M) override; 88*08b48e0bSAndroid Build Coastguard Worker 89*08b48e0bSAndroid Build Coastguard Worker #if LLVM_VERSION_MAJOR >= 4 90*08b48e0bSAndroid Build Coastguard Worker StringRef getPassName() const override { 91*08b48e0bSAndroid Build Coastguard Worker 92*08b48e0bSAndroid Build Coastguard Worker #else 93*08b48e0bSAndroid Build Coastguard Worker const char *getPassName() const override { 94*08b48e0bSAndroid Build Coastguard Worker 95*08b48e0bSAndroid Build Coastguard Worker #endif 96*08b48e0bSAndroid Build Coastguard Worker return "splits switch constructs"; 97*08b48e0bSAndroid Build Coastguard Worker 98*08b48e0bSAndroid Build Coastguard Worker } 99*08b48e0bSAndroid Build Coastguard Worker 100*08b48e0bSAndroid Build Coastguard Worker #endif 101*08b48e0bSAndroid Build Coastguard Worker 102*08b48e0bSAndroid Build Coastguard Worker struct CaseExpr { 103*08b48e0bSAndroid Build Coastguard Worker 104*08b48e0bSAndroid Build Coastguard Worker ConstantInt *Val; 105*08b48e0bSAndroid Build Coastguard Worker BasicBlock *BB; 106*08b48e0bSAndroid Build Coastguard Worker 107*08b48e0bSAndroid Build Coastguard Worker CaseExpr(ConstantInt *val = nullptr, BasicBlock *bb = nullptr) 108*08b48e0bSAndroid Build Coastguard Worker : Val(val), BB(bb) { 109*08b48e0bSAndroid Build Coastguard Worker 110*08b48e0bSAndroid Build Coastguard Worker } 111*08b48e0bSAndroid Build Coastguard Worker 112*08b48e0bSAndroid Build Coastguard Worker }; 113*08b48e0bSAndroid Build Coastguard Worker 114*08b48e0bSAndroid Build Coastguard Worker using CaseVector = std::vector<CaseExpr>; 115*08b48e0bSAndroid Build Coastguard Worker 116*08b48e0bSAndroid Build Coastguard Worker private: 117*08b48e0bSAndroid Build Coastguard Worker bool splitSwitches(Module &M); 118*08b48e0bSAndroid Build Coastguard Worker bool transformCmps(Module &M, const bool processStrcmp, 119*08b48e0bSAndroid Build Coastguard Worker const bool processMemcmp); 120*08b48e0bSAndroid Build Coastguard Worker BasicBlock *switchConvert(CaseVector Cases, std::vector<bool> bytesChecked, 121*08b48e0bSAndroid Build Coastguard Worker BasicBlock *OrigBlock, BasicBlock *NewDefault, 122*08b48e0bSAndroid Build Coastguard Worker Value *Val, unsigned level); 123*08b48e0bSAndroid Build Coastguard Worker 124*08b48e0bSAndroid Build Coastguard Worker }; 125*08b48e0bSAndroid Build Coastguard Worker 126*08b48e0bSAndroid Build Coastguard Worker } // namespace 127*08b48e0bSAndroid Build Coastguard Worker 128*08b48e0bSAndroid Build Coastguard Worker #if LLVM_VERSION_MAJOR >= 11 /* use new pass manager */ 129*08b48e0bSAndroid Build Coastguard Worker extern "C" ::llvm::PassPluginLibraryInfo LLVM_ATTRIBUTE_WEAK 130*08b48e0bSAndroid Build Coastguard Worker llvmGetPassPluginInfo() { 131*08b48e0bSAndroid Build Coastguard Worker 132*08b48e0bSAndroid Build Coastguard Worker return {LLVM_PLUGIN_API_VERSION, "splitswitches", "v0.1", 133*08b48e0bSAndroid Build Coastguard Worker /* lambda to insert our pass into the pass pipeline. */ 134*08b48e0bSAndroid Build Coastguard Worker [](PassBuilder &PB) { 135*08b48e0bSAndroid Build Coastguard Worker 136*08b48e0bSAndroid Build Coastguard Worker #if 1 137*08b48e0bSAndroid Build Coastguard Worker #if LLVM_VERSION_MAJOR <= 13 138*08b48e0bSAndroid Build Coastguard Worker using OptimizationLevel = typename PassBuilder::OptimizationLevel; 139*08b48e0bSAndroid Build Coastguard Worker #endif 140*08b48e0bSAndroid Build Coastguard Worker PB.registerOptimizerLastEPCallback( 141*08b48e0bSAndroid Build Coastguard Worker [](ModulePassManager &MPM, OptimizationLevel OL) { 142*08b48e0bSAndroid Build Coastguard Worker 143*08b48e0bSAndroid Build Coastguard Worker MPM.addPass(SplitSwitchesTransform()); 144*08b48e0bSAndroid Build Coastguard Worker 145*08b48e0bSAndroid Build Coastguard Worker }); 146*08b48e0bSAndroid Build Coastguard Worker 147*08b48e0bSAndroid Build Coastguard Worker /* TODO LTO registration */ 148*08b48e0bSAndroid Build Coastguard Worker #else 149*08b48e0bSAndroid Build Coastguard Worker using PipelineElement = typename PassBuilder::PipelineElement; 150*08b48e0bSAndroid Build Coastguard Worker PB.registerPipelineParsingCallback([](StringRef Name, 151*08b48e0bSAndroid Build Coastguard Worker ModulePassManager &MPM, 152*08b48e0bSAndroid Build Coastguard Worker ArrayRef<PipelineElement>) { 153*08b48e0bSAndroid Build Coastguard Worker 154*08b48e0bSAndroid Build Coastguard Worker if (Name == "splitswitches") { 155*08b48e0bSAndroid Build Coastguard Worker 156*08b48e0bSAndroid Build Coastguard Worker MPM.addPass(SplitSwitchesTransform()); 157*08b48e0bSAndroid Build Coastguard Worker return true; 158*08b48e0bSAndroid Build Coastguard Worker 159*08b48e0bSAndroid Build Coastguard Worker } else { 160*08b48e0bSAndroid Build Coastguard Worker 161*08b48e0bSAndroid Build Coastguard Worker return false; 162*08b48e0bSAndroid Build Coastguard Worker 163*08b48e0bSAndroid Build Coastguard Worker } 164*08b48e0bSAndroid Build Coastguard Worker 165*08b48e0bSAndroid Build Coastguard Worker }); 166*08b48e0bSAndroid Build Coastguard Worker 167*08b48e0bSAndroid Build Coastguard Worker #endif 168*08b48e0bSAndroid Build Coastguard Worker 169*08b48e0bSAndroid Build Coastguard Worker }}; 170*08b48e0bSAndroid Build Coastguard Worker 171*08b48e0bSAndroid Build Coastguard Worker } 172*08b48e0bSAndroid Build Coastguard Worker 173*08b48e0bSAndroid Build Coastguard Worker #else 174*08b48e0bSAndroid Build Coastguard Worker char SplitSwitchesTransform::ID = 0; 175*08b48e0bSAndroid Build Coastguard Worker #endif 176*08b48e0bSAndroid Build Coastguard Worker 177*08b48e0bSAndroid Build Coastguard Worker /* switchConvert - Transform simple list of Cases into list of CaseRange's */ 178*08b48e0bSAndroid Build Coastguard Worker BasicBlock *SplitSwitchesTransform::switchConvert( 179*08b48e0bSAndroid Build Coastguard Worker CaseVector Cases, std::vector<bool> bytesChecked, BasicBlock *OrigBlock, 180*08b48e0bSAndroid Build Coastguard Worker BasicBlock *NewDefault, Value *Val, unsigned level) { 181*08b48e0bSAndroid Build Coastguard Worker 182*08b48e0bSAndroid Build Coastguard Worker unsigned ValTypeBitWidth = Cases[0].Val->getBitWidth(); 183*08b48e0bSAndroid Build Coastguard Worker IntegerType *ValType = 184*08b48e0bSAndroid Build Coastguard Worker IntegerType::get(OrigBlock->getContext(), ValTypeBitWidth); 185*08b48e0bSAndroid Build Coastguard Worker IntegerType *ByteType = IntegerType::get(OrigBlock->getContext(), 8); 186*08b48e0bSAndroid Build Coastguard Worker unsigned BytesInValue = bytesChecked.size(); 187*08b48e0bSAndroid Build Coastguard Worker std::vector<uint8_t> setSizes; 188*08b48e0bSAndroid Build Coastguard Worker std::vector<std::set<uint8_t> > byteSets(BytesInValue, std::set<uint8_t>()); 189*08b48e0bSAndroid Build Coastguard Worker 190*08b48e0bSAndroid Build Coastguard Worker /* for each of the possible cases we iterate over all bytes of the values 191*08b48e0bSAndroid Build Coastguard Worker * build a set of possible values at each byte position in byteSets */ 192*08b48e0bSAndroid Build Coastguard Worker for (CaseExpr &Case : Cases) { 193*08b48e0bSAndroid Build Coastguard Worker 194*08b48e0bSAndroid Build Coastguard Worker for (unsigned i = 0; i < BytesInValue; i++) { 195*08b48e0bSAndroid Build Coastguard Worker 196*08b48e0bSAndroid Build Coastguard Worker uint8_t byte = (Case.Val->getZExtValue() >> (i * 8)) & 0xFF; 197*08b48e0bSAndroid Build Coastguard Worker byteSets[i].insert(byte); 198*08b48e0bSAndroid Build Coastguard Worker 199*08b48e0bSAndroid Build Coastguard Worker } 200*08b48e0bSAndroid Build Coastguard Worker 201*08b48e0bSAndroid Build Coastguard Worker } 202*08b48e0bSAndroid Build Coastguard Worker 203*08b48e0bSAndroid Build Coastguard Worker /* find the index of the first byte position that was not yet checked. then 204*08b48e0bSAndroid Build Coastguard Worker * save the number of possible values at that byte position */ 205*08b48e0bSAndroid Build Coastguard Worker unsigned smallestIndex = 0; 206*08b48e0bSAndroid Build Coastguard Worker unsigned smallestSize = 257; 207*08b48e0bSAndroid Build Coastguard Worker for (unsigned i = 0; i < byteSets.size(); i++) { 208*08b48e0bSAndroid Build Coastguard Worker 209*08b48e0bSAndroid Build Coastguard Worker if (bytesChecked[i]) continue; 210*08b48e0bSAndroid Build Coastguard Worker if (byteSets[i].size() < smallestSize) { 211*08b48e0bSAndroid Build Coastguard Worker 212*08b48e0bSAndroid Build Coastguard Worker smallestIndex = i; 213*08b48e0bSAndroid Build Coastguard Worker smallestSize = byteSets[i].size(); 214*08b48e0bSAndroid Build Coastguard Worker 215*08b48e0bSAndroid Build Coastguard Worker } 216*08b48e0bSAndroid Build Coastguard Worker 217*08b48e0bSAndroid Build Coastguard Worker } 218*08b48e0bSAndroid Build Coastguard Worker 219*08b48e0bSAndroid Build Coastguard Worker assert(bytesChecked[smallestIndex] == false); 220*08b48e0bSAndroid Build Coastguard Worker 221*08b48e0bSAndroid Build Coastguard Worker /* there are only smallestSize different bytes at index smallestIndex */ 222*08b48e0bSAndroid Build Coastguard Worker 223*08b48e0bSAndroid Build Coastguard Worker Instruction *Shift, *Trunc; 224*08b48e0bSAndroid Build Coastguard Worker Function *F = OrigBlock->getParent(); 225*08b48e0bSAndroid Build Coastguard Worker BasicBlock *NewNode = BasicBlock::Create(Val->getContext(), "NodeBlock", F); 226*08b48e0bSAndroid Build Coastguard Worker Shift = BinaryOperator::Create(Instruction::LShr, Val, 227*08b48e0bSAndroid Build Coastguard Worker ConstantInt::get(ValType, smallestIndex * 8)); 228*08b48e0bSAndroid Build Coastguard Worker #if LLVM_VERSION_MAJOR >= 16 229*08b48e0bSAndroid Build Coastguard Worker Shift->insertInto(NewNode, NewNode->end()); 230*08b48e0bSAndroid Build Coastguard Worker #else 231*08b48e0bSAndroid Build Coastguard Worker NewNode->getInstList().push_back(Shift); 232*08b48e0bSAndroid Build Coastguard Worker #endif 233*08b48e0bSAndroid Build Coastguard Worker 234*08b48e0bSAndroid Build Coastguard Worker if (ValTypeBitWidth > 8) { 235*08b48e0bSAndroid Build Coastguard Worker 236*08b48e0bSAndroid Build Coastguard Worker Trunc = new TruncInst(Shift, ByteType); 237*08b48e0bSAndroid Build Coastguard Worker #if LLVM_VERSION_MAJOR >= 16 238*08b48e0bSAndroid Build Coastguard Worker Trunc->insertInto(NewNode, NewNode->end()); 239*08b48e0bSAndroid Build Coastguard Worker #else 240*08b48e0bSAndroid Build Coastguard Worker NewNode->getInstList().push_back(Trunc); 241*08b48e0bSAndroid Build Coastguard Worker #endif 242*08b48e0bSAndroid Build Coastguard Worker 243*08b48e0bSAndroid Build Coastguard Worker } else { 244*08b48e0bSAndroid Build Coastguard Worker 245*08b48e0bSAndroid Build Coastguard Worker /* not necessary to trunc */ 246*08b48e0bSAndroid Build Coastguard Worker Trunc = Shift; 247*08b48e0bSAndroid Build Coastguard Worker 248*08b48e0bSAndroid Build Coastguard Worker } 249*08b48e0bSAndroid Build Coastguard Worker 250*08b48e0bSAndroid Build Coastguard Worker /* this is a trivial case, we can directly check for the byte, 251*08b48e0bSAndroid Build Coastguard Worker * if the byte is not found go to default. if the byte was found 252*08b48e0bSAndroid Build Coastguard Worker * mark the byte as checked. if this was the last byte to check 253*08b48e0bSAndroid Build Coastguard Worker * we can finally execute the block belonging to this case */ 254*08b48e0bSAndroid Build Coastguard Worker 255*08b48e0bSAndroid Build Coastguard Worker if (smallestSize == 1) { 256*08b48e0bSAndroid Build Coastguard Worker 257*08b48e0bSAndroid Build Coastguard Worker uint8_t byte = *(byteSets[smallestIndex].begin()); 258*08b48e0bSAndroid Build Coastguard Worker 259*08b48e0bSAndroid Build Coastguard Worker /* insert instructions to check whether the value we are switching on is 260*08b48e0bSAndroid Build Coastguard Worker * equal to byte */ 261*08b48e0bSAndroid Build Coastguard Worker ICmpInst *Comp = 262*08b48e0bSAndroid Build Coastguard Worker new ICmpInst(ICmpInst::ICMP_EQ, Trunc, ConstantInt::get(ByteType, byte), 263*08b48e0bSAndroid Build Coastguard Worker "byteMatch"); 264*08b48e0bSAndroid Build Coastguard Worker #if LLVM_VERSION_MAJOR >= 16 265*08b48e0bSAndroid Build Coastguard Worker Comp->insertInto(NewNode, NewNode->end()); 266*08b48e0bSAndroid Build Coastguard Worker #else 267*08b48e0bSAndroid Build Coastguard Worker NewNode->getInstList().push_back(Comp); 268*08b48e0bSAndroid Build Coastguard Worker #endif 269*08b48e0bSAndroid Build Coastguard Worker 270*08b48e0bSAndroid Build Coastguard Worker bytesChecked[smallestIndex] = true; 271*08b48e0bSAndroid Build Coastguard Worker bool allBytesAreChecked = true; 272*08b48e0bSAndroid Build Coastguard Worker 273*08b48e0bSAndroid Build Coastguard Worker for (std::vector<bool>::iterator BCI = bytesChecked.begin(), 274*08b48e0bSAndroid Build Coastguard Worker E = bytesChecked.end(); 275*08b48e0bSAndroid Build Coastguard Worker BCI != E; ++BCI) { 276*08b48e0bSAndroid Build Coastguard Worker 277*08b48e0bSAndroid Build Coastguard Worker if (!*BCI) { 278*08b48e0bSAndroid Build Coastguard Worker 279*08b48e0bSAndroid Build Coastguard Worker allBytesAreChecked = false; 280*08b48e0bSAndroid Build Coastguard Worker break; 281*08b48e0bSAndroid Build Coastguard Worker 282*08b48e0bSAndroid Build Coastguard Worker } 283*08b48e0bSAndroid Build Coastguard Worker 284*08b48e0bSAndroid Build Coastguard Worker } 285*08b48e0bSAndroid Build Coastguard Worker 286*08b48e0bSAndroid Build Coastguard Worker // if (std::all_of(bytesChecked.begin(), bytesChecked.end(), 287*08b48e0bSAndroid Build Coastguard Worker // [](bool b) { return b; })) { 288*08b48e0bSAndroid Build Coastguard Worker 289*08b48e0bSAndroid Build Coastguard Worker if (allBytesAreChecked) { 290*08b48e0bSAndroid Build Coastguard Worker 291*08b48e0bSAndroid Build Coastguard Worker assert(Cases.size() == 1); 292*08b48e0bSAndroid Build Coastguard Worker BranchInst::Create(Cases[0].BB, NewDefault, Comp, NewNode); 293*08b48e0bSAndroid Build Coastguard Worker 294*08b48e0bSAndroid Build Coastguard Worker /* we have to update the phi nodes! */ 295*08b48e0bSAndroid Build Coastguard Worker for (BasicBlock::iterator I = Cases[0].BB->begin(); 296*08b48e0bSAndroid Build Coastguard Worker I != Cases[0].BB->end(); ++I) { 297*08b48e0bSAndroid Build Coastguard Worker 298*08b48e0bSAndroid Build Coastguard Worker if (!isa<PHINode>(&*I)) { continue; } 299*08b48e0bSAndroid Build Coastguard Worker PHINode *PN = cast<PHINode>(I); 300*08b48e0bSAndroid Build Coastguard Worker 301*08b48e0bSAndroid Build Coastguard Worker /* Only update the first occurrence. */ 302*08b48e0bSAndroid Build Coastguard Worker unsigned Idx = 0, E = PN->getNumIncomingValues(); 303*08b48e0bSAndroid Build Coastguard Worker for (; Idx != E; ++Idx) { 304*08b48e0bSAndroid Build Coastguard Worker 305*08b48e0bSAndroid Build Coastguard Worker if (PN->getIncomingBlock(Idx) == OrigBlock) { 306*08b48e0bSAndroid Build Coastguard Worker 307*08b48e0bSAndroid Build Coastguard Worker PN->setIncomingBlock(Idx, NewNode); 308*08b48e0bSAndroid Build Coastguard Worker break; 309*08b48e0bSAndroid Build Coastguard Worker 310*08b48e0bSAndroid Build Coastguard Worker } 311*08b48e0bSAndroid Build Coastguard Worker 312*08b48e0bSAndroid Build Coastguard Worker } 313*08b48e0bSAndroid Build Coastguard Worker 314*08b48e0bSAndroid Build Coastguard Worker } 315*08b48e0bSAndroid Build Coastguard Worker 316*08b48e0bSAndroid Build Coastguard Worker } else { 317*08b48e0bSAndroid Build Coastguard Worker 318*08b48e0bSAndroid Build Coastguard Worker BasicBlock *BB = switchConvert(Cases, bytesChecked, OrigBlock, NewDefault, 319*08b48e0bSAndroid Build Coastguard Worker Val, level + 1); 320*08b48e0bSAndroid Build Coastguard Worker BranchInst::Create(BB, NewDefault, Comp, NewNode); 321*08b48e0bSAndroid Build Coastguard Worker 322*08b48e0bSAndroid Build Coastguard Worker } 323*08b48e0bSAndroid Build Coastguard Worker 324*08b48e0bSAndroid Build Coastguard Worker } 325*08b48e0bSAndroid Build Coastguard Worker 326*08b48e0bSAndroid Build Coastguard Worker /* there is no byte which we can directly check on, split the tree */ 327*08b48e0bSAndroid Build Coastguard Worker else { 328*08b48e0bSAndroid Build Coastguard Worker 329*08b48e0bSAndroid Build Coastguard Worker std::vector<uint8_t> byteVector; 330*08b48e0bSAndroid Build Coastguard Worker std::copy(byteSets[smallestIndex].begin(), byteSets[smallestIndex].end(), 331*08b48e0bSAndroid Build Coastguard Worker std::back_inserter(byteVector)); 332*08b48e0bSAndroid Build Coastguard Worker std::sort(byteVector.begin(), byteVector.end()); 333*08b48e0bSAndroid Build Coastguard Worker uint8_t pivot = byteVector[byteVector.size() / 2]; 334*08b48e0bSAndroid Build Coastguard Worker 335*08b48e0bSAndroid Build Coastguard Worker /* we already chose to divide the cases based on the value of byte at index 336*08b48e0bSAndroid Build Coastguard Worker * smallestIndex the pivot value determines the threshold for the decicion; 337*08b48e0bSAndroid Build Coastguard Worker * if a case value 338*08b48e0bSAndroid Build Coastguard Worker * is smaller at this byte index move it to the LHS vector, otherwise to the 339*08b48e0bSAndroid Build Coastguard Worker * RHS vector */ 340*08b48e0bSAndroid Build Coastguard Worker 341*08b48e0bSAndroid Build Coastguard Worker CaseVector LHSCases, RHSCases; 342*08b48e0bSAndroid Build Coastguard Worker 343*08b48e0bSAndroid Build Coastguard Worker for (CaseExpr &Case : Cases) { 344*08b48e0bSAndroid Build Coastguard Worker 345*08b48e0bSAndroid Build Coastguard Worker uint8_t byte = (Case.Val->getZExtValue() >> (smallestIndex * 8)) & 0xFF; 346*08b48e0bSAndroid Build Coastguard Worker 347*08b48e0bSAndroid Build Coastguard Worker if (byte < pivot) { 348*08b48e0bSAndroid Build Coastguard Worker 349*08b48e0bSAndroid Build Coastguard Worker LHSCases.push_back(Case); 350*08b48e0bSAndroid Build Coastguard Worker 351*08b48e0bSAndroid Build Coastguard Worker } else { 352*08b48e0bSAndroid Build Coastguard Worker 353*08b48e0bSAndroid Build Coastguard Worker RHSCases.push_back(Case); 354*08b48e0bSAndroid Build Coastguard Worker 355*08b48e0bSAndroid Build Coastguard Worker } 356*08b48e0bSAndroid Build Coastguard Worker 357*08b48e0bSAndroid Build Coastguard Worker } 358*08b48e0bSAndroid Build Coastguard Worker 359*08b48e0bSAndroid Build Coastguard Worker BasicBlock *LBB, *RBB; 360*08b48e0bSAndroid Build Coastguard Worker LBB = switchConvert(LHSCases, bytesChecked, OrigBlock, NewDefault, Val, 361*08b48e0bSAndroid Build Coastguard Worker level + 1); 362*08b48e0bSAndroid Build Coastguard Worker RBB = switchConvert(RHSCases, bytesChecked, OrigBlock, NewDefault, Val, 363*08b48e0bSAndroid Build Coastguard Worker level + 1); 364*08b48e0bSAndroid Build Coastguard Worker 365*08b48e0bSAndroid Build Coastguard Worker /* insert instructions to check whether the value we are switching on is 366*08b48e0bSAndroid Build Coastguard Worker * equal to byte */ 367*08b48e0bSAndroid Build Coastguard Worker ICmpInst *Comp = 368*08b48e0bSAndroid Build Coastguard Worker new ICmpInst(ICmpInst::ICMP_ULT, Trunc, 369*08b48e0bSAndroid Build Coastguard Worker ConstantInt::get(ByteType, pivot), "byteMatch"); 370*08b48e0bSAndroid Build Coastguard Worker #if LLVM_VERSION_MAJOR >= 16 371*08b48e0bSAndroid Build Coastguard Worker Comp->insertInto(NewNode, NewNode->end()); 372*08b48e0bSAndroid Build Coastguard Worker #else 373*08b48e0bSAndroid Build Coastguard Worker NewNode->getInstList().push_back(Comp); 374*08b48e0bSAndroid Build Coastguard Worker #endif 375*08b48e0bSAndroid Build Coastguard Worker BranchInst::Create(LBB, RBB, Comp, NewNode); 376*08b48e0bSAndroid Build Coastguard Worker 377*08b48e0bSAndroid Build Coastguard Worker } 378*08b48e0bSAndroid Build Coastguard Worker 379*08b48e0bSAndroid Build Coastguard Worker return NewNode; 380*08b48e0bSAndroid Build Coastguard Worker 381*08b48e0bSAndroid Build Coastguard Worker } 382*08b48e0bSAndroid Build Coastguard Worker 383*08b48e0bSAndroid Build Coastguard Worker bool SplitSwitchesTransform::splitSwitches(Module &M) { 384*08b48e0bSAndroid Build Coastguard Worker 385*08b48e0bSAndroid Build Coastguard Worker #if (LLVM_VERSION_MAJOR == 3 && LLVM_VERSION_MINOR < 7) 386*08b48e0bSAndroid Build Coastguard Worker LLVMContext &C = M.getContext(); 387*08b48e0bSAndroid Build Coastguard Worker #endif 388*08b48e0bSAndroid Build Coastguard Worker 389*08b48e0bSAndroid Build Coastguard Worker std::vector<SwitchInst *> switches; 390*08b48e0bSAndroid Build Coastguard Worker 391*08b48e0bSAndroid Build Coastguard Worker /* iterate over all functions, bbs and instruction and add 392*08b48e0bSAndroid Build Coastguard Worker * all switches to switches vector for later processing */ 393*08b48e0bSAndroid Build Coastguard Worker for (auto &F : M) { 394*08b48e0bSAndroid Build Coastguard Worker 395*08b48e0bSAndroid Build Coastguard Worker if (!isInInstrumentList(&F, MNAME)) continue; 396*08b48e0bSAndroid Build Coastguard Worker 397*08b48e0bSAndroid Build Coastguard Worker for (auto &BB : F) { 398*08b48e0bSAndroid Build Coastguard Worker 399*08b48e0bSAndroid Build Coastguard Worker SwitchInst *switchInst = nullptr; 400*08b48e0bSAndroid Build Coastguard Worker 401*08b48e0bSAndroid Build Coastguard Worker if ((switchInst = dyn_cast<SwitchInst>(BB.getTerminator()))) { 402*08b48e0bSAndroid Build Coastguard Worker 403*08b48e0bSAndroid Build Coastguard Worker if (switchInst->getNumCases() < 1) continue; 404*08b48e0bSAndroid Build Coastguard Worker switches.push_back(switchInst); 405*08b48e0bSAndroid Build Coastguard Worker 406*08b48e0bSAndroid Build Coastguard Worker } 407*08b48e0bSAndroid Build Coastguard Worker 408*08b48e0bSAndroid Build Coastguard Worker } 409*08b48e0bSAndroid Build Coastguard Worker 410*08b48e0bSAndroid Build Coastguard Worker } 411*08b48e0bSAndroid Build Coastguard Worker 412*08b48e0bSAndroid Build Coastguard Worker if (!switches.size()) return false; 413*08b48e0bSAndroid Build Coastguard Worker /* 414*08b48e0bSAndroid Build Coastguard Worker if (!be_quiet) 415*08b48e0bSAndroid Build Coastguard Worker errs() << "Rewriting " << switches.size() << " switch statements " 416*08b48e0bSAndroid Build Coastguard Worker << "\n"; 417*08b48e0bSAndroid Build Coastguard Worker */ 418*08b48e0bSAndroid Build Coastguard Worker for (auto &SI : switches) { 419*08b48e0bSAndroid Build Coastguard Worker 420*08b48e0bSAndroid Build Coastguard Worker BasicBlock *CurBlock = SI->getParent(); 421*08b48e0bSAndroid Build Coastguard Worker BasicBlock *OrigBlock = CurBlock; 422*08b48e0bSAndroid Build Coastguard Worker Function *F = CurBlock->getParent(); 423*08b48e0bSAndroid Build Coastguard Worker /* this is the value we are switching on */ 424*08b48e0bSAndroid Build Coastguard Worker Value *Val = SI->getCondition(); 425*08b48e0bSAndroid Build Coastguard Worker BasicBlock *Default = SI->getDefaultDest(); 426*08b48e0bSAndroid Build Coastguard Worker unsigned bitw = Val->getType()->getIntegerBitWidth(); 427*08b48e0bSAndroid Build Coastguard Worker 428*08b48e0bSAndroid Build Coastguard Worker /* 429*08b48e0bSAndroid Build Coastguard Worker if (!be_quiet) 430*08b48e0bSAndroid Build Coastguard Worker errs() << "switch: " << SI->getNumCases() << " cases " << bitw 431*08b48e0bSAndroid Build Coastguard Worker << " bit\n"; 432*08b48e0bSAndroid Build Coastguard Worker */ 433*08b48e0bSAndroid Build Coastguard Worker 434*08b48e0bSAndroid Build Coastguard Worker /* If there is only the default destination or the condition checks 8 bit or 435*08b48e0bSAndroid Build Coastguard Worker * less, don't bother with the code below. */ 436*08b48e0bSAndroid Build Coastguard Worker if (SI->getNumCases() < 2 || bitw % 8 || bitw > 64) { 437*08b48e0bSAndroid Build Coastguard Worker 438*08b48e0bSAndroid Build Coastguard Worker // if (!be_quiet) errs() << "skip switch..\n"; 439*08b48e0bSAndroid Build Coastguard Worker continue; 440*08b48e0bSAndroid Build Coastguard Worker 441*08b48e0bSAndroid Build Coastguard Worker } 442*08b48e0bSAndroid Build Coastguard Worker 443*08b48e0bSAndroid Build Coastguard Worker /* Create a new, empty default block so that the new hierarchy of 444*08b48e0bSAndroid Build Coastguard Worker * if-then statements go to this and the PHI nodes are happy. 445*08b48e0bSAndroid Build Coastguard Worker * if the default block is set as an unreachable we avoid creating one 446*08b48e0bSAndroid Build Coastguard Worker * because will never be a valid target.*/ 447*08b48e0bSAndroid Build Coastguard Worker BasicBlock *NewDefault = nullptr; 448*08b48e0bSAndroid Build Coastguard Worker NewDefault = BasicBlock::Create(SI->getContext(), "NewDefault", F, Default); 449*08b48e0bSAndroid Build Coastguard Worker BranchInst::Create(Default, NewDefault); 450*08b48e0bSAndroid Build Coastguard Worker 451*08b48e0bSAndroid Build Coastguard Worker /* Prepare cases vector. */ 452*08b48e0bSAndroid Build Coastguard Worker CaseVector Cases; 453*08b48e0bSAndroid Build Coastguard Worker for (SwitchInst::CaseIt i = SI->case_begin(), e = SI->case_end(); i != e; 454*08b48e0bSAndroid Build Coastguard Worker ++i) 455*08b48e0bSAndroid Build Coastguard Worker #if LLVM_VERSION_MAJOR >= 5 456*08b48e0bSAndroid Build Coastguard Worker Cases.push_back(CaseExpr(i->getCaseValue(), i->getCaseSuccessor())); 457*08b48e0bSAndroid Build Coastguard Worker #else 458*08b48e0bSAndroid Build Coastguard Worker Cases.push_back(CaseExpr(i.getCaseValue(), i.getCaseSuccessor())); 459*08b48e0bSAndroid Build Coastguard Worker #endif 460*08b48e0bSAndroid Build Coastguard Worker /* bugfix thanks to pbst 461*08b48e0bSAndroid Build Coastguard Worker * round up bytesChecked (in case getBitWidth() % 8 != 0) */ 462*08b48e0bSAndroid Build Coastguard Worker std::vector<bool> bytesChecked((7 + Cases[0].Val->getBitWidth()) / 8, 463*08b48e0bSAndroid Build Coastguard Worker false); 464*08b48e0bSAndroid Build Coastguard Worker BasicBlock *SwitchBlock = 465*08b48e0bSAndroid Build Coastguard Worker switchConvert(Cases, bytesChecked, OrigBlock, NewDefault, Val, 0); 466*08b48e0bSAndroid Build Coastguard Worker 467*08b48e0bSAndroid Build Coastguard Worker /* Branch to our shiny new if-then stuff... */ 468*08b48e0bSAndroid Build Coastguard Worker BranchInst::Create(SwitchBlock, OrigBlock); 469*08b48e0bSAndroid Build Coastguard Worker 470*08b48e0bSAndroid Build Coastguard Worker /* We are now done with the switch instruction, delete it. */ 471*08b48e0bSAndroid Build Coastguard Worker #if LLVM_VERSION_MAJOR >= 16 472*08b48e0bSAndroid Build Coastguard Worker SI->eraseFromParent(); 473*08b48e0bSAndroid Build Coastguard Worker #else 474*08b48e0bSAndroid Build Coastguard Worker CurBlock->getInstList().erase(SI); 475*08b48e0bSAndroid Build Coastguard Worker #endif 476*08b48e0bSAndroid Build Coastguard Worker 477*08b48e0bSAndroid Build Coastguard Worker /* we have to update the phi nodes! */ 478*08b48e0bSAndroid Build Coastguard Worker for (BasicBlock::iterator I = Default->begin(); I != Default->end(); ++I) { 479*08b48e0bSAndroid Build Coastguard Worker 480*08b48e0bSAndroid Build Coastguard Worker if (!isa<PHINode>(&*I)) { continue; } 481*08b48e0bSAndroid Build Coastguard Worker PHINode *PN = cast<PHINode>(I); 482*08b48e0bSAndroid Build Coastguard Worker 483*08b48e0bSAndroid Build Coastguard Worker /* Only update the first occurrence. */ 484*08b48e0bSAndroid Build Coastguard Worker unsigned Idx = 0, E = PN->getNumIncomingValues(); 485*08b48e0bSAndroid Build Coastguard Worker for (; Idx != E; ++Idx) { 486*08b48e0bSAndroid Build Coastguard Worker 487*08b48e0bSAndroid Build Coastguard Worker if (PN->getIncomingBlock(Idx) == OrigBlock) { 488*08b48e0bSAndroid Build Coastguard Worker 489*08b48e0bSAndroid Build Coastguard Worker PN->setIncomingBlock(Idx, NewDefault); 490*08b48e0bSAndroid Build Coastguard Worker break; 491*08b48e0bSAndroid Build Coastguard Worker 492*08b48e0bSAndroid Build Coastguard Worker } 493*08b48e0bSAndroid Build Coastguard Worker 494*08b48e0bSAndroid Build Coastguard Worker } 495*08b48e0bSAndroid Build Coastguard Worker 496*08b48e0bSAndroid Build Coastguard Worker } 497*08b48e0bSAndroid Build Coastguard Worker 498*08b48e0bSAndroid Build Coastguard Worker } 499*08b48e0bSAndroid Build Coastguard Worker 500*08b48e0bSAndroid Build Coastguard Worker verifyModule(M); 501*08b48e0bSAndroid Build Coastguard Worker return true; 502*08b48e0bSAndroid Build Coastguard Worker 503*08b48e0bSAndroid Build Coastguard Worker } 504*08b48e0bSAndroid Build Coastguard Worker 505*08b48e0bSAndroid Build Coastguard Worker #if LLVM_VERSION_MAJOR >= 11 /* use new pass manager */ 506*08b48e0bSAndroid Build Coastguard Worker PreservedAnalyses SplitSwitchesTransform::run(Module &M, 507*08b48e0bSAndroid Build Coastguard Worker ModuleAnalysisManager &MAM) { 508*08b48e0bSAndroid Build Coastguard Worker 509*08b48e0bSAndroid Build Coastguard Worker #else 510*08b48e0bSAndroid Build Coastguard Worker bool SplitSwitchesTransform::runOnModule(Module &M) { 511*08b48e0bSAndroid Build Coastguard Worker 512*08b48e0bSAndroid Build Coastguard Worker #endif 513*08b48e0bSAndroid Build Coastguard Worker 514*08b48e0bSAndroid Build Coastguard Worker if ((isatty(2) && getenv("AFL_QUIET") == NULL) || getenv("AFL_DEBUG") != NULL) 515*08b48e0bSAndroid Build Coastguard Worker printf("Running split-switches-pass by [email protected]\n"); 516*08b48e0bSAndroid Build Coastguard Worker else 517*08b48e0bSAndroid Build Coastguard Worker be_quiet = 1; 518*08b48e0bSAndroid Build Coastguard Worker 519*08b48e0bSAndroid Build Coastguard Worker #if LLVM_VERSION_MAJOR >= 11 /* use new pass manager */ 520*08b48e0bSAndroid Build Coastguard Worker auto PA = PreservedAnalyses::all(); 521*08b48e0bSAndroid Build Coastguard Worker #endif 522*08b48e0bSAndroid Build Coastguard Worker 523*08b48e0bSAndroid Build Coastguard Worker splitSwitches(M); 524*08b48e0bSAndroid Build Coastguard Worker verifyModule(M); 525*08b48e0bSAndroid Build Coastguard Worker 526*08b48e0bSAndroid Build Coastguard Worker #if LLVM_VERSION_MAJOR >= 11 /* use new pass manager */ 527*08b48e0bSAndroid Build Coastguard Worker /* if (modified) { 528*08b48e0bSAndroid Build Coastguard Worker 529*08b48e0bSAndroid Build Coastguard Worker PA.abandon<XX_Manager>(); 530*08b48e0bSAndroid Build Coastguard Worker 531*08b48e0bSAndroid Build Coastguard Worker }*/ 532*08b48e0bSAndroid Build Coastguard Worker 533*08b48e0bSAndroid Build Coastguard Worker return PA; 534*08b48e0bSAndroid Build Coastguard Worker #else 535*08b48e0bSAndroid Build Coastguard Worker return true; 536*08b48e0bSAndroid Build Coastguard Worker #endif 537*08b48e0bSAndroid Build Coastguard Worker 538*08b48e0bSAndroid Build Coastguard Worker } 539*08b48e0bSAndroid Build Coastguard Worker 540*08b48e0bSAndroid Build Coastguard Worker #if LLVM_VERSION_MAJOR < 11 /* use old pass manager */ 541*08b48e0bSAndroid Build Coastguard Worker static void registerSplitSwitchesTransPass(const PassManagerBuilder &, 542*08b48e0bSAndroid Build Coastguard Worker legacy::PassManagerBase &PM) { 543*08b48e0bSAndroid Build Coastguard Worker 544*08b48e0bSAndroid Build Coastguard Worker auto p = new SplitSwitchesTransform(); 545*08b48e0bSAndroid Build Coastguard Worker PM.add(p); 546*08b48e0bSAndroid Build Coastguard Worker 547*08b48e0bSAndroid Build Coastguard Worker } 548*08b48e0bSAndroid Build Coastguard Worker 549*08b48e0bSAndroid Build Coastguard Worker static RegisterStandardPasses RegisterSplitSwitchesTransPass( 550*08b48e0bSAndroid Build Coastguard Worker PassManagerBuilder::EP_OptimizerLast, registerSplitSwitchesTransPass); 551*08b48e0bSAndroid Build Coastguard Worker 552*08b48e0bSAndroid Build Coastguard Worker static RegisterStandardPasses RegisterSplitSwitchesTransPass0( 553*08b48e0bSAndroid Build Coastguard Worker PassManagerBuilder::EP_EnabledOnOptLevel0, registerSplitSwitchesTransPass); 554*08b48e0bSAndroid Build Coastguard Worker 555*08b48e0bSAndroid Build Coastguard Worker #if LLVM_VERSION_MAJOR >= 11 556*08b48e0bSAndroid Build Coastguard Worker static RegisterStandardPasses RegisterSplitSwitchesTransPassLTO( 557*08b48e0bSAndroid Build Coastguard Worker PassManagerBuilder::EP_FullLinkTimeOptimizationLast, 558*08b48e0bSAndroid Build Coastguard Worker registerSplitSwitchesTransPass); 559*08b48e0bSAndroid Build Coastguard Worker #endif 560*08b48e0bSAndroid Build Coastguard Worker #endif 561*08b48e0bSAndroid Build Coastguard Worker 562