1# Copyright 2024 Google LLC
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#      https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15# -----------------------------------------------------------------------------
16# Imports
17# -----------------------------------------------------------------------------
18import asyncio
19import pytest
20import functools
21import pytest_asyncio
22import logging
23import sys
24
25from bumble import att, device
26from bumble.profiles import hap
27from .test_utils import TwoDevices
28from bumble.keys import PairingKeys
29
30# -----------------------------------------------------------------------------
31# Logging
32# -----------------------------------------------------------------------------
33logger = logging.getLogger(__name__)
34logger.setLevel(logging.DEBUG)
35
36foo_preset = hap.PresetRecord(1, "foo preset")
37bar_preset = hap.PresetRecord(50, "bar preset")
38foobar_preset = hap.PresetRecord(5, "foobar preset")
39unavailable_preset = hap.PresetRecord(
40    78,
41    "foobar preset",
42    hap.PresetRecord.Property(
43        hap.PresetRecord.Property.Writable.CANNOT_BE_WRITTEN,
44        hap.PresetRecord.Property.IsAvailable.IS_UNAVAILABLE,
45    ),
46)
47
48server_features = hap.HearingAidFeatures(
49    hap.HearingAidType.MONAURAL_HEARING_AID,
50    hap.PresetSynchronizationSupport.PRESET_SYNCHRONIZATION_IS_NOT_SUPPORTED,
51    hap.IndependentPresets.IDENTICAL_PRESET_RECORD,
52    hap.DynamicPresets.PRESET_RECORDS_DOES_NOT_CHANGE,
53    hap.WritablePresetsSupport.WRITABLE_PRESET_RECORDS_SUPPORTED,
54)
55
56TIMEOUT = 0.1
57
58
59async def assert_queue_is_empty(queue: asyncio.Queue):
60    assert queue.empty()
61
62    # Check that nothing is being added during TIMEOUT secondes
63    if sys.version_info >= (3, 11):
64        with pytest.raises(TimeoutError):
65            await asyncio.wait_for(queue.get(), TIMEOUT)
66    else:
67        with pytest.raises(asyncio.TimeoutError):
68            await asyncio.wait_for(queue.get(), TIMEOUT)
69
70
71# -----------------------------------------------------------------------------
72@pytest_asyncio.fixture
73async def hap_client():
74    devices = TwoDevices()
75    devices[0].add_service(
76        hap.HearingAccessService(
77            devices[0],
78            server_features,
79            [foo_preset, bar_preset, foobar_preset, unavailable_preset],
80        )
81    )
82
83    await devices.setup_connection()
84    # TODO negotiate MTU > 49 to not truncate preset names
85
86    # Mock encryption.
87    devices.connections[0].encryption = 1  # type: ignore
88    devices.connections[1].encryption = 1  # type: ignore
89
90    devices[0].on_pairing(
91        devices.connections[0], devices.connections[0].peer_address, PairingKeys(), True
92    )
93
94    peer = device.Peer(devices.connections[1])  # type: ignore
95    hap_client = await peer.discover_service_and_create_proxy(
96        hap.HearingAccessServiceProxy
97    )
98    assert hap_client
99    await hap_client.setup_subscription()
100
101    yield hap_client
102
103
104# -----------------------------------------------------------------------------
105@pytest.mark.asyncio
106async def test_init_service(hap_client: hap.HearingAccessServiceProxy):
107    assert (
108        hap.HearingAidFeatures_from_bytes(await hap_client.server_features.read_value())
109        == server_features
110    )
111    assert (await hap_client.active_preset_index.read_value()) == (foo_preset.index)
112
113
114# -----------------------------------------------------------------------------
115@pytest.mark.asyncio
116async def test_read_all_presets(hap_client: hap.HearingAccessServiceProxy):
117    await hap_client.hearing_aid_preset_control_point.write_value(
118        bytes([hap.HearingAidPresetControlPointOpcode.READ_PRESETS_REQUEST, 1, 0xFF])
119    )
120    assert (await hap_client.preset_control_point_indications.get()) == bytes(
121        [hap.HearingAidPresetControlPointOpcode.READ_PRESET_RESPONSE, 0]
122    ) + bytes(foo_preset)
123    assert (await hap_client.preset_control_point_indications.get()) == bytes(
124        [hap.HearingAidPresetControlPointOpcode.READ_PRESET_RESPONSE, 0]
125    ) + bytes(foobar_preset)
126    assert (await hap_client.preset_control_point_indications.get()) == bytes(
127        [hap.HearingAidPresetControlPointOpcode.READ_PRESET_RESPONSE, 0]
128    ) + bytes(bar_preset)
129    assert (await hap_client.preset_control_point_indications.get()) == bytes(
130        [hap.HearingAidPresetControlPointOpcode.READ_PRESET_RESPONSE, 1]
131    ) + bytes(unavailable_preset)
132
133    await assert_queue_is_empty(hap_client.preset_control_point_indications)
134
135
136# -----------------------------------------------------------------------------
137@pytest.mark.asyncio
138async def test_read_partial_presets(hap_client: hap.HearingAccessServiceProxy):
139    await hap_client.hearing_aid_preset_control_point.write_value(
140        bytes([hap.HearingAidPresetControlPointOpcode.READ_PRESETS_REQUEST, 3, 2])
141    )
142    assert (await hap_client.preset_control_point_indications.get())[2:] == bytes(
143        foobar_preset
144    )
145    assert (await hap_client.preset_control_point_indications.get())[2:] == bytes(
146        bar_preset
147    )
148
149
150# -----------------------------------------------------------------------------
151@pytest.mark.asyncio
152async def test_set_active_preset_valid(hap_client: hap.HearingAccessServiceProxy):
153    await hap_client.hearing_aid_preset_control_point.write_value(
154        bytes(
155            [hap.HearingAidPresetControlPointOpcode.SET_ACTIVE_PRESET, bar_preset.index]
156        )
157    )
158    assert (await hap_client.active_preset_index_notification.get()) == bar_preset.index
159
160    assert (await hap_client.active_preset_index.read_value()) == (bar_preset.index)
161
162    await assert_queue_is_empty(hap_client.active_preset_index_notification)
163
164
165# -----------------------------------------------------------------------------
166@pytest.mark.asyncio
167async def test_set_active_preset_invalid(hap_client: hap.HearingAccessServiceProxy):
168    with pytest.raises(att.ATT_Error) as e:
169        await hap_client.hearing_aid_preset_control_point.write_value(
170            bytes(
171                [
172                    hap.HearingAidPresetControlPointOpcode.SET_ACTIVE_PRESET,
173                    unavailable_preset.index,
174                ]
175            ),
176            with_response=True,
177        )
178    assert e.value.error_code == hap.ErrorCode.PRESET_OPERATION_NOT_POSSIBLE
179
180
181# -----------------------------------------------------------------------------
182@pytest.mark.asyncio
183async def test_set_next_preset(hap_client: hap.HearingAccessServiceProxy):
184    await hap_client.hearing_aid_preset_control_point.write_value(
185        bytes([hap.HearingAidPresetControlPointOpcode.SET_NEXT_PRESET])
186    )
187    assert (
188        await hap_client.active_preset_index_notification.get()
189    ) == foobar_preset.index
190
191    assert (await hap_client.active_preset_index.read_value()) == (foobar_preset.index)
192
193    await assert_queue_is_empty(hap_client.active_preset_index_notification)
194
195
196# -----------------------------------------------------------------------------
197@pytest.mark.asyncio
198async def test_set_next_preset_will_loop_to_first(
199    hap_client: hap.HearingAccessServiceProxy,
200):
201    async def go_next(new_preset: hap.PresetRecord):
202        await hap_client.hearing_aid_preset_control_point.write_value(
203            bytes([hap.HearingAidPresetControlPointOpcode.SET_NEXT_PRESET])
204        )
205        assert (
206            await hap_client.active_preset_index_notification.get()
207        ) == new_preset.index
208
209        assert (await hap_client.active_preset_index.read_value()) == (new_preset.index)
210
211    await go_next(foobar_preset)
212    await go_next(bar_preset)
213    await go_next(foo_preset)
214
215    # Note that there is a invalid preset in the preset record of the server
216
217    await assert_queue_is_empty(hap_client.active_preset_index_notification)
218
219
220# -----------------------------------------------------------------------------
221@pytest.mark.asyncio
222async def test_set_previous_preset_will_loop_to_last(
223    hap_client: hap.HearingAccessServiceProxy,
224):
225    await hap_client.hearing_aid_preset_control_point.write_value(
226        bytes([hap.HearingAidPresetControlPointOpcode.SET_PREVIOUS_PRESET])
227    )
228    assert (await hap_client.active_preset_index_notification.get()) == bar_preset.index
229
230    assert (await hap_client.active_preset_index.read_value()) == (bar_preset.index)
231
232    await assert_queue_is_empty(hap_client.active_preset_index_notification)
233