1# Copyright 2024 Google LLC
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#      https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15"""LE Audio - Audio Input Control Service"""
16
17# -----------------------------------------------------------------------------
18# Imports
19# -----------------------------------------------------------------------------
20import logging
21import struct
22
23from dataclasses import dataclass
24from typing import Optional
25
26from bumble import gatt
27from bumble.device import Connection
28from bumble.att import ATT_Error
29from bumble.gatt import (
30    Characteristic,
31    DelegatedCharacteristicAdapter,
32    TemplateService,
33    CharacteristicValue,
34    PackedCharacteristicAdapter,
35    GATT_AUDIO_INPUT_CONTROL_SERVICE,
36    GATT_AUDIO_INPUT_STATE_CHARACTERISTIC,
37    GATT_GAIN_SETTINGS_ATTRIBUTE_CHARACTERISTIC,
38    GATT_AUDIO_INPUT_TYPE_CHARACTERISTIC,
39    GATT_AUDIO_INPUT_STATUS_CHARACTERISTIC,
40    GATT_AUDIO_INPUT_CONTROL_POINT_CHARACTERISTIC,
41    GATT_AUDIO_INPUT_DESCRIPTION_CHARACTERISTIC,
42)
43from bumble.gatt_client import ProfileServiceProxy, ServiceProxy
44from bumble.utils import OpenIntEnum
45
46# -----------------------------------------------------------------------------
47# Logging
48# -----------------------------------------------------------------------------
49logger = logging.getLogger(__name__)
50
51
52# -----------------------------------------------------------------------------
53# Constants
54# -----------------------------------------------------------------------------
55CHANGE_COUNTER_MAX_VALUE = 0xFF
56GAIN_SETTINGS_MIN_VALUE = 0
57GAIN_SETTINGS_MAX_VALUE = 255
58
59
60class ErrorCode(OpenIntEnum):
61    '''
62    Cf. 1.6 Application error codes
63    '''
64
65    INVALID_CHANGE_COUNTER = 0x80
66    OPCODE_NOT_SUPPORTED = 0x81
67    MUTE_DISABLED = 0x82
68    VALUE_OUT_OF_RANGE = 0x83
69    GAIN_MODE_CHANGE_NOT_ALLOWED = 0x84
70
71
72class Mute(OpenIntEnum):
73    '''
74    Cf. 2.2.1.2 Mute Field
75    '''
76
77    NOT_MUTED = 0x00
78    MUTED = 0x01
79    DISABLED = 0x02
80
81
82class GainMode(OpenIntEnum):
83    '''
84    Cf. 2.2.1.3 Gain Mode
85    '''
86
87    MANUAL_ONLY = 0x00
88    AUTOMATIC_ONLY = 0x01
89    MANUAL = 0x02
90    AUTOMATIC = 0x03
91
92
93class AudioInputStatus(OpenIntEnum):
94    '''
95    Cf. 3.4 Audio Input Status
96    '''
97
98    INATIVE = 0x00
99    ACTIVE = 0x01
100
101
102class AudioInputControlPointOpCode(OpenIntEnum):
103    '''
104    Cf. 3.5.1 Audio Input Control Point procedure requirements
105    '''
106
107    SET_GAIN_SETTING = 0x00
108    UNMUTE = 0x02
109    MUTE = 0x03
110    SET_MANUAL_GAIN_MODE = 0x04
111    SET_AUTOMATIC_GAIN_MODE = 0x05
112
113
114# -----------------------------------------------------------------------------
115@dataclass
116class AudioInputState:
117    '''
118    Cf. 2.2.1 Audio Input State
119    '''
120
121    gain_settings: int = 0
122    mute: Mute = Mute.NOT_MUTED
123    gain_mode: GainMode = GainMode.MANUAL
124    change_counter: int = 0
125    attribute_value: Optional[CharacteristicValue] = None
126
127    def __bytes__(self) -> bytes:
128        return bytes(
129            [self.gain_settings, self.mute, self.gain_mode, self.change_counter]
130        )
131
132    @classmethod
133    def from_bytes(cls, data: bytes):
134        gain_settings, mute, gain_mode, change_counter = struct.unpack("BBBB", data)
135        return cls(gain_settings, mute, gain_mode, change_counter)
136
137    def update_gain_settings_unit(self, gain_settings_unit: int) -> None:
138        self.gain_settings_unit = gain_settings_unit
139
140    def increment_gain_settings(self, gain_settings_unit: int) -> None:
141        self.gain_settings += gain_settings_unit
142        self.increment_change_counter()
143
144    def decrement_gain_settings(self) -> None:
145        self.gain_settings -= self.gain_settings_unit
146        self.increment_change_counter()
147
148    def increment_change_counter(self):
149        self.change_counter = (self.change_counter + 1) % (CHANGE_COUNTER_MAX_VALUE + 1)
150
151    async def notify_subscribers_via_connection(self, connection: Connection) -> None:
152        assert self.attribute_value is not None
153        await connection.device.notify_subscribers(
154            attribute=self.attribute_value, value=bytes(self)
155        )
156
157    def on_read(self, _connection: Optional[Connection]) -> bytes:
158        return bytes(self)
159
160
161@dataclass
162class GainSettingsProperties:
163    '''
164    Cf. 3.2 Gain Settings Properties
165    '''
166
167    gain_settings_unit: int = 1
168    gain_settings_minimum: int = GAIN_SETTINGS_MIN_VALUE
169    gain_settings_maximum: int = GAIN_SETTINGS_MAX_VALUE
170
171    @classmethod
172    def from_bytes(cls, data: bytes):
173        (gain_settings_unit, gain_settings_minimum, gain_settings_maximum) = (
174            struct.unpack('BBB', data)
175        )
176        GainSettingsProperties(
177            gain_settings_unit, gain_settings_minimum, gain_settings_maximum
178        )
179
180    def __bytes__(self) -> bytes:
181        return bytes(
182            [
183                self.gain_settings_unit,
184                self.gain_settings_minimum,
185                self.gain_settings_maximum,
186            ]
187        )
188
189    def on_read(self, _connection: Optional[Connection]) -> bytes:
190        return bytes(self)
191
192
193@dataclass
194class AudioInputControlPoint:
195    '''
196    Cf. 3.5.2 Audio Input Control Point
197    '''
198
199    audio_input_state: AudioInputState
200    gain_settings_properties: GainSettingsProperties
201
202    async def on_write(self, connection: Optional[Connection], value: bytes) -> None:
203        assert connection
204
205        opcode = AudioInputControlPointOpCode(value[0])
206
207        if opcode == AudioInputControlPointOpCode.SET_GAIN_SETTING:
208            gain_settings_operand = value[2]
209            await self._set_gain_settings(connection, gain_settings_operand)
210        elif opcode == AudioInputControlPointOpCode.UNMUTE:
211            await self._unmute(connection)
212        elif opcode == AudioInputControlPointOpCode.MUTE:
213            change_counter_operand = value[1]
214            await self._mute(connection, change_counter_operand)
215        elif opcode == AudioInputControlPointOpCode.SET_MANUAL_GAIN_MODE:
216            await self._set_manual_gain_mode(connection)
217        elif opcode == AudioInputControlPointOpCode.SET_AUTOMATIC_GAIN_MODE:
218            await self._set_automatic_gain_mode(connection)
219        else:
220            logger.error(f"OpCode value is incorrect: {opcode}")
221            raise ATT_Error(ErrorCode.OPCODE_NOT_SUPPORTED)
222
223    async def _set_gain_settings(
224        self, connection: Connection, gain_settings_operand: int
225    ) -> None:
226        '''Cf. 3.5.2.1 Set Gain Settings Procedure'''
227
228        gain_mode = self.audio_input_state.gain_mode
229
230        logger.error(f"set_gain_setting: gain_mode: {gain_mode}")
231        if not (gain_mode == GainMode.MANUAL or gain_mode == GainMode.MANUAL_ONLY):
232            logger.warning(
233                "GainMode should be either MANUAL or MANUAL_ONLY Cf Spec Audio Input Control Service 3.5.2.1"
234            )
235            return
236
237        if (
238            gain_settings_operand < self.gain_settings_properties.gain_settings_minimum
239            or gain_settings_operand
240            > self.gain_settings_properties.gain_settings_maximum
241        ):
242            logger.error("gain_seetings value out of range")
243            raise ATT_Error(ErrorCode.VALUE_OUT_OF_RANGE)
244
245        if self.audio_input_state.gain_settings != gain_settings_operand:
246            self.audio_input_state.gain_settings = gain_settings_operand
247            await self.audio_input_state.notify_subscribers_via_connection(connection)
248
249    async def _unmute(self, connection: Connection):
250        '''Cf. 3.5.2.2 Unmute procedure'''
251
252        logger.error(f'unmute: {self.audio_input_state.mute}')
253        mute = self.audio_input_state.mute
254        if mute == Mute.DISABLED:
255            logger.error("unmute: Cannot change Mute value, Mute state is DISABLED")
256            raise ATT_Error(ErrorCode.MUTE_DISABLED)
257
258        if mute == Mute.NOT_MUTED:
259            return
260
261        self.audio_input_state.mute = Mute.NOT_MUTED
262        self.audio_input_state.increment_change_counter()
263        await self.audio_input_state.notify_subscribers_via_connection(connection)
264
265    async def _mute(self, connection: Connection, change_counter_operand: int) -> None:
266        '''Cf. 3.5.5.2 Mute procedure'''
267
268        change_counter = self.audio_input_state.change_counter
269        mute = self.audio_input_state.mute
270        if mute == Mute.DISABLED:
271            logger.error("mute: Cannot change Mute value, Mute state is DISABLED")
272            raise ATT_Error(ErrorCode.MUTE_DISABLED)
273
274        if change_counter != change_counter_operand:
275            raise ATT_Error(ErrorCode.INVALID_CHANGE_COUNTER)
276
277        if mute == Mute.MUTED:
278            return
279
280        self.audio_input_state.mute = Mute.MUTED
281        self.audio_input_state.increment_change_counter()
282        await self.audio_input_state.notify_subscribers_via_connection(connection)
283
284    async def _set_manual_gain_mode(self, connection: Connection) -> None:
285        '''Cf. 3.5.2.4 Set Manual Gain Mode procedure'''
286
287        gain_mode = self.audio_input_state.gain_mode
288        if gain_mode in (GainMode.AUTOMATIC_ONLY, GainMode.MANUAL_ONLY):
289            logger.error(f"Cannot change gain_mode, bad state: {gain_mode}")
290            raise ATT_Error(ErrorCode.GAIN_MODE_CHANGE_NOT_ALLOWED)
291
292        if gain_mode == GainMode.MANUAL:
293            return
294
295        self.audio_input_state.gain_mode = GainMode.MANUAL
296        self.audio_input_state.increment_change_counter()
297        await self.audio_input_state.notify_subscribers_via_connection(connection)
298
299    async def _set_automatic_gain_mode(self, connection: Connection) -> None:
300        '''Cf. 3.5.2.5 Set Automatic Gain Mode'''
301
302        gain_mode = self.audio_input_state.gain_mode
303        if gain_mode in (GainMode.AUTOMATIC_ONLY, GainMode.MANUAL_ONLY):
304            logger.error(f"Cannot change gain_mode, bad state: {gain_mode}")
305            raise ATT_Error(ErrorCode.GAIN_MODE_CHANGE_NOT_ALLOWED)
306
307        if gain_mode == GainMode.AUTOMATIC:
308            return
309
310        self.audio_input_state.gain_mode = GainMode.AUTOMATIC
311        self.audio_input_state.increment_change_counter()
312        await self.audio_input_state.notify_subscribers_via_connection(connection)
313
314
315@dataclass
316class AudioInputDescription:
317    '''
318    Cf. 3.6 Audio Input Description
319    '''
320
321    audio_input_description: str = "Bluetooth"
322    attribute_value: Optional[CharacteristicValue] = None
323
324    @classmethod
325    def from_bytes(cls, data: bytes):
326        return cls(audio_input_description=data.decode('utf-8'))
327
328    def __bytes__(self) -> bytes:
329        return self.audio_input_description.encode('utf-8')
330
331    def on_read(self, _connection: Optional[Connection]) -> bytes:
332        return self.audio_input_description.encode('utf-8')
333
334    async def on_write(self, connection: Optional[Connection], value: bytes) -> None:
335        assert connection
336        assert self.attribute_value
337
338        self.audio_input_description = value.decode('utf-8')
339        await connection.device.notify_subscribers(
340            attribute=self.attribute_value, value=value
341        )
342
343
344class AICSService(TemplateService):
345    UUID = GATT_AUDIO_INPUT_CONTROL_SERVICE
346
347    def __init__(
348        self,
349        audio_input_state: Optional[AudioInputState] = None,
350        gain_settings_properties: Optional[GainSettingsProperties] = None,
351        audio_input_type: str = "local",
352        audio_input_status: Optional[AudioInputStatus] = None,
353        audio_input_description: Optional[AudioInputDescription] = None,
354    ):
355        self.audio_input_state = (
356            AudioInputState() if audio_input_state is None else audio_input_state
357        )
358        self.gain_settings_properties = (
359            GainSettingsProperties()
360            if gain_settings_properties is None
361            else gain_settings_properties
362        )
363        self.audio_input_status = (
364            AudioInputStatus.ACTIVE
365            if audio_input_status is None
366            else audio_input_status
367        )
368        self.audio_input_description = (
369            AudioInputDescription()
370            if audio_input_description is None
371            else audio_input_description
372        )
373
374        self.audio_input_control_point: AudioInputControlPoint = AudioInputControlPoint(
375            self.audio_input_state, self.gain_settings_properties
376        )
377
378        self.audio_input_state_characteristic = DelegatedCharacteristicAdapter(
379            Characteristic(
380                uuid=GATT_AUDIO_INPUT_STATE_CHARACTERISTIC,
381                properties=Characteristic.Properties.READ
382                | Characteristic.Properties.NOTIFY,
383                permissions=Characteristic.Permissions.READ_REQUIRES_ENCRYPTION,
384                value=CharacteristicValue(read=self.audio_input_state.on_read),
385            ),
386            encode=lambda value: bytes(value),
387        )
388        self.audio_input_state.attribute_value = (
389            self.audio_input_state_characteristic.value
390        )
391
392        self.gain_settings_properties_characteristic = DelegatedCharacteristicAdapter(
393            Characteristic(
394                uuid=GATT_GAIN_SETTINGS_ATTRIBUTE_CHARACTERISTIC,
395                properties=Characteristic.Properties.READ,
396                permissions=Characteristic.Permissions.READ_REQUIRES_ENCRYPTION,
397                value=CharacteristicValue(read=self.gain_settings_properties.on_read),
398            )
399        )
400
401        self.audio_input_type_characteristic = Characteristic(
402            uuid=GATT_AUDIO_INPUT_TYPE_CHARACTERISTIC,
403            properties=Characteristic.Properties.READ,
404            permissions=Characteristic.Permissions.READ_REQUIRES_ENCRYPTION,
405            value=audio_input_type,
406        )
407
408        self.audio_input_status_characteristic = Characteristic(
409            uuid=GATT_AUDIO_INPUT_STATUS_CHARACTERISTIC,
410            properties=Characteristic.Properties.READ,
411            permissions=Characteristic.Permissions.READ_REQUIRES_ENCRYPTION,
412            value=bytes([self.audio_input_status]),
413        )
414
415        self.audio_input_control_point_characteristic = DelegatedCharacteristicAdapter(
416            Characteristic(
417                uuid=GATT_AUDIO_INPUT_CONTROL_POINT_CHARACTERISTIC,
418                properties=Characteristic.Properties.WRITE,
419                permissions=Characteristic.Permissions.WRITE_REQUIRES_ENCRYPTION,
420                value=CharacteristicValue(
421                    write=self.audio_input_control_point.on_write
422                ),
423            )
424        )
425
426        self.audio_input_description_characteristic = DelegatedCharacteristicAdapter(
427            Characteristic(
428                uuid=GATT_AUDIO_INPUT_DESCRIPTION_CHARACTERISTIC,
429                properties=Characteristic.Properties.READ
430                | Characteristic.Properties.NOTIFY
431                | Characteristic.Properties.WRITE_WITHOUT_RESPONSE,
432                permissions=Characteristic.Permissions.READ_REQUIRES_ENCRYPTION
433                | Characteristic.Permissions.WRITE_REQUIRES_ENCRYPTION,
434                value=CharacteristicValue(
435                    write=self.audio_input_description.on_write,
436                    read=self.audio_input_description.on_read,
437                ),
438            )
439        )
440        self.audio_input_description.attribute_value = (
441            self.audio_input_control_point_characteristic.value
442        )
443
444        super().__init__(
445            characteristics=[
446                self.audio_input_state_characteristic,  # type: ignore
447                self.gain_settings_properties_characteristic,  # type: ignore
448                self.audio_input_type_characteristic,  # type: ignore
449                self.audio_input_status_characteristic,  # type: ignore
450                self.audio_input_control_point_characteristic,  # type: ignore
451                self.audio_input_description_characteristic,  # type: ignore
452            ],
453            primary=False,
454        )
455
456
457# -----------------------------------------------------------------------------
458# Client
459# -----------------------------------------------------------------------------
460class AICSServiceProxy(ProfileServiceProxy):
461    SERVICE_CLASS = AICSService
462
463    def __init__(self, service_proxy: ServiceProxy) -> None:
464        self.service_proxy = service_proxy
465
466        if not (
467            characteristics := service_proxy.get_characteristics_by_uuid(
468                GATT_AUDIO_INPUT_STATE_CHARACTERISTIC
469            )
470        ):
471            raise gatt.InvalidServiceError("Audio Input State Characteristic not found")
472        self.audio_input_state = DelegatedCharacteristicAdapter(
473            characteristic=characteristics[0], decode=AudioInputState.from_bytes
474        )
475
476        if not (
477            characteristics := service_proxy.get_characteristics_by_uuid(
478                GATT_GAIN_SETTINGS_ATTRIBUTE_CHARACTERISTIC
479            )
480        ):
481            raise gatt.InvalidServiceError(
482                "Gain Settings Attribute Characteristic not found"
483            )
484        self.gain_settings_properties = PackedCharacteristicAdapter(
485            characteristics[0],
486            'BBB',
487        )
488
489        if not (
490            characteristics := service_proxy.get_characteristics_by_uuid(
491                GATT_AUDIO_INPUT_STATUS_CHARACTERISTIC
492            )
493        ):
494            raise gatt.InvalidServiceError(
495                "Audio Input Status Characteristic not found"
496            )
497        self.audio_input_status = PackedCharacteristicAdapter(
498            characteristics[0],
499            'B',
500        )
501
502        if not (
503            characteristics := service_proxy.get_characteristics_by_uuid(
504                GATT_AUDIO_INPUT_CONTROL_POINT_CHARACTERISTIC
505            )
506        ):
507            raise gatt.InvalidServiceError(
508                "Audio Input Control Point Characteristic not found"
509            )
510        self.audio_input_control_point = characteristics[0]
511
512        if not (
513            characteristics := service_proxy.get_characteristics_by_uuid(
514                GATT_AUDIO_INPUT_DESCRIPTION_CHARACTERISTIC
515            )
516        ):
517            raise gatt.InvalidServiceError(
518                "Audio Input Description Characteristic not found"
519            )
520        self.audio_input_description = characteristics[0]
521