xref: /aosp_15_r20/external/crosvm/third_party/minijail/tools/generate_seccomp_policy.py (revision 4b9c6d91573e8b3a96609339b46361b5476dd0f9)
1#!/usr/bin/env python3
2# -*- coding: utf-8 -*-
3#
4# Copyright (C) 2016 The Android Open Source Project
5#
6# Licensed under the Apache License, Version 2.0 (the "License");
7# you may not use this file except in compliance with the License.
8# You may obtain a copy of the License at
9#
10#      http://www.apache.org/licenses/LICENSE-2.0
11#
12# Unless required by applicable law or agreed to in writing, software
13# distributed under the License is distributed on an "AS IS" BASIS,
14# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15# See the License for the specific language governing permissions and
16# limitations under the License.
17#
18# This script will take any number of trace files generated by strace(1)
19# and output a system call filtering policy suitable for use with Minijail.
20
21"""Tool to generate a minijail seccomp filter from strace or audit output."""
22
23from __future__ import print_function
24
25import argparse
26import collections
27import os
28import re
29import sys
30
31# auparse may not be installed and is currently optional.
32try:
33    import auparse
34except ImportError:
35    auparse = None
36
37
38NOTICE = """# Copyright (C) 2018 The Android Open Source Project
39#
40# Licensed under the Apache License, Version 2.0 (the "License");
41# you may not use this file except in compliance with the License.
42# You may obtain a copy of the License at
43#
44#      http://www.apache.org/licenses/LICENSE-2.0
45#
46# Unless required by applicable law or agreed to in writing, software
47# distributed under the License is distributed on an "AS IS" BASIS,
48# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
49# See the License for the specific language governing permissions and
50# limitations under the License.
51"""
52
53ALLOW = '1'
54
55# This ignores any leading PID tag and trailing <unfinished ...>, and extracts
56# the syscall name and the argument list.
57LINE_RE = re.compile(r'^\s*(?:\[[^]]*\]|\d+)?\s*([a-zA-Z0-9_]+)\(([^)<]*)')
58
59SOCKETCALLS = {
60    'accept', 'bind', 'connect', 'getpeername', 'getsockname', 'getsockopt',
61    'listen', 'recv', 'recvfrom', 'recvmsg', 'send', 'sendmsg', 'sendto',
62    'setsockopt', 'shutdown', 'socket', 'socketpair',
63}
64
65# List of private ARM syscalls. These can be found in any ARM specific unistd.h
66# such as Linux's arch/arm/include/uapi/asm/unistd.h.
67PRIVATE_ARM_SYSCALLS = {
68    983041: 'ARM_breakpoint',
69    983042: 'ARM_cacheflush',
70    983043: 'ARM_usr26',
71    983044: 'ARM_usr32',
72    983045: 'ARM_set_tls',
73}
74
75ArgInspectionEntry = collections.namedtuple('ArgInspectionEntry',
76                                            ('arg_index', 'value_set'))
77
78
79# pylint: disable=too-few-public-methods
80class BucketInputFiles(argparse.Action):
81    """Buckets input files using simple content based heuristics.
82
83    Attributes:
84      audit_logs: Mutually exclusive list of audit log filenames.
85      traces: Mutually exclusive list of strace log filenames.
86    """
87    def __call__(self, parser, namespace, values, option_string=None):
88        audit_logs = []
89        traces = []
90
91        strace_line_re = re.compile(r'[a-z]+[0-9]*\(.+\) += ')
92        audit_line_re = re.compile(r'type=(SYSCALL|SECCOMP)')
93
94        for filename in values:
95            if not os.path.exists(filename):
96                parser.error(f'Input file {filename} not found.')
97            with open(filename, mode='r', encoding='utf8') as input_file:
98                for line in input_file.readlines():
99                    if strace_line_re.search(line):
100                        traces.append(filename)
101                        break
102                    if audit_line_re.search(line):
103                        audit_logs.append(filename)
104                        break
105                else:
106                    # Treat it as an strace log to retain legacy behaviour and
107                    # also just in case the strace regex is imperfect.
108                    traces.append(filename)
109
110        setattr(namespace, 'audit_logs', audit_logs)
111        setattr(namespace, 'traces', traces)
112# pylint: enable=too-few-public-methods
113
114
115def parse_args(argv):
116    """Returns the parsed CLI arguments for this tool."""
117    parser = argparse.ArgumentParser(description=__doc__)
118    parser.add_argument('--verbose', action='store_true',
119                        help='output informational messages to stderr')
120    parser.add_argument('--frequency', type=argparse.FileType('w'),
121                        help='frequency file')
122    parser.add_argument('--policy', type=argparse.FileType('w'),
123                        default=sys.stdout, help='policy file')
124    parser.add_argument('input-logs', action=BucketInputFiles,
125                        help='strace and/or audit logs', nargs='+')
126    parser.add_argument('--audit-comm', type=str, metavar='PROCESS_NAME',
127                        help='relevant process name from the audit.log files')
128    opts = parser.parse_args(argv)
129
130    if opts.audit_logs and not auparse:
131        parser.error('Python bindings for the audit subsystem were not found.\n'
132                     'Please install the python3-audit (sometimes python-audit)'
133                     ' package for your distro to process audit logs: '
134                     f'{opts.audit_logs}')
135
136    if opts.audit_logs and not opts.audit_comm:
137        parser.error(f'--audit-comm is required when using audit logs as input:'
138                     f' {opts.audit_logs}')
139
140    if not opts.audit_logs and opts.audit_comm:
141        parser.error('--audit-comm was specified yet none of the input files '
142                     'matched our hueristic for an audit log')
143
144    return opts
145
146
147def get_seccomp_bpf_filter(syscall, entry):
148    """Returns a minijail seccomp-bpf filter expression for the syscall."""
149    arg_index = entry.arg_index
150    arg_values = entry.value_set
151    atoms = []
152    if syscall in ('mmap', 'mmap2', 'mprotect') and arg_index == 2:
153        # See if there is at least one instance of any of these syscalls trying
154        # to map memory with both PROT_EXEC and PROT_WRITE. If there isn't, we
155        # can craft a concise expression to forbid this.
156        write_and_exec = set(('PROT_EXEC', 'PROT_WRITE'))
157        for arg_value in arg_values:
158            if write_and_exec.issubset(set(p.strip() for p in
159                                           arg_value.split('|'))):
160                break
161        else:
162            atoms.extend(['arg2 in ~PROT_EXEC', 'arg2 in ~PROT_WRITE'])
163            arg_values = set()
164    atoms.extend(f'arg{arg_index} == {arg_value}' for arg_value in arg_values)
165    return ' || '.join(atoms)
166
167
168def parse_trace_file(trace_filename, syscalls, arg_inspection):
169    """Parses one file produced by strace."""
170    uses_socketcall = ('i386' in trace_filename or
171                       ('x86' in trace_filename and
172                        '64' not in trace_filename))
173
174    with open(trace_filename, encoding='utf8') as trace_file:
175        for line in trace_file:
176            matches = LINE_RE.match(line)
177            if not matches:
178                continue
179
180            syscall, args = matches.groups()
181            if uses_socketcall and syscall in SOCKETCALLS:
182                syscall = 'socketcall'
183
184            # strace omits the 'ARM_' prefix on all private ARM syscalls. Add
185            # it manually here as a workaround. These syscalls are exclusive
186            # to ARM so we don't need to predicate this on a trace_filename
187            # based heuristic for the arch.
188            if f'ARM_{syscall}' in PRIVATE_ARM_SYSCALLS.values():
189                syscall = f'ARM_{syscall}'
190
191            syscalls[syscall] += 1
192
193            args = [arg.strip() for arg in args.split(',')]
194
195            if syscall in arg_inspection:
196                arg_value = args[arg_inspection[syscall].arg_index]
197                arg_inspection[syscall].value_set.add(arg_value)
198
199
200def parse_audit_log(audit_log, audit_comm, syscalls, arg_inspection):
201    """Parses one audit.log file generated by the Linux audit subsystem."""
202
203    unknown_syscall_re = re.compile(r'unknown-syscall\((?P<syscall_num>\d+)\)')
204
205    au = auparse.AuParser(auparse.AUSOURCE_FILE, audit_log)
206    # Quick validity check for whether this parses as a valid audit log. The
207    # first event should have at least one record.
208    if not au.first_record():
209        raise ValueError(f'Unable to parse audit log file {audit_log.name}')
210
211    # Iterate through events where _any_ contained record matches
212    # ((type == SECCOMP || type == SYSCALL) && comm == audit_comm).
213    au.search_add_item('type', '=', 'SECCOMP', auparse.AUSEARCH_RULE_CLEAR)
214    au.search_add_item('type', '=', 'SYSCALL', auparse.AUSEARCH_RULE_OR)
215    au.search_add_item('comm', '=', f'"{audit_comm}"',
216                       auparse.AUSEARCH_RULE_AND)
217
218    # auparse_find_field(3) will ignore preceding fields in the record and
219    # at the same time happily cross record boundaries when looking for the
220    # field. This helper method always seeks the cursor back to the first
221    # field in the record and stops searching before crossing over to the
222    # next record; making the search far less error prone.
223    # Also implicitly seeks the internal 'cursor' to the matching field
224    # for any subsequent calls like auparse_interpret_field.
225    def _find_field_in_current_record(name):
226        au.first_field()
227        while True:
228            if au.get_field_name() == name:
229                return au.get_field_str()
230            if not au.next_field():
231                return None
232
233    while au.search_next_event():
234        # The event may have multiple records. Loop through all.
235        au.first_record()
236        for _ in range(au.get_num_records()):
237            event_type = _find_field_in_current_record('type')
238            comm = _find_field_in_current_record('comm')
239            # Some of the records in this event may not be relevant
240            # despite the event-specific search filter. Skip those.
241            if (event_type not in ('SECCOMP', 'SYSCALL') or
242                    comm != f'"{audit_comm}"'):
243                au.next_record()
244                continue
245
246            if not _find_field_in_current_record('syscall'):
247                raise ValueError(f'Could not find field "syscall" in event of '
248                                 f'type {event_type}')
249            # Intepret the syscall field that's under our 'cursor' following the
250            # find. Interpreting fields yields human friendly names instead
251            # of integers. E.g '16' -> 'ioctl'.
252            syscall = au.interpret_field()
253
254            # TODO(crbug/1172449): Add these syscalls to upstream
255            # audit-userspace and remove this workaround.
256            # This is redundant but safe for non-ARM architectures due to the
257            # disjoint set of private syscall numbers.
258            match = unknown_syscall_re.match(syscall)
259            if match:
260                syscall_num = int(match.group('syscall_num'))
261                syscall = PRIVATE_ARM_SYSCALLS.get(syscall_num, syscall)
262
263            if ((syscall in arg_inspection and event_type == 'SECCOMP') or
264                (syscall not in arg_inspection and event_type == 'SYSCALL')):
265                # Skip SECCOMP records for syscalls that require argument
266                # inspection. Similarly, skip SYSCALL records for syscalls
267                # that do not require argument inspection. Technically such
268                # records wouldn't exist per our setup instructions but audit
269                # sometimes lets a few records slip through.
270                au.next_record()
271                continue
272            elif event_type == 'SYSCALL':
273                arg_field_name = f'a{arg_inspection[syscall].arg_index}'
274                if not _find_field_in_current_record(arg_field_name):
275                    raise ValueError(f'Could not find field "{arg_field_name}"'
276                                     f'in event of type {event_type}')
277                # Intepret the arg field that's under our 'cursor' following the
278                # find. This may yield a more human friendly name.
279                # E.g '5401' -> 'TCGETS'.
280                arg_inspection[syscall].value_set.add(au.interpret_field())
281
282            syscalls[syscall] += 1
283            au.next_record()
284
285
286def main(argv=None):
287    """Main entrypoint."""
288
289    if argv is None:
290        argv = sys.argv[1:]
291
292    opts = parse_args(argv)
293
294    syscalls = collections.defaultdict(int)
295
296    arg_inspection = {
297        'socket': ArgInspectionEntry(0, set([])),   # int domain
298        'ioctl': ArgInspectionEntry(1, set([])),    # int request
299        'prctl': ArgInspectionEntry(0, set([])),    # int option
300        'mmap': ArgInspectionEntry(2, set([])),     # int prot
301        'mmap2': ArgInspectionEntry(2, set([])),    # int prot
302        'mprotect': ArgInspectionEntry(2, set([])), # int prot
303    }
304
305    if opts.verbose:
306        # Print an informational message to stderr in case the filetype detection
307        # heuristics are wonky.
308        print('Generating a seccomp policy using these input files:',
309              file=sys.stderr)
310        print(f'Strace logs: {opts.traces}', file=sys.stderr)
311        print(f'Audit logs: {opts.audit_logs}', file=sys.stderr)
312
313    for trace_filename in opts.traces:
314        parse_trace_file(trace_filename, syscalls, arg_inspection)
315
316    for audit_log in opts.audit_logs:
317        parse_audit_log(audit_log, opts.audit_comm, syscalls, arg_inspection)
318
319    # Add the basic set if they are not yet present.
320    basic_set = [
321        'restart_syscall', 'exit', 'exit_group', 'rt_sigreturn',
322    ]
323    for basic_syscall in basic_set:
324        if basic_syscall not in syscalls:
325            syscalls[basic_syscall] = 1
326
327    # If a frequency file isn't used then sort the syscalls based on frequency
328    # to make the common case fast (by checking frequent calls earlier).
329    # Otherwise, sort alphabetically to make it easier for humans to see which
330    # calls are in use (and if necessary manually add a new syscall to the
331    # list).
332    if opts.frequency is None:
333        sorted_syscalls = list(
334            x[0] for x in sorted(syscalls.items(), key=lambda pair: pair[1],
335                                 reverse=True)
336        )
337    else:
338        sorted_syscalls = list(
339            x[0] for x in sorted(syscalls.items(), key=lambda pair: pair[0])
340        )
341
342    print(NOTICE, file=opts.policy)
343    if opts.frequency is not None:
344        print(NOTICE, file=opts.frequency)
345
346    for syscall in sorted_syscalls:
347        if syscall in arg_inspection:
348            arg_filter = get_seccomp_bpf_filter(syscall,
349                                                arg_inspection[syscall])
350        else:
351            arg_filter = ALLOW
352        print(f'{syscall}: {arg_filter}', file=opts.policy)
353        if opts.frequency is not None:
354            print(f'{syscall}: {syscalls[syscall]}', file=opts.frequency)
355
356
357if __name__ == '__main__':
358    sys.exit(main(sys.argv[1:]))
359