xref: /aosp_15_r20/external/libopus/dnn/training_tf2/dataloader.py (revision a58d3d2adb790c104798cd88c8a3aff4fa8b82cc)
1import numpy as np
2from tensorflow.keras.utils import Sequence
3from ulaw import lin2ulaw
4
5def lpc2rc(lpc):
6    #print("shape is = ", lpc.shape)
7    order = lpc.shape[-1]
8    rc = 0*lpc
9    for i in range(order, 0, -1):
10        rc[:,:,i-1] = lpc[:,:,-1]
11        ki = rc[:,:,i-1:i].repeat(i-1, axis=2)
12        lpc = (lpc[:,:,:-1] - ki*lpc[:,:,-2::-1])/(1-ki*ki)
13    return rc
14
15class LPCNetLoader(Sequence):
16    def __init__(self, data, features, periods, batch_size, e2e=False, lookahead=2):
17        self.batch_size = batch_size
18        self.nb_batches = np.minimum(np.minimum(data.shape[0], features.shape[0]), periods.shape[0])//self.batch_size
19        self.data = data[:self.nb_batches*self.batch_size, :]
20        self.features = features[:self.nb_batches*self.batch_size, :]
21        self.periods = periods[:self.nb_batches*self.batch_size, :]
22        self.e2e = e2e
23        self.lookahead = lookahead
24        self.on_epoch_end()
25
26    def on_epoch_end(self):
27        self.indices = np.arange(self.nb_batches*self.batch_size)
28        np.random.shuffle(self.indices)
29
30    def __getitem__(self, index):
31        data = self.data[self.indices[index*self.batch_size:(index+1)*self.batch_size], :, :]
32        in_data = data[: , :, :1]
33        out_data = data[: , :, 1:]
34        features = self.features[self.indices[index*self.batch_size:(index+1)*self.batch_size], :, :-16]
35        periods = self.periods[self.indices[index*self.batch_size:(index+1)*self.batch_size], :, :]
36        outputs = [out_data]
37        inputs = [in_data, features, periods]
38        if self.lookahead > 0:
39            lpc = self.features[self.indices[index*self.batch_size:(index+1)*self.batch_size], 4-self.lookahead:-self.lookahead, -16:]
40        else:
41            lpc = self.features[self.indices[index*self.batch_size:(index+1)*self.batch_size], 4:, -16:]
42        if self.e2e:
43            outputs.append(lpc2rc(lpc))
44        else:
45            inputs.append(lpc)
46        return (inputs, outputs)
47
48    def __len__(self):
49        return self.nb_batches
50