1#!/usr/bin/env python 2# 3# This script generates a BPF program with structure inspired by trace.py. The 4# generated program operates on PID-indexed stacks. Generally speaking, 5# bookkeeping is done at every intermediate function kprobe/kretprobe to enforce 6# the goal of "fail iff this call chain and these predicates". 7# 8# Top level functions(the ones at the end of the call chain) are responsible for 9# creating the pid_struct and deleting it from the map in kprobe and kretprobe 10# respectively. 11# 12# Intermediate functions(between should_fail_whatever and the top level 13# functions) are responsible for updating the stack to indicate "I have been 14# called and one of my predicate(s) passed" in their entry probes. In their exit 15# probes, they do the opposite, popping their stack to maintain correctness. 16# This implementation aims to ensure correctness in edge cases like recursive 17# calls, so there's some additional information stored in pid_struct for that. 18# 19# At the bottom level function(should_fail_whatever), we do a simple check to 20# ensure all necessary calls/predicates have passed before error injection. 21# 22# Note: presently there are a few hacks to get around various rewriter/verifier 23# issues. 24# 25# Note: this tool requires: 26# - CONFIG_BPF_KPROBE_OVERRIDE 27# 28# USAGE: inject [-h] [-I header] [-P probability] [-v] mode spec 29# 30# Copyright (c) 2018 Facebook, Inc. 31# Licensed under the Apache License, Version 2.0 (the "License") 32# 33# 16-Mar-2018 Howard McLauchlan Created this. 34 35import argparse 36import re 37from bcc import BPF 38 39 40class Probe: 41 errno_mapping = { 42 "kmalloc": "-ENOMEM", 43 "bio": "-EIO", 44 "alloc_page" : "true", 45 } 46 47 @classmethod 48 def configure(cls, mode, probability, count): 49 cls.mode = mode 50 cls.probability = probability 51 cls.count = count 52 53 def __init__(self, func, preds, length, entry): 54 # length of call chain 55 self.length = length 56 self.func = func 57 self.preds = preds 58 self.is_entry = entry 59 60 def _bail(self, err): 61 raise ValueError("error in probe '%s': %s" % 62 (self.spec, err)) 63 64 def _get_err(self): 65 return Probe.errno_mapping[Probe.mode] 66 67 def _get_if_top(self): 68 # ordering guarantees that if this function is top, the last tup is top 69 chk = self.preds[0][1] == 0 70 if not chk: 71 return "" 72 73 if Probe.probability == 1: 74 early_pred = "false" 75 else: 76 early_pred = "bpf_get_prandom_u32() > %s" % str(int((1<<32)*Probe.probability)) 77 # init the map 78 # don't do an early exit here so the singular case works automatically 79 # have an early exit for probability option 80 enter = """ 81 /* 82 * Early exit for probability case 83 */ 84 if (%s) 85 return 0; 86 /* 87 * Top level function init map 88 */ 89 struct pid_struct p_struct = {0, 0}; 90 m.insert(&pid, &p_struct); 91 """ % early_pred 92 93 # kill the entry 94 exit = """ 95 /* 96 * Top level function clean up map 97 */ 98 m.delete(&pid); 99 """ 100 101 return enter if self.is_entry else exit 102 103 def _get_heading(self): 104 105 # we need to insert identifier and ctx into self.func 106 # gonna make a lot of formatting assumptions to make this work 107 left = self.func.find("(") 108 right = self.func.rfind(")") 109 110 # self.event and self.func_name need to be accessible 111 self.event = self.func[0:left] 112 self.func_name = self.event + ("_entry" if self.is_entry else "_exit") 113 func_sig = "struct pt_regs *ctx" 114 115 # assume there's something in there, no guarantee its well formed 116 if right > left + 1 and self.is_entry: 117 func_sig += ", " + self.func[left + 1:right] 118 119 return "int %s(%s)" % (self.func_name, func_sig) 120 121 def _get_entry_logic(self): 122 # there is at least one tup(pred, place) for this function 123 text = """ 124 125 if (p->conds_met >= %s) 126 return 0; 127 if (p->conds_met == %s && %s) { 128 p->stack[%s] = p->curr_call; 129 p->conds_met++; 130 }""" 131 text = text % (self.length, self.preds[0][1], self.preds[0][0], 132 self.preds[0][1]) 133 134 # for each additional pred 135 for tup in self.preds[1:]: 136 text += """ 137 else if (p->conds_met == %s && %s) { 138 p->stack[%s] = p->curr_call; 139 p->conds_met++; 140 } 141 """ % (tup[1], tup[0], tup[1]) 142 return text 143 144 def _generate_entry(self): 145 prog = self._get_heading() + """ 146{ 147 u32 pid = bpf_get_current_pid_tgid(); 148 %s 149 150 struct pid_struct *p = m.lookup(&pid); 151 152 if (!p) 153 return 0; 154 155 /* 156 * preparation for predicate, if necessary 157 */ 158 %s 159 /* 160 * Generate entry logic 161 */ 162 %s 163 164 p->curr_call++; 165 166 return 0; 167}""" 168 169 prog = prog % (self._get_if_top(), self.prep, self._get_entry_logic()) 170 return prog 171 172 # only need to check top of stack 173 def _get_exit_logic(self): 174 text = """ 175 if (p->conds_met < 1 || p->conds_met >= %s) 176 return 0; 177 178 if (p->stack[p->conds_met - 1] == p->curr_call) 179 p->conds_met--; 180 """ 181 return text % str(self.length + 1) 182 183 def _generate_exit(self): 184 prog = self._get_heading() + """ 185{ 186 u32 pid = bpf_get_current_pid_tgid(); 187 188 struct pid_struct *p = m.lookup(&pid); 189 190 if (!p) 191 return 0; 192 193 p->curr_call--; 194 195 /* 196 * Generate exit logic 197 */ 198 %s 199 %s 200 return 0; 201}""" 202 203 prog = prog % (self._get_exit_logic(), self._get_if_top()) 204 205 return prog 206 207 # Special case for should_fail_whatever 208 def _generate_bottom(self): 209 pred = self.preds[0][0] 210 text = self._get_heading() + """ 211{ 212 u32 overridden = 0; 213 int zero = 0; 214 u32* val; 215 216 val = count.lookup(&zero); 217 if (val) 218 overridden = *val; 219 220 /* 221 * preparation for predicate, if necessary 222 */ 223 %s 224 /* 225 * If this is the only call in the chain and predicate passes 226 */ 227 if (%s == 1 && %s && overridden < %s) { 228 count.atomic_increment(zero); 229 bpf_override_return(ctx, %s); 230 return 0; 231 } 232 u32 pid = bpf_get_current_pid_tgid(); 233 234 struct pid_struct *p = m.lookup(&pid); 235 236 if (!p) 237 return 0; 238 239 /* 240 * If all conds have been met and predicate passes 241 */ 242 if (p->conds_met == %s && %s && overridden < %s) { 243 count.atomic_increment(zero); 244 bpf_override_return(ctx, %s); 245 } 246 return 0; 247}""" 248 return text % (self.prep, self.length, pred, Probe.count, 249 self._get_err(), self.length - 1, pred, Probe.count, 250 self._get_err()) 251 252 # presently parses and replaces STRCMP 253 # STRCMP exists because string comparison is inconvenient and somewhat buggy 254 # https://github.com/iovisor/bcc/issues/1617 255 def _prepare_pred(self): 256 self.prep = "" 257 for i in range(len(self.preds)): 258 new_pred = "" 259 pred = self.preds[i][0] 260 place = self.preds[i][1] 261 start, ind = 0, 0 262 while start < len(pred): 263 ind = pred.find("STRCMP(", start) 264 if ind == -1: 265 break 266 new_pred += pred[start:ind] 267 # 7 is len("STRCMP(") 268 start = pred.find(")", start + 7) + 1 269 270 # then ind ... start is STRCMP(...) 271 ptr, literal = pred[ind + 7:start - 1].split(",") 272 literal = literal.strip() 273 274 # x->y->z, some string literal 275 # we make unique id with place_ind 276 uuid = "%s_%s" % (place, ind) 277 unique_bool = "is_true_%s" % uuid 278 self.prep += """ 279 char *str_%s = %s; 280 bool %s = true;\n""" % (uuid, ptr.strip(), unique_bool) 281 282 check = "\t%s &= *(str_%s++) == '%%s';\n" % (unique_bool, uuid) 283 284 for ch in literal: 285 self.prep += check % ch 286 self.prep += check % r'\0' 287 new_pred += unique_bool 288 289 new_pred += pred[start:] 290 self.preds[i] = (new_pred, place) 291 292 def generate_program(self): 293 # generate code to work around various rewriter issues 294 self._prepare_pred() 295 296 # special case for bottom 297 if self.preds[-1][1] == self.length - 1: 298 return self._generate_bottom() 299 300 return self._generate_entry() if self.is_entry else self._generate_exit() 301 302 def attach(self, bpf): 303 if self.is_entry: 304 bpf.attach_kprobe(event=self.event, 305 fn_name=self.func_name) 306 else: 307 bpf.attach_kretprobe(event=self.event, 308 fn_name=self.func_name) 309 310 311class Tool: 312 313 examples =""" 314EXAMPLES: 315# ./inject.py kmalloc -v 'SyS_mount()' 316 Fails all calls to syscall mount 317# ./inject.py kmalloc -v '(true) => SyS_mount()(true)' 318 Explicit rewriting of above 319# ./inject.py kmalloc -v 'mount_subtree() => btrfs_mount()' 320 Fails btrfs mounts only 321# ./inject.py kmalloc -v 'd_alloc_parallel(struct dentry *parent, const struct \\ 322 qstr *name)(STRCMP(name->name, 'bananas'))' 323 Fails dentry allocations of files named 'bananas' 324# ./inject.py kmalloc -v -P 0.01 'SyS_mount()' 325 Fails calls to syscall mount with 1% probability 326 """ 327 # add cases as necessary 328 error_injection_mapping = { 329 "kmalloc": "should_failslab(struct kmem_cache *s, gfp_t gfpflags)", 330 "bio": "should_fail_bio(struct bio *bio)", 331 "alloc_page": "should_fail_alloc_page(gfp_t gfp_mask, unsigned int order)", 332 } 333 334 def __init__(self): 335 parser = argparse.ArgumentParser(description="Fail specified kernel" + 336 " functionality when call chain and predicates are met", 337 formatter_class=argparse.RawDescriptionHelpFormatter, 338 epilog=Tool.examples) 339 parser.add_argument(dest="mode", choices=["kmalloc", "bio", "alloc_page"], 340 help="indicate which base kernel function to fail") 341 parser.add_argument(metavar="spec", dest="spec", 342 help="specify call chain") 343 parser.add_argument("-I", "--include", action="append", 344 metavar="header", 345 help="additional header files to include in the BPF program") 346 parser.add_argument("-P", "--probability", default=1, 347 metavar="probability", type=float, 348 help="probability that this call chain will fail") 349 parser.add_argument("-v", "--verbose", action="store_true", 350 help="print BPF program") 351 parser.add_argument("-c", "--count", action="store", default=-1, 352 help="Number of fails before bypassing the override") 353 self.args = parser.parse_args() 354 355 self.program = "" 356 self.spec = self.args.spec 357 self.map = {} 358 self.probes = [] 359 self.key = Tool.error_injection_mapping[self.args.mode] 360 361 # create_probes and associated stuff 362 def _create_probes(self): 363 self._parse_spec() 364 Probe.configure(self.args.mode, self.args.probability, self.args.count) 365 # self, func, preds, total, entry 366 367 # create all the pair probes 368 for fx, preds in self.map.items(): 369 370 # do the enter 371 self.probes.append(Probe(fx, preds, self.length, True)) 372 373 if self.key == fx: 374 continue 375 376 # do the exit 377 self.probes.append(Probe(fx, preds, self.length, False)) 378 379 def _parse_frames(self): 380 # sentinel 381 data = self.spec + '\0' 382 start, count = 0, 0 383 384 frames = [] 385 cur_frame = [] 386 i = 0 387 last_frame_added = 0 388 389 while i < len(data): 390 # improper input 391 if count < 0: 392 raise Exception("Check your parentheses") 393 c = data[i] 394 count += c == '(' 395 count -= c == ')' 396 if not count: 397 if c == '\0' or (c == '=' and data[i + 1] == '>'): 398 # This block is closing a chunk. This means cur_frame must 399 # have something in it. 400 if not cur_frame: 401 raise Exception("Cannot parse spec, missing parens") 402 if len(cur_frame) == 2: 403 frame = tuple(cur_frame) 404 elif cur_frame[0][0] == '(': 405 frame = self.key, cur_frame[0] 406 else: 407 frame = cur_frame[0], '(true)' 408 frames.append(frame) 409 del cur_frame[:] 410 i += 1 411 start = i + 1 412 elif c == ')': 413 cur_frame.append(data[start:i + 1].strip()) 414 start = i + 1 415 last_frame_added = start 416 i += 1 417 418 # We only permit spaces after the last frame 419 if self.spec[last_frame_added:].strip(): 420 raise Exception("Invalid characters found after last frame"); 421 # improper input 422 if count: 423 raise Exception("Check your parentheses") 424 return frames 425 426 def _parse_spec(self): 427 frames = self._parse_frames() 428 frames.reverse() 429 430 absolute_order = 0 431 for f in frames: 432 # default case 433 func, pred = f[0], f[1] 434 435 if not self._validate_predicate(pred): 436 raise Exception("Invalid predicate") 437 if not self._validate_identifier(func): 438 raise Exception("Invalid function identifier") 439 tup = (pred, absolute_order) 440 441 if func not in self.map: 442 self.map[func] = [tup] 443 else: 444 self.map[func].append(tup) 445 446 absolute_order += 1 447 448 if self.key not in self.map: 449 self.map[self.key] = [('(true)', absolute_order)] 450 absolute_order += 1 451 452 self.length = absolute_order 453 454 def _validate_identifier(self, func): 455 # We've already established paren balancing. We will only look for 456 # identifier validity here. 457 paren_index = func.find("(") 458 potential_id = func[:paren_index] 459 pattern = '[_a-zA-z][_a-zA-Z0-9]*$' 460 if re.match(pattern, potential_id): 461 return True 462 return False 463 464 def _validate_predicate(self, pred): 465 466 if len(pred) > 0 and pred[0] == "(": 467 open = 1 468 for i in range(1, len(pred)): 469 if pred[i] == "(": 470 open += 1 471 elif pred[i] == ")": 472 open -= 1 473 if open != 0: 474 # not well formed, break 475 return False 476 477 return True 478 479 def _def_pid_struct(self): 480 text = """ 481struct pid_struct { 482 u64 curr_call; /* book keeping to handle recursion */ 483 u64 conds_met; /* stack pointer */ 484 u64 stack[%s]; 485}; 486""" % self.length 487 return text 488 489 def _attach_probes(self): 490 self.bpf = BPF(text=self.program) 491 for p in self.probes: 492 p.attach(self.bpf) 493 494 def _generate_program(self): 495 # leave out auto includes for now 496 self.program += '#include <linux/mm.h>\n' 497 for include in (self.args.include or []): 498 self.program += "#include <%s>\n" % include 499 500 self.program += self._def_pid_struct() 501 self.program += "BPF_HASH(m, u32, struct pid_struct);\n" 502 self.program += "BPF_ARRAY(count, u32, 1);\n" 503 504 for p in self.probes: 505 self.program += p.generate_program() + "\n" 506 507 if self.args.verbose: 508 print(self.program) 509 510 def _main_loop(self): 511 while True: 512 try: 513 self.bpf.perf_buffer_poll() 514 except KeyboardInterrupt: 515 exit() 516 517 def run(self): 518 self._create_probes() 519 self._generate_program() 520 self._attach_probes() 521 self._main_loop() 522 523 524if __name__ == "__main__": 525 Tool().run() 526