xref: /aosp_15_r20/external/libopus/dnn/training_tf2/train_rdovae.py (revision a58d3d2adb790c104798cd88c8a3aff4fa8b82cc)
1#!/usr/bin/python3
2'''Copyright (c) 2021-2022 Amazon
3   Copyright (c) 2018-2019 Mozilla
4
5   Redistribution and use in source and binary forms, with or without
6   modification, are permitted provided that the following conditions
7   are met:
8
9   - Redistributions of source code must retain the above copyright
10   notice, this list of conditions and the following disclaimer.
11
12   - Redistributions in binary form must reproduce the above copyright
13   notice, this list of conditions and the following disclaimer in the
14   documentation and/or other materials provided with the distribution.
15
16   THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
17   ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
18   LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
19   A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
20   CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
21   EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
22   PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
23   PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
24   LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
25   NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
26   SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
27'''
28
29# Train an LPCNet model
30import tensorflow as tf
31strategy = tf.distribute.MultiWorkerMirroredStrategy()
32
33
34import argparse
35#from plc_loader import PLCLoader
36
37parser = argparse.ArgumentParser(description='Train a quantization model')
38
39parser.add_argument('features', metavar='<features file>', help='binary features file (float32)')
40parser.add_argument('output', metavar='<output>', help='trained model file (.h5)')
41parser.add_argument('--model', metavar='<model>', default='rdovae', help='PLC model python definition (without .py)')
42group1 = parser.add_mutually_exclusive_group()
43group1.add_argument('--quantize', metavar='<input weights>', help='quantize model')
44group1.add_argument('--retrain', metavar='<input weights>', help='continue training model')
45parser.add_argument('--cond-size', metavar='<units>', default=1024, type=int, help='number of units in conditioning network (default 1024)')
46parser.add_argument('--epochs', metavar='<epochs>', default=120, type=int, help='number of epochs to train for (default 120)')
47parser.add_argument('--batch-size', metavar='<batch size>', default=128, type=int, help='batch size to use (default 128)')
48parser.add_argument('--seq-length', metavar='<sequence length>', default=1000, type=int, help='sequence length to use (default 1000)')
49parser.add_argument('--lr', metavar='<learning rate>', type=float, help='learning rate')
50parser.add_argument('--decay', metavar='<decay>', type=float, help='learning rate decay')
51parser.add_argument('--logdir', metavar='<log dir>', help='directory for tensorboard log files')
52
53
54args = parser.parse_args()
55
56import importlib
57rdovae = importlib.import_module(args.model)
58
59import sys
60import numpy as np
61from tensorflow.keras.optimizers import Adam
62from tensorflow.keras.callbacks import ModelCheckpoint, CSVLogger
63import tensorflow.keras.backend as K
64import h5py
65
66#gpus = tf.config.experimental.list_physical_devices('GPU')
67#if gpus:
68#  try:
69#    tf.config.experimental.set_virtual_device_configuration(gpus[0], [tf.config.experimental.VirtualDeviceConfiguration(memory_limit=5120)])
70#  except RuntimeError as e:
71#    print(e)
72
73nb_epochs = args.epochs
74
75# Try reducing batch_size if you run out of memory on your GPU
76batch_size = args.batch_size
77
78quantize = args.quantize is not None
79retrain = args.retrain is not None
80
81if quantize:
82    lr = 0.00003
83    decay = 0
84    input_model = args.quantize
85else:
86    lr = 0.001
87    decay = 2.5e-5
88
89if args.lr is not None:
90    lr = args.lr
91
92if args.decay is not None:
93    decay = args.decay
94
95if retrain:
96    input_model = args.retrain
97
98
99opt = Adam(lr, decay=decay, beta_2=0.99)
100
101with strategy.scope():
102    model, encoder, decoder, _ = rdovae.new_rdovae_model(nb_used_features=20, nb_bits=80, batch_size=batch_size, cond_size=args.cond_size, nb_quant=16)
103    model.compile(optimizer=opt, loss=[rdovae.feat_dist_loss, rdovae.feat_dist_loss, rdovae.sq1_rate_loss, rdovae.sq2_rate_loss], loss_weights=[.5, .5, 1., .1], metrics={'hard_bits':rdovae.sq_rate_metric})
104    model.summary()
105
106lpc_order = 16
107
108feature_file = args.features
109nb_features = model.nb_used_features + lpc_order
110nb_used_features = model.nb_used_features
111sequence_size = args.seq_length
112
113# u for unquantised, load 16 bit PCM samples and convert to mu-law
114
115
116features = np.memmap(feature_file, dtype='float32', mode='r')
117nb_sequences = len(features)//(nb_features*sequence_size)//batch_size*batch_size
118features = features[:nb_sequences*sequence_size*nb_features]
119
120features = np.reshape(features, (nb_sequences, sequence_size, nb_features))
121print(features.shape)
122features = features[:, :, :nb_used_features]
123
124#lambda_val = np.repeat(np.random.uniform(.0007, .002, (features.shape[0], 1, 1)), features.shape[1]//2, axis=1)
125#quant_id = np.round(10*np.log(lambda_val/.0007)).astype('int16')
126#quant_id = quant_id[:,:,0]
127quant_id = np.repeat(np.random.randint(16, size=(features.shape[0], 1, 1), dtype='int16'), features.shape[1]//2, axis=1)
128lambda_val = .0002*np.exp(quant_id/3.8)
129quant_id = quant_id[:,:,0]
130
131# dump models to disk as we go
132checkpoint = ModelCheckpoint('{}_{}_{}.h5'.format(args.output, args.cond_size, '{epoch:02d}'))
133
134if args.retrain is not None:
135    model.load_weights(args.retrain)
136
137if quantize or retrain:
138    #Adapting from an existing model
139    model.load_weights(input_model)
140
141model.save_weights('{}_{}_initial.h5'.format(args.output, args.cond_size))
142
143callbacks = [checkpoint]
144#callbacks = []
145
146if args.logdir is not None:
147    logdir = '{}/{}_{}_logs'.format(args.logdir, args.output, args.cond_size)
148    tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=logdir)
149    callbacks.append(tensorboard_callback)
150
151model.fit([features, quant_id, lambda_val], [features, features, features, features], batch_size=batch_size, epochs=nb_epochs, validation_split=0.0, callbacks=callbacks)
152