1*4b9c6d91SCole Faust#!/usr/bin/env python3 2*4b9c6d91SCole Faust# -*- coding: utf-8 -*- 3*4b9c6d91SCole Faust# 4*4b9c6d91SCole Faust# Copyright (C) 2018 The Android Open Source Project 5*4b9c6d91SCole Faust# 6*4b9c6d91SCole Faust# Licensed under the Apache License, Version 2.0 (the "License"); 7*4b9c6d91SCole Faust# you may not use this file except in compliance with the License. 8*4b9c6d91SCole Faust# You may obtain a copy of the License at 9*4b9c6d91SCole Faust# 10*4b9c6d91SCole Faust# http://www.apache.org/licenses/LICENSE-2.0 11*4b9c6d91SCole Faust# 12*4b9c6d91SCole Faust# Unless required by applicable law or agreed to in writing, software 13*4b9c6d91SCole Faust# distributed under the License is distributed on an "AS IS" BASIS, 14*4b9c6d91SCole Faust# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15*4b9c6d91SCole Faust# See the License for the specific language governing permissions and 16*4b9c6d91SCole Faust# limitations under the License. 17*4b9c6d91SCole Faust"""Tools to interact with BPF programs.""" 18*4b9c6d91SCole Faust 19*4b9c6d91SCole Faustimport abc 20*4b9c6d91SCole Faustimport collections 21*4b9c6d91SCole Faustimport struct 22*4b9c6d91SCole Faust 23*4b9c6d91SCole Faust# This comes from syscall(2). Most architectures only support passing 6 args to 24*4b9c6d91SCole Faust# syscalls, but ARM supports passing 7. 25*4b9c6d91SCole FaustMAX_SYSCALL_ARGUMENTS = 7 26*4b9c6d91SCole Faust 27*4b9c6d91SCole Faust# The following fields were copied from <linux/bpf_common.h>: 28*4b9c6d91SCole Faust 29*4b9c6d91SCole Faust# Instruction classes 30*4b9c6d91SCole FaustBPF_LD = 0x00 31*4b9c6d91SCole FaustBPF_LDX = 0x01 32*4b9c6d91SCole FaustBPF_ST = 0x02 33*4b9c6d91SCole FaustBPF_STX = 0x03 34*4b9c6d91SCole FaustBPF_ALU = 0x04 35*4b9c6d91SCole FaustBPF_JMP = 0x05 36*4b9c6d91SCole FaustBPF_RET = 0x06 37*4b9c6d91SCole FaustBPF_MISC = 0x07 38*4b9c6d91SCole Faust 39*4b9c6d91SCole Faust# LD/LDX fields. 40*4b9c6d91SCole Faust# Size 41*4b9c6d91SCole FaustBPF_W = 0x00 42*4b9c6d91SCole FaustBPF_H = 0x08 43*4b9c6d91SCole FaustBPF_B = 0x10 44*4b9c6d91SCole Faust# Mode 45*4b9c6d91SCole FaustBPF_IMM = 0x00 46*4b9c6d91SCole FaustBPF_ABS = 0x20 47*4b9c6d91SCole FaustBPF_IND = 0x40 48*4b9c6d91SCole FaustBPF_MEM = 0x60 49*4b9c6d91SCole FaustBPF_LEN = 0x80 50*4b9c6d91SCole FaustBPF_MSH = 0xa0 51*4b9c6d91SCole Faust 52*4b9c6d91SCole Faust# JMP fields. 53*4b9c6d91SCole FaustBPF_JA = 0x00 54*4b9c6d91SCole FaustBPF_JEQ = 0x10 55*4b9c6d91SCole FaustBPF_JGT = 0x20 56*4b9c6d91SCole FaustBPF_JGE = 0x30 57*4b9c6d91SCole FaustBPF_JSET = 0x40 58*4b9c6d91SCole Faust 59*4b9c6d91SCole Faust# Source 60*4b9c6d91SCole FaustBPF_K = 0x00 61*4b9c6d91SCole FaustBPF_X = 0x08 62*4b9c6d91SCole Faust 63*4b9c6d91SCole FaustBPF_MAXINSNS = 4096 64*4b9c6d91SCole Faust 65*4b9c6d91SCole Faust# The following fields were copied from <linux/seccomp.h>: 66*4b9c6d91SCole Faust 67*4b9c6d91SCole FaustSECCOMP_RET_KILL_PROCESS = 0x80000000 68*4b9c6d91SCole FaustSECCOMP_RET_KILL_THREAD = 0x00000000 69*4b9c6d91SCole FaustSECCOMP_RET_TRAP = 0x00030000 70*4b9c6d91SCole FaustSECCOMP_RET_ERRNO = 0x00050000 71*4b9c6d91SCole FaustSECCOMP_RET_TRACE = 0x7ff00000 72*4b9c6d91SCole FaustSECCOMP_RET_USER_NOTIF = 0x7fc00000 73*4b9c6d91SCole FaustSECCOMP_RET_LOG = 0x7ffc0000 74*4b9c6d91SCole FaustSECCOMP_RET_ALLOW = 0x7fff0000 75*4b9c6d91SCole Faust 76*4b9c6d91SCole FaustSECCOMP_RET_ACTION_FULL = 0xffff0000 77*4b9c6d91SCole FaustSECCOMP_RET_DATA = 0x0000ffff 78*4b9c6d91SCole Faust 79*4b9c6d91SCole Faust 80*4b9c6d91SCole Faustdef arg_offset(arg_index, hi=False): 81*4b9c6d91SCole Faust """Return the BPF_LD|BPF_W|BPF_ABS addressing-friendly register offset.""" 82*4b9c6d91SCole Faust offsetof_args = 4 + 4 + 8 83*4b9c6d91SCole Faust arg_width = 8 84*4b9c6d91SCole Faust return offsetof_args + arg_width * arg_index + (arg_width // 2) * hi 85*4b9c6d91SCole Faust 86*4b9c6d91SCole Faust 87*4b9c6d91SCole Faustdef simulate(instructions, arch, syscall_number, *args): 88*4b9c6d91SCole Faust """Simulate a BPF program with the given arguments.""" 89*4b9c6d91SCole Faust args = ((args + (0, ) * 90*4b9c6d91SCole Faust (MAX_SYSCALL_ARGUMENTS - len(args)))[:MAX_SYSCALL_ARGUMENTS]) 91*4b9c6d91SCole Faust input_memory = struct.pack('IIQ' + 'Q' * MAX_SYSCALL_ARGUMENTS, 92*4b9c6d91SCole Faust syscall_number, arch, 0, *args) 93*4b9c6d91SCole Faust 94*4b9c6d91SCole Faust register = 0 95*4b9c6d91SCole Faust program_counter = 0 96*4b9c6d91SCole Faust cost = 0 97*4b9c6d91SCole Faust while program_counter < len(instructions): 98*4b9c6d91SCole Faust ins = instructions[program_counter] 99*4b9c6d91SCole Faust program_counter += 1 100*4b9c6d91SCole Faust cost += 1 101*4b9c6d91SCole Faust if ins.code == BPF_LD | BPF_W | BPF_ABS: 102*4b9c6d91SCole Faust register = struct.unpack('I', input_memory[ins.k:ins.k + 4])[0] 103*4b9c6d91SCole Faust elif ins.code == BPF_JMP | BPF_JA | BPF_K: 104*4b9c6d91SCole Faust program_counter += ins.k 105*4b9c6d91SCole Faust elif ins.code == BPF_JMP | BPF_JEQ | BPF_K: 106*4b9c6d91SCole Faust if register == ins.k: 107*4b9c6d91SCole Faust program_counter += ins.jt 108*4b9c6d91SCole Faust else: 109*4b9c6d91SCole Faust program_counter += ins.jf 110*4b9c6d91SCole Faust elif ins.code == BPF_JMP | BPF_JGT | BPF_K: 111*4b9c6d91SCole Faust if register > ins.k: 112*4b9c6d91SCole Faust program_counter += ins.jt 113*4b9c6d91SCole Faust else: 114*4b9c6d91SCole Faust program_counter += ins.jf 115*4b9c6d91SCole Faust elif ins.code == BPF_JMP | BPF_JGE | BPF_K: 116*4b9c6d91SCole Faust if register >= ins.k: 117*4b9c6d91SCole Faust program_counter += ins.jt 118*4b9c6d91SCole Faust else: 119*4b9c6d91SCole Faust program_counter += ins.jf 120*4b9c6d91SCole Faust elif ins.code == BPF_JMP | BPF_JSET | BPF_K: 121*4b9c6d91SCole Faust if register & ins.k != 0: 122*4b9c6d91SCole Faust program_counter += ins.jt 123*4b9c6d91SCole Faust else: 124*4b9c6d91SCole Faust program_counter += ins.jf 125*4b9c6d91SCole Faust elif ins.code == BPF_RET: 126*4b9c6d91SCole Faust if ins.k == SECCOMP_RET_KILL_PROCESS: 127*4b9c6d91SCole Faust return (cost, 'KILL_PROCESS') 128*4b9c6d91SCole Faust if ins.k == SECCOMP_RET_KILL_THREAD: 129*4b9c6d91SCole Faust return (cost, 'KILL_THREAD') 130*4b9c6d91SCole Faust if ins.k == SECCOMP_RET_TRAP: 131*4b9c6d91SCole Faust return (cost, 'TRAP') 132*4b9c6d91SCole Faust if (ins.k & SECCOMP_RET_ACTION_FULL) == SECCOMP_RET_ERRNO: 133*4b9c6d91SCole Faust return (cost, 'ERRNO', ins.k & SECCOMP_RET_DATA) 134*4b9c6d91SCole Faust if ins.k == SECCOMP_RET_TRACE: 135*4b9c6d91SCole Faust return (cost, 'TRACE') 136*4b9c6d91SCole Faust if ins.k == SECCOMP_RET_USER_NOTIF: 137*4b9c6d91SCole Faust return (cost, 'USER_NOTIF') 138*4b9c6d91SCole Faust if ins.k == SECCOMP_RET_LOG: 139*4b9c6d91SCole Faust return (cost, 'LOG') 140*4b9c6d91SCole Faust if ins.k == SECCOMP_RET_ALLOW: 141*4b9c6d91SCole Faust return (cost, 'ALLOW') 142*4b9c6d91SCole Faust raise Exception('unknown return %#x' % ins.k) 143*4b9c6d91SCole Faust else: 144*4b9c6d91SCole Faust raise Exception('unknown instruction %r' % (ins, )) 145*4b9c6d91SCole Faust raise Exception('out-of-bounds') 146*4b9c6d91SCole Faust 147*4b9c6d91SCole Faust 148*4b9c6d91SCole Faustclass SockFilter( 149*4b9c6d91SCole Faust collections.namedtuple('SockFilter', ['code', 'jt', 'jf', 'k'])): 150*4b9c6d91SCole Faust """A representation of struct sock_filter.""" 151*4b9c6d91SCole Faust 152*4b9c6d91SCole Faust __slots__ = () 153*4b9c6d91SCole Faust 154*4b9c6d91SCole Faust def encode(self): 155*4b9c6d91SCole Faust """Return an encoded version of the SockFilter.""" 156*4b9c6d91SCole Faust return struct.pack('HBBI', self.code, self.jt, self.jf, self.k) 157*4b9c6d91SCole Faust 158*4b9c6d91SCole Faust 159*4b9c6d91SCole Faustclass AbstractBlock(abc.ABC): 160*4b9c6d91SCole Faust """A class that implements the visitor pattern.""" 161*4b9c6d91SCole Faust 162*4b9c6d91SCole Faust def __init__(self): 163*4b9c6d91SCole Faust super().__init__() 164*4b9c6d91SCole Faust 165*4b9c6d91SCole Faust @abc.abstractmethod 166*4b9c6d91SCole Faust def accept(self, visitor): 167*4b9c6d91SCole Faust pass 168*4b9c6d91SCole Faust 169*4b9c6d91SCole Faust 170*4b9c6d91SCole Faustclass BasicBlock(AbstractBlock): 171*4b9c6d91SCole Faust """A concrete implementation of AbstractBlock that has been compiled.""" 172*4b9c6d91SCole Faust 173*4b9c6d91SCole Faust def __init__(self, instructions): 174*4b9c6d91SCole Faust super().__init__() 175*4b9c6d91SCole Faust self._instructions = instructions 176*4b9c6d91SCole Faust 177*4b9c6d91SCole Faust def accept(self, visitor): 178*4b9c6d91SCole Faust if visitor.visited(self): 179*4b9c6d91SCole Faust return 180*4b9c6d91SCole Faust visitor.visit(self) 181*4b9c6d91SCole Faust 182*4b9c6d91SCole Faust @property 183*4b9c6d91SCole Faust def instructions(self): 184*4b9c6d91SCole Faust return self._instructions 185*4b9c6d91SCole Faust 186*4b9c6d91SCole Faust @property 187*4b9c6d91SCole Faust def opcodes(self): 188*4b9c6d91SCole Faust return b''.join(i.encode() for i in self._instructions) 189*4b9c6d91SCole Faust 190*4b9c6d91SCole Faust def __eq__(self, o): 191*4b9c6d91SCole Faust if not isinstance(o, BasicBlock): 192*4b9c6d91SCole Faust return False 193*4b9c6d91SCole Faust return self._instructions == o._instructions 194*4b9c6d91SCole Faust 195*4b9c6d91SCole Faust 196*4b9c6d91SCole Faustclass KillProcess(BasicBlock): 197*4b9c6d91SCole Faust """A BasicBlock that unconditionally returns KILL_PROCESS.""" 198*4b9c6d91SCole Faust 199*4b9c6d91SCole Faust def __init__(self): 200*4b9c6d91SCole Faust super().__init__( 201*4b9c6d91SCole Faust [SockFilter(BPF_RET, 0x00, 0x00, SECCOMP_RET_KILL_PROCESS)]) 202*4b9c6d91SCole Faust 203*4b9c6d91SCole Faust 204*4b9c6d91SCole Faustclass KillThread(BasicBlock): 205*4b9c6d91SCole Faust """A BasicBlock that unconditionally returns KILL_THREAD.""" 206*4b9c6d91SCole Faust 207*4b9c6d91SCole Faust def __init__(self): 208*4b9c6d91SCole Faust super().__init__( 209*4b9c6d91SCole Faust [SockFilter(BPF_RET, 0x00, 0x00, SECCOMP_RET_KILL_THREAD)]) 210*4b9c6d91SCole Faust 211*4b9c6d91SCole Faust 212*4b9c6d91SCole Faustclass Trap(BasicBlock): 213*4b9c6d91SCole Faust """A BasicBlock that unconditionally returns TRAP.""" 214*4b9c6d91SCole Faust 215*4b9c6d91SCole Faust def __init__(self): 216*4b9c6d91SCole Faust super().__init__([SockFilter(BPF_RET, 0x00, 0x00, SECCOMP_RET_TRAP)]) 217*4b9c6d91SCole Faust 218*4b9c6d91SCole Faust 219*4b9c6d91SCole Faustclass Trace(BasicBlock): 220*4b9c6d91SCole Faust """A BasicBlock that unconditionally returns TRACE.""" 221*4b9c6d91SCole Faust 222*4b9c6d91SCole Faust def __init__(self): 223*4b9c6d91SCole Faust super().__init__([SockFilter(BPF_RET, 0x00, 0x00, SECCOMP_RET_TRACE)]) 224*4b9c6d91SCole Faust 225*4b9c6d91SCole Faust 226*4b9c6d91SCole Faustclass UserNotify(BasicBlock): 227*4b9c6d91SCole Faust """A BasicBlock that unconditionally returns USER_NOTIF.""" 228*4b9c6d91SCole Faust 229*4b9c6d91SCole Faust def __init__(self): 230*4b9c6d91SCole Faust super().__init__([SockFilter(BPF_RET, 0x00, 0x00, SECCOMP_RET_USER_NOTIF)]) 231*4b9c6d91SCole Faust 232*4b9c6d91SCole Faust 233*4b9c6d91SCole Faustclass Log(BasicBlock): 234*4b9c6d91SCole Faust """A BasicBlock that unconditionally returns LOG.""" 235*4b9c6d91SCole Faust 236*4b9c6d91SCole Faust def __init__(self): 237*4b9c6d91SCole Faust super().__init__([SockFilter(BPF_RET, 0x00, 0x00, SECCOMP_RET_LOG)]) 238*4b9c6d91SCole Faust 239*4b9c6d91SCole Faust 240*4b9c6d91SCole Faustclass ReturnErrno(BasicBlock): 241*4b9c6d91SCole Faust """A BasicBlock that unconditionally returns the specified errno.""" 242*4b9c6d91SCole Faust 243*4b9c6d91SCole Faust def __init__(self, errno): 244*4b9c6d91SCole Faust super().__init__([ 245*4b9c6d91SCole Faust SockFilter(BPF_RET, 0x00, 0x00, 246*4b9c6d91SCole Faust SECCOMP_RET_ERRNO | (errno & SECCOMP_RET_DATA)) 247*4b9c6d91SCole Faust ]) 248*4b9c6d91SCole Faust self.errno = errno 249*4b9c6d91SCole Faust 250*4b9c6d91SCole Faust 251*4b9c6d91SCole Faustclass Allow(BasicBlock): 252*4b9c6d91SCole Faust """A BasicBlock that unconditionally returns ALLOW.""" 253*4b9c6d91SCole Faust 254*4b9c6d91SCole Faust def __init__(self): 255*4b9c6d91SCole Faust super().__init__([SockFilter(BPF_RET, 0x00, 0x00, SECCOMP_RET_ALLOW)]) 256*4b9c6d91SCole Faust 257*4b9c6d91SCole Faust 258*4b9c6d91SCole Faustclass ValidateArch(AbstractBlock): 259*4b9c6d91SCole Faust """An AbstractBlock that validates the architecture.""" 260*4b9c6d91SCole Faust 261*4b9c6d91SCole Faust def __init__(self, next_block): 262*4b9c6d91SCole Faust super().__init__() 263*4b9c6d91SCole Faust self.next_block = next_block 264*4b9c6d91SCole Faust 265*4b9c6d91SCole Faust def accept(self, visitor): 266*4b9c6d91SCole Faust if visitor.visited(self): 267*4b9c6d91SCole Faust return 268*4b9c6d91SCole Faust self.next_block.accept(visitor) 269*4b9c6d91SCole Faust visitor.visit(self) 270*4b9c6d91SCole Faust 271*4b9c6d91SCole Faust 272*4b9c6d91SCole Faustclass SyscallEntry(AbstractBlock): 273*4b9c6d91SCole Faust """An abstract block that represents a syscall comparison in a DAG.""" 274*4b9c6d91SCole Faust 275*4b9c6d91SCole Faust def __init__(self, syscall_number, jt, jf, *, op=BPF_JEQ): 276*4b9c6d91SCole Faust super().__init__() 277*4b9c6d91SCole Faust self.op = op 278*4b9c6d91SCole Faust self.syscall_number = syscall_number 279*4b9c6d91SCole Faust self.jt = jt 280*4b9c6d91SCole Faust self.jf = jf 281*4b9c6d91SCole Faust 282*4b9c6d91SCole Faust def __lt__(self, o): 283*4b9c6d91SCole Faust # Defined because we want to compare tuples that contain SyscallEntries. 284*4b9c6d91SCole Faust return False 285*4b9c6d91SCole Faust 286*4b9c6d91SCole Faust def __gt__(self, o): 287*4b9c6d91SCole Faust # Defined because we want to compare tuples that contain SyscallEntries. 288*4b9c6d91SCole Faust return False 289*4b9c6d91SCole Faust 290*4b9c6d91SCole Faust def accept(self, visitor): 291*4b9c6d91SCole Faust if visitor.visited(self): 292*4b9c6d91SCole Faust return 293*4b9c6d91SCole Faust self.jt.accept(visitor) 294*4b9c6d91SCole Faust self.jf.accept(visitor) 295*4b9c6d91SCole Faust visitor.visit(self) 296*4b9c6d91SCole Faust 297*4b9c6d91SCole Faust def __lt__(self, o): 298*4b9c6d91SCole Faust # Defined because we want to compare tuples that contain SyscallEntries. 299*4b9c6d91SCole Faust return False 300*4b9c6d91SCole Faust 301*4b9c6d91SCole Faust def __gt__(self, o): 302*4b9c6d91SCole Faust # Defined because we want to compare tuples that contain SyscallEntries. 303*4b9c6d91SCole Faust return False 304*4b9c6d91SCole Faust 305*4b9c6d91SCole Faust 306*4b9c6d91SCole Faustclass WideAtom(AbstractBlock): 307*4b9c6d91SCole Faust """A BasicBlock that represents a 32-bit wide atom.""" 308*4b9c6d91SCole Faust 309*4b9c6d91SCole Faust def __init__(self, arg_offset, op, value, jt, jf): 310*4b9c6d91SCole Faust super().__init__() 311*4b9c6d91SCole Faust self.arg_offset = arg_offset 312*4b9c6d91SCole Faust self.op = op 313*4b9c6d91SCole Faust self.value = value 314*4b9c6d91SCole Faust self.jt = jt 315*4b9c6d91SCole Faust self.jf = jf 316*4b9c6d91SCole Faust 317*4b9c6d91SCole Faust def accept(self, visitor): 318*4b9c6d91SCole Faust if visitor.visited(self): 319*4b9c6d91SCole Faust return 320*4b9c6d91SCole Faust self.jt.accept(visitor) 321*4b9c6d91SCole Faust self.jf.accept(visitor) 322*4b9c6d91SCole Faust visitor.visit(self) 323*4b9c6d91SCole Faust 324*4b9c6d91SCole Faust 325*4b9c6d91SCole Faustclass Atom(AbstractBlock): 326*4b9c6d91SCole Faust """A BasicBlock that represents an atom (a simple comparison operation).""" 327*4b9c6d91SCole Faust 328*4b9c6d91SCole Faust def __init__(self, arg_index, op, value, jt, jf): 329*4b9c6d91SCole Faust super().__init__() 330*4b9c6d91SCole Faust if op == '==': 331*4b9c6d91SCole Faust op = BPF_JEQ 332*4b9c6d91SCole Faust elif op == '!=': 333*4b9c6d91SCole Faust op = BPF_JEQ 334*4b9c6d91SCole Faust jt, jf = jf, jt 335*4b9c6d91SCole Faust elif op == '>': 336*4b9c6d91SCole Faust op = BPF_JGT 337*4b9c6d91SCole Faust elif op == '<=': 338*4b9c6d91SCole Faust op = BPF_JGT 339*4b9c6d91SCole Faust jt, jf = jf, jt 340*4b9c6d91SCole Faust elif op == '>=': 341*4b9c6d91SCole Faust op = BPF_JGE 342*4b9c6d91SCole Faust elif op == '<': 343*4b9c6d91SCole Faust op = BPF_JGE 344*4b9c6d91SCole Faust jt, jf = jf, jt 345*4b9c6d91SCole Faust elif op == '&': 346*4b9c6d91SCole Faust op = BPF_JSET 347*4b9c6d91SCole Faust elif op == 'in': 348*4b9c6d91SCole Faust op = BPF_JSET 349*4b9c6d91SCole Faust # The mask is negated, so the comparison will be true when the 350*4b9c6d91SCole Faust # argument includes a flag that wasn't listed in the original 351*4b9c6d91SCole Faust # (non-negated) mask. This would be the failure case, so we switch 352*4b9c6d91SCole Faust # |jt| and |jf|. 353*4b9c6d91SCole Faust value = (~value) & ((1 << 64) - 1) 354*4b9c6d91SCole Faust jt, jf = jf, jt 355*4b9c6d91SCole Faust else: 356*4b9c6d91SCole Faust raise Exception('Unknown operator %s' % op) 357*4b9c6d91SCole Faust 358*4b9c6d91SCole Faust self.arg_index = arg_index 359*4b9c6d91SCole Faust self.op = op 360*4b9c6d91SCole Faust self.jt = jt 361*4b9c6d91SCole Faust self.jf = jf 362*4b9c6d91SCole Faust self.value = value 363*4b9c6d91SCole Faust 364*4b9c6d91SCole Faust def accept(self, visitor): 365*4b9c6d91SCole Faust if visitor.visited(self): 366*4b9c6d91SCole Faust return 367*4b9c6d91SCole Faust self.jt.accept(visitor) 368*4b9c6d91SCole Faust self.jf.accept(visitor) 369*4b9c6d91SCole Faust visitor.visit(self) 370*4b9c6d91SCole Faust 371*4b9c6d91SCole Faust 372*4b9c6d91SCole Faustclass AbstractVisitor(abc.ABC): 373*4b9c6d91SCole Faust """An abstract visitor.""" 374*4b9c6d91SCole Faust 375*4b9c6d91SCole Faust def __init__(self): 376*4b9c6d91SCole Faust self._visited = set() 377*4b9c6d91SCole Faust 378*4b9c6d91SCole Faust def visited(self, block): 379*4b9c6d91SCole Faust if id(block) in self._visited: 380*4b9c6d91SCole Faust return True 381*4b9c6d91SCole Faust self._visited.add(id(block)) 382*4b9c6d91SCole Faust return False 383*4b9c6d91SCole Faust 384*4b9c6d91SCole Faust def process(self, block): 385*4b9c6d91SCole Faust block.accept(self) 386*4b9c6d91SCole Faust return block 387*4b9c6d91SCole Faust 388*4b9c6d91SCole Faust def visit(self, block): 389*4b9c6d91SCole Faust if isinstance(block, KillProcess): 390*4b9c6d91SCole Faust self.visitKillProcess(block) 391*4b9c6d91SCole Faust elif isinstance(block, KillThread): 392*4b9c6d91SCole Faust self.visitKillThread(block) 393*4b9c6d91SCole Faust elif isinstance(block, Trap): 394*4b9c6d91SCole Faust self.visitTrap(block) 395*4b9c6d91SCole Faust elif isinstance(block, ReturnErrno): 396*4b9c6d91SCole Faust self.visitReturnErrno(block) 397*4b9c6d91SCole Faust elif isinstance(block, Trace): 398*4b9c6d91SCole Faust self.visitTrace(block) 399*4b9c6d91SCole Faust elif isinstance(block, UserNotify): 400*4b9c6d91SCole Faust self.visitUserNotify(block) 401*4b9c6d91SCole Faust elif isinstance(block, Log): 402*4b9c6d91SCole Faust self.visitLog(block) 403*4b9c6d91SCole Faust elif isinstance(block, Allow): 404*4b9c6d91SCole Faust self.visitAllow(block) 405*4b9c6d91SCole Faust elif isinstance(block, BasicBlock): 406*4b9c6d91SCole Faust self.visitBasicBlock(block) 407*4b9c6d91SCole Faust elif isinstance(block, ValidateArch): 408*4b9c6d91SCole Faust self.visitValidateArch(block) 409*4b9c6d91SCole Faust elif isinstance(block, SyscallEntry): 410*4b9c6d91SCole Faust self.visitSyscallEntry(block) 411*4b9c6d91SCole Faust elif isinstance(block, WideAtom): 412*4b9c6d91SCole Faust self.visitWideAtom(block) 413*4b9c6d91SCole Faust elif isinstance(block, Atom): 414*4b9c6d91SCole Faust self.visitAtom(block) 415*4b9c6d91SCole Faust else: 416*4b9c6d91SCole Faust raise Exception('Unknown block type: %r' % block) 417*4b9c6d91SCole Faust 418*4b9c6d91SCole Faust @abc.abstractmethod 419*4b9c6d91SCole Faust def visitKillProcess(self, block): 420*4b9c6d91SCole Faust pass 421*4b9c6d91SCole Faust 422*4b9c6d91SCole Faust @abc.abstractmethod 423*4b9c6d91SCole Faust def visitKillThread(self, block): 424*4b9c6d91SCole Faust pass 425*4b9c6d91SCole Faust 426*4b9c6d91SCole Faust @abc.abstractmethod 427*4b9c6d91SCole Faust def visitTrap(self, block): 428*4b9c6d91SCole Faust pass 429*4b9c6d91SCole Faust 430*4b9c6d91SCole Faust @abc.abstractmethod 431*4b9c6d91SCole Faust def visitReturnErrno(self, block): 432*4b9c6d91SCole Faust pass 433*4b9c6d91SCole Faust 434*4b9c6d91SCole Faust @abc.abstractmethod 435*4b9c6d91SCole Faust def visitTrace(self, block): 436*4b9c6d91SCole Faust pass 437*4b9c6d91SCole Faust 438*4b9c6d91SCole Faust @abc.abstractmethod 439*4b9c6d91SCole Faust def visitUserNotify(self, block): 440*4b9c6d91SCole Faust pass 441*4b9c6d91SCole Faust 442*4b9c6d91SCole Faust @abc.abstractmethod 443*4b9c6d91SCole Faust def visitLog(self, block): 444*4b9c6d91SCole Faust pass 445*4b9c6d91SCole Faust 446*4b9c6d91SCole Faust @abc.abstractmethod 447*4b9c6d91SCole Faust def visitAllow(self, block): 448*4b9c6d91SCole Faust pass 449*4b9c6d91SCole Faust 450*4b9c6d91SCole Faust @abc.abstractmethod 451*4b9c6d91SCole Faust def visitBasicBlock(self, block): 452*4b9c6d91SCole Faust pass 453*4b9c6d91SCole Faust 454*4b9c6d91SCole Faust @abc.abstractmethod 455*4b9c6d91SCole Faust def visitValidateArch(self, block): 456*4b9c6d91SCole Faust pass 457*4b9c6d91SCole Faust 458*4b9c6d91SCole Faust @abc.abstractmethod 459*4b9c6d91SCole Faust def visitSyscallEntry(self, block): 460*4b9c6d91SCole Faust pass 461*4b9c6d91SCole Faust 462*4b9c6d91SCole Faust @abc.abstractmethod 463*4b9c6d91SCole Faust def visitWideAtom(self, block): 464*4b9c6d91SCole Faust pass 465*4b9c6d91SCole Faust 466*4b9c6d91SCole Faust @abc.abstractmethod 467*4b9c6d91SCole Faust def visitAtom(self, block): 468*4b9c6d91SCole Faust pass 469*4b9c6d91SCole Faust 470*4b9c6d91SCole Faust 471*4b9c6d91SCole Faustclass CopyingVisitor(AbstractVisitor): 472*4b9c6d91SCole Faust """A visitor that copies Blocks.""" 473*4b9c6d91SCole Faust 474*4b9c6d91SCole Faust def __init__(self): 475*4b9c6d91SCole Faust super().__init__() 476*4b9c6d91SCole Faust self._mapping = {} 477*4b9c6d91SCole Faust 478*4b9c6d91SCole Faust def process(self, block): 479*4b9c6d91SCole Faust self._mapping = {} 480*4b9c6d91SCole Faust block.accept(self) 481*4b9c6d91SCole Faust return self._mapping[id(block)] 482*4b9c6d91SCole Faust 483*4b9c6d91SCole Faust def visitKillProcess(self, block): 484*4b9c6d91SCole Faust assert id(block) not in self._mapping 485*4b9c6d91SCole Faust self._mapping[id(block)] = KillProcess() 486*4b9c6d91SCole Faust 487*4b9c6d91SCole Faust def visitKillThread(self, block): 488*4b9c6d91SCole Faust assert id(block) not in self._mapping 489*4b9c6d91SCole Faust self._mapping[id(block)] = KillThread() 490*4b9c6d91SCole Faust 491*4b9c6d91SCole Faust def visitTrap(self, block): 492*4b9c6d91SCole Faust assert id(block) not in self._mapping 493*4b9c6d91SCole Faust self._mapping[id(block)] = Trap() 494*4b9c6d91SCole Faust 495*4b9c6d91SCole Faust def visitReturnErrno(self, block): 496*4b9c6d91SCole Faust assert id(block) not in self._mapping 497*4b9c6d91SCole Faust self._mapping[id(block)] = ReturnErrno(block.errno) 498*4b9c6d91SCole Faust 499*4b9c6d91SCole Faust def visitTrace(self, block): 500*4b9c6d91SCole Faust assert id(block) not in self._mapping 501*4b9c6d91SCole Faust self._mapping[id(block)] = Trace() 502*4b9c6d91SCole Faust 503*4b9c6d91SCole Faust def visitUserNotify(self, block): 504*4b9c6d91SCole Faust assert id(block) not in self._mapping 505*4b9c6d91SCole Faust self._mapping[id(block)] = UserNotify() 506*4b9c6d91SCole Faust 507*4b9c6d91SCole Faust def visitLog(self, block): 508*4b9c6d91SCole Faust assert id(block) not in self._mapping 509*4b9c6d91SCole Faust self._mapping[id(block)] = Log() 510*4b9c6d91SCole Faust 511*4b9c6d91SCole Faust def visitAllow(self, block): 512*4b9c6d91SCole Faust assert id(block) not in self._mapping 513*4b9c6d91SCole Faust self._mapping[id(block)] = Allow() 514*4b9c6d91SCole Faust 515*4b9c6d91SCole Faust def visitBasicBlock(self, block): 516*4b9c6d91SCole Faust assert id(block) not in self._mapping 517*4b9c6d91SCole Faust self._mapping[id(block)] = BasicBlock(block.instructions) 518*4b9c6d91SCole Faust 519*4b9c6d91SCole Faust def visitValidateArch(self, block): 520*4b9c6d91SCole Faust assert id(block) not in self._mapping 521*4b9c6d91SCole Faust self._mapping[id(block)] = ValidateArch( 522*4b9c6d91SCole Faust block.arch, self._mapping[id(block.next_block)]) 523*4b9c6d91SCole Faust 524*4b9c6d91SCole Faust def visitSyscallEntry(self, block): 525*4b9c6d91SCole Faust assert id(block) not in self._mapping 526*4b9c6d91SCole Faust self._mapping[id(block)] = SyscallEntry( 527*4b9c6d91SCole Faust block.syscall_number, 528*4b9c6d91SCole Faust self._mapping[id(block.jt)], 529*4b9c6d91SCole Faust self._mapping[id(block.jf)], 530*4b9c6d91SCole Faust op=block.op) 531*4b9c6d91SCole Faust 532*4b9c6d91SCole Faust def visitWideAtom(self, block): 533*4b9c6d91SCole Faust assert id(block) not in self._mapping 534*4b9c6d91SCole Faust self._mapping[id(block)] = WideAtom( 535*4b9c6d91SCole Faust block.arg_offset, block.op, block.value, self._mapping[id( 536*4b9c6d91SCole Faust block.jt)], self._mapping[id(block.jf)]) 537*4b9c6d91SCole Faust 538*4b9c6d91SCole Faust def visitAtom(self, block): 539*4b9c6d91SCole Faust assert id(block) not in self._mapping 540*4b9c6d91SCole Faust self._mapping[id(block)] = Atom(block.arg_index, block.op, block.value, 541*4b9c6d91SCole Faust self._mapping[id(block.jt)], 542*4b9c6d91SCole Faust self._mapping[id(block.jf)]) 543*4b9c6d91SCole Faust 544*4b9c6d91SCole Faust 545*4b9c6d91SCole Faustclass LoweringVisitor(CopyingVisitor): 546*4b9c6d91SCole Faust """A visitor that lowers Atoms into WideAtoms.""" 547*4b9c6d91SCole Faust 548*4b9c6d91SCole Faust def __init__(self, *, arch): 549*4b9c6d91SCole Faust super().__init__() 550*4b9c6d91SCole Faust self._bits = arch.bits 551*4b9c6d91SCole Faust 552*4b9c6d91SCole Faust def visitAtom(self, block): 553*4b9c6d91SCole Faust assert id(block) not in self._mapping 554*4b9c6d91SCole Faust 555*4b9c6d91SCole Faust lo = block.value & 0xFFFFFFFF 556*4b9c6d91SCole Faust hi = (block.value >> 32) & 0xFFFFFFFF 557*4b9c6d91SCole Faust 558*4b9c6d91SCole Faust lo_block = WideAtom( 559*4b9c6d91SCole Faust arg_offset(block.arg_index, False), block.op, lo, 560*4b9c6d91SCole Faust self._mapping[id(block.jt)], self._mapping[id(block.jf)]) 561*4b9c6d91SCole Faust 562*4b9c6d91SCole Faust if self._bits == 32: 563*4b9c6d91SCole Faust self._mapping[id(block)] = lo_block 564*4b9c6d91SCole Faust return 565*4b9c6d91SCole Faust 566*4b9c6d91SCole Faust if block.op in (BPF_JGE, BPF_JGT): 567*4b9c6d91SCole Faust # hi_1,lo_1 <op> hi_2,lo_2 568*4b9c6d91SCole Faust # 569*4b9c6d91SCole Faust # hi_1 > hi_2 || hi_1 == hi_2 && lo_1 <op> lo_2 570*4b9c6d91SCole Faust if hi == 0: 571*4b9c6d91SCole Faust # Special case: it's not needed to check whether |hi_1 == hi_2|, 572*4b9c6d91SCole Faust # because it's true iff the JGT test fails. 573*4b9c6d91SCole Faust self._mapping[id(block)] = WideAtom( 574*4b9c6d91SCole Faust arg_offset(block.arg_index, True), BPF_JGT, hi, 575*4b9c6d91SCole Faust self._mapping[id(block.jt)], lo_block) 576*4b9c6d91SCole Faust return 577*4b9c6d91SCole Faust hi_eq_block = WideAtom( 578*4b9c6d91SCole Faust arg_offset(block.arg_index, True), BPF_JEQ, hi, lo_block, 579*4b9c6d91SCole Faust self._mapping[id(block.jf)]) 580*4b9c6d91SCole Faust self._mapping[id(block)] = WideAtom( 581*4b9c6d91SCole Faust arg_offset(block.arg_index, True), BPF_JGT, hi, 582*4b9c6d91SCole Faust self._mapping[id(block.jt)], hi_eq_block) 583*4b9c6d91SCole Faust return 584*4b9c6d91SCole Faust if block.op == BPF_JSET: 585*4b9c6d91SCole Faust # hi_1,lo_1 & hi_2,lo_2 586*4b9c6d91SCole Faust # 587*4b9c6d91SCole Faust # hi_1 & hi_2 || lo_1 & lo_2 588*4b9c6d91SCole Faust if hi == 0: 589*4b9c6d91SCole Faust # Special case: |hi_1 & hi_2| will never be True, so jump 590*4b9c6d91SCole Faust # directly into the |lo_1 & lo_2| case. 591*4b9c6d91SCole Faust self._mapping[id(block)] = lo_block 592*4b9c6d91SCole Faust return 593*4b9c6d91SCole Faust self._mapping[id(block)] = WideAtom( 594*4b9c6d91SCole Faust arg_offset(block.arg_index, True), block.op, hi, 595*4b9c6d91SCole Faust self._mapping[id(block.jt)], lo_block) 596*4b9c6d91SCole Faust return 597*4b9c6d91SCole Faust 598*4b9c6d91SCole Faust assert block.op == BPF_JEQ, block.op 599*4b9c6d91SCole Faust 600*4b9c6d91SCole Faust # hi_1,lo_1 == hi_2,lo_2 601*4b9c6d91SCole Faust # 602*4b9c6d91SCole Faust # hi_1 == hi_2 && lo_1 == lo_2 603*4b9c6d91SCole Faust self._mapping[id(block)] = WideAtom( 604*4b9c6d91SCole Faust arg_offset(block.arg_index, True), block.op, hi, lo_block, 605*4b9c6d91SCole Faust self._mapping[id(block.jf)]) 606*4b9c6d91SCole Faust 607*4b9c6d91SCole Faust 608*4b9c6d91SCole Faustclass FlatteningVisitor: 609*4b9c6d91SCole Faust """A visitor that flattens a DAG of Block objects.""" 610*4b9c6d91SCole Faust 611*4b9c6d91SCole Faust def __init__(self, *, arch, kill_action): 612*4b9c6d91SCole Faust self._visited = set() 613*4b9c6d91SCole Faust self._kill_action = kill_action 614*4b9c6d91SCole Faust self._instructions = [] 615*4b9c6d91SCole Faust self._arch = arch 616*4b9c6d91SCole Faust self._offsets = {} 617*4b9c6d91SCole Faust 618*4b9c6d91SCole Faust @property 619*4b9c6d91SCole Faust def result(self): 620*4b9c6d91SCole Faust return BasicBlock(self._instructions) 621*4b9c6d91SCole Faust 622*4b9c6d91SCole Faust def _distance(self, block): 623*4b9c6d91SCole Faust distance = self._offsets[id(block)] + len(self._instructions) 624*4b9c6d91SCole Faust assert distance >= 0 625*4b9c6d91SCole Faust return distance 626*4b9c6d91SCole Faust 627*4b9c6d91SCole Faust def _emit_load_arg(self, offset): 628*4b9c6d91SCole Faust return [SockFilter(BPF_LD | BPF_W | BPF_ABS, 0, 0, offset)] 629*4b9c6d91SCole Faust 630*4b9c6d91SCole Faust def _emit_jmp(self, op, value, jt_distance, jf_distance): 631*4b9c6d91SCole Faust if jt_distance < 0x100 and jf_distance < 0x100: 632*4b9c6d91SCole Faust return [ 633*4b9c6d91SCole Faust SockFilter(BPF_JMP | op | BPF_K, jt_distance, jf_distance, 634*4b9c6d91SCole Faust value), 635*4b9c6d91SCole Faust ] 636*4b9c6d91SCole Faust if jt_distance + 1 < 0x100: 637*4b9c6d91SCole Faust return [ 638*4b9c6d91SCole Faust SockFilter(BPF_JMP | op | BPF_K, jt_distance + 1, 0, value), 639*4b9c6d91SCole Faust SockFilter(BPF_JMP | BPF_JA, 0, 0, jf_distance), 640*4b9c6d91SCole Faust ] 641*4b9c6d91SCole Faust if jf_distance + 1 < 0x100: 642*4b9c6d91SCole Faust return [ 643*4b9c6d91SCole Faust SockFilter(BPF_JMP | op | BPF_K, 0, jf_distance + 1, value), 644*4b9c6d91SCole Faust SockFilter(BPF_JMP | BPF_JA, 0, 0, jt_distance), 645*4b9c6d91SCole Faust ] 646*4b9c6d91SCole Faust return [ 647*4b9c6d91SCole Faust SockFilter(BPF_JMP | op | BPF_K, 0, 1, value), 648*4b9c6d91SCole Faust SockFilter(BPF_JMP | BPF_JA, 0, 0, jt_distance + 1), 649*4b9c6d91SCole Faust SockFilter(BPF_JMP | BPF_JA, 0, 0, jf_distance), 650*4b9c6d91SCole Faust ] 651*4b9c6d91SCole Faust 652*4b9c6d91SCole Faust def visited(self, block): 653*4b9c6d91SCole Faust if id(block) in self._visited: 654*4b9c6d91SCole Faust return True 655*4b9c6d91SCole Faust self._visited.add(id(block)) 656*4b9c6d91SCole Faust return False 657*4b9c6d91SCole Faust 658*4b9c6d91SCole Faust def visit(self, block): 659*4b9c6d91SCole Faust assert id(block) not in self._offsets 660*4b9c6d91SCole Faust 661*4b9c6d91SCole Faust if isinstance(block, BasicBlock): 662*4b9c6d91SCole Faust instructions = block.instructions 663*4b9c6d91SCole Faust elif isinstance(block, ValidateArch): 664*4b9c6d91SCole Faust instructions = [ 665*4b9c6d91SCole Faust SockFilter(BPF_LD | BPF_W | BPF_ABS, 0, 0, 4), 666*4b9c6d91SCole Faust SockFilter(BPF_JMP | BPF_JEQ | BPF_K, 667*4b9c6d91SCole Faust self._distance(block.next_block) + 1, 0, 668*4b9c6d91SCole Faust self._arch.arch_nr), 669*4b9c6d91SCole Faust ] + self._kill_action.instructions + [ 670*4b9c6d91SCole Faust SockFilter(BPF_LD | BPF_W | BPF_ABS, 0, 0, 0), 671*4b9c6d91SCole Faust ] 672*4b9c6d91SCole Faust elif isinstance(block, SyscallEntry): 673*4b9c6d91SCole Faust instructions = self._emit_jmp(block.op, block.syscall_number, 674*4b9c6d91SCole Faust self._distance(block.jt), 675*4b9c6d91SCole Faust self._distance(block.jf)) 676*4b9c6d91SCole Faust elif isinstance(block, WideAtom): 677*4b9c6d91SCole Faust instructions = ( 678*4b9c6d91SCole Faust self._emit_load_arg(block.arg_offset) + self._emit_jmp( 679*4b9c6d91SCole Faust block.op, block.value, self._distance(block.jt), 680*4b9c6d91SCole Faust self._distance(block.jf))) 681*4b9c6d91SCole Faust else: 682*4b9c6d91SCole Faust raise Exception('Unknown block type: %r' % block) 683*4b9c6d91SCole Faust 684*4b9c6d91SCole Faust self._instructions = instructions + self._instructions 685*4b9c6d91SCole Faust self._offsets[id(block)] = -len(self._instructions) 686*4b9c6d91SCole Faust return 687*4b9c6d91SCole Faust 688*4b9c6d91SCole Faust 689*4b9c6d91SCole Faustclass ArgFilterForwardingVisitor: 690*4b9c6d91SCole Faust """A visitor that forwards visitation to all arg filters.""" 691*4b9c6d91SCole Faust 692*4b9c6d91SCole Faust def __init__(self, visitor): 693*4b9c6d91SCole Faust self._visited = set() 694*4b9c6d91SCole Faust self.visitor = visitor 695*4b9c6d91SCole Faust 696*4b9c6d91SCole Faust def visited(self, block): 697*4b9c6d91SCole Faust if id(block) in self._visited: 698*4b9c6d91SCole Faust return True 699*4b9c6d91SCole Faust self._visited.add(id(block)) 700*4b9c6d91SCole Faust return False 701*4b9c6d91SCole Faust 702*4b9c6d91SCole Faust def visit(self, block): 703*4b9c6d91SCole Faust # All arg filters are BasicBlocks. 704*4b9c6d91SCole Faust if not isinstance(block, BasicBlock): 705*4b9c6d91SCole Faust return 706*4b9c6d91SCole Faust # But the ALLOW, KILL_PROCESS, TRAP, etc. actions are too and we don't 707*4b9c6d91SCole Faust # want to visit them just yet. 708*4b9c6d91SCole Faust if (isinstance(block, KillProcess) or isinstance(block, KillThread) 709*4b9c6d91SCole Faust or isinstance(block, Trap) or isinstance(block, ReturnErrno) 710*4b9c6d91SCole Faust or isinstance(block, Trace) or isinstance(block, UserNotify) 711*4b9c6d91SCole Faust or isinstance(block, Log) or isinstance(block, Allow)): 712*4b9c6d91SCole Faust return 713*4b9c6d91SCole Faust block.accept(self.visitor) 714