1# Owner(s): ["oncall: quantization"] 2 3from .common import AOMigrationTestCase 4 5 6class TestAOMigrationQuantizationFx(AOMigrationTestCase): 7 def test_function_import_quantize_fx(self): 8 function_list = [ 9 "_check_is_graph_module", 10 "_swap_ff_with_fxff", 11 "_fuse_fx", 12 "QuantizationTracer", 13 "_prepare_fx", 14 "_prepare_standalone_module_fx", 15 "fuse_fx", 16 "Scope", 17 "ScopeContextManager", 18 "prepare_fx", 19 "prepare_qat_fx", 20 "_convert_fx", 21 "convert_fx", 22 "_convert_standalone_module_fx", 23 ] 24 self._test_function_import("quantize_fx", function_list) 25 26 def test_function_import_fx(self): 27 function_list = [ 28 "prepare", 29 "convert", 30 "fuse", 31 ] 32 self._test_function_import("fx", function_list) 33 34 def test_function_import_fx_graph_module(self): 35 function_list = [ 36 "FusedGraphModule", 37 "ObservedGraphModule", 38 "_is_observed_module", 39 "ObservedStandaloneGraphModule", 40 "_is_observed_standalone_module", 41 "QuantizedGraphModule", 42 ] 43 self._test_function_import("fx.graph_module", function_list) 44 45 def test_function_import_fx_pattern_utils(self): 46 function_list = [ 47 "QuantizeHandler", 48 "_register_fusion_pattern", 49 "get_default_fusion_patterns", 50 "_register_quant_pattern", 51 "get_default_quant_patterns", 52 "get_default_output_activation_post_process_map", 53 ] 54 self._test_function_import("fx.pattern_utils", function_list) 55 56 def test_function_import_fx_equalize(self): 57 function_list = [ 58 "reshape_scale", 59 "_InputEqualizationObserver", 60 "_WeightEqualizationObserver", 61 "calculate_equalization_scale", 62 "EqualizationQConfig", 63 "input_equalization_observer", 64 "weight_equalization_observer", 65 "default_equalization_qconfig", 66 "fused_module_supports_equalization", 67 "nn_module_supports_equalization", 68 "node_supports_equalization", 69 "is_equalization_observer", 70 "get_op_node_and_weight_eq_obs", 71 "maybe_get_weight_eq_obs_node", 72 "maybe_get_next_input_eq_obs", 73 "maybe_get_next_equalization_scale", 74 "scale_input_observer", 75 "scale_weight_node", 76 "scale_weight_functional", 77 "clear_weight_quant_obs_node", 78 "remove_node", 79 "update_obs_for_equalization", 80 "convert_eq_obs", 81 "_convert_equalization_ref", 82 "get_layer_sqnr_dict", 83 "get_equalization_qconfig_dict", 84 ] 85 self._test_function_import("fx._equalize", function_list) 86 87 def test_function_import_fx_quantization_patterns(self): 88 function_list = [ 89 "QuantizeHandler", 90 "BinaryOpQuantizeHandler", 91 "CatQuantizeHandler", 92 "ConvReluQuantizeHandler", 93 "LinearReLUQuantizeHandler", 94 "BatchNormQuantizeHandler", 95 "EmbeddingQuantizeHandler", 96 "RNNDynamicQuantizeHandler", 97 "DefaultNodeQuantizeHandler", 98 "FixedQParamsOpQuantizeHandler", 99 "CopyNodeQuantizeHandler", 100 "CustomModuleQuantizeHandler", 101 "GeneralTensorShapeOpQuantizeHandler", 102 "StandaloneModuleQuantizeHandler", 103 ] 104 self._test_function_import( 105 "fx.quantization_patterns", 106 function_list, 107 new_package_name="fx.quantize_handler", 108 ) 109 110 def test_function_import_fx_match_utils(self): 111 function_list = ["_MatchResult", "MatchAllNode", "_is_match", "_find_matches"] 112 self._test_function_import("fx.match_utils", function_list) 113 114 def test_function_import_fx_prepare(self): 115 function_list = ["prepare"] 116 self._test_function_import("fx.prepare", function_list) 117 118 def test_function_import_fx_convert(self): 119 function_list = ["convert"] 120 self._test_function_import("fx.convert", function_list) 121 122 def test_function_import_fx_fuse(self): 123 function_list = ["fuse"] 124 self._test_function_import("fx.fuse", function_list) 125 126 def test_function_import_fx_fusion_patterns(self): 127 function_list = ["FuseHandler", "DefaultFuseHandler"] 128 self._test_function_import( 129 "fx.fusion_patterns", 130 function_list, 131 new_package_name="fx.fuse_handler", 132 ) 133 134 # we removed matching test for torch.quantization.fx.quantization_types 135 # old: torch.quantization.fx.quantization_types 136 # new: torch.ao.quantization.utils 137 # both are valid, but we'll deprecate the old path in the future 138 139 def test_function_import_fx_utils(self): 140 function_list = [ 141 "get_custom_module_class_keys", 142 "get_linear_prepack_op_for_dtype", 143 "get_qconv_prepack_op", 144 "get_new_attr_name_with_prefix", 145 "graph_module_from_producer_nodes", 146 "assert_and_get_unique_device", 147 "create_getattr_from_value", 148 "all_node_args_have_no_tensors", 149 "get_non_observable_arg_indexes_and_types", 150 "maybe_get_next_module", 151 ] 152 self._test_function_import("fx.utils", function_list) 153