xref: /btstack/3rd-party/lc3-google/test/encoder.py (revision 4930cef6e21e6da2d7571b9259c7f0fb8bed3d01)
1*4930cef6SMatthias Ringwald#!/usr/bin/env python3
2*4930cef6SMatthias Ringwald#
3*4930cef6SMatthias Ringwald# Copyright 2022 Google LLC
4*4930cef6SMatthias Ringwald#
5*4930cef6SMatthias Ringwald# Licensed under the Apache License, Version 2.0 (the "License");
6*4930cef6SMatthias Ringwald# you may not use this file except in compliance with the License.
7*4930cef6SMatthias Ringwald# You may obtain a copy of the License at
8*4930cef6SMatthias Ringwald#
9*4930cef6SMatthias Ringwald#     http://www.apache.org/licenses/LICENSE-2.0
10*4930cef6SMatthias Ringwald#
11*4930cef6SMatthias Ringwald# Unless required by applicable law or agreed to in writing, software
12*4930cef6SMatthias Ringwald# distributed under the License is distributed on an "AS IS" BASIS,
13*4930cef6SMatthias Ringwald# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14*4930cef6SMatthias Ringwald# See the License for the specific language governing permissions and
15*4930cef6SMatthias Ringwald# limitations under the License.
16*4930cef6SMatthias Ringwald#
17*4930cef6SMatthias Ringwald
18*4930cef6SMatthias Ringwaldimport numpy as np
19*4930cef6SMatthias Ringwaldimport scipy.signal as signal
20*4930cef6SMatthias Ringwaldimport scipy.io.wavfile as wavfile
21*4930cef6SMatthias Ringwaldimport struct
22*4930cef6SMatthias Ringwaldimport argparse
23*4930cef6SMatthias Ringwald
24*4930cef6SMatthias Ringwaldimport build.lc3 as lc3
25*4930cef6SMatthias Ringwaldimport tables as T, appendix_c as C
26*4930cef6SMatthias Ringwald
27*4930cef6SMatthias Ringwaldimport attdet, ltpf
28*4930cef6SMatthias Ringwaldimport mdct, energy, bwdet, sns, tns, spec
29*4930cef6SMatthias Ringwaldimport bitstream
30*4930cef6SMatthias Ringwald
31*4930cef6SMatthias Ringwald### ------------------------------------------------------------------------ ###
32*4930cef6SMatthias Ringwald
33*4930cef6SMatthias Ringwaldclass Encoder:
34*4930cef6SMatthias Ringwald
35*4930cef6SMatthias Ringwald    def __init__(self, dt_ms, sr_hz):
36*4930cef6SMatthias Ringwald
37*4930cef6SMatthias Ringwald        dt = { 7.5: T.DT_7M5, 10: T.DT_10M }[dt_ms]
38*4930cef6SMatthias Ringwald
39*4930cef6SMatthias Ringwald        sr = {  8000: T.SRATE_8K , 16000: T.SRATE_16K, 24000: T.SRATE_24K,
40*4930cef6SMatthias Ringwald               32000: T.SRATE_32K, 48000: T.SRATE_48K }[sr_hz]
41*4930cef6SMatthias Ringwald
42*4930cef6SMatthias Ringwald        self.ne = T.NE[dt][sr]
43*4930cef6SMatthias Ringwald
44*4930cef6SMatthias Ringwald        self.attdet = attdet.AttackDetector(dt, sr)
45*4930cef6SMatthias Ringwald        self.ltpf = ltpf.Ltpf(dt, sr)
46*4930cef6SMatthias Ringwald
47*4930cef6SMatthias Ringwald        self.mdct = mdct.Mdct(dt, sr)
48*4930cef6SMatthias Ringwald        self.energy = e_energy.EnergyBand(dt, sr)
49*4930cef6SMatthias Ringwald        self.bwdet = bwdet.BandwidthDetector(dt, sr)
50*4930cef6SMatthias Ringwald        self.sns = sns.SnsAnalysis(dt, sr)
51*4930cef6SMatthias Ringwald        self.tns = tns.TnsAnalysis(dt)
52*4930cef6SMatthias Ringwald        self.spec = spec.SpectrumEncoder(dt, sr)
53*4930cef6SMatthias Ringwald
54*4930cef6SMatthias Ringwald    def analyse(self, x, nbytes):
55*4930cef6SMatthias Ringwald
56*4930cef6SMatthias Ringwald        att = self.attdet.run(nbytes, x)
57*4930cef6SMatthias Ringwald
58*4930cef6SMatthias Ringwald        pitch_present = self.ltpf.run(x)
59*4930cef6SMatthias Ringwald
60*4930cef6SMatthias Ringwald        x = self.mdct.forward(x)[:self.ne]
61*4930cef6SMatthias Ringwald
62*4930cef6SMatthias Ringwald        (e, nn_flag) = self.energy.compute(x)
63*4930cef6SMatthias Ringwald        if nn_flag:
64*4930cef6SMatthias Ringwald            self.ltpf.disable()
65*4930cef6SMatthias Ringwald
66*4930cef6SMatthias Ringwald        bw = self.bwdet.run(e)
67*4930cef6SMatthias Ringwald
68*4930cef6SMatthias Ringwald        x = self.sns.run(e, att, x)
69*4930cef6SMatthias Ringwald
70*4930cef6SMatthias Ringwald        x = self.tns.run(x, bw, nn_flag, nbytes)
71*4930cef6SMatthias Ringwald
72*4930cef6SMatthias Ringwald        (xq, lastnz, x) = self.spec.quantize(bw, nbytes,
73*4930cef6SMatthias Ringwald            self.bwdet.get_nbits(), self.ltpf.get_nbits(),
74*4930cef6SMatthias Ringwald            self.sns.get_nbits(), self.tns.get_nbits(), x)
75*4930cef6SMatthias Ringwald
76*4930cef6SMatthias Ringwald        return pitch_present
77*4930cef6SMatthias Ringwald
78*4930cef6SMatthias Ringwald    def encode(self, pitch_present, nbytes):
79*4930cef6SMatthias Ringwald
80*4930cef6SMatthias Ringwald        b = bitstream.BitstreamWriter(nbytes)
81*4930cef6SMatthias Ringwald
82*4930cef6SMatthias Ringwald        self.bwdet.store(b)
83*4930cef6SMatthias Ringwald
84*4930cef6SMatthias Ringwald        self.spec.store(b)
85*4930cef6SMatthias Ringwald
86*4930cef6SMatthias Ringwald        self.tns.store(b)
87*4930cef6SMatthias Ringwald
88*4930cef6SMatthias Ringwald        b.write_bit(pitch_present)
89*4930cef6SMatthias Ringwald
90*4930cef6SMatthias Ringwald        self.sns.store(b)
91*4930cef6SMatthias Ringwald
92*4930cef6SMatthias Ringwald        if pitch_present:
93*4930cef6SMatthias Ringwald            self.ltpf.store_data(b)
94*4930cef6SMatthias Ringwald
95*4930cef6SMatthias Ringwald        self.spec.encode(b)
96*4930cef6SMatthias Ringwald
97*4930cef6SMatthias Ringwald        return b.terminate()
98*4930cef6SMatthias Ringwald
99*4930cef6SMatthias Ringwald    def run(self, x, nbytes):
100*4930cef6SMatthias Ringwald
101*4930cef6SMatthias Ringwald        pitch_present = self.analyse(x, nbytes)
102*4930cef6SMatthias Ringwald
103*4930cef6SMatthias Ringwald        data = self.encode(pitch_present, nbytes)
104*4930cef6SMatthias Ringwald
105*4930cef6SMatthias Ringwald        return data
106*4930cef6SMatthias Ringwald
107*4930cef6SMatthias Ringwald### ------------------------------------------------------------------------ ###
108*4930cef6SMatthias Ringwald
109*4930cef6SMatthias Ringwalddef check_appendix_c(dt):
110*4930cef6SMatthias Ringwald
111*4930cef6SMatthias Ringwald    ok = True
112*4930cef6SMatthias Ringwald
113*4930cef6SMatthias Ringwald    enc_c = lc3.setup_encoder(int(T.DT_MS[dt] * 1000), 16000)
114*4930cef6SMatthias Ringwald
115*4930cef6SMatthias Ringwald    for i in range(len(C.X_PCM[dt])):
116*4930cef6SMatthias Ringwald
117*4930cef6SMatthias Ringwald        data = lc3.encode(enc_c, C.X_PCM[dt][i], C.NBYTES[dt])
118*4930cef6SMatthias Ringwald        ok = ok and data == C.BYTES_AC[dt][i]
119*4930cef6SMatthias Ringwald        if not ok:
120*4930cef6SMatthias Ringwald            dump(data)
121*4930cef6SMatthias Ringwald            dump(C.BYTES_AC[dt][i])
122*4930cef6SMatthias Ringwald
123*4930cef6SMatthias Ringwald    return ok
124*4930cef6SMatthias Ringwald
125*4930cef6SMatthias Ringwalddef check():
126*4930cef6SMatthias Ringwald
127*4930cef6SMatthias Ringwald    ok = True
128*4930cef6SMatthias Ringwald
129*4930cef6SMatthias Ringwald    for dt in range(T.NUM_DT):
130*4930cef6SMatthias Ringwald        ok = ok and check_appendix_c(dt)
131*4930cef6SMatthias Ringwald
132*4930cef6SMatthias Ringwald    return ok
133*4930cef6SMatthias Ringwald
134*4930cef6SMatthias Ringwald### ------------------------------------------------------------------------ ###
135*4930cef6SMatthias Ringwald
136*4930cef6SMatthias Ringwalddef dump(data):
137*4930cef6SMatthias Ringwald    for i in range(0, len(data), 20):
138*4930cef6SMatthias Ringwald        print(''.join('{:02x} '.format(x)
139*4930cef6SMatthias Ringwald            for x in data[i:min(i+20, len(data))] ))
140*4930cef6SMatthias Ringwald
141*4930cef6SMatthias Ringwaldif __name__ == "__main__":
142*4930cef6SMatthias Ringwald
143*4930cef6SMatthias Ringwald    parser = argparse.ArgumentParser(description='LC3 Encoder Test Framework')
144*4930cef6SMatthias Ringwald    parser.add_argument('wav_file',
145*4930cef6SMatthias Ringwald        help='Input wave file', type=argparse.FileType('r'))
146*4930cef6SMatthias Ringwald    parser.add_argument('--bitrate',
147*4930cef6SMatthias Ringwald        help='Bitrate in bps', type=int, required=True)
148*4930cef6SMatthias Ringwald    parser.add_argument('--dt',
149*4930cef6SMatthias Ringwald        help='Frame duration in ms', type=float, default=10)
150*4930cef6SMatthias Ringwald    parser.add_argument('--pyout',
151*4930cef6SMatthias Ringwald        help='Python output file', type=argparse.FileType('w'))
152*4930cef6SMatthias Ringwald    parser.add_argument('--cout',
153*4930cef6SMatthias Ringwald        help='C output file', type=argparse.FileType('w'))
154*4930cef6SMatthias Ringwald    args = parser.parse_args()
155*4930cef6SMatthias Ringwald
156*4930cef6SMatthias Ringwald    if args.bitrate < 16000 or args.bitrate > 320000:
157*4930cef6SMatthias Ringwald        raise ValueError('Invalid bitate %d bps' % args.bitrate)
158*4930cef6SMatthias Ringwald
159*4930cef6SMatthias Ringwald    if args.dt not in (7.5, 10):
160*4930cef6SMatthias Ringwald        raise ValueError('Invalid frame duration %.1f ms' % args.dt)
161*4930cef6SMatthias Ringwald
162*4930cef6SMatthias Ringwald    (sr_hz, pcm) = wavfile.read(args.wav_file.name)
163*4930cef6SMatthias Ringwald    if sr_hz not in (8000, 16000, 24000, 320000, 48000):
164*4930cef6SMatthias Ringwald        raise ValueError('Unsupported input samplerate: %d' % sr_hz)
165*4930cef6SMatthias Ringwald
166*4930cef6SMatthias Ringwald    ### Setup ###
167*4930cef6SMatthias Ringwald
168*4930cef6SMatthias Ringwald    enc = Encoder(args.dt, sr_hz)
169*4930cef6SMatthias Ringwald    enc_c = lc3.setup_encoder(int(args.dt * 1000), sr_hz)
170*4930cef6SMatthias Ringwald
171*4930cef6SMatthias Ringwald    frame_samples = int((args.dt * sr_hz) / 1000)
172*4930cef6SMatthias Ringwald    frame_nbytes = int((args.bitrate * args.dt) / (1000 * 8))
173*4930cef6SMatthias Ringwald
174*4930cef6SMatthias Ringwald    ### File Header ###
175*4930cef6SMatthias Ringwald
176*4930cef6SMatthias Ringwald    f_py = open(args.pyout.name, 'wb') if args.pyout else None
177*4930cef6SMatthias Ringwald    f_c  = open(args.cout.name , 'wb') if args.cout  else None
178*4930cef6SMatthias Ringwald
179*4930cef6SMatthias Ringwald    header = struct.pack('=HHHHHHHI', 0xcc1c, 18,
180*4930cef6SMatthias Ringwald        sr_hz // 100, args.bitrate // 100, 1, int(args.dt * 100), 0, len(pcm))
181*4930cef6SMatthias Ringwald
182*4930cef6SMatthias Ringwald    for f in (f_py, f_c):
183*4930cef6SMatthias Ringwald        if f: f.write(header)
184*4930cef6SMatthias Ringwald
185*4930cef6SMatthias Ringwald    ### Encoding loop ###
186*4930cef6SMatthias Ringwald
187*4930cef6SMatthias Ringwald    if len(pcm) % frame_samples > 0:
188*4930cef6SMatthias Ringwald        pcm = np.append(pcm, np.zeros(frame_samples - (len(pcm) % frame_samples)))
189*4930cef6SMatthias Ringwald
190*4930cef6SMatthias Ringwald    for i in range(0, len(pcm), frame_samples):
191*4930cef6SMatthias Ringwald
192*4930cef6SMatthias Ringwald        print('Encoding frame %d' % (i // frame_samples), end='\r')
193*4930cef6SMatthias Ringwald
194*4930cef6SMatthias Ringwald        frame_pcm = pcm[i:i+frame_samples]
195*4930cef6SMatthias Ringwald
196*4930cef6SMatthias Ringwald        data = enc.run(frame_pcm, frame_nbytes)
197*4930cef6SMatthias Ringwald        data_c = lc3.encode(enc_c, frame_pcm, frame_nbytes)
198*4930cef6SMatthias Ringwald
199*4930cef6SMatthias Ringwald        for f in (f_py, f_c):
200*4930cef6SMatthias Ringwald            if f: f.write(struct.pack('=H', frame_nbytes))
201*4930cef6SMatthias Ringwald
202*4930cef6SMatthias Ringwald        if f_py: f_py.write(data)
203*4930cef6SMatthias Ringwald        if f_c: f_c.write(data_c)
204*4930cef6SMatthias Ringwald
205*4930cef6SMatthias Ringwald    print('done ! %16s' % '')
206*4930cef6SMatthias Ringwald
207*4930cef6SMatthias Ringwald    ### Terminate ###
208*4930cef6SMatthias Ringwald
209*4930cef6SMatthias Ringwald    for f in (f_py, f_c):
210*4930cef6SMatthias Ringwald        if f: f.close()
211*4930cef6SMatthias Ringwald
212*4930cef6SMatthias Ringwald
213*4930cef6SMatthias Ringwald### ------------------------------------------------------------------------ ###
214