xref: /aosp_15_r20/external/pytorch/test/test_quantization.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["oncall: quantization"]
2*da0073e9SAndroid Build Coastguard Worker
3*da0073e9SAndroid Build Coastguard Workerimport logging
4*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import run_tests
5*da0073e9SAndroid Build Coastguard Worker
6*da0073e9SAndroid Build Coastguard Worker# Quantization core tests. These include tests for
7*da0073e9SAndroid Build Coastguard Worker# - quantized kernels
8*da0073e9SAndroid Build Coastguard Worker# - quantized functional operators
9*da0073e9SAndroid Build Coastguard Worker# - quantized workflow modules
10*da0073e9SAndroid Build Coastguard Worker# - quantized workflow operators
11*da0073e9SAndroid Build Coastguard Worker# - quantized tensor
12*da0073e9SAndroid Build Coastguard Worker
13*da0073e9SAndroid Build Coastguard Worker# 1. Quantized Kernels
14*da0073e9SAndroid Build Coastguard Worker# TODO: merge the different quantized op tests into one test class
15*da0073e9SAndroid Build Coastguard Workerfrom quantization.core.test_quantized_op import TestQuantizedOps  # noqa: F401
16*da0073e9SAndroid Build Coastguard Workerfrom quantization.core.test_quantized_op import TestQNNPackOps  # noqa: F401
17*da0073e9SAndroid Build Coastguard Workerfrom quantization.core.test_quantized_op import TestQuantizedLinear  # noqa: F401
18*da0073e9SAndroid Build Coastguard Workerfrom quantization.core.test_quantized_op import TestQuantizedConv  # noqa: F401
19*da0073e9SAndroid Build Coastguard Workerfrom quantization.core.test_quantized_op import TestDynamicQuantizedOps  # noqa: F401
20*da0073e9SAndroid Build Coastguard Workerfrom quantization.core.test_quantized_op import TestComparatorOps  # noqa: F401
21*da0073e9SAndroid Build Coastguard Workerfrom quantization.core.test_quantized_op import TestPadding  # noqa: F401
22*da0073e9SAndroid Build Coastguard Workerfrom quantization.core.test_quantized_op import TestQuantizedEmbeddingOps  # noqa: F401
23*da0073e9SAndroid Build Coastguard Worker# 2. Quantized Functional/Workflow Ops
24*da0073e9SAndroid Build Coastguard Workerfrom quantization.core.test_quantized_functional import TestQuantizedFunctionalOps  # noqa: F401
25*da0073e9SAndroid Build Coastguard Workerfrom quantization.core.test_workflow_ops import TestFakeQuantizeOps  # noqa: F401
26*da0073e9SAndroid Build Coastguard Workerfrom quantization.core.test_workflow_ops import TestFusedObsFakeQuant  # noqa: F401
27*da0073e9SAndroid Build Coastguard Worker# 3. Quantized Tensor
28*da0073e9SAndroid Build Coastguard Workerfrom quantization.core.test_quantized_tensor import TestQuantizedTensor  # noqa: F401
29*da0073e9SAndroid Build Coastguard Worker# 4. Modules
30*da0073e9SAndroid Build Coastguard Workerfrom quantization.core.test_workflow_module import TestFakeQuantize  # noqa: F401
31*da0073e9SAndroid Build Coastguard Workerfrom quantization.core.test_workflow_module import TestObserver  # noqa: F401
32*da0073e9SAndroid Build Coastguard Workerfrom quantization.core.test_quantized_module import TestStaticQuantizedModule  # noqa: F401
33*da0073e9SAndroid Build Coastguard Workerfrom quantization.core.test_quantized_module import TestDynamicQuantizedModule  # noqa: F401
34*da0073e9SAndroid Build Coastguard Workerfrom quantization.core.test_quantized_module import TestReferenceQuantizedModule  # noqa: F401
35*da0073e9SAndroid Build Coastguard Workerfrom quantization.core.test_workflow_module import TestRecordHistogramObserver  # noqa: F401
36*da0073e9SAndroid Build Coastguard Workerfrom quantization.core.test_workflow_module import TestHistogramObserver  # noqa: F401
37*da0073e9SAndroid Build Coastguard Workerfrom quantization.core.test_workflow_module import TestDistributed  # noqa: F401
38*da0073e9SAndroid Build Coastguard Workerfrom quantization.core.test_workflow_module import TestFusedObsFakeQuantModule  # noqa: F401
39*da0073e9SAndroid Build Coastguard Workerfrom quantization.core.test_backend_config import TestBackendConfig  # noqa: F401
40*da0073e9SAndroid Build Coastguard Workerfrom quantization.core.test_utils import TestUtils  # noqa: F401
41*da0073e9SAndroid Build Coastguard Workertry:
42*da0073e9SAndroid Build Coastguard Worker    # This test has extra data dependencies, so in some environments, e.g. Meta internal
43*da0073e9SAndroid Build Coastguard Worker    # Buck, it has its own test runner.
44*da0073e9SAndroid Build Coastguard Worker    from quantization.core.test_docs import TestQuantizationDocs  # noqa: F401
45*da0073e9SAndroid Build Coastguard Workerexcept ImportError as e:
46*da0073e9SAndroid Build Coastguard Worker    logging.warning(e)
47*da0073e9SAndroid Build Coastguard Worker
48*da0073e9SAndroid Build Coastguard Worker# Eager Mode Workflow. Tests for the functionality of APIs and different features implemented
49*da0073e9SAndroid Build Coastguard Worker# using eager mode.
50*da0073e9SAndroid Build Coastguard Worker
51*da0073e9SAndroid Build Coastguard Worker# 1. Eager mode post training quantization
52*da0073e9SAndroid Build Coastguard Workerfrom quantization.eager.test_quantize_eager_ptq import TestQuantizeEagerPTQStatic  # noqa: F401
53*da0073e9SAndroid Build Coastguard Workerfrom quantization.eager.test_quantize_eager_ptq import TestQuantizeEagerPTQDynamic  # noqa: F401
54*da0073e9SAndroid Build Coastguard Workerfrom quantization.eager.test_quantize_eager_ptq import TestQuantizeEagerOps  # noqa: F401
55*da0073e9SAndroid Build Coastguard Worker# 2. Eager mode quantization aware training
56*da0073e9SAndroid Build Coastguard Workerfrom quantization.eager.test_quantize_eager_qat import TestQuantizeEagerQAT  # noqa: F401
57*da0073e9SAndroid Build Coastguard Workerfrom quantization.eager.test_quantize_eager_qat import TestQuantizeEagerQATNumerics  # noqa: F401
58*da0073e9SAndroid Build Coastguard Worker# 3. Eager mode fusion passes
59*da0073e9SAndroid Build Coastguard Workerfrom quantization.eager.test_fuse_eager import TestFuseEager  # noqa: F401
60*da0073e9SAndroid Build Coastguard Worker# 4. Testing model numerics between quanitzed and FP32 models
61*da0073e9SAndroid Build Coastguard Workerfrom quantization.eager.test_model_numerics import TestModelNumericsEager  # noqa: F401
62*da0073e9SAndroid Build Coastguard Worker# 5. Tooling: numeric_suite
63*da0073e9SAndroid Build Coastguard Workerfrom quantization.eager.test_numeric_suite_eager import TestNumericSuiteEager  # noqa: F401
64*da0073e9SAndroid Build Coastguard Worker# 6. Equalization and Bias Correction
65*da0073e9SAndroid Build Coastguard Workerfrom quantization.eager.test_equalize_eager import TestEqualizeEager  # noqa: F401
66*da0073e9SAndroid Build Coastguard Workerfrom quantization.eager.test_bias_correction_eager import TestBiasCorrectionEager  # noqa: F401
67*da0073e9SAndroid Build Coastguard Worker
68*da0073e9SAndroid Build Coastguard Worker
69*da0073e9SAndroid Build Coastguard Worker# FX GraphModule Graph Mode Quantization. Tests for the functionality of APIs and different features implemented
70*da0073e9SAndroid Build Coastguard Worker# using fx quantization.
71*da0073e9SAndroid Build Coastguard Workertry:
72*da0073e9SAndroid Build Coastguard Worker    from quantization.fx.test_quantize_fx import TestFuseFx  # noqa: F401
73*da0073e9SAndroid Build Coastguard Worker    from quantization.fx.test_quantize_fx import TestQuantizeFx  # noqa: F401
74*da0073e9SAndroid Build Coastguard Worker    from quantization.fx.test_quantize_fx import TestQuantizeFxOps  # noqa: F401
75*da0073e9SAndroid Build Coastguard Worker    from quantization.fx.test_quantize_fx import TestQuantizeFxModels  # noqa: F401
76*da0073e9SAndroid Build Coastguard Worker    from quantization.fx.test_subgraph_rewriter import TestSubgraphRewriter  # noqa: F401
77*da0073e9SAndroid Build Coastguard Workerexcept ImportError as e:
78*da0073e9SAndroid Build Coastguard Worker    # In FBCode we separate FX out into a separate target for the sake of dev
79*da0073e9SAndroid Build Coastguard Worker    # velocity. These are covered by a separate test target `quantization_fx`
80*da0073e9SAndroid Build Coastguard Worker    logging.warning(e)
81*da0073e9SAndroid Build Coastguard Worker
82*da0073e9SAndroid Build Coastguard Worker# PyTorch 2 Export Quantization
83*da0073e9SAndroid Build Coastguard Workertry:
84*da0073e9SAndroid Build Coastguard Worker    # To be moved to compiler side later
85*da0073e9SAndroid Build Coastguard Worker    from quantization.pt2e.test_graph_utils import TestGraphUtils  # noqa: F401
86*da0073e9SAndroid Build Coastguard Worker    from quantization.pt2e.test_duplicate_dq import TestDuplicateDQPass  # noqa: F401
87*da0073e9SAndroid Build Coastguard Worker    from quantization.pt2e.test_metadata_porting import TestMetaDataPorting  # noqa: F401
88*da0073e9SAndroid Build Coastguard Worker    from quantization.pt2e.test_numeric_debugger import TestNumericDebugger  # noqa: F401
89*da0073e9SAndroid Build Coastguard Worker    from quantization.pt2e.test_quantize_pt2e import TestQuantizePT2E  # noqa: F401
90*da0073e9SAndroid Build Coastguard Worker    from quantization.pt2e.test_representation import TestPT2ERepresentation  # noqa: F401
91*da0073e9SAndroid Build Coastguard Worker    from quantization.pt2e.test_xnnpack_quantizer import TestXNNPACKQuantizer  # noqa: F401
92*da0073e9SAndroid Build Coastguard Worker    from quantization.pt2e.test_xnnpack_quantizer import TestXNNPACKQuantizerModels  # noqa: F401
93*da0073e9SAndroid Build Coastguard Worker    from quantization.pt2e.test_x86inductor_quantizer import TestQuantizePT2EX86Inductor  # noqa: F401
94*da0073e9SAndroid Build Coastguard Worker    # TODO: Figure out a way to merge all QAT tests in one TestCase
95*da0073e9SAndroid Build Coastguard Worker    from quantization.pt2e.test_quantize_pt2e_qat import TestQuantizePT2EQAT_ConvBn1d  # noqa: F401
96*da0073e9SAndroid Build Coastguard Worker    from quantization.pt2e.test_quantize_pt2e_qat import TestQuantizePT2EQAT_ConvBn2d  # noqa: F401
97*da0073e9SAndroid Build Coastguard Worker    from quantization.pt2e.test_quantize_pt2e_qat import TestQuantizePT2EQATModels  # noqa: F401
98*da0073e9SAndroid Build Coastguard Workerexcept ImportError as e:
99*da0073e9SAndroid Build Coastguard Worker    # In FBCode we separate PT2 out into a separate target for the sake of dev
100*da0073e9SAndroid Build Coastguard Worker    # velocity. These are covered by a separate test target `quantization_pt2e`
101*da0073e9SAndroid Build Coastguard Worker    logging.warning(e)
102*da0073e9SAndroid Build Coastguard Worker
103*da0073e9SAndroid Build Coastguard Workertry:
104*da0073e9SAndroid Build Coastguard Worker    from quantization.fx.test_numeric_suite_fx import TestFXGraphMatcher  # noqa: F401
105*da0073e9SAndroid Build Coastguard Worker    from quantization.fx.test_numeric_suite_fx import TestFXGraphMatcherModels  # noqa: F401
106*da0073e9SAndroid Build Coastguard Worker    from quantization.fx.test_numeric_suite_fx import TestFXNumericSuiteCoreAPIs  # noqa: F401
107*da0073e9SAndroid Build Coastguard Worker    from quantization.fx.test_numeric_suite_fx import TestFXNumericSuiteNShadows  # noqa: F401
108*da0073e9SAndroid Build Coastguard Worker    from quantization.fx.test_numeric_suite_fx import TestFXNumericSuiteCoreAPIsModels  # noqa: F401
109*da0073e9SAndroid Build Coastguard Workerexcept ImportError as e:
110*da0073e9SAndroid Build Coastguard Worker    logging.warning(e)
111*da0073e9SAndroid Build Coastguard Worker
112*da0073e9SAndroid Build Coastguard Worker# Test the model report module
113*da0073e9SAndroid Build Coastguard Workertry:
114*da0073e9SAndroid Build Coastguard Worker    from quantization.fx.test_model_report_fx import TestFxModelReportDetector  # noqa: F401
115*da0073e9SAndroid Build Coastguard Worker    from quantization.fx.test_model_report_fx import TestFxModelReportObserver      # noqa: F401
116*da0073e9SAndroid Build Coastguard Worker    from quantization.fx.test_model_report_fx import TestFxModelReportDetectDynamicStatic  # noqa: F401
117*da0073e9SAndroid Build Coastguard Worker    from quantization.fx.test_model_report_fx import TestFxModelReportClass  # noqa: F401
118*da0073e9SAndroid Build Coastguard Worker    from quantization.fx.test_model_report_fx import TestFxDetectInputWeightEqualization  # noqa: F401
119*da0073e9SAndroid Build Coastguard Worker    from quantization.fx.test_model_report_fx import TestFxDetectOutliers  # noqa: F401
120*da0073e9SAndroid Build Coastguard Worker    from quantization.fx.test_model_report_fx import TestFxModelReportVisualizer  # noqa: F401
121*da0073e9SAndroid Build Coastguard Workerexcept ImportError as e:
122*da0073e9SAndroid Build Coastguard Worker    logging.warning(e)
123*da0073e9SAndroid Build Coastguard Worker
124*da0073e9SAndroid Build Coastguard Worker# Equalization for FX mode
125*da0073e9SAndroid Build Coastguard Workertry:
126*da0073e9SAndroid Build Coastguard Worker    from quantization.fx.test_equalize_fx import TestEqualizeFx  # noqa: F401
127*da0073e9SAndroid Build Coastguard Workerexcept ImportError as e:
128*da0073e9SAndroid Build Coastguard Worker    logging.warning(e)
129*da0073e9SAndroid Build Coastguard Worker
130*da0073e9SAndroid Build Coastguard Worker# Backward Compatibility. Tests serialization and BC for quantized modules.
131*da0073e9SAndroid Build Coastguard Workertry:
132*da0073e9SAndroid Build Coastguard Worker    from quantization.bc.test_backward_compatibility import TestSerialization  # noqa: F401
133*da0073e9SAndroid Build Coastguard Workerexcept ImportError as e:
134*da0073e9SAndroid Build Coastguard Worker    logging.warning(e)
135*da0073e9SAndroid Build Coastguard Worker
136*da0073e9SAndroid Build Coastguard Worker# JIT Graph Mode Quantization
137*da0073e9SAndroid Build Coastguard Workerfrom quantization.jit.test_quantize_jit import TestQuantizeJit  # noqa: F401
138*da0073e9SAndroid Build Coastguard Workerfrom quantization.jit.test_quantize_jit import TestQuantizeJitPasses  # noqa: F401
139*da0073e9SAndroid Build Coastguard Workerfrom quantization.jit.test_quantize_jit import TestQuantizeJitOps  # noqa: F401
140*da0073e9SAndroid Build Coastguard Workerfrom quantization.jit.test_quantize_jit import TestQuantizeDynamicJitPasses  # noqa: F401
141*da0073e9SAndroid Build Coastguard Workerfrom quantization.jit.test_quantize_jit import TestQuantizeDynamicJitOps  # noqa: F401
142*da0073e9SAndroid Build Coastguard Worker# Quantization specific fusion passes
143*da0073e9SAndroid Build Coastguard Workerfrom quantization.jit.test_fusion_passes import TestFusionPasses  # noqa: F401
144*da0073e9SAndroid Build Coastguard Workerfrom quantization.jit.test_deprecated_jit_quant import TestDeprecatedJitQuantized  # noqa: F401
145*da0073e9SAndroid Build Coastguard Worker
146*da0073e9SAndroid Build Coastguard Worker# AO Migration tests
147*da0073e9SAndroid Build Coastguard Workerfrom quantization.ao_migration.test_quantization import TestAOMigrationQuantization  # noqa: F401
148*da0073e9SAndroid Build Coastguard Workerfrom quantization.ao_migration.test_ao_migration import TestAOMigrationNNQuantized  # noqa: F401
149*da0073e9SAndroid Build Coastguard Workerfrom quantization.ao_migration.test_ao_migration import TestAOMigrationNNIntrinsic  # noqa: F401
150*da0073e9SAndroid Build Coastguard Workertry:
151*da0073e9SAndroid Build Coastguard Worker    from quantization.ao_migration.test_quantization_fx import TestAOMigrationQuantizationFx  # noqa: F401
152*da0073e9SAndroid Build Coastguard Workerexcept ImportError as e:
153*da0073e9SAndroid Build Coastguard Worker    logging.warning(e)
154*da0073e9SAndroid Build Coastguard Worker
155*da0073e9SAndroid Build Coastguard Worker# Experimental functionality
156*da0073e9SAndroid Build Coastguard Workertry:
157*da0073e9SAndroid Build Coastguard Worker    from quantization.core.experimental.test_bits import TestBitsCPU  # noqa: F401
158*da0073e9SAndroid Build Coastguard Workerexcept ImportError as e:
159*da0073e9SAndroid Build Coastguard Worker    logging.warning(e)
160*da0073e9SAndroid Build Coastguard Workertry:
161*da0073e9SAndroid Build Coastguard Worker    from quantization.core.experimental.test_bits import TestBitsCUDA  # noqa: F401
162*da0073e9SAndroid Build Coastguard Workerexcept ImportError as e:
163*da0073e9SAndroid Build Coastguard Worker    logging.warning(e)
164*da0073e9SAndroid Build Coastguard Workertry:
165*da0073e9SAndroid Build Coastguard Worker    from quantization.core.experimental.test_float8 import TestFloat8DtypeCPU  # noqa: F401
166*da0073e9SAndroid Build Coastguard Workerexcept ImportError as e:
167*da0073e9SAndroid Build Coastguard Worker    logging.warning(e)
168*da0073e9SAndroid Build Coastguard Workertry:
169*da0073e9SAndroid Build Coastguard Worker    from quantization.core.experimental.test_float8 import TestFloat8DtypeCUDA  # noqa: F401
170*da0073e9SAndroid Build Coastguard Workerexcept ImportError as e:
171*da0073e9SAndroid Build Coastguard Worker    logging.warning(e)
172*da0073e9SAndroid Build Coastguard Workertry:
173*da0073e9SAndroid Build Coastguard Worker    from quantization.core.experimental.test_float8 import TestFloat8DtypeCPUOnlyCPU  # noqa: F401
174*da0073e9SAndroid Build Coastguard Workerexcept ImportError as e:
175*da0073e9SAndroid Build Coastguard Worker    logging.warning(e)
176*da0073e9SAndroid Build Coastguard Worker
177*da0073e9SAndroid Build Coastguard Workerif __name__ == '__main__':
178*da0073e9SAndroid Build Coastguard Worker    run_tests()
179