1import torch 2import torch.nn as nn 3import torch.ao.quantization 4from torchvision.models.quantization.resnet import resnet18 5from torch.ao.quantization.experimental.quantization_helper import ( 6 evaluate, 7 prepare_data_loaders 8) 9 10# validation dataset: full ImageNet dataset 11data_path = '~/my_imagenet/' 12 13data_loader, data_loader_test = prepare_data_loaders(data_path) 14criterion = nn.CrossEntropyLoss() 15float_model = resnet18(pretrained=True) 16float_model.eval() 17 18# deepcopy the model since we need to keep the original model around 19import copy 20model_to_quantize = copy.deepcopy(float_model) 21 22model_to_quantize.eval() 23 24""" 25Prepare models 26""" 27 28# Note that this is temporary, we'll expose these functions to torch.ao.quantization after official releasee 29from torch.ao.quantization.quantize_fx import prepare_qat_fx 30 31def calibrate(model, data_loader): 32 model.eval() 33 with torch.no_grad(): 34 for image, target in data_loader: 35 model(image) 36 37from torch.ao.quantization.experimental.qconfig import ( 38 uniform_qconfig_8bit, 39 apot_weights_qconfig_8bit, 40 apot_qconfig_8bit, 41 uniform_qconfig_4bit, 42 apot_weights_qconfig_4bit, 43 apot_qconfig_4bit 44) 45 46""" 47Prepare full precision model 48""" 49full_precision_model = float_model 50 51top1, top5 = evaluate(full_precision_model, criterion, data_loader_test) 52print(f"Model #0 Evaluation accuracy on test dataset: {top1.avg:2.2f}, {top5.avg:2.2f}") 53 54""" 55Prepare model PTQ for specified qconfig for torch.nn.Linear 56""" 57def prepare_ptq_linear(qconfig): 58 qconfig_dict = {"object_type": [(torch.nn.Linear, qconfig)]} 59 prepared_model = prepare_qat_fx(copy.deepcopy(float_model), qconfig_dict) # fuse modules and insert observers 60 calibrate(prepared_model, data_loader_test) # run calibration on sample data 61 return prepared_model 62 63""" 64Prepare model with uniform activation, uniform weight 65b=8, k=2 66""" 67 68prepared_model = prepare_ptq_linear(uniform_qconfig_8bit) 69quantized_model = convert_fx(prepared_model) # convert the calibrated model to a quantized model # noqa: F821 70 71top1, top5 = evaluate(quantized_model, criterion, data_loader_test) 72print(f"Model #1 Evaluation accuracy on test dataset (b=8, k=2): {top1.avg:2.2f}, {top5.avg:2.2f}") 73 74""" 75Prepare model with uniform activation, uniform weight 76b=4, k=2 77""" 78 79prepared_model = prepare_ptq_linear(uniform_qconfig_4bit) 80quantized_model = convert_fx(prepared_model) # convert the calibrated model to a quantized model # noqa: F821 81 82top1, top5 = evaluate(quantized_model1, criterion, data_loader_test) # noqa: F821 83print(f"Model #1 Evaluation accuracy on test dataset (b=4, k=2): {top1.avg:2.2f}, {top5.avg:2.2f}") 84 85""" 86Prepare model with uniform activation, APoT weight 87(b=8, k=2) 88""" 89 90prepared_model = prepare_ptq_linear(apot_weights_qconfig_8bit) 91 92top1, top5 = evaluate(prepared_model, criterion, data_loader_test) 93print(f"Model #2 Evaluation accuracy on test dataset (b=8, k=2): {top1.avg:2.2f}, {top5.avg:2.2f}") 94 95""" 96Prepare model with uniform activation, APoT weight 97(b=4, k=2) 98""" 99 100prepared_model = prepare_ptq_linear(apot_weights_qconfig_4bit) 101 102top1, top5 = evaluate(prepared_model, criterion, data_loader_test) 103print(f"Model #2 Evaluation accuracy on test dataset (b=4, k=2): {top1.avg:2.2f}, {top5.avg:2.2f}") 104 105 106""" 107Prepare model with APoT activation and weight 108(b=8, k=2) 109""" 110 111prepared_model = prepare_ptq_linear(apot_qconfig_8bit) 112 113top1, top5 = evaluate(prepared_model, criterion, data_loader_test) 114print(f"Model #3 Evaluation accuracy on test dataset (b=8, k=2): {top1.avg:2.2f}, {top5.avg:2.2f}") 115 116""" 117Prepare model with APoT activation and weight 118(b=4, k=2) 119""" 120 121prepared_model = prepare_ptq_linear(apot_qconfig_4bit) 122 123top1, top5 = evaluate(prepared_model, criterion, data_loader_test) 124print(f"Model #3 Evaluation accuracy on test dataset (b=4, k=2): {top1.avg:2.2f}, {top5.avg:2.2f}") 125 126""" 127Prepare eager mode quantized model 128""" 129eager_quantized_model = resnet18(pretrained=True, quantize=True).eval() 130top1, top5 = evaluate(eager_quantized_model, criterion, data_loader_test) 131print(f"Eager mode quantized model evaluation accuracy on test dataset: {top1.avg:2.2f}, {top5.avg:2.2f}") 132