1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7# pyre-unsafe 8 9import argparse 10import copy 11import logging 12import time 13 14import torch 15from executorch.exir import EdgeCompileConfig 16from executorch.exir.capture._config import ExecutorchBackendConfig 17from executorch.extension.export_util.utils import export_to_edge, save_pte_program 18from torch.ao.ns.fx.utils import compute_sqnr 19from torch.ao.quantization import ( # @manual 20 default_per_channel_symmetric_qnnpack_qconfig, 21 QConfigMapping, 22) 23from torch.ao.quantization.backend_config import get_executorch_backend_config 24from torch.ao.quantization.quantize_fx import ( 25 _convert_to_reference_decomposed_fx, 26 prepare_fx, 27) 28from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e 29from torch.ao.quantization.quantizer.xnnpack_quantizer import ( 30 get_symmetric_quantization_config, 31 XNNPACKQuantizer, 32) 33 34from ...models import MODEL_NAME_TO_MODEL 35from ...models.model_factory import EagerModelFactory 36 37from .. import MODEL_NAME_TO_OPTIONS 38from .utils import quantize 39 40 41FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s" 42logging.basicConfig(level=logging.INFO, format=FORMAT) 43 44 45def verify_xnnpack_quantizer_matching_fx_quant_model(model_name, model, example_inputs): 46 """This is a verification against fx graph mode quantization flow as a sanity check""" 47 48 if model_name in ["edsr", "mobilebert"]: 49 # EDSR has control flows that are not traceable in symbolic_trace 50 # mobilebert is not symbolically traceable with torch.fx.symbolic_trace 51 return 52 if model_name == "ic3": 53 # we don't want to compare results of inception_v3 with fx, since mul op with Scalar 54 # input is quantized differently in fx, and we don't want to replicate the behavior 55 # in XNNPACKQuantizer 56 return 57 58 model.eval() 59 m_copy = copy.deepcopy(model) 60 m = model 61 62 # 1. pytorch 2.0 export quantization flow (recommended/default flow) 63 m = torch.export.export_for_training(m, copy.deepcopy(example_inputs)).module() 64 quantizer = XNNPACKQuantizer() 65 quantization_config = get_symmetric_quantization_config(is_per_channel=True) 66 quantizer.set_global(quantization_config) 67 m = prepare_pt2e(m, quantizer) 68 # calibration 69 after_prepare_result = m(*example_inputs) 70 logging.info(f"prepare_pt2e: {m}") 71 m = convert_pt2e(m) 72 after_quant_result = m(*example_inputs) 73 74 # 2. the previous fx graph mode quantization reference flow 75 qconfig = default_per_channel_symmetric_qnnpack_qconfig 76 qconfig_mapping = QConfigMapping().set_global(qconfig) 77 backend_config = get_executorch_backend_config() 78 m_fx = prepare_fx( 79 m_copy, qconfig_mapping, example_inputs, backend_config=backend_config 80 ) 81 after_prepare_result_fx = m_fx(*example_inputs) 82 logging.info(f"prepare_fx: {m_fx}") 83 m_fx = _convert_to_reference_decomposed_fx(m_fx, backend_config=backend_config) 84 after_quant_result_fx = m_fx(*example_inputs) 85 86 # 3. compare results 87 if model_name == "dl3": 88 # dl3 output format: {"out": a, "aux": b} 89 after_prepare_result = after_prepare_result["out"] 90 after_prepare_result_fx = after_prepare_result_fx["out"] 91 after_quant_result = after_quant_result["out"] 92 after_quant_result_fx = after_quant_result_fx["out"] 93 logging.info(f"m: {m}") 94 logging.info(f"m_fx: {m_fx}") 95 logging.info( 96 f"prepare sqnr: {compute_sqnr(after_prepare_result, after_prepare_result_fx)}" 97 ) 98 99 # NB: this check is more useful for QAT since for PTQ we are only inserting observers that does not change the 100 # output of a model, so it's just testing the numerical difference for different captures in PTQ 101 # for QAT it is also testing whether the fake quant placement match or not 102 # not exactly the same due to capture changing numerics, but still really close 103 assert compute_sqnr(after_prepare_result, after_prepare_result_fx) > 100 104 logging.info( 105 f"quant diff max: {torch.max(after_quant_result - after_quant_result_fx)}" 106 ) 107 assert torch.max(after_quant_result - after_quant_result_fx) < 1e-1 108 logging.info( 109 f"quant sqnr: {compute_sqnr(after_quant_result, after_quant_result_fx)}" 110 ) 111 assert compute_sqnr(after_quant_result, after_quant_result_fx) > 30 112 113 114def main() -> None: 115 parser = argparse.ArgumentParser() 116 parser.add_argument( 117 "-m", 118 "--model_name", 119 required=True, 120 help=f"Provide model name. Valid ones: {list(MODEL_NAME_TO_OPTIONS.keys())}", 121 ) 122 parser.add_argument( 123 "-ve", 124 "--verify", 125 action="store_true", 126 required=False, 127 default=False, 128 help="flag for verifying XNNPACKQuantizer against fx graph mode quantization", 129 ) 130 parser.add_argument( 131 "-s", 132 "--so_library", 133 required=False, 134 help="shared library for quantized operators", 135 ) 136 137 args = parser.parse_args() 138 # See if we have quantized op out variants registered 139 has_out_ops = True 140 try: 141 _ = torch.ops.quantized_decomposed.add.out 142 except AttributeError: 143 logging.info("No registered quantized ops") 144 has_out_ops = False 145 if not has_out_ops: 146 if args.so_library: 147 torch.ops.load_library(args.so_library) 148 else: 149 raise RuntimeError( 150 "Need to specify shared library path to register quantized ops (and their out variants) into" 151 "EXIR. The required shared library is defined as `quantized_ops_aot_lib` in " 152 "kernels/quantized/CMakeLists.txt if you are using CMake build, or `aot_lib` in " 153 "kernels/quantized/targets.bzl for buck2. One example path would be cmake-out/kernels/quantized/" 154 "libquantized_ops_aot_lib.[so|dylib]." 155 ) 156 if not args.verify and args.model_name not in MODEL_NAME_TO_OPTIONS: 157 raise RuntimeError( 158 f"Model {args.model_name} is not a valid name. or not quantizable right now, " 159 "please contact executorch team if you want to learn why or how to support " 160 "quantization for the requested model" 161 f"Available models are {list(MODEL_NAME_TO_OPTIONS.keys())}." 162 ) 163 164 start = time.perf_counter() 165 model, example_inputs, _, _ = EagerModelFactory.create_model( 166 *MODEL_NAME_TO_MODEL[args.model_name] 167 ) 168 end = time.perf_counter() 169 # logging.info(f"Model init time: {end - start}s") 170 if args.verify: 171 start = time.perf_counter() 172 verify_xnnpack_quantizer_matching_fx_quant_model( 173 args.model_name, model, example_inputs 174 ) 175 end = time.perf_counter() 176 # logging.info(f"Verify time: {end - start}s") 177 178 model = model.eval() 179 # pre-autograd export. eventually this will become torch.export 180 model = torch.export.export_for_training(model, example_inputs).module() 181 start = time.perf_counter() 182 quantized_model = quantize(model, example_inputs) 183 end = time.perf_counter() 184 logging.info(f"Quantize time: {end - start}s") 185 186 start = time.perf_counter() 187 edge_compile_config = EdgeCompileConfig(_check_ir_validity=False) 188 edge_m = export_to_edge( 189 quantized_model, example_inputs, edge_compile_config=edge_compile_config 190 ) 191 end = time.perf_counter() 192 logging.info(f"Export time: {end - start}s") 193 194 start = time.perf_counter() 195 prog = edge_m.to_executorch( 196 config=ExecutorchBackendConfig(extract_delegate_segments=False) 197 ) 198 save_pte_program(prog, f"{args.model_name}_quantized") 199 end = time.perf_counter() 200 logging.info(f"Save time: {end - start}s") 201 logging.info("finished") 202 203 204if __name__ == "__main__": 205 main() # pragma: no cover 206