xref: /aosp_15_r20/external/pytorch/test/test_quantization.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: quantization"]
2
3import logging
4from torch.testing._internal.common_utils import run_tests
5
6# Quantization core tests. These include tests for
7# - quantized kernels
8# - quantized functional operators
9# - quantized workflow modules
10# - quantized workflow operators
11# - quantized tensor
12
13# 1. Quantized Kernels
14# TODO: merge the different quantized op tests into one test class
15from quantization.core.test_quantized_op import TestQuantizedOps  # noqa: F401
16from quantization.core.test_quantized_op import TestQNNPackOps  # noqa: F401
17from quantization.core.test_quantized_op import TestQuantizedLinear  # noqa: F401
18from quantization.core.test_quantized_op import TestQuantizedConv  # noqa: F401
19from quantization.core.test_quantized_op import TestDynamicQuantizedOps  # noqa: F401
20from quantization.core.test_quantized_op import TestComparatorOps  # noqa: F401
21from quantization.core.test_quantized_op import TestPadding  # noqa: F401
22from quantization.core.test_quantized_op import TestQuantizedEmbeddingOps  # noqa: F401
23# 2. Quantized Functional/Workflow Ops
24from quantization.core.test_quantized_functional import TestQuantizedFunctionalOps  # noqa: F401
25from quantization.core.test_workflow_ops import TestFakeQuantizeOps  # noqa: F401
26from quantization.core.test_workflow_ops import TestFusedObsFakeQuant  # noqa: F401
27# 3. Quantized Tensor
28from quantization.core.test_quantized_tensor import TestQuantizedTensor  # noqa: F401
29# 4. Modules
30from quantization.core.test_workflow_module import TestFakeQuantize  # noqa: F401
31from quantization.core.test_workflow_module import TestObserver  # noqa: F401
32from quantization.core.test_quantized_module import TestStaticQuantizedModule  # noqa: F401
33from quantization.core.test_quantized_module import TestDynamicQuantizedModule  # noqa: F401
34from quantization.core.test_quantized_module import TestReferenceQuantizedModule  # noqa: F401
35from quantization.core.test_workflow_module import TestRecordHistogramObserver  # noqa: F401
36from quantization.core.test_workflow_module import TestHistogramObserver  # noqa: F401
37from quantization.core.test_workflow_module import TestDistributed  # noqa: F401
38from quantization.core.test_workflow_module import TestFusedObsFakeQuantModule  # noqa: F401
39from quantization.core.test_backend_config import TestBackendConfig  # noqa: F401
40from quantization.core.test_utils import TestUtils  # noqa: F401
41try:
42    # This test has extra data dependencies, so in some environments, e.g. Meta internal
43    # Buck, it has its own test runner.
44    from quantization.core.test_docs import TestQuantizationDocs  # noqa: F401
45except ImportError as e:
46    logging.warning(e)
47
48# Eager Mode Workflow. Tests for the functionality of APIs and different features implemented
49# using eager mode.
50
51# 1. Eager mode post training quantization
52from quantization.eager.test_quantize_eager_ptq import TestQuantizeEagerPTQStatic  # noqa: F401
53from quantization.eager.test_quantize_eager_ptq import TestQuantizeEagerPTQDynamic  # noqa: F401
54from quantization.eager.test_quantize_eager_ptq import TestQuantizeEagerOps  # noqa: F401
55# 2. Eager mode quantization aware training
56from quantization.eager.test_quantize_eager_qat import TestQuantizeEagerQAT  # noqa: F401
57from quantization.eager.test_quantize_eager_qat import TestQuantizeEagerQATNumerics  # noqa: F401
58# 3. Eager mode fusion passes
59from quantization.eager.test_fuse_eager import TestFuseEager  # noqa: F401
60# 4. Testing model numerics between quanitzed and FP32 models
61from quantization.eager.test_model_numerics import TestModelNumericsEager  # noqa: F401
62# 5. Tooling: numeric_suite
63from quantization.eager.test_numeric_suite_eager import TestNumericSuiteEager  # noqa: F401
64# 6. Equalization and Bias Correction
65from quantization.eager.test_equalize_eager import TestEqualizeEager  # noqa: F401
66from quantization.eager.test_bias_correction_eager import TestBiasCorrectionEager  # noqa: F401
67
68
69# FX GraphModule Graph Mode Quantization. Tests for the functionality of APIs and different features implemented
70# using fx quantization.
71try:
72    from quantization.fx.test_quantize_fx import TestFuseFx  # noqa: F401
73    from quantization.fx.test_quantize_fx import TestQuantizeFx  # noqa: F401
74    from quantization.fx.test_quantize_fx import TestQuantizeFxOps  # noqa: F401
75    from quantization.fx.test_quantize_fx import TestQuantizeFxModels  # noqa: F401
76    from quantization.fx.test_subgraph_rewriter import TestSubgraphRewriter  # noqa: F401
77except ImportError as e:
78    # In FBCode we separate FX out into a separate target for the sake of dev
79    # velocity. These are covered by a separate test target `quantization_fx`
80    logging.warning(e)
81
82# PyTorch 2 Export Quantization
83try:
84    # To be moved to compiler side later
85    from quantization.pt2e.test_graph_utils import TestGraphUtils  # noqa: F401
86    from quantization.pt2e.test_duplicate_dq import TestDuplicateDQPass  # noqa: F401
87    from quantization.pt2e.test_metadata_porting import TestMetaDataPorting  # noqa: F401
88    from quantization.pt2e.test_numeric_debugger import TestNumericDebugger  # noqa: F401
89    from quantization.pt2e.test_quantize_pt2e import TestQuantizePT2E  # noqa: F401
90    from quantization.pt2e.test_representation import TestPT2ERepresentation  # noqa: F401
91    from quantization.pt2e.test_xnnpack_quantizer import TestXNNPACKQuantizer  # noqa: F401
92    from quantization.pt2e.test_xnnpack_quantizer import TestXNNPACKQuantizerModels  # noqa: F401
93    from quantization.pt2e.test_x86inductor_quantizer import TestQuantizePT2EX86Inductor  # noqa: F401
94    # TODO: Figure out a way to merge all QAT tests in one TestCase
95    from quantization.pt2e.test_quantize_pt2e_qat import TestQuantizePT2EQAT_ConvBn1d  # noqa: F401
96    from quantization.pt2e.test_quantize_pt2e_qat import TestQuantizePT2EQAT_ConvBn2d  # noqa: F401
97    from quantization.pt2e.test_quantize_pt2e_qat import TestQuantizePT2EQATModels  # noqa: F401
98except ImportError as e:
99    # In FBCode we separate PT2 out into a separate target for the sake of dev
100    # velocity. These are covered by a separate test target `quantization_pt2e`
101    logging.warning(e)
102
103try:
104    from quantization.fx.test_numeric_suite_fx import TestFXGraphMatcher  # noqa: F401
105    from quantization.fx.test_numeric_suite_fx import TestFXGraphMatcherModels  # noqa: F401
106    from quantization.fx.test_numeric_suite_fx import TestFXNumericSuiteCoreAPIs  # noqa: F401
107    from quantization.fx.test_numeric_suite_fx import TestFXNumericSuiteNShadows  # noqa: F401
108    from quantization.fx.test_numeric_suite_fx import TestFXNumericSuiteCoreAPIsModels  # noqa: F401
109except ImportError as e:
110    logging.warning(e)
111
112# Test the model report module
113try:
114    from quantization.fx.test_model_report_fx import TestFxModelReportDetector  # noqa: F401
115    from quantization.fx.test_model_report_fx import TestFxModelReportObserver      # noqa: F401
116    from quantization.fx.test_model_report_fx import TestFxModelReportDetectDynamicStatic  # noqa: F401
117    from quantization.fx.test_model_report_fx import TestFxModelReportClass  # noqa: F401
118    from quantization.fx.test_model_report_fx import TestFxDetectInputWeightEqualization  # noqa: F401
119    from quantization.fx.test_model_report_fx import TestFxDetectOutliers  # noqa: F401
120    from quantization.fx.test_model_report_fx import TestFxModelReportVisualizer  # noqa: F401
121except ImportError as e:
122    logging.warning(e)
123
124# Equalization for FX mode
125try:
126    from quantization.fx.test_equalize_fx import TestEqualizeFx  # noqa: F401
127except ImportError as e:
128    logging.warning(e)
129
130# Backward Compatibility. Tests serialization and BC for quantized modules.
131try:
132    from quantization.bc.test_backward_compatibility import TestSerialization  # noqa: F401
133except ImportError as e:
134    logging.warning(e)
135
136# JIT Graph Mode Quantization
137from quantization.jit.test_quantize_jit import TestQuantizeJit  # noqa: F401
138from quantization.jit.test_quantize_jit import TestQuantizeJitPasses  # noqa: F401
139from quantization.jit.test_quantize_jit import TestQuantizeJitOps  # noqa: F401
140from quantization.jit.test_quantize_jit import TestQuantizeDynamicJitPasses  # noqa: F401
141from quantization.jit.test_quantize_jit import TestQuantizeDynamicJitOps  # noqa: F401
142# Quantization specific fusion passes
143from quantization.jit.test_fusion_passes import TestFusionPasses  # noqa: F401
144from quantization.jit.test_deprecated_jit_quant import TestDeprecatedJitQuantized  # noqa: F401
145
146# AO Migration tests
147from quantization.ao_migration.test_quantization import TestAOMigrationQuantization  # noqa: F401
148from quantization.ao_migration.test_ao_migration import TestAOMigrationNNQuantized  # noqa: F401
149from quantization.ao_migration.test_ao_migration import TestAOMigrationNNIntrinsic  # noqa: F401
150try:
151    from quantization.ao_migration.test_quantization_fx import TestAOMigrationQuantizationFx  # noqa: F401
152except ImportError as e:
153    logging.warning(e)
154
155# Experimental functionality
156try:
157    from quantization.core.experimental.test_bits import TestBitsCPU  # noqa: F401
158except ImportError as e:
159    logging.warning(e)
160try:
161    from quantization.core.experimental.test_bits import TestBitsCUDA  # noqa: F401
162except ImportError as e:
163    logging.warning(e)
164try:
165    from quantization.core.experimental.test_float8 import TestFloat8DtypeCPU  # noqa: F401
166except ImportError as e:
167    logging.warning(e)
168try:
169    from quantization.core.experimental.test_float8 import TestFloat8DtypeCUDA  # noqa: F401
170except ImportError as e:
171    logging.warning(e)
172try:
173    from quantization.core.experimental.test_float8 import TestFloat8DtypeCPUOnlyCPU  # noqa: F401
174except ImportError as e:
175    logging.warning(e)
176
177if __name__ == '__main__':
178    run_tests()
179