xref: /aosp_15_r20/external/pytorch/test/ao/sparsity/test_composability.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: unknown"]
2
3
4import logging
5
6import torch
7import torch.ao.quantization as tq
8from torch import nn
9from torch.ao import pruning
10from torch.ao.pruning import fqn_to_module
11from torch.ao.quantization.quantize_fx import (
12    convert_fx,
13    convert_to_reference_fx,
14    prepare_fx,
15    prepare_qat_fx,
16)
17from torch.testing._internal.common_utils import TestCase
18
19
20logging.basicConfig(
21    format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO
22)
23
24sparse_defaults = {
25    "sparsity_level": 0.8,
26    "sparse_block_shape": (1, 4),
27    "zeros_per_block": 4,
28}
29
30
31def _get_model_and_sparsifier_and_sparse_config(qconfig=None):
32    model = nn.Sequential(
33        nn.Linear(4, 4),  # 0
34        nn.ReLU(),
35        nn.Linear(4, 4),  # 2
36        nn.ReLU(),
37        tq.QuantStub(),
38        nn.Linear(4, 4),  # 5
39        nn.ReLU(),
40        tq.DeQuantStub(),
41    )
42    if qconfig:
43        model[4].qconfig = qconfig
44        model[5].qconfig = qconfig
45
46    sparsifier = pruning.WeightNormSparsifier(**sparse_defaults)
47
48    sparse_config = [
49        {
50            "tensor_fqn": "5.weight",
51            "sparsity_level": 0.7,
52            "sparse_block_shape": (1, 4),
53            "zeros_per_block": 4,
54        },
55        {"tensor_fqn": "0.weight"},
56    ]
57    return model, sparsifier, sparse_config
58
59
60def _squash_mask_calibrate_and_convert(model, sparsifier, input):
61    sparsifier.step()
62    sparsifier.squash_mask()
63    model(input)
64    tq.convert(model, inplace=True)
65
66
67def _calculate_sparsity(tensor):
68    return ((tensor == 0).sum() / tensor.numel()).item()
69
70
71# This series of tests are to check the composability goals for sparsity and quantization. Namely
72# that performing quantization and sparsity model manipulations in various orderings
73# does not cause problems
74class TestComposability(TestCase):
75    # This test checks whether performing quantization prepare before sparse prepare
76    # causes any issues and verifies that the correct observers are inserted and that
77    # the quantized model works as expected
78    def test_q_prep_before_s_prep(self):
79        (
80            mod,
81            sparsifier,
82            sparse_config,
83        ) = _get_model_and_sparsifier_and_sparse_config(
84            tq.get_default_qconfig("fbgemm")
85        )
86
87        tq.prepare(mod, inplace=True)
88        sparsifier.prepare(mod, config=sparse_config)
89
90        # check that correct modules had parametrizations added
91        self.assertTrue(hasattr(mod[0], "parametrizations"))
92        self.assertTrue(hasattr(mod[5], "parametrizations"))
93        # check that correct observers were inserted
94        self.assertTrue(hasattr(mod[5], "activation_post_process"))
95
96        _squash_mask_calibrate_and_convert(mod, sparsifier, torch.randn(1, 4, 4, 4))
97
98        # check that final module is the expected quantized module and that the model runs
99        self.assertTrue(isinstance(mod[5], torch.ao.nn.quantized.Linear))
100        self.assertEqual(mod(torch.randn(1, 4, 4, 4)).shape, torch.Size([1, 4, 4, 4]))
101
102    # This test checks whether performing sparsity prepare before quantization prepare
103    # causes any issues. In particular, previous quantization flow was unable to match
104    # the post sparse prepare module names (adding parametrizations changes the module class names)
105    # which would result in those parametrized modules not being quantized. This test verifies that
106    # the fix for this was successful.
107    def test_s_prep_before_q_prep(self):
108        (
109            mod,
110            sparsifier,
111            sparse_config,
112        ) = _get_model_and_sparsifier_and_sparse_config(
113            tq.get_default_qconfig("fbgemm")
114        )
115
116        sparsifier.prepare(mod, config=sparse_config)
117        tq.prepare(mod, inplace=True)
118
119        # check that correct modules had parametrizations added and
120        # that none were lost during prepare
121        self.assertTrue(hasattr(mod[0], "parametrizations"))
122        self.assertTrue(hasattr(mod[5], "parametrizations"))
123
124        # check that correct observers were inserted and that matching
125        # occurred successfully
126        self.assertTrue(hasattr(mod[5], "activation_post_process"))
127
128        _squash_mask_calibrate_and_convert(mod, sparsifier, torch.randn(1, 4, 4, 4))
129
130        # check that final module is the expected quantized module and that the model runs
131        self.assertTrue(isinstance(mod[5], torch.ao.nn.quantized.Linear))
132        self.assertEqual(mod(torch.randn(1, 4, 4, 4)).shape, torch.Size([1, 4, 4, 4]))
133
134    # if the sparsified modules have not undergone the final squash mask operation, its possible
135    # that the problem outlined in test_s_prep_before_q_prep would occur. This test verifies
136    # both that the fix to the convert flow avoids this issue and that the resulting quantized
137    # module uses the sparse version of the weight value.
138    def test_convert_without_squash_mask(self):
139        (
140            mod,
141            sparsifier,
142            sparse_config,
143        ) = _get_model_and_sparsifier_and_sparse_config(
144            tq.get_default_qconfig("fbgemm")
145        )
146
147        sparsifier.prepare(mod, config=sparse_config)
148        tq.prepare(mod, inplace=True)
149
150        # check that correct modules had parametrizations added and
151        # that none were lost during prepare
152        self.assertTrue(hasattr(mod[0], "parametrizations"))
153        self.assertTrue(hasattr(mod[5], "parametrizations"))
154
155        # check that correct observers were inserted and that matching
156        # occurred successfully
157        self.assertTrue(hasattr(mod[5], "activation_post_process"))
158        sparsifier.step()
159        sparsity_level = _calculate_sparsity(mod[5].weight)
160        mod(torch.randn(1, 4, 4, 4))
161        tq.convert(mod, inplace=True)
162
163        # check that final module is the expected quantized module and that the model runs
164        self.assertTrue(isinstance(mod[5], torch.ao.nn.quantized.Linear))
165        self.assertEqual(mod(torch.randn(1, 4, 4, 4)).shape, torch.Size([1, 4, 4, 4]))
166
167        # check that module was actually sparsified
168        cur_sparsity = _calculate_sparsity(mod[5]._weight_bias()[0])
169        self.assertGreaterAlmostEqual(cur_sparsity, sparsity_level)
170        self.assertGreaterAlmostEqual(
171            sparsity_level, sparse_config[0]["sparsity_level"]
172        )
173        self.assertGreaterAlmostEqual(cur_sparsity, sparse_config[0]["sparsity_level"])
174
175    # This tests whether performing sparse prepare before fusion causes any issues. The
176    # worry was that the link created between the sparsifier and the modules that need to
177    # be sparsified would be broken.
178    def test_s_prep_before_fusion(self):
179        (
180            mod,
181            sparsifier,
182            sparse_config,
183        ) = _get_model_and_sparsifier_and_sparse_config(
184            tq.get_default_qconfig("fbgemm")
185        )
186        sparsifier.prepare(mod, config=sparse_config)
187        tq.fuse_modules(mod, [["5", "6"]], inplace=True)
188        mod[5].qconfig = tq.get_default_qconfig("fbgemm")
189        tq.prepare(mod, inplace=True)
190
191        # check that correct modules had parametrizations added and
192        # that none were lost during prepare or fusion
193        self.assertTrue(hasattr(mod[0], "parametrizations"))
194        self.assertTrue(hasattr(mod[5][0], "parametrizations"))
195
196        # check that correct observers were inserted and that matching
197        # occurred successfully
198        self.assertTrue(hasattr(mod[5], "activation_post_process"))
199        _squash_mask_calibrate_and_convert(mod, sparsifier, torch.randn(1, 4, 4, 4))
200
201        # check that final module is the expected quantized module and that the model runs
202        self.assertTrue(isinstance(mod[5], torch.ao.nn.intrinsic.quantized.LinearReLU))
203        self.assertEqual(mod(torch.randn(1, 4, 4, 4)).shape, torch.Size([1, 4, 4, 4]))
204
205    # This tests whether performing fusion before sparse prepare causes and issues. The
206    # main worry was that the links to the modules in the sparse config would be broken by fusion.
207    def test_fusion_before_s_prep(self):
208        (
209            mod,
210            sparsifier,
211            _,
212        ) = _get_model_and_sparsifier_and_sparse_config(
213            tq.get_default_qconfig("fbgemm")
214        )
215        tq.fuse_modules(mod, [["5", "6"]], inplace=True)
216
217        # its absolutely broken by fusion but will still work if you put the correct fqn in
218        sparse_config = [
219            {
220                "tensor_fqn": "5.0.weight",
221                "sparsity_level": 0.7,
222                "sparse_block_shape": (1, 4),
223                "zeros_per_block": 4,
224            },
225            {"tensor_fqn": "0.weight"},
226        ]
227
228        sparsifier.prepare(mod, config=sparse_config)
229        mod[5].qconfig = tq.get_default_qconfig("fbgemm")
230        tq.prepare(mod, inplace=True)
231
232        # check that correct modules had parametrizations added and
233        # that none were lost during prepare
234        self.assertTrue(hasattr(mod[0], "parametrizations"))
235        self.assertTrue(hasattr(mod[5][0], "parametrizations"))
236
237        # check that correct observers were inserted and that matching
238        # occurred successfully
239        self.assertTrue(hasattr(mod[5], "activation_post_process"))
240        sparsifier.step()
241        sparsity_level = _calculate_sparsity(mod[5][0].weight)
242        mod(torch.randn(1, 4, 4, 4))
243        tq.convert(mod, inplace=True)
244
245        # check that final module is the expected quantized module and that the model runs
246        self.assertTrue(isinstance(mod[5], torch.ao.nn.intrinsic.quantized.LinearReLU))
247        self.assertEqual(mod(torch.randn(1, 4, 4, 4)).shape, torch.Size([1, 4, 4, 4]))
248
249        # check that module was actually sparsified
250        cur_sparsity = _calculate_sparsity(mod[5]._weight_bias()[0])
251        self.assertGreaterAlmostEqual(cur_sparsity, sparsity_level)
252        self.assertGreaterAlmostEqual(
253            sparsity_level, sparse_config[0]["sparsity_level"]
254        )
255        self.assertGreaterAlmostEqual(cur_sparsity, sparse_config[0]["sparsity_level"])
256
257    # This tests whether performing sparse prepare before qat prepare causes issues.
258    # The primary worries were that qat_prep wouldn't recognize the parametrized
259    # modules and that the convert step for qat would remove the parametrizations
260    # from the modules.
261    def test_s_prep_before_qat_prep(self):
262        (
263            mod,
264            sparsifier,
265            sparse_config,
266        ) = _get_model_and_sparsifier_and_sparse_config(
267            tq.get_default_qat_qconfig("fbgemm")
268        )
269        sparsifier.prepare(mod, config=sparse_config)
270        tq.prepare_qat(mod, inplace=True)
271        self.assertTrue(hasattr(mod[0], "parametrizations"))
272        self.assertTrue(hasattr(mod[5], "parametrizations"))
273
274        # check that correct observers were inserted and that matching
275        # occurred successfully
276        self.assertTrue(hasattr(mod[5], "activation_post_process"))
277        self.assertTrue(isinstance(mod[5], torch.ao.nn.qat.Linear))
278        _squash_mask_calibrate_and_convert(mod, sparsifier, torch.randn(1, 4, 4, 4))
279        # check that final module is the expected quantized module and that the model runs
280        self.assertTrue(isinstance(mod[5], torch.ao.nn.quantized.Linear))
281        self.assertEqual(mod(torch.randn(1, 4, 4, 4)).shape, torch.Size([1, 4, 4, 4]))
282
283        # check that module was actually sparsified
284        cur_sparsity = _calculate_sparsity(mod[5]._weight_bias()[0])
285        self.assertGreaterAlmostEqual(cur_sparsity, sparse_config[0]["sparsity_level"])
286
287    # This tests whether performing qat prepare before sparse prepare causes issues.
288    def test_qat_prep_before_s_prep(self):
289        mod, sparsifier, _ = _get_model_and_sparsifier_and_sparse_config(
290            tq.get_default_qat_qconfig("fbgemm")
291        )
292        tq.prepare_qat(mod, inplace=True)
293
294        # need to setup sparse_config on new modules
295        sparse_config = [
296            {
297                "tensor_fqn": "5.weight",
298                "sparsity_level": 0.7,
299                "sparse_block_shape": (1, 4),
300                "zeros_per_block": 4,
301            },
302            {"tensor_fqn": "0.weight"},
303        ]
304        sparsifier.prepare(mod, config=sparse_config)
305
306        # check that correct modules had parametrizations added and
307        # that none were lost during qat prepare
308        self.assertTrue(hasattr(mod[0], "parametrizations"))
309        self.assertTrue(hasattr(mod[5], "parametrizations"))
310
311        # check that correct observers were inserted and that matching
312        # occurred successfully
313        self.assertTrue(hasattr(mod[5], "activation_post_process"))
314        self.assertTrue(isinstance(mod[5], torch.ao.nn.qat.Linear))
315
316        _squash_mask_calibrate_and_convert(mod, sparsifier, torch.randn(1, 4, 4, 4))
317
318        # check that final module is the expected quantized module and that the model runs
319        self.assertTrue(isinstance(mod[5], torch.ao.nn.quantized.Linear))
320        self.assertEqual(mod(torch.randn(1, 4, 4, 4)).shape, torch.Size([1, 4, 4, 4]))
321
322        # check that module was actually sparsified
323        cur_sparsity = _calculate_sparsity(mod[5]._weight_bias()[0])
324        self.assertGreaterAlmostEqual(cur_sparsity, sparse_config[0]["sparsity_level"])
325
326
327def _module_has_activation_post_process(model, fqn_of_module):
328    for node in model.graph.nodes:
329        # look for an observer whose arg is the target module
330        if "activation_post_process" in node.name:
331            if node.args[0].target == fqn_of_module:
332                return True
333    return False
334
335
336class TestFxComposability(TestCase):
337    r"""This series of tests checks that various steps of the quantization and sparsity flow
338    compose cleanly despite variation in sequencing.
339    """
340
341    def test_q_prep_fx_before_s_prep(self):
342        r"""
343        This test checks that the ordering of prepare_fx -> sparse prepare -> convert_fx
344        compose cleanly without issue and that the final result is sparsified without
345        having to call squash mask between sparse prepare and convert_fx. This also tests the
346        automatic fusion that occurs during prepare_fx.
347        """
348        (
349            mod,
350            sparsifier,
351            _,
352        ) = _get_model_and_sparsifier_and_sparse_config()
353
354        example = torch.randn(1, 4, 4, 4)
355        qconfig = tq.get_default_qconfig("fbgemm")
356        qconfig_mapping = (
357            tq.QConfigMapping()
358            .set_module_name("4", qconfig)
359            .set_module_name("5", qconfig)
360        )
361
362        mod = prepare_fx(mod, qconfig_mapping, (example,))
363
364        # its absolutely broken by auto fusion in fx
365        # but will still work if you put the correct fqn in
366        sparse_config = [
367            {
368                "tensor_fqn": "5.0.weight",
369                "sparsity_level": 0.7,
370                "sparse_block_shape": (1, 4),
371                "zeros_per_block": 4,
372            },
373            {"tensor_fqn": "0.0.weight"},
374        ]
375        sparsifier.prepare(mod, config=sparse_config)
376
377        # check that correct modules had parametrizations added and
378        # that none were lost during prepare
379        self.assertTrue(hasattr(fqn_to_module(mod, "0.0"), "parametrizations"))
380        self.assertTrue(hasattr(fqn_to_module(mod, "5.0"), "parametrizations"))
381
382        # check that correct observers were inserted and that matching
383        # occurred successfully
384        self.assertTrue(_module_has_activation_post_process(mod, "5"))
385        sparsifier.step()
386        sparsity_level = _calculate_sparsity(fqn_to_module(mod, "5.0.weight"))
387        mod(example)
388        mod = convert_fx(mod)
389
390        # check that final module is the expected quantized module and that the model runs
391        self.assertTrue(
392            isinstance(
393                fqn_to_module(mod, "5"), torch.ao.nn.intrinsic.quantized.LinearReLU
394            )
395        )
396        self.assertEqual(mod(example).shape, torch.Size([1, 4, 4, 4]))
397
398        # check that module was actually sparsified
399        cur_sparsity = _calculate_sparsity(fqn_to_module(mod, "5")._weight_bias()[0])
400        self.assertGreaterAlmostEqual(cur_sparsity, sparsity_level)
401        self.assertGreaterAlmostEqual(
402            sparsity_level, sparse_config[0]["sparsity_level"]
403        )
404        self.assertGreaterAlmostEqual(cur_sparsity, sparse_config[0]["sparsity_level"])
405
406    def test_q_prep_fx_s_prep_ref_conv(self):
407        r"""
408        This checks that the ordering: prepare_fx -> sparse prepare -> convert_to_reference_fx
409        compose cleanly without issue and that the final result is sparsified without
410        having to call squash mask before convert_to_reference_fx.
411        """
412        (
413            mod,
414            sparsifier,
415            _,
416        ) = _get_model_and_sparsifier_and_sparse_config()
417
418        example = torch.randn(1, 4, 4, 4)
419        qconfig = tq.get_default_qconfig("fbgemm")
420        qconfig_mapping = (
421            tq.QConfigMapping()
422            .set_module_name("4", qconfig)
423            .set_module_name("5", qconfig)
424        )
425
426        mod = prepare_fx(mod, qconfig_mapping, (example,))
427
428        # its absolutely broken by auto fusion in fx
429        # but will still work if you put the correct fqn in
430        sparse_config = [
431            {
432                "tensor_fqn": "5.0.weight",
433                "sparsity_level": 0.7,
434                "sparse_block_shape": (1, 4),
435                "zeros_per_block": 4,
436            },
437            {"tensor_fqn": "0.0.weight"},
438        ]
439        sparsifier.prepare(mod, config=sparse_config)
440
441        # check that correct modules had parametrizations added and
442        # that none were lost during prepare
443        self.assertTrue(hasattr(fqn_to_module(mod, "0.0"), "parametrizations"))
444        self.assertTrue(hasattr(fqn_to_module(mod, "5.0"), "parametrizations"))
445
446        # check that correct observers were inserted and that matching
447        # occurred successfully
448        self.assertTrue(_module_has_activation_post_process(mod, "5"))
449        sparsifier.step()
450        sparsity_level = _calculate_sparsity(fqn_to_module(mod, "5.0.weight"))
451        mod(example)
452        mod = convert_to_reference_fx(mod)
453
454        # check that final module is the expected quantized module and that the model runs
455        self.assertTrue(
456            isinstance(fqn_to_module(mod, "5"), torch.ao.nn.intrinsic.LinearReLU)
457        )
458        self.assertEqual(mod(example).shape, torch.Size([1, 4, 4, 4]))
459        self.assertTrue(
460            isinstance(
461                fqn_to_module(mod, "5.0"), torch.ao.nn.quantized.reference.Linear
462            )
463        )
464
465        # check that module was actually sparsified
466        cur_sparsity = _calculate_sparsity(fqn_to_module(mod, "5.0.weight"))
467        self.assertGreaterAlmostEqual(cur_sparsity, sparsity_level)
468        self.assertGreaterAlmostEqual(
469            sparsity_level, sparse_config[0]["sparsity_level"]
470        )
471        self.assertGreaterAlmostEqual(cur_sparsity, sparse_config[0]["sparsity_level"])
472
473    def test_s_prep_before_q_prep_fx(self):
474        r"""
475        This test checks that the ordering of sparse prepare -> prepare_fx -> convert_fx
476        compose cleanly without issue and that the final result is sparsified without
477        having to call squash mask before convert_fx.
478        """
479        (
480            mod,
481            sparsifier,
482            sparse_config,
483        ) = _get_model_and_sparsifier_and_sparse_config()
484        sparsifier.prepare(mod, config=sparse_config)
485
486        example = torch.randn(1, 4, 4, 4)
487        qconfig = tq.get_default_qconfig("fbgemm")
488        qconfig_mapping = (
489            tq.QConfigMapping()
490            .set_module_name("4", qconfig)
491            .set_module_name("5", qconfig)
492        )
493        mod = prepare_fx(mod, qconfig_mapping, (example,))
494
495        # check that correct modules had parametrizations added and
496        # that none were lost during prepare
497        self.assertTrue(hasattr(fqn_to_module(mod, "0.0"), "parametrizations"))
498        self.assertTrue(hasattr(fqn_to_module(mod, "5.0"), "parametrizations"))
499
500        # check that correct observers were inserted and that matching
501        # occurred successfully
502        self.assertTrue(_module_has_activation_post_process(mod, "5"))
503        sparsifier.step()
504        sparsity_level = _calculate_sparsity(fqn_to_module(mod, "5.0.weight"))
505        mod(example)
506        mod = convert_fx(mod)
507
508        # check that final module is the expected quantized module and that the model runs
509        self.assertTrue(
510            isinstance(
511                fqn_to_module(mod, "5"), torch.ao.nn.intrinsic.quantized.LinearReLU
512            )
513        )
514        self.assertEqual(mod(example).shape, torch.Size([1, 4, 4, 4]))
515
516        # check that module was actually sparsified
517        cur_sparsity = _calculate_sparsity(fqn_to_module(mod, "5")._weight_bias()[0])
518        self.assertGreaterAlmostEqual(cur_sparsity, sparsity_level)
519        self.assertGreaterAlmostEqual(
520            sparsity_level, sparse_config[0]["sparsity_level"]
521        )
522        self.assertGreaterAlmostEqual(cur_sparsity, sparse_config[0]["sparsity_level"])
523
524    def test_s_prep_before_qat_prep_fx(self):
525        r"""
526        This test checks that the ordering of sparse prepare -> prepare_qat_fx -> convert_fx
527        compose cleanly without issue and that the final result is sparsified without
528        having to call squash mask before convert_fx.
529        """
530        (
531            mod,
532            sparsifier,
533            sparse_config,
534        ) = _get_model_and_sparsifier_and_sparse_config()
535        sparsifier.prepare(mod, config=sparse_config)
536
537        example = torch.randn(1, 4, 4, 4)
538        qconfig = tq.get_default_qat_qconfig("fbgemm")
539        qconfig_mapping = (
540            tq.QConfigMapping()
541            .set_module_name("4", qconfig)
542            .set_module_name("5", qconfig)
543        )
544        mod = prepare_qat_fx(mod, qconfig_mapping, (example,))
545
546        # check that correct modules had parametrizations added and
547        # that none were lost during prepare
548        self.assertTrue(hasattr(fqn_to_module(mod, "0.0"), "parametrizations"))
549        self.assertTrue(hasattr(fqn_to_module(mod, "5"), "parametrizations"))
550        self.assertTrue(
551            isinstance(fqn_to_module(mod, "5"), torch.ao.nn.intrinsic.qat.LinearReLU)
552        )
553
554        # check that correct observers were inserted and that matching
555        # occurred successfully
556        self.assertTrue(_module_has_activation_post_process(mod, "5"))
557        sparsifier.step()
558        sparsity_level = _calculate_sparsity(fqn_to_module(mod, "5.weight"))
559        mod(example)
560        mod = convert_fx(mod)
561
562        # check that final module is the expected quantized module and that the model runs
563        self.assertTrue(
564            isinstance(
565                fqn_to_module(mod, "5"), torch.ao.nn.intrinsic.quantized.LinearReLU
566            )
567        )
568        self.assertEqual(mod(example).shape, torch.Size([1, 4, 4, 4]))
569
570        # check that module was actually sparsified
571        cur_sparsity = _calculate_sparsity(fqn_to_module(mod, "5")._weight_bias()[0])
572        self.assertGreaterAlmostEqual(cur_sparsity, sparsity_level)
573        self.assertGreaterAlmostEqual(
574            sparsity_level, sparse_config[0]["sparsity_level"]
575        )
576        self.assertGreaterAlmostEqual(cur_sparsity, sparse_config[0]["sparsity_level"])
577
578    def test_s_prep_q_prep_fx_ref(self):
579        r"""
580        This checks that the ordering: sparse prepare -> prepare_fx -> convert_to_reference_fx
581        compose cleanly without issue and that the final result is sparsified without
582        having to call squash mask before convert_to_reference_fx.
583        """
584        (
585            mod,
586            sparsifier,
587            sparse_config,
588        ) = _get_model_and_sparsifier_and_sparse_config()
589        sparsifier.prepare(mod, config=sparse_config)
590
591        example = torch.randn(1, 4, 4, 4)
592        qconfig = tq.get_default_qconfig("fbgemm")
593        qconfig_mapping = (
594            tq.QConfigMapping()
595            .set_module_name("4", qconfig)
596            .set_module_name("5", qconfig)
597        )
598        mod = prepare_fx(mod, qconfig_mapping, (example,))
599
600        # check that correct modules had parametrizations added and
601        # that none were lost during prepare
602        self.assertTrue(hasattr(fqn_to_module(mod, "0.0"), "parametrizations"))
603        self.assertTrue(hasattr(fqn_to_module(mod, "5.0"), "parametrizations"))
604
605        # check that correct observers were inserted and that matching
606        # occurred successfully
607        self.assertTrue(_module_has_activation_post_process(mod, "5"))
608        sparsifier.step()
609        sparsity_level = _calculate_sparsity(fqn_to_module(mod, "5.0.weight"))
610        mod(example)
611        mod = convert_to_reference_fx(mod)
612
613        # check that final module is the expected quantized module and that the model runs
614        self.assertTrue(
615            isinstance(fqn_to_module(mod, "5"), torch.ao.nn.intrinsic.LinearReLU)
616        )
617        self.assertEqual(mod(example).shape, torch.Size([1, 4, 4, 4]))
618        self.assertTrue(
619            isinstance(
620                fqn_to_module(mod, "5.0"), torch.ao.nn.quantized.reference.Linear
621            )
622        )
623
624        # check that module was actually sparsified
625        cur_sparsity = _calculate_sparsity(fqn_to_module(mod, "5.0.weight"))
626        self.assertGreaterAlmostEqual(cur_sparsity, sparsity_level)
627        self.assertGreaterAlmostEqual(
628            sparsity_level, sparse_config[0]["sparsity_level"]
629        )
630        self.assertGreaterAlmostEqual(cur_sparsity, sparse_config[0]["sparsity_level"])
631