1import torch 2import torchvision 3import torchvision.transforms.transforms as transforms 4import os 5import torch.ao.quantization 6from torchvision.models.quantization.resnet import resnet18 7from torch.autograd import Variable 8 9# Setup warnings 10import warnings 11warnings.filterwarnings( 12 action='ignore', 13 category=DeprecationWarning, 14 module=r'.*' 15) 16warnings.filterwarnings( 17 action='default', 18 module=r'torch.ao.quantization' 19) 20 21""" 22Define helper functions for APoT PTQ and QAT 23""" 24 25# Specify random seed for repeatable results 26_ = torch.manual_seed(191009) 27 28train_batch_size = 30 29eval_batch_size = 50 30 31class AverageMeter: 32 """Computes and stores the average and current value""" 33 def __init__(self, name, fmt=':f'): 34 self.name = name 35 self.fmt = fmt 36 self.reset() 37 38 def reset(self): 39 self.val = 0 40 self.avg = 0.0 41 self.sum = 0 42 self.count = 0 43 44 def update(self, val, n=1): 45 self.val = val 46 self.sum += val * n 47 self.count += n 48 self.avg = self.sum / self.count 49 50 def __str__(self): 51 fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' 52 return fmtstr.format(**self.__dict__) 53 54 55def accuracy(output, target, topk=(1,)): 56 """Computes the accuracy over the k top predictions for the specified values of k""" 57 with torch.no_grad(): 58 maxk = max(topk) 59 batch_size = target.size(0) 60 61 _, pred = output.topk(maxk, 1, True, True) 62 pred = pred.t() 63 correct = pred.eq(target.view(1, -1).expand_as(pred)) 64 65 res = [] 66 for k in topk: 67 correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) 68 res.append(correct_k.mul_(100.0 / batch_size)) 69 return res 70 71 72def evaluate(model, criterion, data_loader): 73 model.eval() 74 top1 = AverageMeter('Acc@1', ':6.2f') 75 top5 = AverageMeter('Acc@5', ':6.2f') 76 with torch.no_grad(): 77 for image, target in data_loader: 78 output = model(image) 79 loss = criterion(output, target) 80 acc1, acc5 = accuracy(output, target, topk=(1, 5)) 81 top1.update(acc1[0], image.size(0)) 82 top5.update(acc5[0], image.size(0)) 83 print() 84 85 return top1, top5 86 87def load_model(model_file): 88 model = resnet18(pretrained=False) 89 state_dict = torch.load(model_file) 90 model.load_state_dict(state_dict) 91 model.to("cpu") 92 return model 93 94def print_size_of_model(model): 95 if isinstance(model, torch.jit.RecursiveScriptModule): 96 torch.jit.save(model, "temp.p") 97 else: 98 torch.jit.save(torch.jit.script(model), "temp.p") 99 print("Size (MB):", os.path.getsize("temp.p") / 1e6) 100 os.remove("temp.p") 101 102def prepare_data_loaders(data_path): 103 104 normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 105 std=[0.229, 0.224, 0.225]) 106 dataset = torchvision.datasets.ImageNet(data_path, 107 split="train", 108 transform=transforms.Compose([transforms.RandomResizedCrop(224), 109 transforms.RandomHorizontalFlip(), 110 transforms.ToTensor(), 111 normalize])) 112 dataset_test = torchvision.datasets.ImageNet(data_path, 113 split="val", 114 transform=transforms.Compose([transforms.Resize(256), 115 transforms.CenterCrop(224), 116 transforms.ToTensor(), 117 normalize])) 118 119 train_sampler = torch.utils.data.RandomSampler(dataset) 120 test_sampler = torch.utils.data.SequentialSampler(dataset_test) 121 122 data_loader = torch.utils.data.DataLoader( 123 dataset, batch_size=train_batch_size, 124 sampler=train_sampler) 125 126 data_loader_test = torch.utils.data.DataLoader( 127 dataset_test, batch_size=eval_batch_size, 128 sampler=test_sampler) 129 130 return data_loader, data_loader_test 131 132def training_loop(model, criterion, data_loader): 133 optimizer = torch.optim.Adam(model.parameters(), lr=0.001) 134 train_loss, correct, total = 0, 0, 0 135 model.train() 136 for i in range(10): 137 for data, target in data_loader: 138 optimizer.zero_grad() 139 output = model(data) 140 loss = criterion(output, target) 141 loss = Variable(loss, requires_grad=True) 142 loss.backward() 143 optimizer.step() 144 train_loss += loss.item() 145 _, predicted = torch.max(output, 1) 146 total += target.size(0) 147 correct += (predicted == target).sum().item() 148 return train_loss, correct, total 149