xref: /aosp_15_r20/external/pytorch/test/quantization/bc/test_backward_compatibility.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: quantization"]
2
3import os
4import sys
5import unittest
6from typing import Set
7
8# torch
9import torch
10import torch.ao.nn.intrinsic.quantized as nniq
11import torch.ao.nn.quantized as nnq
12import torch.ao.nn.quantized.dynamic as nnqd
13import torch.ao.quantization.quantize_fx as quantize_fx
14import torch.nn as nn
15from torch.ao.quantization import MinMaxObserver, PerChannelMinMaxObserver
16from torch.fx import GraphModule
17from torch.testing._internal.common_quantization import skipIfNoFBGEMM
18from torch.testing._internal.common_quantized import (
19    override_qengines,
20    qengine_is_fbgemm,
21)
22
23# Testing utils
24from torch.testing._internal.common_utils import IS_AVX512_VNNI_SUPPORTED, TestCase
25from torch.testing._internal.quantization_torch_package_models import (
26    LinearReluFunctional,
27)
28
29
30def remove_prefix(text, prefix):
31    if text.startswith(prefix):
32        return text[len(prefix) :]
33    return text
34
35
36def get_filenames(self, subname):
37    # NB: we take __file__ from the module that defined the test
38    # class, so we place the expect directory where the test script
39    # lives, NOT where test/common_utils.py lives.
40    module_id = self.__class__.__module__
41    munged_id = remove_prefix(self.id(), module_id + ".")
42    test_file = os.path.realpath(sys.modules[module_id].__file__)
43    base_name = os.path.join(os.path.dirname(test_file), "../serialized", munged_id)
44
45    subname_output = ""
46    if subname:
47        base_name += "_" + subname
48        subname_output = f" ({subname})"
49
50    input_file = base_name + ".input.pt"
51    state_dict_file = base_name + ".state_dict.pt"
52    scripted_module_file = base_name + ".scripted.pt"
53    traced_module_file = base_name + ".traced.pt"
54    expected_file = base_name + ".expected.pt"
55    package_file = base_name + ".package.pt"
56    get_attr_targets_file = base_name + ".get_attr_targets.pt"
57
58    return (
59        input_file,
60        state_dict_file,
61        scripted_module_file,
62        traced_module_file,
63        expected_file,
64        package_file,
65        get_attr_targets_file,
66    )
67
68
69class TestSerialization(TestCase):
70    """Test backward compatiblity for serialization and numerics"""
71
72    # Copy and modified from TestCase.assertExpected
73    def _test_op(
74        self,
75        qmodule,
76        subname=None,
77        input_size=None,
78        input_quantized=True,
79        generate=False,
80        prec=None,
81        new_zipfile_serialization=False,
82    ):
83        r"""Test quantized modules serialized previously can be loaded
84        with current code, make sure we don't break backward compatibility for the
85        serialization of quantized modules
86        """
87        (
88            input_file,
89            state_dict_file,
90            scripted_module_file,
91            traced_module_file,
92            expected_file,
93            _package_file,
94            _get_attr_targets_file,
95        ) = get_filenames(self, subname)
96
97        # only generate once.
98        if generate and qengine_is_fbgemm():
99            input_tensor = torch.rand(*input_size).float()
100            if input_quantized:
101                input_tensor = torch.quantize_per_tensor(
102                    input_tensor, 0.5, 2, torch.quint8
103                )
104            torch.save(input_tensor, input_file)
105            # Temporary fix to use _use_new_zipfile_serialization until #38379 lands.
106            torch.save(
107                qmodule.state_dict(),
108                state_dict_file,
109                _use_new_zipfile_serialization=new_zipfile_serialization,
110            )
111            torch.jit.save(torch.jit.script(qmodule), scripted_module_file)
112            torch.jit.save(torch.jit.trace(qmodule, input_tensor), traced_module_file)
113            torch.save(qmodule(input_tensor), expected_file)
114
115        input_tensor = torch.load(input_file)
116        # weights_only = False as sometimes get ScriptObject here
117        qmodule.load_state_dict(torch.load(state_dict_file, weights_only=False))
118        qmodule_scripted = torch.jit.load(scripted_module_file)
119        qmodule_traced = torch.jit.load(traced_module_file)
120        expected = torch.load(expected_file)
121        self.assertEqual(qmodule(input_tensor), expected, atol=prec)
122        self.assertEqual(qmodule_scripted(input_tensor), expected, atol=prec)
123        self.assertEqual(qmodule_traced(input_tensor), expected, atol=prec)
124
125    def _test_op_graph(
126        self,
127        qmodule,
128        subname=None,
129        input_size=None,
130        input_quantized=True,
131        generate=False,
132        prec=None,
133        new_zipfile_serialization=False,
134    ):
135        r"""
136        Input: a floating point module
137
138        If generate == True, traces and scripts the module and quantizes the results with
139        PTQ, and saves the results.
140
141        If generate == False, traces and scripts the module and quantizes the results with
142        PTQ, and compares to saved results.
143        """
144        (
145            input_file,
146            state_dict_file,
147            scripted_module_file,
148            traced_module_file,
149            expected_file,
150            _package_file,
151            _get_attr_targets_file,
152        ) = get_filenames(self, subname)
153
154        # only generate once.
155        if generate and qengine_is_fbgemm():
156            input_tensor = torch.rand(*input_size).float()
157            torch.save(input_tensor, input_file)
158
159            # convert to TorchScript
160            scripted = torch.jit.script(qmodule)
161            traced = torch.jit.trace(qmodule, input_tensor)
162
163            # quantize
164
165            def _eval_fn(model, data):
166                model(data)
167
168            qconfig_dict = {"": torch.ao.quantization.default_qconfig}
169            scripted_q = torch.ao.quantization.quantize_jit(
170                scripted, qconfig_dict, _eval_fn, [input_tensor]
171            )
172            traced_q = torch.ao.quantization.quantize_jit(
173                traced, qconfig_dict, _eval_fn, [input_tensor]
174            )
175
176            torch.jit.save(scripted_q, scripted_module_file)
177            torch.jit.save(traced_q, traced_module_file)
178            torch.save(scripted_q(input_tensor), expected_file)
179
180        input_tensor = torch.load(input_file)
181        qmodule_scripted = torch.jit.load(scripted_module_file)
182        qmodule_traced = torch.jit.load(traced_module_file)
183        expected = torch.load(expected_file)
184        self.assertEqual(qmodule_scripted(input_tensor), expected, atol=prec)
185        self.assertEqual(qmodule_traced(input_tensor), expected, atol=prec)
186
187    def _test_obs(
188        self, obs, input_size, subname=None, generate=False, check_numerics=True
189    ):
190        """
191        Test observer code can be loaded from state_dict.
192        """
193        (
194            input_file,
195            state_dict_file,
196            _,
197            traced_module_file,
198            expected_file,
199            _package_file,
200            _get_attr_targets_file,
201        ) = get_filenames(self, None)
202        if generate:
203            input_tensor = torch.rand(*input_size).float()
204            torch.save(input_tensor, input_file)
205            torch.save(obs(input_tensor), expected_file)
206            torch.save(obs.state_dict(), state_dict_file)
207
208        input_tensor = torch.load(input_file)
209        obs.load_state_dict(torch.load(state_dict_file))
210        expected = torch.load(expected_file)
211        if check_numerics:
212            self.assertEqual(obs(input_tensor), expected)
213
214    def _test_package(self, fp32_module, input_size, generate=False):
215        """
216        Verifies that files created in the past with torch.package
217        work on today's FX graph mode quantization transforms.
218        """
219        (
220            input_file,
221            state_dict_file,
222            _scripted_module_file,
223            _traced_module_file,
224            expected_file,
225            package_file,
226            get_attr_targets_file,
227        ) = get_filenames(self, None)
228
229        package_name = "test"
230        resource_name_model = "test.pkl"
231
232        def _do_quant_transforms(
233            m: torch.nn.Module,
234            input_tensor: torch.Tensor,
235        ) -> torch.nn.Module:
236            example_inputs = (input_tensor,)
237            # do the quantizaton transforms and save result
238            qconfig = torch.ao.quantization.get_default_qconfig("fbgemm")
239            mp = quantize_fx.prepare_fx(m, {"": qconfig}, example_inputs=example_inputs)
240            mp(input_tensor)
241            mq = quantize_fx.convert_fx(mp)
242            return mq
243
244        def _get_get_attr_target_strings(m: GraphModule) -> Set[str]:
245            results = set()
246            for node in m.graph.nodes:
247                if node.op == "get_attr":
248                    results.add(node.target)
249            return results
250
251        if generate and qengine_is_fbgemm():
252            input_tensor = torch.randn(*input_size)
253            torch.save(input_tensor, input_file)
254
255            # save the model with torch.package
256            with torch.package.PackageExporter(package_file) as exp:
257                exp.intern("torch.testing._internal.quantization_torch_package_models")
258                exp.save_pickle(package_name, resource_name_model, fp32_module)
259
260            # do the quantization transforms and save the result
261            mq = _do_quant_transforms(fp32_module, input_tensor)
262            get_attrs = _get_get_attr_target_strings(mq)
263            torch.save(get_attrs, get_attr_targets_file)
264            q_result = mq(input_tensor)
265            torch.save(q_result, expected_file)
266
267        # load input tensor
268        input_tensor = torch.load(input_file)
269        expected_output_tensor = torch.load(expected_file)
270        expected_get_attrs = torch.load(get_attr_targets_file, weights_only=False)
271
272        # load model from package and verify output and get_attr targets match
273        imp = torch.package.PackageImporter(package_file)
274        m = imp.load_pickle(package_name, resource_name_model)
275        mq = _do_quant_transforms(m, input_tensor)
276
277        get_attrs = _get_get_attr_target_strings(mq)
278        self.assertTrue(
279            get_attrs == expected_get_attrs,
280            f"get_attrs: expected {expected_get_attrs}, got {get_attrs}",
281        )
282        output_tensor = mq(input_tensor)
283        self.assertTrue(torch.allclose(output_tensor, expected_output_tensor))
284
285    @override_qengines
286    def test_linear(self):
287        module = nnq.Linear(3, 1, bias_=True, dtype=torch.qint8)
288        self._test_op(module, input_size=[1, 3], generate=False)
289
290    @override_qengines
291    def test_linear_relu(self):
292        module = nniq.LinearReLU(3, 1, bias=True, dtype=torch.qint8)
293        self._test_op(module, input_size=[1, 3], generate=False)
294
295    @override_qengines
296    def test_linear_dynamic(self):
297        module_qint8 = nnqd.Linear(3, 1, bias_=True, dtype=torch.qint8)
298        self._test_op(
299            module_qint8,
300            "qint8",
301            input_size=[1, 3],
302            input_quantized=False,
303            generate=False,
304        )
305        if qengine_is_fbgemm():
306            module_float16 = nnqd.Linear(3, 1, bias_=True, dtype=torch.float16)
307            self._test_op(
308                module_float16,
309                "float16",
310                input_size=[1, 3],
311                input_quantized=False,
312                generate=False,
313            )
314
315    @override_qengines
316    def test_conv2d(self):
317        module = nnq.Conv2d(
318            3,
319            3,
320            kernel_size=3,
321            stride=1,
322            padding=0,
323            dilation=1,
324            groups=1,
325            bias=True,
326            padding_mode="zeros",
327        )
328        self._test_op(module, input_size=[1, 3, 6, 6], generate=False)
329
330    @override_qengines
331    def test_conv2d_nobias(self):
332        module = nnq.Conv2d(
333            3,
334            3,
335            kernel_size=3,
336            stride=1,
337            padding=0,
338            dilation=1,
339            groups=1,
340            bias=False,
341            padding_mode="zeros",
342        )
343        self._test_op(module, input_size=[1, 3, 6, 6], generate=False)
344
345    @override_qengines
346    def test_conv2d_graph(self):
347        module = nn.Sequential(
348            torch.ao.quantization.QuantStub(),
349            nn.Conv2d(
350                3,
351                3,
352                kernel_size=3,
353                stride=1,
354                padding=0,
355                dilation=1,
356                groups=1,
357                bias=True,
358                padding_mode="zeros",
359            ),
360        )
361        self._test_op_graph(module, input_size=[1, 3, 6, 6], generate=False)
362
363    @override_qengines
364    def test_conv2d_nobias_graph(self):
365        module = nn.Sequential(
366            torch.ao.quantization.QuantStub(),
367            nn.Conv2d(
368                3,
369                3,
370                kernel_size=3,
371                stride=1,
372                padding=0,
373                dilation=1,
374                groups=1,
375                bias=False,
376                padding_mode="zeros",
377            ),
378        )
379        self._test_op_graph(module, input_size=[1, 3, 6, 6], generate=False)
380
381    @override_qengines
382    def test_conv2d_graph_v2(self):
383        # tests the same thing as test_conv2d_graph, but for version 2 of
384        # ConvPackedParams{n}d
385        module = nn.Sequential(
386            torch.ao.quantization.QuantStub(),
387            nn.Conv2d(
388                3,
389                3,
390                kernel_size=3,
391                stride=1,
392                padding=0,
393                dilation=1,
394                groups=1,
395                bias=True,
396                padding_mode="zeros",
397            ),
398        )
399        self._test_op_graph(module, input_size=[1, 3, 6, 6], generate=False)
400
401    @override_qengines
402    def test_conv2d_nobias_graph_v2(self):
403        # tests the same thing as test_conv2d_nobias_graph, but for version 2 of
404        # ConvPackedParams{n}d
405        module = nn.Sequential(
406            torch.ao.quantization.QuantStub(),
407            nn.Conv2d(
408                3,
409                3,
410                kernel_size=3,
411                stride=1,
412                padding=0,
413                dilation=1,
414                groups=1,
415                bias=False,
416                padding_mode="zeros",
417            ),
418        )
419        self._test_op_graph(module, input_size=[1, 3, 6, 6], generate=False)
420
421    @override_qengines
422    def test_conv2d_graph_v3(self):
423        # tests the same thing as test_conv2d_graph, but for version 3 of
424        # ConvPackedParams{n}d
425        module = nn.Sequential(
426            torch.ao.quantization.QuantStub(),
427            nn.Conv2d(
428                3,
429                3,
430                kernel_size=3,
431                stride=1,
432                padding=0,
433                dilation=1,
434                groups=1,
435                bias=True,
436                padding_mode="zeros",
437            ),
438        )
439        self._test_op_graph(module, input_size=[1, 3, 6, 6], generate=False)
440
441    @override_qengines
442    def test_conv2d_nobias_graph_v3(self):
443        # tests the same thing as test_conv2d_nobias_graph, but for version 3 of
444        # ConvPackedParams{n}d
445        module = nn.Sequential(
446            torch.ao.quantization.QuantStub(),
447            nn.Conv2d(
448                3,
449                3,
450                kernel_size=3,
451                stride=1,
452                padding=0,
453                dilation=1,
454                groups=1,
455                bias=False,
456                padding_mode="zeros",
457            ),
458        )
459        self._test_op_graph(module, input_size=[1, 3, 6, 6], generate=False)
460
461    @override_qengines
462    def test_conv2d_relu(self):
463        module = nniq.ConvReLU2d(
464            3,
465            3,
466            kernel_size=3,
467            stride=1,
468            padding=0,
469            dilation=1,
470            groups=1,
471            bias=True,
472            padding_mode="zeros",
473        )
474        self._test_op(module, input_size=[1, 3, 6, 6], generate=False)
475        # TODO: graph mode quantized conv2d module
476
477    @override_qengines
478    def test_conv3d(self):
479        if qengine_is_fbgemm():
480            module = nnq.Conv3d(
481                3,
482                3,
483                kernel_size=3,
484                stride=1,
485                padding=0,
486                dilation=1,
487                groups=1,
488                bias=True,
489                padding_mode="zeros",
490            )
491            self._test_op(module, input_size=[1, 3, 6, 6, 6], generate=False)
492            # TODO: graph mode quantized conv3d module
493
494    @override_qengines
495    def test_conv3d_relu(self):
496        if qengine_is_fbgemm():
497            module = nniq.ConvReLU3d(
498                3,
499                3,
500                kernel_size=3,
501                stride=1,
502                padding=0,
503                dilation=1,
504                groups=1,
505                bias=True,
506                padding_mode="zeros",
507            )
508            self._test_op(module, input_size=[1, 3, 6, 6, 6], generate=False)
509            # TODO: graph mode quantized conv3d module
510
511    @override_qengines
512    @unittest.skipIf(
513        IS_AVX512_VNNI_SUPPORTED,
514        "This test fails on machines with AVX512_VNNI support. Ref: GH Issue 59098",
515    )
516    def test_lstm(self):
517        class LSTMModule(torch.nn.Module):
518            def __init__(self) -> None:
519                super().__init__()
520                self.lstm = nnqd.LSTM(input_size=3, hidden_size=7, num_layers=1).to(
521                    dtype=torch.float
522                )
523
524            def forward(self, x):
525                x = self.lstm(x)
526                return x
527
528        if qengine_is_fbgemm():
529            mod = LSTMModule()
530            self._test_op(
531                mod,
532                input_size=[4, 4, 3],
533                input_quantized=False,
534                generate=False,
535                new_zipfile_serialization=True,
536            )
537
538    def test_per_channel_observer(self):
539        obs = PerChannelMinMaxObserver()
540        self._test_obs(obs, input_size=[5, 5], generate=False)
541
542    def test_per_tensor_observer(self):
543        obs = MinMaxObserver()
544        self._test_obs(obs, input_size=[5, 5], generate=False)
545
546    def test_default_qat_qconfig(self):
547        class Model(nn.Module):
548            def __init__(self) -> None:
549                super().__init__()
550                self.linear = nn.Linear(5, 5)
551                self.relu = nn.ReLU()
552
553            def forward(self, x):
554                x = self.linear(x)
555                x = self.relu(x)
556                return x
557
558        model = Model()
559        model.linear.weight = torch.nn.Parameter(torch.randn(5, 5))
560        model.qconfig = torch.ao.quantization.get_default_qat_qconfig("fbgemm")
561        ref_model = torch.ao.quantization.QuantWrapper(model)
562        ref_model = torch.ao.quantization.prepare_qat(ref_model)
563        self._test_obs(
564            ref_model, input_size=[5, 5], generate=False, check_numerics=False
565        )
566
567    @skipIfNoFBGEMM
568    def test_linear_relu_package_quantization_transforms(self):
569        m = LinearReluFunctional(4).eval()
570        self._test_package(m, input_size=(1, 1, 4, 4), generate=False)
571