1 /* 2 american fuzzy lop++ - LLVM CmpLog instrumentation 3 -------------------------------------------------- 4 5 Written by Andrea Fioraldi <[email protected]> 6 7 Copyright 2015, 2016 Google Inc. All rights reserved. 8 Copyright 2019-2024 AFLplusplus Project. All rights reserved. 9 10 Licensed under the Apache License, Version 2.0 (the "License"); 11 you may not use this file except in compliance with the License. 12 You may obtain a copy of the License at: 13 14 https://www.apache.org/licenses/LICENSE-2.0 15 16 */ 17 18 #include <stdio.h> 19 #include <stdlib.h> 20 #include <unistd.h> 21 22 #include <iostream> 23 #include <list> 24 #include <string> 25 #include <fstream> 26 #include <sys/time.h> 27 28 #include "llvm/Config/llvm-config.h" 29 #include "llvm/ADT/Statistic.h" 30 #include "llvm/IR/IRBuilder.h" 31 #if LLVM_VERSION_MAJOR >= 11 /* use new pass manager */ 32 #include "llvm/Passes/PassPlugin.h" 33 #include "llvm/Passes/PassBuilder.h" 34 #include "llvm/IR/PassManager.h" 35 #else 36 #include "llvm/IR/LegacyPassManager.h" 37 #include "llvm/Transforms/IPO/PassManagerBuilder.h" 38 #endif 39 #include "llvm/IR/Module.h" 40 #include "llvm/Support/Debug.h" 41 #include "llvm/Support/raw_ostream.h" 42 #if LLVM_VERSION_MAJOR < 17 43 #include "llvm/Transforms/IPO/PassManagerBuilder.h" 44 #endif 45 #include "llvm/Transforms/Utils/BasicBlockUtils.h" 46 #include "llvm/Pass.h" 47 #include "llvm/Analysis/ValueTracking.h" 48 49 #include "llvm/IR/IRBuilder.h" 50 #if LLVM_VERSION_MAJOR >= 4 || \ 51 (LLVM_VERSION_MAJOR == 3 && LLVM_VERSION_MINOR > 4) 52 #include "llvm/IR/Verifier.h" 53 #include "llvm/IR/DebugInfo.h" 54 #else 55 #include "llvm/Analysis/Verifier.h" 56 #include "llvm/DebugInfo.h" 57 #define nullptr 0 58 #endif 59 60 #include <set> 61 #include "afl-llvm-common.h" 62 63 using namespace llvm; 64 65 namespace { 66 67 #if LLVM_VERSION_MAJOR >= 11 /* use new pass manager */ 68 class CmplogSwitches : public PassInfoMixin<CmplogSwitches> { 69 70 public: CmplogSwitches()71 CmplogSwitches() { 72 73 #else 74 class CmplogSwitches : public ModulePass { 75 76 public: 77 static char ID; 78 CmplogSwitches() : ModulePass(ID) { 79 80 #endif 81 initInstrumentList(); 82 83 } 84 85 #if LLVM_VERSION_MAJOR >= 11 /* use new pass manager */ 86 PreservedAnalyses run(Module &M, ModuleAnalysisManager &MAM); 87 #else 88 bool runOnModule(Module &M) override; 89 90 #if LLVM_VERSION_MAJOR < 4 91 const char *getPassName() const override { 92 93 #else 94 StringRef getPassName() const override { 95 96 #endif 97 return "cmplog switch split"; 98 99 } 100 101 #endif 102 103 private: 104 bool hookInstrs(Module &M); 105 106 }; 107 108 } // namespace 109 110 #if LLVM_MAJOR >= 11 111 extern "C" ::llvm::PassPluginLibraryInfo LLVM_ATTRIBUTE_WEAK 112 llvmGetPassPluginInfo() { 113 114 return {LLVM_PLUGIN_API_VERSION, "cmplogswitches", "v0.1", 115 /* lambda to insert our pass into the pass pipeline. */ 116 [](PassBuilder &PB) { 117 118 #if LLVM_VERSION_MAJOR <= 13 119 using OptimizationLevel = typename PassBuilder::OptimizationLevel; 120 #endif 121 PB.registerOptimizerLastEPCallback( 122 [](ModulePassManager &MPM, OptimizationLevel OL) { 123 124 MPM.addPass(CmplogSwitches()); 125 126 }); 127 128 }}; 129 130 } 131 132 #else 133 char CmplogSwitches::ID = 0; 134 #endif 135 136 template <class Iterator> 137 Iterator Unique(Iterator first, Iterator last) { 138 139 while (first != last) { 140 141 Iterator next(first); 142 last = std::remove(++next, last, *first); 143 first = next; 144 145 } 146 147 return last; 148 149 } 150 151 bool CmplogSwitches::hookInstrs(Module &M) { 152 153 std::vector<SwitchInst *> switches; 154 LLVMContext &C = M.getContext(); 155 156 Type *VoidTy = Type::getVoidTy(C); 157 IntegerType *Int8Ty = IntegerType::getInt8Ty(C); 158 IntegerType *Int16Ty = IntegerType::getInt16Ty(C); 159 IntegerType *Int32Ty = IntegerType::getInt32Ty(C); 160 IntegerType *Int64Ty = IntegerType::getInt64Ty(C); 161 162 #if LLVM_VERSION_MAJOR >= 9 163 FunctionCallee 164 #else 165 Constant * 166 #endif 167 c1 = M.getOrInsertFunction("__cmplog_ins_hook1", VoidTy, Int8Ty, Int8Ty, 168 Int8Ty 169 #if LLVM_VERSION_MAJOR < 5 170 , 171 NULL 172 #endif 173 ); 174 #if LLVM_VERSION_MAJOR >= 9 175 FunctionCallee cmplogHookIns1 = c1; 176 #else 177 Function *cmplogHookIns1 = cast<Function>(c1); 178 #endif 179 180 #if LLVM_VERSION_MAJOR >= 9 181 FunctionCallee 182 #else 183 Constant * 184 #endif 185 c2 = M.getOrInsertFunction("__cmplog_ins_hook2", VoidTy, Int16Ty, Int16Ty, 186 Int8Ty 187 #if LLVM_VERSION_MAJOR < 5 188 , 189 NULL 190 #endif 191 ); 192 #if LLVM_VERSION_MAJOR >= 9 193 FunctionCallee cmplogHookIns2 = c2; 194 #else 195 Function *cmplogHookIns2 = cast<Function>(c2); 196 #endif 197 198 #if LLVM_VERSION_MAJOR >= 9 199 FunctionCallee 200 #else 201 Constant * 202 #endif 203 c4 = M.getOrInsertFunction("__cmplog_ins_hook4", VoidTy, Int32Ty, Int32Ty, 204 Int8Ty 205 #if LLVM_VERSION_MAJOR < 5 206 , 207 NULL 208 #endif 209 ); 210 #if LLVM_VERSION_MAJOR >= 9 211 FunctionCallee cmplogHookIns4 = c4; 212 #else 213 Function *cmplogHookIns4 = cast<Function>(c4); 214 #endif 215 216 #if LLVM_VERSION_MAJOR >= 9 217 FunctionCallee 218 #else 219 Constant * 220 #endif 221 c8 = M.getOrInsertFunction("__cmplog_ins_hook8", VoidTy, Int64Ty, Int64Ty, 222 Int8Ty 223 #if LLVM_VERSION_MAJOR < 5 224 , 225 NULL 226 #endif 227 ); 228 #if LLVM_VERSION_MAJOR >= 9 229 FunctionCallee cmplogHookIns8 = c8; 230 #else 231 Function *cmplogHookIns8 = cast<Function>(c8); 232 #endif 233 234 GlobalVariable *AFLCmplogPtr = M.getNamedGlobal("__afl_cmp_map"); 235 236 if (!AFLCmplogPtr) { 237 238 AFLCmplogPtr = new GlobalVariable(M, PointerType::get(Int8Ty, 0), false, 239 GlobalValue::ExternalWeakLinkage, 0, 240 "__afl_cmp_map"); 241 242 } 243 244 Constant *Null = Constant::getNullValue(PointerType::get(Int8Ty, 0)); 245 246 /* iterate over all functions, bbs and instruction and add suitable calls */ 247 for (auto &F : M) { 248 249 if (!isInInstrumentList(&F, MNAME)) continue; 250 251 for (auto &BB : F) { 252 253 SwitchInst *switchInst = nullptr; 254 if ((switchInst = dyn_cast<SwitchInst>(BB.getTerminator()))) { 255 256 if (switchInst->getNumCases() > 1) { switches.push_back(switchInst); } 257 258 } 259 260 } 261 262 } 263 264 // unique the collected switches 265 switches.erase(Unique(switches.begin(), switches.end()), switches.end()); 266 267 // Instrument switch values for cmplog 268 if (switches.size()) { 269 270 if (!be_quiet) 271 errs() << "Hooking " << switches.size() << " switch instructions\n"; 272 273 for (auto &SI : switches) { 274 275 Value *Val = SI->getCondition(); 276 unsigned int max_size = Val->getType()->getIntegerBitWidth(), cast_size; 277 unsigned char do_cast = 0; 278 279 if (!SI->getNumCases() || max_size < 16) { 280 281 // if (!be_quiet) errs() << "skip trivial switch..\n"; 282 continue; 283 284 } 285 286 if (max_size % 8) { 287 288 max_size = (((max_size / 8) + 1) * 8); 289 do_cast = 1; 290 291 } 292 293 IRBuilder<> IRB2(SI->getParent()); 294 IRB2.SetInsertPoint(SI); 295 296 LoadInst *CmpPtr = IRB2.CreateLoad( 297 #if LLVM_VERSION_MAJOR >= 14 298 PointerType::get(Int8Ty, 0), 299 #endif 300 AFLCmplogPtr); 301 CmpPtr->setMetadata(M.getMDKindID("nosanitize"), MDNode::get(C, None)); 302 auto is_not_null = IRB2.CreateICmpNE(CmpPtr, Null); 303 auto ThenTerm = SplitBlockAndInsertIfThen(is_not_null, SI, false); 304 305 IRBuilder<> IRB(ThenTerm); 306 307 if (max_size > 128) { 308 309 if (!be_quiet) { 310 311 fprintf(stderr, 312 "Cannot handle this switch bit size: %u (truncating)\n", 313 max_size); 314 315 } 316 317 max_size = 128; 318 do_cast = 1; 319 320 } 321 322 // do we need to cast? 323 switch (max_size) { 324 325 case 8: 326 case 16: 327 case 32: 328 case 64: 329 case 128: 330 cast_size = max_size; 331 break; 332 default: 333 cast_size = 128; 334 do_cast = 1; 335 336 } 337 338 Value *CompareTo = Val; 339 340 if (do_cast) { 341 342 CompareTo = 343 IRB.CreateIntCast(CompareTo, IntegerType::get(C, cast_size), false); 344 345 } 346 347 for (SwitchInst::CaseIt i = SI->case_begin(), e = SI->case_end(); i != e; 348 ++i) { 349 350 #if LLVM_VERSION_MAJOR < 5 351 ConstantInt *cint = i.getCaseValue(); 352 #else 353 ConstantInt *cint = i->getCaseValue(); 354 #endif 355 356 if (cint) { 357 358 std::vector<Value *> args; 359 args.push_back(CompareTo); 360 361 Value *new_param = cint; 362 363 if (do_cast) { 364 365 new_param = 366 IRB.CreateIntCast(cint, IntegerType::get(C, cast_size), false); 367 368 } 369 370 if (new_param) { 371 372 args.push_back(new_param); 373 ConstantInt *attribute = ConstantInt::get(Int8Ty, 1); 374 args.push_back(attribute); 375 if (cast_size != max_size) { 376 377 ConstantInt *bitsize = 378 ConstantInt::get(Int8Ty, (max_size / 8) - 1); 379 args.push_back(bitsize); 380 381 } 382 383 switch (cast_size) { 384 385 case 8: 386 IRB.CreateCall(cmplogHookIns1, args); 387 break; 388 case 16: 389 IRB.CreateCall(cmplogHookIns2, args); 390 break; 391 case 32: 392 IRB.CreateCall(cmplogHookIns4, args); 393 break; 394 case 64: 395 IRB.CreateCall(cmplogHookIns8, args); 396 break; 397 case 128: 398 #ifdef WORD_SIZE_64 399 if (max_size == 128) { 400 401 IRB.CreateCall(cmplogHookIns16, args); 402 403 } else { 404 405 IRB.CreateCall(cmplogHookInsN, args); 406 407 } 408 409 #endif 410 break; 411 default: 412 break; 413 414 } 415 416 } 417 418 } 419 420 } 421 422 } 423 424 } 425 426 if (switches.size()) 427 return true; 428 else 429 return false; 430 431 } 432 433 #if LLVM_VERSION_MAJOR >= 11 /* use new pass manager */ 434 PreservedAnalyses CmplogSwitches::run(Module &M, ModuleAnalysisManager &MAM) { 435 436 #else 437 bool CmplogSwitches::runOnModule(Module &M) { 438 439 #endif 440 441 if (getenv("AFL_QUIET") == NULL) 442 printf("Running cmplog-switches-pass by [email protected]\n"); 443 else 444 be_quiet = 1; 445 hookInstrs(M); 446 #if LLVM_VERSION_MAJOR >= 11 /* use new pass manager */ 447 auto PA = PreservedAnalyses::all(); 448 #endif 449 verifyModule(M); 450 451 #if LLVM_VERSION_MAJOR >= 11 /* use new pass manager */ 452 return PA; 453 #else 454 return true; 455 #endif 456 457 } 458 459 #if LLVM_VERSION_MAJOR < 11 /* use old pass manager */ 460 static void registerCmplogSwitchesPass(const PassManagerBuilder &, 461 legacy::PassManagerBase &PM) { 462 463 auto p = new CmplogSwitches(); 464 PM.add(p); 465 466 } 467 468 static RegisterStandardPasses RegisterCmplogSwitchesPass( 469 PassManagerBuilder::EP_OptimizerLast, registerCmplogSwitchesPass); 470 471 static RegisterStandardPasses RegisterCmplogSwitchesPass0( 472 PassManagerBuilder::EP_EnabledOnOptLevel0, registerCmplogSwitchesPass); 473 474 #if LLVM_VERSION_MAJOR >= 11 475 static RegisterStandardPasses RegisterCmplogSwitchesPassLTO( 476 PassManagerBuilder::EP_FullLinkTimeOptimizationLast, 477 registerCmplogSwitchesPass); 478 #endif 479 #endif 480 481