xref: /aosp_15_r20/external/pytorch/test/quantization/ao_migration/test_ao_migration.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: quantization"]
2
3from .common import AOMigrationTestCase
4
5
6class TestAOMigrationNNQuantized(AOMigrationTestCase):
7    def test_functional_import(self):
8        r"""Tests the migration of the torch.nn.quantized.functional"""
9        function_list = [
10            "avg_pool2d",
11            "avg_pool3d",
12            "adaptive_avg_pool2d",
13            "adaptive_avg_pool3d",
14            "conv1d",
15            "conv2d",
16            "conv3d",
17            "interpolate",
18            "linear",
19            "max_pool1d",
20            "max_pool2d",
21            "celu",
22            "leaky_relu",
23            "hardtanh",
24            "hardswish",
25            "threshold",
26            "elu",
27            "hardsigmoid",
28            "clamp",
29            "upsample",
30            "upsample_bilinear",
31            "upsample_nearest",
32        ]
33        self._test_function_import("functional", function_list, base="nn.quantized")
34
35    def test_modules_import(self):
36        module_list = [
37            # Modules
38            "BatchNorm2d",
39            "BatchNorm3d",
40            "Conv1d",
41            "Conv2d",
42            "Conv3d",
43            "ConvTranspose1d",
44            "ConvTranspose2d",
45            "ConvTranspose3d",
46            "DeQuantize",
47            "ELU",
48            "Embedding",
49            "EmbeddingBag",
50            "GroupNorm",
51            "Hardswish",
52            "InstanceNorm1d",
53            "InstanceNorm2d",
54            "InstanceNorm3d",
55            "LayerNorm",
56            "LeakyReLU",
57            "Linear",
58            "MaxPool2d",
59            "Quantize",
60            "ReLU6",
61            "Sigmoid",
62            "Softmax",
63            "Dropout",
64            # Wrapper modules
65            "FloatFunctional",
66            "FXFloatFunctional",
67            "QFunctional",
68        ]
69        self._test_function_import("modules", module_list, base="nn.quantized")
70
71    def test_modules_activation(self):
72        function_list = [
73            "ReLU6",
74            "Hardswish",
75            "ELU",
76            "LeakyReLU",
77            "Sigmoid",
78            "Softmax",
79        ]
80        self._test_function_import(
81            "activation", function_list, base="nn.quantized.modules"
82        )
83
84    def test_modules_batchnorm(self):
85        function_list = [
86            "BatchNorm2d",
87            "BatchNorm3d",
88        ]
89        self._test_function_import(
90            "batchnorm", function_list, base="nn.quantized.modules"
91        )
92
93    def test_modules_conv(self):
94        function_list = [
95            "_reverse_repeat_padding",
96            "Conv1d",
97            "Conv2d",
98            "Conv3d",
99            "ConvTranspose1d",
100            "ConvTranspose2d",
101            "ConvTranspose3d",
102        ]
103
104        self._test_function_import("conv", function_list, base="nn.quantized.modules")
105
106    def test_modules_dropout(self):
107        function_list = [
108            "Dropout",
109        ]
110        self._test_function_import(
111            "dropout", function_list, base="nn.quantized.modules"
112        )
113
114    def test_modules_embedding_ops(self):
115        function_list = [
116            "EmbeddingPackedParams",
117            "Embedding",
118            "EmbeddingBag",
119        ]
120        self._test_function_import(
121            "embedding_ops", function_list, base="nn.quantized.modules"
122        )
123
124    def test_modules_functional_modules(self):
125        function_list = [
126            "FloatFunctional",
127            "FXFloatFunctional",
128            "QFunctional",
129        ]
130        self._test_function_import(
131            "functional_modules", function_list, base="nn.quantized.modules"
132        )
133
134    def test_modules_linear(self):
135        function_list = [
136            "Linear",
137            "LinearPackedParams",
138        ]
139        self._test_function_import("linear", function_list, base="nn.quantized.modules")
140
141    def test_modules_normalization(self):
142        function_list = [
143            "LayerNorm",
144            "GroupNorm",
145            "InstanceNorm1d",
146            "InstanceNorm2d",
147            "InstanceNorm3d",
148        ]
149        self._test_function_import(
150            "normalization", function_list, base="nn.quantized.modules"
151        )
152
153    def test_modules_utils(self):
154        function_list = [
155            "_ntuple_from_first",
156            "_pair_from_first",
157            "_quantize_weight",
158            "_hide_packed_params_repr",
159            "WeightedQuantizedModule",
160        ]
161        self._test_function_import("utils", function_list, base="nn.quantized.modules")
162
163    def test_import_nn_quantized_dynamic_import(self):
164        module_list = [
165            # Modules
166            "Linear",
167            "LSTM",
168            "GRU",
169            "LSTMCell",
170            "RNNCell",
171            "GRUCell",
172            "Conv1d",
173            "Conv2d",
174            "Conv3d",
175            "ConvTranspose1d",
176            "ConvTranspose2d",
177            "ConvTranspose3d",
178        ]
179        self._test_function_import("dynamic", module_list, base="nn.quantized")
180
181    def test_import_nn_quantizable_activation(self):
182        module_list = [
183            # Modules
184            "MultiheadAttention",
185        ]
186        self._test_function_import(
187            "activation", module_list, base="nn.quantizable.modules"
188        )
189
190    def test_import_nn_quantizable_rnn(self):
191        module_list = [
192            # Modules
193            "LSTM",
194            "LSTMCell",
195        ]
196        self._test_function_import("rnn", module_list, base="nn.quantizable.modules")
197
198    def test_import_nn_qat_conv(self):
199        module_list = [
200            "Conv1d",
201            "Conv2d",
202            "Conv3d",
203        ]
204        self._test_function_import("conv", module_list, base="nn.qat.modules")
205
206    def test_import_nn_qat_embedding_ops(self):
207        module_list = [
208            "Embedding",
209            "EmbeddingBag",
210        ]
211        self._test_function_import("embedding_ops", module_list, base="nn.qat.modules")
212
213    def test_import_nn_qat_linear(self):
214        module_list = [
215            "Linear",
216        ]
217        self._test_function_import("linear", module_list, base="nn.qat.modules")
218
219    def test_import_nn_qat_dynamic_linear(self):
220        module_list = [
221            "Linear",
222        ]
223        self._test_function_import("linear", module_list, base="nn.qat.dynamic.modules")
224
225
226class TestAOMigrationNNIntrinsic(AOMigrationTestCase):
227    def test_modules_import_nn_intrinsic(self):
228        module_list = [
229            # Modules
230            "_FusedModule",
231            "ConvBn1d",
232            "ConvBn2d",
233            "ConvBn3d",
234            "ConvBnReLU1d",
235            "ConvBnReLU2d",
236            "ConvBnReLU3d",
237            "ConvReLU1d",
238            "ConvReLU2d",
239            "ConvReLU3d",
240            "LinearReLU",
241            "BNReLU2d",
242            "BNReLU3d",
243            "LinearBn1d",
244        ]
245        self._test_function_import("intrinsic", module_list, base="nn")
246
247    def test_modules_nn_intrinsic_fused(self):
248        function_list = [
249            "_FusedModule",
250            "ConvBn1d",
251            "ConvBn2d",
252            "ConvBn3d",
253            "ConvBnReLU1d",
254            "ConvBnReLU2d",
255            "ConvBnReLU3d",
256            "ConvReLU1d",
257            "ConvReLU2d",
258            "ConvReLU3d",
259            "LinearReLU",
260            "BNReLU2d",
261            "BNReLU3d",
262            "LinearBn1d",
263        ]
264        self._test_function_import("fused", function_list, base="nn.intrinsic.modules")
265
266    def test_modules_import_nn_intrinsic_qat(self):
267        module_list = [
268            "LinearReLU",
269            "LinearBn1d",
270            "ConvReLU1d",
271            "ConvReLU2d",
272            "ConvReLU3d",
273            "ConvBn1d",
274            "ConvBn2d",
275            "ConvBn3d",
276            "ConvBnReLU1d",
277            "ConvBnReLU2d",
278            "ConvBnReLU3d",
279            "update_bn_stats",
280            "freeze_bn_stats",
281        ]
282        self._test_function_import("qat", module_list, base="nn.intrinsic")
283
284    def test_modules_intrinsic_qat_conv_fused(self):
285        function_list = [
286            "ConvBn1d",
287            "ConvBnReLU1d",
288            "ConvReLU1d",
289            "ConvBn2d",
290            "ConvBnReLU2d",
291            "ConvReLU2d",
292            "ConvBn3d",
293            "ConvBnReLU3d",
294            "ConvReLU3d",
295            "update_bn_stats",
296            "freeze_bn_stats",
297        ]
298        self._test_function_import(
299            "conv_fused", function_list, base="nn.intrinsic.qat.modules"
300        )
301
302    def test_modules_intrinsic_qat_linear_fused(self):
303        function_list = [
304            "LinearBn1d",
305        ]
306        self._test_function_import(
307            "linear_fused", function_list, base="nn.intrinsic.qat.modules"
308        )
309
310    def test_modules_intrinsic_qat_linear_relu(self):
311        function_list = [
312            "LinearReLU",
313        ]
314        self._test_function_import(
315            "linear_relu", function_list, base="nn.intrinsic.qat.modules"
316        )
317
318    def test_modules_import_nn_intrinsic_quantized(self):
319        module_list = [
320            "BNReLU2d",
321            "BNReLU3d",
322            "ConvReLU1d",
323            "ConvReLU2d",
324            "ConvReLU3d",
325            "LinearReLU",
326        ]
327        self._test_function_import("quantized", module_list, base="nn.intrinsic")
328
329    def test_modules_intrinsic_quantized_bn_relu(self):
330        function_list = [
331            "BNReLU2d",
332            "BNReLU3d",
333        ]
334        self._test_function_import(
335            "bn_relu", function_list, base="nn.intrinsic.quantized.modules"
336        )
337
338    def test_modules_intrinsic_quantized_conv_relu(self):
339        function_list = [
340            "ConvReLU1d",
341            "ConvReLU2d",
342            "ConvReLU3d",
343        ]
344        self._test_function_import(
345            "conv_relu", function_list, base="nn.intrinsic.quantized.modules"
346        )
347
348    def test_modules_intrinsic_quantized_linear_relu(self):
349        function_list = [
350            "LinearReLU",
351        ]
352        self._test_function_import(
353            "linear_relu", function_list, base="nn.intrinsic.quantized.modules"
354        )
355
356    def test_modules_no_import_nn_intrinsic_quantized_dynamic(self):
357        # TODO(future PR): generalize this
358        import torch
359
360        _ = torch.ao.nn.intrinsic.quantized.dynamic
361        _ = torch.nn.intrinsic.quantized.dynamic
362