1 /* 2 american fuzzy lop++ - LLVM CmpLog instrumentation 3 -------------------------------------------------- 4 5 Written by Andrea Fioraldi <[email protected]> 6 7 Copyright 2015, 2016 Google Inc. All rights reserved. 8 Copyright 2019-2024 AFLplusplus Project. All rights reserved. 9 10 Licensed under the Apache License, Version 2.0 (the "License"); 11 you may not use this file except in compliance with the License. 12 You may obtain a copy of the License at: 13 14 https://www.apache.org/licenses/LICENSE-2.0 15 16 */ 17 18 #include <stdio.h> 19 #include <stdlib.h> 20 #include <unistd.h> 21 22 #include <list> 23 #include <string> 24 #include <fstream> 25 #include <sys/time.h> 26 #include "llvm/Config/llvm-config.h" 27 28 #include "llvm/ADT/Statistic.h" 29 #include "llvm/IR/IRBuilder.h" 30 #if LLVM_VERSION_MAJOR >= 11 /* use new pass manager */ 31 #include "llvm/Passes/PassPlugin.h" 32 #include "llvm/Passes/PassBuilder.h" 33 #include "llvm/IR/PassManager.h" 34 #else 35 #include "llvm/IR/LegacyPassManager.h" 36 #include "llvm/Transforms/IPO/PassManagerBuilder.h" 37 #endif 38 #include "llvm/IR/Module.h" 39 #include "llvm/Support/Debug.h" 40 #include "llvm/Support/raw_ostream.h" 41 #if LLVM_VERSION_MAJOR < 17 42 #include "llvm/Transforms/IPO/PassManagerBuilder.h" 43 #endif 44 #include "llvm/Transforms/Utils/BasicBlockUtils.h" 45 #include "llvm/Pass.h" 46 #include "llvm/Analysis/ValueTracking.h" 47 48 #include "llvm/IR/IRBuilder.h" 49 #if LLVM_VERSION_MAJOR >= 4 || \ 50 (LLVM_VERSION_MAJOR == 3 && LLVM_VERSION_MINOR > 4) 51 #include "llvm/IR/Verifier.h" 52 #include "llvm/IR/DebugInfo.h" 53 #else 54 #include "llvm/Analysis/Verifier.h" 55 #include "llvm/DebugInfo.h" 56 #define nullptr 0 57 #endif 58 59 #include <set> 60 #include "afl-llvm-common.h" 61 62 using namespace llvm; 63 64 namespace { 65 66 #if LLVM_VERSION_MAJOR >= 11 /* use new pass manager */ 67 class CmpLogRoutines : public PassInfoMixin<CmpLogRoutines> { 68 69 public: CmpLogRoutines()70 CmpLogRoutines() { 71 72 #else 73 class CmpLogRoutines : public ModulePass { 74 75 public: 76 static char ID; 77 CmpLogRoutines() : ModulePass(ID) { 78 79 #endif 80 81 initInstrumentList(); 82 83 } 84 85 #if LLVM_VERSION_MAJOR >= 11 /* use new pass manager */ 86 PreservedAnalyses run(Module &M, ModuleAnalysisManager &MAM); 87 #else 88 bool runOnModule(Module &M) override; 89 90 #if LLVM_VERSION_MAJOR >= 4 91 StringRef getPassName() const override { 92 93 #else 94 const char *getPassName() const override { 95 96 #endif 97 return "cmplog routines"; 98 99 } 100 101 #endif 102 103 private: 104 bool hookRtns(Module &M); 105 106 }; 107 108 } // namespace 109 110 #if LLVM_MAJOR >= 11 111 extern "C" ::llvm::PassPluginLibraryInfo LLVM_ATTRIBUTE_WEAK 112 llvmGetPassPluginInfo() { 113 114 return {LLVM_PLUGIN_API_VERSION, "cmplogroutines", "v0.1", 115 /* lambda to insert our pass into the pass pipeline. */ 116 [](PassBuilder &PB) { 117 118 #if LLVM_VERSION_MAJOR <= 13 119 using OptimizationLevel = typename PassBuilder::OptimizationLevel; 120 #endif 121 PB.registerOptimizerLastEPCallback( 122 [](ModulePassManager &MPM, OptimizationLevel OL) { 123 124 MPM.addPass(CmpLogRoutines()); 125 126 }); 127 128 }}; 129 130 } 131 132 #else 133 char CmpLogRoutines::ID = 0; 134 #endif 135 136 bool CmpLogRoutines::hookRtns(Module &M) { 137 138 std::vector<CallInst *> calls, llvmStdStd, llvmStdC, gccStdStd, gccStdC, 139 Memcmp, Strcmp, Strncmp; 140 LLVMContext &C = M.getContext(); 141 142 Type *VoidTy = Type::getVoidTy(C); 143 // PointerType *VoidPtrTy = PointerType::get(VoidTy, 0); 144 IntegerType *Int8Ty = IntegerType::getInt8Ty(C); 145 IntegerType *Int64Ty = IntegerType::getInt64Ty(C); 146 PointerType *i8PtrTy = PointerType::get(Int8Ty, 0); 147 148 #if LLVM_VERSION_MAJOR >= 9 149 FunctionCallee 150 #else 151 Constant * 152 #endif 153 c = M.getOrInsertFunction("__cmplog_rtn_hook", VoidTy, i8PtrTy, i8PtrTy 154 #if LLVM_VERSION_MAJOR < 5 155 , 156 NULL 157 #endif 158 ); 159 #if LLVM_VERSION_MAJOR >= 9 160 FunctionCallee cmplogHookFn = c; 161 #else 162 Function *cmplogHookFn = cast<Function>(c); 163 #endif 164 165 #if LLVM_VERSION_MAJOR >= 9 166 FunctionCallee 167 #else 168 Constant * 169 #endif 170 c1 = M.getOrInsertFunction("__cmplog_rtn_llvm_stdstring_stdstring", 171 VoidTy, i8PtrTy, i8PtrTy 172 #if LLVM_VERSION_MAJOR < 5 173 , 174 NULL 175 #endif 176 ); 177 #if LLVM_VERSION_MAJOR >= 9 178 FunctionCallee cmplogLlvmStdStd = c1; 179 #else 180 Function *cmplogLlvmStdStd = cast<Function>(c1); 181 #endif 182 183 #if LLVM_VERSION_MAJOR >= 9 184 FunctionCallee 185 #else 186 Constant * 187 #endif 188 c2 = M.getOrInsertFunction("__cmplog_rtn_llvm_stdstring_cstring", VoidTy, 189 i8PtrTy, i8PtrTy 190 #if LLVM_VERSION_MAJOR < 5 191 , 192 NULL 193 #endif 194 ); 195 #if LLVM_VERSION_MAJOR >= 9 196 FunctionCallee cmplogLlvmStdC = c2; 197 #else 198 Function *cmplogLlvmStdC = cast<Function>(c2); 199 #endif 200 201 #if LLVM_VERSION_MAJOR >= 9 202 FunctionCallee 203 #else 204 Constant * 205 #endif 206 c3 = M.getOrInsertFunction("__cmplog_rtn_gcc_stdstring_stdstring", VoidTy, 207 i8PtrTy, i8PtrTy 208 #if LLVM_VERSION_MAJOR < 5 209 , 210 NULL 211 #endif 212 ); 213 #if LLVM_VERSION_MAJOR >= 9 214 FunctionCallee cmplogGccStdStd = c3; 215 #else 216 Function *cmplogGccStdStd = cast<Function>(c3); 217 #endif 218 219 #if LLVM_VERSION_MAJOR >= 9 220 FunctionCallee 221 #else 222 Constant * 223 #endif 224 c4 = M.getOrInsertFunction("__cmplog_rtn_gcc_stdstring_cstring", VoidTy, 225 i8PtrTy, i8PtrTy 226 #if LLVM_VERSION_MAJOR < 5 227 , 228 NULL 229 #endif 230 ); 231 #if LLVM_VERSION_MAJOR >= 9 232 FunctionCallee cmplogGccStdC = c4; 233 #else 234 Function *cmplogGccStdC = cast<Function>(c4); 235 #endif 236 237 #if LLVM_VERSION_MAJOR >= 9 238 FunctionCallee 239 #else 240 Constant * 241 #endif 242 c5 = M.getOrInsertFunction("__cmplog_rtn_hook_n", VoidTy, i8PtrTy, 243 i8PtrTy, Int64Ty 244 #if LLVM_VERSION_MAJOR < 5 245 , 246 NULL 247 #endif 248 ); 249 #if LLVM_VERSION_MAJOR >= 9 250 FunctionCallee cmplogHookFnN = c5; 251 #else 252 Function *cmplogHookFnN = cast<Function>(c5); 253 #endif 254 255 #if LLVM_VERSION_MAJOR >= 9 256 FunctionCallee 257 #else 258 Constant * 259 #endif 260 c6 = M.getOrInsertFunction("__cmplog_rtn_hook_strn", VoidTy, i8PtrTy, 261 i8PtrTy, Int64Ty 262 #if LLVM_VERSION_MAJOR < 5 263 , 264 NULL 265 #endif 266 ); 267 #if LLVM_VERSION_MAJOR >= 9 268 FunctionCallee cmplogHookFnStrN = c6; 269 #else 270 Function *cmplogHookFnStrN = cast<Function>(c6); 271 #endif 272 273 #if LLVM_VERSION_MAJOR >= 9 274 FunctionCallee 275 #else 276 Constant * 277 #endif 278 c7 = M.getOrInsertFunction("__cmplog_rtn_hook_str", VoidTy, i8PtrTy, 279 i8PtrTy 280 #if LLVM_VERSION_MAJOR < 5 281 , 282 NULL 283 #endif 284 ); 285 #if LLVM_VERSION_MAJOR >= 9 286 FunctionCallee cmplogHookFnStr = c7; 287 #else 288 Function *cmplogHookFnStr = cast<Function>(c7); 289 #endif 290 291 GlobalVariable *AFLCmplogPtr = M.getNamedGlobal("__afl_cmp_map"); 292 293 if (!AFLCmplogPtr) { 294 295 AFLCmplogPtr = new GlobalVariable(M, PointerType::get(Int8Ty, 0), false, 296 GlobalValue::ExternalWeakLinkage, 0, 297 "__afl_cmp_map"); 298 299 } 300 301 Constant *Null = Constant::getNullValue(PointerType::get(Int8Ty, 0)); 302 303 /* iterate over all functions, bbs and instruction and add suitable calls */ 304 for (auto &F : M) { 305 306 if (!isInInstrumentList(&F, MNAME)) continue; 307 308 for (auto &BB : F) { 309 310 for (auto &IN : BB) { 311 312 CallInst *callInst = nullptr; 313 314 if ((callInst = dyn_cast<CallInst>(&IN))) { 315 316 Function *Callee = callInst->getCalledFunction(); 317 if (!Callee) continue; 318 if (callInst->getCallingConv() != llvm::CallingConv::C) continue; 319 320 FunctionType *FT = Callee->getFunctionType(); 321 std::string FuncName = Callee->getName().str(); 322 323 bool isPtrRtn = FT->getNumParams() >= 2 && 324 !FT->getReturnType()->isVoidTy() && 325 FT->getParamType(0) == FT->getParamType(1) && 326 FT->getParamType(0)->isPointerTy(); 327 328 bool isPtrRtnN = FT->getNumParams() >= 3 && 329 !FT->getReturnType()->isVoidTy() && 330 FT->getParamType(0) == FT->getParamType(1) && 331 FT->getParamType(0)->isPointerTy() && 332 FT->getParamType(2)->isIntegerTy(); 333 if (isPtrRtnN) { 334 335 auto intTyOp = 336 dyn_cast<IntegerType>(callInst->getArgOperand(2)->getType()); 337 if (intTyOp) { 338 339 if (intTyOp->getBitWidth() != 32 && 340 intTyOp->getBitWidth() != 64) { 341 342 isPtrRtnN = false; 343 344 } 345 346 } 347 348 } 349 350 bool isMemcmp = 351 (!FuncName.compare("memcmp") || !FuncName.compare("bcmp") || 352 !FuncName.compare("CRYPTO_memcmp") || 353 !FuncName.compare("OPENSSL_memcmp") || 354 !FuncName.compare("memcmp_const_time") || 355 !FuncName.compare("memcmpct")); 356 isMemcmp &= FT->getNumParams() == 3 && 357 FT->getReturnType()->isIntegerTy(32) && 358 FT->getParamType(0)->isPointerTy() && 359 FT->getParamType(1)->isPointerTy() && 360 FT->getParamType(2)->isIntegerTy(); 361 362 bool isStrcmp = 363 (!FuncName.compare("strcmp") || !FuncName.compare("xmlStrcmp") || 364 !FuncName.compare("xmlStrEqual") || 365 !FuncName.compare("g_strcmp0") || 366 !FuncName.compare("curl_strequal") || 367 !FuncName.compare("strcsequal") || 368 !FuncName.compare("strcasecmp") || 369 !FuncName.compare("stricmp") || 370 !FuncName.compare("ap_cstr_casecmp") || 371 !FuncName.compare("OPENSSL_strcasecmp") || 372 !FuncName.compare("xmlStrcasecmp") || 373 !FuncName.compare("g_strcasecmp") || 374 !FuncName.compare("g_ascii_strcasecmp") || 375 !FuncName.compare("Curl_strcasecompare") || 376 !FuncName.compare("Curl_safe_strcasecompare") || 377 !FuncName.compare("cmsstrcasecmp") || 378 !FuncName.compare("strstr") || 379 !FuncName.compare("g_strstr_len") || 380 !FuncName.compare("ap_strcasestr") || 381 !FuncName.compare("xmlStrstr") || 382 !FuncName.compare("xmlStrcasestr") || 383 !FuncName.compare("g_str_has_prefix") || 384 !FuncName.compare("g_str_has_suffix")); 385 isStrcmp &= 386 FT->getNumParams() == 2 && FT->getReturnType()->isIntegerTy(32) && 387 FT->getParamType(0) == FT->getParamType(1) && 388 FT->getParamType(0) == 389 IntegerType::getInt8Ty(M.getContext())->getPointerTo(0); 390 391 bool isStrncmp = (!FuncName.compare("strncmp") || 392 !FuncName.compare("xmlStrncmp") || 393 !FuncName.compare("curl_strnequal") || 394 !FuncName.compare("strncasecmp") || 395 !FuncName.compare("strnicmp") || 396 !FuncName.compare("ap_cstr_casecmpn") || 397 !FuncName.compare("OPENSSL_strncasecmp") || 398 !FuncName.compare("xmlStrncasecmp") || 399 !FuncName.compare("g_ascii_strncasecmp") || 400 !FuncName.compare("Curl_strncasecompare") || 401 !FuncName.compare("g_strncasecmp")); 402 isStrncmp &= 403 FT->getNumParams() == 3 && FT->getReturnType()->isIntegerTy(32) && 404 FT->getParamType(0) == FT->getParamType(1) && 405 FT->getParamType(0) == 406 IntegerType::getInt8Ty(M.getContext())->getPointerTo(0) && 407 FT->getParamType(2)->isIntegerTy(); 408 409 bool isGccStdStringStdString = 410 Callee->getName().find("__is_charIT_EE7__value") != 411 std::string::npos && 412 Callee->getName().find( 413 "St7__cxx1112basic_stringIS2_St11char_traits") != 414 std::string::npos && 415 FT->getNumParams() >= 2 && 416 FT->getParamType(0) == FT->getParamType(1) && 417 FT->getParamType(0)->isPointerTy(); 418 419 bool isGccStdStringCString = 420 Callee->getName().find( 421 "St7__cxx1112basic_stringIcSt11char_" 422 "traitsIcESaIcEE7compareEPK") != std::string::npos && 423 FT->getNumParams() >= 2 && FT->getParamType(0)->isPointerTy() && 424 FT->getParamType(1)->isPointerTy(); 425 426 bool isLlvmStdStringStdString = 427 Callee->getName().find("_ZNSt3__1eqI") != std::string::npos && 428 Callee->getName().find("_12basic_stringI") != std::string::npos && 429 Callee->getName().find("_11char_traits") != std::string::npos && 430 FT->getNumParams() >= 2 && FT->getParamType(0)->isPointerTy() && 431 FT->getParamType(1)->isPointerTy(); 432 433 bool isLlvmStdStringCString = 434 Callee->getName().find("_ZNSt3__1eqI") != std::string::npos && 435 Callee->getName().find("_12basic_stringI") != std::string::npos && 436 FT->getNumParams() >= 2 && FT->getParamType(0)->isPointerTy() && 437 FT->getParamType(1)->isPointerTy(); 438 439 /* 440 { 441 442 fprintf(stderr, "F:%s C:%s argc:%u\n", 443 F.getName().str().c_str(), 444 Callee->getName().str().c_str(), FT->getNumParams()); 445 fprintf(stderr, "ptr0:%u ptr1:%u ptr2:%u\n", 446 FT->getParamType(0)->isPointerTy(), 447 FT->getParamType(1)->isPointerTy(), 448 FT->getNumParams() > 2 ? 449 FT->getParamType(2)->isPointerTy() : 22 ); 450 451 } 452 453 */ 454 455 if (isGccStdStringCString || isGccStdStringStdString || 456 isLlvmStdStringStdString || isLlvmStdStringCString || isMemcmp || 457 isStrcmp || isStrncmp) { 458 459 isPtrRtnN = isPtrRtn = false; 460 461 } 462 463 if (isPtrRtnN) { isPtrRtn = false; } 464 465 if (isPtrRtn) { calls.push_back(callInst); } 466 if (isMemcmp || isPtrRtnN) { Memcmp.push_back(callInst); } 467 if (isStrcmp) { Strcmp.push_back(callInst); } 468 if (isStrncmp) { Strncmp.push_back(callInst); } 469 if (isGccStdStringStdString) { gccStdStd.push_back(callInst); } 470 if (isGccStdStringCString) { gccStdC.push_back(callInst); } 471 if (isLlvmStdStringStdString) { llvmStdStd.push_back(callInst); } 472 if (isLlvmStdStringCString) { llvmStdC.push_back(callInst); } 473 474 } 475 476 } 477 478 } 479 480 } 481 482 if (!calls.size() && !gccStdStd.size() && !gccStdC.size() && 483 !llvmStdStd.size() && !llvmStdC.size() && !Memcmp.size() && 484 Strcmp.size() && Strncmp.size()) 485 return false; 486 487 /* 488 if (!be_quiet) 489 errs() << "Hooking " << calls.size() 490 << " calls with pointers as arguments\n"; 491 */ 492 493 for (auto &callInst : calls) { 494 495 Value *v1P = callInst->getArgOperand(0), *v2P = callInst->getArgOperand(1); 496 497 IRBuilder<> IRB2(callInst->getParent()); 498 IRB2.SetInsertPoint(callInst); 499 500 LoadInst *CmpPtr = IRB2.CreateLoad( 501 #if LLVM_VERSION_MAJOR >= 14 502 PointerType::get(Int8Ty, 0), 503 #endif 504 AFLCmplogPtr); 505 CmpPtr->setMetadata(M.getMDKindID("nosanitize"), MDNode::get(C, None)); 506 auto is_not_null = IRB2.CreateICmpNE(CmpPtr, Null); 507 auto ThenTerm = SplitBlockAndInsertIfThen(is_not_null, callInst, false); 508 509 IRBuilder<> IRB(ThenTerm); 510 511 std::vector<Value *> args; 512 Value *v1Pcasted = IRB.CreatePointerCast(v1P, i8PtrTy); 513 Value *v2Pcasted = IRB.CreatePointerCast(v2P, i8PtrTy); 514 args.push_back(v1Pcasted); 515 args.push_back(v2Pcasted); 516 517 IRB.CreateCall(cmplogHookFn, args); 518 519 // errs() << callInst->getCalledFunction()->getName() << "\n"; 520 521 } 522 523 for (auto &callInst : Memcmp) { 524 525 Value *v1P = callInst->getArgOperand(0), *v2P = callInst->getArgOperand(1), 526 *v3P = callInst->getArgOperand(2); 527 528 IRBuilder<> IRB2(callInst->getParent()); 529 IRB2.SetInsertPoint(callInst); 530 531 LoadInst *CmpPtr = IRB2.CreateLoad( 532 #if LLVM_VERSION_MAJOR >= 14 533 PointerType::get(Int8Ty, 0), 534 #endif 535 AFLCmplogPtr); 536 CmpPtr->setMetadata(M.getMDKindID("nosanitize"), MDNode::get(C, None)); 537 auto is_not_null = IRB2.CreateICmpNE(CmpPtr, Null); 538 auto ThenTerm = SplitBlockAndInsertIfThen(is_not_null, callInst, false); 539 540 IRBuilder<> IRB(ThenTerm); 541 542 std::vector<Value *> args; 543 Value *v1Pcasted = IRB.CreatePointerCast(v1P, i8PtrTy); 544 Value *v2Pcasted = IRB.CreatePointerCast(v2P, i8PtrTy); 545 Value *v3Pbitcast = IRB.CreateBitCast( 546 v3P, IntegerType::get(C, v3P->getType()->getPrimitiveSizeInBits())); 547 Value *v3Pcasted = 548 IRB.CreateIntCast(v3Pbitcast, IntegerType::get(C, 64), false); 549 args.push_back(v1Pcasted); 550 args.push_back(v2Pcasted); 551 args.push_back(v3Pcasted); 552 553 IRB.CreateCall(cmplogHookFnN, args); 554 555 // errs() << callInst->getCalledFunction()->getName() << "\n"; 556 557 } 558 559 for (auto &callInst : Strcmp) { 560 561 Value *v1P = callInst->getArgOperand(0), *v2P = callInst->getArgOperand(1); 562 563 IRBuilder<> IRB2(callInst->getParent()); 564 IRB2.SetInsertPoint(callInst); 565 566 LoadInst *CmpPtr = IRB2.CreateLoad( 567 #if LLVM_VERSION_MAJOR >= 14 568 PointerType::get(Int8Ty, 0), 569 #endif 570 AFLCmplogPtr); 571 CmpPtr->setMetadata(M.getMDKindID("nosanitize"), MDNode::get(C, None)); 572 auto is_not_null = IRB2.CreateICmpNE(CmpPtr, Null); 573 auto ThenTerm = SplitBlockAndInsertIfThen(is_not_null, callInst, false); 574 575 IRBuilder<> IRB(ThenTerm); 576 577 std::vector<Value *> args; 578 Value *v1Pcasted = IRB.CreatePointerCast(v1P, i8PtrTy); 579 Value *v2Pcasted = IRB.CreatePointerCast(v2P, i8PtrTy); 580 args.push_back(v1Pcasted); 581 args.push_back(v2Pcasted); 582 583 IRB.CreateCall(cmplogHookFnStr, args); 584 585 // errs() << callInst->getCalledFunction()->getName() << "\n"; 586 587 } 588 589 for (auto &callInst : Strncmp) { 590 591 Value *v1P = callInst->getArgOperand(0), *v2P = callInst->getArgOperand(1), 592 *v3P = callInst->getArgOperand(2); 593 594 IRBuilder<> IRB2(callInst->getParent()); 595 IRB2.SetInsertPoint(callInst); 596 597 LoadInst *CmpPtr = IRB2.CreateLoad( 598 #if LLVM_VERSION_MAJOR >= 14 599 PointerType::get(Int8Ty, 0), 600 #endif 601 AFLCmplogPtr); 602 CmpPtr->setMetadata(M.getMDKindID("nosanitize"), MDNode::get(C, None)); 603 auto is_not_null = IRB2.CreateICmpNE(CmpPtr, Null); 604 auto ThenTerm = SplitBlockAndInsertIfThen(is_not_null, callInst, false); 605 606 IRBuilder<> IRB(ThenTerm); 607 608 std::vector<Value *> args; 609 Value *v1Pcasted = IRB.CreatePointerCast(v1P, i8PtrTy); 610 Value *v2Pcasted = IRB.CreatePointerCast(v2P, i8PtrTy); 611 Value *v3Pbitcast = IRB.CreateBitCast( 612 v3P, IntegerType::get(C, v3P->getType()->getPrimitiveSizeInBits())); 613 Value *v3Pcasted = 614 IRB.CreateIntCast(v3Pbitcast, IntegerType::get(C, 64), false); 615 args.push_back(v1Pcasted); 616 args.push_back(v2Pcasted); 617 args.push_back(v3Pcasted); 618 619 IRB.CreateCall(cmplogHookFnStrN, args); 620 621 // errs() << callInst->getCalledFunction()->getName() << "\n"; 622 623 } 624 625 for (auto &callInst : gccStdStd) { 626 627 Value *v1P = callInst->getArgOperand(0), *v2P = callInst->getArgOperand(1); 628 629 IRBuilder<> IRB2(callInst->getParent()); 630 IRB2.SetInsertPoint(callInst); 631 632 LoadInst *CmpPtr = IRB2.CreateLoad( 633 #if LLVM_VERSION_MAJOR >= 14 634 PointerType::get(Int8Ty, 0), 635 #endif 636 AFLCmplogPtr); 637 CmpPtr->setMetadata(M.getMDKindID("nosanitize"), MDNode::get(C, None)); 638 auto is_not_null = IRB2.CreateICmpNE(CmpPtr, Null); 639 auto ThenTerm = SplitBlockAndInsertIfThen(is_not_null, callInst, false); 640 641 IRBuilder<> IRB(ThenTerm); 642 643 std::vector<Value *> args; 644 Value *v1Pcasted = IRB.CreatePointerCast(v1P, i8PtrTy); 645 Value *v2Pcasted = IRB.CreatePointerCast(v2P, i8PtrTy); 646 args.push_back(v1Pcasted); 647 args.push_back(v2Pcasted); 648 649 IRB.CreateCall(cmplogGccStdStd, args); 650 651 // errs() << callInst->getCalledFunction()->getName() << "\n"; 652 653 } 654 655 for (auto &callInst : gccStdC) { 656 657 Value *v1P = callInst->getArgOperand(0), *v2P = callInst->getArgOperand(1); 658 659 IRBuilder<> IRB2(callInst->getParent()); 660 IRB2.SetInsertPoint(callInst); 661 662 LoadInst *CmpPtr = IRB2.CreateLoad( 663 #if LLVM_VERSION_MAJOR >= 14 664 PointerType::get(Int8Ty, 0), 665 #endif 666 AFLCmplogPtr); 667 CmpPtr->setMetadata(M.getMDKindID("nosanitize"), MDNode::get(C, None)); 668 auto is_not_null = IRB2.CreateICmpNE(CmpPtr, Null); 669 auto ThenTerm = SplitBlockAndInsertIfThen(is_not_null, callInst, false); 670 671 IRBuilder<> IRB(ThenTerm); 672 673 std::vector<Value *> args; 674 Value *v1Pcasted = IRB.CreatePointerCast(v1P, i8PtrTy); 675 Value *v2Pcasted = IRB.CreatePointerCast(v2P, i8PtrTy); 676 args.push_back(v1Pcasted); 677 args.push_back(v2Pcasted); 678 679 IRB.CreateCall(cmplogGccStdC, args); 680 681 // errs() << callInst->getCalledFunction()->getName() << "\n"; 682 683 } 684 685 for (auto &callInst : llvmStdStd) { 686 687 Value *v1P = callInst->getArgOperand(0), *v2P = callInst->getArgOperand(1); 688 689 IRBuilder<> IRB2(callInst->getParent()); 690 IRB2.SetInsertPoint(callInst); 691 692 LoadInst *CmpPtr = IRB2.CreateLoad( 693 #if LLVM_VERSION_MAJOR >= 14 694 PointerType::get(Int8Ty, 0), 695 #endif 696 AFLCmplogPtr); 697 CmpPtr->setMetadata(M.getMDKindID("nosanitize"), MDNode::get(C, None)); 698 auto is_not_null = IRB2.CreateICmpNE(CmpPtr, Null); 699 auto ThenTerm = SplitBlockAndInsertIfThen(is_not_null, callInst, false); 700 701 IRBuilder<> IRB(ThenTerm); 702 703 std::vector<Value *> args; 704 Value *v1Pcasted = IRB.CreatePointerCast(v1P, i8PtrTy); 705 Value *v2Pcasted = IRB.CreatePointerCast(v2P, i8PtrTy); 706 args.push_back(v1Pcasted); 707 args.push_back(v2Pcasted); 708 709 IRB.CreateCall(cmplogLlvmStdStd, args); 710 711 // errs() << callInst->getCalledFunction()->getName() << "\n"; 712 713 } 714 715 for (auto &callInst : llvmStdC) { 716 717 Value *v1P = callInst->getArgOperand(0), *v2P = callInst->getArgOperand(1); 718 719 IRBuilder<> IRB2(callInst->getParent()); 720 IRB2.SetInsertPoint(callInst); 721 722 LoadInst *CmpPtr = IRB2.CreateLoad( 723 #if LLVM_VERSION_MAJOR >= 14 724 PointerType::get(Int8Ty, 0), 725 #endif 726 AFLCmplogPtr); 727 CmpPtr->setMetadata(M.getMDKindID("nosanitize"), MDNode::get(C, None)); 728 auto is_not_null = IRB2.CreateICmpNE(CmpPtr, Null); 729 auto ThenTerm = SplitBlockAndInsertIfThen(is_not_null, callInst, false); 730 731 IRBuilder<> IRB(ThenTerm); 732 733 std::vector<Value *> args; 734 Value *v1Pcasted = IRB.CreatePointerCast(v1P, i8PtrTy); 735 Value *v2Pcasted = IRB.CreatePointerCast(v2P, i8PtrTy); 736 args.push_back(v1Pcasted); 737 args.push_back(v2Pcasted); 738 739 IRB.CreateCall(cmplogLlvmStdC, args); 740 741 // errs() << callInst->getCalledFunction()->getName() << "\n"; 742 743 } 744 745 return true; 746 747 } 748 749 #if LLVM_VERSION_MAJOR >= 11 /* use new pass manager */ 750 PreservedAnalyses CmpLogRoutines::run(Module &M, ModuleAnalysisManager &MAM) { 751 752 #else 753 bool CmpLogRoutines::runOnModule(Module &M) { 754 755 #endif 756 757 if (getenv("AFL_QUIET") == NULL) 758 printf("Running cmplog-routines-pass by [email protected]\n"); 759 else 760 be_quiet = 1; 761 hookRtns(M); 762 #if LLVM_VERSION_MAJOR >= 11 /* use new pass manager */ 763 auto PA = PreservedAnalyses::all(); 764 #endif 765 verifyModule(M); 766 767 #if LLVM_VERSION_MAJOR >= 11 /* use new pass manager */ 768 return PA; 769 #else 770 return true; 771 #endif 772 773 } 774 775 #if LLVM_VERSION_MAJOR < 11 /* use old pass manager */ 776 static void registerCmpLogRoutinesPass(const PassManagerBuilder &, 777 legacy::PassManagerBase &PM) { 778 779 auto p = new CmpLogRoutines(); 780 PM.add(p); 781 782 } 783 784 static RegisterStandardPasses RegisterCmpLogRoutinesPass( 785 PassManagerBuilder::EP_OptimizerLast, registerCmpLogRoutinesPass); 786 787 static RegisterStandardPasses RegisterCmpLogRoutinesPass0( 788 PassManagerBuilder::EP_EnabledOnOptLevel0, registerCmpLogRoutinesPass); 789 790 #if LLVM_VERSION_MAJOR >= 11 791 static RegisterStandardPasses RegisterCmpLogRoutinesPassLTO( 792 PassManagerBuilder::EP_FullLinkTimeOptimizationLast, 793 registerCmpLogRoutinesPass); 794 #endif 795 #endif 796 797