1# Copyright 2024 Google LLC
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15import asyncio
16import avatar
17
18from avatar import PandoraDevices, BumblePandoraDevice
19from mobly import base_test, signals
20from mobly.asserts import assert_in, assert_not_in  # type: ignore
21
22from pandora.host_pb2 import RANDOM
23from pandora_experimental.gatt_grpc import GATT
24from pandora_experimental.gatt_pb2 import PRIMARY
25
26from bumble.att import UUID
27from bumble.gatt import GATT_VOLUME_CONTROL_SERVICE, GATT_AUDIO_INPUT_CONTROL_SERVICE
28from bumble.profiles.aics import AICSService
29from bumble.profiles.vcp import VolumeControlService
30
31
32class AicsTest(base_test.BaseTestClass):
33
34    def setup_class(self) -> None:
35        self.devices = PandoraDevices(self)
36        self.dut, self.ref, *_ = self.devices
37
38        if not isinstance(self.ref, BumblePandoraDevice):
39            raise signals.TestAbortClass('Test require Bumble as reference device.')
40
41    def teardown_class(self) -> None:
42        if self.devices:
43            self.devices.stop_all()
44
45    @avatar.asynchronous
46    async def setup_test(self) -> None:
47        await asyncio.gather(self.dut.reset(), self.ref.reset())
48
49        aics_service = AICSService()
50        volume_control_service = VolumeControlService(included_services=[aics_service])
51        self.ref.device.add_service(aics_service)  # type: ignore
52        self.ref.device.add_service(volume_control_service)  # type: ignore
53
54    def connect_dut_to_ref(self):
55        advertise = self.ref.host.Advertise(legacy=True, connectable=True)
56        dut_ref_connection = self.dut.host.ConnectLE(public=self.ref.address, own_address_type=RANDOM).connection
57        assert dut_ref_connection
58        advertise.cancel()  # type: ignore
59
60        return dut_ref_connection
61
62    def test_do_not_discover_aics_as_primary_service(self) -> None:
63        dut_ref_connection = self.connect_dut_to_ref()
64        dut_gatt = GATT(self.dut.channel)
65
66        services = dut_gatt.DiscoverServices(dut_ref_connection).services
67        uuids = [UUID(service.uuid) for service in services if service.service_type == PRIMARY]
68
69        assert_in(GATT_VOLUME_CONTROL_SERVICE, uuids)
70        assert_not_in(GATT_AUDIO_INPUT_CONTROL_SERVICE, uuids)
71
72    def test_gatt_discover_aics_service(self) -> None:
73        dut_ref_connection = self.connect_dut_to_ref()
74        dut_gatt = GATT(self.dut.channel)
75
76        services = dut_gatt.DiscoverServices(dut_ref_connection).services
77
78        filtered_services = [service for service in services if UUID(service.uuid) == GATT_VOLUME_CONTROL_SERVICE]
79        assert len(filtered_services) == 1
80        vcp_service = filtered_services[0]
81
82        included_services_uuids = [UUID(included_service.uuid) for included_service in vcp_service.included_services]
83        assert_in(GATT_AUDIO_INPUT_CONTROL_SERVICE, included_services_uuids)
84