xref: /aosp_15_r20/external/libopus/dnn/torch/osce/data/lpcnet_vocoding_dataset.py (revision a58d3d2adb790c104798cd88c8a3aff4fa8b82cc)
1"""
2/* Copyright (c) 2023 Amazon
3   Written by Jan Buethe */
4/*
5   Redistribution and use in source and binary forms, with or without
6   modification, are permitted provided that the following conditions
7   are met:
8
9   - Redistributions of source code must retain the above copyright
10   notice, this list of conditions and the following disclaimer.
11
12   - Redistributions in binary form must reproduce the above copyright
13   notice, this list of conditions and the following disclaimer in the
14   documentation and/or other materials provided with the distribution.
15
16   THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
17   ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
18   LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
19   A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
20   OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
21   EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
22   PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
23   PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
24   LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
25   NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
26   SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
27*/
28"""
29
30""" Dataset for LPCNet training """
31import os
32
33import yaml
34import torch
35import numpy as np
36from torch.utils.data import Dataset
37
38
39scale = 255.0/32768.0
40scale_1 = 32768.0/255.0
41def ulaw2lin(u):
42    u = u - 128
43    s = np.sign(u)
44    u = np.abs(u)
45    return s*scale_1*(np.exp(u/128.*np.log(256))-1)
46
47
48def lin2ulaw(x):
49    s = np.sign(x)
50    x = np.abs(x)
51    u = (s*(128*np.log(1+scale*x)/np.log(256)))
52    u = np.clip(128 + np.round(u), 0, 255)
53    return u
54
55
56def run_lpc(signal, lpcs, frame_length=160):
57    num_frames, lpc_order = lpcs.shape
58
59    prediction = np.concatenate(
60        [- np.convolve(signal[i * frame_length : (i + 1) * frame_length + lpc_order - 1], lpcs[i], mode='valid') for i in range(num_frames)]
61    )
62    error = signal[lpc_order :] - prediction
63
64    return prediction, error
65
66class LPCNetVocodingDataset(Dataset):
67    def __init__(self,
68                 path_to_dataset,
69                 features=['cepstrum', 'periods', 'pitch_corr'],
70                 target='signal',
71                 frames_per_sample=100,
72                 feature_history=0,
73                 feature_lookahead=0,
74                 lpc_gamma=1):
75
76        super().__init__()
77
78        # load dataset info
79        self.path_to_dataset = path_to_dataset
80        with open(os.path.join(path_to_dataset, 'info.yml'), 'r') as f:
81            dataset = yaml.load(f, yaml.FullLoader)
82
83        # dataset version
84        self.version = dataset['version']
85        if self.version == 1:
86            self.getitem = self.getitem_v1
87        elif self.version == 2:
88            self.getitem = self.getitem_v2
89        else:
90            raise ValueError(f"dataset version {self.version} unknown")
91
92        # features
93        self.feature_history      = feature_history
94        self.feature_lookahead    = feature_lookahead
95        self.frame_offset         = 2 + self.feature_history
96        self.frames_per_sample    = frames_per_sample
97        self.input_features       = features
98        self.feature_frame_layout = dataset['feature_frame_layout']
99        self.lpc_gamma            = lpc_gamma
100
101        # load feature file
102        self.feature_file = os.path.join(path_to_dataset, dataset['feature_file'])
103        self.features = np.memmap(self.feature_file, dtype=dataset['feature_dtype'])
104        self.feature_frame_length = dataset['feature_frame_length']
105
106        assert len(self.features) % self.feature_frame_length == 0
107        self.features = self.features.reshape((-1, self.feature_frame_length))
108
109        # derive number of samples is dataset
110        self.dataset_length = (len(self.features) - self.frame_offset - self.feature_lookahead - 1 - 2) // self.frames_per_sample
111
112        # signals
113        self.frame_length               = dataset['frame_length']
114        self.signal_frame_layout        = dataset['signal_frame_layout']
115        self.target                     = target
116
117        # load signals
118        self.signal_file  = os.path.join(path_to_dataset, dataset['signal_file'])
119        self.signals  = np.memmap(self.signal_file, dtype=dataset['signal_dtype'])
120        self.signal_frame_length  = dataset['signal_frame_length']
121        self.signals = self.signals.reshape((-1, self.signal_frame_length))
122        assert len(self.signals) == len(self.features) * self.frame_length
123
124
125    def __getitem__(self, index):
126        return self.getitem(index)
127
128    def getitem_v2(self, index):
129        sample = dict()
130
131        # extract features
132        frame_start = self.frame_offset + index       * self.frames_per_sample - self.feature_history
133        frame_stop  = self.frame_offset + (index + 1) * self.frames_per_sample + self.feature_lookahead
134
135        for feature in self.input_features:
136            feature_start, feature_stop = self.feature_frame_layout[feature]
137            sample[feature] = self.features[frame_start : frame_stop, feature_start : feature_stop]
138
139        # convert periods
140        if 'periods' in self.input_features:
141            sample['periods'] = (0.1 + 50 * sample['periods'] + 100).astype('int16')
142
143        signal_start = (self.frame_offset + index       * self.frames_per_sample) * self.frame_length
144        signal_stop  = (self.frame_offset + (index + 1) * self.frames_per_sample) * self.frame_length
145
146        # last_signal and signal are always expected to be there
147        sample['last_signal'] = self.signals[signal_start : signal_stop, self.signal_frame_layout['last_signal']]
148        sample['signal'] = self.signals[signal_start : signal_stop, self.signal_frame_layout['signal']]
149
150        # calculate prediction and error if lpc coefficients present and prediction not given
151        if 'lpc' in self.feature_frame_layout and 'prediction' not in self.signal_frame_layout:
152            # lpc coefficients with one frame lookahead
153            # frame positions (start one frame early for past excitation)
154            frame_start = self.frame_offset + self.frames_per_sample * index - 1
155            frame_stop  = self.frame_offset + self.frames_per_sample * (index + 1)
156
157            # feature positions
158            lpc_start, lpc_stop = self.feature_frame_layout['lpc']
159            lpc_order = lpc_stop - lpc_start
160            lpcs = self.features[frame_start : frame_stop, lpc_start : lpc_stop]
161
162            # LPC weighting
163            lpc_order = lpc_stop - lpc_start
164            weights = np.array([self.lpc_gamma ** (i + 1) for i in range(lpc_order)])
165            lpcs = lpcs * weights
166
167            # signal position (lpc_order samples as history)
168            signal_start = frame_start * self.frame_length - lpc_order + 1
169            signal_stop  = frame_stop  * self.frame_length + 1
170            noisy_signal = self.signals[signal_start : signal_stop, self.signal_frame_layout['last_signal']]
171            clean_signal = self.signals[signal_start - 1 : signal_stop - 1, self.signal_frame_layout['signal']]
172
173            noisy_prediction, noisy_error = run_lpc(noisy_signal, lpcs, frame_length=self.frame_length)
174
175            # extract signals
176            offset = self.frame_length
177            sample['prediction'] = noisy_prediction[offset : offset + self.frame_length * self.frames_per_sample]
178            sample['last_error'] = noisy_error[offset - 1 : offset - 1 + self.frame_length * self.frames_per_sample]
179            # calculate error between real signal and noisy prediction
180
181
182            sample['error'] = sample['signal'] - sample['prediction']
183
184
185        # concatenate features
186        feature_keys = [key for key in self.input_features if not key.startswith("periods")]
187        features = torch.concat([torch.FloatTensor(sample[key]) for key in feature_keys], dim=-1)
188        target  = torch.FloatTensor(sample[self.target]) / 2**15
189        periods = torch.LongTensor(sample['periods'])
190
191        return {'features' : features, 'periods' : periods, 'target' : target}
192
193    def getitem_v1(self, index):
194        sample = dict()
195
196        # extract features
197        frame_start = self.frame_offset + index       * self.frames_per_sample - self.feature_history
198        frame_stop  = self.frame_offset + (index + 1) * self.frames_per_sample + self.feature_lookahead
199
200        for feature in self.input_features:
201            feature_start, feature_stop = self.feature_frame_layout[feature]
202            sample[feature] = self.features[frame_start : frame_stop, feature_start : feature_stop]
203
204        # convert periods
205        if 'periods' in self.input_features:
206            sample['periods'] = (0.1 + 50 * sample['periods'] + 100).astype('int16')
207
208        signal_start = (self.frame_offset + index       * self.frames_per_sample) * self.frame_length
209        signal_stop  = (self.frame_offset + (index + 1) * self.frames_per_sample) * self.frame_length
210
211        # last_signal and signal are always expected to be there
212        for signal_name, index in self.signal_frame_layout.items():
213            sample[signal_name] = self.signals[signal_start : signal_stop, index]
214
215        # concatenate features
216        feature_keys = [key for key in self.input_features if not key.startswith("periods")]
217        features = torch.concat([torch.FloatTensor(sample[key]) for key in feature_keys], dim=-1)
218        signals = torch.cat([torch.LongTensor(sample[key]).unsqueeze(-1) for key in self.input_signals], dim=-1)
219        target  = torch.LongTensor(sample[self.target])
220        periods = torch.LongTensor(sample['periods'])
221
222        return {'features' : features, 'periods' : periods, 'signals' : signals, 'target' : target}
223
224    def __len__(self):
225        return self.dataset_length
226