1import torch 2from tqdm import tqdm 3import sys 4 5def train_one_epoch(model, criterion, optimizer, dataloader, device, scheduler, log_interval=10): 6 7 model.to(device) 8 model.train() 9 10 running_loss = 0 11 previous_running_loss = 0 12 13 14 with tqdm(dataloader, unit='batch', file=sys.stdout) as tepoch: 15 16 for i, batch in enumerate(tepoch): 17 18 # set gradients to zero 19 optimizer.zero_grad() 20 21 22 # push batch to device 23 for key in batch: 24 batch[key] = batch[key].to(device) 25 26 target = batch['target'] 27 28 # calculate model output 29 output = model(batch['signals'].permute(0, 2, 1), batch['features'], batch['periods'], batch['numbits']) 30 31 # calculate loss 32 if isinstance(output, list): 33 loss = torch.zeros(1, device=device) 34 for y in output: 35 loss = loss + criterion(target, y.squeeze(1)) 36 loss = loss / len(output) 37 else: 38 loss = criterion(target, output.squeeze(1)) 39 40 # calculate gradients 41 loss.backward() 42 43 # update weights 44 optimizer.step() 45 46 # update learning rate 47 scheduler.step() 48 49 # sparsification 50 if hasattr(model, 'sparsifier'): 51 model.sparsifier() 52 53 # update running loss 54 running_loss += float(loss.cpu()) 55 56 # update status bar 57 if i % log_interval == 0: 58 tepoch.set_postfix(running_loss=f"{running_loss/(i + 1):8.7f}", current_loss=f"{(running_loss - previous_running_loss)/log_interval:8.7f}") 59 previous_running_loss = running_loss 60 61 62 running_loss /= len(dataloader) 63 64 return running_loss 65 66def evaluate(model, criterion, dataloader, device, log_interval=10): 67 68 model.to(device) 69 model.eval() 70 71 running_loss = 0 72 previous_running_loss = 0 73 74 75 with torch.no_grad(): 76 with tqdm(dataloader, unit='batch', file=sys.stdout) as tepoch: 77 78 for i, batch in enumerate(tepoch): 79 80 # push batch to device 81 for key in batch: 82 batch[key] = batch[key].to(device) 83 84 target = batch['target'] 85 86 # calculate model output 87 output = model(batch['signals'].permute(0, 2, 1), batch['features'], batch['periods'], batch['numbits']) 88 89 # calculate loss 90 loss = criterion(target, output.squeeze(1)) 91 92 # update running loss 93 running_loss += float(loss.cpu()) 94 95 # update status bar 96 if i % log_interval == 0: 97 tepoch.set_postfix(running_loss=f"{running_loss/(i + 1):8.7f}", current_loss=f"{(running_loss - previous_running_loss)/log_interval:8.7f}") 98 previous_running_loss = running_loss 99 100 101 running_loss /= len(dataloader) 102 103 return running_loss