xref: /aosp_15_r20/external/executorch/examples/xnnpack/quantization/example.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7# pyre-unsafe
8
9import argparse
10import copy
11import logging
12import time
13
14import torch
15from executorch.exir import EdgeCompileConfig
16from executorch.exir.capture._config import ExecutorchBackendConfig
17from executorch.extension.export_util.utils import export_to_edge, save_pte_program
18from torch.ao.ns.fx.utils import compute_sqnr
19from torch.ao.quantization import (  # @manual
20    default_per_channel_symmetric_qnnpack_qconfig,
21    QConfigMapping,
22)
23from torch.ao.quantization.backend_config import get_executorch_backend_config
24from torch.ao.quantization.quantize_fx import (
25    _convert_to_reference_decomposed_fx,
26    prepare_fx,
27)
28from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e
29from torch.ao.quantization.quantizer.xnnpack_quantizer import (
30    get_symmetric_quantization_config,
31    XNNPACKQuantizer,
32)
33
34from ...models import MODEL_NAME_TO_MODEL
35from ...models.model_factory import EagerModelFactory
36
37from .. import MODEL_NAME_TO_OPTIONS
38from .utils import quantize
39
40
41FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s"
42logging.basicConfig(level=logging.INFO, format=FORMAT)
43
44
45def verify_xnnpack_quantizer_matching_fx_quant_model(model_name, model, example_inputs):
46    """This is a verification against fx graph mode quantization flow as a sanity check"""
47
48    if model_name in ["edsr", "mobilebert"]:
49        # EDSR has control flows that are not traceable in symbolic_trace
50        # mobilebert is not symbolically traceable with torch.fx.symbolic_trace
51        return
52    if model_name == "ic3":
53        # we don't want to compare results of inception_v3 with fx, since mul op with Scalar
54        # input is quantized differently in fx, and we don't want to replicate the behavior
55        # in XNNPACKQuantizer
56        return
57
58    model.eval()
59    m_copy = copy.deepcopy(model)
60    m = model
61
62    # 1. pytorch 2.0 export quantization flow (recommended/default flow)
63    m = torch.export.export_for_training(m, copy.deepcopy(example_inputs)).module()
64    quantizer = XNNPACKQuantizer()
65    quantization_config = get_symmetric_quantization_config(is_per_channel=True)
66    quantizer.set_global(quantization_config)
67    m = prepare_pt2e(m, quantizer)
68    # calibration
69    after_prepare_result = m(*example_inputs)
70    logging.info(f"prepare_pt2e: {m}")
71    m = convert_pt2e(m)
72    after_quant_result = m(*example_inputs)
73
74    # 2. the previous fx graph mode quantization reference flow
75    qconfig = default_per_channel_symmetric_qnnpack_qconfig
76    qconfig_mapping = QConfigMapping().set_global(qconfig)
77    backend_config = get_executorch_backend_config()
78    m_fx = prepare_fx(
79        m_copy, qconfig_mapping, example_inputs, backend_config=backend_config
80    )
81    after_prepare_result_fx = m_fx(*example_inputs)
82    logging.info(f"prepare_fx: {m_fx}")
83    m_fx = _convert_to_reference_decomposed_fx(m_fx, backend_config=backend_config)
84    after_quant_result_fx = m_fx(*example_inputs)
85
86    # 3. compare results
87    if model_name == "dl3":
88        # dl3 output format: {"out": a, "aux": b}
89        after_prepare_result = after_prepare_result["out"]
90        after_prepare_result_fx = after_prepare_result_fx["out"]
91        after_quant_result = after_quant_result["out"]
92        after_quant_result_fx = after_quant_result_fx["out"]
93    logging.info(f"m: {m}")
94    logging.info(f"m_fx: {m_fx}")
95    logging.info(
96        f"prepare sqnr: {compute_sqnr(after_prepare_result, after_prepare_result_fx)}"
97    )
98
99    # NB: this check is more useful for QAT since for PTQ we are only inserting observers that does not change the
100    # output of a model, so it's just testing the numerical difference for different captures in PTQ
101    # for QAT it is also testing whether the fake quant placement match or not
102    # not exactly the same due to capture changing numerics, but still really close
103    assert compute_sqnr(after_prepare_result, after_prepare_result_fx) > 100
104    logging.info(
105        f"quant diff max: {torch.max(after_quant_result - after_quant_result_fx)}"
106    )
107    assert torch.max(after_quant_result - after_quant_result_fx) < 1e-1
108    logging.info(
109        f"quant sqnr: {compute_sqnr(after_quant_result, after_quant_result_fx)}"
110    )
111    assert compute_sqnr(after_quant_result, after_quant_result_fx) > 30
112
113
114def main() -> None:
115    parser = argparse.ArgumentParser()
116    parser.add_argument(
117        "-m",
118        "--model_name",
119        required=True,
120        help=f"Provide model name. Valid ones: {list(MODEL_NAME_TO_OPTIONS.keys())}",
121    )
122    parser.add_argument(
123        "-ve",
124        "--verify",
125        action="store_true",
126        required=False,
127        default=False,
128        help="flag for verifying XNNPACKQuantizer against fx graph mode quantization",
129    )
130    parser.add_argument(
131        "-s",
132        "--so_library",
133        required=False,
134        help="shared library for quantized operators",
135    )
136
137    args = parser.parse_args()
138    # See if we have quantized op out variants registered
139    has_out_ops = True
140    try:
141        _ = torch.ops.quantized_decomposed.add.out
142    except AttributeError:
143        logging.info("No registered quantized ops")
144        has_out_ops = False
145    if not has_out_ops:
146        if args.so_library:
147            torch.ops.load_library(args.so_library)
148        else:
149            raise RuntimeError(
150                "Need to specify shared library path to register quantized ops (and their out variants) into"
151                "EXIR. The required shared library is defined as `quantized_ops_aot_lib` in "
152                "kernels/quantized/CMakeLists.txt if you are using CMake build, or `aot_lib` in "
153                "kernels/quantized/targets.bzl for buck2. One example path would be cmake-out/kernels/quantized/"
154                "libquantized_ops_aot_lib.[so|dylib]."
155            )
156    if not args.verify and args.model_name not in MODEL_NAME_TO_OPTIONS:
157        raise RuntimeError(
158            f"Model {args.model_name} is not a valid name. or not quantizable right now, "
159            "please contact executorch team if you want to learn why or how to support "
160            "quantization for the requested model"
161            f"Available models are {list(MODEL_NAME_TO_OPTIONS.keys())}."
162        )
163
164    start = time.perf_counter()
165    model, example_inputs, _, _ = EagerModelFactory.create_model(
166        *MODEL_NAME_TO_MODEL[args.model_name]
167    )
168    end = time.perf_counter()
169    # logging.info(f"Model init time: {end - start}s")
170    if args.verify:
171        start = time.perf_counter()
172        verify_xnnpack_quantizer_matching_fx_quant_model(
173            args.model_name, model, example_inputs
174        )
175        end = time.perf_counter()
176        # logging.info(f"Verify time: {end - start}s")
177
178    model = model.eval()
179    # pre-autograd export. eventually this will become torch.export
180    model = torch.export.export_for_training(model, example_inputs).module()
181    start = time.perf_counter()
182    quantized_model = quantize(model, example_inputs)
183    end = time.perf_counter()
184    logging.info(f"Quantize time: {end - start}s")
185
186    start = time.perf_counter()
187    edge_compile_config = EdgeCompileConfig(_check_ir_validity=False)
188    edge_m = export_to_edge(
189        quantized_model, example_inputs, edge_compile_config=edge_compile_config
190    )
191    end = time.perf_counter()
192    logging.info(f"Export time: {end - start}s")
193
194    start = time.perf_counter()
195    prog = edge_m.to_executorch(
196        config=ExecutorchBackendConfig(extract_delegate_segments=False)
197    )
198    save_pte_program(prog, f"{args.model_name}_quantized")
199    end = time.perf_counter()
200    logging.info(f"Save time: {end - start}s")
201    logging.info("finished")
202
203
204if __name__ == "__main__":
205    main()  # pragma: no cover
206