1# Copyright 2021-2022 Google LLC
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#      https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15# -----------------------------------------------------------------------------
16# Imports
17# -----------------------------------------------------------------------------
18import asyncio
19import itertools
20import logging
21import os
22import pytest
23
24from unittest.mock import AsyncMock, MagicMock, patch
25
26from bumble.controller import Controller
27from bumble.core import BT_BR_EDR_TRANSPORT, BT_PERIPHERAL_ROLE, BT_CENTRAL_ROLE
28from bumble.link import LocalLink
29from bumble.device import Device, Peer
30from bumble.host import Host
31from bumble.gatt import Service, Characteristic
32from bumble.transport import AsyncPipeSink
33from bumble.pairing import PairingConfig, PairingDelegate
34from bumble.smp import (
35    SMP_PAIRING_NOT_SUPPORTED_ERROR,
36    SMP_CONFIRM_VALUE_FAILED_ERROR,
37    OobContext,
38    OobLegacyContext,
39)
40from bumble.core import ProtocolError
41from bumble.keys import PairingKeys
42
43
44# -----------------------------------------------------------------------------
45# Logging
46# -----------------------------------------------------------------------------
47logger = logging.getLogger(__name__)
48
49
50# -----------------------------------------------------------------------------
51class TwoDevices:
52    def __init__(self):
53        self.connections = [None, None]
54
55        addresses = ['F0:F1:F2:F3:F4:F5', 'F5:F4:F3:F2:F1:F0']
56        self.link = LocalLink()
57        self.controllers = [
58            Controller('C1', link=self.link, public_address=addresses[0]),
59            Controller('C2', link=self.link, public_address=addresses[1]),
60        ]
61        self.devices = [
62            Device(
63                address=addresses[0],
64                host=Host(self.controllers[0], AsyncPipeSink(self.controllers[0])),
65            ),
66            Device(
67                address=addresses[1],
68                host=Host(self.controllers[1], AsyncPipeSink(self.controllers[1])),
69            ),
70        ]
71
72        self.paired = [
73            asyncio.get_event_loop().create_future(),
74            asyncio.get_event_loop().create_future(),
75        ]
76
77    def on_connection(self, which, connection):
78        self.connections[which] = connection
79
80    def on_paired(self, which: int, keys: PairingKeys):
81        self.paired[which].set_result(keys)
82
83
84# -----------------------------------------------------------------------------
85@pytest.mark.asyncio
86async def test_self_connection():
87    # Create two devices, each with a controller, attached to the same link
88    two_devices = TwoDevices()
89
90    # Attach listeners
91    two_devices.devices[0].on(
92        'connection', lambda connection: two_devices.on_connection(0, connection)
93    )
94    two_devices.devices[1].on(
95        'connection', lambda connection: two_devices.on_connection(1, connection)
96    )
97
98    # Start
99    await two_devices.devices[0].power_on()
100    await two_devices.devices[1].power_on()
101
102    # Connect the two devices
103    await two_devices.devices[0].connect(two_devices.devices[1].random_address)
104
105    # Check the post conditions
106    assert two_devices.connections[0] is not None
107    assert two_devices.connections[1] is not None
108
109
110# -----------------------------------------------------------------------------
111@pytest.mark.asyncio
112@pytest.mark.parametrize(
113    'responder_role,',
114    (BT_CENTRAL_ROLE, BT_PERIPHERAL_ROLE),
115)
116async def test_self_classic_connection(responder_role):
117    # Create two devices, each with a controller, attached to the same link
118    two_devices = TwoDevices()
119
120    # Attach listeners
121    two_devices.devices[0].on(
122        'connection', lambda connection: two_devices.on_connection(0, connection)
123    )
124    two_devices.devices[1].on(
125        'connection', lambda connection: two_devices.on_connection(1, connection)
126    )
127
128    # Enable Classic connections
129    two_devices.devices[0].classic_enabled = True
130    two_devices.devices[1].classic_enabled = True
131
132    # Start
133    await two_devices.devices[0].power_on()
134    await two_devices.devices[1].power_on()
135
136    # Connect the two devices
137    await asyncio.gather(
138        two_devices.devices[0].connect(
139            two_devices.devices[1].public_address, transport=BT_BR_EDR_TRANSPORT
140        ),
141        two_devices.devices[1].accept(
142            two_devices.devices[0].public_address, responder_role
143        ),
144    )
145
146    # Check the post conditions
147    assert two_devices.connections[0] is not None
148    assert two_devices.connections[1] is not None
149
150    # Check the role
151    assert two_devices.connections[0].role != responder_role
152    assert two_devices.connections[1].role == responder_role
153
154    # Role switch
155    await two_devices.connections[0].switch_role(responder_role)
156
157    # Check the role
158    assert two_devices.connections[0].role == responder_role
159    assert two_devices.connections[1].role != responder_role
160
161    await two_devices.connections[0].disconnect()
162
163
164# -----------------------------------------------------------------------------
165@pytest.mark.asyncio
166async def test_self_gatt():
167    # Create two devices, each with a controller, attached to the same link
168    two_devices = TwoDevices()
169
170    # Add some GATT characteristics to device 1
171    c1 = Characteristic(
172        '3A143AD7-D4A7-436B-97D6-5B62C315E833',
173        Characteristic.Properties.READ,
174        Characteristic.READABLE,
175        bytes([1, 2, 3]),
176    )
177    c2 = Characteristic(
178        '9557CCE2-DB37-46EB-94C4-50AE5B9CB0F8',
179        Characteristic.Properties.READ | Characteristic.Properties.WRITE,
180        Characteristic.READABLE | Characteristic.WRITEABLE,
181        bytes([4, 5, 6]),
182    )
183    c3 = Characteristic(
184        '84FC1A2E-C52D-4A2D-B8C3-8855BAB86638',
185        Characteristic.Properties.READ
186        | Characteristic.Properties.WRITE_WITHOUT_RESPONSE,
187        Characteristic.READABLE | Characteristic.WRITEABLE,
188        bytes([7, 8, 9]),
189    )
190    c4 = Characteristic(
191        '84FC1A2E-C52D-4A2D-B8C3-8855BAB86638',
192        Characteristic.Properties.READ
193        | Characteristic.Properties.NOTIFY
194        | Characteristic.Properties.INDICATE,
195        Characteristic.READABLE,
196        bytes([1, 1, 1]),
197    )
198
199    s1 = Service('8140E247-04F0-42C1-BC34-534C344DAFCA', [c1, c2, c3])
200    s2 = Service('97210A0F-1875-4D05-9E5D-326EB171257A', [c4])
201    s3 = Service('1853', [])
202    s4 = Service('3A12C182-14E2-4FE0-8C5B-65D7C569F9DB', [], included_services=[s2, s3])
203    two_devices.devices[1].add_services([s1, s2, s4])
204
205    # Start
206    await two_devices.devices[0].power_on()
207    await two_devices.devices[1].power_on()
208
209    # Connect the two devices
210    connection = await two_devices.devices[0].connect(
211        two_devices.devices[1].random_address
212    )
213    peer = Peer(connection)
214
215    bogus_uuid = 'A0AA6007-0B48-4BBE-80AC-0DE9AAF541EA'
216    result = await peer.discover_services([bogus_uuid])
217    assert result == []
218    services = peer.get_services_by_uuid(bogus_uuid)
219    assert len(services) == 0
220
221    result = await peer.discover_service(s1.uuid)
222    assert len(result) == 1
223    services = peer.get_services_by_uuid(s1.uuid)
224    assert len(services) == 1
225    s = services[0]
226    assert services[0].uuid == s1.uuid
227
228    result = await peer.discover_characteristics([c1.uuid], s)
229    assert len(result) == 1
230    characteristics = peer.get_characteristics_by_uuid(c1.uuid)
231    assert len(characteristics) == 1
232    c = characteristics[0]
233    assert c.uuid == c1.uuid
234    result = await peer.read_value(c)
235    assert result is not None
236    assert result == c1.value
237
238    result = await peer.discover_service(s4.uuid)
239    assert len(result) == 1
240    result = await peer.discover_included_services(result[0])
241    assert len(result) == 2
242    # Service UUID is only present when the UUID is 16-bit Bluetooth UUID
243    assert result[1].uuid.to_bytes() == s3.uuid.to_bytes()
244
245
246# -----------------------------------------------------------------------------
247@pytest.mark.asyncio
248async def test_self_gatt_long_read():
249    # Create two devices, each with a controller, attached to the same link
250    two_devices = TwoDevices()
251
252    # Add some GATT characteristics to device 1
253    characteristics = [
254        Characteristic(
255            f'3A143AD7-D4A7-436B-97D6-5B62C315{i:04X}',
256            Characteristic.Properties.READ,
257            Characteristic.READABLE,
258            bytes([x & 255 for x in range(i)]),
259        )
260        for i in range(0, 513)
261    ]
262
263    service = Service('8140E247-04F0-42C1-BC34-534C344DAFCA', characteristics)
264    two_devices.devices[1].add_service(service)
265
266    # Start
267    await two_devices.devices[0].power_on()
268    await two_devices.devices[1].power_on()
269
270    # Connect the two devices
271    connection = await two_devices.devices[0].connect(
272        two_devices.devices[1].random_address
273    )
274    peer = Peer(connection)
275
276    result = await peer.discover_service(service.uuid)
277    assert len(result) == 1
278    found_service = result[0]
279    found_characteristics = await found_service.discover_characteristics()
280    assert len(found_characteristics) == 513
281    for i, characteristic in enumerate(found_characteristics):
282        value = await characteristic.read_value()
283        assert value == characteristics[i].value
284
285
286# -----------------------------------------------------------------------------
287async def _test_self_smp_with_configs(pairing_config1, pairing_config2):
288    # Create two devices, each with a controller, attached to the same link
289    two_devices = TwoDevices()
290
291    # Start
292    await two_devices.devices[0].power_on()
293    await two_devices.devices[1].power_on()
294
295    # Attach listeners
296    two_devices.devices[0].on(
297        'connection', lambda connection: two_devices.on_connection(0, connection)
298    )
299    two_devices.devices[1].on(
300        'connection', lambda connection: two_devices.on_connection(1, connection)
301    )
302
303    # Connect the two devices
304    connection = await two_devices.devices[0].connect(
305        two_devices.devices[1].random_address
306    )
307    assert not connection.is_encrypted
308
309    # Attach connection listeners
310    two_devices.connections[0].on(
311        'pairing', lambda keys: two_devices.on_paired(0, keys)
312    )
313    two_devices.connections[1].on(
314        'pairing', lambda keys: two_devices.on_paired(1, keys)
315    )
316
317    # Set up the pairing configs
318    if pairing_config1:
319        two_devices.devices[0].pairing_config_factory = (
320            lambda connection: pairing_config1
321        )
322    if pairing_config2:
323        two_devices.devices[1].pairing_config_factory = (
324            lambda connection: pairing_config2
325        )
326
327    # Pair
328    await two_devices.devices[0].pair(connection)
329    assert connection.is_encrypted
330    assert await two_devices.paired[0] is not None
331    assert await two_devices.paired[1] is not None
332
333
334# -----------------------------------------------------------------------------
335IO_CAP = [
336    PairingDelegate.IoCapability.NO_OUTPUT_NO_INPUT,
337    PairingDelegate.IoCapability.KEYBOARD_INPUT_ONLY,
338    PairingDelegate.IoCapability.DISPLAY_OUTPUT_ONLY,
339    PairingDelegate.IoCapability.DISPLAY_OUTPUT_AND_YES_NO_INPUT,
340    PairingDelegate.IoCapability.DISPLAY_OUTPUT_AND_KEYBOARD_INPUT,
341]
342SC = [False, True]
343MITM = [False, True]
344# Key distribution is a 4-bit bitmask
345KEY_DIST = range(16)
346
347
348@pytest.mark.asyncio
349@pytest.mark.parametrize(
350    'io_caps, sc, mitm, key_dist',
351    itertools.chain(
352        itertools.product([IO_CAP], SC, MITM, [15]),
353        itertools.product(
354            [[PairingDelegate.IoCapability.DISPLAY_OUTPUT_AND_KEYBOARD_INPUT]],
355            SC,
356            MITM,
357            KEY_DIST,
358        ),
359    ),
360)
361async def test_self_smp(io_caps, sc, mitm, key_dist):
362    class Delegate(PairingDelegate):
363        def __init__(
364            self,
365            name,
366            io_capability,
367            local_initiator_key_distribution,
368            local_responder_key_distribution,
369        ):
370            super().__init__(
371                io_capability,
372                local_initiator_key_distribution,
373                local_responder_key_distribution,
374            )
375            self.name = name
376            self.reset()
377
378        def reset(self):
379            self.peer_delegate = None
380            self.number = asyncio.get_running_loop().create_future()
381
382        # pylint: disable-next=unused-argument
383        async def compare_numbers(self, number, digits):
384            if self.peer_delegate is None:
385                logger.warning(f'[{self.name}] no peer delegate')
386                return False
387            await self.display_number(number, digits=6)
388            logger.debug(f'[{self.name}] waiting for peer number')
389            peer_number = await self.peer_delegate.number
390            logger.debug(f'[{self.name}] comparing numbers: {number} and {peer_number}')
391            return number == peer_number
392
393        async def get_number(self):
394            if self.peer_delegate is None:
395                logger.warning(f'[{self.name}] no peer delegate')
396                return 0
397            else:
398                if (
399                    self.peer_delegate.io_capability
400                    == PairingDelegate.IoCapability.KEYBOARD_INPUT_ONLY
401                ):
402                    peer_number = 6789
403                else:
404                    logger.debug(f'[{self.name}] waiting for peer number')
405                    peer_number = await self.peer_delegate.number
406                logger.debug(f'[{self.name}] returning number: {peer_number}')
407                return peer_number
408
409        async def display_number(self, number, digits):
410            logger.debug(f'[{self.name}] displaying number: {number}')
411            self.number.set_result(number)
412
413        def __str__(self):
414            return f'Delegate(name={self.name}, io_capability={self.io_capability})'
415
416    pairing_config_sets = [('Initiator', [None]), ('Responder', [None])]
417    for pairing_config_set in pairing_config_sets:
418        for io_cap in io_caps:
419            delegate = Delegate(pairing_config_set[0], io_cap, key_dist, key_dist)
420            pairing_config_set[1].append(PairingConfig(sc, mitm, True, delegate))
421
422    for pairing_config1 in pairing_config_sets[0][1]:
423        for pairing_config2 in pairing_config_sets[1][1]:
424            logger.info(
425                f'########## self_smp with {pairing_config1} and {pairing_config2}'
426            )
427            if pairing_config1:
428                pairing_config1.delegate.reset()
429            if pairing_config2:
430                pairing_config2.delegate.reset()
431            if pairing_config1 and pairing_config2:
432                pairing_config1.delegate.peer_delegate = pairing_config2.delegate
433                pairing_config2.delegate.peer_delegate = pairing_config1.delegate
434
435            await _test_self_smp_with_configs(pairing_config1, pairing_config2)
436
437
438# -----------------------------------------------------------------------------
439@pytest.mark.asyncio
440async def test_self_smp_reject():
441    class RejectingDelegate(PairingDelegate):
442        def __init__(self):
443            super().__init__(PairingDelegate.IoCapability.NO_OUTPUT_NO_INPUT)
444
445        async def accept(self):
446            return False
447
448    rejecting_pairing_config = PairingConfig(delegate=RejectingDelegate())
449    paired = False
450    try:
451        await _test_self_smp_with_configs(None, rejecting_pairing_config)
452        paired = True
453    except ProtocolError as error:
454        assert error.error_code == SMP_PAIRING_NOT_SUPPORTED_ERROR
455
456    assert not paired
457
458
459# -----------------------------------------------------------------------------
460@pytest.mark.asyncio
461async def test_self_smp_wrong_pin():
462    class WrongPinDelegate(PairingDelegate):
463        def __init__(self):
464            super().__init__(
465                PairingDelegate.IoCapability.DISPLAY_OUTPUT_AND_KEYBOARD_INPUT
466            )
467
468        async def compare_numbers(self, number, digits):
469            return False
470
471    wrong_pin_pairing_config = PairingConfig(mitm=True, delegate=WrongPinDelegate())
472    paired = False
473    try:
474        await _test_self_smp_with_configs(
475            wrong_pin_pairing_config, wrong_pin_pairing_config
476        )
477        paired = True
478    except ProtocolError as error:
479        assert error.error_code == SMP_CONFIRM_VALUE_FAILED_ERROR
480
481    assert not paired
482
483
484# -----------------------------------------------------------------------------
485@pytest.mark.asyncio
486async def test_self_smp_over_classic():
487    # Create two devices, each with a controller, attached to the same link
488    two_devices = TwoDevices()
489
490    # Attach listeners
491    two_devices.devices[0].on(
492        'connection', lambda connection: two_devices.on_connection(0, connection)
493    )
494    two_devices.devices[1].on(
495        'connection', lambda connection: two_devices.on_connection(1, connection)
496    )
497
498    # Enable Classic connections
499    two_devices.devices[0].classic_enabled = True
500    two_devices.devices[1].classic_enabled = True
501
502    # Start
503    await two_devices.devices[0].power_on()
504    await two_devices.devices[1].power_on()
505
506    # Connect the two devices
507    await asyncio.gather(
508        two_devices.devices[0].connect(
509            two_devices.devices[1].public_address, transport=BT_BR_EDR_TRANSPORT
510        ),
511        two_devices.devices[1].accept(two_devices.devices[0].public_address),
512    )
513
514    # Check the post conditions
515    assert two_devices.connections[0] is not None
516    assert two_devices.connections[1] is not None
517
518    # Mock connection
519    # TODO: Implement Classic SSP and encryption in link relayer
520    LINK_KEY = bytes.fromhex('287ad379dca402530a39f1f43047b835')
521    two_devices.devices[0].get_link_key = AsyncMock(return_value=LINK_KEY)
522    two_devices.devices[1].get_link_key = AsyncMock(return_value=LINK_KEY)
523    two_devices.connections[0].encryption = 1
524    two_devices.connections[1].encryption = 1
525
526    two_devices.connections[0].on(
527        'pairing', lambda keys: two_devices.on_paired(0, keys)
528    )
529    two_devices.connections[1].on(
530        'pairing', lambda keys: two_devices.on_paired(1, keys)
531    )
532
533    # Mock SMP
534    with patch('bumble.smp.Session', spec=True) as MockSmpSession:
535        MockSmpSession.send_pairing_confirm_command = MagicMock()
536        MockSmpSession.send_pairing_dhkey_check_command = MagicMock()
537        MockSmpSession.send_public_key_command = MagicMock()
538        MockSmpSession.send_pairing_random_command = MagicMock()
539
540        # Start CTKD
541        await two_devices.connections[0].pair()
542        await asyncio.gather(*two_devices.paired)
543
544        # Phase 2 commands should not be invoked
545        MockSmpSession.send_pairing_confirm_command.assert_not_called()
546        MockSmpSession.send_pairing_dhkey_check_command.assert_not_called()
547        MockSmpSession.send_public_key_command.assert_not_called()
548        MockSmpSession.send_pairing_random_command.assert_not_called()
549
550    for i in range(2):
551        assert (
552            await two_devices.devices[i].keystore.get(
553                str(two_devices.connections[i].peer_address)
554            )
555        ).link_key
556
557
558# -----------------------------------------------------------------------------
559@pytest.mark.asyncio
560async def test_self_smp_public_address():
561    pairing_config = PairingConfig(
562        mitm=True,
563        sc=True,
564        bonding=True,
565        identity_address_type=PairingConfig.AddressType.PUBLIC,
566        delegate=PairingDelegate(
567            PairingDelegate.IoCapability.DISPLAY_OUTPUT_AND_YES_NO_INPUT,
568            PairingDelegate.KeyDistribution.DISTRIBUTE_ENCRYPTION_KEY
569            | PairingDelegate.KeyDistribution.DISTRIBUTE_IDENTITY_KEY
570            | PairingDelegate.KeyDistribution.DISTRIBUTE_SIGNING_KEY
571            | PairingDelegate.KeyDistribution.DISTRIBUTE_LINK_KEY,
572        ),
573    )
574
575    await _test_self_smp_with_configs(pairing_config, pairing_config)
576
577
578# -----------------------------------------------------------------------------
579@pytest.mark.asyncio
580async def test_self_smp_oob_sc():
581    oob_context_1 = OobContext()
582    oob_context_2 = OobContext()
583
584    pairing_config_1 = PairingConfig(
585        mitm=True,
586        sc=True,
587        bonding=True,
588        oob=PairingConfig.OobConfig(oob_context_1, oob_context_2.share(), None),
589    )
590
591    pairing_config_2 = PairingConfig(
592        mitm=True,
593        sc=True,
594        bonding=True,
595        oob=PairingConfig.OobConfig(oob_context_2, oob_context_1.share(), None),
596    )
597
598    await _test_self_smp_with_configs(pairing_config_1, pairing_config_2)
599
600    pairing_config_3 = PairingConfig(
601        mitm=True,
602        sc=True,
603        bonding=True,
604        oob=PairingConfig.OobConfig(oob_context_2, None, None),
605    )
606
607    await _test_self_smp_with_configs(pairing_config_1, pairing_config_3)
608    await _test_self_smp_with_configs(pairing_config_3, pairing_config_1)
609
610    pairing_config_4 = PairingConfig(
611        mitm=True,
612        sc=True,
613        bonding=True,
614        oob=PairingConfig.OobConfig(oob_context_2, oob_context_2.share(), None),
615    )
616
617    with pytest.raises(ProtocolError) as error:
618        await _test_self_smp_with_configs(pairing_config_1, pairing_config_4)
619    assert error.value.error_code == SMP_CONFIRM_VALUE_FAILED_ERROR
620
621    with pytest.raises(ProtocolError):
622        await _test_self_smp_with_configs(pairing_config_4, pairing_config_1)
623    assert error.value.error_code == SMP_CONFIRM_VALUE_FAILED_ERROR
624
625
626# -----------------------------------------------------------------------------
627@pytest.mark.asyncio
628async def test_self_smp_oob_legacy():
629    legacy_context = OobLegacyContext()
630
631    pairing_config_1 = PairingConfig(
632        mitm=True,
633        sc=False,
634        bonding=True,
635        oob=PairingConfig.OobConfig(None, None, legacy_context),
636    )
637
638    pairing_config_2 = PairingConfig(
639        mitm=True,
640        sc=True,
641        bonding=True,
642        oob=PairingConfig.OobConfig(OobContext(), None, legacy_context),
643    )
644
645    await _test_self_smp_with_configs(pairing_config_1, pairing_config_2)
646    await _test_self_smp_with_configs(pairing_config_2, pairing_config_1)
647
648
649# -----------------------------------------------------------------------------
650async def run_test_self():
651    await test_self_connection()
652    await test_self_gatt()
653    await test_self_gatt_long_read()
654    await test_self_smp()
655    await test_self_smp_reject()
656    await test_self_smp_wrong_pin()
657    await test_self_smp_over_classic()
658    await test_self_smp_public_address()
659    await test_self_smp_oob_sc()
660    await test_self_smp_oob_legacy()
661
662
663# -----------------------------------------------------------------------------
664if __name__ == '__main__':
665    logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper())
666    asyncio.run(run_test_self())
667