xref: /aosp_15_r20/external/pytorch/test/quantization/core/experimental/quantization_util.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1import torch
2import torchvision
3import torchvision.transforms.transforms as transforms
4import os
5import torch.ao.quantization
6from torchvision.models.quantization.resnet import resnet18
7from torch.autograd import Variable
8
9# Setup warnings
10import warnings
11warnings.filterwarnings(
12    action='ignore',
13    category=DeprecationWarning,
14    module=r'.*'
15)
16warnings.filterwarnings(
17    action='default',
18    module=r'torch.ao.quantization'
19)
20
21"""
22Define helper functions for APoT PTQ and QAT
23"""
24
25# Specify random seed for repeatable results
26_ = torch.manual_seed(191009)
27
28train_batch_size = 30
29eval_batch_size = 50
30
31class AverageMeter:
32    """Computes and stores the average and current value"""
33    def __init__(self, name, fmt=':f'):
34        self.name = name
35        self.fmt = fmt
36        self.reset()
37
38    def reset(self):
39        self.val = 0
40        self.avg = 0.0
41        self.sum = 0
42        self.count = 0
43
44    def update(self, val, n=1):
45        self.val = val
46        self.sum += val * n
47        self.count += n
48        self.avg = self.sum / self.count
49
50    def __str__(self):
51        fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
52        return fmtstr.format(**self.__dict__)
53
54
55def accuracy(output, target, topk=(1,)):
56    """Computes the accuracy over the k top predictions for the specified values of k"""
57    with torch.no_grad():
58        maxk = max(topk)
59        batch_size = target.size(0)
60
61        _, pred = output.topk(maxk, 1, True, True)
62        pred = pred.t()
63        correct = pred.eq(target.view(1, -1).expand_as(pred))
64
65        res = []
66        for k in topk:
67            correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
68            res.append(correct_k.mul_(100.0 / batch_size))
69        return res
70
71
72def evaluate(model, criterion, data_loader):
73    model.eval()
74    top1 = AverageMeter('Acc@1', ':6.2f')
75    top5 = AverageMeter('Acc@5', ':6.2f')
76    with torch.no_grad():
77        for image, target in data_loader:
78            output = model(image)
79            loss = criterion(output, target)
80            acc1, acc5 = accuracy(output, target, topk=(1, 5))
81            top1.update(acc1[0], image.size(0))
82            top5.update(acc5[0], image.size(0))
83    print()
84
85    return top1, top5
86
87def load_model(model_file):
88    model = resnet18(pretrained=False)
89    state_dict = torch.load(model_file)
90    model.load_state_dict(state_dict)
91    model.to("cpu")
92    return model
93
94def print_size_of_model(model):
95    if isinstance(model, torch.jit.RecursiveScriptModule):
96        torch.jit.save(model, "temp.p")
97    else:
98        torch.jit.save(torch.jit.script(model), "temp.p")
99    print("Size (MB):", os.path.getsize("temp.p") / 1e6)
100    os.remove("temp.p")
101
102def prepare_data_loaders(data_path):
103
104    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
105                                     std=[0.229, 0.224, 0.225])
106    dataset = torchvision.datasets.ImageNet(data_path,
107                                            split="train",
108                                            transform=transforms.Compose([transforms.RandomResizedCrop(224),
109                                                                          transforms.RandomHorizontalFlip(),
110                                                                          transforms.ToTensor(),
111                                                                          normalize]))
112    dataset_test = torchvision.datasets.ImageNet(data_path,
113                                                 split="val",
114                                                 transform=transforms.Compose([transforms.Resize(256),
115                                                                               transforms.CenterCrop(224),
116                                                                               transforms.ToTensor(),
117                                                                               normalize]))
118
119    train_sampler = torch.utils.data.RandomSampler(dataset)
120    test_sampler = torch.utils.data.SequentialSampler(dataset_test)
121
122    data_loader = torch.utils.data.DataLoader(
123        dataset, batch_size=train_batch_size,
124        sampler=train_sampler)
125
126    data_loader_test = torch.utils.data.DataLoader(
127        dataset_test, batch_size=eval_batch_size,
128        sampler=test_sampler)
129
130    return data_loader, data_loader_test
131
132def training_loop(model, criterion, data_loader):
133    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
134    train_loss, correct, total = 0, 0, 0
135    model.train()
136    for i in range(10):
137        for data, target in data_loader:
138            optimizer.zero_grad()
139            output = model(data)
140            loss = criterion(output, target)
141            loss = Variable(loss, requires_grad=True)
142            loss.backward()
143            optimizer.step()
144            train_loss += loss.item()
145            _, predicted = torch.max(output, 1)
146            total += target.size(0)
147            correct += (predicted == target).sum().item()
148    return train_loss, correct, total
149