1# Copyright 2021-2023 Google LLC
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#      https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15# -----------------------------------------------------------------------------
16# Imports
17# -----------------------------------------------------------------------------
18from __future__ import annotations
19import asyncio
20import asyncio.subprocess
21from importlib import resources
22import enum
23import json
24import os
25import logging
26import pathlib
27import subprocess
28from typing import Dict, List, Optional
29import weakref
30
31import click
32import aiohttp
33from aiohttp import web
34
35import bumble
36from bumble.colors import color
37from bumble.core import BT_BR_EDR_TRANSPORT, CommandTimeoutError
38from bumble.device import Connection, Device, DeviceConfiguration
39from bumble.hci import HCI_StatusError
40from bumble.pairing import PairingConfig
41from bumble.sdp import ServiceAttribute
42from bumble.transport import open_transport
43from bumble.avdtp import (
44    AVDTP_AUDIO_MEDIA_TYPE,
45    Listener,
46    MediaCodecCapabilities,
47    MediaPacket,
48    Protocol,
49)
50from bumble.a2dp import (
51    MPEG_2_AAC_LC_OBJECT_TYPE,
52    make_audio_sink_service_sdp_records,
53    A2DP_SBC_CODEC_TYPE,
54    A2DP_MPEG_2_4_AAC_CODEC_TYPE,
55    SBC_MONO_CHANNEL_MODE,
56    SBC_DUAL_CHANNEL_MODE,
57    SBC_SNR_ALLOCATION_METHOD,
58    SBC_LOUDNESS_ALLOCATION_METHOD,
59    SBC_STEREO_CHANNEL_MODE,
60    SBC_JOINT_STEREO_CHANNEL_MODE,
61    SbcMediaCodecInformation,
62    AacMediaCodecInformation,
63)
64from bumble.utils import AsyncRunner
65from bumble.codecs import AacAudioRtpPacket
66
67
68# -----------------------------------------------------------------------------
69# Logging
70# -----------------------------------------------------------------------------
71logger = logging.getLogger(__name__)
72
73
74# -----------------------------------------------------------------------------
75# Constants
76# -----------------------------------------------------------------------------
77DEFAULT_UI_PORT = 7654
78
79
80# -----------------------------------------------------------------------------
81class AudioExtractor:
82    @staticmethod
83    def create(codec: str):
84        if codec == 'aac':
85            return AacAudioExtractor()
86        if codec == 'sbc':
87            return SbcAudioExtractor()
88
89    def extract_audio(self, packet: MediaPacket) -> bytes:
90        raise NotImplementedError()
91
92
93# -----------------------------------------------------------------------------
94class AacAudioExtractor:
95    def extract_audio(self, packet: MediaPacket) -> bytes:
96        return AacAudioRtpPacket(packet.payload).to_adts()
97
98
99# -----------------------------------------------------------------------------
100class SbcAudioExtractor:
101    def extract_audio(self, packet: MediaPacket) -> bytes:
102        # header = packet.payload[0]
103        # fragmented = header >> 7
104        # start = (header >> 6) & 0x01
105        # last = (header >> 5) & 0x01
106        # number_of_frames = header & 0x0F
107
108        # TODO: support fragmented payloads
109        return packet.payload[1:]
110
111
112# -----------------------------------------------------------------------------
113class Output:
114    async def start(self) -> None:
115        pass
116
117    async def stop(self) -> None:
118        pass
119
120    async def suspend(self) -> None:
121        pass
122
123    async def on_connection(self, connection: Connection) -> None:
124        pass
125
126    async def on_disconnection(self, reason: int) -> None:
127        pass
128
129    def on_rtp_packet(self, packet: MediaPacket) -> None:
130        pass
131
132
133# -----------------------------------------------------------------------------
134class FileOutput(Output):
135    filename: str
136    codec: str
137    extractor: AudioExtractor
138
139    def __init__(self, filename, codec):
140        self.filename = filename
141        self.codec = codec
142        self.file = open(filename, 'wb')
143        self.extractor = AudioExtractor.create(codec)
144
145    def on_rtp_packet(self, packet: MediaPacket) -> None:
146        self.file.write(self.extractor.extract_audio(packet))
147
148
149# -----------------------------------------------------------------------------
150class QueuedOutput(Output):
151    MAX_QUEUE_SIZE = 32768
152
153    packets: asyncio.Queue
154    extractor: AudioExtractor
155    packet_pump_task: Optional[asyncio.Task]
156    started: bool
157
158    def __init__(self, extractor):
159        self.extractor = extractor
160        self.packets = asyncio.Queue()
161        self.packet_pump_task = None
162        self.started = False
163
164    async def start(self):
165        if self.started:
166            return
167
168        self.packet_pump_task = asyncio.create_task(self.pump_packets())
169
170    async def pump_packets(self):
171        while True:
172            packet = await self.packets.get()
173            await self.on_audio_packet(packet)
174
175    async def on_audio_packet(self, packet: bytes) -> None:
176        pass
177
178    def on_rtp_packet(self, packet: MediaPacket) -> None:
179        if self.packets.qsize() > self.MAX_QUEUE_SIZE:
180            logger.debug("queue full, dropping")
181            return
182
183        self.packets.put_nowait(self.extractor.extract_audio(packet))
184
185
186# -----------------------------------------------------------------------------
187class WebSocketOutput(QueuedOutput):
188    def __init__(self, codec, send_audio, send_message):
189        super().__init__(AudioExtractor.create(codec))
190        self.send_audio = send_audio
191        self.send_message = send_message
192
193    async def on_connection(self, connection: Connection) -> None:
194        try:
195            await connection.request_remote_name()
196        except HCI_StatusError:
197            pass
198        peer_name = '' if connection.peer_name is None else connection.peer_name
199        peer_address = connection.peer_address.to_string(False)
200        await self.send_message(
201            'connection',
202            peer_address=peer_address,
203            peer_name=peer_name,
204        )
205
206    async def on_disconnection(self, reason) -> None:
207        await self.send_message('disconnection')
208
209    async def on_audio_packet(self, packet: bytes) -> None:
210        await self.send_audio(packet)
211
212    async def start(self):
213        await super().start()
214        await self.send_message('start')
215
216    async def stop(self):
217        await super().stop()
218        await self.send_message('stop')
219
220    async def suspend(self):
221        await super().suspend()
222        await self.send_message('suspend')
223
224
225# -----------------------------------------------------------------------------
226class FfplayOutput(QueuedOutput):
227    MAX_QUEUE_SIZE = 32768
228
229    subprocess: Optional[asyncio.subprocess.Process]
230    ffplay_task: Optional[asyncio.Task]
231
232    def __init__(self, codec: str) -> None:
233        super().__init__(AudioExtractor.create(codec))
234        self.subprocess = None
235        self.ffplay_task = None
236        self.codec = codec
237
238    async def start(self):
239        if self.started:
240            return
241
242        await super().start()
243
244        self.subprocess = await asyncio.create_subprocess_shell(
245            f'ffplay -f {self.codec} pipe:0',
246            stdin=asyncio.subprocess.PIPE,
247            stdout=asyncio.subprocess.PIPE,
248            stderr=asyncio.subprocess.PIPE,
249        )
250
251        self.ffplay_task = asyncio.create_task(self.monitor_ffplay())
252
253    async def stop(self):
254        # TODO
255        pass
256
257    async def suspend(self):
258        # TODO
259        pass
260
261    async def monitor_ffplay(self):
262        async def read_stream(name, stream):
263            while True:
264                data = await stream.read()
265                logger.debug(f'{name}:', data)
266
267        await asyncio.wait(
268            [
269                asyncio.create_task(
270                    read_stream('[ffplay stdout]', self.subprocess.stdout)
271                ),
272                asyncio.create_task(
273                    read_stream('[ffplay stderr]', self.subprocess.stderr)
274                ),
275                asyncio.create_task(self.subprocess.wait()),
276            ]
277        )
278        logger.debug("FFPLAY done")
279
280    async def on_audio_packet(self, packet):
281        try:
282            self.subprocess.stdin.write(packet)
283        except Exception:
284            logger.warning('!!!! exception while sending audio to ffplay pipe')
285
286
287# -----------------------------------------------------------------------------
288class UiServer:
289    speaker: weakref.ReferenceType[Speaker]
290    port: int
291
292    def __init__(self, speaker: Speaker, port: int) -> None:
293        self.speaker = weakref.ref(speaker)
294        self.port = port
295        self.channel_socket = None
296
297    async def start_http(self) -> None:
298        """Start the UI HTTP server."""
299
300        app = web.Application()
301        app.add_routes(
302            [
303                web.get('/', self.get_static),
304                web.get('/speaker.html', self.get_static),
305                web.get('/speaker.js', self.get_static),
306                web.get('/speaker.css', self.get_static),
307                web.get('/logo.svg', self.get_static),
308                web.get('/channel', self.get_channel),
309            ]
310        )
311
312        runner = web.AppRunner(app)
313        await runner.setup()
314        site = web.TCPSite(runner, 'localhost', self.port)
315        print('UI HTTP server at ' + color(f'http://127.0.0.1:{self.port}', 'green'))
316        await site.start()
317
318    async def get_static(self, request):
319        path = request.path
320        if path == '/':
321            path = '/speaker.html'
322        if path.endswith('.html'):
323            content_type = 'text/html'
324        elif path.endswith('.js'):
325            content_type = 'text/javascript'
326        elif path.endswith('.css'):
327            content_type = 'text/css'
328        elif path.endswith('.svg'):
329            content_type = 'image/svg+xml'
330        else:
331            content_type = 'text/plain'
332        text = (
333            resources.files("bumble.apps.speaker")
334            .joinpath(pathlib.Path(path).relative_to('/'))
335            .read_text(encoding="utf-8")
336        )
337        return aiohttp.web.Response(text=text, content_type=content_type)
338
339    async def get_channel(self, request):
340        ws = web.WebSocketResponse()
341        await ws.prepare(request)
342
343        # Process messages until the socket is closed.
344        self.channel_socket = ws
345        async for message in ws:
346            if message.type == aiohttp.WSMsgType.TEXT:
347                logger.debug(f'<<< received message: {message.data}')
348                await self.on_message(message.data)
349            elif message.type == aiohttp.WSMsgType.ERROR:
350                logger.debug(
351                    f'channel connection closed with exception {ws.exception()}'
352                )
353
354        self.channel_socket = None
355        logger.debug('--- channel connection closed')
356
357        return ws
358
359    async def on_message(self, message_str: str):
360        # Parse the message as JSON
361        message = json.loads(message_str)
362
363        # Dispatch the message
364        message_type = message['type']
365        message_params = message.get('params', {})
366        handler = getattr(self, f'on_{message_type}_message')
367        if handler:
368            await handler(**message_params)
369
370    async def on_hello_message(self):
371        await self.send_message(
372            'hello',
373            bumble_version=bumble.__version__,
374            codec=self.speaker().codec,
375            streamState=self.speaker().stream_state.name,
376        )
377        if connection := self.speaker().connection:
378            await self.send_message(
379                'connection',
380                peer_address=connection.peer_address.to_string(False),
381                peer_name=connection.peer_name,
382            )
383
384    async def send_message(self, message_type: str, **kwargs) -> None:
385        if self.channel_socket is None:
386            return
387
388        message = {'type': message_type, 'params': kwargs}
389        await self.channel_socket.send_json(message)
390
391    async def send_audio(self, data: bytes) -> None:
392        if self.channel_socket is None:
393            return
394
395        try:
396            await self.channel_socket.send_bytes(data)
397        except Exception as error:
398            logger.warning(f'exception while sending audio packet: {error}')
399
400
401# -----------------------------------------------------------------------------
402class Speaker:
403    class StreamState(enum.Enum):
404        IDLE = 0
405        STOPPED = 1
406        STARTED = 2
407        SUSPENDED = 3
408
409    def __init__(self, device_config, transport, codec, discover, outputs, ui_port):
410        self.device_config = device_config
411        self.transport = transport
412        self.codec = codec
413        self.discover = discover
414        self.ui_port = ui_port
415        self.device = None
416        self.connection = None
417        self.listener = None
418        self.packets_received = 0
419        self.bytes_received = 0
420        self.stream_state = Speaker.StreamState.IDLE
421        self.outputs = []
422        for output in outputs:
423            if output == '@ffplay':
424                self.outputs.append(FfplayOutput(codec))
425                continue
426
427            # Default to FileOutput
428            self.outputs.append(FileOutput(output, codec))
429
430        # Create an HTTP server for the UI
431        self.ui_server = UiServer(speaker=self, port=ui_port)
432
433    def sdp_records(self) -> Dict[int, List[ServiceAttribute]]:
434        service_record_handle = 0x00010001
435        return {
436            service_record_handle: make_audio_sink_service_sdp_records(
437                service_record_handle
438            )
439        }
440
441    def codec_capabilities(self) -> MediaCodecCapabilities:
442        if self.codec == 'aac':
443            return self.aac_codec_capabilities()
444
445        if self.codec == 'sbc':
446            return self.sbc_codec_capabilities()
447
448        raise RuntimeError('unsupported codec')
449
450    def aac_codec_capabilities(self) -> MediaCodecCapabilities:
451        return MediaCodecCapabilities(
452            media_type=AVDTP_AUDIO_MEDIA_TYPE,
453            media_codec_type=A2DP_MPEG_2_4_AAC_CODEC_TYPE,
454            media_codec_information=AacMediaCodecInformation.from_lists(
455                object_types=[MPEG_2_AAC_LC_OBJECT_TYPE],
456                sampling_frequencies=[48000, 44100],
457                channels=[1, 2],
458                vbr=1,
459                bitrate=256000,
460            ),
461        )
462
463    def sbc_codec_capabilities(self) -> MediaCodecCapabilities:
464        return MediaCodecCapabilities(
465            media_type=AVDTP_AUDIO_MEDIA_TYPE,
466            media_codec_type=A2DP_SBC_CODEC_TYPE,
467            media_codec_information=SbcMediaCodecInformation.from_lists(
468                sampling_frequencies=[48000, 44100, 32000, 16000],
469                channel_modes=[
470                    SBC_MONO_CHANNEL_MODE,
471                    SBC_DUAL_CHANNEL_MODE,
472                    SBC_STEREO_CHANNEL_MODE,
473                    SBC_JOINT_STEREO_CHANNEL_MODE,
474                ],
475                block_lengths=[4, 8, 12, 16],
476                subbands=[4, 8],
477                allocation_methods=[
478                    SBC_LOUDNESS_ALLOCATION_METHOD,
479                    SBC_SNR_ALLOCATION_METHOD,
480                ],
481                minimum_bitpool_value=2,
482                maximum_bitpool_value=53,
483            ),
484        )
485
486    async def dispatch_to_outputs(self, function):
487        for output in self.outputs:
488            await function(output)
489
490    def on_bluetooth_connection(self, connection):
491        print(f'Connection: {connection}')
492        self.connection = connection
493        connection.on('disconnection', self.on_bluetooth_disconnection)
494        AsyncRunner.spawn(
495            self.dispatch_to_outputs(lambda output: output.on_connection(connection))
496        )
497
498    def on_bluetooth_disconnection(self, reason):
499        print(f'Disconnection ({reason})')
500        self.connection = None
501        AsyncRunner.spawn(self.advertise())
502        AsyncRunner.spawn(
503            self.dispatch_to_outputs(lambda output: output.on_disconnection(reason))
504        )
505
506    def on_avdtp_connection(self, protocol):
507        print('Audio Stream Open')
508
509        # Add a sink endpoint to the server
510        sink = protocol.add_sink(self.codec_capabilities())
511        sink.on('start', self.on_sink_start)
512        sink.on('stop', self.on_sink_stop)
513        sink.on('suspend', self.on_sink_suspend)
514        sink.on('configuration', lambda: self.on_sink_configuration(sink.configuration))
515        sink.on('rtp_packet', self.on_rtp_packet)
516        sink.on('rtp_channel_open', self.on_rtp_channel_open)
517        sink.on('rtp_channel_close', self.on_rtp_channel_close)
518
519        # Listen for close events
520        protocol.on('close', self.on_avdtp_close)
521
522        # Discover all endpoints on the remote device is requested
523        if self.discover:
524            AsyncRunner.spawn(self.discover_remote_endpoints(protocol))
525
526    def on_avdtp_close(self):
527        print("Audio Stream Closed")
528
529    def on_sink_start(self):
530        print("Sink Started\u001b[0K")
531        self.stream_state = self.StreamState.STARTED
532        AsyncRunner.spawn(self.dispatch_to_outputs(lambda output: output.start()))
533
534    def on_sink_stop(self):
535        print("Sink Stopped\u001b[0K")
536        self.stream_state = self.StreamState.STOPPED
537        AsyncRunner.spawn(self.dispatch_to_outputs(lambda output: output.stop()))
538
539    def on_sink_suspend(self):
540        print("Sink Suspended\u001b[0K")
541        self.stream_state = self.StreamState.SUSPENDED
542        AsyncRunner.spawn(self.dispatch_to_outputs(lambda output: output.suspend()))
543
544    def on_sink_configuration(self, config):
545        print("Sink Configuration:")
546        print('\n'.join(["  " + str(capability) for capability in config]))
547
548    def on_rtp_channel_open(self):
549        print("RTP Channel Open")
550
551    def on_rtp_channel_close(self):
552        print("RTP Channel Closed")
553        self.stream_state = self.StreamState.IDLE
554
555    def on_rtp_packet(self, packet):
556        self.packets_received += 1
557        self.bytes_received += len(packet.payload)
558        print(
559            f'[{self.bytes_received} bytes in {self.packets_received} packets] {packet}',
560            end='\r',
561        )
562
563        for output in self.outputs:
564            output.on_rtp_packet(packet)
565
566    async def advertise(self):
567        await self.device.set_discoverable(True)
568        await self.device.set_connectable(True)
569
570    async def connect(self, address):
571        # Connect to the source
572        print(f'=== Connecting to {address}...')
573        connection = await self.device.connect(address, transport=BT_BR_EDR_TRANSPORT)
574        print(f'=== Connected to {connection.peer_address}')
575
576        # Request authentication
577        print('*** Authenticating...')
578        await connection.authenticate()
579        print('*** Authenticated')
580
581        # Enable encryption
582        print('*** Enabling encryption...')
583        await connection.encrypt()
584        print('*** Encryption on')
585
586        protocol = await Protocol.connect(connection)
587        self.listener.set_server(connection, protocol)
588        self.on_avdtp_connection(protocol)
589
590    async def discover_remote_endpoints(self, protocol):
591        endpoints = await protocol.discover_remote_endpoints()
592        print(f'@@@ Found {len(endpoints)} endpoints')
593        for endpoint in endpoints:
594            print('@@@', endpoint)
595
596    async def run(self, connect_address):
597        await self.ui_server.start_http()
598        self.outputs.append(
599            WebSocketOutput(
600                self.codec, self.ui_server.send_audio, self.ui_server.send_message
601            )
602        )
603
604        async with await open_transport(self.transport) as (hci_source, hci_sink):
605            # Create a device
606            device_config = DeviceConfiguration()
607            if self.device_config:
608                device_config.load_from_file(self.device_config)
609            else:
610                device_config.name = "Bumble Speaker"
611                device_config.class_of_device = 0x240414
612                device_config.keystore = "JsonKeyStore"
613
614            device_config.classic_enabled = True
615            device_config.le_enabled = False
616            self.device = Device.from_config_with_hci(
617                device_config, hci_source, hci_sink
618            )
619
620            # Setup the SDP to expose the sink service
621            self.device.sdp_service_records = self.sdp_records()
622
623            # Don't require MITM when pairing.
624            self.device.pairing_config_factory = lambda connection: PairingConfig(
625                mitm=False
626            )
627
628            # Start the controller
629            await self.device.power_on()
630
631            # Print some of the config/properties
632            print("Speaker Name:", color(device_config.name, 'yellow'))
633            print(
634                "Speaker Bluetooth Address:",
635                color(
636                    self.device.public_address.to_string(with_type_qualifier=False),
637                    'yellow',
638                ),
639            )
640
641            # Listen for Bluetooth connections
642            self.device.on('connection', self.on_bluetooth_connection)
643
644            # Create a listener to wait for AVDTP connections
645            self.listener = Listener.for_device(self.device)
646            self.listener.on('connection', self.on_avdtp_connection)
647
648            print(f'Speaker ready to play, codec={color(self.codec, "cyan")}')
649
650            if connect_address:
651                # Connect to the source
652                try:
653                    await self.connect(connect_address)
654                except CommandTimeoutError:
655                    print(color("Connection timed out", "red"))
656                    return
657            else:
658                # Start being discoverable and connectable
659                print("Waiting for connection...")
660                await self.advertise()
661
662            await hci_source.wait_for_termination()
663
664        for output in self.outputs:
665            await output.stop()
666
667
668# -----------------------------------------------------------------------------
669@click.group()
670@click.pass_context
671def speaker_cli(ctx, device_config):
672    ctx.ensure_object(dict)
673    ctx.obj['device_config'] = device_config
674
675
676@click.command()
677@click.option(
678    '--codec', type=click.Choice(['sbc', 'aac']), default='aac', show_default=True
679)
680@click.option(
681    '--discover', is_flag=True, help='Discover remote endpoints once connected'
682)
683@click.option(
684    '--output',
685    multiple=True,
686    metavar='NAME',
687    help=(
688        'Send audio to this named output '
689        '(may be used more than once for multiple outputs)'
690    ),
691)
692@click.option(
693    '--ui-port',
694    'ui_port',
695    metavar='HTTP_PORT',
696    default=DEFAULT_UI_PORT,
697    show_default=True,
698    help='HTTP port for the UI server',
699)
700@click.option(
701    '--connect',
702    'connect_address',
703    metavar='ADDRESS_OR_NAME',
704    help='Address or name to connect to',
705)
706@click.option('--device-config', metavar='FILENAME', help='Device configuration file')
707@click.argument('transport')
708def speaker(
709    transport, codec, connect_address, discover, output, ui_port, device_config
710):
711    """Run the speaker."""
712
713    if '@ffplay' in output:
714        # Check if ffplay is installed
715        try:
716            subprocess.run(['ffplay', '-version'], capture_output=True, check=True)
717        except FileNotFoundError:
718            print(
719                color('ffplay not installed, @ffplay output will be disabled', 'yellow')
720            )
721            output = list(filter(lambda x: x != '@ffplay', output))
722
723    asyncio.run(
724        Speaker(device_config, transport, codec, discover, output, ui_port).run(
725            connect_address
726        )
727    )
728
729
730# -----------------------------------------------------------------------------
731def main():
732    logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'WARNING').upper())
733    speaker()
734
735
736# -----------------------------------------------------------------------------
737if __name__ == "__main__":
738    main()  # pylint: disable=no-value-for-parameter
739