xref: /aosp_15_r20/external/executorch/examples/xnnpack/quantization/utils.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1*523fa7a6SAndroid Build Coastguard Worker# Copyright (c) Meta Platforms, Inc. and affiliates.
2*523fa7a6SAndroid Build Coastguard Worker# All rights reserved.
3*523fa7a6SAndroid Build Coastguard Worker#
4*523fa7a6SAndroid Build Coastguard Worker# This source code is licensed under the BSD-style license found in the
5*523fa7a6SAndroid Build Coastguard Worker# LICENSE file in the root directory of this source tree.
6*523fa7a6SAndroid Build Coastguard Worker
7*523fa7a6SAndroid Build Coastguard Workerimport logging
8*523fa7a6SAndroid Build Coastguard Worker
9*523fa7a6SAndroid Build Coastguard Workerfrom torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e
10*523fa7a6SAndroid Build Coastguard Workerfrom torch.ao.quantization.quantizer.xnnpack_quantizer import (
11*523fa7a6SAndroid Build Coastguard Worker    get_symmetric_quantization_config,
12*523fa7a6SAndroid Build Coastguard Worker    XNNPACKQuantizer,
13*523fa7a6SAndroid Build Coastguard Worker)
14*523fa7a6SAndroid Build Coastguard Worker
15*523fa7a6SAndroid Build Coastguard Worker
16*523fa7a6SAndroid Build Coastguard Workerdef quantize(model, example_inputs):
17*523fa7a6SAndroid Build Coastguard Worker    """This is the official recommended flow for quantization in pytorch 2.0 export"""
18*523fa7a6SAndroid Build Coastguard Worker    logging.info(f"Original model: {model}")
19*523fa7a6SAndroid Build Coastguard Worker    quantizer = XNNPACKQuantizer()
20*523fa7a6SAndroid Build Coastguard Worker    # if we set is_per_channel to True, we also need to add out_variant of quantize_per_channel/dequantize_per_channel
21*523fa7a6SAndroid Build Coastguard Worker    operator_config = get_symmetric_quantization_config(is_per_channel=False)
22*523fa7a6SAndroid Build Coastguard Worker    quantizer.set_global(operator_config)
23*523fa7a6SAndroid Build Coastguard Worker    m = prepare_pt2e(model, quantizer)
24*523fa7a6SAndroid Build Coastguard Worker    # calibration
25*523fa7a6SAndroid Build Coastguard Worker    m(*example_inputs)
26*523fa7a6SAndroid Build Coastguard Worker    m = convert_pt2e(m)
27*523fa7a6SAndroid Build Coastguard Worker    logging.info(f"Quantized model: {m}")
28*523fa7a6SAndroid Build Coastguard Worker    # make sure we can export to flat buffer
29*523fa7a6SAndroid Build Coastguard Worker    return m
30