xref: /aosp_15_r20/external/minijail/tools/compiler_unittest.py (revision 4b9c6d91573e8b3a96609339b46361b5476dd0f9)
1*4b9c6d91SCole Faust#!/usr/bin/env python3
2*4b9c6d91SCole Faust# -*- coding: utf-8 -*-
3*4b9c6d91SCole Faust#
4*4b9c6d91SCole Faust# Copyright (C) 2018 The Android Open Source Project
5*4b9c6d91SCole Faust#
6*4b9c6d91SCole Faust# Licensed under the Apache License, Version 2.0 (the "License");
7*4b9c6d91SCole Faust# you may not use this file except in compliance with the License.
8*4b9c6d91SCole Faust# You may obtain a copy of the License at
9*4b9c6d91SCole Faust#
10*4b9c6d91SCole Faust#      http://www.apache.org/licenses/LICENSE-2.0
11*4b9c6d91SCole Faust#
12*4b9c6d91SCole Faust# Unless required by applicable law or agreed to in writing, software
13*4b9c6d91SCole Faust# distributed under the License is distributed on an "AS IS" BASIS,
14*4b9c6d91SCole Faust# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15*4b9c6d91SCole Faust# See the License for the specific language governing permissions and
16*4b9c6d91SCole Faust# limitations under the License.
17*4b9c6d91SCole Faust"""Unittests for the compiler module."""
18*4b9c6d91SCole Faust
19*4b9c6d91SCole Faustfrom __future__ import print_function
20*4b9c6d91SCole Faust
21*4b9c6d91SCole Faustimport os
22*4b9c6d91SCole Faustimport random
23*4b9c6d91SCole Faustimport shutil
24*4b9c6d91SCole Faustimport tempfile
25*4b9c6d91SCole Faustimport unittest
26*4b9c6d91SCole Faustfrom importlib import resources
27*4b9c6d91SCole Faust
28*4b9c6d91SCole Faustimport arch
29*4b9c6d91SCole Faustimport bpf
30*4b9c6d91SCole Faustimport compiler
31*4b9c6d91SCole Faustimport parser  # pylint: disable=wrong-import-order
32*4b9c6d91SCole Faust
33*4b9c6d91SCole FaustARCH_64 = arch.Arch.load_from_json_bytes(
34*4b9c6d91SCole Faust    resources.files("testdata").joinpath("arch_64.json").read_bytes()
35*4b9c6d91SCole Faust)
36*4b9c6d91SCole Faust
37*4b9c6d91SCole Faust
38*4b9c6d91SCole Faustclass CompileFilterStatementTests(unittest.TestCase):
39*4b9c6d91SCole Faust    """Tests for PolicyCompiler.compile_filter_statement."""
40*4b9c6d91SCole Faust
41*4b9c6d91SCole Faust    def setUp(self):
42*4b9c6d91SCole Faust        self.arch = ARCH_64
43*4b9c6d91SCole Faust        self.compiler = compiler.PolicyCompiler(self.arch)
44*4b9c6d91SCole Faust
45*4b9c6d91SCole Faust    def _compile(self, line):
46*4b9c6d91SCole Faust        with tempfile.NamedTemporaryFile(mode='w') as policy_file:
47*4b9c6d91SCole Faust            policy_file.write(line)
48*4b9c6d91SCole Faust            policy_file.flush()
49*4b9c6d91SCole Faust            policy_parser = parser.PolicyParser(
50*4b9c6d91SCole Faust                self.arch, kill_action=bpf.KillProcess())
51*4b9c6d91SCole Faust            parsed_policy = policy_parser.parse_file(policy_file.name)
52*4b9c6d91SCole Faust            assert len(parsed_policy.filter_statements) == 1
53*4b9c6d91SCole Faust            return self.compiler.compile_filter_statement(
54*4b9c6d91SCole Faust                parsed_policy.filter_statements[0],
55*4b9c6d91SCole Faust                kill_action=bpf.KillProcess())
56*4b9c6d91SCole Faust
57*4b9c6d91SCole Faust    def test_allow(self):
58*4b9c6d91SCole Faust        """Accept lines where the syscall is accepted unconditionally."""
59*4b9c6d91SCole Faust        block = self._compile('read: allow')
60*4b9c6d91SCole Faust        self.assertEqual(block.filter, None)
61*4b9c6d91SCole Faust        self.assertEqual(
62*4b9c6d91SCole Faust            block.simulate(self.arch.arch_nr, self.arch.syscalls['read'],
63*4b9c6d91SCole Faust                           0)[1], 'ALLOW')
64*4b9c6d91SCole Faust        self.assertEqual(
65*4b9c6d91SCole Faust            block.simulate(self.arch.arch_nr, self.arch.syscalls['read'],
66*4b9c6d91SCole Faust                           1)[1], 'ALLOW')
67*4b9c6d91SCole Faust
68*4b9c6d91SCole Faust    def test_arg0_eq_generated_code(self):
69*4b9c6d91SCole Faust        """Accept lines with an argument filter with ==."""
70*4b9c6d91SCole Faust        block = self._compile('read: arg0 == 0x100')
71*4b9c6d91SCole Faust        # It might be a bit brittle to check the generated code in each test
72*4b9c6d91SCole Faust        # case instead of just the behavior, but there should be at least one
73*4b9c6d91SCole Faust        # test where this happens.
74*4b9c6d91SCole Faust        self.assertEqual(
75*4b9c6d91SCole Faust            block.filter.instructions,
76*4b9c6d91SCole Faust            [
77*4b9c6d91SCole Faust                bpf.SockFilter(bpf.BPF_LD | bpf.BPF_W | bpf.BPF_ABS, 0, 0,
78*4b9c6d91SCole Faust                               bpf.arg_offset(0, True)),
79*4b9c6d91SCole Faust                # Jump to KILL_PROCESS if the high word does not match.
80*4b9c6d91SCole Faust                bpf.SockFilter(bpf.BPF_JMP | bpf.BPF_JEQ | bpf.BPF_K, 0, 2, 0),
81*4b9c6d91SCole Faust                bpf.SockFilter(bpf.BPF_LD | bpf.BPF_W | bpf.BPF_ABS, 0, 0,
82*4b9c6d91SCole Faust                               bpf.arg_offset(0, False)),
83*4b9c6d91SCole Faust                # Jump to KILL_PROCESS if the low word does not match.
84*4b9c6d91SCole Faust                bpf.SockFilter(bpf.BPF_JMP | bpf.BPF_JEQ | bpf.BPF_K, 1, 0,
85*4b9c6d91SCole Faust                               0x100),
86*4b9c6d91SCole Faust                bpf.SockFilter(bpf.BPF_RET, 0, 0,
87*4b9c6d91SCole Faust                               bpf.SECCOMP_RET_KILL_PROCESS),
88*4b9c6d91SCole Faust                bpf.SockFilter(bpf.BPF_RET, 0, 0, bpf.SECCOMP_RET_ALLOW),
89*4b9c6d91SCole Faust            ])
90*4b9c6d91SCole Faust
91*4b9c6d91SCole Faust    def test_arg0_comparison_operators(self):
92*4b9c6d91SCole Faust        """Accept lines with an argument filter with comparison operators."""
93*4b9c6d91SCole Faust        biases = (-1, 0, 1)
94*4b9c6d91SCole Faust        # For each operator, store the expectations of simulating the program
95*4b9c6d91SCole Faust        # against the constant plus each entry from the |biases| array.
96*4b9c6d91SCole Faust        cases = (
97*4b9c6d91SCole Faust            ('==', ('KILL_PROCESS', 'ALLOW', 'KILL_PROCESS')),
98*4b9c6d91SCole Faust            ('!=', ('ALLOW', 'KILL_PROCESS', 'ALLOW')),
99*4b9c6d91SCole Faust            ('<', ('ALLOW', 'KILL_PROCESS', 'KILL_PROCESS')),
100*4b9c6d91SCole Faust            ('<=', ('ALLOW', 'ALLOW', 'KILL_PROCESS')),
101*4b9c6d91SCole Faust            ('>', ('KILL_PROCESS', 'KILL_PROCESS', 'ALLOW')),
102*4b9c6d91SCole Faust            ('>=', ('KILL_PROCESS', 'ALLOW', 'ALLOW')),
103*4b9c6d91SCole Faust        )
104*4b9c6d91SCole Faust        for operator, expectations in cases:
105*4b9c6d91SCole Faust            block = self._compile('read: arg0 %s 0x100' % operator)
106*4b9c6d91SCole Faust
107*4b9c6d91SCole Faust            # Check the filter's behavior.
108*4b9c6d91SCole Faust            for bias, expectation in zip(biases, expectations):
109*4b9c6d91SCole Faust                self.assertEqual(
110*4b9c6d91SCole Faust                    block.simulate(self.arch.arch_nr,
111*4b9c6d91SCole Faust                                   self.arch.syscalls['read'],
112*4b9c6d91SCole Faust                                   0x100 + bias)[1], expectation)
113*4b9c6d91SCole Faust
114*4b9c6d91SCole Faust    def test_arg0_mask_operator(self):
115*4b9c6d91SCole Faust        """Accept lines with an argument filter with &."""
116*4b9c6d91SCole Faust        block = self._compile('read: arg0 & 0x3')
117*4b9c6d91SCole Faust
118*4b9c6d91SCole Faust        self.assertEqual(
119*4b9c6d91SCole Faust            block.simulate(self.arch.arch_nr, self.arch.syscalls['read'],
120*4b9c6d91SCole Faust                           0)[1], 'KILL_PROCESS')
121*4b9c6d91SCole Faust        self.assertEqual(
122*4b9c6d91SCole Faust            block.simulate(self.arch.arch_nr, self.arch.syscalls['read'],
123*4b9c6d91SCole Faust                           1)[1], 'ALLOW')
124*4b9c6d91SCole Faust        self.assertEqual(
125*4b9c6d91SCole Faust            block.simulate(self.arch.arch_nr, self.arch.syscalls['read'],
126*4b9c6d91SCole Faust                           2)[1], 'ALLOW')
127*4b9c6d91SCole Faust        self.assertEqual(
128*4b9c6d91SCole Faust            block.simulate(self.arch.arch_nr, self.arch.syscalls['read'],
129*4b9c6d91SCole Faust                           3)[1], 'ALLOW')
130*4b9c6d91SCole Faust        self.assertEqual(
131*4b9c6d91SCole Faust            block.simulate(self.arch.arch_nr, self.arch.syscalls['read'],
132*4b9c6d91SCole Faust                           4)[1], 'KILL_PROCESS')
133*4b9c6d91SCole Faust        self.assertEqual(
134*4b9c6d91SCole Faust            block.simulate(self.arch.arch_nr, self.arch.syscalls['read'],
135*4b9c6d91SCole Faust                           5)[1], 'ALLOW')
136*4b9c6d91SCole Faust        self.assertEqual(
137*4b9c6d91SCole Faust            block.simulate(self.arch.arch_nr, self.arch.syscalls['read'],
138*4b9c6d91SCole Faust                           6)[1], 'ALLOW')
139*4b9c6d91SCole Faust        self.assertEqual(
140*4b9c6d91SCole Faust            block.simulate(self.arch.arch_nr, self.arch.syscalls['read'],
141*4b9c6d91SCole Faust                           7)[1], 'ALLOW')
142*4b9c6d91SCole Faust        self.assertEqual(
143*4b9c6d91SCole Faust            block.simulate(self.arch.arch_nr, self.arch.syscalls['read'],
144*4b9c6d91SCole Faust                           8)[1], 'KILL_PROCESS')
145*4b9c6d91SCole Faust
146*4b9c6d91SCole Faust    def test_arg0_in_operator(self):
147*4b9c6d91SCole Faust        """Accept lines with an argument filter with in."""
148*4b9c6d91SCole Faust        block = self._compile('read: arg0 in 0x3')
149*4b9c6d91SCole Faust
150*4b9c6d91SCole Faust        # The 'in' operator only ensures that no bits outside the mask are set,
151*4b9c6d91SCole Faust        # which means that 0 is always allowed.
152*4b9c6d91SCole Faust        self.assertEqual(
153*4b9c6d91SCole Faust            block.simulate(self.arch.arch_nr, self.arch.syscalls['read'],
154*4b9c6d91SCole Faust                           0)[1], 'ALLOW')
155*4b9c6d91SCole Faust        self.assertEqual(
156*4b9c6d91SCole Faust            block.simulate(self.arch.arch_nr, self.arch.syscalls['read'],
157*4b9c6d91SCole Faust                           1)[1], 'ALLOW')
158*4b9c6d91SCole Faust        self.assertEqual(
159*4b9c6d91SCole Faust            block.simulate(self.arch.arch_nr, self.arch.syscalls['read'],
160*4b9c6d91SCole Faust                           2)[1], 'ALLOW')
161*4b9c6d91SCole Faust        self.assertEqual(
162*4b9c6d91SCole Faust            block.simulate(self.arch.arch_nr, self.arch.syscalls['read'],
163*4b9c6d91SCole Faust                           3)[1], 'ALLOW')
164*4b9c6d91SCole Faust        self.assertEqual(
165*4b9c6d91SCole Faust            block.simulate(self.arch.arch_nr, self.arch.syscalls['read'],
166*4b9c6d91SCole Faust                           4)[1], 'KILL_PROCESS')
167*4b9c6d91SCole Faust        self.assertEqual(
168*4b9c6d91SCole Faust            block.simulate(self.arch.arch_nr, self.arch.syscalls['read'],
169*4b9c6d91SCole Faust                           5)[1], 'KILL_PROCESS')
170*4b9c6d91SCole Faust        self.assertEqual(
171*4b9c6d91SCole Faust            block.simulate(self.arch.arch_nr, self.arch.syscalls['read'],
172*4b9c6d91SCole Faust                           6)[1], 'KILL_PROCESS')
173*4b9c6d91SCole Faust        self.assertEqual(
174*4b9c6d91SCole Faust            block.simulate(self.arch.arch_nr, self.arch.syscalls['read'],
175*4b9c6d91SCole Faust                           7)[1], 'KILL_PROCESS')
176*4b9c6d91SCole Faust        self.assertEqual(
177*4b9c6d91SCole Faust            block.simulate(self.arch.arch_nr, self.arch.syscalls['read'],
178*4b9c6d91SCole Faust                           8)[1], 'KILL_PROCESS')
179*4b9c6d91SCole Faust
180*4b9c6d91SCole Faust    def test_arg0_short_gt_ge_comparisons(self):
181*4b9c6d91SCole Faust        """Ensure that the short comparison optimization kicks in."""
182*4b9c6d91SCole Faust        if self.arch.bits == 32:
183*4b9c6d91SCole Faust            return
184*4b9c6d91SCole Faust        short_constant_str = '0xdeadbeef'
185*4b9c6d91SCole Faust        short_constant = int(short_constant_str, base=0)
186*4b9c6d91SCole Faust        long_constant_str = '0xbadc0ffee0ddf00d'
187*4b9c6d91SCole Faust        long_constant = int(long_constant_str, base=0)
188*4b9c6d91SCole Faust        biases = (-1, 0, 1)
189*4b9c6d91SCole Faust        # For each operator, store the expectations of simulating the program
190*4b9c6d91SCole Faust        # against the constant plus each entry from the |biases| array.
191*4b9c6d91SCole Faust        cases = (
192*4b9c6d91SCole Faust            ('<', ('ALLOW', 'KILL_PROCESS', 'KILL_PROCESS')),
193*4b9c6d91SCole Faust            ('<=', ('ALLOW', 'ALLOW', 'KILL_PROCESS')),
194*4b9c6d91SCole Faust            ('>', ('KILL_PROCESS', 'KILL_PROCESS', 'ALLOW')),
195*4b9c6d91SCole Faust            ('>=', ('KILL_PROCESS', 'ALLOW', 'ALLOW')),
196*4b9c6d91SCole Faust        )
197*4b9c6d91SCole Faust        for operator, expectations in cases:
198*4b9c6d91SCole Faust            short_block = self._compile(
199*4b9c6d91SCole Faust                'read: arg0 %s %s' % (operator, short_constant_str))
200*4b9c6d91SCole Faust            long_block = self._compile(
201*4b9c6d91SCole Faust                'read: arg0 %s %s' % (operator, long_constant_str))
202*4b9c6d91SCole Faust
203*4b9c6d91SCole Faust            # Check that the emitted code is shorter when the high word of the
204*4b9c6d91SCole Faust            # constant is zero.
205*4b9c6d91SCole Faust            self.assertLess(
206*4b9c6d91SCole Faust                len(short_block.filter.instructions),
207*4b9c6d91SCole Faust                len(long_block.filter.instructions))
208*4b9c6d91SCole Faust
209*4b9c6d91SCole Faust            # Check the filter's behavior.
210*4b9c6d91SCole Faust            for bias, expectation in zip(biases, expectations):
211*4b9c6d91SCole Faust                self.assertEqual(
212*4b9c6d91SCole Faust                    long_block.simulate(self.arch.arch_nr,
213*4b9c6d91SCole Faust                                        self.arch.syscalls['read'],
214*4b9c6d91SCole Faust                                        long_constant + bias)[1], expectation)
215*4b9c6d91SCole Faust                self.assertEqual(
216*4b9c6d91SCole Faust                    short_block.simulate(
217*4b9c6d91SCole Faust                        self.arch.arch_nr, self.arch.syscalls['read'],
218*4b9c6d91SCole Faust                        short_constant + bias)[1], expectation)
219*4b9c6d91SCole Faust
220*4b9c6d91SCole Faust    def test_and_or(self):
221*4b9c6d91SCole Faust        """Accept lines with a complex expression in DNF."""
222*4b9c6d91SCole Faust        block = self._compile('read: arg0 == 0 && arg1 == 0 || arg0 == 1')
223*4b9c6d91SCole Faust
224*4b9c6d91SCole Faust        self.assertEqual(
225*4b9c6d91SCole Faust            block.simulate(self.arch.arch_nr, self.arch.syscalls['read'], 0,
226*4b9c6d91SCole Faust                           0)[1], 'ALLOW')
227*4b9c6d91SCole Faust        self.assertEqual(
228*4b9c6d91SCole Faust            block.simulate(self.arch.arch_nr, self.arch.syscalls['read'], 0,
229*4b9c6d91SCole Faust                           1)[1], 'KILL_PROCESS')
230*4b9c6d91SCole Faust        self.assertEqual(
231*4b9c6d91SCole Faust            block.simulate(self.arch.arch_nr, self.arch.syscalls['read'], 1,
232*4b9c6d91SCole Faust                           0)[1], 'ALLOW')
233*4b9c6d91SCole Faust        self.assertEqual(
234*4b9c6d91SCole Faust            block.simulate(self.arch.arch_nr, self.arch.syscalls['read'], 1,
235*4b9c6d91SCole Faust                           1)[1], 'ALLOW')
236*4b9c6d91SCole Faust
237*4b9c6d91SCole Faust    def test_trap(self):
238*4b9c6d91SCole Faust        """Accept lines that trap unconditionally."""
239*4b9c6d91SCole Faust        block = self._compile('read: trap')
240*4b9c6d91SCole Faust
241*4b9c6d91SCole Faust        self.assertEqual(
242*4b9c6d91SCole Faust            block.simulate(self.arch.arch_nr, self.arch.syscalls['read'],
243*4b9c6d91SCole Faust                           0)[1], 'TRAP')
244*4b9c6d91SCole Faust
245*4b9c6d91SCole Faust    def test_ret_errno(self):
246*4b9c6d91SCole Faust        """Accept lines that return errno."""
247*4b9c6d91SCole Faust        block = self._compile('read : arg0 == 0 || arg0 == 1 ; return 1')
248*4b9c6d91SCole Faust
249*4b9c6d91SCole Faust        self.assertEqual(
250*4b9c6d91SCole Faust            block.simulate(self.arch.arch_nr, self.arch.syscalls['read'],
251*4b9c6d91SCole Faust                           0)[1:], ('ERRNO', 1))
252*4b9c6d91SCole Faust        self.assertEqual(
253*4b9c6d91SCole Faust            block.simulate(self.arch.arch_nr, self.arch.syscalls['read'],
254*4b9c6d91SCole Faust                           1)[1:], ('ERRNO', 1))
255*4b9c6d91SCole Faust        self.assertEqual(
256*4b9c6d91SCole Faust            block.simulate(self.arch.arch_nr, self.arch.syscalls['read'],
257*4b9c6d91SCole Faust                           2)[1], 'KILL_PROCESS')
258*4b9c6d91SCole Faust
259*4b9c6d91SCole Faust    def test_ret_errno_unconditionally(self):
260*4b9c6d91SCole Faust        """Accept lines that return errno unconditionally."""
261*4b9c6d91SCole Faust        block = self._compile('read: return 1')
262*4b9c6d91SCole Faust
263*4b9c6d91SCole Faust        self.assertEqual(
264*4b9c6d91SCole Faust            block.simulate(self.arch.arch_nr, self.arch.syscalls['read'],
265*4b9c6d91SCole Faust                           0)[1:], ('ERRNO', 1))
266*4b9c6d91SCole Faust
267*4b9c6d91SCole Faust    def test_trace(self):
268*4b9c6d91SCole Faust        """Accept lines that trace unconditionally."""
269*4b9c6d91SCole Faust        block = self._compile('read: trace')
270*4b9c6d91SCole Faust
271*4b9c6d91SCole Faust        self.assertEqual(
272*4b9c6d91SCole Faust            block.simulate(self.arch.arch_nr, self.arch.syscalls['read'],
273*4b9c6d91SCole Faust                           0)[1], 'TRACE')
274*4b9c6d91SCole Faust
275*4b9c6d91SCole Faust    def test_user_notify(self):
276*4b9c6d91SCole Faust        """Accept lines that notify unconditionally."""
277*4b9c6d91SCole Faust        block = self._compile('read: user-notify')
278*4b9c6d91SCole Faust
279*4b9c6d91SCole Faust        self.assertEqual(
280*4b9c6d91SCole Faust            block.simulate(self.arch.arch_nr, self.arch.syscalls['read'],
281*4b9c6d91SCole Faust                           0)[1], 'USER_NOTIF')
282*4b9c6d91SCole Faust
283*4b9c6d91SCole Faust    def test_log(self):
284*4b9c6d91SCole Faust        """Accept lines that log unconditionally."""
285*4b9c6d91SCole Faust        block = self._compile('read: log')
286*4b9c6d91SCole Faust
287*4b9c6d91SCole Faust        self.assertEqual(
288*4b9c6d91SCole Faust            block.simulate(self.arch.arch_nr, self.arch.syscalls['read'],
289*4b9c6d91SCole Faust                           0)[1], 'LOG')
290*4b9c6d91SCole Faust
291*4b9c6d91SCole Faust    def test_mmap_write_xor_exec(self):
292*4b9c6d91SCole Faust        """Accept the idiomatic filter for mmap."""
293*4b9c6d91SCole Faust        block = self._compile(
294*4b9c6d91SCole Faust            'read : arg0 in ~PROT_WRITE || arg0 in ~PROT_EXEC')
295*4b9c6d91SCole Faust
296*4b9c6d91SCole Faust        prot_exec_and_write = 6
297*4b9c6d91SCole Faust        for prot in range(0, 0xf):
298*4b9c6d91SCole Faust            if (prot & prot_exec_and_write) == prot_exec_and_write:
299*4b9c6d91SCole Faust                self.assertEqual(
300*4b9c6d91SCole Faust                    block.simulate(self.arch.arch_nr,
301*4b9c6d91SCole Faust                                   self.arch.syscalls['read'], prot)[1],
302*4b9c6d91SCole Faust                    'KILL_PROCESS')
303*4b9c6d91SCole Faust            else:
304*4b9c6d91SCole Faust                self.assertEqual(
305*4b9c6d91SCole Faust                    block.simulate(self.arch.arch_nr,
306*4b9c6d91SCole Faust                                   self.arch.syscalls['read'], prot)[1],
307*4b9c6d91SCole Faust                    'ALLOW')
308*4b9c6d91SCole Faust
309*4b9c6d91SCole Faust
310*4b9c6d91SCole Faustclass CompileFileTests(unittest.TestCase):
311*4b9c6d91SCole Faust    """Tests for PolicyCompiler.compile_file."""
312*4b9c6d91SCole Faust
313*4b9c6d91SCole Faust    def setUp(self):
314*4b9c6d91SCole Faust        self.arch = ARCH_64
315*4b9c6d91SCole Faust        self.compiler = compiler.PolicyCompiler(self.arch)
316*4b9c6d91SCole Faust        self.tempdir = tempfile.mkdtemp()
317*4b9c6d91SCole Faust
318*4b9c6d91SCole Faust    def tearDown(self):
319*4b9c6d91SCole Faust        shutil.rmtree(self.tempdir)
320*4b9c6d91SCole Faust
321*4b9c6d91SCole Faust    def _write_file(self, filename, contents):
322*4b9c6d91SCole Faust        """Helper to write out a file for testing."""
323*4b9c6d91SCole Faust        path = os.path.join(self.tempdir, filename)
324*4b9c6d91SCole Faust        with open(path, 'w') as outf:
325*4b9c6d91SCole Faust            outf.write(contents)
326*4b9c6d91SCole Faust        return path
327*4b9c6d91SCole Faust
328*4b9c6d91SCole Faust    def test_compile(self):
329*4b9c6d91SCole Faust        """Ensure compilation works with all strategies."""
330*4b9c6d91SCole Faust        self._write_file(
331*4b9c6d91SCole Faust            'test.frequency', """
332*4b9c6d91SCole Faust            read: 1
333*4b9c6d91SCole Faust            close: 10
334*4b9c6d91SCole Faust        """)
335*4b9c6d91SCole Faust        path = self._write_file(
336*4b9c6d91SCole Faust            'test.policy', """
337*4b9c6d91SCole Faust            @frequency ./test.frequency
338*4b9c6d91SCole Faust            read: 1
339*4b9c6d91SCole Faust            close: 1
340*4b9c6d91SCole Faust        """)
341*4b9c6d91SCole Faust
342*4b9c6d91SCole Faust        program = self.compiler.compile_file(
343*4b9c6d91SCole Faust            path,
344*4b9c6d91SCole Faust            optimization_strategy=compiler.OptimizationStrategy.LINEAR,
345*4b9c6d91SCole Faust            kill_action=bpf.KillProcess())
346*4b9c6d91SCole Faust        self.assertGreater(
347*4b9c6d91SCole Faust            bpf.simulate(program.instructions, self.arch.arch_nr,
348*4b9c6d91SCole Faust                         self.arch.syscalls['read'], 0)[0],
349*4b9c6d91SCole Faust            bpf.simulate(program.instructions, self.arch.arch_nr,
350*4b9c6d91SCole Faust                         self.arch.syscalls['close'], 0)[0],
351*4b9c6d91SCole Faust        )
352*4b9c6d91SCole Faust
353*4b9c6d91SCole Faust    def test_compile_bst(self):
354*4b9c6d91SCole Faust        """Ensure compilation with BST is cheaper than the linear model."""
355*4b9c6d91SCole Faust        self._write_file(
356*4b9c6d91SCole Faust            'test.frequency', """
357*4b9c6d91SCole Faust            read: 1
358*4b9c6d91SCole Faust            close: 10
359*4b9c6d91SCole Faust        """)
360*4b9c6d91SCole Faust        path = self._write_file(
361*4b9c6d91SCole Faust            'test.policy', """
362*4b9c6d91SCole Faust            @frequency ./test.frequency
363*4b9c6d91SCole Faust            read: 1
364*4b9c6d91SCole Faust            close: 1
365*4b9c6d91SCole Faust        """)
366*4b9c6d91SCole Faust
367*4b9c6d91SCole Faust        for strategy in list(compiler.OptimizationStrategy):
368*4b9c6d91SCole Faust            program = self.compiler.compile_file(
369*4b9c6d91SCole Faust                path,
370*4b9c6d91SCole Faust                optimization_strategy=strategy,
371*4b9c6d91SCole Faust                kill_action=bpf.KillProcess())
372*4b9c6d91SCole Faust            self.assertGreater(
373*4b9c6d91SCole Faust                bpf.simulate(program.instructions, self.arch.arch_nr,
374*4b9c6d91SCole Faust                             self.arch.syscalls['read'], 0)[0],
375*4b9c6d91SCole Faust                bpf.simulate(program.instructions, self.arch.arch_nr,
376*4b9c6d91SCole Faust                             self.arch.syscalls['close'], 0)[0],
377*4b9c6d91SCole Faust            )
378*4b9c6d91SCole Faust            self.assertEqual(
379*4b9c6d91SCole Faust                bpf.simulate(program.instructions, self.arch.arch_nr,
380*4b9c6d91SCole Faust                             self.arch.syscalls['read'], 0)[1], 'ALLOW')
381*4b9c6d91SCole Faust            self.assertEqual(
382*4b9c6d91SCole Faust                bpf.simulate(program.instructions, self.arch.arch_nr,
383*4b9c6d91SCole Faust                             self.arch.syscalls['close'], 0)[1], 'ALLOW')
384*4b9c6d91SCole Faust
385*4b9c6d91SCole Faust    def test_compile_empty_file(self):
386*4b9c6d91SCole Faust        """Accept empty files."""
387*4b9c6d91SCole Faust        path = self._write_file(
388*4b9c6d91SCole Faust            'test.policy', """
389*4b9c6d91SCole Faust            @default kill-thread
390*4b9c6d91SCole Faust        """)
391*4b9c6d91SCole Faust
392*4b9c6d91SCole Faust        for strategy in list(compiler.OptimizationStrategy):
393*4b9c6d91SCole Faust            program = self.compiler.compile_file(
394*4b9c6d91SCole Faust                path,
395*4b9c6d91SCole Faust                optimization_strategy=strategy,
396*4b9c6d91SCole Faust                kill_action=bpf.KillProcess())
397*4b9c6d91SCole Faust            self.assertEqual(
398*4b9c6d91SCole Faust                bpf.simulate(program.instructions, self.arch.arch_nr,
399*4b9c6d91SCole Faust                             self.arch.syscalls['read'], 0)[1], 'KILL_THREAD')
400*4b9c6d91SCole Faust
401*4b9c6d91SCole Faust    def test_compile_simulate(self):
402*4b9c6d91SCole Faust        """Ensure policy reflects script by testing some random scripts."""
403*4b9c6d91SCole Faust        iterations = 5
404*4b9c6d91SCole Faust        for i in range(iterations):
405*4b9c6d91SCole Faust            num_entries = 64 * (i + 1) // iterations
406*4b9c6d91SCole Faust            syscalls = dict(
407*4b9c6d91SCole Faust                zip(
408*4b9c6d91SCole Faust                    random.sample(
409*4b9c6d91SCole Faust                        list(self.arch.syscalls.keys()), num_entries),
410*4b9c6d91SCole Faust                    (random.randint(1, 1024) for _ in range(num_entries)),
411*4b9c6d91SCole Faust                ))
412*4b9c6d91SCole Faust
413*4b9c6d91SCole Faust            frequency_contents = '\n'.join(
414*4b9c6d91SCole Faust                '%s: %d' % s for s in syscalls.items())
415*4b9c6d91SCole Faust            policy_contents = '@frequency ./test.frequency\n' + '\n'.join(
416*4b9c6d91SCole Faust                '%s: 1' % s[0] for s in syscalls.items())
417*4b9c6d91SCole Faust
418*4b9c6d91SCole Faust            self._write_file('test.frequency', frequency_contents)
419*4b9c6d91SCole Faust            path = self._write_file('test.policy', policy_contents)
420*4b9c6d91SCole Faust
421*4b9c6d91SCole Faust            for strategy in list(compiler.OptimizationStrategy):
422*4b9c6d91SCole Faust                program = self.compiler.compile_file(
423*4b9c6d91SCole Faust                    path,
424*4b9c6d91SCole Faust                    optimization_strategy=strategy,
425*4b9c6d91SCole Faust                    kill_action=bpf.KillProcess())
426*4b9c6d91SCole Faust                for name, number in self.arch.syscalls.items():
427*4b9c6d91SCole Faust                    expected_result = ('ALLOW'
428*4b9c6d91SCole Faust                                       if name in syscalls else 'KILL_PROCESS')
429*4b9c6d91SCole Faust                    self.assertEqual(
430*4b9c6d91SCole Faust                        bpf.simulate(program.instructions, self.arch.arch_nr,
431*4b9c6d91SCole Faust                                     number, 0)[1], expected_result,
432*4b9c6d91SCole Faust                        ('syscall name: %s, syscall number: %d, '
433*4b9c6d91SCole Faust                         'strategy: %s, policy:\n%s') %
434*4b9c6d91SCole Faust                        (name, number, strategy, policy_contents))
435*4b9c6d91SCole Faust
436*4b9c6d91SCole Faust    @unittest.skipIf(not int(os.getenv('SLOW_TESTS', '0')), 'slow')
437*4b9c6d91SCole Faust    def test_compile_huge_policy(self):
438*4b9c6d91SCole Faust        """Ensure jumps while compiling a huge policy are still valid."""
439*4b9c6d91SCole Faust        # Given that the BST strategy is O(n^3), don't choose a crazy large
440*4b9c6d91SCole Faust        # value, but it still needs to be around 128 so that we exercise the
441*4b9c6d91SCole Faust        # codegen paths that depend on the length of the jump.
442*4b9c6d91SCole Faust        #
443*4b9c6d91SCole Faust        # Immediate jump offsets in BPF comparison instructions are limited to
444*4b9c6d91SCole Faust        # 256 instructions, so given that every syscall filter consists of a
445*4b9c6d91SCole Faust        # load and jump instructions, with 128 syscalls there will be at least
446*4b9c6d91SCole Faust        # one jump that's further than 256 instructions.
447*4b9c6d91SCole Faust        num_entries = 128
448*4b9c6d91SCole Faust        syscalls = dict(random.sample(self.arch.syscalls.items(), num_entries))
449*4b9c6d91SCole Faust        # Here we force every single filter to be distinct. Otherwise the
450*4b9c6d91SCole Faust        # codegen layer will coalesce filters that compile to the same
451*4b9c6d91SCole Faust        # instructions.
452*4b9c6d91SCole Faust        policy_contents = '\n'.join(
453*4b9c6d91SCole Faust            '%s: arg0 == %d' % s for s in syscalls.items())
454*4b9c6d91SCole Faust
455*4b9c6d91SCole Faust        path = self._write_file('test.policy', policy_contents)
456*4b9c6d91SCole Faust
457*4b9c6d91SCole Faust        program = self.compiler.compile_file(
458*4b9c6d91SCole Faust            path,
459*4b9c6d91SCole Faust            optimization_strategy=compiler.OptimizationStrategy.BST,
460*4b9c6d91SCole Faust            kill_action=bpf.KillProcess())
461*4b9c6d91SCole Faust        for name, number in self.arch.syscalls.items():
462*4b9c6d91SCole Faust            expected_result = ('ALLOW'
463*4b9c6d91SCole Faust                               if name in syscalls else 'KILL_PROCESS')
464*4b9c6d91SCole Faust            self.assertEqual(
465*4b9c6d91SCole Faust                bpf.simulate(program.instructions, self.arch.arch_nr,
466*4b9c6d91SCole Faust                             self.arch.syscalls[name], number)[1],
467*4b9c6d91SCole Faust                expected_result)
468*4b9c6d91SCole Faust            self.assertEqual(
469*4b9c6d91SCole Faust                bpf.simulate(program.instructions, self.arch.arch_nr,
470*4b9c6d91SCole Faust                             self.arch.syscalls[name], number + 1)[1],
471*4b9c6d91SCole Faust                'KILL_PROCESS')
472*4b9c6d91SCole Faust
473*4b9c6d91SCole Faust    def test_compile_huge_filter(self):
474*4b9c6d91SCole Faust        """Ensure jumps while compiling a huge policy are still valid."""
475*4b9c6d91SCole Faust        # This is intended to force cases where the AST visitation would result
476*4b9c6d91SCole Faust        # in a combinatorial explosion of calls to Block.accept(). An optimized
477*4b9c6d91SCole Faust        # implementation should be O(n).
478*4b9c6d91SCole Faust        num_entries = 128
479*4b9c6d91SCole Faust        syscalls = {}
480*4b9c6d91SCole Faust        # Here we force every single filter to be distinct. Otherwise the
481*4b9c6d91SCole Faust        # codegen layer will coalesce filters that compile to the same
482*4b9c6d91SCole Faust        # instructions.
483*4b9c6d91SCole Faust        policy_contents = []
484*4b9c6d91SCole Faust        for name in random.sample(
485*4b9c6d91SCole Faust            list(self.arch.syscalls.keys()), num_entries):
486*4b9c6d91SCole Faust            values = random.sample(range(1024), num_entries)
487*4b9c6d91SCole Faust            syscalls[name] = values
488*4b9c6d91SCole Faust            policy_contents.append(
489*4b9c6d91SCole Faust                '%s: %s' % (name, ' || '.join('arg0 == %d' % value
490*4b9c6d91SCole Faust                                              for value in values)))
491*4b9c6d91SCole Faust
492*4b9c6d91SCole Faust        path = self._write_file('test.policy', '\n'.join(policy_contents))
493*4b9c6d91SCole Faust
494*4b9c6d91SCole Faust        program = self.compiler.compile_file(
495*4b9c6d91SCole Faust            path,
496*4b9c6d91SCole Faust            optimization_strategy=compiler.OptimizationStrategy.LINEAR,
497*4b9c6d91SCole Faust            kill_action=bpf.KillProcess())
498*4b9c6d91SCole Faust        for name, values in syscalls.items():
499*4b9c6d91SCole Faust            self.assertEqual(
500*4b9c6d91SCole Faust                bpf.simulate(program.instructions,
501*4b9c6d91SCole Faust                             self.arch.arch_nr, self.arch.syscalls[name],
502*4b9c6d91SCole Faust                             random.choice(values))[1], 'ALLOW')
503*4b9c6d91SCole Faust            self.assertEqual(
504*4b9c6d91SCole Faust                bpf.simulate(program.instructions, self.arch.arch_nr,
505*4b9c6d91SCole Faust                             self.arch.syscalls[name], 1025)[1],
506*4b9c6d91SCole Faust                'KILL_PROCESS')
507*4b9c6d91SCole Faust
508*4b9c6d91SCole Faust
509*4b9c6d91SCole Faustif __name__ == '__main__':
510*4b9c6d91SCole Faust    unittest.main()
511