xref: /aosp_15_r20/external/libopus/dnn/torch/osce/stndrd/evaluation/moc2.py (revision a58d3d2adb790c104798cd88c8a3aff4fa8b82cc)
1"""
2/* Copyright (c) 2023 Amazon
3   Written by Jan Buethe */
4/*
5   Redistribution and use in source and binary forms, with or without
6   modification, are permitted provided that the following conditions
7   are met:
8
9   - Redistributions of source code must retain the above copyright
10   notice, this list of conditions and the following disclaimer.
11
12   - Redistributions in binary form must reproduce the above copyright
13   notice, this list of conditions and the following disclaimer in the
14   documentation and/or other materials provided with the distribution.
15
16   THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
17   ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
18   LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
19   A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
20   OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
21   EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
22   PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
23   PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
24   LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
25   NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
26   SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
27*/
28"""
29
30import numpy as np
31import scipy.signal
32
33def compute_vad_mask(x, fs, stop_db=-70):
34
35    frame_length = (fs + 49) // 50
36    x = x[: frame_length * (len(x) // frame_length)]
37
38    frames = x.reshape(-1, frame_length)
39    frame_energy = np.sum(frames ** 2, axis=1)
40    frame_energy_smooth = np.convolve(frame_energy, np.ones(5) / 5, mode='same')
41
42    max_threshold = frame_energy.max() * 10 ** (stop_db/20)
43    vactive = np.ones_like(frames)
44    vactive[frame_energy_smooth < max_threshold, :] = 0
45    vactive = vactive.reshape(-1)
46
47    filter = np.sin(np.arange(frame_length) * np.pi / (frame_length - 1))
48    filter = filter / filter.sum()
49
50    mask = np.convolve(vactive, filter, mode='same')
51
52    return x, mask
53
54def convert_mask(mask, num_frames, frame_size=160, hop_size=40):
55    num_samples = frame_size + (num_frames - 1) * hop_size
56    if len(mask) < num_samples:
57        mask = np.concatenate((mask, np.zeros(num_samples - len(mask))), dtype=mask.dtype)
58    else:
59        mask = mask[:num_samples]
60
61    new_mask = np.array([np.mean(mask[i*hop_size : i*hop_size + frame_size]) for i in range(num_frames)])
62
63    return new_mask
64
65def power_spectrum(x, window_size=160, hop_size=40, window='hamming'):
66    num_spectra = (len(x) - window_size - hop_size) // hop_size
67    window = scipy.signal.get_window(window, window_size)
68    N = window_size // 2
69
70    frames = np.concatenate([x[np.newaxis, i * hop_size : i * hop_size + window_size] for i in range(num_spectra)]) * window
71    psd = np.abs(np.fft.fft(frames, axis=1)[:, :N + 1]) ** 2
72
73    return psd
74
75
76def frequency_mask(num_bands, up_factor, down_factor):
77
78    up_mask = np.zeros((num_bands, num_bands))
79    down_mask = np.zeros((num_bands, num_bands))
80
81    for i in range(num_bands):
82        up_mask[i, : i + 1] = up_factor ** np.arange(i, -1, -1)
83        down_mask[i, i :] = down_factor ** np.arange(num_bands - i)
84
85    return down_mask @ up_mask
86
87
88def rect_fb(band_limits, num_bins=None):
89    num_bands = len(band_limits) - 1
90    if num_bins is None:
91        num_bins = band_limits[-1]
92
93    fb = np.zeros((num_bands, num_bins))
94    for i in range(num_bands):
95        fb[i, band_limits[i]:band_limits[i+1]] = 1
96
97    return fb
98
99
100def _compare(x, y, apply_vad=False, factor=1):
101    """ Modified version of opus_compare for 16 kHz mono signals
102
103    Args:
104        x (np.ndarray): reference input signal scaled to [-1, 1]
105        y (np.ndarray): test signal scaled to [-1, 1]
106
107    Returns:
108        float: perceptually weighted error
109    """
110    # filter bank: bark scale with minimum-2-bin bands and cutoff at 7.5 kHz
111    band_limits = [factor * b for b in [0, 2, 4, 6, 7, 9, 11, 13, 15, 18, 22, 26, 31, 36, 43, 51, 60, 75]]
112    window_size = factor * 160
113    hop_size = factor * 40
114    num_bins = window_size // 2 + 1
115    num_bands = len(band_limits) - 1
116    fb = rect_fb(band_limits, num_bins=num_bins)
117
118    # trim samples to same size
119    num_samples = min(len(x), len(y))
120    x = x[:num_samples].copy() * 2**15
121    y = y[:num_samples].copy() * 2**15
122
123    psd_x = power_spectrum(x, window_size=window_size, hop_size=hop_size) + 100000
124    psd_y = power_spectrum(y, window_size=window_size, hop_size=hop_size) + 100000
125
126    num_frames = psd_x.shape[0]
127
128    # average band energies
129    be_x = (psd_x @ fb.T) / np.sum(fb, axis=1)
130
131    # frequecy masking
132    f_mask = frequency_mask(num_bands, 0.1, 0.03)
133    mask_x = be_x @ f_mask.T
134
135    # temporal masking
136    for i in range(1, num_frames):
137        mask_x[i, :] += (0.5 ** factor) * mask_x[i-1, :]
138
139    # apply mask
140    masked_psd_x = psd_x + 0.1 * (mask_x @ fb)
141    masked_psd_y = psd_y + 0.1 * (mask_x @ fb)
142
143    # 2-frame average
144    masked_psd_x = masked_psd_x[1:] +  masked_psd_x[:-1]
145    masked_psd_y = masked_psd_y[1:] +  masked_psd_y[:-1]
146
147    # distortion metric
148    re = masked_psd_y / masked_psd_x
149    #im = re - np.log(re) - 1
150    im = np.log(re) ** 2
151    Eb = ((im @ fb.T) / np.sum(fb, axis=1))
152    Ef = np.mean(Eb ** 1, axis=1)
153
154    if apply_vad:
155        _, mask = compute_vad_mask(x, 16000)
156        mask = convert_mask(mask, Ef.shape[0])
157    else:
158        mask = np.ones_like(Ef)
159
160    err = np.mean(np.abs(Ef[mask > 1e-6]) ** 3) ** (1/6)
161
162    return float(err)
163
164def compare(x, y, apply_vad=False):
165    err = np.linalg.norm([_compare(x, y, apply_vad=apply_vad, factor=1)], ord=2)
166    return err
167
168if __name__ == "__main__":
169    import argparse
170    from scipy.io import wavfile
171
172    parser = argparse.ArgumentParser()
173    parser.add_argument('ref', type=str, help='reference wav file')
174    parser.add_argument('deg', type=str, help='degraded wav file')
175    parser.add_argument('--apply-vad', action='store_true')
176    args = parser.parse_args()
177
178
179    fs1, x = wavfile.read(args.ref)
180    fs2, y = wavfile.read(args.deg)
181
182    if max(fs1, fs2) != 16000:
183        raise ValueError('error: encountered sampling frequency diffrent from 16kHz')
184
185    x = x.astype(np.float32) / 2**15
186    y = y.astype(np.float32) / 2**15
187
188    err = compare(x, y, apply_vad=args.apply_vad)
189
190    print(f"MOC: {err}")
191