xref: /aosp_15_r20/kernel/tests/net/test/bpf_test.py (revision 2f2c4c7ab4226c71756b9c31670392fdd6887c4f)
1#!/usr/bin/python3
2#
3# Copyright 2016 The Android Open Source Project
4#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16
17import ctypes
18import errno
19import os
20import socket
21import unittest
22
23import bpf
24from bpf import BPF_ADD
25from bpf import BPF_AND
26from bpf import BPF_CGROUP_INET_EGRESS
27from bpf import BPF_CGROUP_INET_INGRESS
28from bpf import BPF_CGROUP_INET_SOCK_CREATE
29from bpf import BPF_DW
30from bpf import BPF_F_RDONLY
31from bpf import BPF_F_WRONLY
32from bpf import BPF_FUNC_get_current_uid_gid
33from bpf import BPF_FUNC_get_socket_cookie
34from bpf import BPF_FUNC_get_socket_uid
35from bpf import BPF_FUNC_ktime_get_boot_ns
36from bpf import BPF_FUNC_ktime_get_ns
37from bpf import BPF_FUNC_map_lookup_elem
38from bpf import BPF_FUNC_map_update_elem
39from bpf import BPF_FUNC_skb_change_head
40from bpf import BPF_JNE
41from bpf import BPF_MAP_TYPE_ARRAY
42from bpf import BPF_MAP_TYPE_HASH
43from bpf import BPF_PROG_TYPE_CGROUP_SKB
44from bpf import BPF_PROG_TYPE_CGROUP_SOCK
45from bpf import BPF_PROG_TYPE_SCHED_CLS
46from bpf import BPF_PROG_TYPE_SOCKET_FILTER
47from bpf import BPF_REG_0
48from bpf import BPF_REG_1
49from bpf import BPF_REG_10
50from bpf import BPF_REG_2
51from bpf import BPF_REG_3
52from bpf import BPF_REG_4
53from bpf import BPF_REG_6
54from bpf import BPF_REG_7
55from bpf import BPF_STX
56from bpf import BPF_W
57from bpf import BPF_XADD
58from bpf import BpfAlu64Imm
59from bpf import BpfExitInsn
60from bpf import BpfFuncCall
61from bpf import BpfJumpImm
62from bpf import BpfLdxMem
63from bpf import BpfLoadMapFd
64from bpf import BpfMov64Imm
65from bpf import BpfMov64Reg
66from bpf import BpfProgAttach
67from bpf import BpfProgAttachSocket
68from bpf import BpfProgDetach
69from bpf import BpfProgGetFdById
70from bpf import BpfProgLoad
71from bpf import BpfProgQuery
72from bpf import BpfRawInsn
73from bpf import BpfStMem
74from bpf import BpfStxMem
75from bpf import CreateMap
76from bpf import DeleteMap
77from bpf import GetFirstKey
78from bpf import GetNextKey
79from bpf import LookupMap
80from bpf import UpdateMap
81import csocket
82import net_test
83import sock_diag
84
85libc = ctypes.CDLL(ctypes.util.find_library("c"), use_errno=True)
86
87KEY_SIZE = 4
88VALUE_SIZE = 4
89TOTAL_ENTRIES = 20
90TEST_UID = 5432
91TEST_GID = 12345
92# Offset to store the map key in stack register REG10
93key_offset = -8
94# Offset to store the map value in stack register REG10
95value_offset = -16
96
97
98# Debug usage only.
99def PrintMapInfo(map_fd):
100  # A random key that the map does not contain.
101  key = 10086
102  while 1:
103    try:
104      next_key = GetNextKey(map_fd, key).value
105      value = LookupMap(map_fd, next_key)
106      print(repr(next_key) + " : " + repr(value.value))  # pylint: disable=superfluous-parens
107      key = next_key
108    except socket.error:
109      print("no value")  # pylint: disable=superfluous-parens
110      break
111
112
113# A dummy loopback function that causes a socket to send traffic to itself.
114def SocketUDPLoopBack(packet_count, version, prog_fd):
115  family = {4: socket.AF_INET, 6: socket.AF_INET6}[version]
116  sock = socket.socket(family, socket.SOCK_DGRAM, 0)
117  try:
118    if prog_fd is not None:
119      BpfProgAttachSocket(sock.fileno(), prog_fd)
120    net_test.SetNonBlocking(sock)
121    addr = {4: "127.0.0.1", 6: "::1"}[version]
122    sock.bind((addr, 0))
123    addr = sock.getsockname()
124    sockaddr = csocket.Sockaddr(addr)
125    for _ in range(packet_count):
126      sock.sendto(b"foo", addr)
127      data, retaddr = csocket.Recvfrom(sock, 4096, 0)
128      assert b"foo" == data
129      assert sockaddr == retaddr
130    return sock
131  except Exception as e:
132    sock.close()
133    raise e
134
135
136# The main code block for eBPF packet counting program. It takes a preloaded
137# key from BPF_REG_0 and use it to look up the bpf map, if the element does not
138# exist in the map yet, the program will update the map with a new <key, 1>
139# pair. Otherwise it will jump to next code block to handle it.
140# REG0: regiter storing return value from helper function and the final return
141# value of eBPF program.
142# REG1 - REG5: temporary register used for storing values and load parameters
143# into eBPF helper function. After calling helper function, the value for these
144# registers will be reset.
145# REG6 - REG9: registers store values that will not be cleared when calling
146# eBPF helper function.
147# REG10: A stack stores values need to be accessed by the address. Program can
148# retrieve the address of a value by specifying the position of the value in
149# the stack.
150def BpfFuncCountPacketInit(map_fd):
151  key_pos = BPF_REG_7
152  return [
153      # Get a preloaded key from BPF_REG_0 and store it at BPF_REG_7
154      BpfMov64Reg(key_pos, BPF_REG_10),
155      BpfAlu64Imm(BPF_ADD, key_pos, key_offset),
156      # Load map fd and look up the key in the map
157      BpfLoadMapFd(map_fd, BPF_REG_1),
158      BpfMov64Reg(BPF_REG_2, key_pos),
159      BpfFuncCall(BPF_FUNC_map_lookup_elem),
160      # if the map element already exist, jump out of this
161      # code block and let next part to handle it
162      BpfJumpImm(BPF_AND, BPF_REG_0, 0, 10),
163      BpfLoadMapFd(map_fd, BPF_REG_1),
164      BpfMov64Reg(BPF_REG_2, key_pos),
165      # Initial a new <key, value> pair with value equal to 1 and update to map
166      BpfStMem(BPF_W, BPF_REG_10, value_offset, 1),
167      BpfMov64Reg(BPF_REG_3, BPF_REG_10),
168      BpfAlu64Imm(BPF_ADD, BPF_REG_3, value_offset),
169      BpfMov64Imm(BPF_REG_4, 0),
170      BpfFuncCall(BPF_FUNC_map_update_elem)
171  ]
172
173
174INS_BPF_EXIT_BLOCK = [
175    BpfMov64Imm(BPF_REG_0, 0),
176    BpfExitInsn()
177]
178
179# Bpf instruction for cgroup bpf filter to accept a packet and exit.
180INS_CGROUP_ACCEPT = [
181    # Set return value to 1 and exit.
182    BpfMov64Imm(BPF_REG_0, 1),
183    BpfExitInsn()
184]
185
186# Bpf instruction for socket bpf filter to accept a packet and exit.
187INS_SK_FILTER_ACCEPT = [
188    # Precondition: BPF_REG_6 = sk_buff context
189    # Load the packet length from BPF_REG_6 and store it in BPF_REG_0 as the
190    # return value.
191    BpfLdxMem(BPF_W, BPF_REG_0, BPF_REG_6, 0),
192    BpfExitInsn()
193]
194
195# Update a existing map element with +1.
196INS_PACK_COUNT_UPDATE = [
197    # Precondition: BPF_REG_0 = Value retrieved from BPF maps
198    # Add one to the corresponding eBPF value field for a specific eBPF key.
199    BpfMov64Reg(BPF_REG_2, BPF_REG_0),
200    BpfMov64Imm(BPF_REG_1, 1),
201    BpfRawInsn(BPF_STX | BPF_XADD | BPF_W, BPF_REG_2, BPF_REG_1, 0, 0),
202]
203
204INS_BPF_PARAM_STORE = [
205    BpfStxMem(BPF_DW, BPF_REG_10, BPF_REG_0, key_offset),
206]
207
208
209class BpfTest(net_test.NetworkTest):
210
211  def setUp(self):
212    super(BpfTest, self).setUp()
213    self.map_fd = None
214    self.prog_fd = None
215    self.sock = None
216
217  def tearDown(self):
218    if self.prog_fd is not None:
219      os.close(self.prog_fd)
220      self.prog_fd = None
221    if self.map_fd is not None:
222      os.close(self.map_fd)
223      self.map_fd = None
224    if self.sock:
225      self.sock.close()
226      self.sock = None
227    super(BpfTest, self).tearDown()
228
229  def testCreateMap(self):
230    key, value = 1, 1
231    self.map_fd = CreateMap(BPF_MAP_TYPE_HASH, KEY_SIZE, VALUE_SIZE,
232                            TOTAL_ENTRIES)
233    UpdateMap(self.map_fd, key, value)
234    self.assertEqual(value, LookupMap(self.map_fd, key).value)
235    DeleteMap(self.map_fd, key)
236    self.assertRaisesErrno(errno.ENOENT, LookupMap, self.map_fd, key)
237
238  def CheckAllMapEntry(self, nonexistent_key, total_entries, value):
239    count = 0
240    key = nonexistent_key
241    while True:
242      if count == total_entries:
243        self.assertRaisesErrno(errno.ENOENT, GetNextKey, self.map_fd, key)
244        break
245      else:
246        result = GetNextKey(self.map_fd, key)
247        key = result.value
248        self.assertGreaterEqual(key, 0)
249        self.assertEqual(value, LookupMap(self.map_fd, key).value)
250        count += 1
251
252  def testIterateMap(self):
253    self.map_fd = CreateMap(BPF_MAP_TYPE_HASH, KEY_SIZE, VALUE_SIZE,
254                            TOTAL_ENTRIES)
255    value = 1024
256    for key in range(0, TOTAL_ENTRIES):
257      UpdateMap(self.map_fd, key, value)
258    for key in range(0, TOTAL_ENTRIES):
259      self.assertEqual(value, LookupMap(self.map_fd, key).value)
260    self.assertRaisesErrno(errno.ENOENT, LookupMap, self.map_fd, 101)
261    nonexistent_key = -1
262    self.CheckAllMapEntry(nonexistent_key, TOTAL_ENTRIES, value)
263
264  def testFindFirstMapKey(self):
265    self.map_fd = CreateMap(BPF_MAP_TYPE_HASH, KEY_SIZE, VALUE_SIZE,
266                            TOTAL_ENTRIES)
267    value = 1024
268    for key in range(0, TOTAL_ENTRIES):
269      UpdateMap(self.map_fd, key, value)
270    first_key = GetFirstKey(self.map_fd)
271    key = first_key.value
272    self.CheckAllMapEntry(key, TOTAL_ENTRIES - 1, value)
273
274  def testArrayNonZeroOffset(self):
275    self.map_fd = CreateMap(BPF_MAP_TYPE_ARRAY, KEY_SIZE, VALUE_SIZE, 2)
276    key = 1
277    value = 123
278    UpdateMap(self.map_fd, key, value)
279    self.assertEqual(value, LookupMap(self.map_fd, key).value)
280
281  def testRdOnlyMap(self):
282    self.map_fd = CreateMap(BPF_MAP_TYPE_HASH, KEY_SIZE, VALUE_SIZE,
283                            TOTAL_ENTRIES, map_flags=BPF_F_RDONLY)
284    value = 1024
285    key = 1
286    self.assertRaisesErrno(errno.EPERM, UpdateMap, self.map_fd, key, value)
287    self.assertRaisesErrno(errno.ENOENT, LookupMap, self.map_fd, key)
288
289  def testWrOnlyMap(self):
290    self.map_fd = CreateMap(BPF_MAP_TYPE_HASH, KEY_SIZE, VALUE_SIZE,
291                            TOTAL_ENTRIES, map_flags=BPF_F_WRONLY)
292    value = 1024
293    key = 1
294    UpdateMap(self.map_fd, key, value)
295    self.assertRaisesErrno(errno.EPERM, LookupMap, self.map_fd, key)
296
297  def testProgLoad(self):
298    # Move skb to BPF_REG_6 for further usage
299    instructions = [
300        BpfMov64Reg(BPF_REG_6, BPF_REG_1)
301    ]
302    instructions += INS_SK_FILTER_ACCEPT
303    self.prog_fd = BpfProgLoad(BPF_PROG_TYPE_SOCKET_FILTER, instructions)
304    SocketUDPLoopBack(1, 4, self.prog_fd).close()
305    SocketUDPLoopBack(1, 6, self.prog_fd).close()
306
307  def testPacketBlock(self):
308    self.prog_fd = BpfProgLoad(BPF_PROG_TYPE_SOCKET_FILTER, INS_BPF_EXIT_BLOCK)
309    self.assertRaisesErrno(errno.EAGAIN, SocketUDPLoopBack, 1, 4, self.prog_fd)
310    self.assertRaisesErrno(errno.EAGAIN, SocketUDPLoopBack, 1, 6, self.prog_fd)
311
312  def testPacketCount(self):
313    self.map_fd = CreateMap(BPF_MAP_TYPE_HASH, KEY_SIZE, VALUE_SIZE,
314                            TOTAL_ENTRIES)
315    key = 0xf0f0
316    # Set up instruction block with key loaded at BPF_REG_0.
317    instructions = [
318        BpfMov64Reg(BPF_REG_6, BPF_REG_1),
319        BpfMov64Imm(BPF_REG_0, key)
320    ]
321    # Concatenate the generic packet count bpf program to it.
322    instructions += (INS_BPF_PARAM_STORE + BpfFuncCountPacketInit(self.map_fd)
323                     + INS_SK_FILTER_ACCEPT + INS_PACK_COUNT_UPDATE
324                     + INS_SK_FILTER_ACCEPT)
325    self.prog_fd = BpfProgLoad(BPF_PROG_TYPE_SOCKET_FILTER, instructions)
326    packet_count = 10
327    SocketUDPLoopBack(packet_count, 4, self.prog_fd).close()
328    SocketUDPLoopBack(packet_count, 6, self.prog_fd).close()
329    self.assertEqual(packet_count * 2, LookupMap(self.map_fd, key).value)
330
331  ##############################################################################
332  #
333  # Test for presence of kernel patch:
334  #
335  #   ANDROID: net: bpf: Allow TC programs to call BPF_FUNC_skb_change_head
336  #
337  # 4.14: https://android-review.googlesource.com/c/kernel/common/+/1237789
338  #       commit fe82848d9c1c887d2a84d3738c13e644d01b6d6f
339  #
340  # 4.19: https://android-review.googlesource.com/c/kernel/common/+/1237788
341  #       commit 6e04d94ab72435b45c413daff63520fd724e260e
342  #
343  # 5.4:  https://android-review.googlesource.com/c/kernel/common/+/1237787
344  #       commit d730995e7bc5b4c10cc176235b704a274e6ec16f
345  #
346  # Upstream in Linux v5.8:
347  #   net: bpf: Allow TC programs to call BPF_FUNC_skb_change_head
348  #   commit 6f3f65d80dac8f2bafce2213005821fccdce194c
349  #
350  def testSkbChangeHead(self):
351    # long bpf_skb_change_head(struct sk_buff *skb, u32 len, u64 flags)
352    instructions = [
353        BpfMov64Imm(BPF_REG_2, 14),  # u32 len
354        BpfMov64Imm(BPF_REG_3, 0),   # u64 flags
355        BpfFuncCall(BPF_FUNC_skb_change_head),
356    ] + INS_BPF_EXIT_BLOCK
357    self.prog_fd = BpfProgLoad(BPF_PROG_TYPE_SCHED_CLS, instructions,
358                               b"Apache 2.0")
359    # No exceptions? Good.
360
361  def testKtimeGetNsGPL(self):
362    instructions = [BpfFuncCall(BPF_FUNC_ktime_get_ns)] + INS_BPF_EXIT_BLOCK
363    self.prog_fd = BpfProgLoad(BPF_PROG_TYPE_SCHED_CLS, instructions)
364    # No exceptions? Good.
365
366  ##############################################################################
367  #
368  # Test for presence of kernel patch:
369  #
370  #   UPSTREAM: net: bpf: Make bpf_ktime_get_ns() available to non GPL programs
371  #
372  # 4.14: https://android-review.googlesource.com/c/kernel/common/+/1585269
373  #       commit cbb4c73f9eab8f3c8ac29175d45c99ccba382e15
374  #
375  # 4.19: https://android-review.googlesource.com/c/kernel/common/+/1355243
376  #       commit 272e21ccc9a92feeee80aff0587410a314b73c5b
377  #
378  # 5.4:  https://android-review.googlesource.com/c/kernel/common/+/1355422
379  #       commit 45217b91eaaa3a563247c4f470f4cb785de6b1c6
380  #
381  def testKtimeGetNsApache2(self):
382    instructions = [BpfFuncCall(BPF_FUNC_ktime_get_ns)] + INS_BPF_EXIT_BLOCK
383    self.prog_fd = BpfProgLoad(BPF_PROG_TYPE_SCHED_CLS, instructions,
384                               b"Apache 2.0")
385    # No exceptions? Good.
386
387  ##############################################################################
388  #
389  # Test for presence of kernel patch:
390  #
391  #   BACKPORT: bpf: add bpf_ktime_get_boot_ns()
392  #
393  # 4.14: https://android-review.googlesource.com/c/kernel/common/+/1585587
394  #       commit 34073d7a8ee47ca908b56e9a1d14ca0615fdfc09
395  #
396  # 4.19: https://android-review.googlesource.com/c/kernel/common/+/1585606
397  #       commit 4812ec50935dfe59ba9f48a572e278dd0b02af68
398  #
399  # 5.4:  https://android-review.googlesource.com/c/kernel/common/+/1585252
400  #       commit 57b3f4830fb66a6038c4c1c66ca2e138fe8be231
401  #
402  def testKtimeGetBootNs(self):
403    instructions = [
404        BpfFuncCall(BPF_FUNC_ktime_get_boot_ns),
405    ] + INS_BPF_EXIT_BLOCK
406    self.prog_fd = BpfProgLoad(BPF_PROG_TYPE_SCHED_CLS, instructions,
407                               b"Apache 2.0")
408    # No exceptions? Good.
409
410  ##############################################################################
411  #
412  # Test for presence of upstream 5.14 kernel patches:
413  #
414  # Android12-5.10:
415  #   UPSTREAM: net: initialize net->net_cookie at netns setup
416  #   https://android-review.git.corp.google.com/c/kernel/common/+/2503195
417  #
418  #   UPSTREAM: net: retrieve netns cookie via getsocketopt
419  #   https://android-review.git.corp.google.com/c/kernel/common/+/2503056
420  #
421  # (and potentially if you care about kernel ABI)
422  #
423  #   ANDROID: fix ABI by undoing atomic64_t -> u64 type conversion
424  #   https://android-review.git.corp.google.com/c/kernel/common/+/2504335
425  #
426  # Android13-5.10:
427  #   UPSTREAM: net: initialize net->net_cookie at netns setup
428  #   https://android-review.git.corp.google.com/c/kernel/common/+/2503795
429  #
430  #   UPSTREAM: net: retrieve netns cookie via getsocketopt
431  #   https://android-review.git.corp.google.com/c/kernel/common/+/2503796
432  #
433  # (and potentially if you care about kernel ABI)
434  #
435  #   ANDROID: fix ABI by undoing atomic64_t -> u64 type conversion
436  #   https://android-review.git.corp.google.com/c/kernel/common/+/2506895
437  #
438  @unittest.skipUnless(bpf.HAVE_SO_NETNS_COOKIE, "no SO_NETNS_COOKIE support")
439  def testGetNetNsCookie(self):
440    sk = socket.socket(socket.AF_INET6, socket.SOCK_DGRAM, 0)
441    sizeof_u64 = 8
442    cookie = sk.getsockopt(socket.SOL_SOCKET, bpf.SO_NETNS_COOKIE, sizeof_u64)
443    sk.close()
444    self.assertEqual(len(cookie), 8)
445    cookie = int.from_bytes(cookie, "little")
446    self.assertGreaterEqual(cookie, 0)
447
448  def testGetSocketCookie(self):
449    self.map_fd = CreateMap(BPF_MAP_TYPE_HASH, KEY_SIZE, VALUE_SIZE,
450                            TOTAL_ENTRIES)
451    # Move skb to REG6 for further usage, call helper function to get socket
452    # cookie of current skb and return the cookie at REG0 for next code block
453    instructions = [
454        BpfMov64Reg(BPF_REG_6, BPF_REG_1),
455        BpfFuncCall(BPF_FUNC_get_socket_cookie)
456    ]
457    instructions += (INS_BPF_PARAM_STORE + BpfFuncCountPacketInit(self.map_fd)
458                     + INS_SK_FILTER_ACCEPT + INS_PACK_COUNT_UPDATE
459                     + INS_SK_FILTER_ACCEPT)
460    self.prog_fd = BpfProgLoad(BPF_PROG_TYPE_SOCKET_FILTER, instructions)
461    packet_count = 10
462    def PacketCountByCookie(version):
463      self.sock = SocketUDPLoopBack(packet_count, version, self.prog_fd)
464      cookie = sock_diag.SockDiag.GetSocketCookie(self.sock)
465      self.assertEqual(packet_count, LookupMap(self.map_fd, cookie).value)
466      self.sock.close()
467    PacketCountByCookie(4)
468    PacketCountByCookie(6)
469
470  def testGetSocketUid(self):
471    self.map_fd = CreateMap(BPF_MAP_TYPE_HASH, KEY_SIZE, VALUE_SIZE,
472                            TOTAL_ENTRIES)
473    # Set up the instruction with uid at BPF_REG_0.
474    instructions = [
475        BpfMov64Reg(BPF_REG_6, BPF_REG_1),
476        BpfFuncCall(BPF_FUNC_get_socket_uid)
477    ]
478    # Concatenate the generic packet count bpf program to it.
479    instructions += (INS_BPF_PARAM_STORE + BpfFuncCountPacketInit(self.map_fd)
480                     + INS_SK_FILTER_ACCEPT + INS_PACK_COUNT_UPDATE
481                     + INS_SK_FILTER_ACCEPT)
482    self.prog_fd = BpfProgLoad(BPF_PROG_TYPE_SOCKET_FILTER, instructions)
483    packet_count = 10
484    uid = TEST_UID
485    with net_test.RunAsUid(uid):
486      self.assertRaisesErrno(errno.ENOENT, LookupMap, self.map_fd, uid)
487      SocketUDPLoopBack(packet_count, 4, self.prog_fd).close()
488      self.assertEqual(packet_count, LookupMap(self.map_fd, uid).value)
489      DeleteMap(self.map_fd, uid)
490      SocketUDPLoopBack(packet_count, 6, self.prog_fd).close()
491      self.assertEqual(packet_count, LookupMap(self.map_fd, uid).value)
492
493
494class BpfCgroupTest(net_test.NetworkTest):
495
496  @classmethod
497  def setUpClass(cls):
498    super(BpfCgroupTest, cls).setUpClass()
499    # os.open() throws exception on failure
500    cls._cg_fd = os.open("/sys/fs/cgroup", os.O_DIRECTORY | os.O_RDONLY)
501
502  @classmethod
503  def tearDownClass(cls):
504    if cls._cg_fd is not None:
505      os.close(cls._cg_fd)
506      cls._cg_fd = None
507    super(BpfCgroupTest, cls).tearDownClass()
508
509  def setUp(self):
510    super(BpfCgroupTest, self).setUp()
511    self.prog_fd = None
512    self.map_fd = None
513    self.cg_inet_ingress = BpfProgGetFdById(
514        BpfProgQuery(self._cg_fd, BPF_CGROUP_INET_INGRESS, 0, 0))
515    self.cg_inet_egress = BpfProgGetFdById(
516        BpfProgQuery(self._cg_fd, BPF_CGROUP_INET_EGRESS, 0, 0))
517    self.cg_inet_sock_create = BpfProgGetFdById(
518        BpfProgQuery(self._cg_fd, BPF_CGROUP_INET_SOCK_CREATE, 0, 0))
519    if self.cg_inet_ingress:
520      BpfProgDetach(self._cg_fd, BPF_CGROUP_INET_INGRESS)
521    if self.cg_inet_egress:
522      BpfProgDetach(self._cg_fd, BPF_CGROUP_INET_EGRESS)
523    if self.cg_inet_sock_create:
524      BpfProgDetach(self._cg_fd, BPF_CGROUP_INET_SOCK_CREATE)
525
526  def tearDown(self):
527    if self.prog_fd is not None:
528      os.close(self.prog_fd)
529      self.prog_fd = None
530    if self.map_fd is not None:
531      os.close(self.map_fd)
532      self.map_fd = None
533    if self.cg_inet_ingress is None:
534      BpfProgDetach(self._cg_fd, BPF_CGROUP_INET_INGRESS)
535    else:
536      BpfProgAttach(self.cg_inet_ingress, self._cg_fd, BPF_CGROUP_INET_INGRESS)
537      os.close(self.cg_inet_ingress)
538      self.cg_inet_ingress = None
539    if self.cg_inet_egress is None:
540      BpfProgDetach(self._cg_fd, BPF_CGROUP_INET_EGRESS)
541    else:
542      BpfProgAttach(self.cg_inet_egress, self._cg_fd, BPF_CGROUP_INET_EGRESS)
543      os.close(self.cg_inet_egress)
544      self.cg_inet_egress = None
545    if self.cg_inet_sock_create is None:
546      BpfProgDetach(self._cg_fd, BPF_CGROUP_INET_SOCK_CREATE)
547    else:
548      BpfProgAttach(self.cg_inet_sock_create, self._cg_fd,
549                    BPF_CGROUP_INET_SOCK_CREATE)
550      os.close(self.cg_inet_sock_create)
551      self.cg_inet_sock_create = None
552    super(BpfCgroupTest, self).tearDown()
553
554  def testCgroupBpfAttach(self):
555    self.prog_fd = BpfProgLoad(BPF_PROG_TYPE_CGROUP_SKB, INS_BPF_EXIT_BLOCK)
556    BpfProgAttach(self.prog_fd, self._cg_fd, BPF_CGROUP_INET_INGRESS)
557    BpfProgDetach(self._cg_fd, BPF_CGROUP_INET_INGRESS)
558
559  def testCgroupIngress(self):
560    self.prog_fd = BpfProgLoad(BPF_PROG_TYPE_CGROUP_SKB, INS_BPF_EXIT_BLOCK)
561    BpfProgAttach(self.prog_fd, self._cg_fd, BPF_CGROUP_INET_INGRESS)
562    self.assertRaisesErrno(errno.EAGAIN, SocketUDPLoopBack, 1, 4, None)
563    self.assertRaisesErrno(errno.EAGAIN, SocketUDPLoopBack, 1, 6, None)
564    BpfProgDetach(self._cg_fd, BPF_CGROUP_INET_INGRESS)
565    SocketUDPLoopBack(1, 4, None).close()
566    SocketUDPLoopBack(1, 6, None).close()
567
568  def testCgroupEgress(self):
569    self.prog_fd = BpfProgLoad(BPF_PROG_TYPE_CGROUP_SKB, INS_BPF_EXIT_BLOCK)
570    BpfProgAttach(self.prog_fd, self._cg_fd, BPF_CGROUP_INET_EGRESS)
571    self.assertRaisesErrno(errno.EPERM, SocketUDPLoopBack, 1, 4, None)
572    self.assertRaisesErrno(errno.EPERM, SocketUDPLoopBack, 1, 6, None)
573    BpfProgDetach(self._cg_fd, BPF_CGROUP_INET_EGRESS)
574    SocketUDPLoopBack(1, 4, None).close()
575    SocketUDPLoopBack(1, 6, None).close()
576
577  def testCgroupBpfUid(self):
578    self.map_fd = CreateMap(BPF_MAP_TYPE_HASH, KEY_SIZE, VALUE_SIZE,
579                            TOTAL_ENTRIES)
580    # Similar to the program used in testGetSocketUid.
581    instructions = [
582        BpfMov64Reg(BPF_REG_6, BPF_REG_1),
583        BpfFuncCall(BPF_FUNC_get_socket_uid)
584    ]
585    instructions += (INS_BPF_PARAM_STORE + BpfFuncCountPacketInit(self.map_fd)
586                     + INS_CGROUP_ACCEPT + INS_PACK_COUNT_UPDATE
587                     + INS_CGROUP_ACCEPT)
588    self.prog_fd = BpfProgLoad(BPF_PROG_TYPE_CGROUP_SKB, instructions)
589    BpfProgAttach(self.prog_fd, self._cg_fd, BPF_CGROUP_INET_INGRESS)
590    packet_count = 20
591    uid = TEST_UID
592    with net_test.RunAsUid(uid):
593      self.assertRaisesErrno(errno.ENOENT, LookupMap, self.map_fd, uid)
594      SocketUDPLoopBack(packet_count, 4, None).close()
595      self.assertEqual(packet_count, LookupMap(self.map_fd, uid).value)
596      DeleteMap(self.map_fd, uid)
597      SocketUDPLoopBack(packet_count, 6, None).close()
598      self.assertEqual(packet_count, LookupMap(self.map_fd, uid).value)
599    BpfProgDetach(self._cg_fd, BPF_CGROUP_INET_INGRESS)
600
601  def checkSocketCreate(self, family, socktype, sockproto, success):
602    try:
603      sock = socket.socket(family, socktype, sockproto)
604      sock.close()
605    except socket.error as e:
606      if success:
607        self.fail("Failed to create socket family=%d type=%d proto=%d err=%s" %
608                  (family, socktype, sockproto, os.strerror(e.errno)))
609      return
610    if not success:
611      self.fail("unexpected socket family=%d type=%d proto=%d created, "
612                "should be blocked" % (family, socktype, sockproto))
613
614  def testPfKeySocketCreate(self):
615    # AF_KEY socket type. See include/linux/socket.h.
616    AF_KEY = 15  # pylint: disable=invalid-name
617
618    # PFKEYv2 constants. See include/uapi/linux/pfkeyv2.h.
619    PF_KEY_V2 = 2  # pylint: disable=invalid-name
620
621    self.checkSocketCreate(AF_KEY, socket.SOCK_RAW, PF_KEY_V2, True)
622
623  def trySocketCreate(self, success):
624    for family in [socket.AF_INET, socket.AF_INET6]:
625      for socktype in [socket.SOCK_DGRAM, socket.SOCK_STREAM]:
626        self.checkSocketCreate(family, socktype, 0, success)
627
628  def testCgroupSocketCreateBlock(self):
629    instructions = [
630        BpfFuncCall(BPF_FUNC_get_current_uid_gid),
631        BpfAlu64Imm(BPF_AND, BPF_REG_0, 0xfffffff),
632        BpfJumpImm(BPF_JNE, BPF_REG_0, TEST_UID, 2),
633    ]
634    instructions += INS_BPF_EXIT_BLOCK + INS_CGROUP_ACCEPT
635
636    fd = BpfProgGetFdById(
637        BpfProgQuery(self._cg_fd, BPF_CGROUP_INET_SOCK_CREATE, 0, 0))
638    assert fd is None
639
640    self.prog_fd = BpfProgLoad(BPF_PROG_TYPE_CGROUP_SOCK, instructions)
641    BpfProgAttach(self.prog_fd, self._cg_fd, BPF_CGROUP_INET_SOCK_CREATE)
642
643    fd = BpfProgGetFdById(
644        BpfProgQuery(self._cg_fd, BPF_CGROUP_INET_SOCK_CREATE, 0, 0))
645    assert fd is not None
646    # equality while almost certain is not actually 100% guaranteed:
647    assert fd >= self.prog_fd + 1
648    os.close(fd)
649    fd = None
650
651    with net_test.RunAsUid(TEST_UID):
652      # Socket creation with target uid should fail
653      self.trySocketCreate(False)
654    # Socket create with different uid should success
655    self.trySocketCreate(True)
656    BpfProgDetach(self._cg_fd, BPF_CGROUP_INET_SOCK_CREATE)
657    with net_test.RunAsUid(TEST_UID):
658      self.trySocketCreate(True)
659
660if __name__ == "__main__":
661  unittest.main()
662