# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. import logging from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e from torch.ao.quantization.quantizer.xnnpack_quantizer import ( get_symmetric_quantization_config, XNNPACKQuantizer, ) def quantize(model, example_inputs): """This is the official recommended flow for quantization in pytorch 2.0 export""" logging.info(f"Original model: {model}") quantizer = XNNPACKQuantizer() # if we set is_per_channel to True, we also need to add out_variant of quantize_per_channel/dequantize_per_channel operator_config = get_symmetric_quantization_config(is_per_channel=False) quantizer.set_global(operator_config) m = prepare_pt2e(model, quantizer) # calibration m(*example_inputs) m = convert_pt2e(m) logging.info(f"Quantized model: {m}") # make sure we can export to flat buffer return m