xref: /btstack/test/mesh/simulator.py (revision cd5f23a3250874824c01a2b3326a9522fea3f99f)
1#!/usr/bin/env python3
2#
3# Simulate network of Bluetooth Controllers
4#
5# Each simulated controller has an HCI H4 interface
6# Network configuration will be stored in a YAML file or similar
7#
8# Copyright 2017 BlueKitchen GmbH
9#
10
11
12import os
13import pty
14import select
15import subprocess
16import sys
17import bisect
18import time
19
20# fallback: try to import PyCryptodome as (an almost drop-in) replacement for the PyCrypto library
21try:
22    from Cryptodome.Cipher import AES
23    import Cryptodome.Random as Random
24except ImportError:
25    # fallback: try to import PyCryptodome as (an almost drop-in) replacement for the PyCrypto library
26    try:
27        from Crypto.Cipher import AES
28        import Crypto.Random as Random
29    except ImportError:
30        print("\n[!] PyCryptodome required but not installed (using random value instead)")
31        print("[!] Please install PyCryptodome, e.g. 'pip3 install pycryptodomex' or 'pip3 install pycryptodome'\n")
32
33
34def little_endian_read_16(buffer, pos):
35    return ord(buffer[pos]) + (ord(buffer[pos+1]) << 8)
36
37def as_hex(data):
38    str_list = []
39    for byte in data:
40        str_list.append("{0:02x} ".format(ord(byte)))
41    return ''.join(str_list)
42
43adv_type_names = ['ADV_IND', 'ADV_DIRECT_IND_HIGH', 'ADV_SCAN_IND', 'ADV_NONCONN_IND', 'ADV_DIRECT_IND_LOW']
44timers_timeouts = []
45timers_callbacks = []
46
47class H4Parser:
48
49    def __init__(self):
50        self.packet_type = "NONE"
51        self.reset()
52
53    def set_packet_handler(self, handler):
54        self.handler = handler
55
56    def reset(self):
57        self.bytes_to_read = 1
58        self.buffer = ''
59        self.state = "H4_W4_PACKET_TYPE"
60
61    def parse(self, data):
62        self.buffer += data
63        self.bytes_to_read -= 1
64        if self.bytes_to_read == 0:
65            if self.state == "H4_W4_PACKET_TYPE":
66                self.buffer = ''
67                if data == chr(1):
68                    # cmd
69                    self.packet_type = "CMD"
70                    self.state = "W4_CMD_HEADER"
71                    self.bytes_to_read = 3
72                if data == chr(2):
73                    # acl
74                    self.packet_type = "ACL"
75                    self.state = "W4_ACL_HEADER"
76                    self.bytes_to_read = 4
77                return
78            if self.state == "W4_CMD_HEADER":
79                self.bytes_to_read = ord(self.buffer[2])
80                self.state = "H4_W4_PAYLOAD"
81                if self.bytes_to_read > 0:
82                    return
83                # fall through to handle payload len = 0
84            if self.state == "W4_ACL_HEADER":
85                self.bytes_to_read = little_endian_read_16(buffer, 2)
86                self.state = "H4_W4_PAYLOAD"
87                if self.bytes_to_read > 0:
88                    return
89                # fall through to handle payload len = 0
90            if self.state == "H4_W4_PAYLOAD":
91                self.handler(self.packet_type, self.buffer)
92                self.reset()
93                return
94
95class HCIController:
96
97    def __init__(self):
98        self.fd = -1
99        self.random = Random.new()
100        self.name = 'BTstack Mesh Simulator'
101        self.bd_addr = 'aaaaaa'
102        self.parser = H4Parser()
103        self.parser.set_packet_handler(self.packet_handler)
104        self.adv_enabled = 0
105        self.adv_type = 0
106        self.adv_interval_min = 0
107        self.adv_interval_max = 0
108        self.adv_data = ''
109        self.scan_enabled = False
110
111    def parse(self, data):
112        self.parser.parse(data)
113
114    def set_fd(self,fd):
115        self.fd = fd
116
117    def set_bd_addr(self, bd_addr):
118        self.bd_addr = bd_addr
119
120    def set_name(self, name):
121        self.name = name
122
123    def set_adv_handler(self, adv_handler, adv_handler_context):
124        self.adv_handler         = adv_handler
125        self.adv_handler_context = adv_handler_context
126
127    def is_scanning(self):
128        return self.scan_enabled
129
130    def emit_command_complete(self, opcode, result):
131        # type, event, len, num commands, opcode, result
132        os.write(self.fd, '\x04\x0e' + chr(3 + len(result)) + chr(1) + chr(opcode & 255)  + chr(opcode >> 8) + result)
133
134    def emit_adv_report(self, event_type, rssi):
135        # type, event, len, Subevent_Code, Num_Reports, Event_Type[i], Address_Type[i], Address[i], Length[i], Data[i], RSSI[i]
136        event = '\x04\x3e' + chr(12 + len(self.adv_data)) + chr(2) + chr(1) + chr(event_type) + chr(0) + self.bd_addr[::-1] + chr(len(self.adv_data)) + self.adv_data + chr(rssi)
137        self.adv_handler(self.adv_handler_context, event)
138
139    def handle_set_adv_enable(self, enable):
140        self.adv_enabled = enable
141        print('Node %s adv enable %u' % (self.name, self.adv_enabled))
142        if self.adv_enabled:
143            add_timer(1, self.handle_adv_timer, self)
144        else:
145            remove_timer(self.handle_adv_timer, self)
146
147    def handle_set_adv_data(self, data):
148        self.adv_data = data
149        print('Node %s adv data %s' % (self.name, as_hex(self.adv_data)))
150
151    def handle_set_adv_params(self, interval_min, interval_max, adv_type):
152        self.adv_interval_min = interval_min * 0.625
153        self.adv_interval_max = interval_max * 0.625
154        self.adv_type         = adv_type
155        print('Node %s adv interval min/max %u/%u ms, type %s' % (self.name, self.adv_interval_min, self.adv_interval_max, adv_type_names[self.adv_type]))
156
157    def handle_adv_timer(self, context):
158        if self.adv_enabled:
159            self.emit_adv_report(0, 0)
160            add_timer(self.adv_interval_min, self.handle_adv_timer, self)
161
162    def packet_handler(self, packet_type, packet):
163        opcode = little_endian_read_16(packet, 0)
164        # print ("%s, opcode 0x%04x" % (self.name, opcode))
165        if opcode == 0x0c03:
166            self.emit_command_complete(opcode, '\x00')
167            return
168        if opcode == 0x1001:
169            self.emit_command_complete(opcode, '\x00\x10\x00\x06\x86\x1d\x06\x0a\x00\x86\x1d')
170            return
171        if opcode == 0x0c14:
172            self.emit_command_complete(opcode, '\x00' + self.name)
173            return
174        if opcode == 0x1002:
175            self.emit_command_complete(opcode, '\x00\xff\xff\xff\x03\xfe\xff\xff\xff\xff\xff\xff\xff\xf3\x0f\xe8\xfe\x3f\xf7\x83\xff\x1c\x00\x00\x00\x61\xf7\xff\xff\x7f\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00')
176            return
177        if opcode == 0x1009:
178            # read bd_addr
179            self.emit_command_complete(opcode, '\x00' + self.bd_addr[::-1])
180            return
181        if opcode == 0x1005:
182            # read buffer size
183            self.emit_command_complete(opcode, '\x00\x36\x01\x40\x0a\x00\x08\x00')
184            return
185        if opcode == 0x1003:
186            # read local supported features
187            self.emit_command_complete(opcode, '\x00\xff\xff\x8f\xfe\xf8\xff\x5b\x87')
188            return
189        if opcode == 0x0c01:
190            self.emit_command_complete(opcode, '\x00')
191            return
192        if opcode == 0x2002:
193            # le read buffer size
194            self.emit_command_complete(opcode, '\x00\x00\x00\x00')
195            return
196        if opcode == 0x200f:
197            # read whitelist size
198            self.emit_command_complete(opcode, '\x00\x19')
199            return
200        if opcode == 0x200b:
201            # set scan parameters
202            self.emit_command_complete(opcode, '\x00')
203            return
204        if opcode == 0x200c:
205            # set scan enabled
206            self.scan_enabled = ord(packet[3])
207            self.emit_command_complete(opcode, '\x00')
208            return
209        if opcode == 0x0c6d:
210            # write le host supported
211            self.emit_command_complete(opcode, '\x00')
212            return
213        if opcode == 0x2017:
214            # LE Encrypt - key 16, data 16
215            key = packet[18:2:-1]
216            data = packet[35:18:-1]
217            cipher = AES.new(key)
218            result = cipher.encrypt(data)
219            self.emit_command_complete(opcode, result[::-1])
220            return
221        if opcode == 0x2018:
222            # LE Rand
223            self.emit_command_complete(opcode, '\x00' + self.random.read(8))
224            return
225        if opcode == 0x2006:
226            # Set Adv Params
227            self.handle_set_adv_params(little_endian_read_16(packet,3), little_endian_read_16(packet,5), ord(packet[6]))
228            self.emit_command_complete(opcode, '\x00')
229            return
230        if opcode == 0x2008:
231            # Set Adv Data
232            len = ord(packet[3])
233            self.handle_set_adv_data(packet[4:4+len])
234            self.emit_command_complete(opcode, '\x00')
235            return
236        if opcode == 0x200a:
237            # Set Adv Enable
238            self.handle_set_adv_enable(ord(packet[3]))
239            self.emit_command_complete(opcode, '\x00')
240            return
241        print("Opcode 0x%0x not handled!" % opcode)
242
243class Node:
244
245    def __init__(self):
246        self.name = 'node'
247        self.master = -1
248        self.slave  = -1
249        self.slave_ttyname = ''
250        self.controller = HCIController()
251
252    def set_name(self, name):
253        self.controller.set_name(name)
254        self.name = name
255
256    def get_name(self):
257        return self.name
258
259    def set_bd_addr(self, bd_addr):
260        self.controller.set_bd_addr(bd_addr)
261
262    def start_process(self):
263        print('Node: %s' % self.name)
264        (self.master, self.slave) = pty.openpty()
265        self.slave_ttyname = os.ttyname(self.slave)
266        print('- tty %s' % self.slave_ttyname)
267        print('- fd %u' % self.master)
268        self.controller.set_fd(self.master)
269        subprocess.Popen(['./mesh', '-d', self.slave_ttyname])
270
271    def get_master(self):
272        return self.master
273
274    def parse(self, c):
275        self.controller.parse(c)
276
277    def set_adv_handler(self, adv_handler, adv_handler_context):
278        self.controller.set_adv_handler(adv_handler, adv_handler_context)
279
280    def inject_packet(self, event):
281        os.write(self.master, event)
282
283    def is_scanning(self):
284        return self.controller.is_scanning()
285
286def get_time_millis():
287    return int(round(time.time() * 1000))
288
289def add_timer(timeout_ms, callback, context):
290    global timers_timeouts
291    global timers_callbacks
292
293    timeout = get_time_millis() + timeout_ms;
294    pos = bisect.bisect(timers_timeouts, timeout)
295    timers_timeouts.insert(pos, timeout)
296    timers_callbacks.insert(pos, (callback, context))
297
298def remove_timer(callback, context):
299    if (callback, context) in timers_callbacks:
300        indices = [timers_callbacks.index(t) for t in timers_callbacks if t[0] == callback and t[1] == context]
301        index = indices[0]
302        timers_callbacks.pop(index)
303        timers_timeouts.pop(index)
304
305def run(nodes):
306    # create map fd -> node
307    nodes_by_fd = { node.get_master():node for node in nodes}
308    read_fds = nodes_by_fd.keys()
309    while True:
310        # process expired timers
311        time_ms = get_time_millis()
312        while len(timers_timeouts) and timers_timeouts[0] < time_ms:
313            timers_timeouts.pop(0)
314            (callback,context) = timers_callbacks.pop(0)
315            callback(context)
316        # timer timers_timeouts?
317        if len(timers_timeouts):
318            timeout = (timers_timeouts[0] - time_ms) / 1000.0
319            (read_ready, write_ready, exception_ready) = select.select(read_fds,[],[], timeout)
320        else:
321            (read_ready, write_ready, exception_ready) = select.select(read_fds,[],[])
322        for fd in read_ready:
323            node = nodes_by_fd[fd]
324            c = os.read(fd, 1)
325            node.parse(c)
326
327def adv_handler(src_node, event):
328    global nodes
329    # print('adv from %s' % src_node.get_name())
330    for dst_node in nodes:
331        if src_node == dst_node:
332            continue
333        if not dst_node.is_scanning():
334            continue
335        print('Adv %s -> %s - %s' % (src_node.get_name(), dst_node.get_name(), as_hex(event[14:-1])))
336        dst_node.inject_packet(event)
337
338# parse configuration file passed in via cmd line args
339# TODO
340
341node1 = Node()
342node1.set_name('node_1')
343node1.set_bd_addr('aaaaaa')
344node1.set_adv_handler(adv_handler, node1)
345node1.start_process()
346
347node2 = Node()
348node2.set_name('node_2')
349node2.set_bd_addr('bbbbbb')
350node2.set_adv_handler(adv_handler, node2)
351node2.start_process()
352
353nodes = [node1, node2]
354
355run(nodes)
356