1*523fa7a6SAndroid Build Coastguard Worker# Copyright (c) Meta Platforms, Inc. and affiliates. 2*523fa7a6SAndroid Build Coastguard Worker# All rights reserved. 3*523fa7a6SAndroid Build Coastguard Worker# 4*523fa7a6SAndroid Build Coastguard Worker# This source code is licensed under the BSD-style license found in the 5*523fa7a6SAndroid Build Coastguard Worker# LICENSE file in the root directory of this source tree. 6*523fa7a6SAndroid Build Coastguard Worker 7*523fa7a6SAndroid Build Coastguard Workerimport logging 8*523fa7a6SAndroid Build Coastguard Worker 9*523fa7a6SAndroid Build Coastguard Workerfrom torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e 10*523fa7a6SAndroid Build Coastguard Workerfrom torch.ao.quantization.quantizer.xnnpack_quantizer import ( 11*523fa7a6SAndroid Build Coastguard Worker get_symmetric_quantization_config, 12*523fa7a6SAndroid Build Coastguard Worker XNNPACKQuantizer, 13*523fa7a6SAndroid Build Coastguard Worker) 14*523fa7a6SAndroid Build Coastguard Worker 15*523fa7a6SAndroid Build Coastguard Worker 16*523fa7a6SAndroid Build Coastguard Workerdef quantize(model, example_inputs): 17*523fa7a6SAndroid Build Coastguard Worker """This is the official recommended flow for quantization in pytorch 2.0 export""" 18*523fa7a6SAndroid Build Coastguard Worker logging.info(f"Original model: {model}") 19*523fa7a6SAndroid Build Coastguard Worker quantizer = XNNPACKQuantizer() 20*523fa7a6SAndroid Build Coastguard Worker # if we set is_per_channel to True, we also need to add out_variant of quantize_per_channel/dequantize_per_channel 21*523fa7a6SAndroid Build Coastguard Worker operator_config = get_symmetric_quantization_config(is_per_channel=False) 22*523fa7a6SAndroid Build Coastguard Worker quantizer.set_global(operator_config) 23*523fa7a6SAndroid Build Coastguard Worker m = prepare_pt2e(model, quantizer) 24*523fa7a6SAndroid Build Coastguard Worker # calibration 25*523fa7a6SAndroid Build Coastguard Worker m(*example_inputs) 26*523fa7a6SAndroid Build Coastguard Worker m = convert_pt2e(m) 27*523fa7a6SAndroid Build Coastguard Worker logging.info(f"Quantized model: {m}") 28*523fa7a6SAndroid Build Coastguard Worker # make sure we can export to flat buffer 29*523fa7a6SAndroid Build Coastguard Worker return m 30