1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7# pyre-strict 8 9import unittest 10 11import torch 12import torchvision 13from executorch.exir import EdgeCompileConfig, to_edge 14from executorch.exir.passes.quant_fusion_pass import QuantFusionPass 15from executorch.exir.passes.spec_prop_pass import SpecPropPass 16from torch.ao.ns.fx.utils import compute_sqnr 17from torch.ao.quantization import QConfigMapping # @manual 18from torch.ao.quantization.backend_config import get_executorch_backend_config 19from torch.ao.quantization.qconfig import default_per_channel_symmetric_qnnpack_qconfig 20from torch.ao.quantization.quantize_fx import prepare_fx 21from torch.ao.quantization.quantize_pt2e import ( 22 _convert_to_reference_decomposed_fx, 23 convert_pt2e, 24 prepare_pt2e, 25) 26 27from torch.ao.quantization.quantizer.xnnpack_quantizer import ( 28 get_symmetric_quantization_config, 29 XNNPACKQuantizer, 30) 31from torch.export import export 32from torch.testing import FileCheck 33from torch.testing._internal.common_quantized import override_quantized_engine 34 35# load executorch out variant ops 36torch.ops.load_library("//executorch/kernels/quantized:custom_ops_generated_lib") 37 38 39class TestQuantization(unittest.TestCase): 40 """prepare_pt2e and convert_pt2e are OSS APIs, the rest are all meta-only 41 42 APIs for now, but we plan to open source them in the future 43 """ 44 45 def test_resnet(self) -> None: 46 import copy 47 48 with override_quantized_engine("qnnpack"): 49 torch.backends.quantized.engine = "qnnpack" 50 example_inputs = (torch.randn(1, 3, 224, 224),) 51 m = torchvision.models.resnet18().eval() 52 m_copy = copy.deepcopy(m) 53 # program capture 54 m = torch.export.export_for_training( 55 m, copy.deepcopy(example_inputs) 56 ).module() 57 58 quantizer = XNNPACKQuantizer() 59 operator_config = get_symmetric_quantization_config(is_per_channel=True) 60 quantizer.set_global(operator_config) 61 m = prepare_pt2e(m, quantizer) # pyre-fixme[6] 62 self.assertEqual( 63 id(m.activation_post_process_3), id(m.activation_post_process_2) 64 ) 65 after_prepare_result = m(*example_inputs)[0] 66 m = convert_pt2e(m) 67 68 # TODO: conv, conv_relu, linear delegation 69 # quantized ops to implement: add_relu 70 compile_config = EdgeCompileConfig( 71 _check_ir_validity=False, 72 ) 73 m = to_edge( 74 export(m, example_inputs), compile_config=compile_config 75 ).transform([QuantFusionPass(), SpecPropPass()]) 76 77 after_quant_result = m.exported_program().module()(*example_inputs)[0] 78 FileCheck().check( 79 "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor" 80 ).check( 81 "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor" 82 ).run( 83 m.exported_program().graph_module.code 84 ) 85 # after_quant_fusion_result = m(*example_inputs)[0] 86 87 # TODO: implement torch.ops.quantized_decomposed.add_relu.out 88 # m = m.to_executorch().dump_graph_module() 89 # after_to_executorch = m(*example_inputs)[0] 90 # test the result before and after to_executorch matches 91 # TODO: debug why this is a mismatch 92 # self.assertTrue(torch.equal(after_quant_fusion_result, after_to_executorch)) 93 # self.assertEqual(compute_sqnr(after_quant_fusion_result, after_to_executorch), torch.tensor(float("inf"))) 94 95 # comparing with existing fx graph mode quantization reference flow 96 qconfig = default_per_channel_symmetric_qnnpack_qconfig 97 qconfig_mapping = QConfigMapping().set_global(qconfig) 98 backend_config = get_executorch_backend_config() 99 m_fx = prepare_fx( 100 m_copy, qconfig_mapping, example_inputs, backend_config=backend_config 101 ) 102 after_prepare_result_fx = m_fx(*example_inputs) 103 m_fx = _convert_to_reference_decomposed_fx( 104 m_fx, backend_config=backend_config 105 ) 106 after_quant_result_fx = m_fx(*example_inputs) 107 108 # the result matches exactly after prepare 109 self.assertTrue( 110 torch.allclose(after_prepare_result, after_prepare_result_fx, atol=1e-6) 111 ) 112 113 # there are slight differences after convert due to different implementations 114 # of quant/dequant 115 self.assertTrue( 116 torch.max(after_quant_result - after_quant_result_fx) < 1e-1 117 ) 118 self.assertTrue( 119 compute_sqnr(after_quant_result, after_quant_result_fx) > 35 120 ) 121