xref: /aosp_15_r20/external/libopus/dnn/torch/osce/utils/spec.py (revision a58d3d2adb790c104798cd88c8a3aff4fa8b82cc)
1"""
2/* Copyright (c) 2023 Amazon
3   Written by Jan Buethe */
4/*
5   Redistribution and use in source and binary forms, with or without
6   modification, are permitted provided that the following conditions
7   are met:
8
9   - Redistributions of source code must retain the above copyright
10   notice, this list of conditions and the following disclaimer.
11
12   - Redistributions in binary form must reproduce the above copyright
13   notice, this list of conditions and the following disclaimer in the
14   documentation and/or other materials provided with the distribution.
15
16   THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
17   ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
18   LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
19   A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
20   OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
21   EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
22   PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
23   PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
24   LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
25   NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
26   SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
27*/
28"""
29
30import math as m
31import numpy as np
32import scipy
33import scipy.fftpack
34import torch
35
36def erb(f):
37    return 24.7 * (4.37 * f + 1)
38
39def inv_erb(e):
40    return (e / 24.7 - 1) / 4.37
41
42def bark(f):
43    return 6 * m.asinh(f/600)
44
45def inv_bark(b):
46    return 600 * m.sinh(b / 6)
47
48
49scale_dict = {
50    'bark': [bark, inv_bark],
51    'erb': [erb, inv_erb]
52}
53
54def gen_filterbank(N, Fs=16000, keep_size=False):
55    in_freq = (np.arange(N+1, dtype='float32')/N*Fs/2)[None,:]
56    M = N + 1 if keep_size else N
57    out_freq = (np.arange(M, dtype='float32')/N*Fs/2)[:,None]
58    #ERB from B.C.J Moore, An Introduction to the Psychology of Hearing, 5th Ed., page 73.
59    ERB_N = 24.7 + .108*in_freq
60    delta = np.abs(in_freq-out_freq)/ERB_N
61    center = (delta<.5).astype('float32')
62    R = -12*center*delta**2 + (1-center)*(3-12*delta)
63    RE = 10.**(R/10.)
64    norm = np.sum(RE, axis=1)
65    RE = RE/norm[:, np.newaxis]
66    return torch.from_numpy(RE)
67
68def create_filter_bank(num_bands, n_fft=320, fs=16000, scale='bark', round_center_bins=False, return_upper=False, normalize=False):
69
70    f0 = 0
71    num_bins = n_fft // 2 + 1
72    f1 = fs / n_fft * (num_bins - 1)
73    fstep = fs / n_fft
74
75    if scale == 'opus':
76        bins_5ms = [0,  1,  2,  3,  4,  5,  6,  7,  8, 10, 12, 14, 16, 20, 24, 28, 34, 40]
77        fac = 1000 * n_fft / fs / 5
78        if num_bands != 18:
79            print("warning: requested Opus filter bank with num_bands != 18. Adjusting num_bands.")
80            num_bands = 18
81        center_bins = np.array([fac * bin for bin in bins_5ms])
82    else:
83        to_scale, from_scale = scale_dict[scale]
84
85        s0 = to_scale(f0)
86        s1 = to_scale(f1)
87
88        center_freqs = np.array([f0] + [from_scale(s0 + i * (s1 - s0) / (num_bands)) for i in range(1, num_bands - 1)] + [f1])
89        center_bins  = (center_freqs - f0) / fstep
90
91    if round_center_bins:
92        center_bins = np.round(center_bins)
93
94    filter_bank = np.zeros((num_bands, num_bins))
95
96    band = 0
97    for bin in range(num_bins):
98        # update band index
99        if bin > center_bins[band + 1]:
100            band += 1
101
102        # calculate filter coefficients
103        frac = (center_bins[band + 1] - bin) / (center_bins[band + 1] - center_bins[band])
104        filter_bank[band][bin]     = frac
105        filter_bank[band + 1][bin] = 1 - frac
106
107    if return_upper:
108        extend = n_fft - num_bins
109        filter_bank = np.concatenate((filter_bank, np.fliplr(filter_bank[:, 1:extend+1])), axis=1)
110
111    if normalize:
112        filter_bank = filter_bank / np.sum(filter_bank, axis=1).reshape(-1, 1)
113
114    return filter_bank
115
116
117def compressed_log_spec(pspec):
118
119    lpspec = np.zeros_like(pspec)
120    num_bands = pspec.shape[-1]
121
122    log_max = -2
123    follow = -2
124
125    for i in range(num_bands):
126        tmp = np.log10(pspec[i] + 1e-9)
127        tmp = max(log_max, max(follow - 2.5, tmp))
128        lpspec[i] = tmp
129        log_max = max(log_max, tmp)
130        follow = max(follow - 2.5, tmp)
131
132    return lpspec
133
134def log_spectrum_from_lpc(a, fb=None, n_fft=320, eps=1e-9, gamma=1, compress=False, power=1):
135    """ calculates cepstrum from SILK lpcs """
136    order = a.shape[-1]
137    assert order + 1 < n_fft
138
139    a = a * (gamma ** (1 + np.arange(order))).astype(np.float32)
140
141    x = np.zeros((*a.shape[:-1], n_fft ))
142    x[..., 0] = 1
143    x[..., 1:1 + order] = -a
144
145    X = np.fft.fft(x, axis=-1)
146    X = np.abs(X[..., :n_fft//2 + 1]) ** power
147
148    S = 1 / (X + eps)
149
150    if fb is None:
151        Sf = S
152    else:
153        Sf = np.matmul(S, fb.T)
154
155    if compress:
156        Sf = np.apply_along_axis(compressed_log_spec, -1, Sf)
157    else:
158        Sf = np.log(Sf + eps)
159
160    return Sf
161
162def cepstrum_from_lpc(a, fb=None, n_fft=320, eps=1e-9, gamma=1, compress=False):
163    """ calculates cepstrum from SILK lpcs """
164
165    Sf = log_spectrum_from_lpc(a, fb, n_fft, eps, gamma, compress)
166
167    cepstrum = scipy.fftpack.dct(Sf, 2, norm='ortho')
168
169    return cepstrum
170
171
172
173def log_spectrum(x, frame_size, fb=None, window=None, power=1):
174    """ calculate cepstrum on 50% overlapping frames """
175
176    assert(2*len(x)) % frame_size == 0
177    assert frame_size % 2 == 0
178
179    n = len(x)
180    num_even = n // frame_size
181    num_odd  = (n - frame_size // 2) // frame_size
182    num_bins = frame_size // 2 + 1
183
184    x_even = x[:num_even * frame_size].reshape(-1, frame_size)
185    x_odd  = x[frame_size//2 : frame_size//2 + frame_size *  num_odd].reshape(-1, frame_size)
186
187    x_unfold = np.empty((x_even.size + x_odd.size), dtype=x.dtype).reshape((-1, frame_size))
188    x_unfold[::2, :] = x_even
189    x_unfold[1::2, :] = x_odd
190
191    if window is not None:
192        x_unfold *= window.reshape(1, -1)
193
194    X = np.abs(np.fft.fft(x_unfold, n=frame_size, axis=-1))[:, :num_bins] ** power
195
196    if fb is not None:
197        X = np.matmul(X, fb.T)
198
199
200    return np.log(X + 1e-9)
201
202
203def cepstrum(x, frame_size, fb=None, window=None):
204    """ calculate cepstrum on 50% overlapping frames """
205
206    X = log_spectrum(x, frame_size, fb, window)
207
208    cepstrum = scipy.fftpack.dct(X, 2, norm='ortho')
209
210    return cepstrum