1# Copyright 2023 The Android Open Source Project
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#      http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15import asyncio
16import collections
17import enum
18import hci_packets as hci
19import link_layer_packets as ll
20import llcp_packets as llcp
21import py.bluetooth
22import sys
23import typing
24import unittest
25from typing import Optional, Tuple, Union
26from hci_packets import ErrorCode
27
28from ctypes import *
29
30rootcanal = cdll.LoadLibrary("lib_rootcanal_ffi.so")
31rootcanal.ffi_controller_new.restype = c_void_p
32
33SEND_HCI_FUNC = CFUNCTYPE(None, c_int, POINTER(c_ubyte), c_size_t)
34SEND_LL_FUNC = CFUNCTYPE(None, POINTER(c_ubyte), c_size_t, c_int, c_int)
35
36
37class Idc(enum.IntEnum):
38    Cmd = 1
39    Acl = 2
40    Sco = 3
41    Evt = 4
42    Iso = 5
43
44
45class Phy(enum.IntEnum):
46    LowEnergy = 0
47    BrEdr = 1
48
49
50class LeFeatures:
51
52    def __init__(self, le_features: int):
53        self.mask = le_features
54        self.ll_privacy = (le_features & hci.LLFeaturesBits.LL_PRIVACY) != 0
55        self.le_extended_advertising = (le_features & hci.LLFeaturesBits.LE_EXTENDED_ADVERTISING) != 0
56        self.le_periodic_advertising = (le_features & hci.LLFeaturesBits.LE_PERIODIC_ADVERTISING) != 0
57
58
59def generate_rpa(irk: bytes) -> hci.Address:
60    rpa = bytearray(6)
61    rpa_type = c_char * 6
62    rootcanal.ffi_generate_rpa(c_char_p(irk), rpa_type.from_buffer(rpa))
63    rpa.reverse()
64    return hci.Address(bytes(rpa))
65
66
67class Controller:
68    """Binder class over RootCanal's ffi interfaces.
69    The methods send_cmd, send_hci, send_ll are used to inject HCI or LL
70    packets into the controller, and receive_hci, receive_ll to
71    catch outgoing HCI packets of LL pdus."""
72
73    def __init__(self, address: hci.Address):
74        # Write the callbacks for handling HCI and LL send events.
75        @SEND_HCI_FUNC
76        def send_hci(idc: c_int, data: POINTER(c_ubyte), data_len: c_size_t):
77            packet = []
78            for n in range(data_len):
79                packet.append(data[n])
80            self.receive_hci_(int(idc), bytes(packet))
81
82        @SEND_LL_FUNC
83        def send_ll(data: POINTER(c_ubyte), data_len: c_size_t, phy: c_int, tx_power: c_int):
84            packet = []
85            for n in range(data_len):
86                packet.append(data[n])
87            self.receive_ll_(bytes(packet), int(phy), int(tx_power))
88
89        self.send_hci_callback = SEND_HCI_FUNC(send_hci)
90        self.send_ll_callback = SEND_LL_FUNC(send_ll)
91
92        # Create a c++ controller instance.
93        self.instance = rootcanal.ffi_controller_new(c_char_p(address.address), self.send_hci_callback,
94                                                     self.send_ll_callback)
95
96        self.address = address
97        self.evt_queue = collections.deque()
98        self.acl_queue = collections.deque()
99        self.iso_queue = collections.deque()
100        self.ll_queue = collections.deque()
101        self.evt_queue_event = asyncio.Event()
102        self.acl_queue_event = asyncio.Event()
103        self.iso_queue_event = asyncio.Event()
104        self.ll_queue_event = asyncio.Event()
105
106    def __del__(self):
107        rootcanal.ffi_controller_delete(c_void_p(self.instance))
108
109    def receive_hci_(self, idc: int, packet: bytes):
110        if idc == Idc.Evt:
111            print(f"<-- received HCI event data={len(packet)}[..]")
112            self.evt_queue.append(packet)
113            self.evt_queue_event.set()
114        elif idc == Idc.Acl:
115            print(f"<-- received HCI ACL packet data={len(packet)}[..]")
116            self.acl_queue.append(packet)
117            self.acl_queue_event.set()
118        elif idc == Idc.Iso:
119            print(f"<-- received HCI ISO packet data={len(packet)}[..]")
120            self.iso_queue.append(packet)
121            self.iso_queue_event.set()
122        else:
123            print(f"ignoring HCI packet typ={idc}")
124
125    def receive_ll_(self, packet: bytes, phy: int, tx_power: int):
126        print(f"<-- received LL pdu data={len(packet)}[..]")
127        self.ll_queue.append(packet)
128        self.ll_queue_event.set()
129
130    def send_cmd(self, cmd: hci.Command):
131        print(f"--> sending HCI command {cmd.__class__.__name__}")
132        data = cmd.serialize()
133        rootcanal.ffi_controller_receive_hci(c_void_p(self.instance), c_int(Idc.Cmd), c_char_p(data), c_int(len(data)))
134
135    def send_iso(self, iso: hci.Iso):
136        print(f"--> sending HCI iso pdu data={len(iso.payload)}[..]")
137        data = iso.serialize()
138        rootcanal.ffi_controller_receive_hci(c_void_p(self.instance), c_int(Idc.Iso), c_char_p(data), c_int(len(data)))
139
140    def send_ll(self, pdu: ll.LinkLayerPacket, phy: Phy = Phy.LowEnergy, rssi: int = -90):
141        print(f"--> sending LL pdu {pdu.__class__.__name__}")
142        data = pdu.serialize()
143        rootcanal.ffi_controller_receive_ll(c_void_p(self.instance), c_char_p(data), c_int(len(data)), c_int(phy),
144                                            c_int(rssi))
145
146    def send_llcp(self,
147                  source_address: hci.Address,
148                  destination_address: hci.Address,
149                  pdu: llcp.LlcpPacket,
150                  phy: Phy = Phy.LowEnergy,
151                  rssi: int = -90):
152        print(f"--> sending LLCP pdu {pdu.__class__.__name__}")
153        ll_pdu = ll.Llcp(source_address=source_address,
154                         destination_address=destination_address,
155                         payload=pdu.serialize())
156        data = ll_pdu.serialize()
157        rootcanal.ffi_controller_receive_ll(c_void_p(self.instance), c_char_p(data), c_int(len(data)), c_int(phy),
158                                            c_int(rssi))
159
160    async def start(self):
161
162        async def timer():
163            while True:
164                await asyncio.sleep(0.005)
165                rootcanal.ffi_controller_tick(c_void_p(self.instance))
166
167        # Spawn the controller timer task.
168        self.timer_task = asyncio.create_task(timer())
169
170    def stop(self):
171        # Cancel the controller timer task.
172        del self.timer_task
173
174        if self.evt_queue:
175            print("evt queue not empty at stop():")
176            for packet in self.evt_queue:
177                evt = hci.Event.parse_all(packet)
178                evt.show()
179            raise Exception("evt queue not empty at stop()")
180
181        if self.iso_queue:
182            print("iso queue not empty at stop():")
183            for packet in self.iso_queue:
184                iso = hci.Iso.parse_all(packet)
185                iso.show()
186            raise Exception("ll queue not empty at stop()")
187
188        if self.ll_queue:
189            for (packet, _) in self.ll_queue:
190                pdu = ll.LinkLayerPacket.parse_all(packet)
191                pdu.show()
192            raise Exception("ll queue not empty at stop()")
193
194    async def receive_evt(self):
195        while not self.evt_queue:
196            await self.evt_queue_event.wait()
197            self.evt_queue_event.clear()
198        return self.evt_queue.popleft()
199
200    async def receive_iso(self):
201        while not self.iso_queue:
202            await self.iso_queue_event.wait()
203            self.iso_queue_event.clear()
204        return self.iso_queue.popleft()
205
206    async def expect_evt(self, expected_evt: hci.Event):
207        packet = await self.receive_evt()
208        evt = hci.Event.parse_all(packet)
209        if evt != expected_evt:
210            print("received unexpected event")
211            print("expected event:")
212            expected_evt.show()
213            print("received event:")
214            evt.show()
215            raise Exception(f"unexpected evt {evt.__class__.__name__}")
216
217    async def receive_ll(self):
218        while not self.ll_queue:
219            await self.ll_queue_event.wait()
220            self.ll_queue_event.clear()
221        return self.ll_queue.popleft()
222
223
224class Any:
225    """Helper class that will match all other values.
226       Use an element of this class in expected packets to match any value
227      returned by the Controller stack."""
228
229    def __eq__(self, other) -> bool:
230        return True
231
232    def __format__(self, format_spec: str) -> str:
233        return "_"
234
235
236class ControllerTest(unittest.IsolatedAsyncioTestCase):
237    """Helper class for writing controller tests using the python bindings.
238    The test setups the controller sending the Reset command and configuring
239    the event masks to allow all events. The local device address is
240    always configured as 11:11:11:11:11:11."""
241
242    Any = Any()
243
244    def setUp(self):
245        self.controller = Controller(hci.Address('11:11:11:11:11:11'))
246
247    async def asyncSetUp(self):
248        controller = self.controller
249
250        # Start the controller timer.
251        await controller.start()
252
253        # Reset the controller and enable all events and LE events.
254        controller.send_cmd(hci.Reset())
255        await controller.expect_evt(hci.ResetComplete(status=ErrorCode.SUCCESS, num_hci_command_packets=1))
256        controller.send_cmd(hci.SetEventMask(event_mask=0xffffffffffffffff))
257        await controller.expect_evt(hci.SetEventMaskComplete(status=ErrorCode.SUCCESS, num_hci_command_packets=1))
258        controller.send_cmd(hci.LeSetEventMask(le_event_mask=0xffffffffffffffff))
259        await controller.expect_evt(hci.LeSetEventMaskComplete(status=ErrorCode.SUCCESS, num_hci_command_packets=1))
260
261        # Load the local supported features to be able to disable tests
262        # that rely on unsupported features.
263        controller.send_cmd(hci.LeReadLocalSupportedFeaturesPage0())
264        evt = await self.expect_cmd_complete(hci.LeReadLocalSupportedFeaturesPage0Complete)
265        controller.le_features = LeFeatures(evt.le_features)
266
267    async def expect_evt(self, expected_evt: typing.Union[hci.Event, type], timeout: int = 3) -> hci.Event:
268        packet = await asyncio.wait_for(self.controller.receive_evt(), timeout=timeout)
269        evt = hci.Event.parse_all(packet)
270
271        if isinstance(expected_evt, type) and not isinstance(evt, expected_evt):
272            print("received unexpected event")
273            print(f"expected event: {expected_evt.__class__.__name__}")
274            print("received event:")
275            evt.show()
276            self.assertTrue(False)
277
278        if isinstance(expected_evt, hci.Event) and evt != expected_evt:
279            print("received unexpected event")
280            print(f"expected event:")
281            expected_evt.show()
282            print("received event:")
283            evt.show()
284            self.assertTrue(False)
285
286        return evt
287
288    async def expect_cmd_complete(self, expected_evt: type, timeout: int = 3) -> hci.Event:
289        evt = await self.expect_evt(expected_evt, timeout=timeout)
290        assert evt.status == ErrorCode.SUCCESS
291        assert evt.num_hci_command_packets == 1
292        return evt
293
294    async def expect_iso(self, expected_iso: hci.Iso, timeout: int = 3):
295        packet = await asyncio.wait_for(self.controller.receive_iso(), timeout=timeout)
296        iso = hci.Iso.parse_all(packet)
297
298        if iso != expected_iso:
299            print("received unexpected iso packet")
300            print("expected packet:")
301            expected_iso.show()
302            print("received packet:")
303            iso.show()
304            self.assertTrue(False)
305
306    async def expect_ll(self,
307                        expected_pdus: typing.Union[list, typing.Union[ll.LinkLayerPacket, type]],
308                        ignored_pdus: typing.Union[list, type] = [],
309                        timeout: int = 3) -> ll.LinkLayerPacket:
310        if not isinstance(ignored_pdus, list):
311            ignored_pdus = [ignored_pdus]
312
313        if not isinstance(expected_pdus, list):
314            expected_pdus = [expected_pdus]
315
316        async with asyncio.timeout(timeout):
317            while True:
318                packet = await self.controller.receive_ll()
319                pdu = ll.LinkLayerPacket.parse_all(packet)
320
321                for ignored_pdu in ignored_pdus:
322                    if isinstance(pdu, ignored_pdu):
323                        continue
324
325                for expected_pdu in expected_pdus:
326                    if isinstance(expected_pdu, type) and isinstance(pdu, expected_pdu):
327                        return pdu
328                    if isinstance(expected_pdu, ll.LinkLayerPacket) and pdu == expected_pdu:
329                        return pdu
330
331                print("received unexpected pdu:")
332                pdu.show()
333                print("expected pdus:")
334                for expected_pdu in expected_pdus:
335                    if isinstance(expected_pdu, type):
336                        print(f"- {expected_pdu.__name__}")
337                    if isinstance(expected_pdu, ll.LinkLayerPacket):
338                        print(f"- {expected_pdu.__class__.__name__}")
339                        expected_pdu.show()
340
341                self.assertTrue(False)
342
343    async def expect_llcp(self,
344                          source_address: hci.Address,
345                          destination_address: hci.Address,
346                          expected_pdu: llcp.LlcpPacket,
347                          timeout: int = 3) -> llcp.LlcpPacket:
348        packet = await asyncio.wait_for(self.controller.receive_ll(), timeout=timeout)
349        pdu = ll.LinkLayerPacket.parse_all(packet)
350
351        if (pdu.type != ll.PacketType.LLCP or pdu.source_address != source_address or
352                pdu.destination_address != destination_address):
353            print("received unexpected pdu:")
354            pdu.show()
355            print(f"expected pdu: {source_address} -> {destination_address}")
356            expected_pdu.show()
357            self.assertTrue(False)
358
359        pdu = llcp.LlcpPacket.parse_all(pdu.payload)
360        if pdu != expected_pdu:
361            print("received unexpected pdu:")
362            pdu.show()
363            print("expected pdu:")
364            expected_pdu.show()
365            self.assertTrue(False)
366
367        return pdu
368
369    async def enable_connected_isochronous_stream_host_support(self):
370        """Enable Connected Isochronous Stream Host Support in the LE Feature mask."""
371        self.controller.send_cmd(
372            hci.LeSetHostFeatureV1(bit_number=hci.LeHostFeatureBits.CONNECTED_ISO_STREAM_HOST_SUPPORT,
373                                   bit_value=hci.Enable.ENABLED))
374
375        await self.expect_evt(hci.LeSetHostFeatureV1Complete(status=ErrorCode.SUCCESS, num_hci_command_packets=1))
376
377    async def establish_le_connection_central(self, peer_address: hci.Address) -> int:
378        """Establish a connection with the selected peer as Central.
379        Returns the ACL connection handle for the opened link."""
380        self.controller.send_cmd(
381            hci.LeExtendedCreateConnectionV1(initiator_filter_policy=hci.InitiatorFilterPolicy.USE_PEER_ADDRESS,
382                                             own_address_type=hci.OwnAddressType.PUBLIC_DEVICE_ADDRESS,
383                                             peer_address_type=hci.AddressType.PUBLIC_DEVICE_ADDRESS,
384                                             peer_address=peer_address,
385                                             initiating_phys=0x1,
386                                             initiating_phy_parameters=[
387                                                 hci.InitiatingPhyParameters(
388                                                     scan_interval=0x200,
389                                                     scan_window=0x100,
390                                                     connection_interval_min=0x200,
391                                                     connection_interval_max=0x200,
392                                                     max_latency=0x6,
393                                                     supervision_timeout=0xc80,
394                                                     min_ce_length=0,
395                                                     max_ce_length=0,
396                                                 )
397                                             ]))
398
399        await self.expect_evt(
400            hci.LeExtendedCreateConnectionV1Status(status=ErrorCode.SUCCESS, num_hci_command_packets=1))
401
402        self.controller.send_ll(ll.LeLegacyAdvertisingPdu(source_address=peer_address,
403                                                          advertising_address_type=ll.AddressType.PUBLIC,
404                                                          advertising_type=ll.LegacyAdvertisingType.ADV_IND,
405                                                          advertising_data=[]),
406                                rssi=-16)
407
408        await self.expect_ll(
409            ll.LeConnect(source_address=self.controller.address,
410                         destination_address=peer_address,
411                         initiating_address_type=ll.AddressType.PUBLIC,
412                         advertising_address_type=ll.AddressType.PUBLIC,
413                         conn_interval=0x200,
414                         conn_peripheral_latency=0x6,
415                         conn_supervision_timeout=0xc80))
416
417        self.controller.send_ll(
418            ll.LeConnectComplete(source_address=peer_address,
419                                 destination_address=self.controller.address,
420                                 initiating_address_type=ll.AddressType.PUBLIC,
421                                 advertising_address_type=ll.AddressType.PUBLIC,
422                                 conn_interval=0x200,
423                                 conn_peripheral_latency=0x6,
424                                 conn_supervision_timeout=0xc80))
425
426        connection_complete = await self.expect_evt(
427            hci.LeEnhancedConnectionCompleteV1(status=ErrorCode.SUCCESS,
428                                               connection_handle=self.Any,
429                                               role=hci.Role.CENTRAL,
430                                               peer_address_type=hci.AddressType.PUBLIC_DEVICE_ADDRESS,
431                                               peer_address=peer_address,
432                                               connection_interval=0x200,
433                                               peripheral_latency=0x6,
434                                               supervision_timeout=0xc80,
435                                               central_clock_accuracy=hci.ClockAccuracy.PPM_500))
436
437        acl_connection_handle = connection_complete.connection_handle
438        await self.expect_evt(
439            hci.LeChannelSelectionAlgorithm(connection_handle=acl_connection_handle,
440                                            channel_selection_algorithm=hci.ChannelSelectionAlgorithm.ALGORITHM_1))
441
442        return acl_connection_handle
443
444    async def establish_le_connection_peripheral(self, peer_address: hci.Address) -> int:
445        """Establish a connection with the selected peer as Peripheral.
446        Returns the ACL connection handle for the opened link."""
447        self.controller.send_cmd(
448            hci.LeSetAdvertisingParameters(advertising_interval_min=0x200,
449                                           advertising_interval_max=0x200,
450                                           advertising_type=hci.AdvertisingType.ADV_IND,
451                                           own_address_type=hci.OwnAddressType.PUBLIC_DEVICE_ADDRESS,
452                                           advertising_channel_map=0x7,
453                                           advertising_filter_policy=hci.AdvertisingFilterPolicy.ALL_DEVICES))
454
455        await self.expect_evt(
456            hci.LeSetAdvertisingParametersComplete(status=ErrorCode.SUCCESS, num_hci_command_packets=1))
457
458        self.controller.send_cmd(hci.LeSetAdvertisingEnable(advertising_enable=True))
459
460        await self.expect_evt(hci.LeSetAdvertisingEnableComplete(status=ErrorCode.SUCCESS, num_hci_command_packets=1))
461
462        self.controller.send_ll(ll.LeConnect(source_address=peer_address,
463                                             destination_address=self.controller.address,
464                                             initiating_address_type=ll.AddressType.PUBLIC,
465                                             advertising_address_type=ll.AddressType.PUBLIC,
466                                             conn_interval=0x200,
467                                             conn_peripheral_latency=0x200,
468                                             conn_supervision_timeout=0x200),
469                                rssi=-16)
470
471        await self.expect_ll(
472            ll.LeConnectComplete(source_address=self.controller.address,
473                                 destination_address=peer_address,
474                                 conn_interval=0x200,
475                                 conn_peripheral_latency=0x200,
476                                 conn_supervision_timeout=0x200))
477
478        connection_complete = await self.expect_evt(
479            hci.LeEnhancedConnectionCompleteV1(status=ErrorCode.SUCCESS,
480                                               connection_handle=self.Any,
481                                               role=hci.Role.PERIPHERAL,
482                                               peer_address_type=hci.AddressType.PUBLIC_DEVICE_ADDRESS,
483                                               peer_address=peer_address,
484                                               connection_interval=0x200,
485                                               peripheral_latency=0x200,
486                                               supervision_timeout=0x200,
487                                               central_clock_accuracy=hci.ClockAccuracy.PPM_500))
488
489        return connection_complete.connection_handle
490
491    def tearDown(self):
492        self.controller.stop()
493