xref: /aosp_15_r20/external/pytorch/test/quantization/pt2e/test_x86inductor_quantizer.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: quantization"]
2import copy
3import itertools
4import sys
5from enum import Enum
6
7import torch
8import torch.ao.quantization.quantizer.x86_inductor_quantizer as xiq
9import torch.nn as nn
10from torch._export import capture_pre_autograd_graph
11from torch.ao.quantization import ObserverBase
12from torch.ao.quantization.quantize_pt2e import (
13    convert_pt2e,
14    prepare_pt2e,
15    prepare_qat_pt2e,
16)
17from torch.ao.quantization.quantizer.x86_inductor_quantizer import (
18    QUANT_ANNOTATION_KEY,
19    X86InductorQuantizer,
20)
21from torch.testing._internal.common_quantization import (
22    NodeSpec as ns,
23    QuantizationTestCase,
24    skipIfNoInductorSupport,
25    skipIfNoX86,
26)
27from torch.testing._internal.common_quantized import override_quantized_engine
28from torch.testing._internal.common_utils import IS_CI, IS_WINDOWS, skipIfTorchDynamo
29
30
31if IS_WINDOWS and IS_CI:
32    sys.stderr.write("Windows CI still has some issue to be fixed.\n")
33    sys.exit(0)
34
35
36class NodePosType(Enum):
37    left = 1
38    right = 2
39    both = 3
40
41
42class TestHelperModules:
43    class SingleConv2dModule(torch.nn.Module):
44        def __init__(self, with_bn=False) -> None:
45            super().__init__()
46            self.conv = nn.Conv2d(3, 6, (2, 2), stride=(1, 1), padding=(1, 1))
47            self.bn = torch.nn.BatchNorm2d(6)
48            self.with_bn = with_bn
49
50        def forward(self, x):
51            x = self.conv(x)
52            if self.with_bn:
53                x = self.bn(x)
54            return x
55
56    class Conv2dUnaryModule(torch.nn.Module):
57        def __init__(self, post_op, use_bias: bool = False, with_bn=False) -> None:
58            super().__init__()
59            self.conv = nn.Conv2d(
60                3, 6, (2, 2), stride=(1, 1), padding=(1, 1), bias=use_bias
61            )
62            self.post_op = post_op
63            self.bn = torch.nn.BatchNorm2d(6)
64            self.with_bn = with_bn
65            self.maxpool = torch.nn.MaxPool2d((3, 3))
66
67        def forward(self, x):
68            x = self.conv(x)
69            if self.with_bn:
70                x = self.bn(x)
71            x = self.post_op(x)
72            x = self.maxpool(x)
73            return x
74
75    class Conv2dAddModule(torch.nn.Module):
76        def __init__(
77            self,
78            inplace_add: bool = False,
79            conv2d_type: NodePosType = NodePosType.left,
80            use_bias: bool = False,
81            with_bn: bool = False,
82        ) -> None:
83            super().__init__()
84            self.conv = torch.nn.Conv2d(
85                in_channels=3,
86                out_channels=3,
87                kernel_size=3,
88                stride=1,
89                padding=1,
90                bias=use_bias,
91            )
92            self.conv2 = torch.nn.Conv2d(
93                in_channels=3,
94                out_channels=3,
95                kernel_size=3,
96                stride=1,
97                padding=1,
98                bias=use_bias,
99            )
100            self.relu = nn.ReLU()
101            self.inplace_add = inplace_add
102            self.conv2d_type = conv2d_type
103            self.bn = torch.nn.BatchNorm2d(3)
104            self.with_bn = with_bn
105
106        def forward(self, x):
107            if self.conv2d_type == NodePosType.left:
108                if self.inplace_add:
109                    tmp = self.conv(x)
110                    if self.with_bn:
111                        tmp = self.bn(tmp)
112                    tmp += self.relu(x)
113                    return tmp
114                else:
115                    tmp = self.conv(x)
116                    if self.with_bn:
117                        tmp = self.bn(tmp)
118                    return tmp + self.relu(x)
119            elif self.conv2d_type == NodePosType.right:
120                if self.inplace_add:
121                    tmp = self.relu(x)
122                    tmp += self.conv(x)
123                    return tmp
124                else:
125                    return self.relu(x) + self.conv(x)
126            elif self.conv2d_type == NodePosType.both:
127                if self.inplace_add:
128                    tmp = self.conv(x)
129                    tmp += self.conv2(x)
130                    return tmp
131                else:
132                    return self.conv(x) + self.conv2(x)
133
134    class Conv2dAddReLUModule(torch.nn.Module):
135        def __init__(
136            self,
137            inplace_add: bool = False,
138            conv2d_type: NodePosType = NodePosType.left,
139            inplace_relu: bool = False,
140            use_bias: bool = False,
141            with_bn: bool = False,
142        ) -> None:
143            super().__init__()
144            self.conv = torch.nn.Conv2d(
145                in_channels=3,
146                out_channels=3,
147                kernel_size=3,
148                stride=1,
149                padding=1,
150                bias=use_bias,
151            )
152            self.conv2 = torch.nn.Conv2d(
153                in_channels=3,
154                out_channels=3,
155                kernel_size=3,
156                stride=1,
157                padding=1,
158                bias=use_bias,
159            )
160            self.relu = nn.ReLU()
161            self.inplace_add = inplace_add
162            self.conv2d_type = conv2d_type
163            self.relu2 = nn.ReLU(inplace=inplace_relu)
164            self.bn = torch.nn.BatchNorm2d(3)
165            self.with_bn = with_bn
166
167        def forward(self, x):
168            if self.conv2d_type == NodePosType.left:
169                if self.inplace_add:
170                    tmp = self.conv(x)
171                    if self.with_bn:
172                        tmp = self.bn(tmp)
173                    tmp += self.relu(x)
174                    return self.relu2(tmp)
175                else:
176                    tmp = self.conv(x)
177                    if self.with_bn:
178                        tmp = self.bn(tmp)
179                    return self.relu2(tmp + self.relu(x))
180            elif self.conv2d_type == NodePosType.right:
181                if self.inplace_add:
182                    tmp = self.relu(x)
183                    tmp += self.conv(x)
184                    return self.relu2(tmp)
185                else:
186                    return self.relu2(self.relu(x) + self.conv(x))
187            elif self.conv2d_type == NodePosType.both:
188                if self.inplace_add:
189                    tmp = self.conv(x)
190                    tmp += self.conv2(x)
191                    return self.relu2(tmp)
192                else:
193                    return self.relu2(self.conv(x) + self.conv2(x))
194
195    class Conv2dSingleOpPowModule(nn.Module):
196        def __init__(self, single_op):
197            super().__init__()
198            self.conv = nn.Conv2d(2, 2, 1)
199            self.single_op = single_op
200
201        def forward(self, x):
202            x = self.conv(x)
203            x = self.single_op(x)
204            return torch.pow(x, 2)
205
206    class SerialsConv2dAddReLUModule(torch.nn.Module):
207        """Serials of 2 Conv2d -> Add -> ReLU Pattern."""
208
209        def __init__(
210            self,
211        ) -> None:
212            super().__init__()
213            self.conv = torch.nn.Conv2d(
214                in_channels=3,
215                out_channels=3,
216                kernel_size=3,
217                stride=1,
218                padding=1,
219                bias=True,
220            )
221            self.conv2 = torch.nn.Conv2d(
222                in_channels=3,
223                out_channels=3,
224                kernel_size=3,
225                stride=1,
226                padding=1,
227                bias=True,
228            )
229            self.conv3 = torch.nn.Conv2d(
230                in_channels=3,
231                out_channels=3,
232                kernel_size=3,
233                stride=1,
234                padding=1,
235                bias=True,
236            )
237            self.conv4 = torch.nn.Conv2d(
238                in_channels=3,
239                out_channels=3,
240                kernel_size=3,
241                stride=1,
242                padding=1,
243                bias=True,
244            )
245            self.relu = nn.ReLU()
246            self.relu2 = nn.ReLU()
247
248        def forward(self, x):
249            x1 = self.conv(x)
250            res1 = self.relu(self.conv2(x1) + self.conv3(x1))
251            res2 = self.relu2(self.conv4(res1) + res1)
252            return res2
253
254    class Conv2dCatMaxpool2d(torch.nn.Module):
255        def __init__(
256            self,
257        ):
258            super().__init__()
259            self.conv = torch.nn.Conv2d(
260                3, 16, 7, bias=True, stride=2, padding=3, dilation=1
261            )
262            self.conv2 = torch.nn.Conv2d(
263                3, 16, 7, bias=True, stride=2, padding=3, dilation=1
264            )
265            self.relu = torch.nn.ReLU()
266            self.maxpool = torch.nn.MaxPool2d(3, stride=2, padding=1)
267            self.conv3 = torch.nn.Conv2d(
268                32, 32, 7, bias=True, stride=2, padding=3, dilation=1
269            )
270
271        def forward(self, x):
272            temp1 = self.relu(self.conv(x))
273            temp2 = self.conv2(x + 1)
274            temp3 = torch.cat((temp1, temp2), 1)
275            temp4 = self.maxpool(temp3)
276            temp5 = self.conv3(temp4)
277            return temp5
278
279    class Conv2dAvgPool2d(torch.nn.Module):
280        def __init__(
281            self,
282        ):
283            super().__init__()
284            self.conv = torch.nn.Conv2d(
285                3, 16, 7, bias=True, stride=2, padding=3, dilation=1
286            )
287            self.avgpool = torch.nn.AvgPool2d(3, stride=2, padding=1)
288
289        def forward(self, x):
290            temp1 = self.avgpool(self.conv(x))
291            return temp1
292
293    class Conv2dCatSameInputs(torch.nn.Module):
294        def __init__(
295            self,
296        ):
297            super().__init__()
298            self.conv = torch.nn.Conv2d(
299                3, 16, 7, bias=True, stride=2, padding=3, dilation=1
300            )
301            self.relu = torch.nn.ReLU()
302
303        def forward(self, x):
304            temp1 = self.relu(self.conv(x))
305            temp3 = torch.cat((temp1, temp1), 1)
306            return temp3
307
308    class Conv2dCatSingleInput(torch.nn.Module):
309        def __init__(
310            self,
311        ):
312            super().__init__()
313            self.conv = torch.nn.Conv2d(
314                3, 16, 7, bias=True, stride=2, padding=3, dilation=1
315            )
316            self.relu = torch.nn.ReLU()
317
318        def forward(self, x):
319            temp1 = self.relu(self.conv(x))
320            temp3 = torch.cat((temp1,), 1)
321            return temp3
322
323    class SingleLinearModule(torch.nn.Module):
324        def __init__(self, use_bias) -> None:
325            super().__init__()
326            self.linear = nn.Linear(4, 4, bias=use_bias)
327
328        def forward(self, x):
329            return self.linear(x)
330
331    class LinearUnaryModule(torch.nn.Module):
332        def __init__(
333            self, use_bias, postop, inplace_postop=False, post_op_algo="none"
334        ) -> None:
335            super().__init__()
336            self.linear = nn.Linear(4, 4, bias=use_bias)
337            if postop == nn.GELU:
338                self.postop = postop(approximate=post_op_algo)
339            else:
340                self.postop = postop(inplace=inplace_postop)
341
342        def forward(self, x):
343            return self.postop(self.linear(x))
344
345    class LinearAddModule(torch.nn.Module):
346        def __init__(
347            self,
348            inplace_add: bool = False,
349            linear_pos: NodePosType = NodePosType.left,
350            use_bias: bool = False,
351        ) -> None:
352            super().__init__()
353            self.linear = torch.nn.Linear(
354                in_features=16, out_features=16, bias=use_bias
355            )
356            self.linear2 = torch.nn.Linear(
357                in_features=16, out_features=16, bias=use_bias
358            )
359            self.relu = nn.ReLU()
360            self.inplace_add = inplace_add
361            self.linear_pos = linear_pos
362
363        def forward(self, x):
364            if self.linear_pos == NodePosType.left:
365                if self.inplace_add:
366                    tmp = self.linear(x)
367                    tmp += self.relu(x)
368                    return tmp
369                else:
370                    tmp = self.linear(x)
371                    return tmp + self.relu(x)
372            elif self.linear_pos == NodePosType.right:
373                if self.inplace_add:
374                    tmp = self.relu(x)
375                    tmp += self.linear(x)
376                    return tmp
377                else:
378                    return self.relu(x) + self.linear(x)
379            elif self.linear_pos == NodePosType.both:
380                if self.inplace_add:
381                    tmp = self.linear(x)
382                    tmp += self.linear2(x)
383                    return tmp
384                else:
385                    return self.linear(x) + self.linear2(x)
386
387    class LinearAddReLUModule(torch.nn.Module):
388        def __init__(
389            self,
390            inplace_add: bool = False,
391            linear_pos: NodePosType = NodePosType.left,
392            inplace_relu: bool = False,
393            use_bias: bool = False,
394        ) -> None:
395            super().__init__()
396            self.linear = torch.nn.Linear(
397                in_features=16, out_features=16, bias=use_bias
398            )
399            self.linear2 = torch.nn.Linear(
400                in_features=16, out_features=16, bias=use_bias
401            )
402            self.relu = nn.ReLU()
403            self.inplace_add = inplace_add
404            self.linear_pos = linear_pos
405            self.relu2 = nn.ReLU(inplace=inplace_relu)
406
407        def forward(self, x):
408            if self.linear_pos == NodePosType.left:
409                if self.inplace_add:
410                    tmp = self.linear(x)
411                    tmp += self.relu(x)
412                    return self.relu2(tmp)
413                else:
414                    tmp = self.linear(x)
415                    return self.relu2(tmp + self.relu(x))
416            elif self.linear_pos == NodePosType.right:
417                if self.inplace_add:
418                    tmp = self.relu(x)
419                    tmp += self.linear(x)
420                    return self.relu2(tmp)
421                else:
422                    return self.relu2(self.relu(x) + self.linear(x))
423            elif self.linear_pos == NodePosType.both:
424                if self.inplace_add:
425                    tmp = self.linear(x)
426                    tmp += self.linear2(x)
427                    return self.relu2(tmp)
428                else:
429                    return self.relu2(self.linear(x) + self.linear2(x))
430
431    class SerialsLinearAddReLUModule(torch.nn.Module):
432        """Serials of 2 Linear -> Add -> ReLU Pattern."""
433
434        def __init__(
435            self,
436        ) -> None:
437            super().__init__()
438            self.linear = torch.nn.Linear(in_features=16, out_features=16, bias=True)
439            self.linear2 = torch.nn.Linear(in_features=16, out_features=16, bias=True)
440            self.linear3 = torch.nn.Linear(in_features=16, out_features=16, bias=True)
441            self.linear4 = torch.nn.Linear(in_features=16, out_features=16, bias=True)
442            self.relu = nn.ReLU()
443            self.relu2 = nn.ReLU()
444
445        def forward(self, x):
446            x1 = self.linear(x)
447            res1 = self.relu(self.linear2(x1) + self.linear3(x1))
448            res2 = self.relu2(self.linear4(res1) + res1)
449            return res2
450
451    class LinearAddModule2(torch.nn.Module):
452        def __init__(
453            self,
454            inplace_add: bool = False,
455        ) -> None:
456            super().__init__()
457            self.linear = torch.nn.Linear(in_features=16, out_features=16, bias=True)
458            self.linear2 = torch.nn.Linear(in_features=16, out_features=16, bias=True)
459            self.inplace_add = inplace_add
460
461        def forward(self, x):
462            if self.inplace_add:
463                tmp = self.linear(x)
464                tmp += self.linear2(tmp)
465                return tmp
466            else:
467                tmp = self.linear(x)
468                return tmp + self.linear2(tmp)
469
470    class Conv2dAddModule2(torch.nn.Module):
471        def __init__(
472            self,
473            inplace_add: bool = False,
474        ) -> None:
475            super().__init__()
476            self.conv = torch.nn.Conv2d(
477                in_channels=3, out_channels=3, kernel_size=3, stride=1, padding=1
478            )
479            self.conv2 = torch.nn.Conv2d(
480                in_channels=3, out_channels=3, kernel_size=3, stride=1, padding=1
481            )
482            self.inplace_add = inplace_add
483            self.bn = torch.nn.BatchNorm2d(3)
484            self.bn2 = torch.nn.BatchNorm2d(3)
485
486        def forward(self, x):
487            if self.inplace_add:
488                tmp = self.bn(self.conv(x))
489                tmp += self.bn2(self.conv2(tmp))
490                return tmp
491            else:
492                tmp = self.bn(self.conv(x))
493                return tmp + self.bn2(self.conv2(tmp))
494
495    class SelfAttnLikeModule(torch.nn.Module):
496        def __init__(
497            self,
498            input_dim,
499            transpose_for_score=False,
500            num_attention_heads=None,
501            attention_head_size=None,
502        ) -> None:
503            super().__init__()
504            self.input_dim = input_dim
505            self.q_proj = nn.Linear(input_dim, input_dim, bias=False)
506            self.k_proj = nn.Linear(input_dim, input_dim, bias=False)
507            self.v_proj = nn.Linear(input_dim, input_dim, bias=False)
508            self.softmax = nn.Softmax(dim=-1)
509            self.transpose_for_score = transpose_for_score
510            if self.transpose_for_score:
511                assert num_attention_heads is not None
512                assert attention_head_size is not None
513                self.num_attention_heads = num_attention_heads
514                self.attention_head_size = attention_head_size
515
516        def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
517            new_x_shape = x.size()[:-1] + (
518                self.num_attention_heads,
519                self.attention_head_size,
520            )
521            x = x.view(new_x_shape)
522            return x.permute(0, 2, 1, 3)
523
524        def forward(self, x):
525            q = self.q_proj(x)
526            k = self.k_proj(x)
527            v = self.v_proj(x)
528            if self.transpose_for_score:
529                q = self.transpose_for_scores(q)
530                k = self.transpose_for_scores(k)
531                v = self.transpose_for_scores(v)
532            scores = torch.matmul(q, k.transpose(-1, -2)) / (self.input_dim**0.5)
533            attention = self.softmax(scores)
534            weighted = torch.matmul(attention, v)
535            return weighted
536
537
538class X86InductorQuantTestCase(QuantizationTestCase):
539    def _test_quantizer(
540        self,
541        model,
542        example_inputs,
543        quantizer,
544        expected_node_occurrence,
545        expected_node_list=None,
546        is_qat=False,
547        debug=False,
548    ):
549        m_eager = model.train() if is_qat else model.eval()
550
551        # program capture
552        m = copy.deepcopy(m_eager)
553        m = capture_pre_autograd_graph(
554            m,
555            example_inputs,
556        )
557
558        # QAT Model failed to deepcopy
559        export_model = m if is_qat else copy.deepcopy(m)
560        m = prepare_qat_pt2e(m, quantizer) if is_qat else prepare_pt2e(m, quantizer)
561        # Calibrate
562        m(*example_inputs)
563        prepare_model = copy.deepcopy(m)
564        m = convert_pt2e(m)
565        convert_model = copy.deepcopy(m)
566        if debug:
567            convert_model.print_readable(True)
568        pt2_quant_output = m(*example_inputs)
569        node_occurrence = {
570            ns.call_function(k): v for k, v in expected_node_occurrence.items()
571        }
572        if expected_node_list is None:
573            expected_node_list = []
574        node_list = [ns.call_function(n) for n in expected_node_list]
575        self.checkGraphModuleNodes(
576            m, expected_node_occurrence=node_occurrence, expected_node_list=node_list
577        )
578        return export_model, prepare_model, convert_model
579
580
581@skipIfNoInductorSupport
582class TestQuantizePT2EX86Inductor(X86InductorQuantTestCase):
583    @skipIfNoX86
584    def test_conv2d(self):
585        """
586        Test pattern of single conv2d with X86InductorQuantizer.
587        """
588        with override_quantized_engine("x86"), torch.no_grad():
589            m = TestHelperModules.SingleConv2dModule().eval()
590            example_inputs = (torch.randn(2, 3, 16, 16),)
591            quantizer = X86InductorQuantizer().set_global(
592                xiq.get_default_x86_inductor_quantization_config()
593            )
594            node_occurrence = {
595                # one for input and weight of the conv
596                torch.ops.quantized_decomposed.quantize_per_tensor.default: 1,
597                torch.ops.quantized_decomposed.dequantize_per_tensor.default: 1,
598                # note: quantize op for weights are const propagated
599                torch.ops.quantized_decomposed.quantize_per_channel.default: 0,
600                torch.ops.quantized_decomposed.dequantize_per_channel.default: 1,
601            }
602            node_list = [
603                torch.ops.quantized_decomposed.quantize_per_tensor.default,
604                torch.ops.quantized_decomposed.dequantize_per_tensor.default,
605                torch.ops.aten.conv2d.default,
606            ]
607            self._test_quantizer(
608                m,
609                example_inputs,
610                quantizer,
611                node_occurrence,
612                node_list,
613            )
614
615    @skipIfNoX86
616    def test_conv2d_unary(self):
617        """
618        Test pattern of conv2d with unary post ops (such as relu, hardtanh, hardswish, relu6) with X86InductorQuantizer.
619        """
620        unary_map = {
621            "relu": [torch.nn.ReLU(inplace=False), torch.ops.aten.relu.default],
622            "relu_inplace": [torch.nn.ReLU(inplace=True), torch.ops.aten.relu_.default],
623            "hardtanh": [
624                torch.nn.Hardtanh(min_val=0.0, max_val=6.0, inplace=False),
625                torch.ops.aten.hardtanh.default,
626            ],
627            "hardtanh_inplace": [
628                torch.nn.Hardtanh(min_val=0.0, max_val=6.0, inplace=True),
629                torch.ops.aten.hardtanh_.default,
630            ],
631            "relu6": [torch.nn.ReLU6(inplace=False), torch.ops.aten.hardtanh.default],
632            "relu6_inplace": [
633                torch.nn.ReLU6(inplace=True),
634                torch.ops.aten.hardtanh_.default,
635            ],
636            "hardswish": [
637                torch.nn.Hardswish(inplace=False),
638                torch.ops.aten.hardswish.default,
639            ],
640            "hardswish_inplace": [
641                torch.nn.Hardswish(inplace=True),
642                torch.ops.aten.hardswish_.default,
643            ],
644            "swish": [torch.nn.SiLU(inplace=False), torch.ops.aten.silu.default],
645            "swish_inplace": [
646                torch.nn.SiLU(inplace=True),
647                torch.ops.aten.silu_.default,
648            ],
649        }
650        use_bias_list = [True, False]
651        with override_quantized_engine("x86"), torch.no_grad():
652            for unary_op, use_bias in itertools.product(
653                unary_map.keys(), use_bias_list
654            ):
655                m = TestHelperModules.Conv2dUnaryModule(
656                    unary_map[unary_op][0], use_bias=use_bias
657                ).eval()
658                example_inputs = (torch.randn(2, 3, 16, 16),)
659                quantizer = X86InductorQuantizer().set_global(
660                    xiq.get_default_x86_inductor_quantization_config()
661                )
662                node_occurrence = {
663                    # one for input and weight of the conv
664                    torch.ops.quantized_decomposed.quantize_per_tensor.default: 3,
665                    torch.ops.quantized_decomposed.dequantize_per_tensor.default: 3,
666                    # note: quantize op for weights are const propagated
667                    torch.ops.quantized_decomposed.quantize_per_channel.default: 0,
668                    torch.ops.quantized_decomposed.dequantize_per_channel.default: 1,
669                }
670                node_list = [
671                    torch.ops.quantized_decomposed.quantize_per_tensor.default,
672                    torch.ops.quantized_decomposed.dequantize_per_tensor.default,
673                    torch.ops.aten.conv2d.default,
674                    unary_map[unary_op][1],
675                ]
676                self._test_quantizer(
677                    m,
678                    example_inputs,
679                    quantizer,
680                    node_occurrence,
681                    node_list,
682                )
683
684    @skipIfNoX86
685    def test_conv2d_binary(self):
686        """
687        Test pattern of conv2d with binary post ops (such as add) with X86InductorQuantizer.
688        Currently, only add as binary post op is supported.
689        """
690        conv2d_type_list = [NodePosType.left, NodePosType.both]
691        example_inputs = (torch.randn(2, 3, 6, 6),)
692        quantizer = X86InductorQuantizer().set_global(
693            xiq.get_default_x86_inductor_quantization_config()
694        )
695        with override_quantized_engine("x86"), torch.no_grad():
696            for conv2d_type in conv2d_type_list:
697                m = TestHelperModules.Conv2dAddModule(conv2d_type=conv2d_type).eval()
698                if conv2d_type != NodePosType.both:
699                    node_occurrence = {
700                        # one for input and weight of the conv
701                        # one for extra input node of add
702                        torch.ops.quantized_decomposed.quantize_per_tensor.default: 2,
703                        torch.ops.quantized_decomposed.dequantize_per_tensor.default: 2,
704                        # quantize_per_channel for weights are const propagated
705                        torch.ops.quantized_decomposed.quantize_per_channel.default: 0,
706                        torch.ops.quantized_decomposed.dequantize_per_channel.default: 1,
707                    }
708                else:
709                    node_occurrence = {
710                        # one for input of the conv
711                        # one for input of another conv
712                        # 2 conv will share same input quant/dequant
713                        # one for extra input node of add
714                        torch.ops.quantized_decomposed.quantize_per_tensor.default: 2,
715                        torch.ops.quantized_decomposed.dequantize_per_tensor.default: 3,
716                        # quantize_per_channel for weights are const propagated
717                        torch.ops.quantized_decomposed.quantize_per_channel.default: 0,
718                        torch.ops.quantized_decomposed.dequantize_per_channel.default: 2,
719                    }
720                node_list = [
721                    torch.ops.quantized_decomposed.quantize_per_tensor.default,
722                    torch.ops.quantized_decomposed.dequantize_per_tensor.default,
723                    torch.ops.aten.conv2d.default,
724                    torch.ops.aten.add.Tensor,
725                ]
726                self._test_quantizer(
727                    m,
728                    example_inputs,
729                    quantizer,
730                    node_occurrence,
731                    node_list,
732                )
733
734    @skipIfNoX86
735    def test_conv2d_binary2(self):
736        """
737        Test Pattern:
738            tmp = conv2d_1(x)
739            tmp2 = conv2d_2(tmp)
740            return tmp + tmp2
741        Since conv2d_1 has 2 users, we should annotate conv2d_2 for binary fusion instead of conv2d_1
742        """
743        example_inputs = (torch.randn(2, 3, 6, 6),)
744        quantizer = X86InductorQuantizer().set_global(
745            xiq.get_default_x86_inductor_quantization_config()
746        )
747        inplace_add_list = [True, False]
748        with override_quantized_engine("x86"), torch.no_grad():
749            for inplace_add in inplace_add_list:
750                m = TestHelperModules.Conv2dAddModule2(inplace_add=inplace_add).eval()
751                node_occurrence = {
752                    torch.ops.quantized_decomposed.quantize_per_tensor.default: 2,
753                    torch.ops.quantized_decomposed.dequantize_per_tensor.default: 3,
754                    # quantize_per_channel for weights are const propagated
755                    torch.ops.quantized_decomposed.quantize_per_channel.default: 0,
756                    torch.ops.quantized_decomposed.dequantize_per_channel.default: 2,
757                }
758                node_list = [
759                    torch.ops.quantized_decomposed.quantize_per_tensor.default,
760                    torch.ops.quantized_decomposed.dequantize_per_tensor.default,
761                    torch.ops.aten.conv2d.default,
762                    torch.ops.quantized_decomposed.quantize_per_tensor.default,
763                    (
764                        torch.ops.aten.add_.Tensor
765                        if inplace_add
766                        else torch.ops.aten.add.Tensor
767                    ),
768                ]
769                self._test_quantizer(
770                    m,
771                    example_inputs,
772                    quantizer,
773                    node_occurrence,
774                    node_list,
775                )
776
777    @skipIfNoX86
778    def test_conv2d_binary_unary(self):
779        """
780        Test pattern of conv2d with binary + unary post ops (such as add + relu) with X86InductorQuantizer.
781        Currently, only add as binary post op and relu as unary post op are supported.
782        """
783        conv2d_type_list = [NodePosType.left, NodePosType.both]
784        example_inputs = (torch.randn(2, 3, 6, 6),)
785        quantizer = X86InductorQuantizer().set_global(
786            xiq.get_default_x86_inductor_quantization_config()
787        )
788        with override_quantized_engine("x86"), torch.no_grad():
789            for conv2d_type in conv2d_type_list:
790                m = TestHelperModules.Conv2dAddReLUModule(
791                    conv2d_type=conv2d_type,
792                ).eval()
793                if conv2d_type != NodePosType.both:
794                    node_occurrence = {
795                        # one for input for conv
796                        # one for extra input node of add
797                        torch.ops.quantized_decomposed.quantize_per_tensor.default: 2,
798                        torch.ops.quantized_decomposed.dequantize_per_tensor.default: 2,
799                        # note: quantize op for weights are const propagated
800                        torch.ops.quantized_decomposed.quantize_per_channel.default: 0,
801                        torch.ops.quantized_decomposed.dequantize_per_channel.default: 1,
802                    }
803                else:
804                    node_occurrence = {
805                        # one for input of the conv
806                        # one for input of another conv
807                        # 2 conv will share same input quant/dequant
808                        # one for extra input node of add
809                        torch.ops.quantized_decomposed.quantize_per_tensor.default: 2,
810                        torch.ops.quantized_decomposed.dequantize_per_tensor.default: 3,
811                        # note: quantize op for weights are const propagated
812                        torch.ops.quantized_decomposed.quantize_per_channel.default: 0,
813                        torch.ops.quantized_decomposed.dequantize_per_channel.default: 2,
814                    }
815                node_list = [
816                    torch.ops.quantized_decomposed.quantize_per_tensor.default,
817                    torch.ops.quantized_decomposed.dequantize_per_tensor.default,
818                    torch.ops.aten.conv2d.default,
819                    torch.ops.aten.add.Tensor,
820                ]
821                self._test_quantizer(
822                    m,
823                    example_inputs,
824                    quantizer,
825                    node_occurrence,
826                    node_list,
827                )
828
829    @skipIfNoX86
830    def test_conv2d_serials_binary_unary(self):
831        """
832        Test pattern of 2 following up conv2d add relu with X86InductorQuantizer.
833        """
834        with override_quantized_engine("x86"), torch.no_grad():
835            m = TestHelperModules.SerialsConv2dAddReLUModule().eval()
836            example_inputs = (torch.randn(2, 3, 16, 16),)
837            quantizer = X86InductorQuantizer().set_global(
838                xiq.get_default_x86_inductor_quantization_config()
839            )
840            node_occurrence = {
841                torch.ops.quantized_decomposed.quantize_per_tensor.default: 4,
842                torch.ops.quantized_decomposed.dequantize_per_tensor.default: 6,
843                # quantize_per_channel for weights are const propagated
844                torch.ops.quantized_decomposed.quantize_per_channel.default: 0,
845                torch.ops.quantized_decomposed.dequantize_per_channel.default: 4,
846            }
847            node_list = [
848                torch.ops.quantized_decomposed.quantize_per_tensor.default,
849                torch.ops.quantized_decomposed.dequantize_per_tensor.default,
850                torch.ops.aten.conv2d.default,
851                torch.ops.quantized_decomposed.quantize_per_tensor.default,
852                torch.ops.quantized_decomposed.dequantize_per_tensor.default,
853                torch.ops.aten.conv2d.default,
854                torch.ops.aten.conv2d.default,
855                torch.ops.aten.add.Tensor,
856                torch.ops.aten.relu.default,
857            ]
858            self._test_quantizer(
859                m,
860                example_inputs,
861                quantizer,
862                node_occurrence,
863                node_list,
864            )
865
866    def _single_op_share_observer_recipe_test_helper(self, m, x, single_op):
867        quantizer = X86InductorQuantizer().set_global(
868            xiq.get_default_x86_inductor_quantization_config()
869        )
870        example_inputs = (x,)
871        node_occurrence = {
872            # one for input and weight of the conv, two for input/output for the maxpool2d
873            torch.ops.quantized_decomposed.quantize_per_tensor.default: 3,
874            torch.ops.quantized_decomposed.dequantize_per_tensor.default: 3,
875            # quantize_per_channel for weights are const propagated
876            torch.ops.quantized_decomposed.quantize_per_channel.default: 0,
877            torch.ops.quantized_decomposed.dequantize_per_channel.default: 1,
878        }
879        node_list = [
880            torch.ops.quantized_decomposed.quantize_per_tensor.default,
881            torch.ops.quantized_decomposed.dequantize_per_tensor.default,
882            torch.ops.aten.conv2d.default,
883            torch.ops.quantized_decomposed.quantize_per_tensor.default,
884            torch.ops.quantized_decomposed.dequantize_per_tensor.default,
885            single_op,
886            torch.ops.quantized_decomposed.quantize_per_tensor.default,
887            torch.ops.quantized_decomposed.dequantize_per_tensor.default,
888        ]
889        _, prepare_model, _ = self._test_quantizer(
890            m,
891            example_inputs,
892            quantizer,
893            node_occurrence,
894            node_list,
895        )
896        # Check Maxpool2d has share observer at input and output
897        for node in prepare_model.graph.nodes:
898            if node.op == "call_function" and node.target is single_op:
899                single_op_node = node
900                input_obs_of_single_op = getattr(
901                    prepare_model, single_op_node.args[0].target
902                )
903                output_obs_of_single_op = getattr(
904                    prepare_model, next(iter(single_op_node.users)).target
905                )
906            elif (
907                node.op == "call_function"
908                and node.target is torch.ops.aten.conv2d.default
909            ):
910                conv_node = node
911                input_obs_of_conv = getattr(prepare_model, conv_node.args[0].target)
912        self.assertTrue(isinstance(input_obs_of_single_op, ObserverBase))
913        self.assertTrue(isinstance(output_obs_of_single_op, ObserverBase))
914        self.assertTrue(isinstance(input_obs_of_conv, ObserverBase))
915        self.assertTrue(input_obs_of_single_op is output_obs_of_single_op)
916        self.assertTrue(input_obs_of_single_op is not input_obs_of_conv)
917
918    @skipIfNoX86
919    def test_maxpool2d_recipe(self):
920        r"""
921        Test pattern: int8_in_int8_out_ops(maxpool) - non_quantizable op(pow)
922        Since maxpool is a int8_in_int8_out_op, there is obs between maxpool and pow.
923        """
924        self._single_op_share_observer_recipe_test_helper(
925            TestHelperModules.Conv2dSingleOpPowModule(nn.MaxPool2d(1, 1)).eval(),
926            torch.rand(1, 2, 14, 14),
927            torch.ops.aten.max_pool2d.default,
928        )
929
930    @skipIfNoX86
931    def test_adaptive_avg_pool2d_recipe(self):
932        r"""
933        Test pattern: int8_in_int8_out_ops(adaptive_avg_pool2d) - non_quantizable op(pow)
934        Since adaptive_avg_pool2d is a int8_in_int8_out_op, there is obs between adaptive_avg_pool2d and pow.
935        """
936        self._single_op_share_observer_recipe_test_helper(
937            TestHelperModules.Conv2dSingleOpPowModule(
938                nn.AdaptiveAvgPool2d((1, 1))
939            ).eval(),
940            torch.rand(1, 2, 14, 14),
941            torch.ops.aten.adaptive_avg_pool2d.default,
942        )
943
944    @skipIfNoX86
945    def test_flatten_recipe(self):
946        r"""
947        Test pattern: int8_in_int8_out_ops(flatten) - non_quantizable op(pow)
948        Since flatten is a int8_in_int8_out_op, there is obs between flatten and pow.
949        """
950        self._single_op_share_observer_recipe_test_helper(
951            TestHelperModules.Conv2dSingleOpPowModule(
952                lambda x: torch.flatten(x, 1)
953            ).eval(),
954            torch.rand(1, 2, 14, 14),
955            torch.ops.aten.flatten.using_ints,
956        )
957
958    @skipIfNoX86
959    def test_cat_recipe(self):
960        r"""
961        Test pattern: conv -> cat -> maxpool2d
962        Since cat, maxpool is a int8_in_int8_out_op, the inputs and outputs should with same observer.
963        """
964        m = TestHelperModules.Conv2dCatMaxpool2d().eval()
965        x = torch.randn(16, 3, 16, 16).contiguous(memory_format=torch.channels_last)
966        quantizer = X86InductorQuantizer().set_global(
967            xiq.get_default_x86_inductor_quantization_config()
968        )
969        example_inputs = (x,)
970        node_occurrence = {
971            torch.ops.quantized_decomposed.quantize_per_tensor.default: 6,
972            torch.ops.quantized_decomposed.dequantize_per_tensor.default: 6,
973            # quantize_per_channel for weights are const propagated
974            torch.ops.quantized_decomposed.quantize_per_channel.default: 0,
975            torch.ops.quantized_decomposed.dequantize_per_channel.default: 3,
976        }
977        node_list = [
978            torch.ops.quantized_decomposed.quantize_per_tensor.default,
979            torch.ops.quantized_decomposed.dequantize_per_tensor.default,
980            torch.ops.aten.conv2d.default,
981            torch.ops.quantized_decomposed.quantize_per_tensor.default,
982            torch.ops.quantized_decomposed.dequantize_per_tensor.default,
983            torch.ops.aten.cat.default,
984            torch.ops.quantized_decomposed.quantize_per_tensor.default,
985            torch.ops.quantized_decomposed.dequantize_per_tensor.default,
986            torch.ops.aten.max_pool2d.default,
987            torch.ops.quantized_decomposed.quantize_per_tensor.default,
988            torch.ops.quantized_decomposed.dequantize_per_tensor.default,
989        ]
990        _, prepare_model, _ = self._test_quantizer(
991            m,
992            example_inputs,
993            quantizer,
994            node_occurrence,
995            node_list,
996        )
997        # Check Cat/Maxpool2d has share observer at input and output
998        for node in prepare_model.graph.nodes:
999            if node.op == "call_function" and node.target == torch.ops.aten.cat.default:
1000                cat_act_obs0 = getattr(prepare_model, node.all_input_nodes[0].target)
1001                cat_act_obs1 = getattr(prepare_model, node.all_input_nodes[1].target)
1002                cat_out_obs = getattr(prepare_model, next(iter(node.users)).target)
1003            elif (
1004                node.op == "call_function"
1005                and node.target is torch.ops.aten.max_pool2d.default
1006            ):
1007                maxpool_node = node
1008                input_obs_of_maxpool = getattr(
1009                    prepare_model, maxpool_node.args[0].target
1010                )
1011                output_obs_of_maxpool = getattr(
1012                    prepare_model, next(iter(maxpool_node.users)).target
1013                )
1014        self.assertTrue(isinstance(cat_act_obs0, ObserverBase))
1015        self.assertTrue(isinstance(cat_act_obs1, ObserverBase))
1016        self.assertTrue(isinstance(cat_out_obs, ObserverBase))
1017        self.assertTrue(isinstance(input_obs_of_maxpool, ObserverBase))
1018        self.assertTrue(isinstance(output_obs_of_maxpool, ObserverBase))
1019        self.assertTrue(cat_act_obs0 is cat_act_obs1)
1020        self.assertTrue(cat_act_obs0 is cat_out_obs)
1021        self.assertTrue(cat_out_obs is input_obs_of_maxpool)
1022        self.assertTrue(input_obs_of_maxpool is output_obs_of_maxpool)
1023
1024    @skipIfNoX86
1025    def test_cat_recipe_same_inputs(self):
1026        r"""
1027        Test pattern: conv -> cat([input0, input0])
1028        Since cat has 2 input node of same tensor, they should also be with same observer.
1029        """
1030        m = TestHelperModules.Conv2dCatSameInputs().eval()
1031        x = torch.randn(16, 3, 16, 16).contiguous(memory_format=torch.channels_last)
1032        quantizer = X86InductorQuantizer().set_global(
1033            xiq.get_default_x86_inductor_quantization_config()
1034        )
1035        example_inputs = (x,)
1036        node_occurrence = {
1037            torch.ops.quantized_decomposed.quantize_per_tensor.default: 3,
1038            torch.ops.quantized_decomposed.dequantize_per_tensor.default: 3,
1039            # quantize_per_channel for weights are const propagated
1040            torch.ops.quantized_decomposed.quantize_per_channel.default: 0,
1041            torch.ops.quantized_decomposed.dequantize_per_channel.default: 1,
1042        }
1043        node_list = [
1044            torch.ops.quantized_decomposed.quantize_per_tensor.default,
1045            torch.ops.quantized_decomposed.dequantize_per_tensor.default,
1046            torch.ops.aten.conv2d.default,
1047            torch.ops.quantized_decomposed.quantize_per_tensor.default,
1048            torch.ops.quantized_decomposed.dequantize_per_tensor.default,
1049            torch.ops.aten.cat.default,
1050            torch.ops.quantized_decomposed.quantize_per_tensor.default,
1051            torch.ops.quantized_decomposed.dequantize_per_tensor.default,
1052        ]
1053        _, prepare_model, _ = self._test_quantizer(
1054            m,
1055            example_inputs,
1056            quantizer,
1057            node_occurrence,
1058            node_list,
1059        )
1060        # Check Cat has share observer at input and output
1061        for node in prepare_model.graph.nodes:
1062            if node.op == "call_function" and node.target == torch.ops.aten.cat.default:
1063                cat_act_obs0 = getattr(prepare_model, node.args[0][0].target)
1064                cat_act_obs1 = getattr(prepare_model, node.args[0][1].target)
1065                cat_out_obs = getattr(prepare_model, next(iter(node.users)).target)
1066        self.assertTrue(isinstance(cat_act_obs0, ObserverBase))
1067        self.assertTrue(isinstance(cat_act_obs1, ObserverBase))
1068        self.assertTrue(isinstance(cat_out_obs, ObserverBase))
1069        self.assertTrue(cat_act_obs0 is cat_act_obs1)
1070        self.assertTrue(cat_act_obs0 is cat_out_obs)
1071
1072    @skipIfNoX86
1073    def test_cat_recipe_single_input(self):
1074        r"""
1075        Test pattern: conv -> cat([input0,])
1076        Since cat has 1 input node, they should also be with same observer.
1077        """
1078        m = TestHelperModules.Conv2dCatSingleInput().eval()
1079        x = torch.randn(16, 3, 16, 16).contiguous(memory_format=torch.channels_last)
1080        quantizer = X86InductorQuantizer().set_global(
1081            xiq.get_default_x86_inductor_quantization_config()
1082        )
1083        example_inputs = (x,)
1084        node_occurrence = {
1085            torch.ops.quantized_decomposed.quantize_per_tensor.default: 3,
1086            torch.ops.quantized_decomposed.dequantize_per_tensor.default: 3,
1087            # quantize_per_channel for weights are const propagated
1088            torch.ops.quantized_decomposed.quantize_per_channel.default: 0,
1089            torch.ops.quantized_decomposed.dequantize_per_channel.default: 1,
1090        }
1091        node_list = [
1092            torch.ops.quantized_decomposed.quantize_per_tensor.default,
1093            torch.ops.quantized_decomposed.dequantize_per_tensor.default,
1094            torch.ops.aten.conv2d.default,
1095            torch.ops.quantized_decomposed.quantize_per_tensor.default,
1096            torch.ops.quantized_decomposed.dequantize_per_tensor.default,
1097            torch.ops.aten.cat.default,
1098            torch.ops.quantized_decomposed.quantize_per_tensor.default,
1099            torch.ops.quantized_decomposed.dequantize_per_tensor.default,
1100        ]
1101        _, prepare_model, _ = self._test_quantizer(
1102            m,
1103            example_inputs,
1104            quantizer,
1105            node_occurrence,
1106            node_list,
1107        )
1108        # Check Cat has share observer at input and output
1109        for node in prepare_model.graph.nodes:
1110            if node.op == "call_function" and node.target == torch.ops.aten.cat.default:
1111                cat_act_obs0 = getattr(prepare_model, node.args[0][0].target)
1112                cat_out_obs = getattr(prepare_model, next(iter(node.users)).target)
1113        self.assertTrue(isinstance(cat_act_obs0, ObserverBase))
1114        self.assertTrue(isinstance(cat_out_obs, ObserverBase))
1115        self.assertTrue(cat_act_obs0 is cat_out_obs)
1116
1117    @skipIfNoX86
1118    def test_avg_pool2d_recipe(self):
1119        r"""
1120        Test pattern: conv -> AvgPool2d
1121        Since AvgPool2d is a int8_in_int8_out_op, the inputs and outputs should with same observer.
1122        """
1123        m = TestHelperModules.Conv2dAvgPool2d().eval()
1124        x = torch.randn(16, 3, 16, 16).contiguous(memory_format=torch.channels_last)
1125        quantizer = X86InductorQuantizer().set_global(
1126            xiq.get_default_x86_inductor_quantization_config()
1127        )
1128        example_inputs = (x,)
1129        node_occurrence = {
1130            torch.ops.quantized_decomposed.quantize_per_tensor.default: 3,
1131            torch.ops.quantized_decomposed.dequantize_per_tensor.default: 3,
1132            # quantize_per_channel for weights are const propagated
1133            torch.ops.quantized_decomposed.quantize_per_channel.default: 0,
1134            torch.ops.quantized_decomposed.dequantize_per_channel.default: 1,
1135        }
1136        node_list = [
1137            torch.ops.quantized_decomposed.quantize_per_tensor.default,
1138            torch.ops.quantized_decomposed.dequantize_per_tensor.default,
1139            torch.ops.aten.conv2d.default,
1140            torch.ops.quantized_decomposed.quantize_per_tensor.default,
1141            torch.ops.quantized_decomposed.dequantize_per_tensor.default,
1142            torch.ops.aten.avg_pool2d.default,
1143            torch.ops.quantized_decomposed.quantize_per_tensor.default,
1144            torch.ops.quantized_decomposed.dequantize_per_tensor.default,
1145        ]
1146        _, prepare_model, _ = self._test_quantizer(
1147            m,
1148            example_inputs,
1149            quantizer,
1150            node_occurrence,
1151            node_list,
1152        )
1153        for node in prepare_model.graph.nodes:
1154            if (
1155                node.op == "call_function"
1156                and node.target is torch.ops.aten.avg_pool2d.default
1157            ):
1158                avgpool_node = node
1159                input_obs_of_avgpool = getattr(
1160                    prepare_model, avgpool_node.args[0].target
1161                )
1162                output_obs_of_avgpool = getattr(
1163                    prepare_model, next(iter(avgpool_node.users)).target
1164                )
1165            elif (
1166                node.op == "call_function"
1167                and node.target is torch.ops.aten.conv2d.default
1168            ):
1169                conv_node = node
1170                output_obs_of_conv = getattr(
1171                    prepare_model, next(iter(conv_node.users)).target
1172                )
1173        self.assertTrue(isinstance(input_obs_of_avgpool, ObserverBase))
1174        self.assertTrue(isinstance(output_obs_of_avgpool, ObserverBase))
1175        self.assertTrue(isinstance(output_obs_of_conv, ObserverBase))
1176        self.assertTrue(input_obs_of_avgpool is output_obs_of_avgpool)
1177        self.assertTrue(input_obs_of_avgpool is output_obs_of_conv)
1178
1179    @skipIfNoX86
1180    def test_linear(self):
1181        """
1182        Test pattern of single linear with X86InductorQuantizer.
1183        """
1184        with override_quantized_engine("x86"), torch.no_grad():
1185            for use_bias in [True, False]:
1186                m = TestHelperModules.SingleLinearModule(use_bias).eval()
1187                example_inputs = (torch.randn(2, 4),)
1188                quantizer = X86InductorQuantizer().set_global(
1189                    xiq.get_default_x86_inductor_quantization_config()
1190                )
1191                node_occurrence = {
1192                    # one for input and weight, one for output
1193                    torch.ops.quantized_decomposed.quantize_per_tensor.default: 1,
1194                    torch.ops.quantized_decomposed.dequantize_per_tensor.default: 1,
1195                    # quantize_per_channel for weights are const propagated
1196                    torch.ops.quantized_decomposed.quantize_per_channel.default: 0,
1197                    torch.ops.quantized_decomposed.dequantize_per_channel.default: 1,
1198                }
1199                node_list = [
1200                    torch.ops.quantized_decomposed.quantize_per_tensor.default,
1201                    torch.ops.quantized_decomposed.dequantize_per_tensor.default,
1202                    torch.ops.aten.linear.default,
1203                ]
1204                self._test_quantizer(
1205                    m,
1206                    example_inputs,
1207                    quantizer,
1208                    node_occurrence,
1209                    node_list,
1210                )
1211
1212    def _test_linear_unary_helper(
1213        self,
1214        post_op_module,
1215        post_op_aten,
1216        post_op_aten_inplace,
1217        post_op_algo_list=None,
1218        is_qat=False,
1219        is_dynamic=False,
1220    ):
1221        """
1222        Test pattern of linear with unary post ops (e.g. relu) with X86InductorQuantizer.
1223        """
1224        use_bias_list = [True, False]
1225        # TODO test for inplace add after refactoring of capture_pre_autograd_graph
1226        inplace_list = [False]
1227        if post_op_algo_list is None:
1228            post_op_algo_list = [None]
1229        cases = itertools.product(use_bias_list, inplace_list, post_op_algo_list)
1230        with override_quantized_engine("x86"), torch.no_grad():
1231            for use_bias, inplace, post_op_algo in cases:
1232                if inplace and post_op_aten_inplace is None:
1233                    continue
1234                m = TestHelperModules.LinearUnaryModule(
1235                    use_bias=use_bias,
1236                    postop=post_op_module,
1237                    inplace_postop=inplace,
1238                    post_op_algo=post_op_algo,
1239                ).eval()
1240                example_inputs = (torch.randn(2, 4),)
1241                quantizer = X86InductorQuantizer().set_global(
1242                    xiq.get_default_x86_inductor_quantization_config(
1243                        is_qat=is_qat,
1244                        is_dynamic=is_dynamic,
1245                    )
1246                )
1247                quantize_per_tensor_op = (
1248                    torch.ops.quantized_decomposed.quantize_per_tensor.tensor
1249                    if is_dynamic
1250                    else torch.ops.quantized_decomposed.quantize_per_tensor.default
1251                )
1252                dequantize_per_tensor_op = (
1253                    torch.ops.quantized_decomposed.dequantize_per_tensor.tensor
1254                    if is_dynamic
1255                    else torch.ops.quantized_decomposed.dequantize_per_tensor.default
1256                )
1257                node_occurrence = {
1258                    # one for input of the linear
1259                    quantize_per_tensor_op: 1,
1260                    dequantize_per_tensor_op: 1,
1261                    # quantize_per_channel for weights are const propagated
1262                    torch.ops.quantized_decomposed.quantize_per_channel.default: 0,
1263                    torch.ops.quantized_decomposed.dequantize_per_channel.default: 1,
1264                }
1265                node_list = [
1266                    quantize_per_tensor_op,
1267                    dequantize_per_tensor_op,
1268                    torch.ops.aten.linear.default,
1269                    post_op_aten_inplace if inplace else post_op_aten,
1270                ]
1271                self._test_quantizer(
1272                    m,
1273                    example_inputs,
1274                    quantizer,
1275                    node_occurrence,
1276                    node_list,
1277                    is_qat=is_qat,
1278                )
1279
1280    @skipIfNoX86
1281    def test_linear_unary(self):
1282        aten = torch.ops.aten
1283        self._test_linear_unary_helper(nn.ReLU, aten.relu.default, aten.relu_.default)
1284        self._test_linear_unary_helper(
1285            nn.LeakyReLU, aten.leaky_relu.default, aten.leaky_relu_.default
1286        )
1287        self._test_linear_unary_helper(
1288            nn.GELU, aten.gelu.default, None, ["none", "tanh"]
1289        )
1290
1291    @skipIfNoX86
1292    def test_linear_unary_qat(self):
1293        aten = torch.ops.aten
1294        self._test_linear_unary_helper(
1295            nn.ReLU, aten.relu.default, aten.relu_.default, is_qat=True
1296        )
1297        self._test_linear_unary_helper(
1298            nn.LeakyReLU, aten.leaky_relu.default, aten.leaky_relu_.default, is_qat=True
1299        )
1300        self._test_linear_unary_helper(
1301            nn.GELU, aten.gelu.default, None, ["none", "tanh"], is_qat=True
1302        )
1303
1304    @skipIfNoX86
1305    def test_linear_unary_dynamic(self):
1306        aten = torch.ops.aten
1307        self._test_linear_unary_helper(
1308            nn.ReLU, aten.relu.default, aten.relu_.default, is_dynamic=True
1309        )
1310        self._test_linear_unary_helper(
1311            nn.LeakyReLU,
1312            aten.leaky_relu.default,
1313            aten.leaky_relu_.default,
1314            is_dynamic=True,
1315        )
1316        self._test_linear_unary_helper(
1317            nn.GELU, aten.gelu.default, None, ["none", "tanh"], is_dynamic=True
1318        )
1319
1320    @skipIfNoX86
1321    def test_linear_unary_dynamic_qat(self):
1322        aten = torch.ops.aten
1323        self._test_linear_unary_helper(
1324            nn.ReLU, aten.relu.default, aten.relu_.default, is_qat=True, is_dynamic=True
1325        )
1326        self._test_linear_unary_helper(
1327            nn.LeakyReLU,
1328            aten.leaky_relu.default,
1329            aten.leaky_relu_.default,
1330            is_qat=True,
1331            is_dynamic=True,
1332        )
1333        self._test_linear_unary_helper(
1334            nn.GELU,
1335            aten.gelu.default,
1336            None,
1337            ["none", "tanh"],
1338            is_qat=True,
1339            is_dynamic=True,
1340        )
1341
1342    def _check_annotation_stat(self, gm, expected_stat_dict):
1343        # Check expected annotation statistics to ensure the annotation is correct
1344
1345        def _check_annotation(node):
1346            annot = node.meta.get(QUANT_ANNOTATION_KEY, None)
1347            if annot is None:
1348                return False, False
1349            return annot._annotated, annot._is_output_of_quantized_pattern
1350
1351        for node in gm.graph.nodes:
1352            if node.target in expected_stat_dict.keys():
1353                annotated, is_quant_out = _check_annotation(node)
1354                expected_stat_dict[node.target]["annotated"] -= annotated
1355                expected_stat_dict[node.target]["is_quant_out"] -= is_quant_out
1356        for op_stat in expected_stat_dict.values():
1357            assert all(v == 0 for v in op_stat.values())
1358
1359    def _test_linear_binary_helper(self, is_qat=False, is_dynamic=False):
1360        """
1361        Test pattern of linear with binary post ops (such as add) with X86InductorQuantizer.
1362        Currently, only add as binary post op is supported.
1363        """
1364        linear_pos_list = [NodePosType.left, NodePosType.right, NodePosType.both]
1365        # TODO test for inplace add after refactoring of capture_pre_autograd_graph
1366        inplace_add_list = [False]
1367        example_inputs = (torch.randn(2, 16),)
1368        quantizer = X86InductorQuantizer().set_global(
1369            xiq.get_default_x86_inductor_quantization_config(
1370                is_qat=is_qat,
1371                is_dynamic=is_dynamic,
1372            )
1373        )
1374        quantize_per_tensor_op = (
1375            torch.ops.quantized_decomposed.quantize_per_tensor.tensor
1376            if is_dynamic
1377            else torch.ops.quantized_decomposed.quantize_per_tensor.default
1378        )
1379        dequantize_per_tensor_op = (
1380            torch.ops.quantized_decomposed.dequantize_per_tensor.tensor
1381            if is_dynamic
1382            else torch.ops.quantized_decomposed.dequantize_per_tensor.default
1383        )
1384        cases = itertools.product(linear_pos_list, inplace_add_list)
1385        with override_quantized_engine("x86"), torch.no_grad():
1386            for linear_pos, inplace_add in cases:
1387                m = TestHelperModules.LinearAddModule(
1388                    inplace_add=inplace_add, linear_pos=linear_pos
1389                ).eval()
1390                if linear_pos != NodePosType.both:
1391                    node_occurrence = {
1392                        # Only one 1 q-dq for input of the linear
1393                        # No q-dq for extra input node of add
1394                        quantize_per_tensor_op: 1,
1395                        dequantize_per_tensor_op: 1,
1396                        # quantize_per_channel for weights are const propagated
1397                        torch.ops.quantized_decomposed.quantize_per_channel.default: 0,
1398                        torch.ops.quantized_decomposed.dequantize_per_channel.default: 1,
1399                    }
1400                else:
1401                    # convert_pt2e disables duplicate dequant for dynamic quant
1402                    num_dequant = 1 if is_dynamic else 2
1403                    node_occurrence = {
1404                        # One quantize_per_tensor for both linear nodes (shared)
1405                        # Two dequantize_per_tensor for two linear nodes
1406                        # No q-dq for extra input node of add
1407                        quantize_per_tensor_op: 1,
1408                        dequantize_per_tensor_op: num_dequant,
1409                        # quantize_per_channel for weights are const propagated
1410                        torch.ops.quantized_decomposed.quantize_per_channel.default: 0,
1411                        torch.ops.quantized_decomposed.dequantize_per_channel.default: 2,
1412                    }
1413                node_list = [
1414                    quantize_per_tensor_op,
1415                    dequantize_per_tensor_op,
1416                    torch.ops.aten.linear.default,
1417                    (
1418                        torch.ops.aten.add_.Tensor
1419                        if inplace_add
1420                        else torch.ops.aten.add.Tensor
1421                    ),
1422                ]
1423                fq_m = self._test_quantizer(
1424                    m,
1425                    example_inputs,
1426                    quantizer,
1427                    node_occurrence,
1428                    node_list,
1429                    is_qat=is_qat,
1430                )[-1]
1431                # One linear and add are fused. The other linear is quantized alone if present
1432                aten = torch.ops.aten
1433                add_op = aten.add_.Tensor if inplace_add else aten.add.Tensor
1434                expected_annotation_stat = {
1435                    aten.linear.default: {
1436                        "annotated": 2 if linear_pos == NodePosType.both else 1,
1437                        "is_quant_out": 1 if linear_pos == NodePosType.both else 0,
1438                    },
1439                    add_op: {"annotated": 1, "is_quant_out": 1},
1440                }
1441                self._check_annotation_stat(fq_m, expected_annotation_stat)
1442
1443    @skipIfNoX86
1444    def test_linear_binary(self):
1445        self._test_linear_binary_helper()
1446
1447    @skipIfNoX86
1448    def test_linear_binary_qat(self):
1449        self._test_linear_binary_helper(is_qat=True)
1450
1451    @skipIfNoX86
1452    def test_linear_binary_dynamic(self):
1453        self._test_linear_binary_helper(is_dynamic=True)
1454
1455    @skipIfNoX86
1456    def test_linear_binary_dynamic_qat(self):
1457        self._test_linear_binary_helper(is_qat=True, is_dynamic=True)
1458
1459    @skipIfNoX86
1460    def test_linear_binary2(self):
1461        """
1462        Test Pattern:
1463            tmp = linear_1(x)
1464            tmp2 = linear_2(tmp)
1465            return tmp + tmp2
1466        Since linear_1 has 2 users, we should annotate linear_2 for binary fusion instead of linear_1
1467        """
1468        example_inputs = (torch.randn(2, 16),)
1469        # TODO test for inplace add after refactoring of capture_pre_autograd_graph
1470        inplace_add_list = [False]
1471        is_qat_list = [False, True]
1472        is_dynamic_list = [False, True]
1473        cases = itertools.product(inplace_add_list, is_qat_list, is_dynamic_list)
1474        with override_quantized_engine("x86"), torch.no_grad():
1475            for inplace_add, is_qat, is_dynamic in cases:
1476                quantizer = X86InductorQuantizer().set_global(
1477                    xiq.get_default_x86_inductor_quantization_config(
1478                        is_qat=is_qat, is_dynamic=is_dynamic
1479                    )
1480                )
1481                m = TestHelperModules.LinearAddModule2(inplace_add=inplace_add).eval()
1482                quantize_per_tensor_op = (
1483                    torch.ops.quantized_decomposed.quantize_per_tensor.tensor
1484                    if is_dynamic
1485                    else torch.ops.quantized_decomposed.quantize_per_tensor.default
1486                )
1487                dequantize_per_tensor_op = (
1488                    torch.ops.quantized_decomposed.dequantize_per_tensor.tensor
1489                    if is_dynamic
1490                    else torch.ops.quantized_decomposed.dequantize_per_tensor.default
1491                )
1492                # Two q-dq nodes for inputs of linear nodes
1493                # No q-dq for extra input node of add
1494                node_occurrence = {
1495                    quantize_per_tensor_op: 2,
1496                    dequantize_per_tensor_op: 2,
1497                    # quantize_per_channel for weights are const propagated
1498                    torch.ops.quantized_decomposed.quantize_per_channel.default: 0,
1499                    torch.ops.quantized_decomposed.dequantize_per_channel.default: 2,
1500                }
1501                node_list = [
1502                    torch.ops.quantized_decomposed.dequantize_per_channel.default,
1503                    quantize_per_tensor_op,
1504                    dequantize_per_tensor_op,
1505                    torch.ops.aten.linear.default,
1506                    (
1507                        torch.ops.aten.add_.Tensor
1508                        if inplace_add
1509                        else torch.ops.aten.add.Tensor
1510                    ),
1511                ]
1512                fq_m = self._test_quantizer(
1513                    m,
1514                    example_inputs,
1515                    quantizer,
1516                    node_occurrence,
1517                    node_list,
1518                )[-1]
1519                # One linear and add are fused. The other linear is quantized alone if present
1520                aten = torch.ops.aten
1521                add_op = aten.add_.Tensor if inplace_add else aten.add.Tensor
1522                expected_annotation_stat = {
1523                    aten.linear.default: {
1524                        "annotated": 2,
1525                        "is_quant_out": 1,
1526                    },
1527                    add_op: {"annotated": 1, "is_quant_out": 1},
1528                }
1529                self._check_annotation_stat(fq_m, expected_annotation_stat)
1530
1531    @skipIfNoX86
1532    def _test_linear_binary_unary_helper(self, is_qat=False, is_dynamic=False):
1533        """
1534        Test pattern of linear with binary + unary post ops (such as add + relu) with X86InductorQuantizer.
1535        Currently, only add as binary post op and relu as unary post op are supported.
1536        """
1537        linear_pos_list = [NodePosType.left, NodePosType.right, NodePosType.both]
1538        # TODO test for inplace add after refactoring of capture_pre_autograd_graph
1539        inplace_add_list = [False]
1540        # TODO test for inplace relu after refactoring of capture_pre_autograd_graph
1541        inplace_relu_list = [False]
1542        example_inputs = (torch.randn(2, 16),)
1543        quantizer = X86InductorQuantizer().set_global(
1544            xiq.get_default_x86_inductor_quantization_config(
1545                is_qat=is_qat,
1546                is_dynamic=is_dynamic,
1547            )
1548        )
1549        quantize_per_tensor_op = (
1550            torch.ops.quantized_decomposed.quantize_per_tensor.tensor
1551            if is_dynamic
1552            else torch.ops.quantized_decomposed.quantize_per_tensor.default
1553        )
1554        dequantize_per_tensor_op = (
1555            torch.ops.quantized_decomposed.dequantize_per_tensor.tensor
1556            if is_dynamic
1557            else torch.ops.quantized_decomposed.dequantize_per_tensor.default
1558        )
1559        cases = itertools.product(linear_pos_list, inplace_add_list, inplace_relu_list)
1560        with override_quantized_engine("x86"), torch.no_grad():
1561            for linear_pos, inplace_add, inplace_relu in cases:
1562                m = TestHelperModules.LinearAddReLUModule(
1563                    inplace_add=inplace_add,
1564                    linear_pos=linear_pos,
1565                    inplace_relu=inplace_relu,
1566                ).eval()
1567                if linear_pos != NodePosType.both:
1568                    node_occurrence = {
1569                        # Only one q-dq node for input of the linear
1570                        # No q-dq node for extra input node of add
1571                        quantize_per_tensor_op: 1,
1572                        dequantize_per_tensor_op: 1,
1573                        # note: quantize op for weights are const propagated
1574                        torch.ops.quantized_decomposed.quantize_per_channel.default: 0,
1575                        torch.ops.quantized_decomposed.dequantize_per_channel.default: 1,
1576                    }
1577                else:
1578                    # convert_pt2e disables duplicate dequant for dynamic quant
1579                    num_dequant = 1 if is_dynamic else 2
1580                    node_occurrence = {
1581                        # One quantize_per_tensor for both linear nodes (shared)
1582                        # Two dequantize_per_tensor for two linear nodes
1583                        # No q-dq for extra input node of add
1584                        quantize_per_tensor_op: 1,
1585                        dequantize_per_tensor_op: num_dequant,
1586                        # note: quantize op for weights are const propagated
1587                        torch.ops.quantized_decomposed.quantize_per_channel.default: 0,
1588                        torch.ops.quantized_decomposed.dequantize_per_channel.default: 2,
1589                    }
1590                node_list = [
1591                    quantize_per_tensor_op,
1592                    dequantize_per_tensor_op,
1593                    torch.ops.aten.linear.default,
1594                    (
1595                        torch.ops.aten.add_.Tensor
1596                        if inplace_add
1597                        else torch.ops.aten.add.Tensor
1598                    ),
1599                ]
1600                fq_m = self._test_quantizer(
1601                    m,
1602                    example_inputs,
1603                    quantizer,
1604                    node_occurrence,
1605                    node_list,
1606                )[-1]
1607                # linear, add, relu are fused
1608                # The other linear is quantized alone if present
1609                aten = torch.ops.aten
1610                add_op = aten.add_.Tensor if inplace_add else aten.add.Tensor
1611                relu_op = aten.relu_.default if inplace_relu else aten.relu.default
1612                expected_annotation_stat = {
1613                    aten.linear.default: {
1614                        "annotated": 2 if linear_pos == NodePosType.both else 1,
1615                        "is_quant_out": 1 if linear_pos == NodePosType.both else 0,
1616                    },
1617                    add_op: {"annotated": 1, "is_quant_out": 0},
1618                    relu_op: {"annotated": 1, "is_quant_out": 1},
1619                }
1620                self._check_annotation_stat(fq_m, expected_annotation_stat)
1621
1622    @skipIfNoX86
1623    def test_linear_binary_unary(self):
1624        self._test_linear_binary_unary_helper()
1625
1626    @skipIfNoX86
1627    def test_linear_binary_unary_qat(self):
1628        self._test_linear_binary_unary_helper(is_qat=True)
1629
1630    @skipIfNoX86
1631    def test_linear_binary_unary_dynamic(self):
1632        self._test_linear_binary_unary_helper(is_dynamic=True)
1633
1634    @skipIfNoX86
1635    def test_linear_binary_unary_dynamic_qat(self):
1636        self._test_linear_binary_unary_helper(is_qat=True, is_dynamic=True)
1637
1638    @skipIfNoX86
1639    def test_linear_binary_unary_serials(self):
1640        """
1641        Test pattern of 2 following up linear add relu with X86InductorQuantizer.
1642        """
1643        is_qat_list = [False, True]
1644        is_dynamic_list = [False, True]
1645        cases = itertools.product(is_qat_list, is_dynamic_list)
1646        with override_quantized_engine("x86"), torch.no_grad():
1647            for is_qat, is_dynamic in cases:
1648                m = TestHelperModules.SerialsLinearAddReLUModule().eval()
1649                example_inputs = (torch.randn(2, 16),)
1650                quantizer = X86InductorQuantizer().set_global(
1651                    xiq.get_default_x86_inductor_quantization_config(
1652                        is_qat=is_qat,
1653                        is_dynamic=is_dynamic,
1654                    )
1655                )
1656                quantize_per_tensor_op = (
1657                    torch.ops.quantized_decomposed.quantize_per_tensor.tensor
1658                    if is_dynamic
1659                    else torch.ops.quantized_decomposed.quantize_per_tensor.default
1660                )
1661                dequantize_per_tensor_op = (
1662                    torch.ops.quantized_decomposed.dequantize_per_tensor.tensor
1663                    if is_dynamic
1664                    else torch.ops.quantized_decomposed.dequantize_per_tensor.default
1665                )
1666                # convert_pt2e disables duplicate dequant for dynamic quant
1667                num_dequant = 3 if is_dynamic else 4
1668                node_occurrence = {
1669                    # quantize_per_tensor: 1 for linear_1, 1 for linear_2/3 (shared), 1 for linear_4
1670                    # dequantize_per_tensor: 1 for each linear
1671                    # No q-dq for extra input node of add
1672                    quantize_per_tensor_op: 3,
1673                    dequantize_per_tensor_op: num_dequant,
1674                    # quantize_per_channel for weights are const propagated
1675                    torch.ops.quantized_decomposed.quantize_per_channel.default: 0,
1676                    torch.ops.quantized_decomposed.dequantize_per_channel.default: 4,
1677                }
1678                node_list = [
1679                    torch.ops.quantized_decomposed.dequantize_per_channel.default,
1680                    quantize_per_tensor_op,
1681                    dequantize_per_tensor_op,
1682                    torch.ops.aten.linear.default,
1683                    torch.ops.aten.linear.default,
1684                    torch.ops.aten.linear.default,
1685                    torch.ops.aten.add.Tensor,
1686                    torch.ops.aten.relu.default,
1687                ]
1688                fq_m = self._test_quantizer(
1689                    m,
1690                    example_inputs,
1691                    quantizer,
1692                    node_occurrence,
1693                    node_list,
1694                )[-1]
1695                # Two linear nodes are quantized alone
1696                # The other two are fused with add and relu
1697                aten = torch.ops.aten
1698                expected_annotation_stat = {
1699                    aten.linear.default: {
1700                        "annotated": 4,
1701                        "is_quant_out": 2,
1702                    },
1703                    aten.add.Tensor: {"annotated": 2, "is_quant_out": 0},
1704                    aten.relu.default: {"annotated": 2, "is_quant_out": 2},
1705                }
1706                self._check_annotation_stat(fq_m, expected_annotation_stat)
1707
1708    @skipIfTorchDynamo("very slow")
1709    @skipIfNoX86
1710    def test_qat_conv2d(self):
1711        """
1712        Test QAT pattern of conv2d_bn with X86InductorQuantizer.
1713        """
1714        with override_quantized_engine("x86"):
1715            m = TestHelperModules.SingleConv2dModule(with_bn=True)
1716            example_inputs = (torch.randn(2, 3, 16, 16),)
1717            quantizer = X86InductorQuantizer().set_global(
1718                xiq.get_default_x86_inductor_quantization_config(is_qat=True)
1719            )
1720            node_occurrence = {
1721                # one for input and weight of the conv, one for output for the conv
1722                torch.ops.quantized_decomposed.quantize_per_tensor.default: 2,
1723                torch.ops.quantized_decomposed.dequantize_per_tensor.default: 2,
1724                # note: quantize op for weights are const propagated
1725                torch.ops.quantized_decomposed.quantize_per_channel.default: 0,
1726                torch.ops.quantized_decomposed.dequantize_per_channel.default: 1,
1727                # BN should be folded into Conv
1728                torch.ops.aten._native_batch_norm_legit.default: 0,
1729            }
1730            node_list = [
1731                torch.ops.quantized_decomposed.quantize_per_tensor.default,
1732                torch.ops.quantized_decomposed.dequantize_per_tensor.default,
1733                torch.ops.aten.conv2d.default,
1734                torch.ops.quantized_decomposed.quantize_per_tensor.default,
1735                torch.ops.quantized_decomposed.dequantize_per_tensor.default,
1736            ]
1737            self._test_quantizer(
1738                m,
1739                example_inputs,
1740                quantizer,
1741                node_occurrence,
1742                node_list,
1743                is_qat=True,
1744            )
1745
1746    @skipIfTorchDynamo("very slow")
1747    @skipIfNoX86
1748    def test_qat_conv2d_unary(self):
1749        """
1750        Test QAT pattern of conv2d_bn with unary post ops (such as relu, sigmoid) with X86InductorQuantizer.
1751        Currently, only relu as unary post op is supported.
1752        """
1753        unary_map = {
1754            "relu": [torch.nn.ReLU(inplace=False), torch.ops.aten.relu.default],
1755            "relu_inplace": [torch.nn.ReLU(inplace=True), torch.ops.aten.relu_.default],
1756            "hardtanh": [
1757                torch.nn.Hardtanh(min_val=0.0, max_val=6.0, inplace=False),
1758                torch.ops.aten.hardtanh.default,
1759            ],
1760            "hardtanh_inplace": [
1761                torch.nn.Hardtanh(min_val=0.0, max_val=6.0, inplace=True),
1762                torch.ops.aten.hardtanh_.default,
1763            ],
1764            "relu6": [torch.nn.ReLU6(inplace=False), torch.ops.aten.hardtanh.default],
1765            "relu6_inplace": [
1766                torch.nn.ReLU6(inplace=True),
1767                torch.ops.aten.hardtanh_.default,
1768            ],
1769            "hardswish": [
1770                torch.nn.Hardswish(inplace=False),
1771                torch.ops.aten.hardswish.default,
1772            ],
1773            "hardswish_inplace": [
1774                torch.nn.Hardswish(inplace=True),
1775                torch.ops.aten.hardswish_.default,
1776            ],
1777            "swish": [torch.nn.SiLU(inplace=False), torch.ops.aten.silu.default],
1778            "swish_inplace": [
1779                torch.nn.SiLU(inplace=True),
1780                torch.ops.aten.silu_.default,
1781            ],
1782        }
1783
1784        with override_quantized_engine("x86"):
1785            for unary_op in unary_map.keys():
1786                m = TestHelperModules.Conv2dUnaryModule(
1787                    unary_map[unary_op][0], with_bn=True
1788                )
1789                example_inputs = (torch.randn(2, 3, 16, 16),)
1790                quantizer = X86InductorQuantizer().set_global(
1791                    xiq.get_default_x86_inductor_quantization_config(is_qat=True)
1792                )
1793                node_occurrence = {
1794                    # one for input and weight of the conv, one for output for the relu
1795                    torch.ops.quantized_decomposed.quantize_per_tensor.default: 3,
1796                    torch.ops.quantized_decomposed.dequantize_per_tensor.default: 3,
1797                    # note: quantize op for weights are const propagated
1798                    torch.ops.quantized_decomposed.quantize_per_channel.default: 0,
1799                    torch.ops.quantized_decomposed.dequantize_per_channel.default: 1,
1800                    # BN should be folded into Conv
1801                    torch.ops.aten._native_batch_norm_legit.default: 0,
1802                }
1803                node_list = [
1804                    torch.ops.quantized_decomposed.quantize_per_tensor.default,
1805                    torch.ops.quantized_decomposed.dequantize_per_tensor.default,
1806                    torch.ops.aten.conv2d.default,
1807                    unary_map[unary_op][1],
1808                    torch.ops.quantized_decomposed.quantize_per_tensor.default,
1809                    torch.ops.quantized_decomposed.dequantize_per_tensor.default,
1810                ]
1811                self._test_quantizer(
1812                    m,
1813                    example_inputs,
1814                    quantizer,
1815                    node_occurrence,
1816                    node_list,
1817                    is_qat=True,
1818                )
1819
1820    @skipIfTorchDynamo("very slow")
1821    @skipIfNoX86
1822    def test_qat_conv2d_binary(self):
1823        """
1824        Test qat pattern of conv2d_bn with binary post ops (such as add) with X86InductorQuantizer.
1825        Currently, only add as binary post op is supported.
1826        """
1827        example_inputs = (torch.randn(2, 3, 6, 6),)
1828        quantizer = X86InductorQuantizer().set_global(
1829            xiq.get_default_x86_inductor_quantization_config(is_qat=True)
1830        )
1831        with override_quantized_engine("x86"):
1832            for inplace_add in [True, False]:
1833                m = TestHelperModules.Conv2dAddModule(
1834                    inplace_add=inplace_add, with_bn=True
1835                )
1836                node_occurrence = {
1837                    # one for input and weight of the conv
1838                    # one for output for the add
1839                    # one for extra input node of add
1840                    torch.ops.quantized_decomposed.quantize_per_tensor.default: 3,
1841                    torch.ops.quantized_decomposed.dequantize_per_tensor.default: 3,
1842                    # quantize_per_channel for weights are const propagated
1843                    torch.ops.quantized_decomposed.quantize_per_channel.default: 0,
1844                    torch.ops.quantized_decomposed.dequantize_per_channel.default: 1,
1845                    # BN should be folded into Conv
1846                    torch.ops.aten._native_batch_norm_legit.default: 0,
1847                }
1848                node_list = [
1849                    torch.ops.quantized_decomposed.quantize_per_tensor.default,
1850                    torch.ops.quantized_decomposed.dequantize_per_tensor.default,
1851                    torch.ops.aten.conv2d.default,
1852                    (
1853                        torch.ops.aten.add_.Tensor
1854                        if inplace_add
1855                        else torch.ops.aten.add.Tensor
1856                    ),
1857                    torch.ops.quantized_decomposed.quantize_per_tensor.default,
1858                    torch.ops.quantized_decomposed.dequantize_per_tensor.default,
1859                ]
1860                self._test_quantizer(
1861                    m,
1862                    example_inputs,
1863                    quantizer,
1864                    node_occurrence,
1865                    node_list,
1866                    is_qat=True,
1867                )
1868
1869    @skipIfTorchDynamo("very slow")
1870    @skipIfNoX86
1871    def test_qat_conv2d_binary2(self):
1872        """
1873        Test qat Pattern:
1874            tmp = bn1(conv2d_1(x))
1875            tmp2 = bn2(conv2d_2(tmp))
1876            return tmp + tmp2
1877        Since conv2d_1 has 2 users, we should annotate conv2d_2 for binary fusion instead of conv2d_1
1878        """
1879        example_inputs = (torch.randn(2, 3, 6, 6),)
1880        quantizer = X86InductorQuantizer().set_global(
1881            xiq.get_default_x86_inductor_quantization_config(is_qat=True)
1882        )
1883        inplace_add_list = [True, False]
1884        with override_quantized_engine("x86"), torch.no_grad():
1885            for inplace_add in inplace_add_list:
1886                m = TestHelperModules.Conv2dAddModule2(inplace_add=inplace_add)
1887                node_occurrence = {
1888                    torch.ops.quantized_decomposed.quantize_per_tensor.default: 3,
1889                    torch.ops.quantized_decomposed.dequantize_per_tensor.default: 4,
1890                    # quantize_per_channel for weights are const propagated
1891                    torch.ops.quantized_decomposed.quantize_per_channel.default: 0,
1892                    torch.ops.quantized_decomposed.dequantize_per_channel.default: 2,
1893                    # BN should be folded into Conv
1894                    torch.ops.aten._native_batch_norm_legit.default: 0,
1895                }
1896                node_list = [
1897                    torch.ops.quantized_decomposed.quantize_per_tensor.default,
1898                    torch.ops.quantized_decomposed.dequantize_per_tensor.default,
1899                    torch.ops.aten.conv2d.default,
1900                    torch.ops.quantized_decomposed.quantize_per_tensor.default,
1901                    (
1902                        torch.ops.aten.add_.Tensor
1903                        if inplace_add
1904                        else torch.ops.aten.add.Tensor
1905                    ),
1906                ]
1907                self._test_quantizer(
1908                    m,
1909                    example_inputs,
1910                    quantizer,
1911                    node_occurrence,
1912                    node_list,
1913                    is_qat=True,
1914                )
1915
1916    @skipIfTorchDynamo("very slow")
1917    @skipIfNoX86
1918    def test_qat_conv2d_binary_unary(self):
1919        """
1920        Test QAT pattern of conv2d_bn with binary + unary post ops (such as add + relu) with X86InductorQuantizer.
1921        Currently, only add as binary post op and relu as unary post op are supported.
1922        """
1923        example_inputs = (torch.randn(2, 3, 6, 6),)
1924        quantizer = X86InductorQuantizer().set_global(
1925            xiq.get_default_x86_inductor_quantization_config(is_qat=True)
1926        )
1927        with override_quantized_engine("x86"):
1928            m = TestHelperModules.Conv2dAddReLUModule(with_bn=True)
1929            node_occurrence = {
1930                # one for input for conv
1931                # one for output for the relu
1932                # one for extra input node of add
1933                torch.ops.quantized_decomposed.quantize_per_tensor.default: 3,
1934                torch.ops.quantized_decomposed.dequantize_per_tensor.default: 3,
1935                # note: quantize op for weights are const propagated
1936                torch.ops.quantized_decomposed.quantize_per_channel.default: 0,
1937                torch.ops.quantized_decomposed.dequantize_per_channel.default: 1,
1938                # BN should be folded into Conv
1939                torch.ops.aten._native_batch_norm_legit.default: 0,
1940            }
1941            node_list = [
1942                torch.ops.quantized_decomposed.quantize_per_tensor.default,
1943                torch.ops.quantized_decomposed.dequantize_per_tensor.default,
1944                torch.ops.aten.conv2d.default,
1945                torch.ops.aten.add.Tensor,
1946                torch.ops.quantized_decomposed.quantize_per_tensor.default,
1947                torch.ops.quantized_decomposed.dequantize_per_tensor.default,
1948            ]
1949            self._test_quantizer(
1950                m,
1951                example_inputs,
1952                quantizer,
1953                node_occurrence,
1954                node_list,
1955                is_qat=True,
1956            )
1957
1958    @skipIfNoX86
1959    def test_dynamic_quant_linear(self):
1960        """
1961        Test pattern of dynamic quantization of linear with X86InductorQuantizer.
1962        """
1963        with override_quantized_engine("x86"), torch.no_grad():
1964            m = TestHelperModules.SelfAttnLikeModule(input_dim=64).eval()
1965            example_inputs = (torch.randn(1, 4, 64),)
1966            quantizer = X86InductorQuantizer().set_global(
1967                xiq.get_default_x86_inductor_quantization_config(is_dynamic=True)
1968            )
1969            node_occurrence = {
1970                torch.ops.quantized_decomposed.choose_qparams.tensor: 1,
1971                torch.ops.quantized_decomposed.quantize_per_tensor.tensor: 1,
1972                torch.ops.quantized_decomposed.dequantize_per_tensor.tensor: 1,
1973                # quantize_per_channel for weights are const propagated
1974                torch.ops.quantized_decomposed.quantize_per_channel.default: 0,
1975                torch.ops.quantized_decomposed.dequantize_per_channel.default: 3,
1976            }
1977            node_list = [
1978                torch.ops.quantized_decomposed.choose_qparams.tensor,
1979                torch.ops.quantized_decomposed.quantize_per_tensor.tensor,
1980                torch.ops.quantized_decomposed.dequantize_per_tensor.tensor,
1981                torch.ops.aten.linear.default,
1982            ]
1983            self._test_quantizer(
1984                m,
1985                example_inputs,
1986                quantizer,
1987                node_occurrence,
1988                node_list,
1989            )
1990
1991    @skipIfNoX86
1992    def test_qat_dynamic_quant_linear(self):
1993        """
1994        Test pattern of qat dynamic quantization of linear with X86InductorQuantizer.
1995        """
1996        with override_quantized_engine("x86"), torch.no_grad():
1997            m = TestHelperModules.SelfAttnLikeModule(input_dim=64).eval()
1998            example_inputs = (torch.randn(1, 4, 64),)
1999            quantizer = X86InductorQuantizer().set_global(
2000                xiq.get_default_x86_inductor_quantization_config(
2001                    is_qat=True, is_dynamic=True
2002                )
2003            )
2004            node_occurrence = {
2005                torch.ops.quantized_decomposed.choose_qparams.tensor: 1,
2006                torch.ops.quantized_decomposed.quantize_per_tensor.tensor: 1,
2007                torch.ops.quantized_decomposed.dequantize_per_tensor.tensor: 1,
2008                # quantize_per_channel for weights are const propagated
2009                torch.ops.quantized_decomposed.quantize_per_channel.default: 0,
2010                torch.ops.quantized_decomposed.dequantize_per_channel.default: 3,
2011            }
2012            node_list = [
2013                torch.ops.quantized_decomposed.choose_qparams.tensor,
2014                torch.ops.quantized_decomposed.quantize_per_tensor.tensor,
2015                torch.ops.quantized_decomposed.dequantize_per_tensor.tensor,
2016                torch.ops.aten.linear.default,
2017            ]
2018            self._test_quantizer(
2019                m,
2020                example_inputs,
2021                quantizer,
2022                node_occurrence,
2023                node_list,
2024                is_qat=True,
2025            )
2026
2027    @skipIfNoX86
2028    def test_set_module_name_qconfig(self):
2029        """Test case for quantizing a specific submodule by configuring `set_module_name_qconfig`.
2030
2031        Expect that all linear layers within the submodule `sub` are quantized.
2032        """
2033
2034        class Sub(torch.nn.Module):
2035            def __init__(self) -> None:
2036                super().__init__()
2037                self.linear1 = torch.nn.Linear(5, 10)
2038                self.relu1 = torch.nn.ReLU(inplace=False)
2039                self.linear2 = torch.nn.Linear(10, 5)
2040
2041            def forward(self, x):
2042                x = self.linear1(x)
2043                x = self.relu1(x)
2044                x = self.linear2(x)
2045                return x
2046
2047        class M(torch.nn.Module):
2048            def __init__(self) -> None:
2049                super().__init__()
2050                self.linear = torch.nn.Linear(5, 5)
2051                self.sub = Sub()
2052
2053            def forward(self, x):
2054                x = self.linear(x)
2055                x = self.sub(x)
2056                return x
2057
2058        m = M().eval()
2059        example_inputs = (torch.randn(3, 5),)
2060        # Set global to `None` and then default config for a specific submodule.
2061        quantizer = X86InductorQuantizer()
2062        quantizer.set_module_name_qconfig(
2063            "sub", xiq.get_default_x86_inductor_quantization_config()
2064        )
2065        node_occurrence = {
2066            torch.ops.aten.linear.default: 3,
2067            # quantize and dequantize the input of two linear layers from `sub`
2068            torch.ops.quantized_decomposed.quantize_per_tensor.default: 2,
2069            torch.ops.quantized_decomposed.dequantize_per_tensor.default: 2,
2070            # dequantize the weight of two linear layers from `sub`
2071            torch.ops.quantized_decomposed.dequantize_per_channel.default: 2,
2072        }
2073        node_list = [
2074            # first linear is not quantized
2075            torch.ops.aten.linear.default,
2076            # two  Q/DQ pairs for two linear layers from `sub`
2077            torch.ops.quantized_decomposed.quantize_per_tensor.default,
2078            torch.ops.quantized_decomposed.dequantize_per_tensor.default,
2079            torch.ops.aten.linear.default,
2080            torch.ops.quantized_decomposed.quantize_per_tensor.default,
2081            torch.ops.quantized_decomposed.dequantize_per_tensor.default,
2082            torch.ops.aten.linear.default,
2083        ]
2084        self._test_quantizer(
2085            m,
2086            example_inputs,
2087            quantizer,
2088            node_occurrence,
2089            node_list,
2090        )
2091
2092    @skipIfNoX86
2093    def test_set_module_name_qconfig_with_underscores(self) -> None:
2094        """Test that if a module name has an underscore, we can still quantize it."""
2095
2096        class M(torch.nn.Module):
2097            def __init__(self) -> None:
2098                super().__init__()
2099                # This module name has underscores, which can be part of a mangled name.
2100                self.foo_bar = torch.nn.Linear(2, 2)
2101                self.baz = torch.nn.Linear(2, 2)
2102
2103            def forward(self, x):
2104                return self.baz(self.foo_bar(x))
2105
2106        # Set global to no quantization and then default config for a specific submodule whose name includes an underscore.
2107        quantizer = X86InductorQuantizer()
2108        quantizer.set_module_name_qconfig(
2109            "foo_bar", xiq.get_default_x86_inductor_quantization_config()
2110        )
2111        example_inputs = (torch.randn(2, 2),)
2112        m = M().eval()
2113        m = capture_pre_autograd_graph(m, example_inputs)
2114        m = prepare_pt2e(m, quantizer)
2115        # Use a linear count instead of names because the names might change, but
2116        # the order should be the same.
2117        count = 0
2118        for n in m.graph.nodes:
2119            if n.op == "call_function" and n.target == torch.ops.aten.linear.default:
2120                # Get the weight observer to see the per-channel vs per-tensor.
2121                weight_observer_node = n.args[1]
2122                if count == 0:
2123                    # for foo_bar.
2124                    self.assertEqual(
2125                        weight_observer_node.op,
2126                        "call_module",
2127                        f"The op of linear({count})'s weight_observer_node is {weight_observer_node.op} instead call_module",
2128                    )
2129                    observer_instance = getattr(m, weight_observer_node.target)
2130                    self.assertEqual(
2131                        observer_instance.qscheme, torch.per_channel_symmetric
2132                    )
2133                else:
2134                    # For baz it should have no observer at all.
2135                    self.assertNotEqual(
2136                        weight_observer_node.op,
2137                        "call_module",
2138                        f"The op of linear({count})'s weight_observer_node is {weight_observer_node.op} instead call_module",
2139                    )
2140                count += 1
2141
2142    @skipIfNoX86
2143    def test_set_module_name_and_module_type_case1(self):
2144        """Test that set `module_name_qconfig` and `module_type_qconfig` at the same time.
2145
2146        Expect that all linear layers are not quantized except the last one.
2147        """
2148
2149        class M(torch.nn.Module):
2150            def __init__(self) -> None:
2151                super().__init__()
2152                self.linear1 = torch.nn.Linear(5, 10)
2153                self.linear2 = torch.nn.Linear(10, 5)
2154                self.sub = torch.nn.Linear(5, 5)
2155
2156            def forward(self, x):
2157                x = self.linear1(x)
2158                x = self.linear2(x)
2159                x = self.sub(x)
2160                return x
2161
2162        m = M().eval()
2163        example_inputs = (torch.randn(3, 5),)
2164        # Set `sub` with default config and then `None` for all `Linear`.
2165        # The config set by `set_module_name_qconfig` has higher priority than `set_module_type_qconfig`.
2166        quantizer = X86InductorQuantizer()
2167        quantizer.set_module_name_qconfig(
2168            "sub", xiq.get_default_x86_inductor_quantization_config()
2169        ).set_module_type_qconfig(torch.nn.Linear, None)
2170
2171        node_occurrence = {
2172            torch.ops.aten.linear.default: 3,
2173            # quantize and dequantize the input of the last linear
2174            torch.ops.quantized_decomposed.quantize_per_tensor.default: 1,
2175            torch.ops.quantized_decomposed.dequantize_per_tensor.default: 1,
2176            # dequantize the weight of the last linear
2177            torch.ops.quantized_decomposed.dequantize_per_channel.default: 1,
2178        }
2179        node_list = [
2180            # first and second linear are not quantized
2181            torch.ops.aten.linear.default,
2182            torch.ops.aten.linear.default,
2183            # last linear is quantized
2184            torch.ops.quantized_decomposed.quantize_per_tensor.default,
2185            torch.ops.quantized_decomposed.dequantize_per_tensor.default,
2186            torch.ops.aten.linear.default,
2187        ]
2188        self._test_quantizer(
2189            m,
2190            example_inputs,
2191            quantizer,
2192            node_occurrence,
2193            node_list,
2194        )
2195
2196    @skipIfNoX86
2197    def test_set_module_name_and_module_type_case2(self):
2198        """Test that set `module_name_qconfig` and `module_type_qconfig` at the same time.
2199
2200        Expect that all linear layers are quantized except the last one.
2201        """
2202
2203        class M(torch.nn.Module):
2204            def __init__(self) -> None:
2205                super().__init__()
2206                self.linear1 = torch.nn.Linear(5, 10)
2207                self.linear2 = torch.nn.Linear(10, 5)
2208                self.sub = torch.nn.Linear(5, 5)
2209
2210            def forward(self, x):
2211                x = self.linear1(x)
2212                x = self.linear2(x)
2213                x = self.sub(x)
2214                return x
2215
2216        m = M().eval()
2217        example_inputs = (torch.randn(3, 5),)
2218        # Set `sub` with None and then default config for a all `Linear`.
2219        quantizer = X86InductorQuantizer()
2220        quantizer.set_module_name_qconfig("sub", None).set_module_type_qconfig(
2221            torch.nn.Linear, xiq.get_default_x86_inductor_quantization_config()
2222        )
2223
2224        node_occurrence = {
2225            torch.ops.aten.linear.default: 3,
2226            # quantize and dequantize the input and output of the first and second linear
2227            torch.ops.quantized_decomposed.quantize_per_tensor.default: 2,
2228            torch.ops.quantized_decomposed.dequantize_per_tensor.default: 2,
2229            # dequantize the weight of the first and second linear
2230            torch.ops.quantized_decomposed.dequantize_per_channel.default: 2,
2231        }
2232        node_list = [
2233            # Q/DQ for first lienar
2234            torch.ops.quantized_decomposed.quantize_per_tensor.default,
2235            torch.ops.quantized_decomposed.dequantize_per_tensor.default,
2236            torch.ops.aten.linear.default,
2237            # Q/DQ for second lienar
2238            torch.ops.quantized_decomposed.quantize_per_tensor.default,
2239            torch.ops.quantized_decomposed.dequantize_per_tensor.default,
2240            torch.ops.aten.linear.default,
2241            # last linear is not quantized
2242            torch.ops.aten.linear.default,
2243        ]
2244        self._test_quantizer(
2245            m,
2246            example_inputs,
2247            quantizer,
2248            node_occurrence,
2249            node_list,
2250        )
2251
2252    @skipIfNoX86
2253    def test_set_module_name_qconfig_for_dynamic_quant(self):
2254        """Test that quantize a specific submodule for dynamic quantization."""
2255
2256        with override_quantized_engine("x86"), torch.no_grad():
2257            for is_qat in [False, True]:
2258                m = TestHelperModules.SelfAttnLikeModule(input_dim=64).eval()
2259                example_inputs = (torch.randn(1, 4, 64),)
2260                # only quantize `q_proj` `v_proj`
2261                dynamic_config = xiq.get_default_x86_inductor_quantization_config(
2262                    is_dynamic=True, is_qat=is_qat
2263                )
2264                quantizer = (
2265                    X86InductorQuantizer()
2266                    .set_module_name_qconfig("q_proj", dynamic_config)
2267                    .set_module_name_qconfig("v_proj", dynamic_config)
2268                )
2269                node_occurrence = {
2270                    # quantize and dequantize the input
2271                    torch.ops.quantized_decomposed.choose_qparams.tensor: 1,
2272                    torch.ops.quantized_decomposed.quantize_per_tensor.tensor: 1,
2273                    torch.ops.quantized_decomposed.dequantize_per_tensor.tensor: 1,
2274                    # dequantize the weight of q_proj and v_proj
2275                    torch.ops.quantized_decomposed.dequantize_per_channel.default: 2,
2276                }
2277                node_list = [
2278                    # quantize and dequantize the input
2279                    torch.ops.quantized_decomposed.choose_qparams.tensor,
2280                    torch.ops.quantized_decomposed.quantize_per_tensor.tensor,
2281                    torch.ops.quantized_decomposed.dequantize_per_tensor.tensor,
2282                    # q_proj
2283                    torch.ops.aten.linear.default,
2284                    # k_proj
2285                    torch.ops.aten.linear.default,
2286                    # v_proj
2287                    torch.ops.aten.linear.default,
2288                ]
2289                self._test_quantizer(
2290                    m,
2291                    example_inputs,
2292                    quantizer,
2293                    node_occurrence,
2294                    node_list,
2295                    is_qat=is_qat,
2296                )
2297
2298    @skipIfNoX86
2299    def test_set_module_name_with_mixed_configs(self):
2300        """Test case for setting module names with mixed static/dynamic or QAT/non-QAT configurations.
2301
2302        The config for 'v_proj' will always be ignored and raise a warning.
2303        """
2304        with override_quantized_engine("x86"), torch.no_grad():
2305            with self.assertWarns(UserWarning) as context:
2306                for q_is_dynamic, v_is_dynamic, q_is_qat, v_is_qat in itertools.product(
2307                    [False, True], repeat=4
2308                ):
2309                    if q_is_dynamic == v_is_dynamic and q_is_qat == v_is_qat:
2310                        continue
2311                    m = TestHelperModules.SelfAttnLikeModule(input_dim=64).eval()
2312                    example_inputs = (torch.randn(1, 4, 64),)
2313                    quantizer = (
2314                        X86InductorQuantizer()
2315                        .set_module_name_qconfig(
2316                            "q_proj",
2317                            xiq.get_default_x86_inductor_quantization_config(
2318                                is_qat=q_is_qat, is_dynamic=q_is_dynamic
2319                            ),
2320                        )
2321                        .set_module_name_qconfig(
2322                            "v_proj",
2323                            xiq.get_default_x86_inductor_quantization_config(
2324                                is_qat=v_is_qat, is_dynamic=v_is_dynamic
2325                            ),
2326                        )
2327                    )
2328                    quant_op = (
2329                        torch.ops.quantized_decomposed.quantize_per_tensor.tensor
2330                        if q_is_dynamic
2331                        else torch.ops.quantized_decomposed.quantize_per_tensor.default
2332                    )
2333                    dequant_op = (
2334                        torch.ops.quantized_decomposed.dequantize_per_tensor.tensor
2335                        if q_is_dynamic
2336                        else torch.ops.quantized_decomposed.dequantize_per_tensor.default
2337                    )
2338                    node_occurrence = {
2339                        # quantize and dequantize the input
2340                        quant_op: 1,
2341                        dequant_op: 1,
2342                        # only `q_proj` was quantized, dequantize its weight
2343                        torch.ops.quantized_decomposed.dequantize_per_channel.default: 1,
2344                    }
2345                    node_list = [
2346                        # quantize and dequantize the input
2347                        quant_op,
2348                        dequant_op,
2349                        # q_proj
2350                        torch.ops.aten.linear.default,
2351                        # k_proj/v_proj
2352                        torch.ops.aten.linear.default,
2353                        torch.ops.aten.linear.default,
2354                    ]
2355                    self._test_quantizer(
2356                        m,
2357                        example_inputs,
2358                        quantizer,
2359                        node_occurrence,
2360                        node_list,
2361                        is_qat=q_is_qat,
2362                    )
2363                    warning_msg = (
2364                        "Mixed QAT and Non-QAT"
2365                        if q_is_qat != v_is_qat
2366                        else "Mixed dynamic and static"
2367                    )
2368                    self.assertTrue(
2369                        any(
2370                            warning_msg in msg
2371                            for msg in [str(w.message) for w in context.warnings]
2372                        )
2373                    )
2374
2375    @skipIfNoX86
2376    def test_set_module_name_and_module_type_with_mixed_configs(self):
2377        """Test that set `module_name_qconfig` and `module_type_qconfig` at the same time with mixed the configs.
2378
2379        Expect that only the last linear(`sub`) is quantized using static quantization.
2380        """
2381
2382        class M(torch.nn.Module):
2383            def __init__(self) -> None:
2384                super().__init__()
2385                self.linear1 = torch.nn.Linear(5, 10)
2386                self.linear2 = torch.nn.Linear(10, 5)
2387                self.sub = torch.nn.Linear(5, 5)
2388
2389            def forward(self, x):
2390                x = self.linear1(x)
2391                x = self.linear2(x)
2392                x = self.sub(x)
2393                return x
2394
2395        m = M().eval()
2396        example_inputs = (torch.randn(3, 5),)
2397        # Set `sub` with static config and then dynamic config for a all `Linear`(ignored).
2398        quantizer = X86InductorQuantizer()
2399        quantizer.set_module_name_qconfig(
2400            "sub", xiq.get_default_x86_inductor_quantization_config(is_dynamic=False)
2401        ).set_module_type_qconfig(
2402            torch.nn.Linear,
2403            xiq.get_default_x86_inductor_quantization_config(is_dynamic=True),
2404        )
2405
2406        node_occurrence = {
2407            torch.ops.aten.linear.default: 3,
2408            # quantize and dequantize the input of the last linear
2409            torch.ops.quantized_decomposed.quantize_per_tensor.default: 1,
2410            torch.ops.quantized_decomposed.dequantize_per_tensor.default: 1,
2411            # dequantize the weight of the last linear
2412            torch.ops.quantized_decomposed.dequantize_per_channel.default: 1,
2413        }
2414        node_list = [
2415            # first and second linear are not quantized
2416            torch.ops.aten.linear.default,
2417            torch.ops.aten.linear.default,
2418            # Q/DQ pairs for the last linear
2419            torch.ops.quantized_decomposed.quantize_per_tensor.default,
2420            torch.ops.quantized_decomposed.dequantize_per_tensor.default,
2421            torch.ops.aten.linear.default,
2422        ]
2423        self._test_quantizer(
2424            m,
2425            example_inputs,
2426            quantizer,
2427            node_occurrence,
2428            node_list,
2429        )
2430
2431    @skipIfNoX86
2432    def test_filter_conv2d_recipe(self):
2433        """
2434        Test removing conv2d from default recipe of X86InductorQuantizer.
2435        """
2436        with override_quantized_engine("x86"), torch.no_grad():
2437            m = TestHelperModules.Conv2dUnaryModule(torch.nn.ReLU(inplace=False)).eval()
2438            example_inputs = (torch.randn(2, 3, 16, 16),)
2439            quantizer = X86InductorQuantizer().set_global(
2440                xiq.get_default_x86_inductor_quantization_config()
2441            )
2442            quantizer.set_module_type_qconfig(torch.nn.Conv2d, None)
2443            node_occurrence = {
2444                # one for input and weight of the conv
2445                torch.ops.quantized_decomposed.quantize_per_tensor.default: 0,
2446                torch.ops.quantized_decomposed.dequantize_per_tensor.default: 0,
2447                # note: quantize op for weights are const propagated
2448                torch.ops.quantized_decomposed.quantize_per_channel.default: 0,
2449                torch.ops.quantized_decomposed.dequantize_per_channel.default: 0,
2450            }
2451            node_list = [
2452                torch.ops.aten.conv2d.default,
2453                torch.ops.aten.relu.default,
2454            ]
2455            self._test_quantizer(
2456                m,
2457                example_inputs,
2458                quantizer,
2459                node_occurrence,
2460                node_list,
2461            )
2462
2463    @skipIfNoX86
2464    def test_filter_linear_recipe(self):
2465        """
2466        Test removing linear from default recipe of X86InductorQuantizer.
2467        """
2468        with override_quantized_engine("x86"), torch.no_grad():
2469            m = TestHelperModules.LinearUnaryModule(
2470                use_bias=True,
2471                postop=nn.ReLU,
2472            ).eval()
2473            example_inputs = (torch.randn(2, 4),)
2474            quantizer = X86InductorQuantizer().set_global(
2475                xiq.get_default_x86_inductor_quantization_config()
2476            )
2477            quantizer.set_function_type_qconfig(torch.nn.functional.linear, None)
2478            node_occurrence = {
2479                # one for input and weight of the conv
2480                torch.ops.quantized_decomposed.quantize_per_tensor.default: 0,
2481                torch.ops.quantized_decomposed.dequantize_per_tensor.default: 0,
2482                # note: quantize op for weights are const propagated
2483                torch.ops.quantized_decomposed.quantize_per_channel.default: 0,
2484                torch.ops.quantized_decomposed.dequantize_per_channel.default: 0,
2485            }
2486            node_list = [
2487                torch.ops.aten.linear.default,
2488                torch.ops.aten.relu.default,
2489            ]
2490            self._test_quantizer(
2491                m,
2492                example_inputs,
2493                quantizer,
2494                node_occurrence,
2495                node_list,
2496            )
2497
2498    @skipIfNoX86
2499    def test_filter_maxpool2d_recipe(self):
2500        """
2501        Test removing maxpool2d from default recipe of X86InductorQuantizer.
2502        """
2503        with override_quantized_engine("x86"), torch.no_grad():
2504            m = TestHelperModules.Conv2dUnaryModule(torch.nn.ReLU(inplace=False)).eval()
2505            example_inputs = (torch.randn(2, 3, 16, 16),)
2506            quantizer = X86InductorQuantizer().set_global(
2507                xiq.get_default_x86_inductor_quantization_config()
2508            )
2509            quantizer.set_function_type_qconfig(torch.nn.functional.max_pool2d, None)
2510            node_occurrence = {
2511                # one for input and weight of the conv
2512                torch.ops.quantized_decomposed.quantize_per_tensor.default: 1,
2513                torch.ops.quantized_decomposed.dequantize_per_tensor.default: 1,
2514                # note: quantize op for weights are const propagated
2515                torch.ops.quantized_decomposed.quantize_per_channel.default: 0,
2516                torch.ops.quantized_decomposed.dequantize_per_channel.default: 1,
2517            }
2518            node_list = [
2519                torch.ops.quantized_decomposed.quantize_per_tensor.default,
2520                torch.ops.quantized_decomposed.dequantize_per_tensor.default,
2521                torch.ops.aten.conv2d.default,
2522                torch.ops.aten.relu.default,
2523                torch.ops.aten.max_pool2d.default,
2524            ]
2525            self._test_quantizer(
2526                m,
2527                example_inputs,
2528                quantizer,
2529                node_occurrence,
2530                node_list,
2531            )
2532
2533    @skipIfNoX86
2534    def test_attention_block(self):
2535        """
2536        Test pattern of Attention like Block with X86InductorQuantizer.
2537        """
2538        for annotate_matmul in [False, True]:
2539            with override_quantized_engine("x86"), torch.no_grad():
2540                m = TestHelperModules.SelfAttnLikeModule(
2541                    input_dim=64 * 16,
2542                    transpose_for_score=True,
2543                    num_attention_heads=16,
2544                    attention_head_size=64,
2545                ).eval()
2546                example_inputs = (torch.randn(2, 384, 1024),)
2547
2548                m(*example_inputs)
2549
2550                quantizer = X86InductorQuantizer().set_global(
2551                    xiq.get_default_x86_inductor_quantization_config()
2552                )
2553
2554                if annotate_matmul:
2555                    quantizer.set_function_type_qconfig(
2556                        torch.matmul, quantizer.get_global_quantization_config()
2557                    )
2558
2559                node_occurrence = {
2560                    torch.ops.quantized_decomposed.quantize_per_tensor.default: (
2561                        5 if annotate_matmul else 1
2562                    ),
2563                    torch.ops.quantized_decomposed.dequantize_per_tensor.default: (
2564                        7 if annotate_matmul else 3
2565                    ),
2566                    # quantize_per_channel for weights are const propagated
2567                    torch.ops.quantized_decomposed.quantize_per_channel.default: 0,
2568                    torch.ops.quantized_decomposed.dequantize_per_channel.default: 3,
2569                }
2570                if annotate_matmul:
2571                    node_list = [
2572                        torch.ops.quantized_decomposed.quantize_per_tensor.default,
2573                        torch.ops.quantized_decomposed.dequantize_per_tensor.default,
2574                        torch.ops.aten.linear.default,
2575                        torch.ops.aten.view.default,
2576                        torch.ops.aten.permute.default,
2577                        torch.ops.quantized_decomposed.quantize_per_tensor.default,
2578                        torch.ops.quantized_decomposed.dequantize_per_tensor.default,
2579                        torch.ops.aten.matmul.default,
2580                        torch.ops.aten.div.Tensor,
2581                        torch.ops.aten.softmax.int,
2582                    ]
2583                else:
2584                    node_list = [
2585                        torch.ops.quantized_decomposed.quantize_per_tensor.default,
2586                        torch.ops.quantized_decomposed.dequantize_per_tensor.default,
2587                        torch.ops.aten.linear.default,
2588                        torch.ops.aten.view.default,
2589                        torch.ops.aten.permute.default,
2590                        torch.ops.aten.matmul.default,
2591                        torch.ops.aten.div.Tensor,
2592                        torch.ops.aten.softmax.int,
2593                    ]
2594                self._test_quantizer(
2595                    m,
2596                    example_inputs,
2597                    quantizer,
2598                    node_occurrence,
2599                    node_list,
2600                )
2601