1#!/usr/bin/env python3 2# Copyright (c) PLUMgrid, Inc. 3# Licensed under the Apache License, Version 2.0 (the "License") 4 5# test program to count the packets sent to a device in a .5 6# second period 7 8from ctypes import c_uint, c_ulong, Structure 9from netaddr import IPAddress 10from bcc import BPF 11from subprocess import check_call 12import sys 13from unittest import main, TestCase 14 15arg1 = sys.argv.pop(1) 16arg2 = "" 17if len(sys.argv) > 1: 18 arg2 = sys.argv.pop(1) 19 20Key = None 21Leaf = None 22 23class TestBPFSocket(TestCase): 24 def setUp(self): 25 b = BPF(arg1.encode(), arg2.encode(), debug=0) 26 fn = b.load_func(b"on_packet", BPF.SOCKET_FILTER) 27 BPF.attach_raw_socket(fn, b"eth0") 28 self.stats = b.get_table(b"stats", Key, Leaf) 29 30 def test_ping(self): 31 cmd = ["ping", "-f", "-c", "100", "172.16.1.1"] 32 check_call(cmd) 33 #for key, leaf in self.stats.items(): 34 # print(IPAddress(key.sip), "=>", IPAddress(key.dip), 35 # "rx", leaf.rx_pkts, "tx", leaf.tx_pkts) 36 key = self.stats.Key(IPAddress("172.16.1.2").value, IPAddress("172.16.1.1").value) 37 leaf = self.stats[key] 38 self.assertEqual(leaf.rx_pkts, 100) 39 self.assertEqual(leaf.tx_pkts, 100) 40 del self.stats[key] 41 with self.assertRaises(KeyError): 42 x = self.stats[key] 43 with self.assertRaises(KeyError): 44 del self.stats[key] 45 self.stats.clear() 46 self.assertEqual(len(self.stats), 0) 47 self.stats[key] = leaf 48 self.assertEqual(len(self.stats), 1) 49 self.stats.clear() 50 self.assertEqual(len(self.stats), 0) 51 52 def test_empty_key(self): 53 # test with a 0 key 54 self.stats.clear() 55 self.stats[self.stats.Key()] = self.stats.Leaf(100, 200) 56 x = self.stats.popitem() 57 self.stats[self.stats.Key(10, 20)] = self.stats.Leaf(300, 400) 58 with self.assertRaises(KeyError): 59 x = self.stats[self.stats.Key()] 60 (_, x) = self.stats.popitem() 61 self.assertEqual(x.rx_pkts, 300) 62 self.assertEqual(x.tx_pkts, 400) 63 self.stats.clear() 64 self.assertEqual(len(self.stats), 0) 65 self.stats[self.stats.Key()] = x 66 self.stats[self.stats.Key(0, 1)] = x 67 self.stats[self.stats.Key(0, 2)] = x 68 self.stats[self.stats.Key(0, 3)] = x 69 self.assertEqual(len(self.stats), 4) 70 71if __name__ == "__main__": 72 main() 73