xref: /aosp_15_r20/external/pytorch/test/quantization/fx/test_numeric_suite_fx.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: quantization"]
2
3import copy
4import math
5import operator
6import unittest
7
8import torch
9import torch.nn as nn
10import torch.nn.functional as F
11from torch.ao.quantization import (
12    default_dynamic_qconfig,
13    QConfigMapping,
14    get_default_qconfig_mapping,
15)
16import torch.ao.nn.quantized as nnq
17toq = torch.ops.quantized
18from torch.ao.quantization.quantize_fx import (
19    convert_fx,
20    convert_to_reference_fx,
21    prepare_fx,
22    prepare_qat_fx,
23)
24from torch.testing._internal.common_quantization import (
25    ConvBnModel,
26    ConvBnReLUModel,
27    ConvModel,
28    QuantizationTestCase,
29    skipIfNoFBGEMM,
30    skipIfNoQNNPACK,
31    withQNNPACKBackend,
32    SingleLayerLinearDynamicModel,
33    SingleLayerLinearModel,
34    LSTMwithHiddenDynamicModel,
35    SparseNNModel,
36    skip_if_no_torchvision,
37    TwoLayerLinearModel
38)
39from torch.testing._internal.common_utils import skipIfTorchDynamo
40from torch.ao.quantization.quantization_mappings import (
41    get_default_static_quant_module_mappings,
42    get_default_dynamic_quant_module_mappings,
43    get_default_float_to_quantized_operator_mappings,
44)
45from torch.testing._internal.common_cuda import TEST_CUDA
46from torch.testing._internal.common_quantization import NodeSpec as ns
47from torch.ao.quantization.fx.pattern_utils import get_default_quant_patterns
48import torch.ao.quantization.fx.quantize_handler as qh
49from torch.ao.ns.fx.pattern_utils import (
50    get_type_a_related_to_b,
51)
52from torch.ao.ns.fx.graph_matcher import (
53    get_matching_subgraph_pairs,
54    GraphMatchingException,
55)
56from torch.ao.ns.fx.utils import (
57    compute_sqnr,
58    compute_normalized_l2_error,
59    compute_cosine_similarity,
60)
61from torch.ao.ns.fx.mappings import (
62    get_node_type_to_io_type_map,
63    get_unmatchable_types_map,
64    get_base_name_to_sets_of_related_ops,
65    get_base_name_for_op,
66    add_op_to_sets_of_related_ops,
67)
68from torch.ao.ns.fx.weight_utils import (
69    get_op_to_type_to_weight_extraction_fn,
70)
71from torch.ao.ns._numeric_suite_fx import (
72    extract_weights,
73    _extract_weights_impl,
74    add_loggers,
75    _add_loggers_impl,
76    OutputLogger,
77    add_shadow_loggers,
78    _add_shadow_loggers_impl,
79    extract_logger_info,
80    extract_shadow_logger_info,
81    extend_logger_results_with_comparison,
82    prepare_n_shadows_model,
83    convert_n_shadows_model,
84    extract_results_n_shadows_model,
85    OutputComparisonLogger,
86    print_comparisons_n_shadows_model,
87    loggers_set_enabled,
88    loggers_set_save_activations,
89    _prepare_n_shadows_add_loggers_model,
90    _n_shadows_compare_weights,
91)
92from torch.ao.ns.fx.qconfig_multi_mapping import QConfigMultiMapping
93from torch.ao.quantization.backend_config import get_native_backend_config
94from torch.ao.quantization.fx.quantize_handler import _get_pattern_to_quantize_handlers
95
96
97# Note: these models are not for use outside of this file. While it's good
98# to reuse code, we also need to be able to iterate on tests
99# quickly when debugging. If a test model has a large number of callsites
100# across various different files, speed of debugging on individual test cases
101# decreases.
102class LinearReluFunctional(nn.Module):
103    def __init__(self) -> None:
104        super().__init__()
105        self.w1 = nn.Parameter(torch.empty(4, 4))
106        self.b1 = nn.Parameter(torch.zeros(4))
107        torch.nn.init.kaiming_uniform_(self.w1, a=math.sqrt(5))
108
109    def forward(self, x):
110        x = F.linear(x, self.w1, self.b1)
111        x = F.relu(x)
112        return x
113
114
115class LinearFunctional(nn.Module):
116    def __init__(self) -> None:
117        super().__init__()
118        self.w1 = nn.Parameter(torch.empty(4, 4))
119        self.b1 = nn.Parameter(torch.zeros(4))
120        torch.nn.init.kaiming_uniform_(self.w1, a=math.sqrt(5))
121
122    def forward(self, x):
123        x = F.linear(x, self.w1, self.b1)
124        return x
125
126
127class LinearReluLinearFunctional(nn.Module):
128    def __init__(self) -> None:
129        super().__init__()
130        self.w = nn.Parameter(torch.Tensor(4, 4))
131        self.b = nn.Parameter(torch.zeros(4))
132        torch.nn.init.kaiming_uniform_(self.w, a=math.sqrt(5))
133
134    def forward(self, x):
135        x = F.linear(x, self.w, self.b)
136        x = F.relu(x)
137        x = F.linear(x, self.w, self.b)
138        return x
139
140
141class AddMulFunctional(nn.Module):
142    def forward(self, x, y):
143        x = x + 1.0
144        x = x * 1.0
145        x = 1.0 + x
146        x = 1.0 * x
147        x = x + y
148        x = x * y
149        return x
150
151
152class AllConvAndLinearFusionModules(torch.nn.Module):
153    def __init__(self) -> None:
154        super().__init__()
155        # conv1d
156        self.conv1d_0 = nn.Conv1d(1, 1, 1)
157        # conv1d - relu
158        self.conv1d_1 = nn.Conv1d(1, 1, 1)
159        self.relu_0 = nn.ReLU()
160        # conv1d - bn (qat only)
161        self.conv1d_2 = nn.Conv1d(1, 1, 1)
162        self.bn1d_0 = nn.BatchNorm1d(1)
163        # conv1d - bn - relu (qat only)
164        self.conv1d_3 = nn.Conv1d(1, 1, 1)
165        self.bn1d_1 = nn.BatchNorm1d(1)
166        self.relu_4 = nn.ReLU()
167        # conv2d
168        self.conv2d_0 = nn.Conv2d(1, 1, 1)
169        # conv2d - relu
170        self.conv2d_1 = nn.Conv2d(1, 1, 1)
171        self.relu_1 = nn.ReLU()
172        # conv2d - bn (qat only)
173        self.conv2d_2 = nn.Conv2d(1, 1, 1)
174        self.bn2d_0 = nn.BatchNorm2d(1)
175        # conv2d - bn - relu (qat only)
176        self.conv2d_3 = nn.Conv2d(1, 1, 1)
177        self.bn2d_1 = nn.BatchNorm2d(1)
178        self.relu_5 = nn.ReLU()
179        # conv3d
180        self.conv3d_0 = nn.Conv3d(1, 1, 1)
181        # conv3d - relu
182        self.conv3d_1 = nn.Conv3d(1, 1, 1)
183        self.relu_2 = nn.ReLU()
184        # conv3d - bn (qat only)
185        self.conv3d_2 = nn.Conv3d(1, 1, 1)
186        self.bn3d_0 = nn.BatchNorm3d(1)
187        # conv3d - bn - relu (qat only)
188        self.conv3d_3 = nn.Conv3d(1, 1, 1)
189        self.bn3d_1 = nn.BatchNorm3d(1)
190        self.relu_6 = nn.ReLU()
191        # linear
192        self.linear_0 = nn.Linear(1, 1)
193        # linear - relu
194        self.linear_1 = nn.Linear(1, 1)
195        self.relu_3 = nn.ReLU()
196
197    def forward(self, x):
198        # conv1d
199        x = self.conv1d_0(x)
200        x = self.conv1d_1(x)
201        x = self.relu_0(x)
202        x = self.conv1d_2(x)
203        x = self.bn1d_0(x)
204        x = self.conv1d_3(x)
205        x = self.bn1d_1(x)
206        x = self.relu_4(x)
207        # conv2d
208        x = x.reshape(1, 1, 1, 1)
209        x = self.conv2d_0(x)
210        x = self.conv2d_1(x)
211        x = self.relu_1(x)
212        x = self.conv2d_2(x)
213        x = self.bn2d_0(x)
214        x = self.conv2d_3(x)
215        x = self.bn2d_1(x)
216        x = self.relu_5(x)
217        # conv3d
218        x = x.reshape(1, 1, 1, 1, 1)
219        x = self.conv3d_0(x)
220        x = self.conv3d_1(x)
221        x = self.relu_2(x)
222        x = self.conv3d_2(x)
223        x = self.bn3d_0(x)
224        x = self.conv3d_3(x)
225        x = self.bn3d_1(x)
226        x = self.relu_6(x)
227        # linear
228        x = x.reshape(1, 1)
229        x = self.linear_0(x)
230        x = self.linear_1(x)
231        x = self.relu_3(x)
232        return x
233
234
235class AllConvFunctional(torch.nn.Module):
236    def __init__(self, weight1d, weight2d, weight3d, bias1d, bias2d, bias3d):
237        super().__init__()
238        self.weight1d = torch.nn.Parameter(weight1d)
239        self.weight2d = torch.nn.Parameter(weight2d)
240        self.weight3d = torch.nn.Parameter(weight3d)
241        self.bias1d = torch.nn.Parameter(bias1d)
242        self.bias2d = torch.nn.Parameter(bias2d)
243        self.bias3d = torch.nn.Parameter(bias3d)
244        self.stride1d = 1
245        self.padding1d = 0
246        self.dilation1d = 1
247        self.stride2d = (1, 1)
248        self.padding2d = (0, 0)
249        self.dilation2d = (1, 1)
250        self.groups = 1
251        self.stride3d = (1, 1, 1)
252        self.padding3d = (0, 0, 0)
253        self.dilation3d = (1, 1, 1)
254
255    def forward(self, x):
256        x = F.conv1d(
257            x, self.weight1d, self.bias1d, self.stride1d, self.padding1d,
258            self.dilation1d, self.groups)
259        x = F.conv1d(
260            x, self.weight1d, self.bias1d, self.stride1d, self.padding1d,
261            self.dilation1d, self.groups)
262        x = F.relu(x)
263        x = F.conv2d(
264            x, self.weight2d, self.bias2d, self.stride2d, self.padding2d,
265            self.dilation2d, self.groups)
266        x = F.conv2d(
267            x, self.weight2d, self.bias2d, self.stride2d, self.padding2d,
268            self.dilation2d, self.groups)
269        x = F.relu(x)
270        x = F.conv3d(
271            x, self.weight3d, self.bias3d, self.stride3d, self.padding3d,
272            self.dilation3d, self.groups)
273        x = F.conv3d(
274            x, self.weight3d, self.bias3d, self.stride3d, self.padding3d,
275            self.dilation3d, self.groups)
276        x = F.relu(x)
277        return x
278
279@torch.fx.wrap
280def _wrapped_hardswish(x):
281    return F.hardswish(x)
282
283@torch.fx.wrap
284def _wrapped_hardswish_fp16(x):
285    x = x.dequantize()
286    x = F.hardswish(x)
287    x = x.to(torch.float16)
288    return x
289
290@torch.fx.wrap
291def _wrapped_sigmoid(x):
292    return F.sigmoid(x)
293
294@torch.fx.wrap
295def _wrapped_linear(x, w, b):
296    return F.linear(x, w, b)
297
298def get_all_quant_patterns():
299    """ we are in the process to migrate the frontend of fx graph mode quant
300    to use backend_config_dict, so some of the patterns are moved to backend_config_dict
301    this function will include these patterns so that we can still have all the patterns
302    """
303    # TODO: we can remove this call, and get all patterns from backend_config_dict in
304    # the future when the frontend refactor is done in fx graph mode quantization
305    all_quant_patterns = get_default_quant_patterns()
306    # some of the patterns are moved to (native) backend_config_dict so we need to
307    # add them back here
308    for pattern, quantize_handler in _get_pattern_to_quantize_handlers(get_native_backend_config()).items():
309        all_quant_patterns[pattern] = quantize_handler
310    return all_quant_patterns
311
312class TestFXGraphMatcher(QuantizationTestCase):
313
314    @skipIfNoFBGEMM
315    def test_simple_mod(self):
316        m = nn.Sequential(nn.Conv2d(1, 1, 1)).eval()
317        mp = prepare_fx(m, {'': torch.ao.quantization.default_qconfig}, example_inputs=(torch.randn(1, 1, 1, 1),))
318        mp_copy = copy.deepcopy(mp)
319        mq = convert_fx(mp_copy)
320        results = get_matching_subgraph_pairs(mp, mq)
321
322        base_name_to_sets_of_related_ops = get_base_name_to_sets_of_related_ops()
323        conv_name_0 = 'base_op_' + get_base_name_for_op(
324            base_name_to_sets_of_related_ops, nn.Conv2d) + '_0'
325
326        expected_types = {
327            conv_name_0: ((nn.Conv2d, torch.ao.quantization.MinMaxObserver), (nnq.Conv2d, nnq.Conv2d)),
328        }
329        self.assert_types_for_matched_subgraph_pairs(results, expected_types, mp, mq)
330
331    @skipIfNoFBGEMM
332    def test_simple_fun(self):
333        class M(nn.Module):
334            def __init__(self) -> None:
335                super().__init__()
336                self.w = nn.Parameter(torch.empty(1, 4))
337                self.b = nn.Parameter(torch.zeros(1))
338                torch.nn.init.kaiming_uniform_(self.w, a=math.sqrt(5))
339
340            def forward(self, x):
341                return F.linear(x, self.w, self.b)
342
343        m = M().eval()
344        mp = prepare_fx(m, {'': torch.ao.quantization.default_qconfig}, example_inputs=(torch.randn(1, 1, 1, 1),))
345        mp_copy = copy.deepcopy(mp)
346        mq = convert_fx(mp_copy)
347        results = get_matching_subgraph_pairs(mp, mq)
348
349        base_name_to_sets_of_related_ops = get_base_name_to_sets_of_related_ops()
350        linear_name_0 = 'base_op_' + get_base_name_for_op(
351            base_name_to_sets_of_related_ops, F.linear) + '_0'
352
353        expected_types = {
354            linear_name_0:
355                ((F.linear, torch.ao.quantization.MinMaxObserver), (toq.linear, toq.linear))
356        }
357        self.assert_types_for_matched_subgraph_pairs(results, expected_types, mp, mq)
358
359    @skipIfNoFBGEMM
360    def test_simple_fusion(self):
361        m = LinearReluFunctional().eval()
362        mp = prepare_fx(m, {'': torch.ao.quantization.default_qconfig}, example_inputs=(torch.randn(4, 4),))
363        mp_copy = copy.deepcopy(mp)
364        mq = convert_fx(mp_copy)
365        results = get_matching_subgraph_pairs(mp, mq)
366
367        base_name_to_sets_of_related_ops = get_base_name_to_sets_of_related_ops()
368        linear_name_0 = 'base_op_' + get_base_name_for_op(
369            base_name_to_sets_of_related_ops, F.linear) + '_0'
370
371        expected_types = {
372            linear_name_0:
373                ((F.linear, torch.ao.quantization.MinMaxObserver), (toq.linear_relu, toq.linear_relu)),
374        }
375        self.assert_types_for_matched_subgraph_pairs(results, expected_types, mp, mq)
376
377    @skipIfNoFBGEMM
378    def test_simple_mod_multi(self):
379        m = nn.Sequential(
380            nn.Sequential(
381                nn.Conv2d(1, 1, 1),
382            ),
383            nn.Conv2d(1, 1, 1),
384        ).eval()
385        mp = prepare_fx(m, {'': torch.ao.quantization.default_qconfig}, example_inputs=(torch.randn(1, 1, 1, 1),))
386        mp_copy = copy.deepcopy(mp)
387        mq = convert_fx(mp_copy)
388        # assume success if no exceptions
389        results = get_matching_subgraph_pairs(mp, mq)
390
391    @skipIfNoFBGEMM
392    def test_simple_tensor_ops(self):
393        class M(nn.Module):
394            def forward(self, x, y):
395                z = x + y
396                return z
397
398        m = M().eval()
399        example_inputs = (torch.randn(1), torch.randn(1))
400        mp = prepare_fx(m, {'': torch.ao.quantization.default_qconfig}, example_inputs=example_inputs)
401        mp_copy = copy.deepcopy(mp)
402        mq = convert_fx(mp_copy)
403        # assume success if no exceptions
404        results = get_matching_subgraph_pairs(mp, mq)
405
406    @skipIfNoFBGEMM
407    def test_matching_failure_node_count(self):
408        # verify that matching graphs with matching node types but
409        # different counts of matchable nodes fails
410        m1 = nn.Sequential(nn.Conv2d(1, 1, 1)).eval()
411        m2 = nn.Sequential(nn.Conv2d(1, 1, 1), nn.Conv2d(1, 1, 1)).eval()
412        example_inputs = (torch.randn(1, 1, 1, 1),)
413        mp1 = prepare_fx(m1, {'': torch.ao.quantization.default_qconfig}, example_inputs=example_inputs)
414        mp2 = prepare_fx(m2, {'': torch.ao.quantization.default_qconfig}, example_inputs=example_inputs)
415        with self.assertRaises(GraphMatchingException) as ex:
416            results = get_matching_subgraph_pairs(mp1, mp2)
417
418    @skipIfNoFBGEMM
419    def test_matching_failure_node_type(self):
420        # verify that matching graphs with non-matching node types fails
421        m1 = nn.Sequential(nn.Conv2d(1, 1, 1)).eval()
422        m2 = nn.Sequential(nn.Linear(1, 1)).eval()
423        example_inputs = (torch.randn(1, 1, 1, 1),)
424        mp1 = prepare_fx(m1, {'': torch.ao.quantization.default_qconfig}, example_inputs=example_inputs)
425        example_inputs = (torch.randn(1, 1),)
426        mp2 = prepare_fx(m2, {'': torch.ao.quantization.default_qconfig}, example_inputs=example_inputs)
427        with self.assertRaises(GraphMatchingException) as ex:
428            results = get_matching_subgraph_pairs(mp1, mp2)
429
430    @skipIfNoFBGEMM
431    def test_nodes_before_cat(self):
432        # verify that nodes before cat get matched
433        class M(nn.Module):
434            def forward(self, x0):
435                x1 = torch.add(x0, 1.0)
436                y1 = torch.add(x0, 1.0)
437                x2 = torch.cat([x1, y1])
438                return x2
439
440        m = M().eval()
441        example_inputs = (torch.randn(1),)
442        mp = prepare_fx(m, {'': torch.ao.quantization.default_qconfig}, example_inputs=example_inputs)
443        mp_copy = copy.deepcopy(mp)
444        mq = convert_fx(mp_copy)
445        results = get_matching_subgraph_pairs(mp, mq)
446
447        base_name_to_sets_of_related_ops = get_base_name_to_sets_of_related_ops()
448        cat_name_0 = 'base_op_' + get_base_name_for_op(
449            base_name_to_sets_of_related_ops, torch.cat) + '_0'
450        add_name_0 = 'base_op_' + get_base_name_for_op(
451            base_name_to_sets_of_related_ops, torch.add) + '_0'
452        add_name_1 = 'base_op_' + get_base_name_for_op(
453            base_name_to_sets_of_related_ops, torch.add) + '_1'
454
455        expected_types = {
456            cat_name_0: ((torch.cat, torch.cat), (torch.cat, torch.cat)),
457            add_name_0: ((torch.add, torch.ao.quantization.MinMaxObserver), (toq.add, toq.add)),
458            add_name_1: ((torch.add, torch.ao.quantization.MinMaxObserver), (toq.add, toq.add)),
459        }
460        self.assert_types_for_matched_subgraph_pairs(results, expected_types, mp, mq)
461
462    @skipIfNoFBGEMM
463    def test_dict_return_type(self):
464        # verify that we can traverse up nodes which return dictionaries
465        class M(nn.Module):
466            def forward(self, x0):
467                x1 = torch.add(x0, 1.0)
468                y1 = torch.add(x0, 1.0)
469                z1 = torch.add(x0, 1.0)
470                a1 = {'x1': x1, 'y1': (y1,), 'z1': [{'key': (z1,)}]}
471                return a1
472
473        m = M().eval()
474        example_inputs = (torch.randn(1),)
475        mp = prepare_fx(m, {'': torch.ao.quantization.default_qconfig}, example_inputs=example_inputs)
476        mp_copy = copy.deepcopy(mp)
477        mq = convert_fx(mp_copy)
478        results = get_matching_subgraph_pairs(mp, mq)
479
480        base_name_to_sets_of_related_ops = get_base_name_to_sets_of_related_ops()
481        add_name_0 = 'base_op_' + get_base_name_for_op(
482            base_name_to_sets_of_related_ops, torch.add) + '_0'
483        add_name_1 = 'base_op_' + get_base_name_for_op(
484            base_name_to_sets_of_related_ops, torch.add) + '_1'
485        add_name_2 = 'base_op_' + get_base_name_for_op(
486            base_name_to_sets_of_related_ops, torch.add) + '_2'
487
488        expected_types = {
489            add_name_0: ((torch.add, torch.ao.quantization.MinMaxObserver), (toq.add, toq.add)),
490            add_name_1: ((torch.add, torch.ao.quantization.MinMaxObserver), (toq.add, toq.add)),
491            add_name_2: ((torch.add, torch.ao.quantization.MinMaxObserver), (toq.add, toq.add)),
492        }
493        self.assert_types_for_matched_subgraph_pairs(results, expected_types, mp, mq)
494
495    @skipIfNoFBGEMM
496    def test_nodes_with_equal_types_get_matched(self):
497        class M(nn.Module):
498            def __init__(self) -> None:
499                super().__init__()
500                self.conv1 = nn.Conv2d(1, 1, 1)
501                self.conv2 = nn.Conv2d(1, 1, 1)
502
503            def forward(self, x):
504                x = self.conv1(x)
505                x = self.conv2(x)
506                x = torch.mul(x, x)
507                x = torch.sigmoid(x)
508                x = F.relu(x)
509                return x
510
511        m = M().eval()
512        # prevent conv2 from getting quantized, so we can test
513        # modules with equal types
514        qconfig_mapping = torch.ao.quantization.get_default_qconfig_mapping().set_module_name("conv2", None)
515        example_inputs = (torch.randn(1, 1, 1, 1),)
516        mp = prepare_fx(m, qconfig_mapping, example_inputs=example_inputs)
517        mp_copy = copy.deepcopy(mp)
518        mq = convert_fx(mp_copy)
519        results = get_matching_subgraph_pairs(mp, mq)
520
521        base_name_to_sets_of_related_ops = get_base_name_to_sets_of_related_ops()
522        conv_name_0 = 'base_op_' + get_base_name_for_op(
523            base_name_to_sets_of_related_ops, nn.Conv2d) + '_0'
524        conv_name_1 = 'base_op_' + get_base_name_for_op(
525            base_name_to_sets_of_related_ops, nn.Conv2d) + '_1'
526        mul_name_0 = 'base_op_' + get_base_name_for_op(
527            base_name_to_sets_of_related_ops, torch.mul) + '_0'
528        relu_name_0 = 'base_op_' + get_base_name_for_op(
529            base_name_to_sets_of_related_ops, torch.relu) + '_0'
530        sigmoid_name_0 = 'base_op_' + get_base_name_for_op(
531            base_name_to_sets_of_related_ops, torch.sigmoid) + '_0'
532
533        # all of these should be matched
534        expected_types = {
535            conv_name_1:
536                ((nn.Conv2d, torch.ao.quantization.HistogramObserver), (nnq.Conv2d, nnq.Conv2d)),
537            conv_name_0:
538                ((nn.Conv2d, torch.ao.quantization.HistogramObserver), (nn.Conv2d, nn.Conv2d)),
539            mul_name_0: ((torch.mul, torch.ao.quantization.HistogramObserver), (toq.mul, toq.mul)),
540            relu_name_0: ((F.relu, torch.ao.quantization.FixedQParamsObserver), (F.relu, F.relu)),
541            sigmoid_name_0:
542                ((torch.sigmoid, torch.ao.quantization.FixedQParamsObserver), (torch.sigmoid, torch.sigmoid)),
543        }
544        self.assert_types_for_matched_subgraph_pairs(results, expected_types, mp, mq)
545
546    def test_methods(self):
547        """
548        Verify that graph matching works on methods
549        """
550        class M(nn.Module):
551            def forward(self, x):
552                x = x.sigmoid()
553                return x
554
555        m1 = M().eval()
556        m2 = M().eval()
557        qconfig_mapping = torch.ao.quantization.get_default_qconfig_mapping()
558        example_inputs = (torch.randn(1),)
559        m1p = prepare_fx(m1, qconfig_mapping, example_inputs=example_inputs)
560        m2p = prepare_fx(m2, qconfig_mapping, example_inputs=example_inputs)
561        results = get_matching_subgraph_pairs(m1p, m2p)
562        base_name_to_sets_of_related_ops = get_base_name_to_sets_of_related_ops()
563        sigmoid_name_0 = 'base_op_' + get_base_name_for_op(
564            base_name_to_sets_of_related_ops, torch.sigmoid) + '_0'
565        expected_types = {
566            sigmoid_name_0:
567                (('sigmoid', torch.ao.quantization.FixedQParamsObserver), ('sigmoid', torch.ao.quantization.FixedQParamsObserver)),
568        }
569        self.assert_types_for_matched_subgraph_pairs(
570            results, expected_types, m1p, m2p)
571
572    def test_op_relationship_mapping(self):
573        """
574        Tests that the mapping of op relationships is complete.
575        """
576        base_name_to_sets_of_related_ops = get_base_name_to_sets_of_related_ops()
577        type_a_related_to_b = \
578            get_type_a_related_to_b(base_name_to_sets_of_related_ops)
579
580        # 1. check static quant module mappings
581        static_quant_mod_mappings = get_default_static_quant_module_mappings()
582        for fp32_type, int8_type in static_quant_mod_mappings.items():
583            # skip quants and dequants, for the purposes of Numerical Suite
584            types_to_skip = (
585                torch.ao.quantization.QuantStub,
586                torch.ao.quantization.DeQuantStub,
587                nnq.FloatFunctional,
588                # the ConvTranspose3d swap is not implemented in FX Graph
589                # mode quantization yet
590                nn.ConvTranspose3d,
591                # the GroupNorm swap is not implemented in FX Graph
592                # mode quantization yet
593                nn.GroupNorm,
594                # nnq.ReLU6 is no longer swapped, because nn.ReLU6 can
595                # take quantized inputs
596                nn.ReLU6,
597            )
598            if fp32_type in types_to_skip:
599                continue
600
601            # verify relatedness
602            in_type_a_related_to_b = \
603                (fp32_type, int8_type) in type_a_related_to_b
604            self.assertTrue(
605                in_type_a_related_to_b,
606                f"{fp32_type} and {int8_type} need a relationship mapping")
607
608        # 2. check static quant op mappings
609        static_quant_fun_mappings = get_default_float_to_quantized_operator_mappings()
610        for fp32_type, int8_type in static_quant_fun_mappings.items():
611            # verify relatedness
612            in_type_a_related_to_b = \
613                (fp32_type, int8_type) in type_a_related_to_b
614            self.assertTrue(
615                in_type_a_related_to_b,
616                f"{fp32_type} and {int8_type} need a relationship mapping")
617
618        # 3. check dynamic quant mappings
619        dynamic_quant_mappings = get_default_dynamic_quant_module_mappings()
620        for fp32_type, int8_type in dynamic_quant_mappings.items():
621            # TODO(future PR): enable correct weight extraction for these
622            # and remove from this list.
623            types_to_skip = (
624                nn.GRUCell,
625                nn.GRU,
626                nn.LSTMCell,
627                nn.RNNCell,
628            )
629            if fp32_type in types_to_skip:
630                continue
631            # verify relatedness
632            in_type_a_related_to_b = \
633                (fp32_type, int8_type) in type_a_related_to_b
634            self.assertTrue(
635                in_type_a_related_to_b,
636                f"{fp32_type} and {int8_type} need a relationship mapping")
637
638        # 4. go through the ops mapped to each QuantizeHandler type, and verify
639        # correctness.
640        def _op_in_base_sets_of_related_ops(op):
641            for ops in base_name_to_sets_of_related_ops.values():
642                if op in ops:
643                    return True
644            return False
645
646        unmatchable_types_map = get_unmatchable_types_map()
647        FUNS_UNMATCHABLE = unmatchable_types_map['funs_unmatchable']
648        MODS_UNMATCHABLE = unmatchable_types_map['mods_unmatchable']
649        METHS_UNMATCHABLE = unmatchable_types_map['meths_unmatchable']
650
651        def _op_is_unmatchable(op):
652            return (
653                op in FUNS_UNMATCHABLE or
654                op in MODS_UNMATCHABLE or
655                op in METHS_UNMATCHABLE
656            )
657
658        default_quant_patterns = get_all_quant_patterns()
659        for pattern, qhandler_cls in default_quant_patterns.items():
660            base_op = None
661            if isinstance(pattern, tuple):
662                base_op = pattern[-1]
663            elif isinstance(pattern, str):
664                base_op = pattern
665            else:
666                base_op = pattern
667
668            qhandler_cls_all_ops_quantizeable = [
669                qh.CatQuantizeHandler,
670                qh.ConvReluQuantizeHandler,
671                qh.LinearReLUQuantizeHandler,
672                qh.BatchNormQuantizeHandler,
673                qh.EmbeddingQuantizeHandler,
674                qh.RNNDynamicQuantizeHandler,
675            ]
676
677            qhandler_cls_quant_op_same_signature = [
678                qh.FixedQParamsOpQuantizeHandler,
679                qh.CopyNodeQuantizeHandler,
680                qh.GeneralTensorShapeOpQuantizeHandler,
681            ]
682
683            if qhandler_cls == qh.BinaryOpQuantizeHandler:
684                # these ops do not have quantized equivalents
685                ops_to_skip = [
686                    torch.bmm,
687                    torch.div,
688                    torch.sub,
689                    operator.truediv,
690                    operator.sub
691                ]
692                if base_op in ops_to_skip:
693                    continue
694                self.assertTrue(
695                    _op_in_base_sets_of_related_ops(base_op),
696                    f"{base_op} not in sets of related ops")
697            elif qhandler_cls == qh.RNNDynamicQuantizeHandler:
698                # TODO(future PR): add support for all classes in
699                # RNNDynamicQuantizeHandler
700                pass
701            elif qhandler_cls == qh.DefaultNodeQuantizeHandler:
702                self.assertTrue(
703                    _op_in_base_sets_of_related_ops(base_op),
704                    f"{base_op} not in sets of related ops")
705            elif qhandler_cls in qhandler_cls_quant_op_same_signature:
706                # these ops use the same op signature for fp32 and quantized
707                # tensors
708                self.assertTrue(
709                    _op_in_base_sets_of_related_ops(base_op) or
710                    _op_is_unmatchable(base_op),
711                    f"{base_op} not in sets of related ops or unmatchable")
712            elif qhandler_cls in qhandler_cls_all_ops_quantizeable:
713                self.assertTrue(
714                    _op_in_base_sets_of_related_ops(base_op),
715                    f"{base_op} not in sets of related ops")
716            else:
717                # torch.sum does not have quantized equivalents
718                if base_op in [
719                        torch.sum,
720                        nn.GRUCell,
721                        nn.GRU,
722                        nn.LSTMCell,
723                        nn.RNNCell,
724                ]:
725                    continue
726                if isinstance(base_op, tuple):
727                    # skip fusion patterns
728                    continue
729                # didn't match explicit quantize handler class, we can check if the
730                # operator is in the related op set directly
731                if not (_op_in_base_sets_of_related_ops(base_op) or _op_is_unmatchable(base_op)):
732                    raise AssertionError(
733                        f"handling for {qhandler_cls} for op {base_op} not implemented")
734
735    @skipIfNoFBGEMM
736    def test_user_defined_function(self):
737        """
738        Verify that graph matching works on user defined functions
739        """
740        class M1(nn.Module):
741            def forward(self, x):
742                x = F.hardswish(x)
743                return x
744
745        class M2(nn.Module):
746            def forward(self, x):
747                x = _wrapped_hardswish(x)
748                return x
749
750        qconfig_mapping = torch.ao.quantization.get_default_qconfig_mapping()
751        example_inputs = (torch.randn(1, 1, 1, 1),)
752        m1 = prepare_fx(M1().eval(), qconfig_mapping, example_inputs=example_inputs)
753        m2 = prepare_fx(M2().eval(), qconfig_mapping, example_inputs=example_inputs)
754
755        base_name_to_sets_of_related_ops = get_base_name_to_sets_of_related_ops()
756        add_op_to_sets_of_related_ops(
757            base_name_to_sets_of_related_ops, _wrapped_hardswish, F.hardswish)
758
759        results = get_matching_subgraph_pairs(
760            m1, m2,
761            base_name_to_sets_of_related_ops=base_name_to_sets_of_related_ops)
762
763        hardswish_name_0 = 'base_op_' + get_base_name_for_op(
764            base_name_to_sets_of_related_ops, F.hardswish) + '_0'
765
766        expected_types = {
767            hardswish_name_0:
768                ((F.hardswish, torch.ao.quantization.HistogramObserver), (_wrapped_hardswish, _wrapped_hardswish)),
769        }
770        self.assert_types_for_matched_subgraph_pairs(
771            results, expected_types, m1, m2)
772
773    @skipIfNoFBGEMM
774    def test_results_order(self):
775        m = nn.Sequential(
776            nn.Conv2d(1, 1, 1),
777            nn.Linear(1, 1),
778        ).eval()
779        example_inputs = (torch.randn(1, 1, 1, 1),)
780        mp = prepare_fx(m, {'': torch.ao.quantization.default_qconfig}, example_inputs=example_inputs)
781        mp_copy = copy.deepcopy(mp)
782        mq = convert_fx(mp_copy)
783        results = get_matching_subgraph_pairs(mp, mq)
784        self.assertTrue(len(results) == 2)
785        results_iter = iter(results.items())
786        _, (subgraph_a_0, subgraph_b_0) = next(results_iter)
787        self.assertTrue(subgraph_a_0.start_node.name == '_0' and
788                        subgraph_b_0.start_node.name == '_0')
789        _, (subgraph_a_1, subgraph_b_1) = next(results_iter)
790        self.assertTrue(subgraph_a_1.start_node.name == '_1' and
791                        subgraph_b_1.start_node.name == '_1')
792
793
794class TestFXGraphMatcherModels(QuantizationTestCase):
795
796    @skipIfTorchDynamo("too slow")
797    @skipIfNoFBGEMM
798    @skip_if_no_torchvision
799    def test_mobilenet_v2(self):
800        # verify that mobilenetv2 graph is able to be matched
801        import torchvision
802        m = torchvision.models.__dict__['mobilenet_v2'](pretrained=False).eval().float()
803        example_inputs = (torch.randn(1, 3, 224, 224),)
804        mp = prepare_fx(copy.deepcopy(m), {'': torch.ao.quantization.default_qconfig}, example_inputs=example_inputs)
805        # assume success if no exceptions
806        results_m_mp = get_matching_subgraph_pairs(torch.fx.symbolic_trace(m), mp)
807        mp_copy = copy.deepcopy(mp)
808        mq = convert_fx(mp_copy)
809        # assume success if no exceptions
810        results_mp_mq = get_matching_subgraph_pairs(mp, mq)
811
812    @skipIfNoFBGEMM
813    @skip_if_no_torchvision
814    def test_mobilenet_v2_qat(self):
815        # verify that mobilenetv2 graph is able to be matched
816        import torchvision
817        m = torchvision.models.__dict__['mobilenet_v2'](pretrained=False).float()
818        example_inputs = (torch.randn(1, 3, 224, 224),)
819        mp = prepare_qat_fx(
820            copy.deepcopy(m),
821            {'': torch.ao.quantization.get_default_qat_qconfig('fbgemm')},
822            example_inputs=example_inputs)
823        # assume success if no exceptions
824        results_m_mp = get_matching_subgraph_pairs(torch.fx.symbolic_trace(m), mp)
825        mp_copy = copy.deepcopy(mp)
826        mq = convert_fx(mp_copy)
827        # assume success if no exceptions
828        results_mp_mq = get_matching_subgraph_pairs(mp, mq)
829
830
831class FXNumericSuiteQuantizationTestCase(QuantizationTestCase):
832    def _test_extract_weights(
833        self, m, example_inputs, results_len=0, qconfig_dict=None, prepare_fn=prepare_fx
834    ):
835        m = torch.fx.symbolic_trace(m)
836        if qconfig_dict is None:
837            qconfig_dict = {'': torch.ao.quantization.default_qconfig}
838        mp = prepare_fn(copy.deepcopy(m), qconfig_dict, example_inputs=example_inputs)
839        mp_copy = copy.deepcopy(mp)
840        mq = convert_fx(mp_copy)
841
842        # test both the public API as well as the internal GraphModule API
843        for extract_weights_fun in (extract_weights, _extract_weights_impl):
844            # test both m vs mp and mp vs mq
845            for m1, m2 in ((m, mp), (mp, mq)):
846                results = extract_weights_fun('a', m1, 'b', m2)
847                self.assertTrue(
848                    len(results) == results_len,
849                    f"expected len {results_len}, got len {len(results)}")
850                self.assert_ns_compare_dict_valid(results)
851                extend_logger_results_with_comparison(
852                    results, 'a', 'b', compute_sqnr, 'sqnr')
853                extend_logger_results_with_comparison(
854                    results, 'a', 'b', compute_normalized_l2_error, 'l2_error')
855                extend_logger_results_with_comparison(
856                    results, 'a', 'b', compute_cosine_similarity,
857                    'cosine_similarity')
858
859    def _test_match_activations(
860        self, m, data, prepared_expected_node_occurrence=None, results_len=0,
861        should_log_inputs=False,
862        qconfig_dict=None,
863        skip_scripting=False,
864        prepare_fn=prepare_fx,
865    ):
866        if qconfig_dict is None:
867            qconfig_dict = torch.ao.quantization.get_default_qconfig_mapping()
868        if prepare_fn == prepare_fx:
869            m.eval()
870        else:
871            m.train()
872        mp = prepare_fn(copy.deepcopy(m), qconfig_dict, example_inputs=data)
873        mp(*data)
874        mp_copy = copy.deepcopy(mp)
875        mq = convert_fx(mp_copy)
876
877        m_ns, mp_ns2 = add_loggers(
878            'a', m, 'b', copy.deepcopy(mp), OutputLogger,
879            should_log_inputs=should_log_inputs)
880        mp_ns, mq_ns = add_loggers(
881            'a', mp, 'b', mq, OutputLogger,
882            should_log_inputs=should_log_inputs)
883
884        if prepared_expected_node_occurrence:
885            self.checkGraphModuleNodes(
886                m_ns, expected_node_occurrence=prepared_expected_node_occurrence)
887            self.checkGraphModuleNodes(
888                mp_ns2, expected_node_occurrence=prepared_expected_node_occurrence)
889            self.checkGraphModuleNodes(
890                mp_ns, expected_node_occurrence=prepared_expected_node_occurrence)
891            self.checkGraphModuleNodes(
892                mq_ns, expected_node_occurrence=prepared_expected_node_occurrence)
893
894        if not skip_scripting:
895            m_ns = torch.jit.script(m_ns)
896            mp_ns = torch.jit.script(mp_ns)
897            mq_ns = torch.jit.script(mq_ns)
898
899        # calibrate
900        m_ns(*data)
901        mp_ns2(*data)
902        mp_ns(*data)
903        mq_ns(*data)
904
905        # check activation result correctness
906        results = []
907        for m1, m2 in ((m_ns, mp_ns2), (mp_ns, mq_ns)):
908            act_compare_dict = extract_logger_info(
909                m1, m2, OutputLogger, 'b')
910            self.assertTrue(
911                len(act_compare_dict) == results_len,
912                f"expected len {results_len}, got len {len(act_compare_dict)}")
913            self.assert_ns_compare_dict_valid(act_compare_dict)
914            extend_logger_results_with_comparison(
915                act_compare_dict, 'a', 'b', compute_sqnr, 'sqnr')
916            extend_logger_results_with_comparison(
917                act_compare_dict, 'a', 'b', compute_normalized_l2_error, 'l2_error')
918            extend_logger_results_with_comparison(
919                act_compare_dict, 'a', 'b', compute_cosine_similarity,
920                'cosine_similarity')
921            results.append(act_compare_dict)
922        return results
923
924    def _test_match_shadow_activations(
925        self, m, data, prepared_expected_node_occurrence=None, results_len=None,
926        should_log_inputs=False, qconfig_dict=None, skip_scripting=False,
927        prepare_fn=prepare_fx, compare_fp32_vs_fp32_prepared=True,
928    ):
929        if qconfig_dict is None:
930            qconfig_dict = torch.ao.quantization.get_default_qconfig_mapping()
931        if prepare_fn == prepare_fx:
932            m.eval()
933        else:
934            m.train()
935        print("qconfig_dict:", qconfig_dict)
936        mp = prepare_fn(copy.deepcopy(m), qconfig_dict, example_inputs=data)
937        print("prepared:", mp)
938        mp(*data)
939        mp_copy = copy.deepcopy(mp)
940        mq = convert_fx(mp_copy)
941        print("quantized:", mq)
942
943        if compare_fp32_vs_fp32_prepared:
944            m_shadows_mp = add_shadow_loggers(
945                'a', copy.deepcopy(m), 'b', copy.deepcopy(mp),
946                OutputLogger, should_log_inputs=should_log_inputs)
947        mp_shadows_mq = add_shadow_loggers(
948            'a', mp, 'b', mq, OutputLogger,
949            should_log_inputs=should_log_inputs)
950
951        if prepared_expected_node_occurrence:
952            if compare_fp32_vs_fp32_prepared:
953                self.checkGraphModuleNodes(
954                    m_shadows_mp, expected_node_occurrence=prepared_expected_node_occurrence)
955            self.checkGraphModuleNodes(
956                mp_shadows_mq, expected_node_occurrence=prepared_expected_node_occurrence)
957
958        if not skip_scripting:
959            if compare_fp32_vs_fp32_prepared:
960                m_shadows_mp = torch.jit.script(m_shadows_mp)
961            mp_shadows_mq = torch.jit.script(mp_shadows_mq)
962
963        # calibrate
964        if compare_fp32_vs_fp32_prepared:
965            m_shadows_mp(*data)
966        mp_shadows_mq(*data)
967
968        # check activation result correctness
969        results = []
970        models = (m_shadows_mp, mp_shadows_mq) if \
971            compare_fp32_vs_fp32_prepared else (mp_shadows_mq,)
972        for model in models:
973            act_compare_dict = extract_shadow_logger_info(
974                model, OutputLogger, 'b')
975            if results_len is not None:
976                self.assertTrue(
977                    len(act_compare_dict) == results_len,
978                    f"expected len {results_len}, got len {len(act_compare_dict)}")
979            self.assert_ns_compare_dict_valid(act_compare_dict)
980            extend_logger_results_with_comparison(
981                act_compare_dict, 'a', 'b', compute_sqnr, 'sqnr')
982            extend_logger_results_with_comparison(
983                act_compare_dict, 'a', 'b', compute_normalized_l2_error, 'l2_error')
984            extend_logger_results_with_comparison(
985                act_compare_dict, 'a', 'b', compute_cosine_similarity,
986                'cosine_similarity')
987            results.append(act_compare_dict)
988        return results
989
990
991class TestFXNumericSuiteCoreAPIs(FXNumericSuiteQuantizationTestCase):
992
993    @skipIfNoFBGEMM
994    def test_extract_weights_mod_ptq(self):
995        m = AllConvAndLinearFusionModules().eval()
996        example_inputs = (torch.randn(1, 1, 1, 1),)
997        self._test_extract_weights(m, example_inputs, results_len=14)
998
999    @skipIfNoFBGEMM
1000    def test_extract_weights_mod_qat(self):
1001        m = AllConvAndLinearFusionModules().train()
1002        qconfig_dict = {'': torch.ao.quantization.get_default_qat_qconfig('fbgemm')}
1003        example_inputs = (torch.randn(1, 1, 1, 1),)
1004        self._test_extract_weights(
1005            m, example_inputs, results_len=14, qconfig_dict=qconfig_dict, prepare_fn=prepare_qat_fx)
1006
1007    @skipIfNoFBGEMM
1008    def test_extract_weights_linear_fun_ptq(self):
1009        m = LinearReluLinearFunctional().eval()
1010        example_inputs = (torch.randn(1, 4),)
1011        self._test_extract_weights(m, example_inputs, results_len=2)
1012
1013    @skipIfNoFBGEMM
1014    def test_extract_weights_linear_fun_qat(self):
1015        m = LinearReluLinearFunctional().train()
1016        qconfig_dict = {'': torch.ao.quantization.get_default_qat_qconfig('fbgemm')}
1017        example_inputs = (torch.randn(1, 4),)
1018        self._test_extract_weights(
1019            m, example_inputs, results_len=2, qconfig_dict=qconfig_dict, prepare_fn=prepare_qat_fx)
1020
1021    @skipIfNoFBGEMM
1022    def test_extract_weights_conv_fun_ptq(self):
1023        w1d = torch.randn(1, 1, 1)
1024        w2d = torch.randn(1, 1, 1, 1)
1025        w3d = torch.randn(1, 1, 1, 1, 1)
1026        b1d = torch.randn(1)
1027        b2d = torch.randn(1)
1028        b3d = torch.randn(1)
1029        m = AllConvFunctional(w1d, w2d, w3d, b1d, b2d, b3d).eval()
1030        example_inputs = (torch.randn(1, 1, 1, 1),)
1031        self._test_extract_weights(m, example_inputs, results_len=6)
1032
1033    @skipIfNoFBGEMM
1034    def test_extract_weights_conv_fun_qat(self):
1035        w1d = torch.randn(1, 1, 1)
1036        w2d = torch.randn(1, 1, 1, 1)
1037        w3d = torch.randn(1, 1, 1, 1, 1)
1038        b1d = torch.randn(1)
1039        b2d = torch.randn(1)
1040        b3d = torch.randn(1)
1041        m = AllConvFunctional(w1d, w2d, w3d, b1d, b2d, b3d).train()
1042        qconfig_dict = {'': torch.ao.quantization.get_default_qat_qconfig('fbgemm')}
1043        example_inputs = (torch.randn(1, 1, 1, 1),)
1044        self._test_extract_weights(
1045            m, example_inputs, results_len=6, qconfig_dict=qconfig_dict, prepare_fn=prepare_qat_fx)
1046
1047    @skipIfNoFBGEMM
1048    def test_extract_weights_dynamic(self):
1049        # TODO(future PR): add Linear-ReLU, after #55393 is fixed.
1050        m = nn.Sequential(nn.Linear(1, 1)).eval()
1051        qconfig_dict = {
1052            'object_type': [
1053                (nn.Linear, default_dynamic_qconfig),
1054            ],
1055        }
1056        example_inputs = (torch.randn(1, 1),)
1057        self._test_extract_weights(m, example_inputs, results_len=1, qconfig_dict=qconfig_dict)
1058
1059    @skipIfNoFBGEMM
1060    def test_extract_weights_fqn(self):
1061        m = nn.Sequential(
1062            nn.Sequential(nn.Conv2d(1, 1, 1)),
1063            nn.Conv2d(1, 1, 1),
1064        ).eval()
1065        qconfig_dict = {'': torch.ao.quantization.default_qconfig}
1066        example_inputs = (torch.randn(1, 1, 1, 1),)
1067        mp = prepare_fx(m, qconfig_dict, example_inputs=example_inputs)
1068        mq = convert_fx(copy.deepcopy(mp))
1069        results = extract_weights('a', mp, 'b', mq)
1070        fqn_a_0 = results['_0_0']['weight']['a'][0]['fqn']
1071        fqn_b_0 = results['_0_0']['weight']['b'][0]['fqn']
1072        self.assertTrue(fqn_a_0 == '0.0' and fqn_a_0 == fqn_b_0)
1073        fqn_a_1 = results['_1']['weight']['a'][0]['fqn']
1074        fqn_b_1 = results['_1']['weight']['b'][0]['fqn']
1075        self.assertTrue(fqn_a_1 == '1' and fqn_a_1 == fqn_b_1)
1076
1077    def _test_match_activations_mod_impl(self, prepare_fn=prepare_fx):
1078        m = nn.Sequential(
1079            torch.ao.quantization.QuantStub(),
1080            nn.Conv2d(1, 1, 1),
1081            nn.Conv2d(1, 1, 1),
1082        ).eval()
1083        qconfig_dict = None
1084        if prepare_fn == prepare_qat_fx:
1085            qconfig_dict = {'': torch.ao.quantization.get_default_qat_qconfig('fbgemm')}
1086        expected_occurrence = {
1087            ns.call_module(OutputLogger): 2,
1088        }
1089        self._test_match_activations(
1090            m, (torch.randn(2, 1, 2, 2),),
1091            prepared_expected_node_occurrence=expected_occurrence,
1092            results_len=2, qconfig_dict=qconfig_dict, prepare_fn=prepare_fn)
1093
1094    @skipIfNoFBGEMM
1095    def test_match_activations_mod_ptq(self):
1096        self._test_match_activations_mod_impl(prepare_fn=prepare_fx)
1097
1098    @skipIfNoFBGEMM
1099    def test_match_activations_mod_qat(self):
1100        self._test_match_activations_mod_impl(prepare_fn=prepare_qat_fx)
1101
1102    def _test_match_activations_fun_impl(self, prepare_fn=prepare_fx):
1103        m = LinearReluLinearFunctional().eval()
1104        qconfig_dict = None
1105        if prepare_fn == prepare_qat_fx:
1106            qconfig_dict = {'': torch.ao.quantization.get_default_qat_qconfig('fbgemm')}
1107        expected_occurrence = {
1108            ns.call_module(OutputLogger): 2,
1109        }
1110        self._test_match_activations(
1111            m, (torch.randn(4, 4),),
1112            prepared_expected_node_occurrence=expected_occurrence,
1113            results_len=2, prepare_fn=prepare_fn, qconfig_dict=qconfig_dict)
1114
1115    @skipIfNoFBGEMM
1116    def test_match_activations_fun_ptq(self):
1117        self._test_match_activations_fun_impl(prepare_fn=prepare_fx)
1118
1119    @skipIfNoFBGEMM
1120    def test_match_activations_fun_qat(self):
1121        self._test_match_activations_fun_impl(prepare_fn=prepare_qat_fx)
1122
1123    @skipIfNoFBGEMM
1124    def test_match_activations_meth_ptq(self):
1125        """
1126        Verify that add_loggers works on methods
1127        """
1128        class M(nn.Module):
1129            def forward(self, x):
1130                x = x.sigmoid()
1131                return x
1132
1133        m = M().eval()
1134        res = self._test_match_activations(
1135            m, (torch.randn(4, 4),),
1136            results_len=1)
1137
1138    @skipIfNoFBGEMM
1139    def test_match_activations_fqn(self):
1140        m = nn.Sequential(
1141            nn.Sequential(nn.Conv2d(1, 1, 1)),
1142            nn.Conv2d(1, 1, 1),
1143        ).eval()
1144        qconfig_dict = {'': torch.ao.quantization.default_qconfig}
1145        example_inputs = (torch.randn(1, 1, 1, 1),)
1146        mp = prepare_fx(m, qconfig_dict, example_inputs=example_inputs)
1147        mq = convert_fx(copy.deepcopy(mp))
1148        mp_ns, mq_ns = add_loggers('a', mp, 'b', mq, OutputLogger)
1149        datum = torch.randn(1, 1, 1, 1)
1150        mp_ns(datum)
1151        mq_ns(datum)
1152
1153        results = extract_logger_info(mp_ns, mq_ns, OutputLogger, 'b')
1154        fqn_a_0 = results['_0_0']['node_output']['a'][0]['fqn']
1155        fqn_b_0 = results['_0_0']['node_output']['b'][0]['fqn']
1156        self.assertTrue(fqn_a_0 == '0.0' and fqn_a_0 == fqn_b_0)
1157        fqn_a_1 = results['_1']['node_output']['a'][0]['fqn']
1158        fqn_b_1 = results['_1']['node_output']['b'][0]['fqn']
1159        self.assertTrue(fqn_a_1 == '1' and fqn_a_1 == fqn_b_1)
1160
1161    def _test_add_shadow_loggers_mod_impl(self, prepare_fn=prepare_fx):
1162        m = nn.Sequential(
1163            nn.Conv2d(1, 1, 1),
1164            nn.Conv2d(1, 1, 1),
1165        ).eval()
1166        qconfig_dict = None
1167        if prepare_fn == prepare_qat_fx:
1168            qconfig_dict = {'': torch.ao.quantization.get_default_qat_qconfig('fbgemm')}
1169        res = self._test_match_shadow_activations(
1170            m, (torch.randn(1, 1, 4, 4),), results_len=2,
1171            prepare_fn=prepare_fn, qconfig_dict=qconfig_dict)
1172
1173    @skipIfNoFBGEMM
1174    def test_add_shadow_loggers_mod_ptq(self):
1175        self._test_add_shadow_loggers_mod_impl(prepare_fn=prepare_fx)
1176
1177    @skipIfNoFBGEMM
1178    def test_add_shadow_loggers_mod_qat(self):
1179        self._test_add_shadow_loggers_mod_impl(prepare_fn=prepare_qat_fx)
1180
1181    def _test_add_shadow_loggers_fun_impl(self, prepare_fn=prepare_fx):
1182        m = LinearReluLinearFunctional()
1183        qconfig_dict = None
1184        if prepare_fn == prepare_qat_fx:
1185            qconfig_dict = {'': torch.ao.quantization.get_default_qat_qconfig('fbgemm')}
1186        res = self._test_match_shadow_activations(
1187            m, (torch.randn(4, 4),), results_len=2, prepare_fn=prepare_fn,
1188            qconfig_dict=qconfig_dict)
1189
1190    @skipIfNoFBGEMM
1191    def test_add_shadow_loggers_fun_ptq(self):
1192        self._test_add_shadow_loggers_fun_impl(prepare_fn=prepare_fx)
1193
1194    @skipIfNoFBGEMM
1195    def test_add_shadow_loggers_fun_qat(self):
1196        self._test_add_shadow_loggers_fun_impl(prepare_fn=prepare_qat_fx)
1197
1198    @skipIfNoFBGEMM
1199    def test_add_shadow_loggers_meth_ptq(self):
1200        """
1201        Verify that add_loggers works on methods
1202        """
1203        class M(nn.Module):
1204            def forward(self, x):
1205                x = x.sigmoid()
1206                return x
1207
1208        m = M().eval()
1209        res = self._test_match_shadow_activations(
1210            m, (torch.randn(4, 4),),
1211            # For now, sigmoid is not supported for shadowing because the dtype
1212            # inference for it is not implemented yet. So, this is just testing
1213            # that shadowing models with method calls does not crash.
1214            results_len=0)
1215
1216    @skipIfNoFBGEMM
1217    def test_shadow_activations_fqn(self):
1218        m = nn.Sequential(
1219            nn.Sequential(nn.Conv2d(1, 1, 1)),
1220            nn.Conv2d(1, 1, 1),
1221        ).eval()
1222        qconfig_mapping = torch.ao.quantization.get_default_qconfig_mapping()
1223        example_inputs = (torch.randn(1, 1, 1, 1),)
1224        mp = prepare_fx(m, qconfig_mapping, example_inputs=example_inputs)
1225        mq = convert_fx(copy.deepcopy(mp))
1226        mp_shadows_mq = add_shadow_loggers('a', mp, 'b', mq, OutputLogger)
1227        datum = torch.randn(1, 1, 1, 1)
1228        mp_shadows_mq(datum)
1229
1230        results = extract_shadow_logger_info(mp_shadows_mq, OutputLogger, 'b')
1231        fqn_a_0 = results['_0_0']['node_output']['a'][0]['fqn']
1232        fqn_b_0 = results['_0_0']['node_output']['b'][0]['fqn']
1233        self.assertTrue(fqn_a_0 == '0.0' and fqn_a_0 == fqn_b_0)
1234        fqn_a_1 = results['_1']['node_output']['a'][0]['fqn']
1235        fqn_b_1 = results['_1']['node_output']['b'][0]['fqn']
1236        self.assertTrue(fqn_a_1 == '1' and fqn_a_1 == fqn_b_1)
1237
1238    @skipIfNoFBGEMM
1239    def test_logging_inputs(self):
1240        """
1241        Verifies that logging inputs works correctly
1242        """
1243        class M(nn.Module):
1244            def __init__(self) -> None:
1245                super().__init__()
1246                self.conv = nn.Conv2d(1, 1, 1)
1247
1248            def forward(self, x):
1249                x = self.conv(x)
1250                x = torch.cat([x, x], dim=0)
1251                return x
1252
1253        m = M().eval()
1254        self._test_match_shadow_activations(
1255            m, (torch.randn(1, 1, 4, 4),),
1256            results_len=1,
1257            should_log_inputs=True)
1258
1259    @skipIfNoFBGEMM
1260    def test_ops_with_same_fp32_and_int8_signature(self):
1261        """
1262        Verifies that we can match pairs of ops which have the same aten
1263        signature for fp32 and int8 tensors.
1264        """
1265        class M(nn.Module):
1266            def __init__(self) -> None:
1267                super().__init__()
1268                self.max_pool_2d = nn.MaxPool2d(2)
1269
1270            def forward(self, x):
1271                x = self.max_pool_2d(x)
1272                x = F.relu(x)
1273                return x
1274
1275        m = M().eval()
1276        self._test_match_activations(
1277            m, (torch.randn(1, 1, 2, 2),),
1278            results_len=2)
1279
1280    @skipIfNoFBGEMM
1281    def test_add_mul_inputs_activations(self):
1282        m = AddMulFunctional().eval()
1283        res = self._test_match_activations(
1284            m, (torch.randn(2, 2), torch.randn(2, 2)),
1285            results_len=6, should_log_inputs=True)
1286
1287    @skipIfNoFBGEMM
1288    def test_linear_fp16_weights(self):
1289        qconfig_dict = {'': torch.ao.quantization.float16_static_qconfig}
1290        m = LinearReluFunctional().eval()
1291        example_inputs = (torch.randn(1, 4),)
1292        self._test_extract_weights(m, example_inputs, results_len=1, qconfig_dict=qconfig_dict)
1293
1294    @skipIfNoFBGEMM
1295    def test_linear_fp16_activations(self):
1296        for should_log_inputs in (True, False):
1297            qconfig_dict = {'': torch.ao.quantization.float16_static_qconfig}
1298            m = LinearReluFunctional().eval()
1299            num_loggers = 2 if should_log_inputs else 1
1300            expected_occurrence = {
1301                ns.call_module(OutputLogger): num_loggers,
1302            }
1303            res = self._test_match_activations(
1304                m, (torch.randn(4, 4),),
1305                prepared_expected_node_occurrence=expected_occurrence,
1306                results_len=1,
1307                qconfig_dict=qconfig_dict,
1308                should_log_inputs=should_log_inputs)
1309
1310    @skipIfNoFBGEMM
1311    def test_linear_fp16_shadow_activations(self):
1312        for should_log_inputs in (True, False):
1313            qconfig_dict = {'': torch.ao.quantization.float16_static_qconfig}
1314            m = LinearReluFunctional().eval()
1315            num_loggers = 4 if should_log_inputs else 2
1316            expected_occurrence = {
1317                ns.call_module(OutputLogger): num_loggers,
1318            }
1319            res2 = self._test_match_shadow_activations(
1320                m, (torch.randn(4, 4),),
1321                prepared_expected_node_occurrence=expected_occurrence,
1322                results_len=1,
1323                qconfig_dict=qconfig_dict,
1324                should_log_inputs=should_log_inputs)
1325
1326    @skipIfNoFBGEMM
1327    def test_linear_fp16_vs_linear_fp16_shadow_activations(self):
1328        m = LinearFunctional().eval()
1329        qconfig_dict = {'': torch.ao.quantization.float16_static_qconfig}
1330        example_inputs = (torch.randn(1, 4),)
1331        mp = prepare_fx(m, qconfig_dict, example_inputs=example_inputs)
1332        mq1 = convert_fx(copy.deepcopy(mp))
1333        mq2 = convert_fx(copy.deepcopy(mp))
1334        mq1_shadows_mq2 = _add_shadow_loggers_impl(
1335            'a', mq1, 'b', mq2, OutputLogger, should_log_inputs=False)
1336        mq1_shadows_mq2(torch.randn(4, 4))
1337        act_compare_dict = extract_shadow_logger_info(
1338            mq1_shadows_mq2, OutputLogger, 'b')
1339        self.assertTrue(len(act_compare_dict) == 1)
1340        self.assert_ns_compare_dict_valid(act_compare_dict)
1341
1342
1343    @skipIfNoFBGEMM
1344    def test_op_with_either_fp32_or_int8_input(self):
1345        """
1346        Verify that shadowing works with ops which accept either fp32 or
1347        int8 inputs.
1348        """
1349        class M(nn.Module):
1350            def __init__(self) -> None:
1351                super().__init__()
1352                self.relu = nn.ReLU()
1353
1354            def forward(self, x):
1355                x = self.relu(x)
1356                x = F.relu(x)
1357                return x
1358
1359        m = M()
1360        res = self._test_match_shadow_activations(
1361            m, (torch.randn(4, 4),),
1362            # Note: shadowing relu by itself is currently not supported,
1363            # this test is just testing that it does not crash
1364            results_len=0)
1365
1366    def _test_int8_shadows_int8_impl(self, m):
1367        """
1368        Verify that shadowing works where both modules are int8
1369        """
1370        qconfig_dict = {'': torch.ao.quantization.default_qconfig}
1371        example_inputs = (torch.randn(4, 1, 4, 4),)
1372        mp = prepare_fx(m, qconfig_dict, example_inputs=example_inputs)
1373        mp(*example_inputs)
1374        mq1 = convert_fx(copy.deepcopy(mp))
1375        mq2 = convert_fx(mp)
1376        mq1_shadows_mq2 = add_shadow_loggers('a', mq1, 'b', mq2, OutputLogger)
1377        mq1_shadows_mq2(torch.randn(4, 1, 4, 4))
1378        act_compare_dict = extract_shadow_logger_info(
1379            mq1_shadows_mq2, OutputLogger, 'b')
1380        self.assertTrue(len(act_compare_dict) == 1)
1381        self.assert_ns_compare_dict_valid(act_compare_dict)
1382
1383    @skipIfNoFBGEMM
1384    def test_int8_shadows_int8_mod(self):
1385        m = nn.Sequential(nn.Conv2d(1, 1, 1)).eval()
1386        self._test_int8_shadows_int8_impl(m)
1387
1388    @skipIfNoFBGEMM
1389    def test_int8_shadows_int8_fun(self):
1390        m = LinearFunctional().eval()
1391        self._test_int8_shadows_int8_impl(m)
1392
1393    @skipIfNoFBGEMM
1394    def test_user_module_scriptable(self):
1395        # Logging of the output of this class is not supported, because it is
1396        # neither a tensor or an RNN return type.
1397        class M1(nn.Module):
1398            def forward(self, x):
1399                x1 = x * 2
1400                x2 = x * 4
1401                return (x1, x2)
1402
1403        class M2(nn.Module):
1404            def __init__(self) -> None:
1405                super().__init__()
1406                self.m1 = M1()
1407
1408            def forward(self, x):
1409                x1, x2 = self.m1(x)
1410                return x1, x2
1411
1412        m = M2().eval()
1413        qconfig_dict = {'': torch.ao.quantization.default_qconfig}
1414        prepare_custom_config_dict = {
1415            'non_traceable_module_class': [M1],
1416        }
1417        example_inputs = (torch.randn(1),)
1418        mp1 = prepare_fx(
1419            m,
1420            qconfig_dict,
1421            example_inputs=example_inputs,
1422            prepare_custom_config=prepare_custom_config_dict)
1423        mp2 = copy.deepcopy(mp1)
1424        unmatchable_types_map = get_unmatchable_types_map()
1425        unmatchable_types_map['mods_unmatchable'].add(M1)
1426        mp1_ns, mp2_ns = _add_loggers_impl(
1427            'a', mp1, 'b', mp2, OutputLogger, should_log_inputs=False,
1428            unmatchable_types_map=unmatchable_types_map)
1429
1430        # Scripting a model with loggers should succeed. If it fails because of
1431        # incorrect dtypes, we can blocklist the associated types from being instrumented.
1432        mp1_ns_scripted = torch.jit.script(mp1_ns)
1433        mp2_ns_scripted = torch.jit.script(mp2_ns)
1434
1435    @skipIfNoFBGEMM
1436    def test_user_module(self):
1437        """
1438        For user defined modules,
1439        1. weight extraction should not crash
1440        2. unshadowed activations should only have loggers for known types
1441        3. shadowed activations should only have loggers for known types with
1442             known dtypes
1443        """
1444        class UserModule(nn.Module):
1445            def forward(self, x):
1446                return x
1447
1448        class M(nn.Module):
1449            def __init__(self) -> None:
1450                super().__init__()
1451                self.linear = nn.Linear(1, 1)
1452                self.user_module = UserModule()
1453
1454            def forward(self, x):
1455                x = self.linear(x)
1456                x = self.user_module(x)
1457                return x
1458
1459        m = M().eval()
1460
1461        # quantize without tracing through UserModule
1462        qconfig_dict = {'': torch.ao.quantization.default_qconfig}
1463        prepare_custom_config_dict = {'non_traceable_module_name': ['user_module']}
1464        example_inputs = (torch.randn(1, 1, 1),)
1465        mp = prepare_fx(
1466            m,
1467            qconfig_dict,
1468            example_inputs=example_inputs,
1469            prepare_custom_config=prepare_custom_config_dict)
1470        mp(*example_inputs)
1471        mq = convert_fx(copy.deepcopy(mp))
1472
1473        # weight extraction should not crash
1474        weights = _extract_weights_impl('fp32_prepared', mp, 'int8', mq)
1475
1476        # unshadowed activations should have loggers
1477
1478        # add loggers, without retracing
1479        # note: converting again because we cannot copy a quantized linear
1480        mp_ns, mq_ns = _add_loggers_impl(
1481            'fp32_prepared', copy.deepcopy(mp), 'int8',
1482            convert_fx(copy.deepcopy(mp)), OutputLogger,
1483            should_log_inputs=True)
1484        # both fp32 and int8 models should have 2 loggers each, 2 for I/O
1485        # of linear, and 0 for I/O of user_module
1486        unshadowed_expected_occurrence = {
1487            ns.call_module(OutputLogger): 2,
1488        }
1489        self.checkGraphModuleNodes(
1490            mp_ns, expected_node_occurrence=unshadowed_expected_occurrence)
1491        self.checkGraphModuleNodes(
1492            mq_ns, expected_node_occurrence=unshadowed_expected_occurrence)
1493
1494        # shadowed activations should only have loggers for nodes where
1495        # the types are known and we can do a dtype cast
1496
1497        # add shadow loggers, without retracing
1498        mp_shadows_mq_ns = _add_shadow_loggers_impl(
1499            'fp32_prepared', mp, 'int8', mq, OutputLogger,
1500            should_log_inputs=True)
1501        # 4 loggers for I/O of linear, 0 loggers for I/O of user_module
1502        shadowed_expected_occurrence = {
1503            ns.call_module(OutputLogger): 4,
1504        }
1505        self.checkGraphModuleNodes(
1506            mp_shadows_mq_ns, expected_node_occurrence=shadowed_expected_occurrence)
1507
1508    def test_op_io_dtype_coverage(self):
1509        """
1510        Tests that all the ops quantization cares about have input and output
1511        dtypes defined.
1512        """
1513        base_name_to_sets_of_related_ops = get_base_name_to_sets_of_related_ops()
1514        type_a_related_to_b = \
1515            get_type_a_related_to_b(base_name_to_sets_of_related_ops)
1516
1517        # TODO(future PR): clean this up
1518        node_type_to_io_type_map = get_node_type_to_io_type_map()
1519        FUNS_IO_TYPE_FP32 = node_type_to_io_type_map['funs_io_type_fp32']
1520        FUNS_IO_TYPE_INT8 = node_type_to_io_type_map['funs_io_type_int8']
1521        FUNS_IO_TYPE_FP32_OR_INT8 = node_type_to_io_type_map['funs_io_type_fp32_or_int8']
1522        MODS_IO_TYPE_FP32 = node_type_to_io_type_map['mods_io_type_fp32']
1523        MODS_IO_TYPE_INT8 = node_type_to_io_type_map['mods_io_type_int8']
1524        MODS_IO_TYPE_FP32_OR_INT8 = node_type_to_io_type_map['mods_io_type_fp32_or_int8']
1525        METHS_IO_TYPE_FP32_OR_INT8 = node_type_to_io_type_map['meths_io_type_fp32_or_int8']
1526
1527        unmatchable_types_map = get_unmatchable_types_map()
1528        FUNS_UNMATCHABLE = unmatchable_types_map['funs_unmatchable']
1529        MODS_UNMATCHABLE = unmatchable_types_map['mods_unmatchable']
1530        METHS_UNMATCHABLE = unmatchable_types_map['meths_unmatchable']
1531
1532        # 1. check static quant module mappings
1533        static_quant_mod_mappings = get_default_static_quant_module_mappings()
1534        for fp32_type, int8_type in static_quant_mod_mappings.items():
1535            types_to_skip = (
1536                torch.ao.quantization.QuantStub,
1537                torch.ao.quantization.DeQuantStub,
1538                nnq.FloatFunctional,
1539                # TODO(future PR): look into whether shadowing embeddings
1540                # makes sense
1541                nn.Embedding,
1542                nn.EmbeddingBag,
1543                # the ConvTranspose3d swap is not implemented in FX Graph
1544                # mode quantization yet
1545                nn.ConvTranspose3d,
1546                # the GroupNorm swap is not implemented in FX Graph
1547                # mode quantization yet
1548                nn.GroupNorm,
1549                # nnq.ReLU6 is no longer swapped, because nn.ReLU6 can
1550                # take quantized inputs
1551                nn.ReLU6,
1552            )
1553            if fp32_type in types_to_skip:
1554                continue
1555            self.assertTrue(
1556                fp32_type in MODS_IO_TYPE_FP32,
1557                f"missing IO type handling for f{fp32_type}")
1558            self.assertTrue(
1559                int8_type in MODS_IO_TYPE_INT8,
1560                f"missing IO type handling for f{int8_type}")
1561
1562        # 2. check static quant op mappings
1563        static_quant_fun_mappings = get_default_float_to_quantized_operator_mappings()
1564        for fp32_type, int8_type in static_quant_fun_mappings.items():
1565            self.assertTrue(
1566                fp32_type in FUNS_IO_TYPE_FP32,
1567                f"missing IO type handling for f{fp32_type}")
1568            self.assertTrue(
1569                int8_type in FUNS_IO_TYPE_INT8,
1570                f"missing IO type handling for f{int8_type}")
1571
1572        # 3. check dynamic quant mappings
1573        dynamic_quant_mappings = get_default_dynamic_quant_module_mappings()
1574        for fp32_type1, fp32_type2 in dynamic_quant_mappings.items():
1575            # TODO(future PR): verify correct I/O for these and remove from
1576            # this list.
1577            types_to_skip = (
1578                nn.GRUCell,
1579                nn.GRU,
1580                nn.LSTMCell,
1581                nn.RNNCell,
1582                # TODO(future PR): look into whether shadowing embeddings
1583                # makes sense
1584                nn.Embedding,
1585                nn.EmbeddingBag,
1586            )
1587            if fp32_type1 in types_to_skip:
1588                continue
1589            self.assertTrue(
1590                fp32_type1 in MODS_IO_TYPE_FP32,
1591                f"missing IO type handling for f{fp32_type1}")
1592            self.assertTrue(
1593                fp32_type2 in MODS_IO_TYPE_FP32,
1594                f"missing IO type handling for f{fp32_type2}")
1595
1596        # 4. go through the ops mapped to each QuantizeHandler type, and verify
1597        # correctness.
1598        default_quant_patterns = get_all_quant_patterns()
1599        for pattern, qhandler_cls in default_quant_patterns.items():
1600            base_op = None
1601            if isinstance(pattern, tuple):
1602                base_op = pattern[-1]
1603            elif isinstance(pattern, str):
1604                base_op = pattern
1605            else:
1606                base_op = pattern
1607
1608            if (
1609                qhandler_cls in (
1610                    qh.BinaryOpQuantizeHandler,
1611                    qh.RNNDynamicQuantizeHandler,
1612                )
1613            ):
1614                # TODO(future PR): implement shadowing for binary ops
1615                # TODO(future PR): implement shadowing for RNN ops
1616                continue
1617            elif qhandler_cls == qh.CatQuantizeHandler:
1618                self.assertTrue(
1619                    base_op in FUNS_IO_TYPE_FP32_OR_INT8,
1620                    f"missing IO type handling for {base_op}")
1621            elif (
1622                qhandler_cls in (
1623                    qh.ConvReluQuantizeHandler,
1624                    qh.LinearReLUQuantizeHandler,
1625                    qh.BatchNormQuantizeHandler,
1626                    qh.DefaultNodeQuantizeHandler,
1627                )
1628            ):
1629                self.assertTrue(
1630                    (base_op in FUNS_IO_TYPE_FP32) or (base_op in MODS_IO_TYPE_FP32),
1631                    f"missing IO type handling for {base_op}")
1632            elif (
1633                qhandler_cls in (
1634                    qh.FixedQParamsOpQuantizeHandler,
1635                    qh.CopyNodeQuantizeHandler,
1636                    qh.GeneralTensorShapeOpQuantizeHandler,
1637                )
1638            ):
1639                if (
1640                    base_op in FUNS_UNMATCHABLE or
1641                    base_op in MODS_UNMATCHABLE or
1642                    base_op in METHS_UNMATCHABLE
1643                ):
1644                    continue
1645
1646                self.assertTrue(
1647                    (base_op in FUNS_IO_TYPE_FP32_OR_INT8) or
1648                    (base_op in MODS_IO_TYPE_FP32_OR_INT8) or
1649                    (base_op in METHS_IO_TYPE_FP32_OR_INT8) or
1650                    # Softmax has a different signature for the quantized
1651                    # version, so it does not fit into the cases above.
1652                    (base_op is torch.nn.Softmax),
1653                    f"missing IO type handling for {base_op}")
1654            elif qhandler_cls == qh.EmbeddingQuantizeHandler:
1655                # embedding shadowing is not implemented, for now
1656                continue
1657            else:
1658                if (
1659                    base_op in FUNS_UNMATCHABLE or
1660                    base_op in MODS_UNMATCHABLE or
1661                    base_op in METHS_UNMATCHABLE
1662                ):
1663                    continue
1664                if qhandler_cls(None, {}).is_general_tensor_value_op():
1665                    self.assertTrue(
1666                        (base_op in FUNS_IO_TYPE_FP32_OR_INT8) or
1667                        (base_op in MODS_IO_TYPE_FP32_OR_INT8) or
1668                        (base_op in METHS_IO_TYPE_FP32_OR_INT8),
1669                        f"missing IO type handling for {base_op} using {qhandler_cls}")
1670                else:
1671                    self.assertTrue(
1672                        (base_op in FUNS_IO_TYPE_FP32_OR_INT8) or
1673                        (base_op in MODS_IO_TYPE_FP32_OR_INT8) or
1674                        (base_op in METHS_IO_TYPE_FP32_OR_INT8) or
1675                        (base_op in FUNS_IO_TYPE_FP32) or
1676                        (base_op in MODS_IO_TYPE_FP32) or
1677                        f"missing IO type handling for {base_op} using {qhandler_cls}")
1678
1679    @skipIfNoFBGEMM
1680    def test_user_defined_function(self):
1681        """
1682        Verify that NS APIs work on user defined functions
1683        """
1684        class M1(nn.Module):
1685            def __init__(self) -> None:
1686                super().__init__()
1687                self.w1 = nn.Parameter(torch.empty(1, 1))
1688                self.b1 = nn.Parameter(torch.zeros(1))
1689                torch.nn.init.kaiming_uniform_(self.w1, a=math.sqrt(5))
1690
1691            def forward(self, x):
1692                x = F.hardswish(x)
1693                x = x.sigmoid()
1694                x = F.linear(x, self.w1, self.b1)
1695                return x
1696
1697        class M2(nn.Module):
1698            def __init__(self) -> None:
1699                super().__init__()
1700                self.w1 = nn.Parameter(torch.empty(1, 1))
1701                self.b1 = nn.Parameter(torch.zeros(1))
1702                torch.nn.init.kaiming_uniform_(self.w1, a=math.sqrt(5))
1703
1704            def forward(self, x):
1705                x = _wrapped_hardswish(x)
1706                x = _wrapped_sigmoid(x)
1707                x = _wrapped_linear(x, self.w1, self.b1)
1708                return x
1709
1710        qconfig_mapping = torch.ao.quantization.get_default_qconfig_mapping()
1711        example_inputs = (torch.randn(1, 1),)
1712        m1 = prepare_fx(M1().eval(), qconfig_mapping, example_inputs=example_inputs)
1713        m2 = prepare_fx(M2().eval(), qconfig_mapping, example_inputs=example_inputs)
1714        data = torch.randn(1, 1)
1715
1716        base_name_to_sets_of_related_ops = get_base_name_to_sets_of_related_ops()
1717        add_op_to_sets_of_related_ops(
1718            base_name_to_sets_of_related_ops, _wrapped_hardswish, F.hardswish)
1719        add_op_to_sets_of_related_ops(
1720            base_name_to_sets_of_related_ops, _wrapped_sigmoid, F.sigmoid)
1721        add_op_to_sets_of_related_ops(
1722            base_name_to_sets_of_related_ops, _wrapped_linear, F.linear)
1723
1724        op_to_type_to_weight_extraction_fn = \
1725            get_op_to_type_to_weight_extraction_fn()
1726        op_to_type_to_weight_extraction_fn['call_function'][_wrapped_linear] = \
1727            torch.ao.ns.fx.weight_utils.get_linear_fun_weight
1728
1729        # test compare weights
1730        results = extract_weights(
1731            'a', m1, 'b', m2,
1732            base_name_to_sets_of_related_ops=base_name_to_sets_of_related_ops,
1733            op_to_type_to_weight_extraction_fn=op_to_type_to_weight_extraction_fn)
1734        self.assertTrue(len(results) == 1)
1735        self.assertTrue(len(results['_wrapped_linear']['weight']) == 2)
1736
1737        # test unshadowed activations
1738
1739        m1_ns, m2_ns = _add_loggers_impl(
1740            'a', copy.deepcopy(m1), 'b', copy.deepcopy(m2), OutputLogger,
1741            should_log_inputs=False,
1742            base_name_to_sets_of_related_ops=base_name_to_sets_of_related_ops)
1743
1744        # calibrate
1745        m1_ns(data)
1746        m2_ns(data)
1747
1748        # check activation result correctness
1749        act_compare_dict = extract_logger_info(m1_ns, m2_ns, OutputLogger, 'b')
1750        self.assertTrue(len(act_compare_dict) == 3)
1751        self.assert_ns_compare_dict_valid(act_compare_dict)
1752
1753        # test shadowed activations
1754
1755        node_type_to_io_type_map = get_node_type_to_io_type_map()
1756        node_type_to_io_type_map['funs_io_type_fp32'].add(_wrapped_hardswish)
1757        node_type_to_io_type_map['funs_io_type_fp32'].add(_wrapped_sigmoid)
1758
1759        m2_shadows_m1_ns = _add_shadow_loggers_impl(
1760            'a', m2, 'b', m1, OutputLogger,
1761            should_log_inputs=False,
1762            base_name_to_sets_of_related_ops=base_name_to_sets_of_related_ops,
1763            node_type_to_io_type_map=node_type_to_io_type_map)
1764
1765        # calibrate
1766        m2_shadows_m1_ns(data)
1767
1768        # check activation result correctness
1769        act_compare_dict = extract_shadow_logger_info(
1770            m2_shadows_m1_ns, OutputLogger, 'b')
1771        self.assertTrue(len(act_compare_dict) == 2)
1772        self.assert_ns_compare_dict_valid(act_compare_dict)
1773
1774    @skipIfNoFBGEMM
1775    def test_layer_names(self):
1776        m = nn.Sequential(
1777            nn.Conv2d(1, 1, 1),
1778            nn.Conv2d(1, 1, 1),
1779            nn.Sigmoid(),
1780        ).eval()
1781        qconfig_mapping = torch.ao.quantization.get_default_qconfig_mapping("fbgemm")
1782        example_inputs = (torch.randn(1, 1, 1, 1),)
1783        mp = torch.ao.quantization.quantize_fx.prepare_fx(m, qconfig_mapping, example_inputs=example_inputs)
1784        mq = torch.ao.quantization.quantize_fx.convert_fx(copy.deepcopy(mp))
1785
1786        # extract weights
1787        results = extract_weights('fp32', mp, 'int8', mq)
1788        mq_node_names = [node.name for node in mq.graph.nodes]
1789        for layer_name in results.keys():
1790            self.assertTrue(layer_name in mq_node_names)
1791
1792        # match activations
1793        mq = torch.ao.quantization.quantize_fx.convert_fx(copy.deepcopy(mp))
1794        mp_ns, mq_ns = add_loggers(
1795            'fp32', copy.deepcopy(mp), 'int8', mq, OutputLogger)
1796        data = torch.randn(1, 1, 1, 1)
1797        mp_ns(data)
1798        mq_ns(data)
1799        results = extract_logger_info(mp_ns, mq_ns, OutputLogger, 'int8')
1800        mq_node_names = [node.name for node in mq_ns.graph.nodes]
1801        for layer_name in results.keys():
1802            self.assertTrue(layer_name in mq_node_names)
1803
1804        # match shadow activations
1805        mq = torch.ao.quantization.quantize_fx.convert_fx(copy.deepcopy(mp))
1806        mp_shadows_mq = add_shadow_loggers(
1807            'fp32', mp, 'int8', mq, OutputLogger)
1808        mp_shadows_mq(data)
1809        results = extract_shadow_logger_info(
1810            mp_shadows_mq, OutputLogger, 'int8')
1811        mq_node_names = [node.name for node in mp_shadows_mq.graph.nodes]
1812        for layer_name in results.keys():
1813            self.assertTrue(layer_name in mq_node_names)
1814
1815    @skipIfNoFBGEMM
1816    def test_extend_logger_results_with_comparison(self):
1817        m = nn.Sequential(nn.Conv2d(1, 1, 1), nn.Conv2d(1, 1, 1)).eval()
1818        qconfig_dict = {'': torch.ao.quantization.default_qconfig}
1819        example_inputs = (torch.randn(1, 1, 1, 1),)
1820        mp = torch.ao.quantization.quantize_fx.prepare_fx(
1821            m, qconfig_dict, example_inputs=example_inputs)
1822        mq = torch.ao.quantization.quantize_fx.convert_fx(copy.deepcopy(mp))
1823
1824        # extract weights
1825        results = extract_weights('fp32', mp, 'int8', mq)
1826        extend_logger_results_with_comparison(
1827            results, 'fp32', 'int8', compute_sqnr, 'sqnr_int8_vs_fp32')
1828        extend_logger_results_with_comparison(
1829            results, 'fp32', 'int8', compute_normalized_l2_error, 'l2_error_int8_vs_fp32')
1830        extend_logger_results_with_comparison(
1831            results, 'fp32', 'int8', compute_cosine_similarity,
1832            'cosine_similarity_int8_vs_fp32')
1833
1834        for layer_results in results.values():
1835            assert 'sqnr_int8_vs_fp32' in \
1836                layer_results['weight']['int8'][0].keys()
1837            assert 'l2_error_int8_vs_fp32' in \
1838                layer_results['weight']['int8'][0].keys()
1839            assert 'cosine_similarity_int8_vs_fp32' in \
1840                layer_results['weight']['int8'][0].keys()
1841
1842    @skipIfNoFBGEMM
1843    def test_int8_shadows_fp32_simple(self):
1844        m = nn.Sequential(nn.Conv2d(1, 1, 1), nn.Conv2d(1, 1, 1), nn.ReLU()).eval()
1845        qconfig_dict = {'': torch.ao.quantization.default_qconfig}
1846        example_inputs = (torch.randn(1, 1, 1, 1),)
1847        mp = torch.ao.quantization.quantize_fx.prepare_fx(
1848            m, qconfig_dict, example_inputs=example_inputs)
1849        mp(torch.randn(1, 1, 1, 1))
1850        mq = torch.ao.quantization.quantize_fx.convert_fx(copy.deepcopy(mp))
1851        mq_ref = torch.ao.quantization.quantize_fx.convert_fx(copy.deepcopy(mp))
1852        mp_shadows_mq = add_shadow_loggers(
1853            'int8', mq, 'fp32', mp, OutputLogger)
1854
1855        # verify that scale and zp were extracted correctly
1856
1857        # for the first op, the scale+zp live as attributes on the module
1858        scale_0 = mp_shadows_mq._0_input_scale_0
1859        scale_0_ref = getattr(mq_ref, '0_input_scale_0')
1860        self.assertEqual(scale_0, scale_0_ref)
1861        zp_0 = mp_shadows_mq._0_input_zero_point_0
1862        zp_0_ref = getattr(mq_ref, '0_input_zero_point_0')
1863        self.assertEqual(zp_0, zp_0_ref)
1864
1865        # for the second op, the scale and zp of input to second op
1866        # must equal to scale and zp of output of first op
1867        scale_1 = mp_shadows_mq._1_input_scale_0
1868        scale_1_ref = getattr(mq_ref, '0').scale
1869        self.assertEqual(scale_1, scale_1_ref)
1870        zp_1 = mp_shadows_mq._1_input_zero_point_0
1871        zp_1_ref = getattr(mq_ref, '0').zero_point
1872        self.assertEqual(zp_1, zp_1_ref)
1873
1874        # verify running data works
1875        mp_shadows_mq(torch.randn(1, 1, 1, 1))
1876        act_compare_dict = extract_shadow_logger_info(
1877            mp_shadows_mq, OutputLogger, 'fp32')
1878        self.assertTrue(len(act_compare_dict) == 2)
1879        self.assert_ns_compare_dict_valid(act_compare_dict)
1880
1881    @skipIfNoFBGEMM
1882    def test_int8_shadows_fp32_coverage(self):
1883        class M(torch.nn.Module):
1884            def __init__(self) -> None:
1885                super().__init__()
1886                self.adaptive_avg_pool = nn.AdaptiveAvgPool2d(1)
1887                self.conv = nn.Conv2d(1, 1, 1)
1888
1889            def forward(self, x):
1890                x = self.adaptive_avg_pool(x)
1891                # input qparams of conv will be input qparams of adaptive_avg_pool
1892                x = self.conv(x)
1893                x = torch.mul(x, x)
1894                x = self.conv(x)
1895                x = torch.add(x, x)
1896                x = F.relu(x)
1897                x = self.conv(x)
1898                return x
1899
1900        m = M().eval()
1901        qconfig_dict = {'': torch.ao.quantization.default_qconfig}
1902        example_inputs = (torch.randn(1, 1, 1, 1),)
1903        mp = prepare_fx(m, qconfig_dict, example_inputs=example_inputs)
1904        mp(*example_inputs)
1905        mq = torch.ao.quantization.quantize_fx.convert_fx(copy.deepcopy(mp))
1906        mq_ref = torch.ao.quantization.quantize_fx.convert_fx(copy.deepcopy(mp))
1907        mp_shadows_mq = add_shadow_loggers(
1908            'int8', mq, 'fp32', mp, OutputLogger)
1909        mp_shadows_mq(torch.randn(1, 1, 1, 1))
1910        act_compare_dict = extract_shadow_logger_info(
1911            mp_shadows_mq, OutputLogger, 'fp32')
1912        self.assertTrue(len(act_compare_dict) == 3)
1913        self.assert_ns_compare_dict_valid(act_compare_dict)
1914
1915    @skipIfNoFBGEMM
1916    def test_loggers_preserve_qat_numerics(self):
1917        m = nn.Sequential(nn.Conv2d(1, 1, 1), nn.Conv2d(1, 1, 1))
1918        qconfig_dict = {'': torch.ao.quantization.get_default_qat_qconfig('fbgemm')}
1919        example_inputs = (torch.randn(1, 1, 1, 1),)
1920        mp = prepare_qat_fx(m, qconfig_dict, example_inputs=example_inputs)
1921        mp(*example_inputs)
1922        mc = convert_fx(copy.deepcopy(mp))
1923        mp.apply(torch.ao.quantization.disable_observer)
1924
1925        ref_fp32 = mp(*example_inputs)
1926        ref_int8 = mc(*example_inputs)
1927
1928        mp_ns, mc_ns = add_loggers('fp32', mp, 'int8', mc, OutputLogger)
1929        ref_fp32_ns = mp_ns(*example_inputs)
1930        ref_int8_ns = mc_ns(*example_inputs)
1931        self.assertEqual(ref_fp32, ref_fp32_ns)
1932        self.assertEqual(ref_int8, ref_int8_ns)
1933
1934    @skipIfNoFBGEMM
1935    def test_shadow_loggers_preserve_qat_numerics(self):
1936        m = nn.Sequential(nn.Conv2d(1, 1, 1), nn.Conv2d(1, 1, 1))
1937        qconfig_dict = {'': torch.ao.quantization.get_default_qat_qconfig('fbgemm')}
1938        example_inputs = (torch.randn(1, 1, 1, 1),)
1939        mp = prepare_qat_fx(m, qconfig_dict, example_inputs=example_inputs)
1940        mp(*example_inputs)
1941        mc = convert_fx(copy.deepcopy(mp))
1942        mp.apply(torch.ao.quantization.disable_observer)
1943
1944        ref_fp32 = mp(*example_inputs)
1945        ref_int8 = mc(*example_inputs)
1946
1947        mc_shadows_mp = add_shadow_loggers('int8', mc, 'fp32', mp, OutputLogger)
1948        ref_shadow = mc_shadows_mp(*example_inputs)
1949        self.assertEqual(ref_fp32, ref_shadow)
1950
1951    @unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
1952    def test_extract_weights_cuda(self):
1953        # Note: this is not using quantization because quantized kernels do not
1954        # work on cuda yet.
1955        m1 = nn.Sequential(nn.Conv2d(1, 1, 1)).cuda()
1956        m2 = nn.Sequential(nn.Conv2d(1, 1, 1)).cuda()
1957        results = extract_weights('a', m1, 'b', m2)
1958        extend_logger_results_with_comparison(
1959            results, 'a', 'b', compute_sqnr, 'sqnr')
1960        self.assert_ns_compare_dict_valid(results)
1961
1962    @unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
1963    def test_add_loggers_cuda(self):
1964        # Note: this is not using quantization because quantized kernels do not
1965        # work on cuda yet.
1966        m1 = nn.Sequential(nn.Conv2d(1, 1, 1)).cuda()
1967        m2 = nn.Sequential(nn.Conv2d(1, 1, 1)).cuda()
1968        m1_ns, m2_ns = add_loggers('a', m1, 'b', m2, OutputLogger)
1969        datum = torch.randn(1, 1, 1, 1)
1970        datum = datum.cuda()
1971
1972        m1_ns(datum)
1973        m2_ns(datum)
1974
1975        act_compare_dict = extract_logger_info(m1_ns, m2_ns, OutputLogger, 'b')
1976        extend_logger_results_with_comparison(
1977            act_compare_dict, 'a', 'b', compute_sqnr, 'sqnr')
1978
1979    @unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
1980    def test_add_shadow_loggers_cuda(self):
1981        # Note: this is not using quantization because quantized kernels do not
1982        # work on cuda yet.
1983        m1 = nn.Sequential(nn.Conv2d(1, 1, 1)).cuda()
1984        m2 = nn.Sequential(nn.Conv2d(1, 1, 1)).cuda()
1985        m1_shadows_m2 = add_shadow_loggers('a', m1, 'b', m2, OutputLogger)
1986        datum = torch.randn(1, 1, 1, 1)
1987        datum = datum.cuda()
1988
1989        m1_shadows_m2(datum)
1990
1991        act_compare_dict = extract_shadow_logger_info(m1_shadows_m2, OutputLogger, 'b')
1992        extend_logger_results_with_comparison(
1993            act_compare_dict, 'a', 'b', compute_sqnr, 'sqnr')
1994
1995    def test_fp16_shadows_fp32(self):
1996        m = LinearReluFunctional().eval()
1997        example_inputs = (torch.randn(1, 4),)
1998        qconfig_dict = {"": torch.ao.quantization.float16_static_qconfig}
1999        mp = prepare_fx(copy.deepcopy(m), qconfig_dict, example_inputs=example_inputs)
2000        mq = convert_to_reference_fx(mp)
2001        mq_shadows_m = add_shadow_loggers('a', mq, 'b', m, OutputLogger)
2002
2003    def test_mul_add_cat_stack_skips_shadowing(self):
2004        class M(nn.Module):
2005            def forward(self, x):
2006                x = x * x
2007                x = torch.mul(x, x)
2008                x = x + x
2009                x = torch.add(x, x)
2010                x = torch.cat([x])
2011                x = torch.stack([x])
2012                return x
2013
2014        m = M().eval()
2015        self._test_match_shadow_activations(
2016            m, (torch.randn(1, 1, 4, 4),),
2017            results_len=0)
2018
2019    def test_op_with_only_kwargs_skips_shadowing(self):
2020        class M(nn.Module):
2021            def forward(self, x):
2022                x = torch.cat(tensors=[x])
2023                x = torch.stack(tensors=[x])
2024                return x
2025
2026        m = M().eval()
2027        self._test_match_shadow_activations(
2028            m, (torch.randn(1, 1, 4, 4),),
2029            results_len=0)
2030
2031    def test_unsupported_op_copy_skips_shadowing(self):
2032        """
2033        Copying a `call_function` node is not implemented, test that this
2034        does not crash shadowing but instead skips the node.
2035        """
2036        class M(nn.Module):
2037            def forward(self, x):
2038                # the second argument leads to attempting to copy a
2039                # call_function node
2040                x = F.layer_norm(x, x.shape[1:])
2041                return x
2042
2043        m = M().eval()
2044        self._test_match_shadow_activations(
2045            m, (torch.randn(1, 1, 4, 4),),
2046            results_len=0)
2047
2048    def test_linear_kwargs_shadow(self):
2049
2050        class M(nn.Module):
2051            def __init__(self) -> None:
2052                super().__init__()
2053                self.w1 = nn.Parameter(torch.empty(4, 4))
2054                self.b1 = nn.Parameter(torch.zeros(4))
2055                torch.nn.init.kaiming_uniform_(self.w1, a=math.sqrt(5))
2056
2057            def forward(self, x):
2058                x = F.linear(input=x, weight=self.w1, bias=self.b1)
2059                return x
2060
2061        # note: FX graph mode quantization does not have good support
2062        # for kwargs-only right now, so we pass in two unquantized
2063        # models
2064        m = M().eval()
2065        mt = torch.fx.symbolic_trace(m)
2066        mt_copy = copy.deepcopy(mt)
2067
2068        mt_shadows_mt_copy = add_shadow_loggers(
2069            'a', mt, 'b', mt_copy, OutputLogger)
2070
2071        mt_shadows_mt_copy(torch.randn(4, 4))
2072        act_compare_dict = extract_shadow_logger_info(
2073            mt_shadows_mt_copy, OutputLogger, 'b')
2074        self.assertTrue(len(act_compare_dict) == 1)
2075
2076@skipIfNoQNNPACK
2077class TestFXNumericSuiteNShadows(FXNumericSuiteQuantizationTestCase):
2078    """
2079    Tests the "n shadows" workflow.
2080    """
2081
2082    def _test_impl(self, m, example_input, qconfig_mappings):
2083        backend_config = get_native_backend_config()
2084
2085        # test that input is valid
2086        _ = m(*example_input)
2087
2088        msp = prepare_n_shadows_model(
2089            m, example_input, qconfig_mappings, backend_config)
2090        # print('msp', msp)
2091
2092        for _ in range(2):
2093            msp(*example_input)
2094
2095        msq = convert_n_shadows_model(msp)
2096
2097        loggers_set_enabled(msq, True)
2098        msq(*example_input)
2099
2100        results = extract_results_n_shadows_model(msq)
2101        print_comparisons_n_shadows_model(results)
2102        return msq
2103
2104    @withQNNPACKBackend
2105    def test_linear_mod(self):
2106        class M(nn.Module):
2107            def __init__(self) -> None:
2108                super().__init__()
2109                self.fc1 = nn.Linear(2, 2)
2110
2111            def forward(self, x):
2112                x = self.fc1(x)
2113                return x
2114
2115        m = M().eval()
2116        example_input = (torch.randn(2, 2),)
2117
2118        qconfig_mappings = \
2119            QConfigMultiMapping().set_global([torch.ao.quantization.default_qconfig])
2120        self._test_impl(m, example_input, qconfig_mappings)
2121
2122    @withQNNPACKBackend
2123    def test_linear_relu_mod(self):
2124        class M(nn.Module):
2125            def __init__(self) -> None:
2126                super().__init__()
2127                self.fc1 = nn.Linear(2, 2)
2128                self.fc2 = nn.Linear(2, 2)
2129                self.relu = nn.ReLU()
2130
2131            def forward(self, x):
2132                x = self.fc1(x)
2133                x = self.fc2(x)
2134                x = self.relu(x)
2135                return x
2136
2137        m = M().eval()
2138        example_input = (torch.randn(2, 2),)
2139
2140        qconfig_mappings = (
2141            QConfigMultiMapping().set_global([
2142                torch.ao.quantization.default_qconfig,
2143                torch.ao.quantization.default_dynamic_qconfig
2144            ])
2145        )
2146        self._test_impl(m, example_input, qconfig_mappings)
2147
2148    @withQNNPACKBackend
2149    def test_conv_bn_relu_mod(self):
2150        class M(nn.Module):
2151            def __init__(self) -> None:
2152                super().__init__()
2153                self.conv = nn.Conv2d(1, 1, 1)
2154                self.bn = nn.BatchNorm2d(1)
2155                self.relu = nn.ReLU()
2156
2157            def forward(self, x):
2158                x = self.conv(x)
2159                x = self.bn(x)
2160                x = self.relu(x)
2161                return x
2162
2163        m = M().eval()
2164        example_input = (torch.randn(32, 1, 16, 16),)
2165
2166        qconfig_mappings = QConfigMultiMapping() \
2167            .set_global([
2168                torch.ao.quantization.default_qconfig,
2169                torch.ao.quantization.default_per_channel_qconfig
2170            ])
2171        self._test_impl(m, example_input, qconfig_mappings)
2172
2173    @withQNNPACKBackend
2174    def test_functions(self):
2175        class M(nn.Module):
2176            def __init__(self) -> None:
2177                super().__init__()
2178                self.w1 = nn.Parameter(torch.randn(2, 2))
2179                self.b1 = nn.Parameter(torch.zeros(2))
2180                torch.nn.init.kaiming_uniform_(self.w1, a=math.sqrt(5))
2181
2182            def forward(self, x):
2183                x = F.sigmoid(x)
2184                x = F.linear(x, self.w1, self.b1)
2185                x = F.linear(x, self.w1[:], self.b1)
2186                x = F.relu(x)
2187                x = x + x
2188                x = torch.cat([x])
2189                x = torch.cat((x,))
2190                x = torch.cat(tensors=[x])
2191                # TODO(future PR): enable layernorm
2192                # blocked on FX graph mode quant not inserting observer for
2193                # second arg, if the second arg is a module input
2194                # x = F.layer_norm(x, x.shape)
2195                # x = F.layer_norm(x, x.shape[1:])
2196                # x = x.reshape(1, -1) * 2
2197                # x = F.layer_norm(x.reshape(1, -1), x.shape[1:])
2198                x = torch.matmul(x, x.reshape(2, 2))
2199                x = torch.matmul(x.reshape(2, 2), x.reshape(2, 2))
2200                # TODO(future PR): enable below after FX graph mode quantization handles
2201                # it, currently this is not supported
2202                # x = F.linear(input=x, weight=self.w1, bias=self.b1)
2203                return x
2204
2205        m = M().eval()
2206        example_input = (torch.randn(2, 2),)
2207
2208        qconfig_mappings = QConfigMultiMapping() \
2209            .set_global([torch.ao.quantization.default_qconfig])
2210        self._test_impl(m, example_input, qconfig_mappings)
2211
2212    @withQNNPACKBackend
2213    def test_partial_qconfig_mapping(self):
2214        class M(nn.Module):
2215            def __init__(self) -> None:
2216                super().__init__()
2217                self.fc = nn.Linear(2, 2)
2218                self.w1 = nn.Parameter(torch.randn(2, 2))
2219                self.b1 = nn.Parameter(torch.randn(2))
2220                torch.nn.init.kaiming_uniform_(self.w1, a=math.sqrt(5))
2221
2222            def forward(self, x):
2223                x = self.fc(x)
2224                x = F.linear(x, self.w1, self.b1)
2225                x = F.relu(x)
2226                x = x + x
2227                return x
2228
2229        m = M().eval()
2230        example_input = (torch.randn(2, 2),)
2231        qconfig = torch.ao.quantization.default_qconfig
2232
2233        qconfig_mappings = QConfigMultiMapping() \
2234            .set_object_type(F.linear, [qconfig]) \
2235            .set_object_type(F.relu, [qconfig])
2236        self._test_impl(m, example_input, qconfig_mappings)
2237
2238    @withQNNPACKBackend
2239    def test_logger_enabled_and_save_activations_flags(self):
2240        m = nn.Sequential(nn.Linear(1, 1)).eval()
2241        example_input = (torch.randn(1, 1),)
2242
2243        qconfig_mappings = QConfigMultiMapping() \
2244            .set_global([torch.ao.quantization.default_qconfig])
2245        backend_config = get_native_backend_config()
2246
2247        msp = prepare_n_shadows_model(
2248            m, example_input, qconfig_mappings, backend_config)
2249
2250        for _ in range(2):
2251            msp(*example_input)
2252
2253        def _check_logger_count(model, exp_count_stats, exp_count_comparisons):
2254            for name, mod in model.named_modules():
2255                if isinstance(mod, OutputLogger):
2256                    self.assertTrue(
2257                        len(mod.stats) == exp_count_stats,
2258                        f'stats: expected {len(mod.stats)} to equal {exp_count_stats}')
2259                    if isinstance(mod, OutputComparisonLogger):
2260                        self.assertTrue(
2261                            len(mod.comparisons) == exp_count_comparisons,
2262                            f'comparisons: expected {len(mod.comparisons)} to equal {exp_count_comparisons}')
2263
2264        # check behavior with save_activations enabled
2265        msq = convert_n_shadows_model(copy.deepcopy(msp))
2266        loggers_set_enabled(msq, True)
2267        loggers_set_save_activations(msq, True)
2268        # after prepare calibration but before convert calibration, loggers
2269        # should not have anything saved
2270        _check_logger_count(msq, 0, 0)
2271        msq(*example_input)
2272        # loggers should save each item after calibration
2273        _check_logger_count(msq, 1, 1)
2274
2275        # check behavior with save_activations disabled
2276        msq = convert_n_shadows_model(copy.deepcopy(msp))
2277        loggers_set_enabled(msq, True)
2278        loggers_set_save_activations(msq, False)
2279        # after prepare calibration but before convert calibration, loggers
2280        # should not have anything saved
2281        _check_logger_count(msq, 0, 0)
2282        msq(*example_input)
2283        # stats should be empty, but comparisons should be there
2284        _check_logger_count(msq, 0, 1)
2285
2286    @skipIfTorchDynamo("too slow")
2287    @skip_if_no_torchvision
2288    @withQNNPACKBackend
2289    def test_mobilenet_v2(self):
2290        import torchvision
2291        m = torchvision.models.quantization.mobilenet_v2(
2292            pretrained=False, quantize=False).eval()
2293        example_input = (torch.randn(1, 3, 224, 224),)
2294
2295        qconfig_mappings = QConfigMultiMapping() \
2296            .set_global([torch.ao.quantization.default_qconfig, torch.ao.quantization.default_dynamic_qconfig])
2297
2298        self._test_impl(m, example_input, qconfig_mappings)
2299
2300    @withQNNPACKBackend
2301    def test_qconfig_multi_mapping_deduplication(self):
2302        # check that insertion deduplicates qconfigs
2303        qconfig_multi_mapping = QConfigMultiMapping().set_global(
2304            [torch.ao.quantization.default_qconfig, torch.ao.quantization.default_qconfig]
2305        )
2306        self.assertEqual(len(qconfig_multi_mapping.qconfig_mappings_list), 1)
2307
2308    @withQNNPACKBackend
2309    def test_qconfig_multi_mapping_insert_padding(self):
2310        # test that inserting a higher priority qconfig style with fewer elements than a lower priority qconfig will
2311        # result in adding None to the extra QConfigMappings at that same style+key
2312        qconfig_multi_mapping = (
2313            QConfigMultiMapping()
2314            .set_global(
2315                [
2316                    torch.ao.quantization.default_qconfig,
2317                    torch.ao.quantization.default_dynamic_qconfig,
2318                ]
2319            )
2320            .set_object_type(torch.nn.Linear, [torch.ao.quantization.default_qconfig])
2321            .set_module_name_regex("fc", [torch.ao.quantization.default_qconfig])
2322            .set_module_name("fc2", [torch.ao.quantization.default_qconfig])
2323            .set_module_name_object_type_order(
2324                "", nn.Linear, 0, [torch.ao.quantization.default_qconfig]
2325            )
2326        )
2327
2328        self.assertEqual(
2329            qconfig_multi_mapping.qconfig_mappings_list[1].object_type_qconfigs[
2330                torch.nn.Linear
2331            ],
2332            None,
2333        )
2334        self.assertEqual(
2335            qconfig_multi_mapping.qconfig_mappings_list[1].module_name_regex_qconfigs[
2336                "fc"
2337            ],
2338            None,
2339        )
2340        self.assertEqual(
2341            qconfig_multi_mapping.qconfig_mappings_list[1].module_name_qconfigs["fc2"],
2342            None,
2343        )
2344        self.assertEqual(
2345            qconfig_multi_mapping.qconfig_mappings_list[
2346                1
2347            ].module_name_object_type_order_qconfigs[("", nn.Linear, 0)],
2348            None,
2349        )
2350
2351    @withQNNPACKBackend
2352    def test_qconfig_multi_mapping_retroactive_padding(self):
2353        # test that inserting a lower priority qconfig style with more elements thhan lower priority qconfig styles
2354        # will result in the new QConfigMapping having None at all previously existing styles+keys
2355        qconfig_multi_mapping = (
2356            QConfigMultiMapping()
2357            .set_object_type(torch.nn.Linear, [torch.ao.quantization.default_qconfig])
2358            .set_module_name_regex("fc", [torch.ao.quantization.default_qconfig])
2359            .set_module_name("fc2", [torch.ao.quantization.default_qconfig])
2360            .set_module_name_object_type_order(
2361                "", nn.Linear, 0, [torch.ao.quantization.default_qconfig]
2362            )
2363            .set_global(
2364                [
2365                    torch.ao.quantization.default_qconfig,
2366                    torch.ao.quantization.default_dynamic_qconfig,
2367                ]
2368            )
2369        )
2370
2371        self.assertEqual(
2372            qconfig_multi_mapping.qconfig_mappings_list[1].object_type_qconfigs[
2373                torch.nn.Linear
2374            ],
2375            None,
2376        )
2377        self.assertEqual(
2378            qconfig_multi_mapping.qconfig_mappings_list[1].module_name_regex_qconfigs[
2379                "fc"
2380            ],
2381            None,
2382        )
2383        self.assertEqual(
2384            qconfig_multi_mapping.qconfig_mappings_list[1].module_name_qconfigs["fc2"],
2385            None,
2386        )
2387        self.assertEqual(
2388            qconfig_multi_mapping.qconfig_mappings_list[
2389                1
2390            ].module_name_object_type_order_qconfigs[("", nn.Linear, 0)],
2391            None,
2392        )
2393
2394    @withQNNPACKBackend
2395    def test_qconfig_multi_mapping_end_to_end(self):
2396        # test that the prepare/convert_n_shadows_model works as expected
2397        # with qconfig_multi_mapping and avoids unwanted matches
2398
2399        m = TwoLayerLinearModel().eval()
2400        example_input = m.get_example_inputs()
2401
2402        qconfig_multi_mapping = (
2403            QConfigMultiMapping()
2404            .set_global(
2405                [
2406                    torch.ao.quantization.default_qconfig,
2407                    torch.ao.quantization.default_dynamic_qconfig,
2408                ]
2409            )
2410            .set_module_name("fc2", [None, torch.ao.quantization.default_qconfig])
2411        )
2412        self.assertEqual(
2413            qconfig_multi_mapping.qconfig_mappings_list[1].module_name_qconfigs["fc2"],
2414            None,
2415        )
2416        msq = self._test_impl(m, example_input, qconfig_multi_mapping)
2417
2418        self.checkQuantizedLinear(msq.shadow_wrapper_0_1.mod_0)
2419        self.checkDynamicQuantizedLinear(msq.shadow_wrapper_0_2.mod_0, torch.qint8)
2420        self.checkQuantizedLinear(msq.shadow_wrapper_1_1.mod_0)
2421        self.assertRaisesRegex(AttributeError, ".*", lambda: msq.shadow_wrapper_1_2)
2422
2423    @withQNNPACKBackend
2424    def test_qconfig_multi_mapping_from_list(self):
2425        # test QConfigMultiMapping.from_list_qconfig_mapping works as expected
2426
2427        m = TwoLayerLinearModel().eval()
2428        example_input = m.get_example_inputs()
2429
2430        qconfig_mappings_list = [
2431            QConfigMapping().set_global(torch.ao.quantization.default_qconfig),
2432            QConfigMapping()
2433            .set_global(torch.ao.quantization.default_dynamic_qconfig)
2434            .set_module_name("fc2", torch.ao.quantization.default_qconfig),
2435        ]
2436
2437        qconfig_multi_mapping = QConfigMultiMapping().from_list_qconfig_mapping(
2438            qconfig_mappings_list
2439        )
2440        self.assertEqual(
2441            qconfig_multi_mapping.qconfig_mappings_list[1].module_name_qconfigs["fc2"],
2442            None,
2443        )
2444
2445        msq = self._test_impl(m, example_input, qconfig_multi_mapping)
2446
2447        self.checkQuantizedLinear(msq.shadow_wrapper_0_1.mod_0)
2448        self.checkDynamicQuantizedLinear(msq.shadow_wrapper_0_2.mod_0, torch.qint8)
2449        self.checkQuantizedLinear(msq.shadow_wrapper_1_1.mod_0)
2450        self.assertRaisesRegex(AttributeError, ".*", lambda: msq.shadow_wrapper_1_2)
2451
2452    @withQNNPACKBackend
2453    def test_qconfig_multi_mapping_ordering(self):
2454        # test that the module ordering ignores None
2455
2456        m = TwoLayerLinearModel().eval()
2457        example_input = m.get_example_inputs()
2458        qconfig_multi_mapping = (
2459            QConfigMultiMapping()
2460            .set_global(
2461                [
2462                    torch.ao.quantization.default_qconfig,
2463                    torch.ao.quantization.default_dynamic_qconfig,
2464                ]
2465            )
2466            .set_module_name(
2467                "fc2",
2468                [
2469                    None,
2470                    torch.ao.quantization.default_dynamic_qconfig,
2471                    torch.ao.quantization.default_qat_qconfig_v2,
2472                ],
2473            )
2474        )
2475        self.assertEqual(len(qconfig_multi_mapping.qconfig_mappings_list), 2)
2476        msq = self._test_impl(m, example_input, qconfig_multi_mapping)
2477
2478        self.checkQuantizedLinear(msq.shadow_wrapper_0_1.mod_0)
2479        self.checkDynamicQuantizedLinear(msq.shadow_wrapper_0_2.mod_0, torch.qint8)
2480        self.checkDynamicQuantizedLinear(msq.shadow_wrapper_1_1.mod_0, torch.qint8)
2481        self.checkQuantizedLinear(msq.shadow_wrapper_1_2.mod_0)
2482
2483    @withQNNPACKBackend
2484    def test_qconfig_multi_mapping_repr(self):
2485        qconfig_multi_mapping = (
2486            QConfigMultiMapping()
2487            .set_global(
2488                [
2489                    torch.ao.quantization.default_qconfig,
2490                    torch.ao.quantization.default_dynamic_qconfig,
2491                ]
2492            )
2493            .set_module_name(
2494                "fc2",
2495                [
2496                    None,
2497                    torch.ao.quantization.default_dynamic_qconfig,
2498                    torch.ao.quantization.default_qat_qconfig_v2,
2499                ],
2500            )
2501        )
2502        self.assertTrue(isinstance(qconfig_multi_mapping.__repr__(), str))
2503
2504    @withQNNPACKBackend
2505    def test_custom_functions_and_tracer(self):
2506        class M(nn.Module):
2507            def __init__(self) -> None:
2508                super().__init__()
2509                self.fc1 = nn.Linear(2, 2)
2510                self.fc2 = nn.Linear(2, 2)
2511
2512            def forward(self, x):
2513                x = self.fc1(x)
2514                x = self.fc2(x)
2515                return x
2516
2517        m = M().eval()
2518        example_inputs = (torch.randn(2, 2),)
2519
2520        qconfig_mappings = QConfigMultiMapping().set_global(
2521            [torch.ao.quantization.default_qat_qconfig]
2522        )
2523
2524        custom_tracer = torch.ao.quantization.quantize_fx.QuantizationTracer(
2525            ["fc2"], []
2526        )
2527
2528        custom_prepare_fn = torch.ao.quantization.quantize_fx.prepare_qat_fx
2529
2530        def custom_convert_fn(module, to_print):
2531            print(to_print)
2532            mod = torch.ao.quantization.quantize_fx.convert_fx(module)
2533            return mod
2534
2535        backend_config = get_native_backend_config()
2536
2537        # test that input is valid
2538        _ = m(*example_inputs)
2539
2540        kwargs = {"to_print": "working"}
2541
2542        msp = prepare_n_shadows_model(
2543            m,
2544            example_inputs,
2545            qconfig_mappings,
2546            backend_config,
2547            custom_prepare_fn=custom_prepare_fn,
2548            custom_prepare_kwargs=None,
2549            custom_tracer=custom_tracer,
2550        )
2551
2552        for _ in range(2):
2553            msp(*example_inputs)
2554
2555        msq = convert_n_shadows_model(
2556            msp, custom_convert_fn=custom_convert_fn, custom_convert_kwargs=kwargs
2557        )
2558        print(msq)
2559        loggers_set_enabled(msq, True)
2560        msq(*example_inputs)
2561
2562        results = extract_results_n_shadows_model(msq)
2563        print_comparisons_n_shadows_model(results)
2564
2565    def _test_extract_weights_impl(self, m, example_input, qconfig_mapping):
2566        backend_config = get_native_backend_config()
2567        results = _n_shadows_compare_weights(
2568            m, example_input, qconfig_mapping, backend_config)
2569        print_comparisons_n_shadows_model(results)
2570
2571    @withQNNPACKBackend
2572    def test_extract_weights_linear(self):
2573        class M(nn.Module):
2574            def __init__(self) -> None:
2575                super().__init__()
2576                self.w1 = nn.Parameter(torch.randn(2, 2))
2577                self.b1 = nn.Parameter(torch.randn(2))
2578                torch.nn.init.kaiming_uniform_(self.w1, a=math.sqrt(5))
2579                self.w2 = nn.Parameter(torch.randn(2, 2))
2580                self.b2 = nn.Parameter(torch.randn(2))
2581                torch.nn.init.kaiming_uniform_(self.w2, a=math.sqrt(5))
2582                self.w3 = nn.Parameter(torch.randn(2, 2))
2583                self.b3 = nn.Parameter(torch.randn(2))
2584                torch.nn.init.kaiming_uniform_(self.w3, a=math.sqrt(5))
2585                self.w4 = nn.Parameter(torch.randn(2, 2))
2586                self.b4 = nn.Parameter(torch.randn(2))
2587                torch.nn.init.kaiming_uniform_(self.w4, a=math.sqrt(5))
2588
2589            def forward(self, x):
2590                x = F.linear(x, self.w1, self.b1)
2591                x = F.linear(x, self.w2, self.b2)
2592                x = F.relu(x)
2593                x = F.linear(x, self.w3, self.b3)
2594                x = F.linear(x, self.w4, self.b4)
2595                return x
2596
2597        per_tensor_qconfig = torch.ao.quantization.default_qconfig
2598
2599        m = M().eval()
2600        example_input = (torch.randn(2, 2),)
2601        qconfig_mapping = get_default_qconfig_mapping()
2602        # test unquantized
2603        qconfig_mapping.set_module_name_object_type_order(
2604            '', F.linear, 2, None)
2605        # test per-tensor
2606        qconfig_mapping.set_module_name_object_type_order(
2607            '', F.linear, 3, per_tensor_qconfig)
2608        self._test_extract_weights_impl(m, example_input, qconfig_mapping)
2609
2610
2611    def _test_add_loggers_impl(self, m, example_input, qconfig_mapping):
2612        backend_config = get_native_backend_config()
2613        m_copy = copy.deepcopy(m)
2614
2615        # test that input is valid
2616        _ = m(*example_input)
2617
2618        msp = _prepare_n_shadows_add_loggers_model(
2619            m, example_input, qconfig_mapping, backend_config)
2620        # print('msp', msp)
2621
2622        msp(*example_input)
2623
2624        msq = convert_n_shadows_model(msp)
2625        # print('msq', msq)
2626
2627        loggers_set_enabled(msq, True)
2628        output_fp32 = msq(*example_input)
2629
2630        results = extract_results_n_shadows_model(msq)
2631        # print(results)
2632        # print_comparisons_n_shadows_model(results)
2633
2634        # get the last quantized output from results
2635        inner_results = results['model']['node_output']
2636        last_subgraph = list(inner_results.keys())[-1]
2637        output_shadow = inner_results[last_subgraph][0]['values'][-1]
2638
2639        # verify that both fp32 and quantized output matches reference
2640        output_fp32_ref = m_copy(*example_input)
2641        mp_ref = prepare_fx(m_copy, qconfig_mapping, example_input)
2642        for _ in range(2):
2643            mp_ref(*example_input)
2644        mq_ref = convert_fx(mp_ref)
2645        output_shadow_ref = mq_ref(*example_input)
2646        self.assertTrue(
2647            torch.allclose(output_fp32, output_fp32_ref),
2648            f"fp32 comparison: {output_fp32} not close to {output_fp32_ref}")
2649
2650        # print('shadow', output_shadow.shape, output_shadow)
2651        # print('shadow_ref', output_shadow_ref.shape, output_shadow_ref)
2652
2653        self.assertTrue(
2654            torch.allclose(output_shadow, output_shadow_ref),
2655            f"shadow comparison: {output_shadow} not close to {output_shadow_ref}")
2656
2657        return msq
2658
2659    @withQNNPACKBackend
2660    def test_add_loggers_linear_mod_quant_quant(self):
2661        m = nn.Sequential(nn.Linear(2, 2), nn.Linear(2, 2))
2662        example_input = (torch.randn(2, 2),)
2663        qconfig_mapping = get_default_qconfig_mapping()
2664        self._test_add_loggers_impl(m, example_input, qconfig_mapping)
2665
2666    @withQNNPACKBackend
2667    def test_add_loggers_linear_mod_fp32_quant(self):
2668        m = nn.Sequential(nn.Linear(2, 2), nn.Linear(2, 2))
2669        example_input = (torch.randn(2, 2),)
2670        qconfig_mapping = get_default_qconfig_mapping()
2671        qconfig_mapping.set_module_name('0', None)
2672        self._test_add_loggers_impl(m, example_input, qconfig_mapping)
2673
2674    @withQNNPACKBackend
2675    def test_add_loggers_linear_mod_quant_fp32(self):
2676        m = nn.Sequential(nn.Linear(2, 2), nn.Linear(2, 2))
2677        example_input = (torch.randn(2, 2),)
2678        qconfig_mapping = get_default_qconfig_mapping()
2679        qconfig_mapping.set_module_name('1', None)
2680        self._test_add_loggers_impl(m, example_input, qconfig_mapping)
2681
2682    @withQNNPACKBackend
2683    def test_add_loggers_linear_mod_fp32_fp32(self):
2684        m = nn.Sequential(nn.Linear(2, 2), nn.Linear(2, 2))
2685        example_input = (torch.randn(2, 2),)
2686        qconfig_mapping = get_default_qconfig_mapping()
2687        qconfig_mapping.set_module_name('0', None)
2688        qconfig_mapping.set_module_name('1', None)
2689        self._test_add_loggers_impl(m, example_input, qconfig_mapping)
2690
2691    @withQNNPACKBackend
2692    def test_add_loggers_conv_bn_relu_fusion_quant(self):
2693        m = nn.Sequential(nn.Conv2d(1, 1, 1), nn.BatchNorm2d(1), nn.ReLU())
2694        m.eval()
2695        example_input = (torch.randn(16, 1, 4, 4),)
2696        qconfig_mapping = get_default_qconfig_mapping()
2697        self._test_add_loggers_impl(m, example_input, qconfig_mapping)
2698
2699    @withQNNPACKBackend
2700    def test_add_loggers_conv_bn_relu_fusion_fp32(self):
2701        m = nn.Sequential(nn.Conv2d(1, 1, 1), nn.BatchNorm2d(1), nn.ReLU())
2702        m.eval()
2703        example_input = (torch.randn(16, 1, 4, 4),)
2704        qconfig_mapping = get_default_qconfig_mapping()
2705        qconfig_mapping.set_module_name('0', None)
2706        qconfig_mapping.set_module_name('1', None)
2707        qconfig_mapping.set_module_name('2', None)
2708        self._test_add_loggers_impl(m, example_input, qconfig_mapping)
2709
2710    @withQNNPACKBackend
2711    def test_add_loggers_functions(self):
2712        class M(nn.Module):
2713            def __init__(self) -> None:
2714                super().__init__()
2715                self.w1 = nn.Parameter(torch.randn(2, 2))
2716                self.b1 = nn.Parameter(torch.randn(2))
2717                torch.nn.init.kaiming_uniform_(self.w1, a=math.sqrt(5))
2718
2719            def forward(self, x):
2720                x = F.linear(x, self.w1, self.b1)
2721                x = F.relu(x)
2722                x = x + x
2723                x = x + 1
2724                # TODO(future PR): support first arg being a scalar
2725                # x = 1 + x
2726                x = torch.cat([x, x])
2727                x = torch.cat([x, x])
2728                x = torch.cat(tensors=[x, x])
2729                # function not matchable by quantization
2730                x = torch.nn.functional.rrelu(x)
2731                x = F.linear(x, self.w1, self.b1)
2732                return x
2733
2734        m = M().eval()
2735        example_input = (torch.randn(16, 2),)
2736        for qconfig_mapping in (
2737            get_default_qconfig_mapping(),
2738            QConfigMapping(),
2739        ):
2740            self._test_add_loggers_impl(m, example_input, qconfig_mapping)
2741
2742    @skipIfTorchDynamo("too slow")
2743    @skip_if_no_torchvision
2744    @withQNNPACKBackend
2745    def test_add_loggers_mobilenet_v2(self):
2746        import torchvision
2747        m = torchvision.models.quantization.mobilenet_v2(
2748            pretrained=False, quantize=False).eval()
2749        example_input = (torch.randn(8, 3, 224, 224),)
2750        qconfig_mapping = get_default_qconfig_mapping()
2751        self._test_add_loggers_impl(m, example_input, qconfig_mapping)
2752
2753
2754class TestFXNumericSuiteCoreAPIsModels(FXNumericSuiteQuantizationTestCase):
2755    """
2756    Tests numeric suite core APIs on non-toy models.
2757    """
2758
2759    @skipIfNoFBGEMM
2760    def test_compare_weights_conv(self):
2761        test_cases = (
2762            (ConvModel(),),
2763            (ConvBnModel(),),
2764            (ConvBnReLUModel(),),
2765        )
2766        for m, in test_cases:
2767            m.eval()
2768            example_inputs = (torch.randn(1, 3, 5, 5),)
2769            self._test_extract_weights(m, example_inputs, results_len=1)
2770
2771    @skipIfNoFBGEMM
2772    def test_compare_weights_linear(self):
2773        test_cases = (
2774            (SingleLayerLinearModel(), None),
2775            (
2776                SingleLayerLinearDynamicModel(),
2777                {"object_type": [(nn.Linear, default_dynamic_qconfig)]},
2778            ),
2779        )
2780        for m, qconfig_dict in test_cases:
2781            m.eval()
2782            example_inputs = (torch.randn(1, 3, 5, 5),)
2783            res = self._test_extract_weights(
2784                m, example_inputs, results_len=1, qconfig_dict=qconfig_dict)
2785
2786    @skipIfNoFBGEMM
2787    def test_compare_weights_lstm_dynamic(self):
2788        qconfig_dict = {"object_type": [(nn.LSTM, default_dynamic_qconfig)]}
2789        lstm_input = torch.rand((1, 1, 2))
2790        lstm_hidden = (torch.rand(1, 1, 2), torch.rand(1, 1, 2))
2791        example_inputs = (lstm_input, lstm_hidden)
2792        m = LSTMwithHiddenDynamicModel().eval()
2793        res = self._test_extract_weights(
2794            m, example_inputs, results_len=1, qconfig_dict=qconfig_dict)
2795
2796    @skipIfNoFBGEMM
2797    def test_compare_activations_conv(self):
2798        test_cases = (
2799            (ConvModel(),),
2800            (ConvBnModel(),),
2801            (ConvBnReLUModel(),),
2802        )
2803        for m, in test_cases:
2804            m.eval()
2805            res = self._test_match_activations(
2806                m, (torch.randn(1, 3, 4, 4),), results_len=1)
2807
2808    @skipIfNoFBGEMM
2809    def test_compare_activations_linear(self):
2810        test_cases = (
2811            (SingleLayerLinearModel(), None),
2812            (
2813                SingleLayerLinearDynamicModel(),
2814                {"object_type": [(nn.Linear, default_dynamic_qconfig)]},
2815            ),
2816        )
2817        for m, qconfig_dict in test_cases:
2818            m.eval()
2819            res = self._test_match_activations(
2820                m, (torch.randn(5, 5),), results_len=1, qconfig_dict=qconfig_dict)
2821
2822    @skipIfNoFBGEMM
2823    def test_compare_activations_lstm_dynamic(self):
2824        qconfig_dict = {"object_type": [(nn.LSTM, default_dynamic_qconfig)]}
2825        m = LSTMwithHiddenDynamicModel().eval()
2826        lstm_input = torch.rand((1, 1, 2))
2827        lstm_hidden = (torch.rand(1, 1, 2), torch.rand(1, 1, 2))
2828        # TODO(future PR): enable scripting (quant prepared LSTM not scriptable)
2829        res = self._test_match_activations(
2830            m, (lstm_input, lstm_hidden), results_len=1, qconfig_dict=qconfig_dict,
2831            skip_scripting=True)
2832
2833    @skipIfNoFBGEMM
2834    def test_compare_shadow_activations_conv(self):
2835        test_cases = (
2836            (ConvModel(),),
2837            (ConvBnModel(),),
2838            (ConvBnReLUModel(),),
2839        )
2840        for m, in test_cases:
2841            m.eval()
2842            res = self._test_match_shadow_activations(
2843                m, (torch.randn(1, 3, 4, 4),), results_len=1)
2844
2845    @skipIfNoFBGEMM
2846    def test_compare_shadow_activations_linear(self):
2847        test_cases = (
2848            (SingleLayerLinearModel(), None),
2849            (
2850                SingleLayerLinearDynamicModel(),
2851                {"object_type": [(nn.Linear, default_dynamic_qconfig)]},
2852            ),
2853        )
2854        for m, qconfig_dict in test_cases:
2855            m.eval()
2856            res = self._test_match_shadow_activations(
2857                m, (torch.randn(5, 5),), results_len=1, qconfig_dict=qconfig_dict)
2858
2859    @skipIfNoFBGEMM
2860    def test_compare_shadow_activations_lstm_dynamic(self):
2861        qconfig_dict = {"object_type": [(nn.LSTM, default_dynamic_qconfig)]}
2862        m = LSTMwithHiddenDynamicModel().eval()
2863        lstm_input = torch.rand((1, 1, 2))
2864        lstm_hidden = (torch.rand(1, 1, 2), torch.rand(1, 1, 2))
2865        # TODO(future PR): enable scripting (quant prepared LSTM not scriptable)
2866        res = self._test_match_shadow_activations(
2867            m, (lstm_input, lstm_hidden), results_len=1, qconfig_dict=qconfig_dict,
2868            skip_scripting=True)
2869
2870    @skipIfNoFBGEMM
2871    def test_sparsenn_compare_activations(self):
2872        for should_log_inputs in (True, False):
2873            sparse_nn = SparseNNModel().eval()
2874            idx = torch.LongTensor([1, 2, 4, 5, 4, 3, 2, 9])
2875            offsets = torch.LongTensor([0, 4])
2876            x = torch.randn(2, 4)
2877            self._test_match_activations(
2878                sparse_nn, (idx, offsets, x),
2879                results_len=5,
2880                should_log_inputs=should_log_inputs)
2881
2882    @skipIfNoFBGEMM
2883    def test_sparsenn_shadow(self):
2884        for should_log_inputs in (True, False):
2885            sparse_nn = SparseNNModel().eval()
2886            idx = torch.LongTensor([1, 2, 4, 5, 4, 3, 2, 9])
2887            offsets = torch.LongTensor([0, 4])
2888            x = torch.randn(2, 4)
2889            self._test_match_shadow_activations(
2890                sparse_nn, (idx, offsets, x),
2891                results_len=3,
2892                should_log_inputs=should_log_inputs)
2893
2894    @skipIfTorchDynamo("too slow")
2895    @skip_if_no_torchvision
2896    @skipIfNoFBGEMM
2897    def test_resnet18(self):
2898        import torchvision
2899        m = torchvision.models.quantization.resnet18(pretrained=False, quantize=False).eval()
2900        qconfig_dict = {'': torch.ao.quantization.default_qconfig}
2901        self._test_match_shadow_activations(
2902            m, (torch.randn(1, 3, 224, 224),),
2903            qconfig_dict=qconfig_dict,
2904            should_log_inputs=False)
2905
2906    @skipIfTorchDynamo("too slow")
2907    @skip_if_no_torchvision
2908    @skipIfNoFBGEMM
2909    def test_mobilenet_v2(self):
2910        import torchvision
2911        m = torchvision.models.quantization.mobilenet_v2(pretrained=False, quantize=False).eval()
2912        qconfig_dict = {'': torch.ao.quantization.default_qconfig}
2913        self._test_match_shadow_activations(
2914            m, (torch.randn(1, 3, 224, 224),),
2915            qconfig_dict=qconfig_dict,
2916            should_log_inputs=False)
2917