xref: /aosp_15_r20/external/bcc/tools/inject.py (revision 387f9dfdfa2baef462e92476d413c7bc2470293e)
1#!/usr/bin/env python
2#
3# This script generates a BPF program with structure inspired by trace.py. The
4# generated program operates on PID-indexed stacks. Generally speaking,
5# bookkeeping is done at every intermediate function kprobe/kretprobe to enforce
6# the goal of "fail iff this call chain and these predicates".
7#
8# Top level functions(the ones at the end of the call chain) are responsible for
9# creating the pid_struct and deleting it from the map in kprobe and kretprobe
10# respectively.
11#
12# Intermediate functions(between should_fail_whatever and the top level
13# functions) are responsible for updating the stack to indicate "I have been
14# called and one of my predicate(s) passed" in their entry probes. In their exit
15# probes, they do the opposite, popping their stack to maintain correctness.
16# This implementation aims to ensure correctness in edge cases like recursive
17# calls, so there's some additional information stored in pid_struct for that.
18#
19# At the bottom level function(should_fail_whatever), we do a simple check to
20# ensure all necessary calls/predicates have passed before error injection.
21#
22# Note: presently there are a few hacks to get around various rewriter/verifier
23# issues.
24#
25# Note: this tool requires:
26# - CONFIG_BPF_KPROBE_OVERRIDE
27#
28# USAGE: inject [-h] [-I header] [-P probability] [-v] mode spec
29#
30# Copyright (c) 2018 Facebook, Inc.
31# Licensed under the Apache License, Version 2.0 (the "License")
32#
33# 16-Mar-2018   Howard McLauchlan   Created this.
34
35import argparse
36import re
37from bcc import BPF
38
39
40class Probe:
41    errno_mapping = {
42        "kmalloc": "-ENOMEM",
43        "bio": "-EIO",
44        "alloc_page" : "true",
45    }
46
47    @classmethod
48    def configure(cls, mode, probability, count):
49        cls.mode = mode
50        cls.probability = probability
51        cls.count = count
52
53    def __init__(self, func, preds, length, entry):
54        # length of call chain
55        self.length = length
56        self.func = func
57        self.preds = preds
58        self.is_entry = entry
59
60    def _bail(self, err):
61        raise ValueError("error in probe '%s': %s" %
62                (self.spec, err))
63
64    def _get_err(self):
65        return Probe.errno_mapping[Probe.mode]
66
67    def _get_if_top(self):
68        # ordering guarantees that if this function is top, the last tup is top
69        chk = self.preds[0][1] == 0
70        if not chk:
71            return ""
72
73        if Probe.probability == 1:
74            early_pred = "false"
75        else:
76            early_pred = "bpf_get_prandom_u32() > %s" % str(int((1<<32)*Probe.probability))
77        # init the map
78        # don't do an early exit here so the singular case works automatically
79        # have an early exit for probability option
80        enter = """
81        /*
82         * Early exit for probability case
83         */
84        if (%s)
85               return 0;
86        /*
87         * Top level function init map
88         */
89        struct pid_struct p_struct = {0, 0};
90        m.insert(&pid, &p_struct);
91        """ % early_pred
92
93        # kill the entry
94        exit = """
95        /*
96         * Top level function clean up map
97         */
98        m.delete(&pid);
99        """
100
101        return enter if self.is_entry else exit
102
103    def _get_heading(self):
104
105        # we need to insert identifier and ctx into self.func
106        # gonna make a lot of formatting assumptions to make this work
107        left = self.func.find("(")
108        right = self.func.rfind(")")
109
110        # self.event and self.func_name need to be accessible
111        self.event = self.func[0:left]
112        self.func_name = self.event + ("_entry" if self.is_entry else "_exit")
113        func_sig = "struct pt_regs *ctx"
114
115        # assume there's something in there, no guarantee its well formed
116        if right > left + 1 and self.is_entry:
117            func_sig += ", " + self.func[left + 1:right]
118
119        return "int %s(%s)" % (self.func_name, func_sig)
120
121    def _get_entry_logic(self):
122        # there is at least one tup(pred, place) for this function
123        text = """
124
125        if (p->conds_met >= %s)
126                return 0;
127        if (p->conds_met == %s && %s) {
128                p->stack[%s] = p->curr_call;
129                p->conds_met++;
130        }"""
131        text = text % (self.length, self.preds[0][1], self.preds[0][0],
132                self.preds[0][1])
133
134        # for each additional pred
135        for tup in self.preds[1:]:
136            text += """
137        else if (p->conds_met == %s && %s) {
138                p->stack[%s] = p->curr_call;
139                p->conds_met++;
140        }
141            """ % (tup[1], tup[0], tup[1])
142        return text
143
144    def _generate_entry(self):
145        prog = self._get_heading() + """
146{
147        u32 pid = bpf_get_current_pid_tgid();
148        %s
149
150        struct pid_struct *p = m.lookup(&pid);
151
152        if (!p)
153                return 0;
154
155        /*
156         * preparation for predicate, if necessary
157         */
158         %s
159        /*
160         * Generate entry logic
161         */
162        %s
163
164        p->curr_call++;
165
166        return 0;
167}"""
168
169        prog = prog % (self._get_if_top(), self.prep, self._get_entry_logic())
170        return prog
171
172    # only need to check top of stack
173    def _get_exit_logic(self):
174        text = """
175        if (p->conds_met < 1 || p->conds_met >= %s)
176                return 0;
177
178        if (p->stack[p->conds_met - 1] == p->curr_call)
179                p->conds_met--;
180        """
181        return text % str(self.length + 1)
182
183    def _generate_exit(self):
184        prog = self._get_heading() + """
185{
186        u32 pid = bpf_get_current_pid_tgid();
187
188        struct pid_struct *p = m.lookup(&pid);
189
190        if (!p)
191                return 0;
192
193        p->curr_call--;
194
195        /*
196         * Generate exit logic
197         */
198        %s
199        %s
200        return 0;
201}"""
202
203        prog = prog % (self._get_exit_logic(), self._get_if_top())
204
205        return prog
206
207    # Special case for should_fail_whatever
208    def _generate_bottom(self):
209        pred = self.preds[0][0]
210        text = self._get_heading() + """
211{
212        u32 overridden = 0;
213        int zero = 0;
214        u32* val;
215
216        val = count.lookup(&zero);
217        if (val)
218            overridden = *val;
219
220        /*
221         * preparation for predicate, if necessary
222         */
223         %s
224        /*
225         * If this is the only call in the chain and predicate passes
226         */
227        if (%s == 1 && %s && overridden < %s) {
228                count.atomic_increment(zero);
229                bpf_override_return(ctx, %s);
230                return 0;
231        }
232        u32 pid = bpf_get_current_pid_tgid();
233
234        struct pid_struct *p = m.lookup(&pid);
235
236        if (!p)
237                return 0;
238
239        /*
240         * If all conds have been met and predicate passes
241         */
242        if (p->conds_met == %s && %s && overridden < %s) {
243                count.atomic_increment(zero);
244                bpf_override_return(ctx, %s);
245        }
246        return 0;
247}"""
248        return text % (self.prep, self.length, pred, Probe.count,
249                self._get_err(), self.length - 1, pred, Probe.count,
250                self._get_err())
251
252    # presently parses and replaces STRCMP
253    # STRCMP exists because string comparison is inconvenient and somewhat buggy
254    # https://github.com/iovisor/bcc/issues/1617
255    def _prepare_pred(self):
256        self.prep = ""
257        for i in range(len(self.preds)):
258            new_pred = ""
259            pred = self.preds[i][0]
260            place = self.preds[i][1]
261            start, ind = 0, 0
262            while start < len(pred):
263                ind = pred.find("STRCMP(", start)
264                if ind == -1:
265                    break
266                new_pred += pred[start:ind]
267                # 7 is len("STRCMP(")
268                start = pred.find(")", start + 7) + 1
269
270                # then ind ... start is STRCMP(...)
271                ptr, literal = pred[ind + 7:start - 1].split(",")
272                literal = literal.strip()
273
274                # x->y->z, some string literal
275                # we make unique id with place_ind
276                uuid = "%s_%s" % (place, ind)
277                unique_bool = "is_true_%s" % uuid
278                self.prep += """
279        char *str_%s = %s;
280        bool %s = true;\n""" % (uuid, ptr.strip(), unique_bool)
281
282                check = "\t%s &= *(str_%s++) == '%%s';\n" % (unique_bool, uuid)
283
284                for ch in literal:
285                    self.prep += check % ch
286                self.prep += check % r'\0'
287                new_pred += unique_bool
288
289            new_pred += pred[start:]
290            self.preds[i] = (new_pred, place)
291
292    def generate_program(self):
293        # generate code to work around various rewriter issues
294        self._prepare_pred()
295
296        # special case for bottom
297        if self.preds[-1][1] == self.length - 1:
298            return self._generate_bottom()
299
300        return self._generate_entry() if self.is_entry else self._generate_exit()
301
302    def attach(self, bpf):
303        if self.is_entry:
304            bpf.attach_kprobe(event=self.event,
305                    fn_name=self.func_name)
306        else:
307            bpf.attach_kretprobe(event=self.event,
308                    fn_name=self.func_name)
309
310
311class Tool:
312
313    examples ="""
314EXAMPLES:
315# ./inject.py kmalloc -v 'SyS_mount()'
316    Fails all calls to syscall mount
317# ./inject.py kmalloc -v '(true) => SyS_mount()(true)'
318    Explicit rewriting of above
319# ./inject.py kmalloc -v 'mount_subtree() => btrfs_mount()'
320    Fails btrfs mounts only
321# ./inject.py kmalloc -v 'd_alloc_parallel(struct dentry *parent, const struct \\
322    qstr *name)(STRCMP(name->name, 'bananas'))'
323    Fails dentry allocations of files named 'bananas'
324# ./inject.py kmalloc -v -P 0.01 'SyS_mount()'
325    Fails calls to syscall mount with 1% probability
326    """
327    # add cases as necessary
328    error_injection_mapping = {
329        "kmalloc": "should_failslab(struct kmem_cache *s, gfp_t gfpflags)",
330        "bio": "should_fail_bio(struct bio *bio)",
331        "alloc_page": "should_fail_alloc_page(gfp_t gfp_mask, unsigned int order)",
332    }
333
334    def __init__(self):
335        parser = argparse.ArgumentParser(description="Fail specified kernel" +
336                " functionality when call chain and predicates are met",
337                formatter_class=argparse.RawDescriptionHelpFormatter,
338                epilog=Tool.examples)
339        parser.add_argument(dest="mode", choices=["kmalloc", "bio", "alloc_page"],
340                help="indicate which base kernel function to fail")
341        parser.add_argument(metavar="spec", dest="spec",
342                help="specify call chain")
343        parser.add_argument("-I", "--include", action="append",
344                metavar="header",
345                help="additional header files to include in the BPF program")
346        parser.add_argument("-P", "--probability", default=1,
347                metavar="probability", type=float,
348                help="probability that this call chain will fail")
349        parser.add_argument("-v", "--verbose", action="store_true",
350                help="print BPF program")
351        parser.add_argument("-c", "--count", action="store", default=-1,
352                help="Number of fails before bypassing the override")
353        self.args = parser.parse_args()
354
355        self.program = ""
356        self.spec = self.args.spec
357        self.map = {}
358        self.probes = []
359        self.key = Tool.error_injection_mapping[self.args.mode]
360
361    # create_probes and associated stuff
362    def _create_probes(self):
363        self._parse_spec()
364        Probe.configure(self.args.mode, self.args.probability, self.args.count)
365        # self, func, preds, total, entry
366
367        # create all the pair probes
368        for fx, preds in self.map.items():
369
370            # do the enter
371            self.probes.append(Probe(fx, preds, self.length, True))
372
373            if self.key == fx:
374                continue
375
376            # do the exit
377            self.probes.append(Probe(fx, preds, self.length, False))
378
379    def _parse_frames(self):
380        # sentinel
381        data = self.spec + '\0'
382        start, count = 0, 0
383
384        frames = []
385        cur_frame = []
386        i = 0
387        last_frame_added = 0
388
389        while i < len(data):
390            # improper input
391            if count < 0:
392                raise Exception("Check your parentheses")
393            c = data[i]
394            count += c == '('
395            count -= c == ')'
396            if not count:
397                if c == '\0' or (c == '=' and data[i + 1] == '>'):
398                    # This block is closing a chunk. This means cur_frame must
399                    # have something in it.
400                    if not cur_frame:
401                        raise Exception("Cannot parse spec, missing parens")
402                    if len(cur_frame) == 2:
403                        frame = tuple(cur_frame)
404                    elif cur_frame[0][0] == '(':
405                        frame = self.key, cur_frame[0]
406                    else:
407                        frame = cur_frame[0], '(true)'
408                    frames.append(frame)
409                    del cur_frame[:]
410                    i += 1
411                    start = i + 1
412                elif c == ')':
413                    cur_frame.append(data[start:i + 1].strip())
414                    start = i + 1
415                    last_frame_added = start
416            i += 1
417
418        # We only permit spaces after the last frame
419        if self.spec[last_frame_added:].strip():
420            raise Exception("Invalid characters found after last frame");
421        # improper input
422        if count:
423            raise Exception("Check your parentheses")
424        return frames
425
426    def _parse_spec(self):
427        frames = self._parse_frames()
428        frames.reverse()
429
430        absolute_order = 0
431        for f in frames:
432            # default case
433            func, pred = f[0], f[1]
434
435            if not self._validate_predicate(pred):
436                raise Exception("Invalid predicate")
437            if not self._validate_identifier(func):
438                raise Exception("Invalid function identifier")
439            tup = (pred, absolute_order)
440
441            if func not in self.map:
442                self.map[func] = [tup]
443            else:
444                self.map[func].append(tup)
445
446            absolute_order += 1
447
448        if self.key not in self.map:
449            self.map[self.key] = [('(true)', absolute_order)]
450            absolute_order += 1
451
452        self.length = absolute_order
453
454    def _validate_identifier(self, func):
455        # We've already established paren balancing. We will only look for
456        # identifier validity here.
457        paren_index = func.find("(")
458        potential_id = func[:paren_index]
459        pattern = '[_a-zA-z][_a-zA-Z0-9]*$'
460        if re.match(pattern, potential_id):
461            return True
462        return False
463
464    def _validate_predicate(self, pred):
465
466        if len(pred) > 0 and pred[0] == "(":
467            open = 1
468            for i in range(1, len(pred)):
469                if pred[i] == "(":
470                    open += 1
471                elif pred[i] == ")":
472                    open -= 1
473            if open != 0:
474                # not well formed, break
475                return False
476
477        return True
478
479    def _def_pid_struct(self):
480        text = """
481struct pid_struct {
482    u64 curr_call; /* book keeping to handle recursion */
483    u64 conds_met; /* stack pointer */
484    u64 stack[%s];
485};
486""" % self.length
487        return text
488
489    def _attach_probes(self):
490        self.bpf = BPF(text=self.program)
491        for p in self.probes:
492            p.attach(self.bpf)
493
494    def _generate_program(self):
495        # leave out auto includes for now
496        self.program += '#include <linux/mm.h>\n'
497        for include in (self.args.include or []):
498            self.program += "#include <%s>\n" % include
499
500        self.program += self._def_pid_struct()
501        self.program += "BPF_HASH(m, u32, struct pid_struct);\n"
502        self.program += "BPF_ARRAY(count, u32, 1);\n"
503
504        for p in self.probes:
505            self.program += p.generate_program() + "\n"
506
507        if self.args.verbose:
508            print(self.program)
509
510    def _main_loop(self):
511        while True:
512            try:
513                self.bpf.perf_buffer_poll()
514            except KeyboardInterrupt:
515                exit()
516
517    def run(self):
518        self._create_probes()
519        self._generate_program()
520        self._attach_probes()
521        self._main_loop()
522
523
524if __name__ == "__main__":
525    Tool().run()
526