xref: /aosp_15_r20/external/minijail/tools/bpf.py (revision 4b9c6d91573e8b3a96609339b46361b5476dd0f9)
1*4b9c6d91SCole Faust#!/usr/bin/env python3
2*4b9c6d91SCole Faust# -*- coding: utf-8 -*-
3*4b9c6d91SCole Faust#
4*4b9c6d91SCole Faust# Copyright (C) 2018 The Android Open Source Project
5*4b9c6d91SCole Faust#
6*4b9c6d91SCole Faust# Licensed under the Apache License, Version 2.0 (the "License");
7*4b9c6d91SCole Faust# you may not use this file except in compliance with the License.
8*4b9c6d91SCole Faust# You may obtain a copy of the License at
9*4b9c6d91SCole Faust#
10*4b9c6d91SCole Faust#      http://www.apache.org/licenses/LICENSE-2.0
11*4b9c6d91SCole Faust#
12*4b9c6d91SCole Faust# Unless required by applicable law or agreed to in writing, software
13*4b9c6d91SCole Faust# distributed under the License is distributed on an "AS IS" BASIS,
14*4b9c6d91SCole Faust# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15*4b9c6d91SCole Faust# See the License for the specific language governing permissions and
16*4b9c6d91SCole Faust# limitations under the License.
17*4b9c6d91SCole Faust"""Tools to interact with BPF programs."""
18*4b9c6d91SCole Faust
19*4b9c6d91SCole Faustimport abc
20*4b9c6d91SCole Faustimport collections
21*4b9c6d91SCole Faustimport struct
22*4b9c6d91SCole Faust
23*4b9c6d91SCole Faust# This comes from syscall(2). Most architectures only support passing 6 args to
24*4b9c6d91SCole Faust# syscalls, but ARM supports passing 7.
25*4b9c6d91SCole FaustMAX_SYSCALL_ARGUMENTS = 7
26*4b9c6d91SCole Faust
27*4b9c6d91SCole Faust# The following fields were copied from <linux/bpf_common.h>:
28*4b9c6d91SCole Faust
29*4b9c6d91SCole Faust# Instruction classes
30*4b9c6d91SCole FaustBPF_LD = 0x00
31*4b9c6d91SCole FaustBPF_LDX = 0x01
32*4b9c6d91SCole FaustBPF_ST = 0x02
33*4b9c6d91SCole FaustBPF_STX = 0x03
34*4b9c6d91SCole FaustBPF_ALU = 0x04
35*4b9c6d91SCole FaustBPF_JMP = 0x05
36*4b9c6d91SCole FaustBPF_RET = 0x06
37*4b9c6d91SCole FaustBPF_MISC = 0x07
38*4b9c6d91SCole Faust
39*4b9c6d91SCole Faust# LD/LDX fields.
40*4b9c6d91SCole Faust# Size
41*4b9c6d91SCole FaustBPF_W = 0x00
42*4b9c6d91SCole FaustBPF_H = 0x08
43*4b9c6d91SCole FaustBPF_B = 0x10
44*4b9c6d91SCole Faust# Mode
45*4b9c6d91SCole FaustBPF_IMM = 0x00
46*4b9c6d91SCole FaustBPF_ABS = 0x20
47*4b9c6d91SCole FaustBPF_IND = 0x40
48*4b9c6d91SCole FaustBPF_MEM = 0x60
49*4b9c6d91SCole FaustBPF_LEN = 0x80
50*4b9c6d91SCole FaustBPF_MSH = 0xa0
51*4b9c6d91SCole Faust
52*4b9c6d91SCole Faust# JMP fields.
53*4b9c6d91SCole FaustBPF_JA = 0x00
54*4b9c6d91SCole FaustBPF_JEQ = 0x10
55*4b9c6d91SCole FaustBPF_JGT = 0x20
56*4b9c6d91SCole FaustBPF_JGE = 0x30
57*4b9c6d91SCole FaustBPF_JSET = 0x40
58*4b9c6d91SCole Faust
59*4b9c6d91SCole Faust# Source
60*4b9c6d91SCole FaustBPF_K = 0x00
61*4b9c6d91SCole FaustBPF_X = 0x08
62*4b9c6d91SCole Faust
63*4b9c6d91SCole FaustBPF_MAXINSNS = 4096
64*4b9c6d91SCole Faust
65*4b9c6d91SCole Faust# The following fields were copied from <linux/seccomp.h>:
66*4b9c6d91SCole Faust
67*4b9c6d91SCole FaustSECCOMP_RET_KILL_PROCESS = 0x80000000
68*4b9c6d91SCole FaustSECCOMP_RET_KILL_THREAD = 0x00000000
69*4b9c6d91SCole FaustSECCOMP_RET_TRAP = 0x00030000
70*4b9c6d91SCole FaustSECCOMP_RET_ERRNO = 0x00050000
71*4b9c6d91SCole FaustSECCOMP_RET_TRACE = 0x7ff00000
72*4b9c6d91SCole FaustSECCOMP_RET_USER_NOTIF = 0x7fc00000
73*4b9c6d91SCole FaustSECCOMP_RET_LOG = 0x7ffc0000
74*4b9c6d91SCole FaustSECCOMP_RET_ALLOW = 0x7fff0000
75*4b9c6d91SCole Faust
76*4b9c6d91SCole FaustSECCOMP_RET_ACTION_FULL = 0xffff0000
77*4b9c6d91SCole FaustSECCOMP_RET_DATA = 0x0000ffff
78*4b9c6d91SCole Faust
79*4b9c6d91SCole Faust
80*4b9c6d91SCole Faustdef arg_offset(arg_index, hi=False):
81*4b9c6d91SCole Faust    """Return the BPF_LD|BPF_W|BPF_ABS addressing-friendly register offset."""
82*4b9c6d91SCole Faust    offsetof_args = 4 + 4 + 8
83*4b9c6d91SCole Faust    arg_width = 8
84*4b9c6d91SCole Faust    return offsetof_args + arg_width * arg_index + (arg_width // 2) * hi
85*4b9c6d91SCole Faust
86*4b9c6d91SCole Faust
87*4b9c6d91SCole Faustdef simulate(instructions, arch, syscall_number, *args):
88*4b9c6d91SCole Faust    """Simulate a BPF program with the given arguments."""
89*4b9c6d91SCole Faust    args = ((args + (0, ) *
90*4b9c6d91SCole Faust             (MAX_SYSCALL_ARGUMENTS - len(args)))[:MAX_SYSCALL_ARGUMENTS])
91*4b9c6d91SCole Faust    input_memory = struct.pack('IIQ' + 'Q' * MAX_SYSCALL_ARGUMENTS,
92*4b9c6d91SCole Faust                               syscall_number, arch, 0, *args)
93*4b9c6d91SCole Faust
94*4b9c6d91SCole Faust    register = 0
95*4b9c6d91SCole Faust    program_counter = 0
96*4b9c6d91SCole Faust    cost = 0
97*4b9c6d91SCole Faust    while program_counter < len(instructions):
98*4b9c6d91SCole Faust        ins = instructions[program_counter]
99*4b9c6d91SCole Faust        program_counter += 1
100*4b9c6d91SCole Faust        cost += 1
101*4b9c6d91SCole Faust        if ins.code == BPF_LD | BPF_W | BPF_ABS:
102*4b9c6d91SCole Faust            register = struct.unpack('I', input_memory[ins.k:ins.k + 4])[0]
103*4b9c6d91SCole Faust        elif ins.code == BPF_JMP | BPF_JA | BPF_K:
104*4b9c6d91SCole Faust            program_counter += ins.k
105*4b9c6d91SCole Faust        elif ins.code == BPF_JMP | BPF_JEQ | BPF_K:
106*4b9c6d91SCole Faust            if register == ins.k:
107*4b9c6d91SCole Faust                program_counter += ins.jt
108*4b9c6d91SCole Faust            else:
109*4b9c6d91SCole Faust                program_counter += ins.jf
110*4b9c6d91SCole Faust        elif ins.code == BPF_JMP | BPF_JGT | BPF_K:
111*4b9c6d91SCole Faust            if register > ins.k:
112*4b9c6d91SCole Faust                program_counter += ins.jt
113*4b9c6d91SCole Faust            else:
114*4b9c6d91SCole Faust                program_counter += ins.jf
115*4b9c6d91SCole Faust        elif ins.code == BPF_JMP | BPF_JGE | BPF_K:
116*4b9c6d91SCole Faust            if register >= ins.k:
117*4b9c6d91SCole Faust                program_counter += ins.jt
118*4b9c6d91SCole Faust            else:
119*4b9c6d91SCole Faust                program_counter += ins.jf
120*4b9c6d91SCole Faust        elif ins.code == BPF_JMP | BPF_JSET | BPF_K:
121*4b9c6d91SCole Faust            if register & ins.k != 0:
122*4b9c6d91SCole Faust                program_counter += ins.jt
123*4b9c6d91SCole Faust            else:
124*4b9c6d91SCole Faust                program_counter += ins.jf
125*4b9c6d91SCole Faust        elif ins.code == BPF_RET:
126*4b9c6d91SCole Faust            if ins.k == SECCOMP_RET_KILL_PROCESS:
127*4b9c6d91SCole Faust                return (cost, 'KILL_PROCESS')
128*4b9c6d91SCole Faust            if ins.k == SECCOMP_RET_KILL_THREAD:
129*4b9c6d91SCole Faust                return (cost, 'KILL_THREAD')
130*4b9c6d91SCole Faust            if ins.k == SECCOMP_RET_TRAP:
131*4b9c6d91SCole Faust                return (cost, 'TRAP')
132*4b9c6d91SCole Faust            if (ins.k & SECCOMP_RET_ACTION_FULL) == SECCOMP_RET_ERRNO:
133*4b9c6d91SCole Faust                return (cost, 'ERRNO', ins.k & SECCOMP_RET_DATA)
134*4b9c6d91SCole Faust            if ins.k == SECCOMP_RET_TRACE:
135*4b9c6d91SCole Faust                return (cost, 'TRACE')
136*4b9c6d91SCole Faust            if ins.k == SECCOMP_RET_USER_NOTIF:
137*4b9c6d91SCole Faust                return (cost, 'USER_NOTIF')
138*4b9c6d91SCole Faust            if ins.k == SECCOMP_RET_LOG:
139*4b9c6d91SCole Faust                return (cost, 'LOG')
140*4b9c6d91SCole Faust            if ins.k == SECCOMP_RET_ALLOW:
141*4b9c6d91SCole Faust                return (cost, 'ALLOW')
142*4b9c6d91SCole Faust            raise Exception('unknown return %#x' % ins.k)
143*4b9c6d91SCole Faust        else:
144*4b9c6d91SCole Faust            raise Exception('unknown instruction %r' % (ins, ))
145*4b9c6d91SCole Faust    raise Exception('out-of-bounds')
146*4b9c6d91SCole Faust
147*4b9c6d91SCole Faust
148*4b9c6d91SCole Faustclass SockFilter(
149*4b9c6d91SCole Faust        collections.namedtuple('SockFilter', ['code', 'jt', 'jf', 'k'])):
150*4b9c6d91SCole Faust    """A representation of struct sock_filter."""
151*4b9c6d91SCole Faust
152*4b9c6d91SCole Faust    __slots__ = ()
153*4b9c6d91SCole Faust
154*4b9c6d91SCole Faust    def encode(self):
155*4b9c6d91SCole Faust        """Return an encoded version of the SockFilter."""
156*4b9c6d91SCole Faust        return struct.pack('HBBI', self.code, self.jt, self.jf, self.k)
157*4b9c6d91SCole Faust
158*4b9c6d91SCole Faust
159*4b9c6d91SCole Faustclass AbstractBlock(abc.ABC):
160*4b9c6d91SCole Faust    """A class that implements the visitor pattern."""
161*4b9c6d91SCole Faust
162*4b9c6d91SCole Faust    def __init__(self):
163*4b9c6d91SCole Faust        super().__init__()
164*4b9c6d91SCole Faust
165*4b9c6d91SCole Faust    @abc.abstractmethod
166*4b9c6d91SCole Faust    def accept(self, visitor):
167*4b9c6d91SCole Faust        pass
168*4b9c6d91SCole Faust
169*4b9c6d91SCole Faust
170*4b9c6d91SCole Faustclass BasicBlock(AbstractBlock):
171*4b9c6d91SCole Faust    """A concrete implementation of AbstractBlock that has been compiled."""
172*4b9c6d91SCole Faust
173*4b9c6d91SCole Faust    def __init__(self, instructions):
174*4b9c6d91SCole Faust        super().__init__()
175*4b9c6d91SCole Faust        self._instructions = instructions
176*4b9c6d91SCole Faust
177*4b9c6d91SCole Faust    def accept(self, visitor):
178*4b9c6d91SCole Faust        if visitor.visited(self):
179*4b9c6d91SCole Faust            return
180*4b9c6d91SCole Faust        visitor.visit(self)
181*4b9c6d91SCole Faust
182*4b9c6d91SCole Faust    @property
183*4b9c6d91SCole Faust    def instructions(self):
184*4b9c6d91SCole Faust        return self._instructions
185*4b9c6d91SCole Faust
186*4b9c6d91SCole Faust    @property
187*4b9c6d91SCole Faust    def opcodes(self):
188*4b9c6d91SCole Faust        return b''.join(i.encode() for i in self._instructions)
189*4b9c6d91SCole Faust
190*4b9c6d91SCole Faust    def __eq__(self, o):
191*4b9c6d91SCole Faust        if not isinstance(o, BasicBlock):
192*4b9c6d91SCole Faust            return False
193*4b9c6d91SCole Faust        return self._instructions == o._instructions
194*4b9c6d91SCole Faust
195*4b9c6d91SCole Faust
196*4b9c6d91SCole Faustclass KillProcess(BasicBlock):
197*4b9c6d91SCole Faust    """A BasicBlock that unconditionally returns KILL_PROCESS."""
198*4b9c6d91SCole Faust
199*4b9c6d91SCole Faust    def __init__(self):
200*4b9c6d91SCole Faust        super().__init__(
201*4b9c6d91SCole Faust            [SockFilter(BPF_RET, 0x00, 0x00, SECCOMP_RET_KILL_PROCESS)])
202*4b9c6d91SCole Faust
203*4b9c6d91SCole Faust
204*4b9c6d91SCole Faustclass KillThread(BasicBlock):
205*4b9c6d91SCole Faust    """A BasicBlock that unconditionally returns KILL_THREAD."""
206*4b9c6d91SCole Faust
207*4b9c6d91SCole Faust    def __init__(self):
208*4b9c6d91SCole Faust        super().__init__(
209*4b9c6d91SCole Faust            [SockFilter(BPF_RET, 0x00, 0x00, SECCOMP_RET_KILL_THREAD)])
210*4b9c6d91SCole Faust
211*4b9c6d91SCole Faust
212*4b9c6d91SCole Faustclass Trap(BasicBlock):
213*4b9c6d91SCole Faust    """A BasicBlock that unconditionally returns TRAP."""
214*4b9c6d91SCole Faust
215*4b9c6d91SCole Faust    def __init__(self):
216*4b9c6d91SCole Faust        super().__init__([SockFilter(BPF_RET, 0x00, 0x00, SECCOMP_RET_TRAP)])
217*4b9c6d91SCole Faust
218*4b9c6d91SCole Faust
219*4b9c6d91SCole Faustclass Trace(BasicBlock):
220*4b9c6d91SCole Faust    """A BasicBlock that unconditionally returns TRACE."""
221*4b9c6d91SCole Faust
222*4b9c6d91SCole Faust    def __init__(self):
223*4b9c6d91SCole Faust        super().__init__([SockFilter(BPF_RET, 0x00, 0x00, SECCOMP_RET_TRACE)])
224*4b9c6d91SCole Faust
225*4b9c6d91SCole Faust
226*4b9c6d91SCole Faustclass UserNotify(BasicBlock):
227*4b9c6d91SCole Faust    """A BasicBlock that unconditionally returns USER_NOTIF."""
228*4b9c6d91SCole Faust
229*4b9c6d91SCole Faust    def __init__(self):
230*4b9c6d91SCole Faust        super().__init__([SockFilter(BPF_RET, 0x00, 0x00, SECCOMP_RET_USER_NOTIF)])
231*4b9c6d91SCole Faust
232*4b9c6d91SCole Faust
233*4b9c6d91SCole Faustclass Log(BasicBlock):
234*4b9c6d91SCole Faust    """A BasicBlock that unconditionally returns LOG."""
235*4b9c6d91SCole Faust
236*4b9c6d91SCole Faust    def __init__(self):
237*4b9c6d91SCole Faust        super().__init__([SockFilter(BPF_RET, 0x00, 0x00, SECCOMP_RET_LOG)])
238*4b9c6d91SCole Faust
239*4b9c6d91SCole Faust
240*4b9c6d91SCole Faustclass ReturnErrno(BasicBlock):
241*4b9c6d91SCole Faust    """A BasicBlock that unconditionally returns the specified errno."""
242*4b9c6d91SCole Faust
243*4b9c6d91SCole Faust    def __init__(self, errno):
244*4b9c6d91SCole Faust        super().__init__([
245*4b9c6d91SCole Faust            SockFilter(BPF_RET, 0x00, 0x00,
246*4b9c6d91SCole Faust                       SECCOMP_RET_ERRNO | (errno & SECCOMP_RET_DATA))
247*4b9c6d91SCole Faust        ])
248*4b9c6d91SCole Faust        self.errno = errno
249*4b9c6d91SCole Faust
250*4b9c6d91SCole Faust
251*4b9c6d91SCole Faustclass Allow(BasicBlock):
252*4b9c6d91SCole Faust    """A BasicBlock that unconditionally returns ALLOW."""
253*4b9c6d91SCole Faust
254*4b9c6d91SCole Faust    def __init__(self):
255*4b9c6d91SCole Faust        super().__init__([SockFilter(BPF_RET, 0x00, 0x00, SECCOMP_RET_ALLOW)])
256*4b9c6d91SCole Faust
257*4b9c6d91SCole Faust
258*4b9c6d91SCole Faustclass ValidateArch(AbstractBlock):
259*4b9c6d91SCole Faust    """An AbstractBlock that validates the architecture."""
260*4b9c6d91SCole Faust
261*4b9c6d91SCole Faust    def __init__(self, next_block):
262*4b9c6d91SCole Faust        super().__init__()
263*4b9c6d91SCole Faust        self.next_block = next_block
264*4b9c6d91SCole Faust
265*4b9c6d91SCole Faust    def accept(self, visitor):
266*4b9c6d91SCole Faust        if visitor.visited(self):
267*4b9c6d91SCole Faust            return
268*4b9c6d91SCole Faust        self.next_block.accept(visitor)
269*4b9c6d91SCole Faust        visitor.visit(self)
270*4b9c6d91SCole Faust
271*4b9c6d91SCole Faust
272*4b9c6d91SCole Faustclass SyscallEntry(AbstractBlock):
273*4b9c6d91SCole Faust    """An abstract block that represents a syscall comparison in a DAG."""
274*4b9c6d91SCole Faust
275*4b9c6d91SCole Faust    def __init__(self, syscall_number, jt, jf, *, op=BPF_JEQ):
276*4b9c6d91SCole Faust        super().__init__()
277*4b9c6d91SCole Faust        self.op = op
278*4b9c6d91SCole Faust        self.syscall_number = syscall_number
279*4b9c6d91SCole Faust        self.jt = jt
280*4b9c6d91SCole Faust        self.jf = jf
281*4b9c6d91SCole Faust
282*4b9c6d91SCole Faust    def __lt__(self, o):
283*4b9c6d91SCole Faust        # Defined because we want to compare tuples that contain SyscallEntries.
284*4b9c6d91SCole Faust        return False
285*4b9c6d91SCole Faust
286*4b9c6d91SCole Faust    def __gt__(self, o):
287*4b9c6d91SCole Faust        # Defined because we want to compare tuples that contain SyscallEntries.
288*4b9c6d91SCole Faust        return False
289*4b9c6d91SCole Faust
290*4b9c6d91SCole Faust    def accept(self, visitor):
291*4b9c6d91SCole Faust        if visitor.visited(self):
292*4b9c6d91SCole Faust            return
293*4b9c6d91SCole Faust        self.jt.accept(visitor)
294*4b9c6d91SCole Faust        self.jf.accept(visitor)
295*4b9c6d91SCole Faust        visitor.visit(self)
296*4b9c6d91SCole Faust
297*4b9c6d91SCole Faust    def __lt__(self, o):
298*4b9c6d91SCole Faust        # Defined because we want to compare tuples that contain SyscallEntries.
299*4b9c6d91SCole Faust        return False
300*4b9c6d91SCole Faust
301*4b9c6d91SCole Faust    def __gt__(self, o):
302*4b9c6d91SCole Faust        # Defined because we want to compare tuples that contain SyscallEntries.
303*4b9c6d91SCole Faust        return False
304*4b9c6d91SCole Faust
305*4b9c6d91SCole Faust
306*4b9c6d91SCole Faustclass WideAtom(AbstractBlock):
307*4b9c6d91SCole Faust    """A BasicBlock that represents a 32-bit wide atom."""
308*4b9c6d91SCole Faust
309*4b9c6d91SCole Faust    def __init__(self, arg_offset, op, value, jt, jf):
310*4b9c6d91SCole Faust        super().__init__()
311*4b9c6d91SCole Faust        self.arg_offset = arg_offset
312*4b9c6d91SCole Faust        self.op = op
313*4b9c6d91SCole Faust        self.value = value
314*4b9c6d91SCole Faust        self.jt = jt
315*4b9c6d91SCole Faust        self.jf = jf
316*4b9c6d91SCole Faust
317*4b9c6d91SCole Faust    def accept(self, visitor):
318*4b9c6d91SCole Faust        if visitor.visited(self):
319*4b9c6d91SCole Faust            return
320*4b9c6d91SCole Faust        self.jt.accept(visitor)
321*4b9c6d91SCole Faust        self.jf.accept(visitor)
322*4b9c6d91SCole Faust        visitor.visit(self)
323*4b9c6d91SCole Faust
324*4b9c6d91SCole Faust
325*4b9c6d91SCole Faustclass Atom(AbstractBlock):
326*4b9c6d91SCole Faust    """A BasicBlock that represents an atom (a simple comparison operation)."""
327*4b9c6d91SCole Faust
328*4b9c6d91SCole Faust    def __init__(self, arg_index, op, value, jt, jf):
329*4b9c6d91SCole Faust        super().__init__()
330*4b9c6d91SCole Faust        if op == '==':
331*4b9c6d91SCole Faust            op = BPF_JEQ
332*4b9c6d91SCole Faust        elif op == '!=':
333*4b9c6d91SCole Faust            op = BPF_JEQ
334*4b9c6d91SCole Faust            jt, jf = jf, jt
335*4b9c6d91SCole Faust        elif op == '>':
336*4b9c6d91SCole Faust            op = BPF_JGT
337*4b9c6d91SCole Faust        elif op == '<=':
338*4b9c6d91SCole Faust            op = BPF_JGT
339*4b9c6d91SCole Faust            jt, jf = jf, jt
340*4b9c6d91SCole Faust        elif op == '>=':
341*4b9c6d91SCole Faust            op = BPF_JGE
342*4b9c6d91SCole Faust        elif op == '<':
343*4b9c6d91SCole Faust            op = BPF_JGE
344*4b9c6d91SCole Faust            jt, jf = jf, jt
345*4b9c6d91SCole Faust        elif op == '&':
346*4b9c6d91SCole Faust            op = BPF_JSET
347*4b9c6d91SCole Faust        elif op == 'in':
348*4b9c6d91SCole Faust            op = BPF_JSET
349*4b9c6d91SCole Faust            # The mask is negated, so the comparison will be true when the
350*4b9c6d91SCole Faust            # argument includes a flag that wasn't listed in the original
351*4b9c6d91SCole Faust            # (non-negated) mask. This would be the failure case, so we switch
352*4b9c6d91SCole Faust            # |jt| and |jf|.
353*4b9c6d91SCole Faust            value = (~value) & ((1 << 64) - 1)
354*4b9c6d91SCole Faust            jt, jf = jf, jt
355*4b9c6d91SCole Faust        else:
356*4b9c6d91SCole Faust            raise Exception('Unknown operator %s' % op)
357*4b9c6d91SCole Faust
358*4b9c6d91SCole Faust        self.arg_index = arg_index
359*4b9c6d91SCole Faust        self.op = op
360*4b9c6d91SCole Faust        self.jt = jt
361*4b9c6d91SCole Faust        self.jf = jf
362*4b9c6d91SCole Faust        self.value = value
363*4b9c6d91SCole Faust
364*4b9c6d91SCole Faust    def accept(self, visitor):
365*4b9c6d91SCole Faust        if visitor.visited(self):
366*4b9c6d91SCole Faust            return
367*4b9c6d91SCole Faust        self.jt.accept(visitor)
368*4b9c6d91SCole Faust        self.jf.accept(visitor)
369*4b9c6d91SCole Faust        visitor.visit(self)
370*4b9c6d91SCole Faust
371*4b9c6d91SCole Faust
372*4b9c6d91SCole Faustclass AbstractVisitor(abc.ABC):
373*4b9c6d91SCole Faust    """An abstract visitor."""
374*4b9c6d91SCole Faust
375*4b9c6d91SCole Faust    def __init__(self):
376*4b9c6d91SCole Faust        self._visited = set()
377*4b9c6d91SCole Faust
378*4b9c6d91SCole Faust    def visited(self, block):
379*4b9c6d91SCole Faust        if id(block) in self._visited:
380*4b9c6d91SCole Faust            return True
381*4b9c6d91SCole Faust        self._visited.add(id(block))
382*4b9c6d91SCole Faust        return False
383*4b9c6d91SCole Faust
384*4b9c6d91SCole Faust    def process(self, block):
385*4b9c6d91SCole Faust        block.accept(self)
386*4b9c6d91SCole Faust        return block
387*4b9c6d91SCole Faust
388*4b9c6d91SCole Faust    def visit(self, block):
389*4b9c6d91SCole Faust        if isinstance(block, KillProcess):
390*4b9c6d91SCole Faust            self.visitKillProcess(block)
391*4b9c6d91SCole Faust        elif isinstance(block, KillThread):
392*4b9c6d91SCole Faust            self.visitKillThread(block)
393*4b9c6d91SCole Faust        elif isinstance(block, Trap):
394*4b9c6d91SCole Faust            self.visitTrap(block)
395*4b9c6d91SCole Faust        elif isinstance(block, ReturnErrno):
396*4b9c6d91SCole Faust            self.visitReturnErrno(block)
397*4b9c6d91SCole Faust        elif isinstance(block, Trace):
398*4b9c6d91SCole Faust            self.visitTrace(block)
399*4b9c6d91SCole Faust        elif isinstance(block, UserNotify):
400*4b9c6d91SCole Faust            self.visitUserNotify(block)
401*4b9c6d91SCole Faust        elif isinstance(block, Log):
402*4b9c6d91SCole Faust            self.visitLog(block)
403*4b9c6d91SCole Faust        elif isinstance(block, Allow):
404*4b9c6d91SCole Faust            self.visitAllow(block)
405*4b9c6d91SCole Faust        elif isinstance(block, BasicBlock):
406*4b9c6d91SCole Faust            self.visitBasicBlock(block)
407*4b9c6d91SCole Faust        elif isinstance(block, ValidateArch):
408*4b9c6d91SCole Faust            self.visitValidateArch(block)
409*4b9c6d91SCole Faust        elif isinstance(block, SyscallEntry):
410*4b9c6d91SCole Faust            self.visitSyscallEntry(block)
411*4b9c6d91SCole Faust        elif isinstance(block, WideAtom):
412*4b9c6d91SCole Faust            self.visitWideAtom(block)
413*4b9c6d91SCole Faust        elif isinstance(block, Atom):
414*4b9c6d91SCole Faust            self.visitAtom(block)
415*4b9c6d91SCole Faust        else:
416*4b9c6d91SCole Faust            raise Exception('Unknown block type: %r' % block)
417*4b9c6d91SCole Faust
418*4b9c6d91SCole Faust    @abc.abstractmethod
419*4b9c6d91SCole Faust    def visitKillProcess(self, block):
420*4b9c6d91SCole Faust        pass
421*4b9c6d91SCole Faust
422*4b9c6d91SCole Faust    @abc.abstractmethod
423*4b9c6d91SCole Faust    def visitKillThread(self, block):
424*4b9c6d91SCole Faust        pass
425*4b9c6d91SCole Faust
426*4b9c6d91SCole Faust    @abc.abstractmethod
427*4b9c6d91SCole Faust    def visitTrap(self, block):
428*4b9c6d91SCole Faust        pass
429*4b9c6d91SCole Faust
430*4b9c6d91SCole Faust    @abc.abstractmethod
431*4b9c6d91SCole Faust    def visitReturnErrno(self, block):
432*4b9c6d91SCole Faust        pass
433*4b9c6d91SCole Faust
434*4b9c6d91SCole Faust    @abc.abstractmethod
435*4b9c6d91SCole Faust    def visitTrace(self, block):
436*4b9c6d91SCole Faust        pass
437*4b9c6d91SCole Faust
438*4b9c6d91SCole Faust    @abc.abstractmethod
439*4b9c6d91SCole Faust    def visitUserNotify(self, block):
440*4b9c6d91SCole Faust        pass
441*4b9c6d91SCole Faust
442*4b9c6d91SCole Faust    @abc.abstractmethod
443*4b9c6d91SCole Faust    def visitLog(self, block):
444*4b9c6d91SCole Faust        pass
445*4b9c6d91SCole Faust
446*4b9c6d91SCole Faust    @abc.abstractmethod
447*4b9c6d91SCole Faust    def visitAllow(self, block):
448*4b9c6d91SCole Faust        pass
449*4b9c6d91SCole Faust
450*4b9c6d91SCole Faust    @abc.abstractmethod
451*4b9c6d91SCole Faust    def visitBasicBlock(self, block):
452*4b9c6d91SCole Faust        pass
453*4b9c6d91SCole Faust
454*4b9c6d91SCole Faust    @abc.abstractmethod
455*4b9c6d91SCole Faust    def visitValidateArch(self, block):
456*4b9c6d91SCole Faust        pass
457*4b9c6d91SCole Faust
458*4b9c6d91SCole Faust    @abc.abstractmethod
459*4b9c6d91SCole Faust    def visitSyscallEntry(self, block):
460*4b9c6d91SCole Faust        pass
461*4b9c6d91SCole Faust
462*4b9c6d91SCole Faust    @abc.abstractmethod
463*4b9c6d91SCole Faust    def visitWideAtom(self, block):
464*4b9c6d91SCole Faust        pass
465*4b9c6d91SCole Faust
466*4b9c6d91SCole Faust    @abc.abstractmethod
467*4b9c6d91SCole Faust    def visitAtom(self, block):
468*4b9c6d91SCole Faust        pass
469*4b9c6d91SCole Faust
470*4b9c6d91SCole Faust
471*4b9c6d91SCole Faustclass CopyingVisitor(AbstractVisitor):
472*4b9c6d91SCole Faust    """A visitor that copies Blocks."""
473*4b9c6d91SCole Faust
474*4b9c6d91SCole Faust    def __init__(self):
475*4b9c6d91SCole Faust        super().__init__()
476*4b9c6d91SCole Faust        self._mapping = {}
477*4b9c6d91SCole Faust
478*4b9c6d91SCole Faust    def process(self, block):
479*4b9c6d91SCole Faust        self._mapping = {}
480*4b9c6d91SCole Faust        block.accept(self)
481*4b9c6d91SCole Faust        return self._mapping[id(block)]
482*4b9c6d91SCole Faust
483*4b9c6d91SCole Faust    def visitKillProcess(self, block):
484*4b9c6d91SCole Faust        assert id(block) not in self._mapping
485*4b9c6d91SCole Faust        self._mapping[id(block)] = KillProcess()
486*4b9c6d91SCole Faust
487*4b9c6d91SCole Faust    def visitKillThread(self, block):
488*4b9c6d91SCole Faust        assert id(block) not in self._mapping
489*4b9c6d91SCole Faust        self._mapping[id(block)] = KillThread()
490*4b9c6d91SCole Faust
491*4b9c6d91SCole Faust    def visitTrap(self, block):
492*4b9c6d91SCole Faust        assert id(block) not in self._mapping
493*4b9c6d91SCole Faust        self._mapping[id(block)] = Trap()
494*4b9c6d91SCole Faust
495*4b9c6d91SCole Faust    def visitReturnErrno(self, block):
496*4b9c6d91SCole Faust        assert id(block) not in self._mapping
497*4b9c6d91SCole Faust        self._mapping[id(block)] = ReturnErrno(block.errno)
498*4b9c6d91SCole Faust
499*4b9c6d91SCole Faust    def visitTrace(self, block):
500*4b9c6d91SCole Faust        assert id(block) not in self._mapping
501*4b9c6d91SCole Faust        self._mapping[id(block)] = Trace()
502*4b9c6d91SCole Faust
503*4b9c6d91SCole Faust    def visitUserNotify(self, block):
504*4b9c6d91SCole Faust        assert id(block) not in self._mapping
505*4b9c6d91SCole Faust        self._mapping[id(block)] = UserNotify()
506*4b9c6d91SCole Faust
507*4b9c6d91SCole Faust    def visitLog(self, block):
508*4b9c6d91SCole Faust        assert id(block) not in self._mapping
509*4b9c6d91SCole Faust        self._mapping[id(block)] = Log()
510*4b9c6d91SCole Faust
511*4b9c6d91SCole Faust    def visitAllow(self, block):
512*4b9c6d91SCole Faust        assert id(block) not in self._mapping
513*4b9c6d91SCole Faust        self._mapping[id(block)] = Allow()
514*4b9c6d91SCole Faust
515*4b9c6d91SCole Faust    def visitBasicBlock(self, block):
516*4b9c6d91SCole Faust        assert id(block) not in self._mapping
517*4b9c6d91SCole Faust        self._mapping[id(block)] = BasicBlock(block.instructions)
518*4b9c6d91SCole Faust
519*4b9c6d91SCole Faust    def visitValidateArch(self, block):
520*4b9c6d91SCole Faust        assert id(block) not in self._mapping
521*4b9c6d91SCole Faust        self._mapping[id(block)] = ValidateArch(
522*4b9c6d91SCole Faust            block.arch, self._mapping[id(block.next_block)])
523*4b9c6d91SCole Faust
524*4b9c6d91SCole Faust    def visitSyscallEntry(self, block):
525*4b9c6d91SCole Faust        assert id(block) not in self._mapping
526*4b9c6d91SCole Faust        self._mapping[id(block)] = SyscallEntry(
527*4b9c6d91SCole Faust            block.syscall_number,
528*4b9c6d91SCole Faust            self._mapping[id(block.jt)],
529*4b9c6d91SCole Faust            self._mapping[id(block.jf)],
530*4b9c6d91SCole Faust            op=block.op)
531*4b9c6d91SCole Faust
532*4b9c6d91SCole Faust    def visitWideAtom(self, block):
533*4b9c6d91SCole Faust        assert id(block) not in self._mapping
534*4b9c6d91SCole Faust        self._mapping[id(block)] = WideAtom(
535*4b9c6d91SCole Faust            block.arg_offset, block.op, block.value, self._mapping[id(
536*4b9c6d91SCole Faust                block.jt)], self._mapping[id(block.jf)])
537*4b9c6d91SCole Faust
538*4b9c6d91SCole Faust    def visitAtom(self, block):
539*4b9c6d91SCole Faust        assert id(block) not in self._mapping
540*4b9c6d91SCole Faust        self._mapping[id(block)] = Atom(block.arg_index, block.op, block.value,
541*4b9c6d91SCole Faust                                        self._mapping[id(block.jt)],
542*4b9c6d91SCole Faust                                        self._mapping[id(block.jf)])
543*4b9c6d91SCole Faust
544*4b9c6d91SCole Faust
545*4b9c6d91SCole Faustclass LoweringVisitor(CopyingVisitor):
546*4b9c6d91SCole Faust    """A visitor that lowers Atoms into WideAtoms."""
547*4b9c6d91SCole Faust
548*4b9c6d91SCole Faust    def __init__(self, *, arch):
549*4b9c6d91SCole Faust        super().__init__()
550*4b9c6d91SCole Faust        self._bits = arch.bits
551*4b9c6d91SCole Faust
552*4b9c6d91SCole Faust    def visitAtom(self, block):
553*4b9c6d91SCole Faust        assert id(block) not in self._mapping
554*4b9c6d91SCole Faust
555*4b9c6d91SCole Faust        lo = block.value & 0xFFFFFFFF
556*4b9c6d91SCole Faust        hi = (block.value >> 32) & 0xFFFFFFFF
557*4b9c6d91SCole Faust
558*4b9c6d91SCole Faust        lo_block = WideAtom(
559*4b9c6d91SCole Faust            arg_offset(block.arg_index, False), block.op, lo,
560*4b9c6d91SCole Faust            self._mapping[id(block.jt)], self._mapping[id(block.jf)])
561*4b9c6d91SCole Faust
562*4b9c6d91SCole Faust        if self._bits == 32:
563*4b9c6d91SCole Faust            self._mapping[id(block)] = lo_block
564*4b9c6d91SCole Faust            return
565*4b9c6d91SCole Faust
566*4b9c6d91SCole Faust        if block.op in (BPF_JGE, BPF_JGT):
567*4b9c6d91SCole Faust            # hi_1,lo_1 <op> hi_2,lo_2
568*4b9c6d91SCole Faust            #
569*4b9c6d91SCole Faust            # hi_1 > hi_2 || hi_1 == hi_2 && lo_1 <op> lo_2
570*4b9c6d91SCole Faust            if hi == 0:
571*4b9c6d91SCole Faust                # Special case: it's not needed to check whether |hi_1 == hi_2|,
572*4b9c6d91SCole Faust                # because it's true iff the JGT test fails.
573*4b9c6d91SCole Faust                self._mapping[id(block)] = WideAtom(
574*4b9c6d91SCole Faust                    arg_offset(block.arg_index, True), BPF_JGT, hi,
575*4b9c6d91SCole Faust                    self._mapping[id(block.jt)], lo_block)
576*4b9c6d91SCole Faust                return
577*4b9c6d91SCole Faust            hi_eq_block = WideAtom(
578*4b9c6d91SCole Faust                arg_offset(block.arg_index, True), BPF_JEQ, hi, lo_block,
579*4b9c6d91SCole Faust                self._mapping[id(block.jf)])
580*4b9c6d91SCole Faust            self._mapping[id(block)] = WideAtom(
581*4b9c6d91SCole Faust                arg_offset(block.arg_index, True), BPF_JGT, hi,
582*4b9c6d91SCole Faust                self._mapping[id(block.jt)], hi_eq_block)
583*4b9c6d91SCole Faust            return
584*4b9c6d91SCole Faust        if block.op == BPF_JSET:
585*4b9c6d91SCole Faust            # hi_1,lo_1 & hi_2,lo_2
586*4b9c6d91SCole Faust            #
587*4b9c6d91SCole Faust            # hi_1 & hi_2 || lo_1 & lo_2
588*4b9c6d91SCole Faust            if hi == 0:
589*4b9c6d91SCole Faust                # Special case: |hi_1 & hi_2| will never be True, so jump
590*4b9c6d91SCole Faust                # directly into the |lo_1 & lo_2| case.
591*4b9c6d91SCole Faust                self._mapping[id(block)] = lo_block
592*4b9c6d91SCole Faust                return
593*4b9c6d91SCole Faust            self._mapping[id(block)] = WideAtom(
594*4b9c6d91SCole Faust                arg_offset(block.arg_index, True), block.op, hi,
595*4b9c6d91SCole Faust                self._mapping[id(block.jt)], lo_block)
596*4b9c6d91SCole Faust            return
597*4b9c6d91SCole Faust
598*4b9c6d91SCole Faust        assert block.op == BPF_JEQ, block.op
599*4b9c6d91SCole Faust
600*4b9c6d91SCole Faust        # hi_1,lo_1 == hi_2,lo_2
601*4b9c6d91SCole Faust        #
602*4b9c6d91SCole Faust        # hi_1 == hi_2 && lo_1 == lo_2
603*4b9c6d91SCole Faust        self._mapping[id(block)] = WideAtom(
604*4b9c6d91SCole Faust            arg_offset(block.arg_index, True), block.op, hi, lo_block,
605*4b9c6d91SCole Faust            self._mapping[id(block.jf)])
606*4b9c6d91SCole Faust
607*4b9c6d91SCole Faust
608*4b9c6d91SCole Faustclass FlatteningVisitor:
609*4b9c6d91SCole Faust    """A visitor that flattens a DAG of Block objects."""
610*4b9c6d91SCole Faust
611*4b9c6d91SCole Faust    def __init__(self, *, arch, kill_action):
612*4b9c6d91SCole Faust        self._visited = set()
613*4b9c6d91SCole Faust        self._kill_action = kill_action
614*4b9c6d91SCole Faust        self._instructions = []
615*4b9c6d91SCole Faust        self._arch = arch
616*4b9c6d91SCole Faust        self._offsets = {}
617*4b9c6d91SCole Faust
618*4b9c6d91SCole Faust    @property
619*4b9c6d91SCole Faust    def result(self):
620*4b9c6d91SCole Faust        return BasicBlock(self._instructions)
621*4b9c6d91SCole Faust
622*4b9c6d91SCole Faust    def _distance(self, block):
623*4b9c6d91SCole Faust        distance = self._offsets[id(block)] + len(self._instructions)
624*4b9c6d91SCole Faust        assert distance >= 0
625*4b9c6d91SCole Faust        return distance
626*4b9c6d91SCole Faust
627*4b9c6d91SCole Faust    def _emit_load_arg(self, offset):
628*4b9c6d91SCole Faust        return [SockFilter(BPF_LD | BPF_W | BPF_ABS, 0, 0, offset)]
629*4b9c6d91SCole Faust
630*4b9c6d91SCole Faust    def _emit_jmp(self, op, value, jt_distance, jf_distance):
631*4b9c6d91SCole Faust        if jt_distance < 0x100 and jf_distance < 0x100:
632*4b9c6d91SCole Faust            return [
633*4b9c6d91SCole Faust                SockFilter(BPF_JMP | op | BPF_K, jt_distance, jf_distance,
634*4b9c6d91SCole Faust                           value),
635*4b9c6d91SCole Faust            ]
636*4b9c6d91SCole Faust        if jt_distance + 1 < 0x100:
637*4b9c6d91SCole Faust            return [
638*4b9c6d91SCole Faust                SockFilter(BPF_JMP | op | BPF_K, jt_distance + 1, 0, value),
639*4b9c6d91SCole Faust                SockFilter(BPF_JMP | BPF_JA, 0, 0, jf_distance),
640*4b9c6d91SCole Faust            ]
641*4b9c6d91SCole Faust        if jf_distance + 1 < 0x100:
642*4b9c6d91SCole Faust            return [
643*4b9c6d91SCole Faust                SockFilter(BPF_JMP | op | BPF_K, 0, jf_distance + 1, value),
644*4b9c6d91SCole Faust                SockFilter(BPF_JMP | BPF_JA, 0, 0, jt_distance),
645*4b9c6d91SCole Faust            ]
646*4b9c6d91SCole Faust        return [
647*4b9c6d91SCole Faust            SockFilter(BPF_JMP | op | BPF_K, 0, 1, value),
648*4b9c6d91SCole Faust            SockFilter(BPF_JMP | BPF_JA, 0, 0, jt_distance + 1),
649*4b9c6d91SCole Faust            SockFilter(BPF_JMP | BPF_JA, 0, 0, jf_distance),
650*4b9c6d91SCole Faust        ]
651*4b9c6d91SCole Faust
652*4b9c6d91SCole Faust    def visited(self, block):
653*4b9c6d91SCole Faust        if id(block) in self._visited:
654*4b9c6d91SCole Faust            return True
655*4b9c6d91SCole Faust        self._visited.add(id(block))
656*4b9c6d91SCole Faust        return False
657*4b9c6d91SCole Faust
658*4b9c6d91SCole Faust    def visit(self, block):
659*4b9c6d91SCole Faust        assert id(block) not in self._offsets
660*4b9c6d91SCole Faust
661*4b9c6d91SCole Faust        if isinstance(block, BasicBlock):
662*4b9c6d91SCole Faust            instructions = block.instructions
663*4b9c6d91SCole Faust        elif isinstance(block, ValidateArch):
664*4b9c6d91SCole Faust            instructions = [
665*4b9c6d91SCole Faust                SockFilter(BPF_LD | BPF_W | BPF_ABS, 0, 0, 4),
666*4b9c6d91SCole Faust                SockFilter(BPF_JMP | BPF_JEQ | BPF_K,
667*4b9c6d91SCole Faust                           self._distance(block.next_block) + 1, 0,
668*4b9c6d91SCole Faust                           self._arch.arch_nr),
669*4b9c6d91SCole Faust            ] + self._kill_action.instructions + [
670*4b9c6d91SCole Faust                SockFilter(BPF_LD | BPF_W | BPF_ABS, 0, 0, 0),
671*4b9c6d91SCole Faust            ]
672*4b9c6d91SCole Faust        elif isinstance(block, SyscallEntry):
673*4b9c6d91SCole Faust            instructions = self._emit_jmp(block.op, block.syscall_number,
674*4b9c6d91SCole Faust                                          self._distance(block.jt),
675*4b9c6d91SCole Faust                                          self._distance(block.jf))
676*4b9c6d91SCole Faust        elif isinstance(block, WideAtom):
677*4b9c6d91SCole Faust            instructions = (
678*4b9c6d91SCole Faust                self._emit_load_arg(block.arg_offset) + self._emit_jmp(
679*4b9c6d91SCole Faust                    block.op, block.value, self._distance(block.jt),
680*4b9c6d91SCole Faust                    self._distance(block.jf)))
681*4b9c6d91SCole Faust        else:
682*4b9c6d91SCole Faust            raise Exception('Unknown block type: %r' % block)
683*4b9c6d91SCole Faust
684*4b9c6d91SCole Faust        self._instructions = instructions + self._instructions
685*4b9c6d91SCole Faust        self._offsets[id(block)] = -len(self._instructions)
686*4b9c6d91SCole Faust        return
687*4b9c6d91SCole Faust
688*4b9c6d91SCole Faust
689*4b9c6d91SCole Faustclass ArgFilterForwardingVisitor:
690*4b9c6d91SCole Faust    """A visitor that forwards visitation to all arg filters."""
691*4b9c6d91SCole Faust
692*4b9c6d91SCole Faust    def __init__(self, visitor):
693*4b9c6d91SCole Faust        self._visited = set()
694*4b9c6d91SCole Faust        self.visitor = visitor
695*4b9c6d91SCole Faust
696*4b9c6d91SCole Faust    def visited(self, block):
697*4b9c6d91SCole Faust        if id(block) in self._visited:
698*4b9c6d91SCole Faust            return True
699*4b9c6d91SCole Faust        self._visited.add(id(block))
700*4b9c6d91SCole Faust        return False
701*4b9c6d91SCole Faust
702*4b9c6d91SCole Faust    def visit(self, block):
703*4b9c6d91SCole Faust        # All arg filters are BasicBlocks.
704*4b9c6d91SCole Faust        if not isinstance(block, BasicBlock):
705*4b9c6d91SCole Faust            return
706*4b9c6d91SCole Faust        # But the ALLOW, KILL_PROCESS, TRAP, etc. actions are too and we don't
707*4b9c6d91SCole Faust        # want to visit them just yet.
708*4b9c6d91SCole Faust        if (isinstance(block, KillProcess) or isinstance(block, KillThread)
709*4b9c6d91SCole Faust                or isinstance(block, Trap) or isinstance(block, ReturnErrno)
710*4b9c6d91SCole Faust                or isinstance(block, Trace) or isinstance(block, UserNotify)
711*4b9c6d91SCole Faust                or isinstance(block, Log) or isinstance(block, Allow)):
712*4b9c6d91SCole Faust            return
713*4b9c6d91SCole Faust        block.accept(self.visitor)
714