xref: /aosp_15_r20/external/libopus/dnn/torch/osce/engine/vocoder_engine.py (revision a58d3d2adb790c104798cd88c8a3aff4fa8b82cc)
1import torch
2from tqdm import tqdm
3import sys
4
5def train_one_epoch(model, criterion, optimizer, dataloader, device, scheduler, log_interval=10):
6
7    model.to(device)
8    model.train()
9
10    running_loss = 0
11    previous_running_loss = 0
12
13
14    with tqdm(dataloader, unit='batch', file=sys.stdout) as tepoch:
15
16        for i, batch in enumerate(tepoch):
17
18            # set gradients to zero
19            optimizer.zero_grad()
20
21
22            # push batch to device
23            for key in batch:
24                batch[key] = batch[key].to(device)
25
26            target = batch['target']
27
28            # calculate model output
29            output = model(batch['features'], batch['periods'])
30
31            # calculate loss
32            if isinstance(output, list):
33                loss = torch.zeros(1, device=device)
34                for y in output:
35                    loss = loss + criterion(target, y.squeeze(1))
36                loss = loss / len(output)
37            else:
38                loss = criterion(target, output.squeeze(1))
39
40            # calculate gradients
41            loss.backward()
42
43            # update weights
44            optimizer.step()
45
46            # update learning rate
47            scheduler.step()
48
49            # update running loss
50            running_loss += float(loss.cpu())
51
52            # update status bar
53            if i % log_interval == 0:
54                tepoch.set_postfix(running_loss=f"{running_loss/(i + 1):8.7f}", current_loss=f"{(running_loss - previous_running_loss)/log_interval:8.7f}")
55                previous_running_loss = running_loss
56
57
58    running_loss /= len(dataloader)
59
60    return running_loss
61
62def evaluate(model, criterion, dataloader, device, log_interval=10):
63
64    model.to(device)
65    model.eval()
66
67    running_loss = 0
68    previous_running_loss = 0
69
70
71    with torch.no_grad():
72        with tqdm(dataloader, unit='batch', file=sys.stdout) as tepoch:
73
74            for i, batch in enumerate(tepoch):
75
76
77
78                # push batch to device
79                for key in batch:
80                    batch[key] = batch[key].to(device)
81
82                target = batch['target']
83
84                # calculate model output
85                output = model(batch['features'], batch['periods'])
86
87                # calculate loss
88                loss = criterion(target, output.squeeze(1))
89
90                # update running loss
91                running_loss += float(loss.cpu())
92
93                # update status bar
94                if i % log_interval == 0:
95                    tepoch.set_postfix(running_loss=f"{running_loss/(i + 1):8.7f}", current_loss=f"{(running_loss - previous_running_loss)/log_interval:8.7f}")
96                    previous_running_loss = running_loss
97
98
99        running_loss /= len(dataloader)
100
101        return running_loss