xref: /aosp_15_r20/external/iptables/iptables-test.py (revision a71a954618bbadd4a345637e5edcf36eec826889)
1#!/usr/bin/env python3
2#
3# (C) 2012-2013 by Pablo Neira Ayuso <[email protected]>
4#
5# This program is free software; you can redistribute it and/or modify
6# it under the terms of the GNU General Public License as published by
7# the Free Software Foundation; either version 2 of the License, or
8# (at your option) any later version.
9#
10# This software has been sponsored by Sophos Astaro <http://www.sophos.com>
11#
12
13from __future__ import print_function
14import sys
15import os
16import subprocess
17import argparse
18
19IPTABLES = "iptables"
20IP6TABLES = "ip6tables"
21ARPTABLES = "arptables"
22EBTABLES = "ebtables"
23
24IPTABLES_SAVE = "iptables-save"
25IP6TABLES_SAVE = "ip6tables-save"
26ARPTABLES_SAVE = "arptables-save"
27EBTABLES_SAVE = "ebtables-save"
28#IPTABLES_SAVE = ['xtables-save','-4']
29#IP6TABLES_SAVE = ['xtables-save','-6']
30
31EXTENSIONS_PATH = "extensions"
32LOGFILE="/tmp/iptables-test.log"
33log_file = None
34
35STDOUT_IS_TTY = sys.stdout.isatty()
36STDERR_IS_TTY = sys.stderr.isatty()
37
38def maybe_colored(color, text, isatty):
39    terminal_sequences = {
40        'green': '\033[92m',
41        'red': '\033[91m',
42    }
43
44    return (
45        terminal_sequences[color] + text + '\033[0m' if isatty else text
46    )
47
48
49def print_error(reason, filename=None, lineno=None):
50    '''
51    Prints an error with nice colors, indicating file and line number.
52    '''
53    print(filename + ": " + maybe_colored('red', "ERROR", STDERR_IS_TTY) +
54        ": line %d (%s)" % (lineno, reason), file=sys.stderr)
55
56
57def delete_rule(iptables, rule, filename, lineno, netns = None):
58    '''
59    Removes an iptables rule
60    '''
61    cmd = iptables + " -D " + rule
62    ret = execute_cmd(cmd, filename, lineno, netns)
63    if ret == 1:
64        reason = "cannot delete: " + iptables + " -I " + rule
65        print_error(reason, filename, lineno)
66        return -1
67
68    return 0
69
70
71def run_test(iptables, rule, rule_save, res, filename, lineno, netns):
72    '''
73    Executes an unit test. Returns the output of delete_rule().
74
75    Parameters:
76    :param iptables: string with the iptables command to execute
77    :param rule: string with iptables arguments for the rule to test
78    :param rule_save: string to find the rule in the output of iptables-save
79    :param res: expected result of the rule. Valid values: "OK", "FAIL"
80    :param filename: name of the file tested (used for print_error purposes)
81    :param lineno: line number being tested (used for print_error purposes)
82    :param netns: network namespace to call commands in (or None)
83    '''
84    ret = 0
85
86    cmd = iptables + " -A " + rule
87    ret = execute_cmd(cmd, filename, lineno, netns)
88
89    #
90    # report failed test
91    #
92    if ret:
93        if res != "FAIL":
94            reason = "cannot load: " + cmd
95            print_error(reason, filename, lineno)
96            return -1
97        else:
98            # do not report this error
99            return 0
100    else:
101        if res == "FAIL":
102            reason = "should fail: " + cmd
103            print_error(reason, filename, lineno)
104            delete_rule(iptables, rule, filename, lineno, netns)
105            return -1
106
107    matching = 0
108    tokens = iptables.split(" ")
109    if len(tokens) == 2:
110        if tokens[1] == '-4':
111            command = IPTABLES_SAVE
112        elif tokens[1] == '-6':
113            command = IP6TABLES_SAVE
114    elif len(tokens) == 1:
115        if tokens[0] == IPTABLES:
116            command = IPTABLES_SAVE
117        elif tokens[0] == IP6TABLES:
118            command = IP6TABLES_SAVE
119        elif tokens[0] == ARPTABLES:
120            command = ARPTABLES_SAVE
121        elif tokens[0] == EBTABLES:
122            command = EBTABLES_SAVE
123
124    command = EXECUTABLE + " " + command
125
126    if netns:
127            command = "ip netns exec " + netns + " " + command
128
129    args = tokens[1:]
130    proc = subprocess.Popen(command, shell=True,
131                            stdin=subprocess.PIPE,
132                            stdout=subprocess.PIPE, stderr=subprocess.PIPE)
133    out, err = proc.communicate()
134
135    #
136    # check for segfaults
137    #
138    if proc.returncode == -11:
139        reason = command + " segfaults!"
140        print_error(reason, filename, lineno)
141        delete_rule(iptables, rule, filename, lineno, netns)
142        return -1
143
144    # find the rule
145    matching = out.find(rule_save.encode('utf-8'))
146    if matching < 0:
147        if res == "OK":
148            reason = "cannot find: " + iptables + " -I " + rule
149            print_error(reason, filename, lineno)
150            delete_rule(iptables, rule, filename, lineno, netns)
151            return -1
152        else:
153            # do not report this error
154            return 0
155    else:
156        if res != "OK":
157            reason = "should not match: " + cmd
158            print_error(reason, filename, lineno)
159            delete_rule(iptables, rule, filename, lineno, netns)
160            return -1
161
162    # Test "ip netns del NETNS" path with rules in place
163    if netns:
164        return 0
165
166    return delete_rule(iptables, rule, filename, lineno)
167
168def execute_cmd(cmd, filename, lineno = 0, netns = None):
169    '''
170    Executes a command, checking for segfaults and returning the command exit
171    code.
172
173    :param cmd: string with the command to be executed
174    :param filename: name of the file tested (used for print_error purposes)
175    :param lineno: line number being tested (used for print_error purposes)
176    :param netns: network namespace to run command in
177    '''
178    global log_file
179    if cmd.startswith('iptables ') or cmd.startswith('ip6tables ') or cmd.startswith('ebtables ') or cmd.startswith('arptables '):
180        cmd = EXECUTABLE + " " + cmd
181
182    if netns:
183        cmd = "ip netns exec " + netns + " " + cmd
184
185    print("command: {}".format(cmd), file=log_file)
186    ret = subprocess.call(cmd, shell=True, universal_newlines=True,
187        stderr=subprocess.STDOUT, stdout=log_file)
188    log_file.flush()
189
190    # generic check for segfaults
191    if ret == -11:
192        reason = "command segfaults: " + cmd
193        print_error(reason, filename, lineno)
194    return ret
195
196
197def variant_res(res, variant, alt_res=None):
198    '''
199    Adjust expected result with given variant
200
201    If expected result is scoped to a variant, the other one yields a different
202    result. Therefore map @res to itself if given variant is current, use the
203    alternate result, @alt_res, if specified, invert @res otherwise.
204
205    :param res: expected result from test spec ("OK", "FAIL" or "NOMATCH")
206    :param variant: variant @res is scoped to by test spec ("NFT" or "LEGACY")
207    :param alt_res: optional expected result for the alternate variant.
208    '''
209    variant_executable = {
210        "NFT": "xtables-nft-multi",
211        "LEGACY": "xtables-legacy-multi"
212    }
213    res_inverse = {
214        "OK": "FAIL",
215        "FAIL": "OK",
216        "NOMATCH": "OK"
217    }
218
219    if variant_executable[variant] == EXECUTABLE:
220        return res
221    if alt_res is not None:
222        return alt_res
223    return res_inverse[res]
224
225def fast_run_possible(filename):
226    '''
227    Keep things simple, run only for simple test files:
228    - no external commands
229    - no multiple tables
230    - no variant-specific results
231    '''
232    table = None
233    rulecount = 0
234    for line in open(filename):
235        if line[0] in ["#", ":"] or len(line.strip()) == 0:
236            continue
237        if line[0] == "*":
238            if table or rulecount > 0:
239                return False
240            table = line.rstrip()[1:]
241        if line[0] in ["@", "%"]:
242            return False
243        if len(line.split(";")) > 3:
244            return False
245        rulecount += 1
246
247    return True
248
249def run_test_file_fast(iptables, filename, netns):
250    '''
251    Run a test file, but fast
252
253    :param filename: name of the file with the test rules
254    :param netns: network namespace to perform test run in
255    '''
256
257    f = open(filename)
258
259    rules = {}
260    table = "filter"
261    chain_array = []
262    tests = 0
263
264    for lineno, line in enumerate(f):
265        if line[0] == "#" or len(line.strip()) == 0:
266            continue
267
268        if line[0] == "*":
269            table = line.rstrip()[1:]
270            continue
271
272        if line[0] == ":":
273            chain_array = line.rstrip()[1:].split(",")
274            continue
275
276        if len(chain_array) == 0:
277            return -1
278
279        tests += 1
280
281        for chain in chain_array:
282            item = line.split(";")
283            rule = chain + " " + item[0]
284
285            if item[1] == "=":
286                rule_save = chain + " " + item[0]
287            else:
288                rule_save = chain + " " + item[1]
289
290            if iptables == EBTABLES and rule_save.find('-j') < 0:
291                rule_save += " -j CONTINUE"
292
293            res = item[2].rstrip()
294            if res != "OK":
295                rule = chain + " -t " + table + " " + item[0]
296                ret = run_test(iptables, rule, rule_save,
297                               res, filename, lineno + 1, netns)
298
299                if ret < 0:
300                    return -1
301                continue
302
303            if not chain in rules.keys():
304                rules[chain] = []
305            rules[chain].append((rule, rule_save))
306
307    restore_data = ["*" + table]
308    out_expect = []
309    for chain in ["PREROUTING", "INPUT", "FORWARD", "OUTPUT", "POSTROUTING"]:
310        if not chain in rules.keys():
311            continue
312        for rule in rules[chain]:
313            restore_data.append("-A " + rule[0])
314            out_expect.append("-A " + rule[1])
315    restore_data.append("COMMIT")
316
317    out_expect = "\n".join(out_expect)
318
319    # load all rules via iptables_restore
320
321    command = EXECUTABLE + " " + iptables + "-restore"
322    if netns:
323        command = "ip netns exec " + netns + " " + command
324
325    for line in restore_data:
326        print(iptables + "-restore: " + line, file=log_file)
327
328    proc = subprocess.Popen(command, shell = True, text = True,
329                            stdin = subprocess.PIPE,
330                            stdout = subprocess.PIPE,
331                            stderr = subprocess.PIPE)
332    restore_data = "\n".join(restore_data) + "\n"
333    out, err = proc.communicate(input = restore_data)
334
335    if proc.returncode == -11:
336        reason = iptables + "-restore segfaults!"
337        print_error(reason, filename, lineno)
338        msg = [iptables + "-restore segfault from:"]
339        msg.extend(["input: " + l for l in restore_data.split("\n")])
340        print("\n".join(msg), file=log_file)
341        return -1
342
343    if proc.returncode != 0:
344        print("%s-restore returned %d: %s" % (iptables, proc.returncode, err),
345              file=log_file)
346        return -1
347
348    # find all rules in iptables_save output
349
350    command = EXECUTABLE + " " + iptables + "-save"
351    if netns:
352        command = "ip netns exec " + netns + " " + command
353
354    proc = subprocess.Popen(command, shell = True,
355                            stdin = subprocess.PIPE,
356                            stdout = subprocess.PIPE,
357                            stderr = subprocess.PIPE)
358    out, err = proc.communicate()
359
360    if proc.returncode == -11:
361        reason = iptables + "-save segfaults!"
362        print_error(reason, filename, lineno)
363        return -1
364
365    cmd = iptables + " -F -t " + table
366    execute_cmd(cmd, filename, 0, netns)
367
368    out = out.decode('utf-8').rstrip()
369    if out.find(out_expect) < 0:
370        msg = ["dumps differ!"]
371        msg.extend(["expect: " + l for l in out_expect.split("\n")])
372        msg.extend(["got: " + l for l in out.split("\n")
373                                if not l[0] in ['*', ':', '#']])
374        print("\n".join(msg), file=log_file)
375        return -1
376
377    return tests
378
379def run_test_file(filename, netns):
380    '''
381    Runs a test file
382
383    :param filename: name of the file with the test rules
384    :param netns: network namespace to perform test run in
385    '''
386    #
387    # if this is not a test file, skip.
388    #
389    if not filename.endswith(".t"):
390        return 0, 0
391
392    if "libipt_" in filename:
393        iptables = IPTABLES
394    elif "libip6t_" in filename:
395        iptables = IP6TABLES
396    elif "libxt_"  in filename:
397        iptables = IPTABLES
398    elif "libarpt_" in filename:
399        # only supported with nf_tables backend
400        if EXECUTABLE != "xtables-nft-multi":
401           return 0, 0
402        iptables = ARPTABLES
403    elif "libebt_" in filename:
404        # only supported with nf_tables backend
405        if EXECUTABLE != "xtables-nft-multi":
406           return 0, 0
407        iptables = EBTABLES
408    else:
409        # default to iptables if not known prefix
410        iptables = IPTABLES
411
412    fast_failed = False
413    if fast_run_possible(filename):
414        tests = run_test_file_fast(iptables, filename, netns)
415        if tests > 0:
416            print(filename + ": " + maybe_colored('green', "OK", STDOUT_IS_TTY))
417            return tests, tests
418        fast_failed = True
419
420    f = open(filename)
421
422    tests = 0
423    passed = 0
424    table = ""
425    chain_array = []
426    total_test_passed = True
427
428    if netns:
429        execute_cmd("ip netns add " + netns, filename)
430
431    for lineno, line in enumerate(f):
432        if line[0] == "#" or len(line.strip()) == 0:
433            continue
434
435        if line[0] == ":":
436            chain_array = line.rstrip()[1:].split(",")
437            continue
438
439        # external command invocation, executed as is.
440        # detects iptables commands to prefix with EXECUTABLE automatically
441        if line[0] in ["@", "%"]:
442            external_cmd = line.rstrip()[1:]
443            execute_cmd(external_cmd, filename, lineno, netns)
444            continue
445
446        if line[0] == "*":
447            table = line.rstrip()[1:]
448            continue
449
450        if len(chain_array) == 0:
451            print_error("broken test, missing chain",
452                        filename = filename, lineno = lineno)
453            total_test_passed = False
454            break
455
456        test_passed = True
457        tests += 1
458
459        for chain in chain_array:
460            item = line.split(";")
461            if table == "":
462                rule = chain + " " + item[0]
463            else:
464                rule = chain + " -t " + table + " " + item[0]
465
466            if item[1] == "=":
467                rule_save = chain + " " + item[0]
468            else:
469                rule_save = chain + " " + item[1]
470
471            res = item[2].rstrip()
472            if len(item) > 3:
473                variant = item[3].rstrip()
474                if len(item) > 4:
475                    alt_res = item[4].rstrip()
476                else:
477                    alt_res = None
478                res = variant_res(res, variant, alt_res)
479
480            ret = run_test(iptables, rule, rule_save,
481                           res, filename, lineno + 1, netns)
482
483            if ret < 0:
484                test_passed = False
485                total_test_passed = False
486                break
487
488        if test_passed:
489            passed += 1
490
491    if netns:
492        execute_cmd("ip netns del " + netns, filename)
493    if total_test_passed:
494        suffix = ""
495        if fast_failed:
496            suffix = maybe_colored('red', " but fast mode failed!", STDOUT_IS_TTY)
497        print(filename + ": " + maybe_colored('green', "OK", STDOUT_IS_TTY) + suffix)
498
499    f.close()
500    return tests, passed
501
502
503def show_missing():
504    '''
505    Show the list of missing test files
506    '''
507    file_list = os.listdir(EXTENSIONS_PATH)
508    testfiles = [i for i in file_list if i.endswith('.t')]
509    libfiles = [i for i in file_list
510                if i.startswith('lib') and i.endswith('.c')]
511
512    def test_name(x):
513        return x[0:-2] + '.t'
514    missing = [test_name(i) for i in libfiles
515               if not test_name(i) in testfiles]
516
517    print('\n'.join(missing))
518
519def spawn_netns():
520    # prefer unshare module
521    try:
522        import unshare
523        unshare.unshare(unshare.CLONE_NEWNET)
524        return True
525    except:
526        pass
527
528    # sledgehammer style:
529    # - call ourselves prefixed by 'unshare -n' if found
530    # - pass extra --no-netns parameter to avoid another recursion
531    try:
532        import shutil
533
534        unshare = shutil.which("unshare")
535        if unshare is None:
536            return False
537
538        sys.argv.append("--no-netns")
539        os.execv(unshare, [unshare, "-n", sys.executable] + sys.argv)
540    except:
541        pass
542
543    return False
544
545#
546# main
547#
548def main():
549    parser = argparse.ArgumentParser(description='Run iptables tests')
550    parser.add_argument('filename', nargs='*',
551                        metavar='path/to/file.t',
552                        help='Run only this test')
553    parser.add_argument('-H', '--host', action='store_true',
554                        help='Run tests against installed binaries')
555    parser.add_argument('-l', '--legacy', action='store_true',
556                        help='Test iptables-legacy')
557    parser.add_argument('-m', '--missing', action='store_true',
558                        help='Check for missing tests')
559    parser.add_argument('-n', '--nftables', action='store_true',
560                        help='Test iptables-over-nftables')
561    parser.add_argument('-N', '--netns', action='store_const',
562                        const='____iptables-container-test',
563                        help='Test netnamespace path')
564    parser.add_argument('--no-netns', action='store_true',
565                        help='Do not run testsuite in own network namespace')
566    args = parser.parse_args()
567
568    #
569    # show list of missing test files
570    #
571    if args.missing:
572        show_missing()
573        return
574
575    variants = []
576    if args.legacy:
577        variants.append("legacy")
578    if args.nftables:
579        variants.append("nft")
580    if len(variants) == 0:
581        variants = [ "legacy", "nft" ]
582
583    if os.getuid() != 0:
584        print("You need to be root to run this, sorry", file=sys.stderr)
585        return 77
586
587    if not args.netns and not args.no_netns and not spawn_netns():
588        print("Cannot run in own namespace, connectivity might break",
589              file=sys.stderr)
590
591    if not args.host:
592        os.putenv("XTABLES_LIBDIR", os.path.abspath(EXTENSIONS_PATH))
593        os.putenv("PATH", "%s/iptables:%s" % (os.path.abspath(os.path.curdir),
594                                              os.getenv("PATH")))
595
596    total_test_files = 0
597    total_passed = 0
598    total_tests = 0
599    for variant in variants:
600        global EXECUTABLE
601        EXECUTABLE = "xtables-" + variant + "-multi"
602
603        test_files = 0
604        tests = 0
605        passed = 0
606
607        # setup global var log file
608        global log_file
609        try:
610            log_file = open(LOGFILE, 'w')
611        except IOError:
612            print("Couldn't open log file %s" % LOGFILE, file=sys.stderr)
613            return
614
615        if args.filename:
616            file_list = args.filename
617        else:
618            file_list = [os.path.join(EXTENSIONS_PATH, i)
619                         for i in os.listdir(EXTENSIONS_PATH)
620                         if i.endswith('.t')]
621            file_list.sort()
622
623        for filename in file_list:
624            file_tests, file_passed = run_test_file(filename, args.netns)
625            if file_tests:
626                tests += file_tests
627                passed += file_passed
628                test_files += 1
629
630        print("%s: %d test files, %d unit tests, %d passed"
631              % (variant, test_files, tests, passed))
632
633        total_passed += passed
634        total_tests += tests
635        total_test_files = max(total_test_files, test_files)
636
637    if len(variants) > 1:
638        print("total: %d test files, %d unit tests, %d passed"
639              % (total_test_files, total_tests, total_passed))
640    return total_passed - total_tests
641
642if __name__ == '__main__':
643    sys.exit(main())
644