1# Copyright 2024 Google LLC
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#      https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15# -----------------------------------------------------------------------------
16# Imports
17# -----------------------------------------------------------------------------
18import asyncio
19import os
20import logging
21
22from bumble import hci
23from bumble.profiles import bass
24
25
26# -----------------------------------------------------------------------------
27# Logging
28# -----------------------------------------------------------------------------
29logger = logging.getLogger(__name__)
30
31
32# -----------------------------------------------------------------------------
33def basic_operation_check(operation: bass.ControlPointOperation) -> None:
34    serialized = bytes(operation)
35    parsed = bass.ControlPointOperation.from_bytes(serialized)
36    assert bytes(parsed) == serialized
37
38
39# -----------------------------------------------------------------------------
40def test_operations() -> None:
41    op1 = bass.RemoteScanStoppedOperation()
42    basic_operation_check(op1)
43
44    op2 = bass.RemoteScanStartedOperation()
45    basic_operation_check(op2)
46
47    op3 = bass.AddSourceOperation(
48        hci.Address("AA:BB:CC:DD:EE:FF"),
49        34,
50        123456,
51        bass.PeriodicAdvertisingSyncParams.SYNCHRONIZE_TO_PA_PAST_NOT_AVAILABLE,
52        456,
53        (),
54    )
55    basic_operation_check(op3)
56
57    op4 = bass.AddSourceOperation(
58        hci.Address("AA:BB:CC:DD:EE:FF"),
59        34,
60        123456,
61        bass.PeriodicAdvertisingSyncParams.SYNCHRONIZE_TO_PA_PAST_NOT_AVAILABLE,
62        456,
63        (
64            bass.SubgroupInfo(6677, bytes.fromhex('aabbcc')),
65            bass.SubgroupInfo(8899, bytes.fromhex('ddeeff')),
66        ),
67    )
68    basic_operation_check(op4)
69
70    op5 = bass.ModifySourceOperation(
71        12,
72        bass.PeriodicAdvertisingSyncParams.SYNCHRONIZE_TO_PA_PAST_NOT_AVAILABLE,
73        567,
74        (),
75    )
76    basic_operation_check(op5)
77
78    op6 = bass.ModifySourceOperation(
79        12,
80        bass.PeriodicAdvertisingSyncParams.SYNCHRONIZE_TO_PA_PAST_NOT_AVAILABLE,
81        567,
82        (
83            bass.SubgroupInfo(6677, bytes.fromhex('112233')),
84            bass.SubgroupInfo(8899, bytes.fromhex('4567')),
85        ),
86    )
87    basic_operation_check(op6)
88
89    op7 = bass.SetBroadcastCodeOperation(
90        7, bytes.fromhex('a0a1a2a3a4a5a6a7a8a9aaabacadaeaf')
91    )
92    basic_operation_check(op7)
93
94    op8 = bass.RemoveSourceOperation(7)
95    basic_operation_check(op8)
96
97
98# -----------------------------------------------------------------------------
99def basic_broadcast_receive_state_check(brs: bass.BroadcastReceiveState) -> None:
100    serialized = bytes(brs)
101    parsed = bass.BroadcastReceiveState.from_bytes(serialized)
102    assert parsed is not None
103    assert bytes(parsed) == serialized
104
105
106def test_broadcast_receive_state() -> None:
107    subgroups = [
108        bass.SubgroupInfo(6677, bytes.fromhex('112233')),
109        bass.SubgroupInfo(8899, bytes.fromhex('4567')),
110    ]
111
112    brs1 = bass.BroadcastReceiveState(
113        12,
114        hci.Address("AA:BB:CC:DD:EE:FF"),
115        123,
116        123456,
117        bass.BroadcastReceiveState.PeriodicAdvertisingSyncState.SYNCHRONIZED_TO_PA,
118        bass.BroadcastReceiveState.BigEncryption.DECRYPTING,
119        b'',
120        subgroups,
121    )
122    basic_broadcast_receive_state_check(brs1)
123
124    brs2 = bass.BroadcastReceiveState(
125        12,
126        hci.Address("AA:BB:CC:DD:EE:FF"),
127        123,
128        123456,
129        bass.BroadcastReceiveState.PeriodicAdvertisingSyncState.SYNCHRONIZED_TO_PA,
130        bass.BroadcastReceiveState.BigEncryption.BAD_CODE,
131        bytes.fromhex('a0a1a2a3a4a5a6a7a8a9aaabacadaeaf'),
132        subgroups,
133    )
134    basic_broadcast_receive_state_check(brs2)
135
136
137# -----------------------------------------------------------------------------
138async def run():
139    test_operations()
140    test_broadcast_receive_state()
141
142
143# -----------------------------------------------------------------------------
144if __name__ == '__main__':
145    logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper())
146    asyncio.run(run())
147