xref: /aosp_15_r20/external/pytorch/test/quantization/ao_migration/test_quantization_fx.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: quantization"]
2
3from .common import AOMigrationTestCase
4
5
6class TestAOMigrationQuantizationFx(AOMigrationTestCase):
7    def test_function_import_quantize_fx(self):
8        function_list = [
9            "_check_is_graph_module",
10            "_swap_ff_with_fxff",
11            "_fuse_fx",
12            "QuantizationTracer",
13            "_prepare_fx",
14            "_prepare_standalone_module_fx",
15            "fuse_fx",
16            "Scope",
17            "ScopeContextManager",
18            "prepare_fx",
19            "prepare_qat_fx",
20            "_convert_fx",
21            "convert_fx",
22            "_convert_standalone_module_fx",
23        ]
24        self._test_function_import("quantize_fx", function_list)
25
26    def test_function_import_fx(self):
27        function_list = [
28            "prepare",
29            "convert",
30            "fuse",
31        ]
32        self._test_function_import("fx", function_list)
33
34    def test_function_import_fx_graph_module(self):
35        function_list = [
36            "FusedGraphModule",
37            "ObservedGraphModule",
38            "_is_observed_module",
39            "ObservedStandaloneGraphModule",
40            "_is_observed_standalone_module",
41            "QuantizedGraphModule",
42        ]
43        self._test_function_import("fx.graph_module", function_list)
44
45    def test_function_import_fx_pattern_utils(self):
46        function_list = [
47            "QuantizeHandler",
48            "_register_fusion_pattern",
49            "get_default_fusion_patterns",
50            "_register_quant_pattern",
51            "get_default_quant_patterns",
52            "get_default_output_activation_post_process_map",
53        ]
54        self._test_function_import("fx.pattern_utils", function_list)
55
56    def test_function_import_fx_equalize(self):
57        function_list = [
58            "reshape_scale",
59            "_InputEqualizationObserver",
60            "_WeightEqualizationObserver",
61            "calculate_equalization_scale",
62            "EqualizationQConfig",
63            "input_equalization_observer",
64            "weight_equalization_observer",
65            "default_equalization_qconfig",
66            "fused_module_supports_equalization",
67            "nn_module_supports_equalization",
68            "node_supports_equalization",
69            "is_equalization_observer",
70            "get_op_node_and_weight_eq_obs",
71            "maybe_get_weight_eq_obs_node",
72            "maybe_get_next_input_eq_obs",
73            "maybe_get_next_equalization_scale",
74            "scale_input_observer",
75            "scale_weight_node",
76            "scale_weight_functional",
77            "clear_weight_quant_obs_node",
78            "remove_node",
79            "update_obs_for_equalization",
80            "convert_eq_obs",
81            "_convert_equalization_ref",
82            "get_layer_sqnr_dict",
83            "get_equalization_qconfig_dict",
84        ]
85        self._test_function_import("fx._equalize", function_list)
86
87    def test_function_import_fx_quantization_patterns(self):
88        function_list = [
89            "QuantizeHandler",
90            "BinaryOpQuantizeHandler",
91            "CatQuantizeHandler",
92            "ConvReluQuantizeHandler",
93            "LinearReLUQuantizeHandler",
94            "BatchNormQuantizeHandler",
95            "EmbeddingQuantizeHandler",
96            "RNNDynamicQuantizeHandler",
97            "DefaultNodeQuantizeHandler",
98            "FixedQParamsOpQuantizeHandler",
99            "CopyNodeQuantizeHandler",
100            "CustomModuleQuantizeHandler",
101            "GeneralTensorShapeOpQuantizeHandler",
102            "StandaloneModuleQuantizeHandler",
103        ]
104        self._test_function_import(
105            "fx.quantization_patterns",
106            function_list,
107            new_package_name="fx.quantize_handler",
108        )
109
110    def test_function_import_fx_match_utils(self):
111        function_list = ["_MatchResult", "MatchAllNode", "_is_match", "_find_matches"]
112        self._test_function_import("fx.match_utils", function_list)
113
114    def test_function_import_fx_prepare(self):
115        function_list = ["prepare"]
116        self._test_function_import("fx.prepare", function_list)
117
118    def test_function_import_fx_convert(self):
119        function_list = ["convert"]
120        self._test_function_import("fx.convert", function_list)
121
122    def test_function_import_fx_fuse(self):
123        function_list = ["fuse"]
124        self._test_function_import("fx.fuse", function_list)
125
126    def test_function_import_fx_fusion_patterns(self):
127        function_list = ["FuseHandler", "DefaultFuseHandler"]
128        self._test_function_import(
129            "fx.fusion_patterns",
130            function_list,
131            new_package_name="fx.fuse_handler",
132        )
133
134    # we removed matching test for torch.quantization.fx.quantization_types
135    # old: torch.quantization.fx.quantization_types
136    # new: torch.ao.quantization.utils
137    # both are valid, but we'll deprecate the old path in the future
138
139    def test_function_import_fx_utils(self):
140        function_list = [
141            "get_custom_module_class_keys",
142            "get_linear_prepack_op_for_dtype",
143            "get_qconv_prepack_op",
144            "get_new_attr_name_with_prefix",
145            "graph_module_from_producer_nodes",
146            "assert_and_get_unique_device",
147            "create_getattr_from_value",
148            "all_node_args_have_no_tensors",
149            "get_non_observable_arg_indexes_and_types",
150            "maybe_get_next_module",
151        ]
152        self._test_function_import("fx.utils", function_list)
153