1# Copyright 2024 Google LLC
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#      https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15import asyncio
16import os
17import pytest
18import socket
19import unittest
20from unittest.mock import ANY, patch
21
22from bumble.transport.tcp_server import (
23    open_tcp_server_transport,
24    open_tcp_server_transport_with_socket,
25)
26
27
28class OpenTcpServerTransportTests(unittest.TestCase):
29    def setUp(self):
30        self.patcher = patch('bumble.transport.tcp_server._create_server')
31        self.mock_create_server = self.patcher.start()
32
33    def tearDown(self):
34        self.patcher.stop()
35
36    def test_open_with_spec(self):
37        asyncio.run(open_tcp_server_transport('localhost:32100'))
38        self.mock_create_server.assert_awaited_once_with(
39            ANY, host='localhost', port=32100
40        )
41
42    def test_open_with_port_only_spec(self):
43        asyncio.run(open_tcp_server_transport('_:32100'))
44        self.mock_create_server.assert_awaited_once_with(ANY, host=None, port=32100)
45
46    def test_open_with_socket(self):
47        with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
48            asyncio.run(open_tcp_server_transport_with_socket(sock=sock))
49        self.mock_create_server.assert_awaited_once_with(ANY, sock=sock)
50
51
52@pytest.mark.skipif(
53    not os.environ.get('PYTEST_NOSKIP', 0),
54    reason='''\
55Not hermetic. Should only run manually with
56  $ PYTEST_NOSKIP=1 pytest tests
57''',
58)
59def test_open_with_real_socket():
60    with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
61        sock.bind(('localhost', 0))
62        port = sock.getsockname()[1]
63        assert port != 0
64        asyncio.run(open_tcp_server_transport_with_socket(sock=sock))
65