1# Copyright 2021-2024 Google LLC
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#      https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15# -----------------------------------------------------------------------------
16# Imports
17# -----------------------------------------------------------------------------
18from __future__ import annotations
19import asyncio
20import datetime
21import enum
22import functools
23from importlib import resources
24import json
25import os
26import logging
27import pathlib
28from typing import Optional, List, cast
29import weakref
30import struct
31
32import ctypes
33import wasmtime
34import wasmtime.loader
35import liblc3  # type: ignore
36
37import click
38import aiohttp.web
39
40import bumble
41from bumble.core import AdvertisingData
42from bumble.colors import color
43from bumble.device import Device, DeviceConfiguration, AdvertisingParameters
44from bumble.transport import open_transport
45from bumble.profiles import ascs, bap, pacs
46from bumble.hci import Address, CodecID, CodingFormat, HCI_IsoDataPacket
47
48# -----------------------------------------------------------------------------
49# Logging
50# -----------------------------------------------------------------------------
51logger = logging.getLogger(__name__)
52
53# -----------------------------------------------------------------------------
54# Constants
55# -----------------------------------------------------------------------------
56DEFAULT_UI_PORT = 7654
57
58
59def _sink_pac_record() -> pacs.PacRecord:
60    return pacs.PacRecord(
61        coding_format=CodingFormat(CodecID.LC3),
62        codec_specific_capabilities=bap.CodecSpecificCapabilities(
63            supported_sampling_frequencies=(
64                bap.SupportedSamplingFrequency.FREQ_8000
65                | bap.SupportedSamplingFrequency.FREQ_16000
66                | bap.SupportedSamplingFrequency.FREQ_24000
67                | bap.SupportedSamplingFrequency.FREQ_32000
68                | bap.SupportedSamplingFrequency.FREQ_48000
69            ),
70            supported_frame_durations=(
71                bap.SupportedFrameDuration.DURATION_10000_US_SUPPORTED
72            ),
73            supported_audio_channel_count=[1, 2],
74            min_octets_per_codec_frame=26,
75            max_octets_per_codec_frame=240,
76            supported_max_codec_frames_per_sdu=2,
77        ),
78    )
79
80
81def _source_pac_record() -> pacs.PacRecord:
82    return pacs.PacRecord(
83        coding_format=CodingFormat(CodecID.LC3),
84        codec_specific_capabilities=bap.CodecSpecificCapabilities(
85            supported_sampling_frequencies=(
86                bap.SupportedSamplingFrequency.FREQ_8000
87                | bap.SupportedSamplingFrequency.FREQ_16000
88                | bap.SupportedSamplingFrequency.FREQ_24000
89                | bap.SupportedSamplingFrequency.FREQ_32000
90                | bap.SupportedSamplingFrequency.FREQ_48000
91            ),
92            supported_frame_durations=(
93                bap.SupportedFrameDuration.DURATION_10000_US_SUPPORTED
94            ),
95            supported_audio_channel_count=[1],
96            min_octets_per_codec_frame=30,
97            max_octets_per_codec_frame=100,
98            supported_max_codec_frames_per_sdu=1,
99        ),
100    )
101
102
103# -----------------------------------------------------------------------------
104# WASM - liblc3
105# -----------------------------------------------------------------------------
106store = wasmtime.loader.store
107_memory = cast(wasmtime.Memory, liblc3.memory)
108STACK_POINTER = _memory.data_len(store)
109_memory.grow(store, 1)
110# Mapping wasmtime memory to linear address
111memory = (ctypes.c_ubyte * _memory.data_len(store)).from_address(
112    ctypes.addressof(_memory.data_ptr(store).contents)  # type: ignore
113)
114
115
116class Liblc3PcmFormat(enum.IntEnum):
117    S16 = 0
118    S24 = 1
119    S24_3LE = 2
120    FLOAT = 3
121
122
123MAX_DECODER_SIZE = liblc3.lc3_decoder_size(10000, 48000)
124MAX_ENCODER_SIZE = liblc3.lc3_encoder_size(10000, 48000)
125
126DECODER_STACK_POINTER = STACK_POINTER
127ENCODER_STACK_POINTER = DECODER_STACK_POINTER + MAX_DECODER_SIZE * 2
128DECODE_BUFFER_STACK_POINTER = ENCODER_STACK_POINTER + MAX_ENCODER_SIZE * 2
129ENCODE_BUFFER_STACK_POINTER = DECODE_BUFFER_STACK_POINTER + 8192
130DEFAULT_PCM_SAMPLE_RATE = 48000
131DEFAULT_PCM_FORMAT = Liblc3PcmFormat.S16
132DEFAULT_PCM_BYTES_PER_SAMPLE = 2
133
134
135encoders: List[int] = []
136decoders: List[int] = []
137
138
139def setup_encoders(
140    sample_rate_hz: int, frame_duration_us: int, num_channels: int
141) -> None:
142    logger.info(
143        f"setup_encoders {sample_rate_hz}Hz {frame_duration_us}us {num_channels}channels"
144    )
145    encoders[:num_channels] = [
146        liblc3.lc3_setup_encoder(
147            frame_duration_us,
148            sample_rate_hz,
149            DEFAULT_PCM_SAMPLE_RATE,  # Input sample rate
150            ENCODER_STACK_POINTER + MAX_ENCODER_SIZE * i,
151        )
152        for i in range(num_channels)
153    ]
154
155
156def setup_decoders(
157    sample_rate_hz: int, frame_duration_us: int, num_channels: int
158) -> None:
159    logger.info(
160        f"setup_decoders {sample_rate_hz}Hz {frame_duration_us}us {num_channels}channels"
161    )
162    decoders[:num_channels] = [
163        liblc3.lc3_setup_decoder(
164            frame_duration_us,
165            sample_rate_hz,
166            DEFAULT_PCM_SAMPLE_RATE,  # Output sample rate
167            DECODER_STACK_POINTER + MAX_DECODER_SIZE * i,
168        )
169        for i in range(num_channels)
170    ]
171
172
173def decode(
174    frame_duration_us: int,
175    num_channels: int,
176    input_bytes: bytes,
177) -> bytes:
178    if not input_bytes:
179        return b''
180
181    input_buffer_offset = DECODE_BUFFER_STACK_POINTER
182    input_buffer_size = len(input_bytes)
183    input_bytes_per_frame = input_buffer_size // num_channels
184
185    # Copy into wasm
186    memory[input_buffer_offset : input_buffer_offset + input_buffer_size] = input_bytes  # type: ignore
187
188    output_buffer_offset = input_buffer_offset + input_buffer_size
189    output_buffer_size = (
190        liblc3.lc3_frame_samples(frame_duration_us, DEFAULT_PCM_SAMPLE_RATE)
191        * DEFAULT_PCM_BYTES_PER_SAMPLE
192        * num_channels
193    )
194
195    for i in range(num_channels):
196        res = liblc3.lc3_decode(
197            decoders[i],
198            input_buffer_offset + input_bytes_per_frame * i,
199            input_bytes_per_frame,
200            DEFAULT_PCM_FORMAT,
201            output_buffer_offset + i * DEFAULT_PCM_BYTES_PER_SAMPLE,
202            num_channels,  # Stride
203        )
204
205        if res != 0:
206            logging.error(f"Parsing failed, res={res}")
207
208    # Extract decoded data from the output buffer
209    return bytes(
210        memory[output_buffer_offset : output_buffer_offset + output_buffer_size]
211    )
212
213
214def encode(
215    sdu_length: int,
216    num_channels: int,
217    stride: int,
218    input_bytes: bytes,
219) -> bytes:
220    if not input_bytes:
221        return b''
222
223    input_buffer_offset = ENCODE_BUFFER_STACK_POINTER
224    input_buffer_size = len(input_bytes)
225
226    # Copy into wasm
227    memory[input_buffer_offset : input_buffer_offset + input_buffer_size] = input_bytes  # type: ignore
228
229    output_buffer_offset = input_buffer_offset + input_buffer_size
230    output_buffer_size = sdu_length
231    output_frame_size = output_buffer_size // num_channels
232
233    for i in range(num_channels):
234        res = liblc3.lc3_encode(
235            encoders[i],
236            DEFAULT_PCM_FORMAT,
237            input_buffer_offset + DEFAULT_PCM_BYTES_PER_SAMPLE * i,
238            stride,
239            output_frame_size,
240            output_buffer_offset + output_frame_size * i,
241        )
242
243        if res != 0:
244            logging.error(f"Parsing failed, res={res}")
245
246    # Extract decoded data from the output buffer
247    return bytes(
248        memory[output_buffer_offset : output_buffer_offset + output_buffer_size]
249    )
250
251
252async def lc3_source_task(
253    filename: str,
254    sdu_length: int,
255    frame_duration_us: int,
256    device: Device,
257    cis_handle: int,
258) -> None:
259    with open(filename, 'rb') as f:
260        header = f.read(44)
261        assert header[8:12] == b'WAVE'
262
263        pcm_num_channel, pcm_sample_rate, _byte_rate, _block_align, bits_per_sample = (
264            struct.unpack("<HIIHH", header[22:36])
265        )
266        assert pcm_sample_rate == DEFAULT_PCM_SAMPLE_RATE
267        assert bits_per_sample == DEFAULT_PCM_BYTES_PER_SAMPLE * 8
268
269        frame_bytes = (
270            liblc3.lc3_frame_samples(frame_duration_us, DEFAULT_PCM_SAMPLE_RATE)
271            * DEFAULT_PCM_BYTES_PER_SAMPLE
272        )
273        packet_sequence_number = 0
274
275        while True:
276            next_round = datetime.datetime.now() + datetime.timedelta(
277                microseconds=frame_duration_us
278            )
279            pcm_data = f.read(frame_bytes)
280            sdu = encode(sdu_length, pcm_num_channel, pcm_num_channel, pcm_data)
281
282            iso_packet = HCI_IsoDataPacket(
283                connection_handle=cis_handle,
284                data_total_length=sdu_length + 4,
285                packet_sequence_number=packet_sequence_number,
286                pb_flag=0b10,
287                packet_status_flag=0,
288                iso_sdu_length=sdu_length,
289                iso_sdu_fragment=sdu,
290            )
291            device.host.send_hci_packet(iso_packet)
292            packet_sequence_number += 1
293            sleep_time = next_round - datetime.datetime.now()
294            await asyncio.sleep(sleep_time.total_seconds())
295
296
297# -----------------------------------------------------------------------------
298class UiServer:
299    speaker: weakref.ReferenceType[Speaker]
300    port: int
301
302    def __init__(self, speaker: Speaker, port: int) -> None:
303        self.speaker = weakref.ref(speaker)
304        self.port = port
305        self.channel_socket = None
306
307    async def start_http(self) -> None:
308        """Start the UI HTTP server."""
309
310        app = aiohttp.web.Application()
311        app.add_routes(
312            [
313                aiohttp.web.get('/', self.get_static),
314                aiohttp.web.get('/index.html', self.get_static),
315                aiohttp.web.get('/channel', self.get_channel),
316            ]
317        )
318
319        runner = aiohttp.web.AppRunner(app)
320        await runner.setup()
321        site = aiohttp.web.TCPSite(runner, 'localhost', self.port)
322        print('UI HTTP server at ' + color(f'http://127.0.0.1:{self.port}', 'green'))
323        await site.start()
324
325    async def get_static(self, request):
326        path = request.path
327        if path == '/':
328            path = '/index.html'
329        if path.endswith('.html'):
330            content_type = 'text/html'
331        elif path.endswith('.js'):
332            content_type = 'text/javascript'
333        elif path.endswith('.css'):
334            content_type = 'text/css'
335        elif path.endswith('.svg'):
336            content_type = 'image/svg+xml'
337        else:
338            content_type = 'text/plain'
339        text = (
340            resources.files("bumble.apps.lea_unicast")
341            .joinpath(pathlib.Path(path).relative_to('/'))
342            .read_text(encoding="utf-8")
343        )
344        return aiohttp.web.Response(text=text, content_type=content_type)
345
346    async def get_channel(self, request):
347        ws = aiohttp.web.WebSocketResponse()
348        await ws.prepare(request)
349
350        # Process messages until the socket is closed.
351        self.channel_socket = ws
352        async for message in ws:
353            if message.type == aiohttp.WSMsgType.TEXT:
354                logger.debug(f'<<< received message: {message.data}')
355                await self.on_message(message.data)
356            elif message.type == aiohttp.WSMsgType.ERROR:
357                logger.debug(
358                    f'channel connection closed with exception {ws.exception()}'
359                )
360
361        self.channel_socket = None
362        logger.debug('--- channel connection closed')
363
364        return ws
365
366    async def on_message(self, message_str: str):
367        # Parse the message as JSON
368        message = json.loads(message_str)
369
370        # Dispatch the message
371        message_type = message['type']
372        message_params = message.get('params', {})
373        handler = getattr(self, f'on_{message_type}_message')
374        if handler:
375            await handler(**message_params)
376
377    async def on_hello_message(self):
378        await self.send_message(
379            'hello',
380            bumble_version=bumble.__version__,
381            codec=self.speaker().codec,
382            streamState=self.speaker().stream_state.name,
383        )
384        if connection := self.speaker().connection:
385            await self.send_message(
386                'connection',
387                peer_address=connection.peer_address.to_string(False),
388                peer_name=connection.peer_name,
389            )
390
391    async def send_message(self, message_type: str, **kwargs) -> None:
392        if self.channel_socket is None:
393            return
394
395        message = {'type': message_type, 'params': kwargs}
396        await self.channel_socket.send_json(message)
397
398    async def send_audio(self, data: bytes) -> None:
399        if self.channel_socket is None:
400            return
401
402        try:
403            await self.channel_socket.send_bytes(data)
404        except Exception as error:
405            logger.warning(f'exception while sending audio packet: {error}')
406
407
408# -----------------------------------------------------------------------------
409class Speaker:
410
411    def __init__(
412        self,
413        device_config_path: Optional[str],
414        ui_port: int,
415        transport: str,
416        lc3_input_file_path: str,
417    ):
418        self.device_config_path = device_config_path
419        self.transport = transport
420        self.lc3_input_file_path = lc3_input_file_path
421
422        # Create an HTTP server for the UI
423        self.ui_server = UiServer(speaker=self, port=ui_port)
424
425    async def run(self) -> None:
426        await self.ui_server.start_http()
427
428        async with await open_transport(self.transport) as hci_transport:
429            # Create a device
430            if self.device_config_path:
431                device_config = DeviceConfiguration.from_file(self.device_config_path)
432            else:
433                device_config = DeviceConfiguration(
434                    name="Bumble LE Headphone",
435                    class_of_device=0x244418,
436                    keystore="JsonKeyStore",
437                    advertising_interval_min=25,
438                    advertising_interval_max=25,
439                    address=Address('F1:F2:F3:F4:F5:F6'),
440                )
441
442            device_config.le_enabled = True
443            device_config.cis_enabled = True
444            self.device = Device.from_config_with_hci(
445                device_config, hci_transport.source, hci_transport.sink
446            )
447
448            self.device.add_service(
449                pacs.PublishedAudioCapabilitiesService(
450                    supported_source_context=bap.ContextType(0xFFFF),
451                    available_source_context=bap.ContextType(0xFFFF),
452                    supported_sink_context=bap.ContextType(0xFFFF),  # All context types
453                    available_sink_context=bap.ContextType(0xFFFF),  # All context types
454                    sink_audio_locations=(
455                        bap.AudioLocation.FRONT_LEFT | bap.AudioLocation.FRONT_RIGHT
456                    ),
457                    sink_pac=[_sink_pac_record()],
458                    source_audio_locations=bap.AudioLocation.FRONT_LEFT,
459                    source_pac=[_source_pac_record()],
460                )
461            )
462
463            ascs_service = ascs.AudioStreamControlService(
464                self.device, sink_ase_id=[1], source_ase_id=[2]
465            )
466            self.device.add_service(ascs_service)
467
468            advertising_data = bytes(
469                AdvertisingData(
470                    [
471                        (
472                            AdvertisingData.COMPLETE_LOCAL_NAME,
473                            bytes(device_config.name, 'utf-8'),
474                        ),
475                        (
476                            AdvertisingData.FLAGS,
477                            bytes([AdvertisingData.LE_GENERAL_DISCOVERABLE_MODE_FLAG]),
478                        ),
479                        (
480                            AdvertisingData.INCOMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS,
481                            bytes(pacs.PublishedAudioCapabilitiesService.UUID),
482                        ),
483                    ]
484                )
485            ) + bytes(bap.UnicastServerAdvertisingData())
486
487            def on_pdu(pdu: HCI_IsoDataPacket, ase: ascs.AseStateMachine):
488                codec_config = ase.codec_specific_configuration
489                assert isinstance(codec_config, bap.CodecSpecificConfiguration)
490                pcm = decode(
491                    codec_config.frame_duration.us,
492                    codec_config.audio_channel_allocation.channel_count,
493                    pdu.iso_sdu_fragment,
494                )
495                self.device.abort_on('disconnection', self.ui_server.send_audio(pcm))
496
497            def on_ase_state_change(ase: ascs.AseStateMachine) -> None:
498                if ase.state == ascs.AseStateMachine.State.STREAMING:
499                    codec_config = ase.codec_specific_configuration
500                    assert isinstance(codec_config, bap.CodecSpecificConfiguration)
501                    assert ase.cis_link
502                    if ase.role == ascs.AudioRole.SOURCE:
503                        ase.cis_link.abort_on(
504                            'disconnection',
505                            lc3_source_task(
506                                filename=self.lc3_input_file_path,
507                                sdu_length=(
508                                    codec_config.codec_frames_per_sdu
509                                    * codec_config.octets_per_codec_frame
510                                ),
511                                frame_duration_us=codec_config.frame_duration.us,
512                                device=self.device,
513                                cis_handle=ase.cis_link.handle,
514                            ),
515                        )
516                    else:
517                        ase.cis_link.sink = functools.partial(on_pdu, ase=ase)
518                elif ase.state == ascs.AseStateMachine.State.CODEC_CONFIGURED:
519                    codec_config = ase.codec_specific_configuration
520                    assert isinstance(codec_config, bap.CodecSpecificConfiguration)
521                    if ase.role == ascs.AudioRole.SOURCE:
522                        setup_encoders(
523                            codec_config.sampling_frequency.hz,
524                            codec_config.frame_duration.us,
525                            codec_config.audio_channel_allocation.channel_count,
526                        )
527                    else:
528                        setup_decoders(
529                            codec_config.sampling_frequency.hz,
530                            codec_config.frame_duration.us,
531                            codec_config.audio_channel_allocation.channel_count,
532                        )
533
534            for ase in ascs_service.ase_state_machines.values():
535                ase.on('state_change', functools.partial(on_ase_state_change, ase=ase))
536
537            await self.device.power_on()
538            await self.device.create_advertising_set(
539                advertising_data=advertising_data,
540                auto_restart=True,
541                advertising_parameters=AdvertisingParameters(
542                    primary_advertising_interval_min=100,
543                    primary_advertising_interval_max=100,
544                ),
545            )
546
547            await hci_transport.source.terminated
548
549
550@click.command()
551@click.option(
552    '--ui-port',
553    'ui_port',
554    metavar='HTTP_PORT',
555    default=DEFAULT_UI_PORT,
556    show_default=True,
557    help='HTTP port for the UI server',
558)
559@click.option('--device-config', metavar='FILENAME', help='Device configuration file')
560@click.argument('transport')
561@click.argument('lc3_file')
562def speaker(ui_port: int, device_config: str, transport: str, lc3_file: str) -> None:
563    """Run the speaker."""
564
565    asyncio.run(Speaker(device_config, ui_port, transport, lc3_file).run())
566
567
568# -----------------------------------------------------------------------------
569def main():
570    logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'WARNING').upper())
571    speaker()
572
573
574# -----------------------------------------------------------------------------
575if __name__ == "__main__":
576    main()  # pylint: disable=no-value-for-parameter
577