xref: /aosp_15_r20/external/libopus/dnn/torch/lpcnet/engine/lpcnet_engine.py (revision a58d3d2adb790c104798cd88c8a3aff4fa8b82cc)
1"""
2/* Copyright (c) 2023 Amazon
3   Written by Jan Buethe */
4/*
5   Redistribution and use in source and binary forms, with or without
6   modification, are permitted provided that the following conditions
7   are met:
8
9   - Redistributions of source code must retain the above copyright
10   notice, this list of conditions and the following disclaimer.
11
12   - Redistributions in binary form must reproduce the above copyright
13   notice, this list of conditions and the following disclaimer in the
14   documentation and/or other materials provided with the distribution.
15
16   THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
17   ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
18   LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
19   A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
20   OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
21   EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
22   PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
23   PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
24   LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
25   NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
26   SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
27*/
28"""
29
30import torch
31from tqdm import tqdm
32import sys
33
34def train_one_epoch(model, criterion, optimizer, dataloader, device, scheduler, log_interval=10):
35
36    model.to(device)
37    model.train()
38
39    running_loss = 0
40    previous_running_loss = 0
41
42    # gru states
43    gru_a_state = torch.zeros(1, dataloader.batch_size, model.gru_a_units, device=device).to(device)
44    gru_b_state = torch.zeros(1, dataloader.batch_size, model.gru_b_units, device=device).to(device)
45    gru_states = [gru_a_state, gru_b_state]
46
47    with tqdm(dataloader, unit='batch', file=sys.stdout) as tepoch:
48
49        for i, batch in enumerate(tepoch):
50
51            # set gradients to zero
52            optimizer.zero_grad()
53
54            # zero out initial gru states
55            gru_a_state.zero_()
56            gru_b_state.zero_()
57
58            # push batch to device
59            for key in batch:
60                batch[key] = batch[key].to(device)
61
62            target = batch['target']
63
64            # calculate model output
65            output = model(batch['features'], batch['periods'], batch['signals'], gru_states)
66
67            # calculate loss
68            loss = criterion(output.permute(0, 2, 1), target)
69
70            # calculate gradients
71            loss.backward()
72
73            # update weights
74            optimizer.step()
75
76            # update learning rate
77            scheduler.step()
78
79            # call sparsifier
80            model.sparsify()
81
82            # update running loss
83            running_loss += float(loss.cpu())
84
85            # update status bar
86            if i % log_interval == 0:
87                tepoch.set_postfix(running_loss=f"{running_loss/(i + 1):8.7f}", current_loss=f"{(running_loss - previous_running_loss)/log_interval:8.7f}")
88                previous_running_loss = running_loss
89
90
91    running_loss /= len(dataloader)
92
93    return running_loss
94
95def evaluate(model, criterion, dataloader, device, log_interval=10):
96
97    model.to(device)
98    model.eval()
99
100    running_loss = 0
101    previous_running_loss = 0
102
103    # gru states
104    gru_a_state = torch.zeros(1, dataloader.batch_size, model.gru_a_units, device=device).to(device)
105    gru_b_state = torch.zeros(1, dataloader.batch_size, model.gru_b_units, device=device).to(device)
106    gru_states = [gru_a_state, gru_b_state]
107
108    with torch.no_grad():
109        with tqdm(dataloader, unit='batch', file=sys.stdout) as tepoch:
110
111            for i, batch in enumerate(tepoch):
112
113
114                # zero out initial gru states
115                gru_a_state.zero_()
116                gru_b_state.zero_()
117
118                # push batch to device
119                for key in batch:
120                    batch[key] = batch[key].to(device)
121
122                target = batch['target']
123
124                # calculate model output
125                output = model(batch['features'], batch['periods'], batch['signals'], gru_states)
126
127                # calculate loss
128                loss = criterion(output.permute(0, 2, 1), target)
129
130                # update running loss
131                running_loss += float(loss.cpu())
132
133                # update status bar
134                if i % log_interval == 0:
135                    tepoch.set_postfix(running_loss=f"{running_loss/(i + 1):8.7f}", current_loss=f"{(running_loss - previous_running_loss)/log_interval:8.7f}")
136                    previous_running_loss = running_loss
137
138
139        running_loss /= len(dataloader)
140
141        return running_loss