xref: /aosp_15_r20/external/pytorch/test/quantization/ao_migration/test_quantization.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: quantization"]
2
3from .common import AOMigrationTestCase
4
5
6class TestAOMigrationQuantization(AOMigrationTestCase):
7    r"""Modules and functions related to the
8    `torch/quantization` migration to `torch/ao/quantization`.
9    """
10
11    def test_function_import_quantize(self):
12        function_list = [
13            "_convert",
14            "_observer_forward_hook",
15            "_propagate_qconfig_helper",
16            "_remove_activation_post_process",
17            "_remove_qconfig",
18            "_add_observer_",
19            "add_quant_dequant",
20            "convert",
21            "_get_observer_dict",
22            "_get_unique_devices_",
23            "_is_activation_post_process",
24            "prepare",
25            "prepare_qat",
26            "propagate_qconfig_",
27            "quantize",
28            "quantize_dynamic",
29            "quantize_qat",
30            "_register_activation_post_process_hook",
31            "swap_module",
32        ]
33        self._test_function_import("quantize", function_list)
34
35    def test_function_import_stubs(self):
36        function_list = [
37            "QuantStub",
38            "DeQuantStub",
39            "QuantWrapper",
40        ]
41        self._test_function_import("stubs", function_list)
42
43    def test_function_import_quantize_jit(self):
44        function_list = [
45            "_check_is_script_module",
46            "_check_forward_method",
47            "script_qconfig",
48            "script_qconfig_dict",
49            "fuse_conv_bn_jit",
50            "_prepare_jit",
51            "prepare_jit",
52            "prepare_dynamic_jit",
53            "_convert_jit",
54            "convert_jit",
55            "convert_dynamic_jit",
56            "_quantize_jit",
57            "quantize_jit",
58            "quantize_dynamic_jit",
59        ]
60        self._test_function_import("quantize_jit", function_list)
61
62    def test_function_import_fake_quantize(self):
63        function_list = [
64            "_is_per_channel",
65            "_is_per_tensor",
66            "_is_symmetric_quant",
67            "FakeQuantizeBase",
68            "FakeQuantize",
69            "FixedQParamsFakeQuantize",
70            "FusedMovingAvgObsFakeQuantize",
71            "default_fake_quant",
72            "default_weight_fake_quant",
73            "default_fixed_qparams_range_neg1to1_fake_quant",
74            "default_fixed_qparams_range_0to1_fake_quant",
75            "default_per_channel_weight_fake_quant",
76            "default_histogram_fake_quant",
77            "default_fused_act_fake_quant",
78            "default_fused_wt_fake_quant",
79            "default_fused_per_channel_wt_fake_quant",
80            "_is_fake_quant_script_module",
81            "disable_fake_quant",
82            "enable_fake_quant",
83            "disable_observer",
84            "enable_observer",
85        ]
86        self._test_function_import("fake_quantize", function_list)
87
88    def test_function_import_fuse_modules(self):
89        function_list = [
90            "_fuse_modules",
91            "_get_module",
92            "_set_module",
93            "fuse_conv_bn",
94            "fuse_conv_bn_relu",
95            "fuse_known_modules",
96            "fuse_modules",
97            "get_fuser_method",
98        ]
99        self._test_function_import("fuse_modules", function_list)
100
101    def test_function_import_quant_type(self):
102        function_list = [
103            "QuantType",
104            "_get_quant_type_to_str",
105        ]
106        self._test_function_import("quant_type", function_list)
107
108    def test_function_import_observer(self):
109        function_list = [
110            "_PartialWrapper",
111            "_with_args",
112            "_with_callable_args",
113            "ABC",
114            "ObserverBase",
115            "_ObserverBase",
116            "MinMaxObserver",
117            "MovingAverageMinMaxObserver",
118            "PerChannelMinMaxObserver",
119            "MovingAveragePerChannelMinMaxObserver",
120            "HistogramObserver",
121            "PlaceholderObserver",
122            "RecordingObserver",
123            "NoopObserver",
124            "_is_activation_post_process",
125            "_is_per_channel_script_obs_instance",
126            "get_observer_state_dict",
127            "load_observer_state_dict",
128            "default_observer",
129            "default_placeholder_observer",
130            "default_debug_observer",
131            "default_weight_observer",
132            "default_histogram_observer",
133            "default_per_channel_weight_observer",
134            "default_dynamic_quant_observer",
135            "default_float_qparams_observer",
136        ]
137        self._test_function_import("observer", function_list)
138
139    def test_function_import_qconfig(self):
140        function_list = [
141            "QConfig",
142            "default_qconfig",
143            "default_debug_qconfig",
144            "default_per_channel_qconfig",
145            "QConfigDynamic",
146            "default_dynamic_qconfig",
147            "float16_dynamic_qconfig",
148            "float16_static_qconfig",
149            "per_channel_dynamic_qconfig",
150            "float_qparams_weight_only_qconfig",
151            "default_qat_qconfig",
152            "default_weight_only_qconfig",
153            "default_activation_only_qconfig",
154            "default_qat_qconfig_v2",
155            "get_default_qconfig",
156            "get_default_qat_qconfig",
157            "_assert_valid_qconfig",
158            "QConfigAny",
159            "_add_module_to_qconfig_obs_ctr",
160            "qconfig_equals",
161        ]
162        self._test_function_import("qconfig", function_list)
163
164    def test_function_import_quantization_mappings(self):
165        function_list = [
166            "no_observer_set",
167            "get_default_static_quant_module_mappings",
168            "get_static_quant_module_class",
169            "get_dynamic_quant_module_class",
170            "get_default_qat_module_mappings",
171            "get_default_dynamic_quant_module_mappings",
172            "get_default_qconfig_propagation_list",
173            "get_default_compare_output_module_list",
174            "get_default_float_to_quantized_operator_mappings",
175            "get_quantized_operator",
176            "_get_special_act_post_process",
177            "_has_special_act_post_process",
178        ]
179        dict_list = [
180            "DEFAULT_REFERENCE_STATIC_QUANT_MODULE_MAPPINGS",
181            "DEFAULT_STATIC_QUANT_MODULE_MAPPINGS",
182            "DEFAULT_QAT_MODULE_MAPPINGS",
183            "DEFAULT_DYNAMIC_QUANT_MODULE_MAPPINGS",
184            # "_INCLUDE_QCONFIG_PROPAGATE_LIST",
185            "DEFAULT_FLOAT_TO_QUANTIZED_OPERATOR_MAPPINGS",
186            "DEFAULT_MODULE_TO_ACT_POST_PROCESS",
187        ]
188        self._test_function_import("quantization_mappings", function_list)
189        self._test_dict_import("quantization_mappings", dict_list)
190
191    def test_function_import_fuser_method_mappings(self):
192        function_list = [
193            "fuse_conv_bn",
194            "fuse_conv_bn_relu",
195            "fuse_linear_bn",
196            "get_fuser_method",
197        ]
198        dict_list = ["_DEFAULT_OP_LIST_TO_FUSER_METHOD"]
199        self._test_function_import("fuser_method_mappings", function_list)
200        self._test_dict_import("fuser_method_mappings", dict_list)
201
202    def test_function_import_utils(self):
203        function_list = [
204            "activation_dtype",
205            "activation_is_int8_quantized",
206            "activation_is_statically_quantized",
207            "calculate_qmin_qmax",
208            "check_min_max_valid",
209            "get_combined_dict",
210            "get_qconfig_dtypes",
211            "get_qparam_dict",
212            "get_quant_type",
213            "get_swapped_custom_module_class",
214            "getattr_from_fqn",
215            "is_per_channel",
216            "is_per_tensor",
217            "weight_dtype",
218            "weight_is_quantized",
219            "weight_is_statically_quantized",
220        ]
221        self._test_function_import("utils", function_list)
222