xref: /aosp_15_r20/external/pytorch/test/quantization/eager/test_fuse_eager.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: quantization"]
2
3import copy
4
5import torch
6import torch.nn as nn
7import torch.ao.nn.quantized as nnq
8import torch.ao.nn.intrinsic as nni
9import torch.ao.nn.intrinsic.quantized as nniq
10import torch.ao.nn.intrinsic.qat as nniqat
11from torch.ao.quantization import (
12    quantize,
13    prepare,
14    convert,
15    prepare_qat,
16    quantize_qat,
17    fuse_modules,
18    fuse_modules_qat,
19    QConfig,
20    default_qconfig,
21    default_qat_qconfig,
22)
23
24from torch.testing._internal.common_quantization import (
25    QuantizationTestCase,
26    ModelForFusion,
27    ModelWithSequentialFusion,
28    ModelForLinearBNFusion,
29    ModelForFusionWithBias,
30    ModelForConvTransposeBNFusion,
31    SingleLayerLinearModel,
32    test_only_eval_fn,
33    test_only_train_fn,
34    skipIfNoFBGEMM,
35)
36
37from torch.testing._internal.common_quantized import (
38    override_quantized_engine,
39    supported_qengines,
40)
41
42
43@skipIfNoFBGEMM
44class TestFuseEager(QuantizationTestCase):
45    def test_fuse_module_train(self):
46        model = ModelForFusion(default_qat_qconfig).train()
47        # Test step by step fusion
48        model = fuse_modules_qat(model, ['conv1', 'bn1', 'relu1'])
49        model = fuse_modules_qat(model, ['sub1.conv', 'sub1.bn'])
50        self.assertEqual(type(model.conv1), nni.ConvBnReLU2d,
51                         msg="Fused Conv + BN + Relu first layer")
52        self.assertEqual(type(model.bn1), torch.nn.Identity,
53                         msg="Fused Conv + BN + Relu (skipped BN)")
54        self.assertEqual(type(model.relu1), torch.nn.Identity,
55                         msg="Fused Conv + BN + Relu (skipped Relu)")
56
57        self.assertEqual(type(model.sub1.conv), nni.ConvBn2d,
58                         msg="Fused submodule Conv + BN")
59        self.assertEqual(type(model.sub1.bn), torch.nn.Identity,
60                         msg="Fused submodule Conv + BN (skipped BN)")
61        self.assertEqual(type(model.sub2.conv), torch.nn.Conv2d,
62                         msg="Non-fused submodule Conv")
63        self.assertEqual(type(model.sub2.relu), torch.nn.ReLU,
64                         msg="Non-fused submodule ReLU")
65        model = prepare_qat(model)
66        self.checkObservers(model)
67
68        def checkQAT(model):
69            self.assertEqual(type(model.conv1), nniqat.ConvBnReLU2d)
70            self.assertEqual(type(model.bn1), nn.Identity)
71            self.assertEqual(type(model.relu1), nn.Identity)
72            self.assertEqual(type(model.sub1.conv), nniqat.ConvBn2d)
73            self.assertEqual(type(model.sub1.bn), nn.Identity)
74            self.assertEqual(type(model.sub2.conv), nn.Conv2d)
75            self.assertEqual(type(model.sub2.relu), nn.ReLU)
76
77        checkQAT(model)
78        test_only_train_fn(model, self.img_data_1d_train)
79        model = convert(model)
80
81        def checkQuantized(model):
82            self.assertEqual(type(model.conv1), nniq.ConvReLU2d)
83            self.assertEqual(type(model.bn1), nn.Identity)
84            self.assertEqual(type(model.relu1), nn.Identity)
85            self.assertEqual(type(model.sub1.conv), nnq.Conv2d)
86            self.assertEqual(type(model.sub1.bn), nn.Identity)
87            self.assertEqual(type(model.sub2.conv), nn.Conv2d)
88            self.assertEqual(type(model.sub2.relu), nn.ReLU)
89            test_only_eval_fn(model, self.img_data_1d)
90            self.checkNoQconfig(model)
91
92        with self.assertRaisesRegex(RuntimeError, "Could not run 'aten::native_batch_norm' with arguments from the 'QuantizedCPU'"):
93            checkQuantized(model)
94
95        model = ModelForFusion(default_qat_qconfig).train()
96        model = fuse_modules_qat(
97            model,
98            [['conv1', 'bn1', 'relu1'],
99             ['sub1.conv', 'sub1.bn']])
100        model = quantize_qat(model, test_only_train_fn, [self.img_data_1d_train])
101        with self.assertRaisesRegex(RuntimeError, "Could not run 'aten::native_batch_norm' with arguments from the 'QuantizedCPU'"):
102            checkQuantized(model)
103
104
105    def test_fuse_module_eval(self):
106        model = ModelForFusion(default_qconfig)
107        model.eval()
108        model = fuse_modules(
109            model,
110            [['conv3', 'bn3', 'relu4'],
111             ['conv1', 'bn1', 'relu1'],
112             ['conv2', 'relu2'],
113             ['bn2', 'relu3'],
114             ['sub1.conv', 'sub1.bn']])
115        self.assertEqual(type(model.conv1), nni.ConvReLU2d,
116                         msg="Fused Conv + BN + Relu first layer (BN is folded)")
117        self.assertEqual(type(model.conv1[0]), nn.Conv2d,
118                         msg="Fused Conv + BN + Relu (Conv + folded BN only)")
119        self.assertEqual(type(model.conv1[1]), nn.ReLU,
120                         msg="Fused Conv + BN + Relu second layer (Relu only)")
121        self.assertEqual(type(model.bn1), nn.Identity,
122                         msg="Fused Conv + BN + Relu second layer (Skipped BN)")
123        self.assertEqual(type(model.relu1), nn.Identity,
124                         msg="Fused Conv + BN + Relu second layer (Skipped Relu)")
125        self.assertEqual(type(model.conv2), nni.ConvReLU3d,
126                         msg="Fused Conv + BN + Relu first layer (BN is folded)")
127        self.assertEqual(type(model.bn2), nni.BNReLU3d,
128                         msg="Fused BN + Relu first layer (Relu is folded))")
129        self.assertEqual(type(model.relu3), nn.Identity,
130                         msg="Fused BN + Relu second layer (Skipped Relu)")
131        self.assertEqual(type(model.conv2[0]), nn.Conv3d,
132                         msg="Fused Conv + BN + Relu (Conv + folded BN only)")
133        self.assertEqual(type(model.conv2[1]), nn.ReLU,
134                         msg="Fused Conv + BN + Relu second layer (Relu only)")
135        self.assertEqual(type(model.relu2), nn.Identity,
136                         msg="Fused Conv + BN + Relu second layer (Skipped Relu)")
137
138        self.assertEqual(type(model.conv3), nni.ConvReLU1d,
139                         msg="Fused Conv + Relu for Conv1d (folded BN)")
140        self.assertEqual(type(model.conv3[0]), nn.Conv1d,
141                         msg="Fused Conv + Relu for Conv1d ")
142        self.assertEqual(type(model.conv3[1]), nn.ReLU,
143                         msg="Fused Conv + Relu for Conv1d")
144        self.assertEqual(type(model.bn3), nn.Identity,
145                         msg="Fused Conv + BN + Relu for Conv1d (Skipped BN)")
146
147        self.assertEqual(type(model.sub1.conv), nn.Conv2d,
148                         msg="Fused submodule Conv + folded BN")
149        self.assertEqual(type(model.sub1.bn), nn.Identity,
150                         msg="Fused submodule (skipped BN)")
151        self.assertEqual(type(model.sub2.conv), nn.Conv2d,
152                         msg="Non-fused submodule Conv")
153        self.assertEqual(type(model.sub2.relu), torch.nn.ReLU,
154                         msg="Non-fused submodule ReLU")
155
156        model = prepare(model)
157        self.checkObservers(model)
158        test_only_eval_fn(model, self.img_data_1d)
159        model = convert(model)
160
161        def checkQuantized(model):
162            self.assertEqual(type(model.conv3), nniq.ConvReLU1d)
163            self.assertEqual(type(model.conv1), nniq.ConvReLU2d)
164            self.assertEqual(type(model.bn1), nn.Identity)
165            self.assertEqual(type(model.relu1), nn.Identity)
166            self.assertEqual(type(model.sub1.conv), nnq.Conv2d)
167            self.assertEqual(type(model.sub1.bn), nn.Identity)
168            self.assertEqual(type(model.sub2.conv), nn.Conv2d)
169            self.assertEqual(type(model.sub2.relu), nn.ReLU)
170            self.assertEqual(type(model.bn2), nniq.BNReLU3d)
171            test_only_eval_fn(model, self.img_data_1d)
172            self.checkNoQconfig(model)
173
174        checkQuantized(model)
175
176        model = ModelForFusion(default_qconfig).eval()
177        model = fuse_modules(
178            model,
179            [['conv1', 'bn1', 'relu1'],
180             ['conv2', 'relu2'],
181             ['bn2', 'relu3'],
182             ['sub1.conv', 'sub1.bn'],
183             ['conv3', 'bn3', 'relu4']])
184        model = quantize(model, test_only_eval_fn, [self.img_data_1d])
185        checkQuantized(model)
186
187    def test_fusion_sequential_model_train(self):
188        for qengine in supported_qengines:
189            with override_quantized_engine(qengine):
190                model = ModelWithSequentialFusion().train()
191                model.to(torch.float)
192                fuse_modules_qat(
193                    model, [['conv1', 'relu1'] ,
194                            ['features.0.0', 'features.0.1', 'features.0.2'],
195                            ['features.1.0', 'features.1.1', 'features.1.2'],
196                            ['features.2.0', 'features.2.1', 'features.2.2'],
197                            ['classifier.0', 'classifier.1']],
198                    inplace=True)
199                self.assertEqual(type(model.conv1), nni.ConvReLU2d,
200                                 msg="Fused Conv + Relu: nni.ConvReLU2d")
201                self.assertEqual(type(model.conv1[0]), nn.Conv2d,
202                                 msg="Fused Conv + Relu: Conv2d")
203                self.assertEqual(type(model.conv1[1]), nn.ReLU,
204                                 msg="Fused Conv + Relu: Relu")
205                self.assertEqual(type(model.relu1), nn.Identity,
206                                 msg="Fused Conv + Relu: Identity")
207                for i in range(3):
208                    self.assertEqual(type(model.features[i][0]), nni.ConvBnReLU2d,
209                                     msg="Fused submodule Conv + folded BN")
210                    self.assertEqual(type(model.features[i][1]), nn.Identity,
211                                     msg="Fused submodule (skipped BN)")
212                    self.assertEqual(type(model.features[i][2]), nn.Identity,
213                                     msg="Non-fused submodule Conv")
214                self.assertEqual(type(model.classifier[0]), nni.LinearReLU)
215                self.assertEqual(type(model.classifier[1]), nn.Identity)
216                model.qconfig = torch.ao.quantization.get_default_qat_qconfig(qengine)
217                prepare_qat(model, inplace=True)
218                self.checkObservers(model)
219                model(self.img_data_2d[0][0])
220
221
222                def checkQAT(model):
223                    self.assertEqual(type(model.conv1), nniqat.ConvReLU2d)
224                    self.assertEqual(type(model.relu1), nn.Identity)
225                for i in range(3):
226                    self.assertEqual(type(model.features[i][0]), nniqat.ConvBnReLU2d,
227                                     msg="Fused submodule Conv + folded BN")
228                    self.assertEqual(type(model.features[i][1]), nn.Identity,
229                                     msg="Fused submodule (skipped BN)")
230                    self.assertEqual(type(model.features[i][2]), nn.Identity,
231                                     msg="Non-fused submodule Conv")
232                self.assertEqual(type(model.classifier[0]), nniqat.LinearReLU)
233                self.assertEqual(type(model.classifier[1]), nn.Identity)
234
235                checkQAT(model)
236                model(self.img_data_2d[1][0])
237                convert(model, inplace=True)
238                model(self.img_data_2d[1][0])
239                self.checkModelWithSequentialQuantized(model)
240
241    def test_fusion_sequential_model_eval(self):
242        for qengine in supported_qengines:
243            with override_quantized_engine(qengine):
244                model = ModelWithSequentialFusion().eval()
245                model.to(torch.float)
246                fuse_modules(
247                    model,
248                    [['conv1', 'relu1'],
249                     ['features.0.0', 'features.0.1', 'features.0.2'],
250                     ['features.1.0', 'features.1.1', 'features.1.2'],
251                     ['features.2.0', 'features.2.1', 'features.2.2'],
252                     ['classifier.0', 'classifier.1']],
253                    inplace=True)
254                self.assertEqual(type(model.conv1), nni.ConvReLU2d,
255                                 msg="Fused Conv + Relu: nni.ConvReLU2d")
256                self.assertEqual(type(model.conv1[0]), nn.Conv2d,
257                                 msg="Fused Conv + Relu: Conv2d")
258                self.assertEqual(type(model.conv1[1]), nn.ReLU,
259                                 msg="Fused Conv + Relu: Relu")
260                self.assertEqual(type(model.relu1), nn.Identity,
261                                 msg="Fused Conv + Relu: Identity")
262                for i in range(3):
263                    self.assertEqual(type(model.features[i][0]), nni.ConvReLU2d,
264                                     msg="Fused submodule Conv + folded BN")
265                    self.assertEqual(type(model.features[i][1]), nn.Identity,
266                                     msg="Fused submodule (skipped BN)")
267                    self.assertEqual(type(model.features[i][2]), nn.Identity,
268                                     msg="Non-fused submodule Conv")
269                self.assertEqual(type(model.classifier[0]), nni.LinearReLU)
270                self.assertEqual(type(model.classifier[1]), nn.Identity)
271                model.qconfig = torch.ao.quantization.get_default_qconfig(qengine)
272                prepare(model, inplace=True)
273                self.checkObservers(model)
274                model(self.img_data_2d[0][0])
275                convert(model, inplace=True)
276                model(self.img_data_2d[1][0])
277                self.checkModelWithSequentialQuantized(model)
278
279    def checkModelWithSequentialQuantized(self, model):
280        self.assertEqual(type(model.conv1), nniq.ConvReLU2d)
281        self.assertEqual(type(model.relu1), nn.Identity)
282        for i in range(3):
283            self.assertEqual(type(model.features[i][0]), nniq.ConvReLU2d)
284            self.assertEqual(type(model.features[i][1]), nn.Identity)
285            self.assertEqual(type(model.features[i][2]), nn.Identity)
286        self.assertEqual(type(model.classifier[0]), nniq.LinearReLU)
287        self.assertEqual(type(model.classifier[1]), nn.Identity)
288
289    def test_fusion_conv_with_bias(self):
290        for qengine in supported_qengines:
291            with override_quantized_engine(qengine):
292                model_orig = ModelForFusionWithBias().train()
293
294                # reference model
295                model_ref = copy.deepcopy(model_orig)
296                # output with no fusion.
297                out_ref = model_ref(self.img_data_2d[0][0])
298
299                # fused model
300                model_orig.qconfig = QConfig(activation=torch.nn.Identity,
301                                             weight=torch.nn.Identity)
302                model = fuse_modules_qat(
303                    model_orig,
304                    [["conv1", "bn1", "relu1"],
305                     ["conv2", "bn2"]])
306                prep_model = prepare_qat(model, inplace=False)
307                # output with fusion but no observers.
308                out_fused = prep_model(self.img_data_2d[0][0])
309
310                self.assertEqual(out_ref, out_fused)
311
312                def checkBN(bn_ref, bn):
313                    self.assertEqual(bn_ref.weight, bn.weight)
314                    self.assertEqual(bn_ref.bias, bn.bias)
315                    self.assertEqual(bn_ref.running_mean, bn.running_mean)
316                    self.assertEqual(bn_ref.running_var, bn.running_var)
317
318                checkBN(model_ref.bn1, prep_model.conv1.bn)
319                checkBN(model_ref.bn2, prep_model.conv2.bn)
320
321                model.qconfig = torch.ao.quantization.get_default_qconfig(qengine)
322                prepare_qat(model, inplace=True)
323
324                model(self.img_data_2d[0][0])
325
326                def checkQAT(model):
327                    self.assertEqual(type(model.conv1), nniqat.ConvBnReLU2d)
328                    self.assertEqual(type(model.bn1), nn.Identity)
329                    self.assertEqual(type(model.relu1), nn.Identity)
330                    self.assertEqual(type(model.conv2), nniqat.ConvBn2d)
331                    self.assertEqual(type(model.bn2), nn.Identity)
332
333                checkQAT(model)
334
335
336    def test_fusion_linear_bn_eval(self):
337        model = ModelForLinearBNFusion().train()
338        inp1 = torch.randn(8, 20)
339        inp2 = torch.randn(8, 20)
340
341        # Get some interesting values into the running mean and variance.
342        model(inp1)
343        model.eval()
344        golden = model(inp2)
345
346        model = fuse_modules(model, [["fc", "bn"]])
347        self.assertEqual(type(model.bn), nn.Identity)
348        self.assertEqual(golden, model(inp2))
349
350    def test_fusion_convtranspose_bn_eval(self):
351        model = ModelForConvTransposeBNFusion().train()
352        inp1 = torch.randn(8, 3, 16)
353        inp2 = torch.randn(8, 3, 16)
354
355        # Get some interesting values into the running mean and variance.
356        model(inp1)
357        model.eval()
358        golden = model(inp2)
359
360        model = fuse_modules(model, [["conv1", "bn1"], ["conv2", "bn2"], ["conv3", "bn3"]])
361        self.assertEqual(type(model.bn1), nn.Identity)
362        self.assertEqual(type(model.bn2), nn.Identity)
363        self.assertEqual(type(model.bn3), nn.Identity)
364
365        self.assertEqual(golden, model(inp2))
366
367    def test_fuse_function_customization(self):
368        dummy_model = SingleLayerLinearModel().train()
369        dummy_model.eval()
370
371        # A custom fuse funct
372        def custom_fuse_func(module, is_qat, add_fuser_mapping):
373            return [torch.nn.Identity()]
374
375        dummy_model = fuse_modules(dummy_model, [["fc1"]], fuser_func=custom_fuse_func)
376        self.assertEqual(type(dummy_model.fc1), nn.Identity)
377
378    def test_forward_hooks_preserved(self):
379        r"""Test case that checks whether forward pre hooks of the first module and
380        post forward hooks of the last module in modules list passed to fusion function preserved.
381        (e.g. before fusion: [nn.Conv2d (with pre forward hooks), nn.BatchNorm2d, nn.ReLU (with post forward hooks)]
382        after fusion: [nni.ConvBnReLU2d (with pre and post hooks), nn.Identity, nn.Identity])
383        """
384        model = ModelForFusion(default_qat_qconfig).train()
385
386        counter = {
387            'pre_forwards': 0,
388            'forwards': 0,
389        }
390        fused = False
391
392        def fw_pre_hook(fused_module_class, h_module, input):
393            if fused:
394                self.assertEqual(type(h_module), fused_module_class,
395                                 "After fusion owner of the first module's forward pre hook is not a fused module")
396            counter['pre_forwards'] += 1
397
398        def fw_hook(fused_module_class, h_module, input, output):
399            if fused:
400                self.assertEqual(type(h_module), fused_module_class,
401                                 "After fusion owner of the last module's forward hook is not a fused module")
402            counter['forwards'] += 1
403
404        # Registering two pre and two post forward hooks, thus expecting counter increment by two each inference
405        model.conv1.register_forward_pre_hook(lambda *args: fw_pre_hook(nni.ConvBnReLU2d, *args))
406        model.sub1.conv.register_forward_pre_hook(lambda *args: fw_pre_hook(nni.ConvBn2d, *args))
407        model.relu1.register_forward_hook(lambda *args: fw_hook(nni.ConvBnReLU2d, *args))
408        model.sub1.bn.register_forward_hook(lambda *args: fw_hook(nni.ConvBn2d, *args))
409
410        test_only_eval_fn(model, self.img_data_1d)
411        self.assertEqual(counter['pre_forwards'], 2 * len(self.img_data_1d))
412        self.assertEqual(counter['forwards'], 2 * len(self.img_data_1d))
413
414        model = fuse_modules_qat(model, ['conv1', 'bn1', 'relu1'])
415        model = fuse_modules_qat(model, ['sub1.conv', 'sub1.bn'])
416
417        fused = True
418        before_fusion_pre_count = counter['pre_forwards']
419        before_fusion_post_count = counter['forwards']
420        test_only_eval_fn(model, self.img_data_1d)
421        self.assertEqual(counter['pre_forwards'] - before_fusion_pre_count, 2 * len(self.img_data_1d))
422        self.assertEqual(counter['forwards'] - before_fusion_post_count, 2 * len(self.img_data_1d))
423
424    def test_fuse_modules_with_nested_hooks(self):
425        r"""Test case that checks whether a nested module with sub-sub modules registered with hooks
426        can be safely fused. Safeguard for issues similar to https://github.com/pytorch/pytorch/issues/105063
427        in the future.
428        """
429        def myhook(*x):
430            return ""
431        for qengine in supported_qengines:
432            with override_quantized_engine(qengine):
433                model = ModelWithSequentialFusion().eval()
434
435                for sub_model in model.modules():
436                    if isinstance(sub_model, nn.Sequential):
437                        for layer in sub_model:
438                            if hasattr(layer, 'register_forward_hook'):
439                                layer.register_forward_hook(myhook)
440
441                fuse_modules(model, [['features.0.0', 'features.0.1', 'features.0.2']], inplace=True)
442                self.assertEqual(
443                    type(model.features[0][0]),
444                    nni.ConvReLU2d,
445                    msg="Fused submodule Conv + folded BN"
446                )
447                self.assertEqual(
448                    type(model.features[0][1]),
449                    nn.Identity,
450                    msg="Fused submodule (skipped BN)"
451                )
452                self.assertEqual(
453                    type(model.features[0][2]),
454                    nn.Identity,
455                    msg="Non-fused submodule Conv"
456                )
457
458
459if __name__ == '__main__':
460    raise RuntimeError(
461        "This test file is not meant to be run directly, use:\n\n"
462        "\tpython test/test_quantization.py TESTNAME\n\n"
463        "instead."
464    )
465