1#!/usr/bin/env python3 2# Copyright (c) PLUMgrid, Inc. 3# Licensed under the Apache License, Version 2.0 (the "License") 4 5from ctypes import c_ushort, c_int, c_ulonglong 6from netaddr import IPAddress 7from bcc import BPF 8from pyroute2 import IPRoute 9from socket import socket, AF_INET, SOCK_DGRAM 10import sys 11from time import sleep 12from unittest import main, TestCase 13from utils import mayFail 14 15arg1 = sys.argv.pop(1) 16 17S_EOP = 1 18S_ETHER = 2 19S_ARP = 3 20S_IP = 4 21 22class TestBPFSocket(TestCase): 23 def setUp(self): 24 b = BPF(src_file=arg1.encode(), debug=0) 25 ether_fn = b.load_func(b"parse_ether", BPF.SCHED_CLS) 26 arp_fn = b.load_func(b"parse_arp", BPF.SCHED_CLS) 27 ip_fn = b.load_func(b"parse_ip", BPF.SCHED_CLS) 28 eop_fn = b.load_func(b"eop", BPF.SCHED_CLS) 29 ip = IPRoute() 30 ifindex = ip.link_lookup(ifname=b"eth0")[0] 31 ip.tc("add", "sfq", ifindex, "1:") 32 ip.tc("add-filter", "bpf", ifindex, ":1", fd=ether_fn.fd, 33 name=ether_fn.name, parent="1:", action="ok", classid=1) 34 self.jump = b.get_table(b"jump", c_int, c_int) 35 self.jump[c_int(S_ARP)] = c_int(arp_fn.fd) 36 self.jump[c_int(S_IP)] = c_int(ip_fn.fd) 37 self.jump[c_int(S_EOP)] = c_int(eop_fn.fd) 38 self.stats = b.get_table(b"stats", c_int, c_ulonglong) 39 40 @mayFail("This may fail on github actions environment due to udp packet loss") 41 def test_jumps(self): 42 udp = socket(AF_INET, SOCK_DGRAM) 43 udp.sendto(b"a" * 10, ("172.16.1.1", 5000)) 44 udp.close() 45 self.assertGreater(self.stats[c_int(S_IP)].value, 0) 46 self.assertGreater(self.stats[c_int(S_ARP)].value, 0) 47 self.assertGreater(self.stats[c_int(S_EOP)].value, 1) 48 49if __name__ == "__main__": 50 main() 51