1# Copyright 2021-2023 Google LLC 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# https://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15# ----------------------------------------------------------------------------- 16# Imports 17# ----------------------------------------------------------------------------- 18import asyncio 19import os 20import pytest 21import struct 22import logging 23from unittest import mock 24 25from bumble import device 26from bumble.profiles import csip 27from .test_utils import TwoDevices 28 29# ----------------------------------------------------------------------------- 30# Logging 31# ----------------------------------------------------------------------------- 32logger = logging.getLogger(__name__) 33 34 35# ----------------------------------------------------------------------------- 36def test_s1(): 37 assert ( 38 csip.s1(b'SIRKenc'[::-1]) 39 == bytes.fromhex('6901983f 18149e82 3c7d133a 7d774572')[::-1] 40 ) 41 42 43# ----------------------------------------------------------------------------- 44def test_k1(): 45 K = bytes.fromhex('676e1b9b d448696f 061ec622 3ce5ced9')[::-1] 46 SALT = csip.s1(b'SIRKenc'[::-1]) 47 P = b'csis'[::-1] 48 assert ( 49 csip.k1(K, SALT, P) 50 == bytes.fromhex('5277453c c094d982 b0e8ee53 2f2d1f8b')[::-1] 51 ) 52 53 54# ----------------------------------------------------------------------------- 55def test_sih(): 56 SIRK = bytes.fromhex('457d7d09 21a1fd22 cecd8c86 dd72cccd')[::-1] 57 PRAND = bytes.fromhex('69f563')[::-1] 58 assert csip.sih(SIRK, PRAND) == bytes.fromhex('1948da')[::-1] 59 60 61# ----------------------------------------------------------------------------- 62def test_sef(): 63 SIRK = bytes.fromhex('457d7d09 21a1fd22 cecd8c86 dd72cccd')[::-1] 64 K = bytes.fromhex('676e1b9b d448696f 061ec622 3ce5ced9')[::-1] 65 assert ( 66 csip.sef(K, SIRK) == bytes.fromhex('170a3835 e13524a0 7e2562d5 f25fd346')[::-1] 67 ) 68 69 70# ----------------------------------------------------------------------------- 71@pytest.mark.asyncio 72@pytest.mark.parametrize( 73 'sirk_type,', [(csip.SirkType.ENCRYPTED), (csip.SirkType.PLAINTEXT)] 74) 75async def test_csis(sirk_type): 76 SIRK = bytes.fromhex('2f62c8ae41867d1bb619e788a2605faa') 77 LTK = bytes.fromhex('2f62c8ae41867d1bb619e788a2605faa') 78 79 devices = TwoDevices() 80 devices[0].add_service( 81 csip.CoordinatedSetIdentificationService( 82 set_identity_resolving_key=SIRK, 83 set_identity_resolving_key_type=sirk_type, 84 coordinated_set_size=2, 85 set_member_lock=csip.MemberLock.UNLOCKED, 86 set_member_rank=0, 87 ) 88 ) 89 90 await devices.setup_connection() 91 92 # Mock encryption. 93 devices.connections[0].encryption = 1 94 devices.connections[1].encryption = 1 95 devices[0].get_long_term_key = mock.AsyncMock(return_value=LTK) 96 devices[1].get_long_term_key = mock.AsyncMock(return_value=LTK) 97 98 peer = device.Peer(devices.connections[1]) 99 csis_client = await peer.discover_service_and_create_proxy( 100 csip.CoordinatedSetIdentificationProxy 101 ) 102 103 assert await csis_client.read_set_identity_resolving_key() == (sirk_type, SIRK) 104 assert await csis_client.coordinated_set_size.read_value() == struct.pack('B', 2) 105 assert await csis_client.set_member_lock.read_value() == struct.pack( 106 'B', csip.MemberLock.UNLOCKED 107 ) 108 assert await csis_client.set_member_rank.read_value() == struct.pack('B', 0) 109 110 111# ----------------------------------------------------------------------------- 112async def run(): 113 test_sih() 114 await test_csis() 115 116 117# ----------------------------------------------------------------------------- 118if __name__ == '__main__': 119 logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper()) 120 asyncio.run(run()) 121