1import torch 2from tqdm import tqdm 3import sys 4 5def train_one_epoch(model, criterion, optimizer, dataloader, device, scheduler, log_interval=10): 6 7 model.to(device) 8 model.train() 9 10 running_loss = 0 11 previous_running_loss = 0 12 13 14 with tqdm(dataloader, unit='batch', file=sys.stdout) as tepoch: 15 16 for i, batch in enumerate(tepoch): 17 18 # set gradients to zero 19 optimizer.zero_grad() 20 21 22 # push batch to device 23 for key in batch: 24 batch[key] = batch[key].to(device) 25 26 target = batch['target'] 27 28 # calculate model output 29 output = model(batch['features'], batch['periods']) 30 31 # calculate loss 32 if isinstance(output, list): 33 loss = torch.zeros(1, device=device) 34 for y in output: 35 loss = loss + criterion(target, y.squeeze(1)) 36 loss = loss / len(output) 37 else: 38 loss = criterion(target, output.squeeze(1)) 39 40 # calculate gradients 41 loss.backward() 42 43 # update weights 44 optimizer.step() 45 46 # update learning rate 47 scheduler.step() 48 49 # update running loss 50 running_loss += float(loss.cpu()) 51 52 # update status bar 53 if i % log_interval == 0: 54 tepoch.set_postfix(running_loss=f"{running_loss/(i + 1):8.7f}", current_loss=f"{(running_loss - previous_running_loss)/log_interval:8.7f}") 55 previous_running_loss = running_loss 56 57 58 running_loss /= len(dataloader) 59 60 return running_loss 61 62def evaluate(model, criterion, dataloader, device, log_interval=10): 63 64 model.to(device) 65 model.eval() 66 67 running_loss = 0 68 previous_running_loss = 0 69 70 71 with torch.no_grad(): 72 with tqdm(dataloader, unit='batch', file=sys.stdout) as tepoch: 73 74 for i, batch in enumerate(tepoch): 75 76 77 78 # push batch to device 79 for key in batch: 80 batch[key] = batch[key].to(device) 81 82 target = batch['target'] 83 84 # calculate model output 85 output = model(batch['features'], batch['periods']) 86 87 # calculate loss 88 loss = criterion(target, output.squeeze(1)) 89 90 # update running loss 91 running_loss += float(loss.cpu()) 92 93 # update status bar 94 if i % log_interval == 0: 95 tepoch.set_postfix(running_loss=f"{running_loss/(i + 1):8.7f}", current_loss=f"{(running_loss - previous_running_loss)/log_interval:8.7f}") 96 previous_running_loss = running_loss 97 98 99 running_loss /= len(dataloader) 100 101 return running_loss