xref: /btstack/3rd-party/lc3-google/test/encoder.py (revision 4c4eb519208b4224604d94b3ed1931841ddd93bb)
14930cef6SMatthias Ringwald#!/usr/bin/env python3
24930cef6SMatthias Ringwald#
34930cef6SMatthias Ringwald# Copyright 2022 Google LLC
44930cef6SMatthias Ringwald#
54930cef6SMatthias Ringwald# Licensed under the Apache License, Version 2.0 (the "License");
64930cef6SMatthias Ringwald# you may not use this file except in compliance with the License.
74930cef6SMatthias Ringwald# You may obtain a copy of the License at
84930cef6SMatthias Ringwald#
94930cef6SMatthias Ringwald#     http://www.apache.org/licenses/LICENSE-2.0
104930cef6SMatthias Ringwald#
114930cef6SMatthias Ringwald# Unless required by applicable law or agreed to in writing, software
124930cef6SMatthias Ringwald# distributed under the License is distributed on an "AS IS" BASIS,
134930cef6SMatthias Ringwald# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
144930cef6SMatthias Ringwald# See the License for the specific language governing permissions and
154930cef6SMatthias Ringwald# limitations under the License.
164930cef6SMatthias Ringwald#
174930cef6SMatthias Ringwald
184930cef6SMatthias Ringwaldimport numpy as np
194930cef6SMatthias Ringwaldimport scipy.signal as signal
204930cef6SMatthias Ringwaldimport scipy.io.wavfile as wavfile
214930cef6SMatthias Ringwaldimport struct
224930cef6SMatthias Ringwaldimport argparse
234930cef6SMatthias Ringwald
24*4c4eb519SMatthias Ringwaldimport lc3
254930cef6SMatthias Ringwaldimport tables as T, appendix_c as C
264930cef6SMatthias Ringwald
274930cef6SMatthias Ringwaldimport attdet, ltpf
284930cef6SMatthias Ringwaldimport mdct, energy, bwdet, sns, tns, spec
294930cef6SMatthias Ringwaldimport bitstream
304930cef6SMatthias Ringwald
314930cef6SMatthias Ringwald### ------------------------------------------------------------------------ ###
324930cef6SMatthias Ringwald
334930cef6SMatthias Ringwaldclass Encoder:
344930cef6SMatthias Ringwald
354930cef6SMatthias Ringwald    def __init__(self, dt_ms, sr_hz):
364930cef6SMatthias Ringwald
374930cef6SMatthias Ringwald        dt = { 7.5: T.DT_7M5, 10: T.DT_10M }[dt_ms]
384930cef6SMatthias Ringwald
394930cef6SMatthias Ringwald        sr = {  8000: T.SRATE_8K , 16000: T.SRATE_16K, 24000: T.SRATE_24K,
404930cef6SMatthias Ringwald               32000: T.SRATE_32K, 48000: T.SRATE_48K }[sr_hz]
414930cef6SMatthias Ringwald
424930cef6SMatthias Ringwald        self.ne = T.NE[dt][sr]
434930cef6SMatthias Ringwald
444930cef6SMatthias Ringwald        self.attdet = attdet.AttackDetector(dt, sr)
45*4c4eb519SMatthias Ringwald        self.ltpf = ltpf.LtpfAnalysis(dt, sr)
464930cef6SMatthias Ringwald
47*4c4eb519SMatthias Ringwald        self.mdct = mdct.MdctForward(dt, sr)
48*4c4eb519SMatthias Ringwald        self.energy = energy.EnergyBand(dt, sr)
494930cef6SMatthias Ringwald        self.bwdet = bwdet.BandwidthDetector(dt, sr)
504930cef6SMatthias Ringwald        self.sns = sns.SnsAnalysis(dt, sr)
514930cef6SMatthias Ringwald        self.tns = tns.TnsAnalysis(dt)
52*4c4eb519SMatthias Ringwald        self.spec = spec.SpectrumAnalysis(dt, sr)
534930cef6SMatthias Ringwald
544930cef6SMatthias Ringwald    def analyse(self, x, nbytes):
554930cef6SMatthias Ringwald
564930cef6SMatthias Ringwald        att = self.attdet.run(nbytes, x)
574930cef6SMatthias Ringwald
584930cef6SMatthias Ringwald        pitch_present = self.ltpf.run(x)
594930cef6SMatthias Ringwald
60*4c4eb519SMatthias Ringwald        x = self.mdct.run(x)[:self.ne]
614930cef6SMatthias Ringwald
624930cef6SMatthias Ringwald        (e, nn_flag) = self.energy.compute(x)
634930cef6SMatthias Ringwald        if nn_flag:
644930cef6SMatthias Ringwald            self.ltpf.disable()
654930cef6SMatthias Ringwald
664930cef6SMatthias Ringwald        bw = self.bwdet.run(e)
674930cef6SMatthias Ringwald
684930cef6SMatthias Ringwald        x = self.sns.run(e, att, x)
694930cef6SMatthias Ringwald
704930cef6SMatthias Ringwald        x = self.tns.run(x, bw, nn_flag, nbytes)
714930cef6SMatthias Ringwald
72*4c4eb519SMatthias Ringwald        (xq, lastnz, x) = self.spec.run(bw, nbytes,
734930cef6SMatthias Ringwald            self.bwdet.get_nbits(), self.ltpf.get_nbits(),
744930cef6SMatthias Ringwald            self.sns.get_nbits(), self.tns.get_nbits(), x)
754930cef6SMatthias Ringwald
764930cef6SMatthias Ringwald        return pitch_present
774930cef6SMatthias Ringwald
784930cef6SMatthias Ringwald    def encode(self, pitch_present, nbytes):
794930cef6SMatthias Ringwald
804930cef6SMatthias Ringwald        b = bitstream.BitstreamWriter(nbytes)
814930cef6SMatthias Ringwald
824930cef6SMatthias Ringwald        self.bwdet.store(b)
834930cef6SMatthias Ringwald
844930cef6SMatthias Ringwald        self.spec.store(b)
854930cef6SMatthias Ringwald
864930cef6SMatthias Ringwald        self.tns.store(b)
874930cef6SMatthias Ringwald
884930cef6SMatthias Ringwald        b.write_bit(pitch_present)
894930cef6SMatthias Ringwald
904930cef6SMatthias Ringwald        self.sns.store(b)
914930cef6SMatthias Ringwald
924930cef6SMatthias Ringwald        if pitch_present:
93*4c4eb519SMatthias Ringwald            self.ltpf.store(b)
944930cef6SMatthias Ringwald
954930cef6SMatthias Ringwald        self.spec.encode(b)
964930cef6SMatthias Ringwald
974930cef6SMatthias Ringwald        return b.terminate()
984930cef6SMatthias Ringwald
994930cef6SMatthias Ringwald    def run(self, x, nbytes):
1004930cef6SMatthias Ringwald
1014930cef6SMatthias Ringwald        pitch_present = self.analyse(x, nbytes)
1024930cef6SMatthias Ringwald
1034930cef6SMatthias Ringwald        data = self.encode(pitch_present, nbytes)
1044930cef6SMatthias Ringwald
1054930cef6SMatthias Ringwald        return data
1064930cef6SMatthias Ringwald
1074930cef6SMatthias Ringwald### ------------------------------------------------------------------------ ###
1084930cef6SMatthias Ringwald
1094930cef6SMatthias Ringwalddef check_appendix_c(dt):
1104930cef6SMatthias Ringwald
1114930cef6SMatthias Ringwald    ok = True
1124930cef6SMatthias Ringwald
1134930cef6SMatthias Ringwald    enc_c = lc3.setup_encoder(int(T.DT_MS[dt] * 1000), 16000)
1144930cef6SMatthias Ringwald
1154930cef6SMatthias Ringwald    for i in range(len(C.X_PCM[dt])):
1164930cef6SMatthias Ringwald
1174930cef6SMatthias Ringwald        data = lc3.encode(enc_c, C.X_PCM[dt][i], C.NBYTES[dt])
1184930cef6SMatthias Ringwald        ok = ok and data == C.BYTES_AC[dt][i]
1194930cef6SMatthias Ringwald
1204930cef6SMatthias Ringwald    return ok
1214930cef6SMatthias Ringwald
1224930cef6SMatthias Ringwalddef check():
1234930cef6SMatthias Ringwald
1244930cef6SMatthias Ringwald    ok = True
1254930cef6SMatthias Ringwald
1264930cef6SMatthias Ringwald    for dt in range(T.NUM_DT):
1274930cef6SMatthias Ringwald        ok = ok and check_appendix_c(dt)
1284930cef6SMatthias Ringwald
1294930cef6SMatthias Ringwald    return ok
1304930cef6SMatthias Ringwald
1314930cef6SMatthias Ringwald### ------------------------------------------------------------------------ ###
1324930cef6SMatthias Ringwald
1334930cef6SMatthias Ringwalddef dump(data):
1344930cef6SMatthias Ringwald    for i in range(0, len(data), 20):
1354930cef6SMatthias Ringwald        print(''.join('{:02x} '.format(x)
1364930cef6SMatthias Ringwald            for x in data[i:min(i+20, len(data))] ))
1374930cef6SMatthias Ringwald
1384930cef6SMatthias Ringwaldif __name__ == "__main__":
1394930cef6SMatthias Ringwald
1404930cef6SMatthias Ringwald    parser = argparse.ArgumentParser(description='LC3 Encoder Test Framework')
1414930cef6SMatthias Ringwald    parser.add_argument('wav_file',
1424930cef6SMatthias Ringwald        help='Input wave file', type=argparse.FileType('r'))
1434930cef6SMatthias Ringwald    parser.add_argument('--bitrate',
1444930cef6SMatthias Ringwald        help='Bitrate in bps', type=int, required=True)
1454930cef6SMatthias Ringwald    parser.add_argument('--dt',
1464930cef6SMatthias Ringwald        help='Frame duration in ms', type=float, default=10)
1474930cef6SMatthias Ringwald    parser.add_argument('--pyout',
1484930cef6SMatthias Ringwald        help='Python output file', type=argparse.FileType('w'))
1494930cef6SMatthias Ringwald    parser.add_argument('--cout',
1504930cef6SMatthias Ringwald        help='C output file', type=argparse.FileType('w'))
1514930cef6SMatthias Ringwald    args = parser.parse_args()
1524930cef6SMatthias Ringwald
1534930cef6SMatthias Ringwald    if args.bitrate < 16000 or args.bitrate > 320000:
1544930cef6SMatthias Ringwald        raise ValueError('Invalid bitate %d bps' % args.bitrate)
1554930cef6SMatthias Ringwald
1564930cef6SMatthias Ringwald    if args.dt not in (7.5, 10):
1574930cef6SMatthias Ringwald        raise ValueError('Invalid frame duration %.1f ms' % args.dt)
1584930cef6SMatthias Ringwald
1594930cef6SMatthias Ringwald    (sr_hz, pcm) = wavfile.read(args.wav_file.name)
1604930cef6SMatthias Ringwald    if sr_hz not in (8000, 16000, 24000, 320000, 48000):
1614930cef6SMatthias Ringwald        raise ValueError('Unsupported input samplerate: %d' % sr_hz)
162*4c4eb519SMatthias Ringwald    if pcm.ndim != 1:
163*4c4eb519SMatthias Ringwald        raise ValueError('Only single channel wav file supported')
1644930cef6SMatthias Ringwald
1654930cef6SMatthias Ringwald    ### Setup ###
1664930cef6SMatthias Ringwald
1674930cef6SMatthias Ringwald    enc = Encoder(args.dt, sr_hz)
1684930cef6SMatthias Ringwald    enc_c = lc3.setup_encoder(int(args.dt * 1000), sr_hz)
1694930cef6SMatthias Ringwald
1704930cef6SMatthias Ringwald    frame_samples = int((args.dt * sr_hz) / 1000)
1714930cef6SMatthias Ringwald    frame_nbytes = int((args.bitrate * args.dt) / (1000 * 8))
1724930cef6SMatthias Ringwald
1734930cef6SMatthias Ringwald    ### File Header ###
1744930cef6SMatthias Ringwald
1754930cef6SMatthias Ringwald    f_py = open(args.pyout.name, 'wb') if args.pyout else None
1764930cef6SMatthias Ringwald    f_c  = open(args.cout.name , 'wb') if args.cout  else None
1774930cef6SMatthias Ringwald
1784930cef6SMatthias Ringwald    header = struct.pack('=HHHHHHHI', 0xcc1c, 18,
1794930cef6SMatthias Ringwald        sr_hz // 100, args.bitrate // 100, 1, int(args.dt * 100), 0, len(pcm))
1804930cef6SMatthias Ringwald
1814930cef6SMatthias Ringwald    for f in (f_py, f_c):
1824930cef6SMatthias Ringwald        if f: f.write(header)
1834930cef6SMatthias Ringwald
1844930cef6SMatthias Ringwald    ### Encoding loop ###
1854930cef6SMatthias Ringwald
1864930cef6SMatthias Ringwald    if len(pcm) % frame_samples > 0:
1874930cef6SMatthias Ringwald        pcm = np.append(pcm, np.zeros(frame_samples - (len(pcm) % frame_samples)))
1884930cef6SMatthias Ringwald
1894930cef6SMatthias Ringwald    for i in range(0, len(pcm), frame_samples):
1904930cef6SMatthias Ringwald
1914930cef6SMatthias Ringwald        print('Encoding frame %d' % (i // frame_samples), end='\r')
1924930cef6SMatthias Ringwald
1934930cef6SMatthias Ringwald        frame_pcm = pcm[i:i+frame_samples]
1944930cef6SMatthias Ringwald
1954930cef6SMatthias Ringwald        data = enc.run(frame_pcm, frame_nbytes)
1964930cef6SMatthias Ringwald        data_c = lc3.encode(enc_c, frame_pcm, frame_nbytes)
1974930cef6SMatthias Ringwald
1984930cef6SMatthias Ringwald        for f in (f_py, f_c):
1994930cef6SMatthias Ringwald            if f: f.write(struct.pack('=H', frame_nbytes))
2004930cef6SMatthias Ringwald
2014930cef6SMatthias Ringwald        if f_py: f_py.write(data)
2024930cef6SMatthias Ringwald        if f_c: f_c.write(data_c)
2034930cef6SMatthias Ringwald
2044930cef6SMatthias Ringwald    print('done ! %16s' % '')
2054930cef6SMatthias Ringwald
2064930cef6SMatthias Ringwald    ### Terminate ###
2074930cef6SMatthias Ringwald
2084930cef6SMatthias Ringwald    for f in (f_py, f_c):
2094930cef6SMatthias Ringwald        if f: f.close()
2104930cef6SMatthias Ringwald
2114930cef6SMatthias Ringwald
2124930cef6SMatthias Ringwald### ------------------------------------------------------------------------ ###
213