xref: /aosp_15_r20/external/libopus/dnn/torch/osce/engine/engine.py (revision a58d3d2adb790c104798cd88c8a3aff4fa8b82cc)
1import torch
2from tqdm import tqdm
3import sys
4
5def train_one_epoch(model, criterion, optimizer, dataloader, device, scheduler, log_interval=10):
6
7    model.to(device)
8    model.train()
9
10    running_loss = 0
11    previous_running_loss = 0
12
13
14    with tqdm(dataloader, unit='batch', file=sys.stdout) as tepoch:
15
16        for i, batch in enumerate(tepoch):
17
18            # set gradients to zero
19            optimizer.zero_grad()
20
21
22            # push batch to device
23            for key in batch:
24                batch[key] = batch[key].to(device)
25
26            target = batch['target']
27
28            # calculate model output
29            output = model(batch['signals'].permute(0, 2, 1), batch['features'], batch['periods'], batch['numbits'])
30
31            # calculate loss
32            if isinstance(output, list):
33                loss = torch.zeros(1, device=device)
34                for y in output:
35                    loss = loss + criterion(target, y.squeeze(1))
36                loss = loss / len(output)
37            else:
38                loss = criterion(target, output.squeeze(1))
39
40            # calculate gradients
41            loss.backward()
42
43            # update weights
44            optimizer.step()
45
46            # update learning rate
47            scheduler.step()
48
49            # sparsification
50            if hasattr(model, 'sparsifier'):
51                model.sparsifier()
52
53            # update running loss
54            running_loss += float(loss.cpu())
55
56            # update status bar
57            if i % log_interval == 0:
58                tepoch.set_postfix(running_loss=f"{running_loss/(i + 1):8.7f}", current_loss=f"{(running_loss - previous_running_loss)/log_interval:8.7f}")
59                previous_running_loss = running_loss
60
61
62    running_loss /= len(dataloader)
63
64    return running_loss
65
66def evaluate(model, criterion, dataloader, device, log_interval=10):
67
68    model.to(device)
69    model.eval()
70
71    running_loss = 0
72    previous_running_loss = 0
73
74
75    with torch.no_grad():
76        with tqdm(dataloader, unit='batch', file=sys.stdout) as tepoch:
77
78            for i, batch in enumerate(tepoch):
79
80                # push batch to device
81                for key in batch:
82                    batch[key] = batch[key].to(device)
83
84                target = batch['target']
85
86                # calculate model output
87                output = model(batch['signals'].permute(0, 2, 1), batch['features'], batch['periods'], batch['numbits'])
88
89                # calculate loss
90                loss = criterion(target, output.squeeze(1))
91
92                # update running loss
93                running_loss += float(loss.cpu())
94
95                # update status bar
96                if i % log_interval == 0:
97                    tepoch.set_postfix(running_loss=f"{running_loss/(i + 1):8.7f}", current_loss=f"{(running_loss - previous_running_loss)/log_interval:8.7f}")
98                    previous_running_loss = running_loss
99
100
101        running_loss /= len(dataloader)
102
103        return running_loss