xref: /aosp_15_r20/external/AFLplusplus/instrumentation/cmplog-switches-pass.cc (revision 08b48e0b10e97b33e7b60c5b6e2243bd915777f2)
1 /*
2    american fuzzy lop++ - LLVM CmpLog instrumentation
3    --------------------------------------------------
4 
5    Written by Andrea Fioraldi <[email protected]>
6 
7    Copyright 2015, 2016 Google Inc. All rights reserved.
8    Copyright 2019-2024 AFLplusplus Project. All rights reserved.
9 
10    Licensed under the Apache License, Version 2.0 (the "License");
11    you may not use this file except in compliance with the License.
12    You may obtain a copy of the License at:
13 
14      https://www.apache.org/licenses/LICENSE-2.0
15 
16 */
17 
18 #include <stdio.h>
19 #include <stdlib.h>
20 #include <unistd.h>
21 
22 #include <iostream>
23 #include <list>
24 #include <string>
25 #include <fstream>
26 #include <sys/time.h>
27 
28 #include "llvm/Config/llvm-config.h"
29 #include "llvm/ADT/Statistic.h"
30 #include "llvm/IR/IRBuilder.h"
31 #if LLVM_VERSION_MAJOR >= 11                        /* use new pass manager */
32   #include "llvm/Passes/PassPlugin.h"
33   #include "llvm/Passes/PassBuilder.h"
34   #include "llvm/IR/PassManager.h"
35 #else
36   #include "llvm/IR/LegacyPassManager.h"
37   #include "llvm/Transforms/IPO/PassManagerBuilder.h"
38 #endif
39 #include "llvm/IR/Module.h"
40 #include "llvm/Support/Debug.h"
41 #include "llvm/Support/raw_ostream.h"
42 #if LLVM_VERSION_MAJOR < 17
43   #include "llvm/Transforms/IPO/PassManagerBuilder.h"
44 #endif
45 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
46 #include "llvm/Pass.h"
47 #include "llvm/Analysis/ValueTracking.h"
48 
49 #include "llvm/IR/IRBuilder.h"
50 #if LLVM_VERSION_MAJOR >= 4 || \
51     (LLVM_VERSION_MAJOR == 3 && LLVM_VERSION_MINOR > 4)
52   #include "llvm/IR/Verifier.h"
53   #include "llvm/IR/DebugInfo.h"
54 #else
55   #include "llvm/Analysis/Verifier.h"
56   #include "llvm/DebugInfo.h"
57   #define nullptr 0
58 #endif
59 
60 #include <set>
61 #include "afl-llvm-common.h"
62 
63 using namespace llvm;
64 
65 namespace {
66 
67 #if LLVM_VERSION_MAJOR >= 11                        /* use new pass manager */
68 class CmplogSwitches : public PassInfoMixin<CmplogSwitches> {
69 
70  public:
CmplogSwitches()71   CmplogSwitches() {
72 
73 #else
74 class CmplogSwitches : public ModulePass {
75 
76  public:
77   static char ID;
78   CmplogSwitches() : ModulePass(ID) {
79 
80 #endif
81     initInstrumentList();
82 
83   }
84 
85 #if LLVM_VERSION_MAJOR >= 11                        /* use new pass manager */
86   PreservedAnalyses run(Module &M, ModuleAnalysisManager &MAM);
87 #else
88   bool runOnModule(Module &M) override;
89 
90   #if LLVM_VERSION_MAJOR < 4
91   const char *getPassName() const override {
92 
93   #else
94   StringRef getPassName() const override {
95 
96   #endif
97     return "cmplog switch split";
98 
99   }
100 
101 #endif
102 
103  private:
104   bool hookInstrs(Module &M);
105 
106 };
107 
108 }  // namespace
109 
110 #if LLVM_MAJOR >= 11
111 extern "C" ::llvm::PassPluginLibraryInfo LLVM_ATTRIBUTE_WEAK
112 llvmGetPassPluginInfo() {
113 
114   return {LLVM_PLUGIN_API_VERSION, "cmplogswitches", "v0.1",
115           /* lambda to insert our pass into the pass pipeline. */
116           [](PassBuilder &PB) {
117 
118   #if LLVM_VERSION_MAJOR <= 13
119             using OptimizationLevel = typename PassBuilder::OptimizationLevel;
120   #endif
121             PB.registerOptimizerLastEPCallback(
122                 [](ModulePassManager &MPM, OptimizationLevel OL) {
123 
124                   MPM.addPass(CmplogSwitches());
125 
126                 });
127 
128           }};
129 
130 }
131 
132 #else
133 char CmplogSwitches::ID = 0;
134 #endif
135 
136 template <class Iterator>
137 Iterator Unique(Iterator first, Iterator last) {
138 
139   while (first != last) {
140 
141     Iterator next(first);
142     last = std::remove(++next, last, *first);
143     first = next;
144 
145   }
146 
147   return last;
148 
149 }
150 
151 bool CmplogSwitches::hookInstrs(Module &M) {
152 
153   std::vector<SwitchInst *> switches;
154   LLVMContext              &C = M.getContext();
155 
156   Type        *VoidTy = Type::getVoidTy(C);
157   IntegerType *Int8Ty = IntegerType::getInt8Ty(C);
158   IntegerType *Int16Ty = IntegerType::getInt16Ty(C);
159   IntegerType *Int32Ty = IntegerType::getInt32Ty(C);
160   IntegerType *Int64Ty = IntegerType::getInt64Ty(C);
161 
162 #if LLVM_VERSION_MAJOR >= 9
163   FunctionCallee
164 #else
165   Constant *
166 #endif
167       c1 = M.getOrInsertFunction("__cmplog_ins_hook1", VoidTy, Int8Ty, Int8Ty,
168                                  Int8Ty
169 #if LLVM_VERSION_MAJOR < 5
170                                  ,
171                                  NULL
172 #endif
173       );
174 #if LLVM_VERSION_MAJOR >= 9
175   FunctionCallee cmplogHookIns1 = c1;
176 #else
177   Function *cmplogHookIns1 = cast<Function>(c1);
178 #endif
179 
180 #if LLVM_VERSION_MAJOR >= 9
181   FunctionCallee
182 #else
183   Constant *
184 #endif
185       c2 = M.getOrInsertFunction("__cmplog_ins_hook2", VoidTy, Int16Ty, Int16Ty,
186                                  Int8Ty
187 #if LLVM_VERSION_MAJOR < 5
188                                  ,
189                                  NULL
190 #endif
191       );
192 #if LLVM_VERSION_MAJOR >= 9
193   FunctionCallee cmplogHookIns2 = c2;
194 #else
195   Function *cmplogHookIns2 = cast<Function>(c2);
196 #endif
197 
198 #if LLVM_VERSION_MAJOR >= 9
199   FunctionCallee
200 #else
201   Constant *
202 #endif
203       c4 = M.getOrInsertFunction("__cmplog_ins_hook4", VoidTy, Int32Ty, Int32Ty,
204                                  Int8Ty
205 #if LLVM_VERSION_MAJOR < 5
206                                  ,
207                                  NULL
208 #endif
209       );
210 #if LLVM_VERSION_MAJOR >= 9
211   FunctionCallee cmplogHookIns4 = c4;
212 #else
213   Function *cmplogHookIns4 = cast<Function>(c4);
214 #endif
215 
216 #if LLVM_VERSION_MAJOR >= 9
217   FunctionCallee
218 #else
219   Constant *
220 #endif
221       c8 = M.getOrInsertFunction("__cmplog_ins_hook8", VoidTy, Int64Ty, Int64Ty,
222                                  Int8Ty
223 #if LLVM_VERSION_MAJOR < 5
224                                  ,
225                                  NULL
226 #endif
227       );
228 #if LLVM_VERSION_MAJOR >= 9
229   FunctionCallee cmplogHookIns8 = c8;
230 #else
231   Function *cmplogHookIns8 = cast<Function>(c8);
232 #endif
233 
234   GlobalVariable *AFLCmplogPtr = M.getNamedGlobal("__afl_cmp_map");
235 
236   if (!AFLCmplogPtr) {
237 
238     AFLCmplogPtr = new GlobalVariable(M, PointerType::get(Int8Ty, 0), false,
239                                       GlobalValue::ExternalWeakLinkage, 0,
240                                       "__afl_cmp_map");
241 
242   }
243 
244   Constant *Null = Constant::getNullValue(PointerType::get(Int8Ty, 0));
245 
246   /* iterate over all functions, bbs and instruction and add suitable calls */
247   for (auto &F : M) {
248 
249     if (!isInInstrumentList(&F, MNAME)) continue;
250 
251     for (auto &BB : F) {
252 
253       SwitchInst *switchInst = nullptr;
254       if ((switchInst = dyn_cast<SwitchInst>(BB.getTerminator()))) {
255 
256         if (switchInst->getNumCases() > 1) { switches.push_back(switchInst); }
257 
258       }
259 
260     }
261 
262   }
263 
264   // unique the collected switches
265   switches.erase(Unique(switches.begin(), switches.end()), switches.end());
266 
267   // Instrument switch values for cmplog
268   if (switches.size()) {
269 
270     if (!be_quiet)
271       errs() << "Hooking " << switches.size() << " switch instructions\n";
272 
273     for (auto &SI : switches) {
274 
275       Value        *Val = SI->getCondition();
276       unsigned int  max_size = Val->getType()->getIntegerBitWidth(), cast_size;
277       unsigned char do_cast = 0;
278 
279       if (!SI->getNumCases() || max_size < 16) {
280 
281         // if (!be_quiet) errs() << "skip trivial switch..\n";
282         continue;
283 
284       }
285 
286       if (max_size % 8) {
287 
288         max_size = (((max_size / 8) + 1) * 8);
289         do_cast = 1;
290 
291       }
292 
293       IRBuilder<> IRB2(SI->getParent());
294       IRB2.SetInsertPoint(SI);
295 
296       LoadInst *CmpPtr = IRB2.CreateLoad(
297 #if LLVM_VERSION_MAJOR >= 14
298           PointerType::get(Int8Ty, 0),
299 #endif
300           AFLCmplogPtr);
301       CmpPtr->setMetadata(M.getMDKindID("nosanitize"), MDNode::get(C, None));
302       auto is_not_null = IRB2.CreateICmpNE(CmpPtr, Null);
303       auto ThenTerm = SplitBlockAndInsertIfThen(is_not_null, SI, false);
304 
305       IRBuilder<> IRB(ThenTerm);
306 
307       if (max_size > 128) {
308 
309         if (!be_quiet) {
310 
311           fprintf(stderr,
312                   "Cannot handle this switch bit size: %u (truncating)\n",
313                   max_size);
314 
315         }
316 
317         max_size = 128;
318         do_cast = 1;
319 
320       }
321 
322       // do we need to cast?
323       switch (max_size) {
324 
325         case 8:
326         case 16:
327         case 32:
328         case 64:
329         case 128:
330           cast_size = max_size;
331           break;
332         default:
333           cast_size = 128;
334           do_cast = 1;
335 
336       }
337 
338       Value *CompareTo = Val;
339 
340       if (do_cast) {
341 
342         CompareTo =
343             IRB.CreateIntCast(CompareTo, IntegerType::get(C, cast_size), false);
344 
345       }
346 
347       for (SwitchInst::CaseIt i = SI->case_begin(), e = SI->case_end(); i != e;
348            ++i) {
349 
350 #if LLVM_VERSION_MAJOR < 5
351         ConstantInt *cint = i.getCaseValue();
352 #else
353         ConstantInt *cint = i->getCaseValue();
354 #endif
355 
356         if (cint) {
357 
358           std::vector<Value *> args;
359           args.push_back(CompareTo);
360 
361           Value *new_param = cint;
362 
363           if (do_cast) {
364 
365             new_param =
366                 IRB.CreateIntCast(cint, IntegerType::get(C, cast_size), false);
367 
368           }
369 
370           if (new_param) {
371 
372             args.push_back(new_param);
373             ConstantInt *attribute = ConstantInt::get(Int8Ty, 1);
374             args.push_back(attribute);
375             if (cast_size != max_size) {
376 
377               ConstantInt *bitsize =
378                   ConstantInt::get(Int8Ty, (max_size / 8) - 1);
379               args.push_back(bitsize);
380 
381             }
382 
383             switch (cast_size) {
384 
385               case 8:
386                 IRB.CreateCall(cmplogHookIns1, args);
387                 break;
388               case 16:
389                 IRB.CreateCall(cmplogHookIns2, args);
390                 break;
391               case 32:
392                 IRB.CreateCall(cmplogHookIns4, args);
393                 break;
394               case 64:
395                 IRB.CreateCall(cmplogHookIns8, args);
396                 break;
397               case 128:
398 #ifdef WORD_SIZE_64
399                 if (max_size == 128) {
400 
401                   IRB.CreateCall(cmplogHookIns16, args);
402 
403                 } else {
404 
405                   IRB.CreateCall(cmplogHookInsN, args);
406 
407                 }
408 
409 #endif
410                 break;
411               default:
412                 break;
413 
414             }
415 
416           }
417 
418         }
419 
420       }
421 
422     }
423 
424   }
425 
426   if (switches.size())
427     return true;
428   else
429     return false;
430 
431 }
432 
433 #if LLVM_VERSION_MAJOR >= 11                        /* use new pass manager */
434 PreservedAnalyses CmplogSwitches::run(Module &M, ModuleAnalysisManager &MAM) {
435 
436 #else
437 bool CmplogSwitches::runOnModule(Module &M) {
438 
439 #endif
440 
441   if (getenv("AFL_QUIET") == NULL)
442     printf("Running cmplog-switches-pass by [email protected]\n");
443   else
444     be_quiet = 1;
445   hookInstrs(M);
446 #if LLVM_VERSION_MAJOR >= 11                        /* use new pass manager */
447   auto PA = PreservedAnalyses::all();
448 #endif
449   verifyModule(M);
450 
451 #if LLVM_VERSION_MAJOR >= 11                        /* use new pass manager */
452   return PA;
453 #else
454   return true;
455 #endif
456 
457 }
458 
459 #if LLVM_VERSION_MAJOR < 11                         /* use old pass manager */
460 static void registerCmplogSwitchesPass(const PassManagerBuilder &,
461                                        legacy::PassManagerBase &PM) {
462 
463   auto p = new CmplogSwitches();
464   PM.add(p);
465 
466 }
467 
468 static RegisterStandardPasses RegisterCmplogSwitchesPass(
469     PassManagerBuilder::EP_OptimizerLast, registerCmplogSwitchesPass);
470 
471 static RegisterStandardPasses RegisterCmplogSwitchesPass0(
472     PassManagerBuilder::EP_EnabledOnOptLevel0, registerCmplogSwitchesPass);
473 
474   #if LLVM_VERSION_MAJOR >= 11
475 static RegisterStandardPasses RegisterCmplogSwitchesPassLTO(
476     PassManagerBuilder::EP_FullLinkTimeOptimizationLast,
477     registerCmplogSwitchesPass);
478   #endif
479 #endif
480 
481