xref: /btstack/test/sbc/sbc_decoder.py (revision cd5f23a3250874824c01a2b3326a9522fea3f99f)
1#!/usr/bin/env python3
2import numpy as np
3import wave
4import struct
5import sys
6from sbc import *
7from sbc_synthesis_v1 import *
8
9V = np.zeros(shape = (2, 10*2*8))
10N = np.zeros(shape = (16,8))
11total_time_ms = 0
12mSBC_enabled = 1
13H2_first_byte = 0
14H2_second_byte = 0
15
16def find_syncword(h2_first_byte, h2_second_byte):
17    if h2_first_byte != 1:
18        return -1
19
20    hn = h2_second_byte >> 4
21    ln = h2_second_byte & 0x0F
22    if ln == 8:
23        sn0 = hn & 3
24        sn1 = hn >> 2
25
26        if sn0 != sn1:
27            return -1
28
29        if sn0 not in [0,3]:
30            return -1
31
32    return sn0
33
34def sbc_unpack_frame(fin, available_bytes, frame):
35    global H2_first_byte, H2_second_byte
36    if available_bytes == 0:
37        print ("no available_bytes")
38        raise TypeError
39
40    frame.syncword = get_bits(fin,8)
41    if mSBC_enabled:
42        if frame.syncword != 173:
43            #print ("out of sync %02x" % frame.syncword)
44            H2_first_byte = H2_second_byte
45            H2_second_byte = frame.syncword
46            return -1
47    else:
48        if frame.syncword != 156:
49            #print ("out of sync %02x" % frame.syncword)
50            return -1
51
52    if mSBC_enabled:
53        frame.sampling_frequency = 0    # == 16 kHz
54        frame.nr_blocks = 15
55        frame.channel_mode = MONO
56        frame.allocation_method = LOUDNESS
57        frame.nr_subbands = 8
58        frame.bitpool = 26
59        frame.reserved_for_future_use = get_bits(fin,16)
60    else:
61        frame.sampling_frequency = get_bits(fin,2)
62        frame.nr_blocks = nr_blocks[get_bits(fin,2)]
63        frame.channel_mode = get_bits(fin,2)
64        frame.allocation_method = get_bits(fin,1)
65        frame.nr_subbands = nr_subbands[get_bits(fin,1)]
66        frame.bitpool = get_bits(fin,8)
67
68    if frame.channel_mode == MONO:
69        frame.nr_channels = 1
70    else:
71        frame.nr_channels = 2
72
73    frame.crc_check = get_bits(fin,8)
74
75    frame.init(frame.nr_blocks, frame.nr_subbands, frame.nr_channels)
76
77    # read joint stereo flags
78    if frame.channel_mode == JOINT_STEREO:
79        for sb in range(frame.nr_subbands-1):
80            frame.join[sb] = get_bits(fin,1)
81        get_bits(fin,1) # RFA
82
83    frame.scale_factor = np.zeros(shape=(frame.nr_channels, frame.nr_subbands), dtype = np.int32)
84
85    # read scale factors
86    for ch in range(frame.nr_channels):
87        for sb in range(frame.nr_subbands):
88            frame.scale_factor[ch][sb] = get_bits(fin, 4)
89
90    if mSBC_enabled:
91        #print ("syncword: ", find_syncword(H2_first_byte, H2_second_byte))
92        crc = calculate_crc_mSBC(frame)
93    else:
94        crc = calculate_crc(frame)
95
96    if crc != frame.crc_check:
97        print ("CRC mismatch: calculated %d, expected %d" % (crc, frame.crc_check))
98        return -1
99
100
101    frame.scalefactor = np.zeros(shape=(frame.nr_channels, frame.nr_subbands), dtype = np.int32)
102    for ch in range(frame.nr_channels):
103        for sb in range(frame.nr_subbands):
104            frame.scalefactor[ch][sb] = 1 << (frame.scale_factor[ch][sb] + 1)
105
106
107    frame.bits = sbc_bit_allocation(frame)
108
109    frame.audio_sample = np.ndarray(shape=(frame.nr_blocks, frame.nr_channels, frame.nr_subbands), dtype = np.uint16)
110    for blk in range(frame.nr_blocks):
111        for ch in range(frame.nr_channels):
112            for sb in range(frame.nr_subbands):
113                frame.audio_sample[blk][ch][sb] = get_bits(fin, frame.bits[ch][sb])
114        #print ("block %2d - audio sample: %s" % (blk, frame.audio_sample[blk][0]))
115
116    drop_remaining_bits()
117    return 0
118
119def sbc_reconstruct_subband_samples(frame):
120    frame.levels = np.zeros(shape=(frame.nr_channels, frame.nr_subbands), dtype = np.int32)
121    frame.sb_sample = np.zeros(shape=(frame.nr_blocks, frame.nr_channels, frame.nr_subbands))
122
123    for ch in range(frame.nr_channels):
124        for sb in range(frame.nr_subbands):
125            frame.levels[ch][sb] = pow(2.0, frame.bits[ch][sb]) - 1
126
127    for blk in range(frame.nr_blocks):
128        for ch in range(frame.nr_channels):
129            for sb in range(frame.nr_subbands):
130                if frame.levels[ch][sb] > 0:
131                    AS = frame.audio_sample[blk][ch][sb]
132                    L  = frame.levels[ch][sb]
133                    SF = frame.scalefactor[ch][sb]
134                    frame.sb_sample[blk][ch][sb] = SF * ((AS*2.0+1.0) / L -1.0 )
135                else:
136                    frame.sb_sample[blk][ch][sb] = 0
137
138    # sythesis filter
139    if frame.channel_mode == JOINT_STEREO:
140        for blk in range(frame.nr_blocks):
141            for sb in range(frame.nr_subbands):
142                if frame.join[sb]==1:
143                    ch_a = frame.sb_sample[blk][0][sb] + frame.sb_sample[blk][1][sb]
144                    ch_b = frame.sb_sample[blk][0][sb] - frame.sb_sample[blk][1][sb]
145                    frame.sb_sample[blk][0][sb] = ch_a
146                    frame.sb_sample[blk][1][sb] = ch_b
147
148    return 0
149
150
151def sbc_frame_synthesis_sig(frame, ch, blk, proto_table):
152    global V, N
153    M = frame.nr_subbands
154    L = 10 * M
155    M2 = 2*M
156    L2 = 2*L
157
158    S = np.zeros(M)
159    U = np.zeros(L)
160    W = np.zeros(L)
161    frame.X = np.zeros(M)
162
163    for i in range(M):
164        S[i] = frame.sb_sample[blk][ch][i]
165
166    for i in range(L2-1, M2-1,-1):
167        V[ch][i] = V[ch][i-M2]
168
169    for k in range(M2):
170        V[ch][k] = 0
171        for i in range(M):
172            V[ch][k] += N[k][i] * S[i]
173
174    for i in range(5):
175        for j in range(M):
176            U[i*M2+j] = V[ch][i*2*M2+j]
177            U[(i*2+1)*M+j] = V[ch][(i*4+3)*M+j]
178
179    for i in range(L):
180        D = proto_table[i] * (-M)
181        W[i] = U[i]*D
182
183
184    offset = blk*M
185    for j in range(M):
186        for i in range(10):
187            frame.X[j] += W[j+M*i]
188        frame.pcm[ch][offset + j] = np.int16(frame.X[j])
189
190
191def sbc_frame_synthesis_v1(frame, ch, blk, proto_table):
192    global V
193    N = matrix_N()
194
195    M = frame.nr_subbands
196    L = 10 * M
197    M2 = 2*M
198    L2 = 2*L
199
200    S = np.zeros(M)
201    W = np.zeros(L)
202    frame.X = np.zeros(M)
203
204    for i in range(M):
205        S[i] = frame.sb_sample[blk][ch][i]
206
207    for i in range(L2-1, M2-1,-1):
208        V[ch][i] = V[ch][i-M2]
209
210
211    for k in range(M2):
212        V[ch][k] = 0
213        for i in range(M):
214            V[ch][k] += N[k][i] * S[i]
215
216    for i in range(L):
217        D = proto_table[i] * (-M)
218        W[i] = D * VSGN(i,M2) * V[ch][remap_V(i)]
219
220    offset = blk*M
221    for j in range(M):
222        for i in range(10):
223            frame.X[j] += W[j+M*i]
224        frame.pcm[ch][offset + j] = np.int16(frame.X[j])
225
226
227def sbc_frame_synthesis(frame, ch, blk, proto_table, implementation = "SIG"):
228    global total_time_ms
229
230    t1 = time_ms()
231    if implementation == "SIG":
232         sbc_frame_synthesis_sig(frame, ch, blk, proto_table)
233    elif implementation == "V1":
234        sbc_frame_synthesis_v1(frame, ch, blk, proto_table)
235    else:
236        print ("synthesis %s not implemented" % implementation)
237        exit(1)
238
239    t2 = time_ms()
240    total_time_ms += t2-t1
241
242
243def sbc_init_synthesis_sig(M):
244    global N
245    M2 = M << 1
246
247    N = np.zeros(shape = (M2,M))
248    for k in range(M2):
249        for i in range(M):
250            N[k][i] = np.cos((i+0.5)*(k+M/2)*np.pi/M)
251
252
253
254def sbc_init_sythesis(nr_subbands, implementation = "SIG"):
255    if implementation == "SIG":
256         sbc_init_synthesis_sig(nr_subbands)
257    elif implementation == "V1":
258        sbc_init_synthesis_v1(nr_subbands)
259    else:
260        print ("synthesis %s not implemented" % implementation)
261        exit(1)
262
263
264def sbc_synthesis(frame, implementation = "SIG"):
265    if frame.nr_subbands == 4:
266        proto_table = Proto_4_40
267    elif frame.nr_subbands == 8:
268        proto_table = Proto_8_80
269    else:
270        return -1
271    for ch in range(frame.nr_channels):
272        for blk in range(frame.nr_blocks):
273            sbc_frame_synthesis(frame, ch, blk, proto_table, implementation)
274
275    return frame.nr_blocks * frame.nr_subbands
276
277def sbc_decode(frame, implementation = "SIG"):
278    err = sbc_reconstruct_subband_samples(frame)
279    if err >= 0:
280        err = sbc_synthesis(frame, implementation)
281    return err
282
283
284def write_wav_file(fout, frame):
285    values = []
286
287    for i in range(frame.nr_subbands * frame.nr_blocks):
288        for ch in range(frame.nr_channels):
289            try:
290                packed_value = struct.pack('h', frame.pcm[ch][i])
291                values.append(packed_value)
292            except struct.error:
293                print (frame)
294                print (i, frame.pcm[ch][i], frame.pcm[ch])
295                exit(1)
296
297    value_str = ''.join(values)
298    fout.writeframes(value_str)
299
300
301
302if __name__ == "__main__":
303    usage = '''
304    Usage: ./sbc_decoder.py input.(msbc|sbc) implementation[default=SIG, V1]
305    '''
306
307    if (len(sys.argv) < 2):
308        print(usage)
309        sys.exit(1)
310    try:
311        mSBC_enabled = 0
312        infile = sys.argv[1]
313        if not infile.endswith('.sbc'):
314            if infile.endswith('.msbc'):
315                wavfile = infile.replace('.msbc', '-decoded.wav')
316                mSBC_enabled = 1
317            else:
318                print(usage)
319                sys.exit(1)
320        else:
321            wavfile = infile.replace('.sbc', '-decoded-py.wav')
322
323        print ("input file: ", infile)
324        print ("output file: ", wavfile)
325        print ("mSBC enabled: ", mSBC_enabled)
326
327        fout = False
328
329        implementation = "SIG"
330        if len(sys.argv) == 3:
331            implementation = sys.argv[2]
332            if implementation != "V1":
333                print ("synthesis %s not implemented" % implementation)
334                exit(1)
335
336        print ("\nSynthesis implementation: %s\n" % implementation)
337
338        with open (infile, 'rb') as fin:
339            try:
340                fin.seek(0, 2)
341                file_size = fin.tell()
342                fin.seek(0, 0)
343
344                frame_count = 0
345                while True:
346                    frame = SBCFrame()
347                    if frame_count % 200 == 0:
348                        print ("== Frame %d == offset %d" % (frame_count, fin.tell()))
349
350                    err = sbc_unpack_frame(fin, file_size - fin.tell(), frame)
351                    if err:
352                        #print ("error, frame_count: ", frame_count)
353                        continue
354
355                    if frame_count == 0:
356                        sbc_init_sythesis(frame.nr_subbands, implementation)
357                        print (frame                    )
358
359                    sbc_decode(frame, implementation)
360
361                    if frame_count == 0:
362                        fout = wave.open(wavfile, 'w')
363                        fout.setnchannels(frame.nr_channels)
364                        fout.setsampwidth(2)
365                        fout.setframerate(sampling_frequencies[frame.sampling_frequency])
366                        fout.setnframes(0)
367                        fout.setcomptype = 'NONE'
368
369                        print (frame.pcm)
370
371
372                    write_wav_file(fout, frame)
373                    frame_count += 1
374
375                    # if frame_count == 1:
376                    #     break
377
378            except TypeError as err:
379                if not fout:
380                    print (err)
381                else:
382                    fout.close()
383                    if frame_count > 0:
384                        print ("DONE, SBC file %s decoded into WAV file %s " % (infile, wavfile))
385                        print ("Average sythesis time per frame: %d ms/frame" % (total_time_ms/frame_count))
386                    else:
387                        print ("No frame found")
388                exit(0)
389
390        fout.close()
391        if frame_count > 0:
392            print ("DONE: SBC file %s decoded into WAV file %s " % (infile, wavfile))
393            print ("Average sythesis time per frame: %d ms/frame" % (total_time_ms/frame_count))
394        else:
395            print ("No frame found")
396
397    except IOError as e:
398        print(usage)
399        sys.exit(1)
400
401
402
403
404
405