xref: /aosp_15_r20/external/pytorch/test/quantization/core/experimental/apot_fx_graph_mode_ptq.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1import torch
2import torch.nn as nn
3import torch.ao.quantization
4from torchvision.models.quantization.resnet import resnet18
5from torch.ao.quantization.experimental.quantization_helper import (
6    evaluate,
7    prepare_data_loaders
8)
9
10# validation dataset: full ImageNet dataset
11data_path = '~/my_imagenet/'
12
13data_loader, data_loader_test = prepare_data_loaders(data_path)
14criterion = nn.CrossEntropyLoss()
15float_model = resnet18(pretrained=True)
16float_model.eval()
17
18# deepcopy the model since we need to keep the original model around
19import copy
20model_to_quantize = copy.deepcopy(float_model)
21
22model_to_quantize.eval()
23
24"""
25Prepare models
26"""
27
28# Note that this is temporary, we'll expose these functions to torch.ao.quantization after official releasee
29from torch.ao.quantization.quantize_fx import prepare_qat_fx
30
31def calibrate(model, data_loader):
32    model.eval()
33    with torch.no_grad():
34        for image, target in data_loader:
35            model(image)
36
37from torch.ao.quantization.experimental.qconfig import (
38    uniform_qconfig_8bit,
39    apot_weights_qconfig_8bit,
40    apot_qconfig_8bit,
41    uniform_qconfig_4bit,
42    apot_weights_qconfig_4bit,
43    apot_qconfig_4bit
44)
45
46"""
47Prepare full precision model
48"""
49full_precision_model = float_model
50
51top1, top5 = evaluate(full_precision_model, criterion, data_loader_test)
52print(f"Model #0 Evaluation accuracy on test dataset: {top1.avg:2.2f}, {top5.avg:2.2f}")
53
54"""
55Prepare model PTQ for specified qconfig for torch.nn.Linear
56"""
57def prepare_ptq_linear(qconfig):
58    qconfig_dict = {"object_type": [(torch.nn.Linear, qconfig)]}
59    prepared_model = prepare_qat_fx(copy.deepcopy(float_model), qconfig_dict)  # fuse modules and insert observers
60    calibrate(prepared_model, data_loader_test)  # run calibration on sample data
61    return prepared_model
62
63"""
64Prepare model with uniform activation, uniform weight
65b=8, k=2
66"""
67
68prepared_model = prepare_ptq_linear(uniform_qconfig_8bit)
69quantized_model = convert_fx(prepared_model)  # convert the calibrated model to a quantized model  # noqa: F821
70
71top1, top5 = evaluate(quantized_model, criterion, data_loader_test)
72print(f"Model #1 Evaluation accuracy on test dataset (b=8, k=2): {top1.avg:2.2f}, {top5.avg:2.2f}")
73
74"""
75Prepare model with uniform activation, uniform weight
76b=4, k=2
77"""
78
79prepared_model = prepare_ptq_linear(uniform_qconfig_4bit)
80quantized_model = convert_fx(prepared_model)  # convert the calibrated model to a quantized model  # noqa: F821
81
82top1, top5 = evaluate(quantized_model1, criterion, data_loader_test)  # noqa: F821
83print(f"Model #1 Evaluation accuracy on test dataset (b=4, k=2): {top1.avg:2.2f}, {top5.avg:2.2f}")
84
85"""
86Prepare model with uniform activation, APoT weight
87(b=8, k=2)
88"""
89
90prepared_model = prepare_ptq_linear(apot_weights_qconfig_8bit)
91
92top1, top5 = evaluate(prepared_model, criterion, data_loader_test)
93print(f"Model #2 Evaluation accuracy on test dataset (b=8, k=2): {top1.avg:2.2f}, {top5.avg:2.2f}")
94
95"""
96Prepare model with uniform activation, APoT weight
97(b=4, k=2)
98"""
99
100prepared_model = prepare_ptq_linear(apot_weights_qconfig_4bit)
101
102top1, top5 = evaluate(prepared_model, criterion, data_loader_test)
103print(f"Model #2 Evaluation accuracy on test dataset (b=4, k=2): {top1.avg:2.2f}, {top5.avg:2.2f}")
104
105
106"""
107Prepare model with APoT activation and weight
108(b=8, k=2)
109"""
110
111prepared_model = prepare_ptq_linear(apot_qconfig_8bit)
112
113top1, top5 = evaluate(prepared_model, criterion, data_loader_test)
114print(f"Model #3 Evaluation accuracy on test dataset (b=8, k=2): {top1.avg:2.2f}, {top5.avg:2.2f}")
115
116"""
117Prepare model with APoT activation and weight
118(b=4, k=2)
119"""
120
121prepared_model = prepare_ptq_linear(apot_qconfig_4bit)
122
123top1, top5 = evaluate(prepared_model, criterion, data_loader_test)
124print(f"Model #3 Evaluation accuracy on test dataset (b=4, k=2): {top1.avg:2.2f}, {top5.avg:2.2f}")
125
126"""
127Prepare eager mode quantized model
128"""
129eager_quantized_model = resnet18(pretrained=True, quantize=True).eval()
130top1, top5 = evaluate(eager_quantized_model, criterion, data_loader_test)
131print(f"Eager mode quantized model evaluation accuracy on test dataset: {top1.avg:2.2f}, {top5.avg:2.2f}")
132