1# Owner(s): ["oncall: quantization"] 2 3from .common import AOMigrationTestCase 4 5 6class TestAOMigrationQuantization(AOMigrationTestCase): 7 r"""Modules and functions related to the 8 `torch/quantization` migration to `torch/ao/quantization`. 9 """ 10 11 def test_function_import_quantize(self): 12 function_list = [ 13 "_convert", 14 "_observer_forward_hook", 15 "_propagate_qconfig_helper", 16 "_remove_activation_post_process", 17 "_remove_qconfig", 18 "_add_observer_", 19 "add_quant_dequant", 20 "convert", 21 "_get_observer_dict", 22 "_get_unique_devices_", 23 "_is_activation_post_process", 24 "prepare", 25 "prepare_qat", 26 "propagate_qconfig_", 27 "quantize", 28 "quantize_dynamic", 29 "quantize_qat", 30 "_register_activation_post_process_hook", 31 "swap_module", 32 ] 33 self._test_function_import("quantize", function_list) 34 35 def test_function_import_stubs(self): 36 function_list = [ 37 "QuantStub", 38 "DeQuantStub", 39 "QuantWrapper", 40 ] 41 self._test_function_import("stubs", function_list) 42 43 def test_function_import_quantize_jit(self): 44 function_list = [ 45 "_check_is_script_module", 46 "_check_forward_method", 47 "script_qconfig", 48 "script_qconfig_dict", 49 "fuse_conv_bn_jit", 50 "_prepare_jit", 51 "prepare_jit", 52 "prepare_dynamic_jit", 53 "_convert_jit", 54 "convert_jit", 55 "convert_dynamic_jit", 56 "_quantize_jit", 57 "quantize_jit", 58 "quantize_dynamic_jit", 59 ] 60 self._test_function_import("quantize_jit", function_list) 61 62 def test_function_import_fake_quantize(self): 63 function_list = [ 64 "_is_per_channel", 65 "_is_per_tensor", 66 "_is_symmetric_quant", 67 "FakeQuantizeBase", 68 "FakeQuantize", 69 "FixedQParamsFakeQuantize", 70 "FusedMovingAvgObsFakeQuantize", 71 "default_fake_quant", 72 "default_weight_fake_quant", 73 "default_fixed_qparams_range_neg1to1_fake_quant", 74 "default_fixed_qparams_range_0to1_fake_quant", 75 "default_per_channel_weight_fake_quant", 76 "default_histogram_fake_quant", 77 "default_fused_act_fake_quant", 78 "default_fused_wt_fake_quant", 79 "default_fused_per_channel_wt_fake_quant", 80 "_is_fake_quant_script_module", 81 "disable_fake_quant", 82 "enable_fake_quant", 83 "disable_observer", 84 "enable_observer", 85 ] 86 self._test_function_import("fake_quantize", function_list) 87 88 def test_function_import_fuse_modules(self): 89 function_list = [ 90 "_fuse_modules", 91 "_get_module", 92 "_set_module", 93 "fuse_conv_bn", 94 "fuse_conv_bn_relu", 95 "fuse_known_modules", 96 "fuse_modules", 97 "get_fuser_method", 98 ] 99 self._test_function_import("fuse_modules", function_list) 100 101 def test_function_import_quant_type(self): 102 function_list = [ 103 "QuantType", 104 "_get_quant_type_to_str", 105 ] 106 self._test_function_import("quant_type", function_list) 107 108 def test_function_import_observer(self): 109 function_list = [ 110 "_PartialWrapper", 111 "_with_args", 112 "_with_callable_args", 113 "ABC", 114 "ObserverBase", 115 "_ObserverBase", 116 "MinMaxObserver", 117 "MovingAverageMinMaxObserver", 118 "PerChannelMinMaxObserver", 119 "MovingAveragePerChannelMinMaxObserver", 120 "HistogramObserver", 121 "PlaceholderObserver", 122 "RecordingObserver", 123 "NoopObserver", 124 "_is_activation_post_process", 125 "_is_per_channel_script_obs_instance", 126 "get_observer_state_dict", 127 "load_observer_state_dict", 128 "default_observer", 129 "default_placeholder_observer", 130 "default_debug_observer", 131 "default_weight_observer", 132 "default_histogram_observer", 133 "default_per_channel_weight_observer", 134 "default_dynamic_quant_observer", 135 "default_float_qparams_observer", 136 ] 137 self._test_function_import("observer", function_list) 138 139 def test_function_import_qconfig(self): 140 function_list = [ 141 "QConfig", 142 "default_qconfig", 143 "default_debug_qconfig", 144 "default_per_channel_qconfig", 145 "QConfigDynamic", 146 "default_dynamic_qconfig", 147 "float16_dynamic_qconfig", 148 "float16_static_qconfig", 149 "per_channel_dynamic_qconfig", 150 "float_qparams_weight_only_qconfig", 151 "default_qat_qconfig", 152 "default_weight_only_qconfig", 153 "default_activation_only_qconfig", 154 "default_qat_qconfig_v2", 155 "get_default_qconfig", 156 "get_default_qat_qconfig", 157 "_assert_valid_qconfig", 158 "QConfigAny", 159 "_add_module_to_qconfig_obs_ctr", 160 "qconfig_equals", 161 ] 162 self._test_function_import("qconfig", function_list) 163 164 def test_function_import_quantization_mappings(self): 165 function_list = [ 166 "no_observer_set", 167 "get_default_static_quant_module_mappings", 168 "get_static_quant_module_class", 169 "get_dynamic_quant_module_class", 170 "get_default_qat_module_mappings", 171 "get_default_dynamic_quant_module_mappings", 172 "get_default_qconfig_propagation_list", 173 "get_default_compare_output_module_list", 174 "get_default_float_to_quantized_operator_mappings", 175 "get_quantized_operator", 176 "_get_special_act_post_process", 177 "_has_special_act_post_process", 178 ] 179 dict_list = [ 180 "DEFAULT_REFERENCE_STATIC_QUANT_MODULE_MAPPINGS", 181 "DEFAULT_STATIC_QUANT_MODULE_MAPPINGS", 182 "DEFAULT_QAT_MODULE_MAPPINGS", 183 "DEFAULT_DYNAMIC_QUANT_MODULE_MAPPINGS", 184 # "_INCLUDE_QCONFIG_PROPAGATE_LIST", 185 "DEFAULT_FLOAT_TO_QUANTIZED_OPERATOR_MAPPINGS", 186 "DEFAULT_MODULE_TO_ACT_POST_PROCESS", 187 ] 188 self._test_function_import("quantization_mappings", function_list) 189 self._test_dict_import("quantization_mappings", dict_list) 190 191 def test_function_import_fuser_method_mappings(self): 192 function_list = [ 193 "fuse_conv_bn", 194 "fuse_conv_bn_relu", 195 "fuse_linear_bn", 196 "get_fuser_method", 197 ] 198 dict_list = ["_DEFAULT_OP_LIST_TO_FUSER_METHOD"] 199 self._test_function_import("fuser_method_mappings", function_list) 200 self._test_dict_import("fuser_method_mappings", dict_list) 201 202 def test_function_import_utils(self): 203 function_list = [ 204 "activation_dtype", 205 "activation_is_int8_quantized", 206 "activation_is_statically_quantized", 207 "calculate_qmin_qmax", 208 "check_min_max_valid", 209 "get_combined_dict", 210 "get_qconfig_dtypes", 211 "get_qparam_dict", 212 "get_quant_type", 213 "get_swapped_custom_module_class", 214 "getattr_from_fqn", 215 "is_per_channel", 216 "is_per_tensor", 217 "weight_dtype", 218 "weight_is_quantized", 219 "weight_is_statically_quantized", 220 ] 221 self._test_function_import("utils", function_list) 222