xref: /aosp_15_r20/external/pytorch/test/quantization/jit/test_quantize_jit.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: quantization"]
2
3# torch
4import io
5import itertools
6import unittest
7
8# Standard library
9from typing import List, Tuple
10
11import torch
12import torch.jit
13import torch.jit.quantized
14import torch.nn as nn
15import torch.nn.functional as F
16
17# torch.ao.quantization
18from torch.ao.quantization import (
19    default_dynamic_qconfig,
20    default_histogram_observer,
21    default_observer,
22    default_per_channel_weight_observer,
23    default_qconfig,
24    default_weight_observer,
25    float16_dynamic_qconfig,
26    fuse_modules,
27    get_default_qconfig,
28    per_channel_dynamic_qconfig,
29    PlaceholderObserver,
30    QConfig,
31    quantize,
32    quantize_dynamic,
33    quantize_dynamic_jit,
34    quantize_jit,
35)
36
37# torch.ao.quantization.quantize_jit
38from torch.ao.quantization.quantize_jit import (
39    convert_dynamic_jit,
40    convert_jit,
41    fuse_conv_bn_jit,
42    prepare_dynamic_jit,
43    prepare_jit,
44    script_qconfig,
45)
46from torch.jit._recursive import wrap_cpp_module
47from torch.testing import FileCheck
48
49# Annotated models
50from torch.testing._internal.common_quantization import (
51    AnnotatedConvBnModel,
52    AnnotatedConvModel,
53    AnnotatedConvTransposeModel,
54    AnnotatedNestedModel,
55    AnnotatedSingleLayerLinearModel,
56    AnnotatedSkipQuantModel,
57    ConvBnModel,
58    ConvModel,
59    ConvTransposeModel,
60    default_per_channel_qconfig,
61    get_script_module,
62    NestedModel,
63    QuantizationTestCase,
64    SingleLayerLinearModel,
65    skipIfNoFBGEMM,
66    SkipQuantModel,
67    test_only_eval_fn,
68)
69
70# Testing utils
71from torch.testing._internal.common_quantized import (
72    override_qengines,
73    qengine_is_fbgemm,
74    qengine_is_qnnpack,
75)
76from torch.testing._internal.common_utils import set_default_dtype
77from torch.testing._internal.jit_utils import (
78    attrs_with_prefix,
79    get_forward,
80    get_forward_graph,
81)
82
83
84class TestQuantizeJitPasses(QuantizationTestCase):
85    """Test graph mode quantization passes used by quantize_jit"""
86
87    def test_skip_dequant_constant_prop(self):
88        class M(torch.nn.Module):
89            def __init__(self) -> None:
90                super().__init__()
91                self.conv = torch.nn.Conv2d(3, 5, 3).float()
92
93            def forward(self, x):
94                return self.conv(x)
95
96        m = torch.jit.script(M())
97        observer = default_per_channel_weight_observer.with_args(ch_axis=1)
98        qconfig_dict = {"": QConfig(activation=default_observer, weight=observer)}
99        m = prepare_jit(m, qconfig_dict)
100        data = torch.randn(1, 3, 10, 10, dtype=torch.float)
101
102        m(data)
103        m = convert_jit(m, debug=True)
104
105        freezed = torch.jit.freeze(m)
106        freezed(data)
107
108        # After freezing, weight becomes Constant.
109        # We have this pattern in the original graph: Constant f32_weight -> quant -> dequant
110        # After skipping dequant during Constant Propagation, the resulting graph will be:
111        # Constant int8_weight -> dequant
112        FileCheck().check_count("aten::quantize_per_tensor", 2, exactly=True).run(
113            freezed.graph
114        )
115        FileCheck().check_count("aten::quantize_per_channel", 0, exactly=True).run(
116            freezed.graph
117        )
118        FileCheck().check_count("aten::dequantize", 3, exactly=True).run(freezed.graph)
119        FileCheck().check("aten::quantize_per_tensor").check_next(
120            "aten::dequantize"
121        ).check_not("aten::quantize_per_channel").check("aten::dequantize").check_next(
122            "aten::conv2d"
123        ).check_next(
124            "aten::quantize_per_tensor"
125        ).check_next(
126            "aten::dequantize"
127        ).run(
128            freezed.graph
129        )
130
131    def test_foldbn_trivial(self):
132        bn_module = {2: torch.nn.BatchNorm2d, 3: torch.nn.BatchNorm3d}
133        conv_module = {2: torch.nn.Conv2d, 3: torch.nn.Conv3d}
134
135        # Test trivial case
136        class TestModule(torch.nn.Module):
137            def __init__(self, dim):
138                super().__init__()
139                self.conv = conv_module[dim](1, 20, 5, 1)
140                self.bn = bn_module[dim](num_features=20)
141                self.bn.eps = 0.0023
142
143            def forward(self, x):
144                x = self.conv(x)
145                x = self.bn(x)
146                return x
147
148        options = itertools.product([True, False], [2, 3])
149        data = {2: torch.rand(1, 1, 6, 6), 3: torch.rand(1, 1, 6, 6, 6)}
150        # Check that the transformation doesn't change numerics
151        for tracing, dim in options:
152            eager = TestModule(dim).eval()
153            x = data[dim]
154            scripted_or_traced = get_script_module(eager, tracing, x).eval()
155            # Check that in the original script module's forward we have two
156            # CallMethod nodes. One of them should be for conv.forward and the other
157            # for bn.forward.
158            FileCheck().check_count(
159                'prim::CallMethod[name="forward"]', 2, exactly=True
160            ).run(str(get_forward(scripted_or_traced._c).graph))
161
162            # Run FoldConvBatchnorm pass.
163            scripted_or_traced = fuse_conv_bn_jit(scripted_or_traced)
164
165            # Check that after the pass one of the CallMethods is gone (supposedly,
166            # the bn.forward).
167            FileCheck().check_count(
168                'prim::CallMethod[name="forward"]', 1, exactly=True
169            ).run(str(get_forward_graph(scripted_or_traced._c)))
170
171            # Check that the transformation doesn't change numerics
172            self.assertEqual(eager(x), scripted_or_traced(x))
173
174    def test_foldbn_trivial_nobias(self):
175        bn_module = {2: torch.nn.BatchNorm2d, 3: torch.nn.BatchNorm3d}
176        conv_module = {2: torch.nn.Conv2d, 3: torch.nn.Conv3d}
177
178        # Test trivial case
179        class TestModule(torch.nn.Module):
180            def __init__(self, dim):
181                super().__init__()
182                self.conv = conv_module[dim](1, 20, 5, 1, bias=False)
183                self.bn = bn_module[dim](num_features=20)
184                # to make sure new bias is not zero
185                self.bn.eps = 0.0027
186                self.bn.bias = torch.nn.Parameter(torch.rand([20]))
187
188            def forward(self, x):
189                x = self.conv(x)
190                x = self.bn(x)
191                return x
192
193        options = itertools.product([True, False], [2, 3])
194        data = {2: torch.rand(1, 1, 6, 6), 3: torch.rand(1, 1, 6, 6, 6)}
195        for tracing, dim in options:
196            eager = TestModule(dim).eval()
197            x = data[dim]
198            scripted_or_traced = get_script_module(eager, tracing, x).eval()
199            # Check that in the original script module's forward we have two
200            # CallMethod nodes. One of them should be for conv.forward and the other
201            # for bn.forward.
202            FileCheck().check_count(
203                'prim::CallMethod[name="forward"]', 2, exactly=True
204            ).run(str(get_forward_graph(scripted_or_traced._c)))
205
206            # Run FoldConvBatchnorm pass.
207            scripted_or_traced = fuse_conv_bn_jit(scripted_or_traced)
208
209            # Check that after the pass one of the CallMethods is gone (supposedly,
210            # the bn.forward).
211            FileCheck().check_count(
212                'prim::CallMethod[name="forward"]', 1, exactly=True
213            ).run(str(get_forward_graph(scripted_or_traced._c)))
214
215            # Check that the transformation doesn't change numerics
216            self.assertEqual(eager(x), scripted_or_traced(x))
217
218    def test_foldbn_in_submodule(self):
219        bn_module = {2: torch.nn.BatchNorm2d, 3: torch.nn.BatchNorm3d}
220        conv_module = {2: torch.nn.Conv2d, 3: torch.nn.Conv3d}
221
222        # Test that we find Conv-BN patterns in submodules
223        class SubModule(torch.nn.Module):
224            def __init__(self, dim):
225                super().__init__()
226                self.conv = conv_module[dim](1, 20, 5, 1)
227                self.bn = bn_module[dim](num_features=20)
228
229            def forward(self, x):
230                x = self.conv(x)
231                x = self.bn(x)
232                return x
233
234        class TestModule(torch.nn.Module):
235            def __init__(self, dim):
236                super().__init__()
237                self.sub = SubModule(dim)
238
239            def forward(self, x):
240                x = self.sub(x)
241                return x
242
243        options = itertools.product([True, False], [2, 3])
244        data = {2: torch.rand(1, 1, 10, 10), 3: torch.rand(1, 1, 10, 10, 10)}
245        for tracing, dim in options:
246            eager = TestModule(dim).eval()
247            x = data[dim]
248            scripted_or_traced = get_script_module(eager, tracing, x).eval()
249            FileCheck().check_count(
250                'prim::CallMethod[name="forward"]', 2, exactly=True
251            ).run(str(get_forward_graph(scripted_or_traced.sub._c)))
252
253            scripted_or_traced = fuse_conv_bn_jit(scripted_or_traced)
254
255            FileCheck().check_count(
256                'prim::CallMethod[name="forward"]', 1, exactly=True
257            ).run(str(get_forward_graph(scripted_or_traced.sub._c)))
258
259            self.assertEqual(eager(x), scripted_or_traced(x))
260
261    def test_foldbn_shared_classtype(self):
262        bn_module = {2: torch.nn.BatchNorm2d, 3: torch.nn.BatchNorm3d}
263        conv_module = {2: torch.nn.Conv2d, 3: torch.nn.Conv3d}
264
265        class TestModule(torch.nn.Module):
266            def __init__(self, dim, bias=False):
267                super().__init__()
268                self.conv1 = conv_module[dim](5, 5, 3, bias=bias)
269                self.bn1 = bn_module[dim](num_features=5)
270                self.bn1.running_mean.fill_(-0.2)
271                self.bn1.bias = torch.nn.Parameter(torch.rand([5]))
272                # to make sure new bias is not zero
273                self.bn1.eps = 0.0023
274                self.conv2 = conv_module[dim](5, 5, 3, bias=bias)
275                self.bn2 = bn_module[dim](num_features=5)
276                self.bn2.eps = 0.0029
277                self.relu = torch.nn.ReLU()
278
279            def forward(self, x):
280                x = self.conv1(x)
281                x = self.bn1(x)
282                x = self.relu(x)
283                x = self.conv2(x)
284                x = self.bn2(x)
285                x = self.relu(x)
286                return x
287
288        options = itertools.product([True, False], [2, 2], [True, False])
289        data = {2: torch.rand(1, 5, 6, 6), 3: torch.rand(1, 5, 6, 6, 6)}
290        for tracing, dim, bias in options:
291            eager = TestModule(dim, bias).eval()
292            x = data[dim]
293            scripted_or_traced = get_script_module(eager, tracing, x)
294            folded = fuse_conv_bn_jit(scripted_or_traced)
295            self.assertEqual(eager(x), scripted_or_traced(x))
296
297    def test_foldbn_no_fusion(self):
298        """Test that we don't fuse the cases when module type does not match"""
299
300        class CustomConv(torch.nn.Module):
301            def forward(self, x):
302                return x
303
304        class CustomBn(torch.nn.Module):
305            def forward(self, x):
306                return x
307
308        class M(torch.nn.Module):
309            def __init__(self) -> None:
310                super().__init__()
311                self.conv = CustomConv()
312                self.bn = CustomBn()
313
314            def forward(self, x):
315                return self.bn(self.conv(x))
316
317        m = torch.jit.script(M())
318        m = fuse_conv_bn_jit(m)
319        FileCheck().check_count("prim::CallMethod", 2, exactly=True).run(m.graph)
320
321    @set_default_dtype(torch.double)
322    def test_foldbn_complex_cases(self):
323        # This test case attempt to try combinations of conv2d/conv3d with bias/nobias
324        # as well as BatchNorm with affine/no-affine along with varying the
325        # number of layers.
326        # this only works when default dtype is double
327        bn_module = {2: torch.nn.BatchNorm2d, 3: torch.nn.BatchNorm3d}
328        conv_module = {2: torch.nn.Conv2d, 3: torch.nn.Conv3d}
329
330        class SubModule(torch.nn.Module):
331            def __init__(self, dim, num_blocks, enable_bias, enable_affine):
332                super().__init__()
333                layers = []
334                for i in range(num_blocks):
335                    layers.append(conv_module[dim](20, 20, 5, 1, bias=enable_bias))
336                    bn_obj = bn_module[dim](num_features=20, affine=enable_affine)
337                    if enable_affine:
338                        bn_obj.weight = torch.nn.Parameter(
339                            torch.rand_like(bn_obj.weight)
340                        )
341                        bn_obj.bias = torch.nn.Parameter(torch.rand_like(bn_obj.bias))
342                    bn_obj.running_mean = torch.rand_like(bn_obj.running_mean)
343                    bn_obj.running_var = torch.rand_like(bn_obj.running_var)
344                    layers.append(bn_obj)
345                self.layers = nn.Sequential(*layers)
346
347            def forward(self, x):
348                return self.layers(x)
349
350        class TestModule(torch.nn.Module):
351            def __init__(self, dim, num_blocks, enable_bias, enable_affine):
352                super().__init__()
353                self.sub = SubModule(dim, num_blocks, enable_bias, enable_affine)
354
355            def forward(self, x):
356                x = self.sub(x)
357                return x
358
359        options = itertools.product(
360            [True, False], [2, 3], [True, False], [True, False], [1, 2]
361        )
362        data = {2: torch.rand(1, 20, 10, 10), 3: torch.rand(1, 20, 10, 10, 10)}
363        for tracing, dim, enable_bias, enable_bn_affine, num_layers in options:
364            eager = TestModule(dim, num_layers, enable_bias, enable_bn_affine).eval()
365            x = data[dim]
366            scripted_or_traced = get_script_module(eager, tracing, x).eval()
367
368            FileCheck().check_count(
369                'prim::CallMethod[name="forward"]', num_layers * 2, exactly=True
370            ).run(str(get_forward_graph(scripted_or_traced.sub.layers._c)))
371
372            scripted_or_traced = fuse_conv_bn_jit(scripted_or_traced)
373
374            FileCheck().check_count(
375                'prim::CallMethod[name="forward"]', num_layers, exactly=True
376            ).run(str(get_forward_graph(scripted_or_traced.sub.layers._c)))
377
378            self.assertEqual(eager(x), scripted_or_traced(x))
379
380    def test_fuse_linear(self):
381        class FunctionalLinear(torch.nn.Module):
382            def __init__(self, weight, bias):
383                super().__init__()
384                self.weight = weight
385                self.bias = bias
386
387            def forward(self, x):
388                res = torch.matmul(x, self.weight.t())
389                if self.bias is not None:
390                    res.add_(self.bias)
391                return res
392
393        x1 = torch.rand(3)
394        w1 = torch.rand(5, 3)
395        b1 = torch.rand(5)
396
397        x2 = torch.rand(5, 5)
398        w2 = torch.rand(5, 5)
399        b2 = torch.rand(5)
400
401        x3 = torch.rand(5, 5, 5)
402        w3 = torch.rand(5, 5)
403        b3 = torch.rand(5)
404        for has_bias, (x, weight, b) in itertools.product(
405            [True, False], [(x1, w1, b1), (x2, w2, b2), (x3, w3, b3)]
406        ):
407            bias = b if has_bias else None
408            model = torch.jit.trace(FunctionalLinear(weight, bias), [x])
409            for node in model.graph.nodes():
410                if node.kind() == "aten::matmul":
411                    source_range_1 = node.sourceRange()
412            torch._C._jit_pass_fuse_linear(model.graph)
413            for node in model.graph.nodes():
414                if node.kind() == "aten::linear":
415                    source_range_2 = node.sourceRange()
416            FileCheck().check("aten::linear").run(model.graph)
417            check_not = ["aten::matmul", "aten::addmm", "aten::add_", "aten::t("]
418            for cn in check_not:
419                FileCheck().check_not(cn).run(model.graph)
420            # make sure it runs
421            self.assertTrue(source_range_1 == source_range_2)
422            model(x)
423
424        # check matmuls are not fused
425        class Matmul(torch.nn.Module):
426            def __init__(self, weight):
427                super().__init__()
428                self.weight = weight
429
430            def forward(self, x):
431                return torch.matmul(x, self.weight)
432
433        x = torch.rand(5, 6, 5)
434        w = torch.rand(5, 5, 100)
435        model = torch.jit.trace(Matmul(w), [x])
436        torch._C._jit_pass_fuse_linear(model.graph)
437        # check 3d matmul is not fused
438        FileCheck().check("aten::matmul").run(model.graph)
439        FileCheck().check_not("aten::linear").run(model.graph)
440        # make sure it runs
441        model(x)
442
443    def test_insert_observers(self):
444        class M(torch.nn.Module):
445            def __init__(self) -> None:
446                super().__init__()
447                self.conv = torch.nn.Conv2d(3, 5, 3)
448
449            def forward(self, x):
450                return self.conv(x)
451
452        m = torch.jit.script(M())
453        qconfig_dict = {"": default_qconfig}
454        m = prepare_jit(m, qconfig_dict)
455        # for input and output of conv
456        assert len(attrs_with_prefix(m, "_observer_")) == 2
457        # for weight
458        assert len(attrs_with_prefix(m.conv, "_observer_")) == 1
459
460    def test_insert_observers_interface(self):
461        @torch.jit.interface
462        class SubInterface(torch.nn.Module):
463            def addOne(self, inp) -> torch.Tensor:
464                pass
465
466        class Sub(torch.nn.Module):
467            def __init__(self) -> None:
468                super().__init__()
469                self.fc = torch.nn.Linear(5, 5)
470
471            def addOne(self, inp):
472                return self.fc(inp) + 1
473
474            def forward(self, x):
475                return self.addOne(x)
476
477        class M(torch.nn.Module):
478            def __init__(self) -> None:
479                super().__init__()
480                self.conv = torch.nn.Conv2d(3, 5, 3)
481                self.sub = Sub()
482
483            def forward(self, x):
484                return self.sub(self.conv(x))
485
486        m = torch.jit.script(M())
487        qconfig_dict = {"sub.conv": default_qconfig}
488        m = prepare_jit(m, qconfig_dict)
489
490    def test_insert_observers_interface_unshare_type(self):
491        @torch.jit.interface
492        class OperatorIf(nn.Module):
493            def forward(self, inp: torch.Tensor) -> torch.Tensor:
494                pass
495
496        class Operator(nn.Module):
497            def __init__(self, a):
498                super().__init__()
499                self.a = a
500
501            def forward(self, inp: torch.Tensor) -> torch.Tensor:
502                return self.a * (inp + self.a)
503
504        class Inner(nn.Module):
505            op: OperatorIf
506
507            def __init__(self, op):
508                super().__init__()
509                self.op = op
510
511            def forward(self, inp):
512                return self.op(inp)
513
514        class Outer(nn.Module):
515            def __init__(self) -> None:
516                super().__init__()
517                self.inner_a = Inner(Operator(1))
518                self.inner_b = Inner(Operator(3.0))
519
520            def forward(self, inp):
521                return self.inner_a(inp) + self.inner_b(inp)
522
523        qconfig_dict = {"inner_a": default_qconfig, "inner_b": default_qconfig}
524
525        eager_model = Outer()
526        for tracing in [True, False]:
527            x = torch.rand(3)
528            script_model = get_script_module(eager_model, tracing, x)
529            # make sure it runs
530            prepare_jit(script_model, qconfig_dict)
531
532    def test_insert_observers_child_qconfig(self):
533        class Sub(torch.nn.Module):
534            def __init__(self) -> None:
535                super().__init__()
536                self.fc = torch.nn.Linear(5, 5)
537
538            def forward(self, x):
539                return self.fc(x)
540
541        class M(torch.nn.Module):
542            def __init__(self) -> None:
543                super().__init__()
544                self.conv = torch.nn.Conv2d(3, 5, 3)
545                self.sub = Sub()
546
547            def forward(self, x):
548                return self.sub(self.conv(x))
549
550        m = torch.jit.script(M())
551        qconfig_dict = {"sub.fc": default_qconfig}
552        m = prepare_jit(m, qconfig_dict)
553        # input and output of sub
554        assert len(attrs_with_prefix(m, "_observer_")) == 2
555        # not quantized
556        assert len(attrs_with_prefix(m.conv, "_observer_")) == 0
557        # no observers since we observe in the outer most call site
558        assert len(attrs_with_prefix(m.sub, "_observer_")) == 0
559        # weight of linear
560        assert len(attrs_with_prefix(m.sub.fc, "_observer_")) == 1
561
562    @unittest.skipUnless(
563        "fbgemm" in torch.backends.quantized.supported_engines,
564        " Quantized operations require FBGEMM. FBGEMM is only optimized for CPUs"
565        " with instruction set support avx2 or newer.",
566    )
567    def test_insert_observers_skip_values(self):
568        class ConvFunctionalReLU(torch.nn.Module):
569            def __init__(self) -> None:
570                super().__init__()
571                self.conv = torch.nn.Conv2d(3, 5, 3)
572
573            def forward(self, x):
574                return F.relu(self.conv(x))
575
576        class ConvReLUModule(torch.nn.Module):
577            def __init__(self) -> None:
578                super().__init__()
579                self.conv = torch.nn.Conv2d(3, 5, 3)
580                self.relu = torch.nn.ReLU()
581
582            def forward(self, x):
583                return self.relu(self.conv(x))
584
585        class AddReLUModule(torch.nn.Module):
586            def __init__(self) -> None:
587                super().__init__()
588                self.relu = torch.nn.ReLU()
589                self.conv = torch.nn.Conv2d(3, 3, 3).float()
590
591            def forward(self, x):
592                out = self.conv(x)
593                out += x
594                return self.relu(out)
595
596        class AddFunctionalReLU(torch.nn.Module):
597            def __init__(self) -> None:
598                super().__init__()
599                self.conv = torch.nn.Conv2d(3, 3, 3).float()
600
601            def forward(self, x):
602                out = self.conv(x)
603                out += x
604                return F.relu(out)
605
606        def attrs_with_prefix(module, prefix):
607            return [x for x, _ in module._modules._c.items() if x.startswith(prefix)]
608
609        qconfig_dict = {"": default_qconfig}
610        m = torch.jit.script(ConvFunctionalReLU())
611        m = prepare_jit(m, qconfig_dict)
612        # observer for weight of conv
613        assert len(attrs_with_prefix(m.conv, "_observer_")) == 1
614        # observer for input of conv and output of relu
615        assert len(attrs_with_prefix(m, "_observer_")) == 2
616
617        m = torch.jit.script(ConvReLUModule())
618        m = prepare_jit(m, qconfig_dict)
619        # observer for input of conv and output of relu
620        assert len(attrs_with_prefix(m, "_observer_")) == 2
621        # observer for weight of conv
622        assert len(attrs_with_prefix(m.conv, "_observer_")) == 1
623        # observer for output of relu
624        assert len(attrs_with_prefix(m.relu, "_observer_")) == 0
625
626        m = torch.jit.script(AddReLUModule())
627        qconfig_dict = {"": default_qconfig}
628        m = prepare_jit(m, qconfig_dict)
629        assert len(attrs_with_prefix(m, "_observer")) == 3
630        assert len(attrs_with_prefix(m.relu, "_observer")) == 0
631        FileCheck().check("aten::add_").check_not(
632            'Observer = prim::GetAttr[name="_observer_'
633        ).check("ReLU = prim::GetAttr").run(str(get_forward_graph(m._c)))
634
635        m = torch.jit.script(AddFunctionalReLU())
636        qconfig_dict = {"": default_qconfig}
637        m = prepare_jit(m, qconfig_dict)
638        assert len(attrs_with_prefix(m, "_observer")) == 3
639        FileCheck().check("aten::add_").check_not(
640            'Observer = prim::GetAttr[name="_observer_'
641        ).check("CallFunction").check('Observer = prim::GetAttr[name="_observer_').run(
642            str(get_forward_graph(m._c))
643        )
644
645    def test_insert_observers_weight_dtype(self):
646        class M(torch.nn.Module):
647            def __init__(self) -> None:
648                super().__init__()
649                self.conv = torch.nn.Conv2d(3, 5, 3)
650
651            def forward(self, x):
652                return F.relu(self.conv(x))
653
654        m = torch.jit.script(M())
655        qconfig_dict = {"": default_qconfig}
656        m = prepare_jit(m, qconfig_dict)
657        activation_dtypes = {
658            obs.getattr("dtype")
659            for x, obs in m._modules._c.items()
660            if x.startswith("_observer_")
661        }
662        weight_dtypes = {
663            obs.getattr("dtype")
664            for x, obs in m.conv._modules._c.items()
665            if x.startswith("_observer_")
666        }
667        assert len(activation_dtypes) == 1, "Expected to have 1 activation dtype"
668        assert len(weight_dtypes) == 1, "Expected to have 1 weight dtype"
669        assert next(iter(activation_dtypes)) != next(
670            iter(weight_dtypes)
671        ), "Expected activation dtype to "
672        " be different from wegiht dtype"
673
674    def test_insert_observers_for_reused_weight(self):
675        class M(torch.nn.Module):
676            def forward(self, x, y, weight):
677                x = F.conv2d(x, weight)
678                y = F.conv2d(y, weight)
679                return x + y
680
681        m = torch.jit.script(M()).eval()
682        m = prepare_jit(m, {"": default_qconfig})
683        # 3 for x, y, weight, one for output of each F.conv2d and one for output of add
684        assert len(attrs_with_prefix(m, "_observer")) == 6
685
686    def test_insert_observers_shared_class_type(self):
687        class M(torch.nn.Module):
688            def __init__(self) -> None:
689                super().__init__()
690                self.conv1 = torch.nn.Conv2d(3, 5, 3).float()
691                self.conv2 = torch.nn.Conv2d(3, 5, 3).float()
692
693            def forward(self, x):
694                return self.conv2(self.conv1(x))
695
696        m = torch.jit.script(M())
697        qconfig_dict = {"": default_qconfig}
698        m = prepare_jit(m, qconfig_dict)
699        # conv1 and conv2 shares the same type, we need to
700        # make sure we didn't quantize the type twice
701        conv1_observers = attrs_with_prefix(m.conv1, "_observer_")
702        conv2_observers = attrs_with_prefix(m.conv2, "_observer_")
703        assert len(conv1_observers) == 1, "Expected to have 1 observer submodules"
704        assert len(conv2_observers) == 1, "Expected to have 1 observer submodules"
705        assert (
706            conv1_observers == conv2_observers
707        ), "Expect conv1 and conv2 to have same observers since the class type is shared"
708
709    def test_insert_observers_for_general_ops(self):
710        """Make sure we skip observers for ops that doesn't require
711        observation, e.g. flatten
712        """
713
714        class M(torch.nn.Module):
715            def __init__(self) -> None:
716                super().__init__()
717                self.conv = torch.nn.Conv2d(3, 3, 3).float()
718
719            def forward(self, x):
720                x = self.conv(x)
721                x = torch.flatten(x)
722                return x
723
724        m = torch.jit.script(M())
725        qconfig_dict = {"": default_qconfig}
726        m = prepare_jit(m, qconfig_dict)
727        # input and output of conv
728        assert len(attrs_with_prefix(m, "_observer_")) == 2
729        FileCheck().check('Observer = prim::GetAttr[name="_observer_').check(
730            'prim::GetAttr[name="conv"]'
731        ).check("prim::CallMethod").check(
732            'Observer = prim::GetAttr[name="_observer_'
733        ).check(
734            "aten::flatten"
735        ).check_not(
736            'Observer = prim::GetAttr[name="_observer_'
737        ).run(
738            m.graph
739        )
740
741    # TODO: this is too long, split this to test_insert_observers.py and remove
742    # insrt_observers prefix
743    def test_insert_observers_propagate_observed(self):
744        """Make sure we propagate observed property through general ops"""
745
746        class M(torch.nn.Module):
747            def __init__(self) -> None:
748                super().__init__()
749                self.conv1 = torch.nn.Conv2d(3, 3, 3).float()
750                self.conv2 = torch.nn.Conv2d(3, 3, 3).float()
751
752            def forward(self, x):
753                x = self.conv1(x)
754                x = torch.flatten(x)
755                # we don't want to insert observer for input of self.conv2
756                # because output of self.conv1 is already observed
757                x = self.conv2(x)
758                return x
759
760        m = torch.jit.script(M())
761        qconfig_dict = {"": default_qconfig}
762        m = prepare_jit(m, qconfig_dict)
763        # input and output of conv
764        assert len(attrs_with_prefix(m, "_observer_")) == 3
765        FileCheck().check('Observer = prim::GetAttr[name="_observer_').check(
766            'prim::GetAttr[name="conv1"]'
767        ).check("prim::CallMethod").check(
768            'Observer = prim::GetAttr[name="_observer_'
769        ).check(
770            "aten::flatten"
771        ).check_not(
772            'Observer = prim::GetAttr[name="_observer_'
773        ).check(
774            'prim::GetAttr[name="conv2"]'
775        ).check(
776            'Observer = prim::GetAttr[name="_observer_'
777        ).run(
778            m.graph
779        )
780
781    def test_insert_observers_propagate_observed_in_submodule(self):
782        """Make sure we propagate observed property through general ops"""
783
784        class M(torch.nn.Module):
785            def __init__(self) -> None:
786                super().__init__()
787                self.conv1 = torch.nn.Conv2d(3, 3, 3).float()
788                self.conv2 = torch.nn.Conv2d(3, 3, 3).float()
789                self.avgpool = torch.nn.AdaptiveAvgPool2d((1, 1))
790
791            def forward(self, x):
792                x = self.conv1(x)
793                x = self.avgpool(x)
794                # we don't want to insert observer for input of self.conv2
795                # because output of self.conv1 is already observed
796                x = self.conv2(x)
797                return x
798
799        m = torch.jit.script(M())
800        qconfig_dict = {"": default_qconfig}
801        m = prepare_jit(m, qconfig_dict)
802        # input and output of conv
803        assert len(attrs_with_prefix(m, "_observer_")) == 3
804        FileCheck().check('Observer = prim::GetAttr[name="_observer_').check(
805            'prim::GetAttr[name="conv1"]'
806        ).check("prim::CallMethod").check(
807            'Observer = prim::GetAttr[name="_observer_'
808        ).check(
809            "prim::CallMethod"
810        ).check_not(
811            'Observer = prim::GetAttr[name="_observer_'
812        ).check(
813            'prim::GetAttr[name="conv2"]'
814        ).check(
815            'Observer = prim::GetAttr[name="_observer_'
816        ).run(
817            m.graph
818        )
819
820    def test_insert_observers_propagate_observed_for_function(self):
821        def channel_shuffle(x: torch.Tensor, groups: int) -> torch.Tensor:
822            batchsize, num_channels, height, width = x.data.size()
823            channels_per_group = num_channels // groups
824            # reshape
825            x = x.view(batchsize, groups, channels_per_group, height, width)
826            x = torch.transpose(x, 1, 2).contiguous()
827            # flatten
828            x = x.view(batchsize, -1, height, width)
829            return x
830
831        class M(torch.nn.Module):
832            def __init__(self) -> None:
833                super().__init__()
834                self.conv1 = torch.nn.Conv2d(3, 3, 1).float()
835                self.conv2 = torch.nn.Conv2d(3, 3, 1).float()
836
837            def forward(self, x):
838                x = self.conv1(x)
839                x = channel_shuffle(x, 1)
840                x = self.conv2(x)
841                return x
842
843        data = [
844            (
845                torch.rand((1, 3, 10, 10), dtype=torch.float),
846                torch.randint(0, 1, (1,), dtype=torch.long),
847            )
848            for _ in range(2)
849        ]
850        m = torch.jit.script(M()).eval()
851        m = prepare_jit(m, {"": default_qconfig})
852        # we want to test that channel_shuffle is going to pass
853        # the observed property from the output of conv1 to input of conv2
854        # so that we don't insert observers for input of conv2
855        assert (
856            len(
857                attrs_with_prefix(
858                    m,
859                    "_observer_",
860                )
861            )
862            == 3
863        )
864
865    def test_insert_observers_for_if(self):
866        class QuantProp(torch.nn.Module):
867            def __init__(self, use_skip):
868                super().__init__()
869                self.conv = torch.nn.Conv2d(3, 3, 1).float()
870                self.use_skip = use_skip
871
872            def forward(self, x):
873                if self.use_skip:
874                    x = self.conv(x)
875                    return torch.reshape(x, x.shape)
876                else:
877                    x = self.conv(x)
878                    return torch.reshape(x, x.shape)
879
880        class Res(torch.nn.Module):
881            def __init__(self, use_skip):
882                super().__init__()
883                self.conv = torch.nn.Conv2d(3, 3, 1).float()
884                self.use_skip = use_skip
885
886            def forward(self, x):
887                if self.use_skip:
888                    return self.conv(x)
889                else:
890                    return self.conv(x)
891
892        class M(torch.nn.Module):
893            def __init__(self) -> None:
894                super().__init__()
895                self.quant_prop = QuantProp(True)
896                self.res = Res(False)
897
898            def forward(self, x):
899                x = self.quant_prop(x)
900                x = self.res(x)
901                return x
902
903        data = [torch.rand(1, 3, 10, 10, dtype=torch.float)]
904        result = {False: [1, 2, 2], True: [2, 1, 0]}
905        for tracing in [True, False]:
906            if tracing:
907                m = torch.jit.trace(M(), data).eval()
908            else:
909                m = torch.jit.script(M()).eval()
910            m = prepare_jit(m, {"": default_qconfig})
911            assert (
912                len(
913                    attrs_with_prefix(
914                        m,
915                        "_observer_",
916                    )
917                )
918                == result[tracing][0]
919            )
920            assert (
921                len(
922                    attrs_with_prefix(
923                        m.quant_prop,
924                        "_observer_",
925                    )
926                )
927                == result[tracing][1]
928            )
929            assert (
930                len(
931                    attrs_with_prefix(
932                        m.res,
933                        "_observer_",
934                    )
935                )
936                == result[tracing][2]
937            )
938
939    def test_insert_observers_for_nested_if(self):
940        class Res(torch.nn.Module):
941            def __init__(self, use_skip):
942                super().__init__()
943                self.conv = torch.nn.Conv2d(3, 3, 1).float()
944                self.cond = use_skip
945                self.use_skip = use_skip
946
947            def forward(self, x):
948                if self.use_skip:
949                    if self.cond:
950                        return self.conv(x)
951                    else:
952                        return self.conv(x)
953                else:
954                    return self.conv(x)
955
956        class M(torch.nn.Module):
957            def __init__(self) -> None:
958                super().__init__()
959                self.res1 = Res(True)
960                self.res2 = Res(False)
961
962            def forward(self, x):
963                x = self.res1(x)
964                x = self.res2(x)
965                return x
966
967        data = torch.rand((1, 3, 10, 10), dtype=torch.float)
968        result = {True: 3, False: 1}
969        for tracing in [True, False]:
970            if tracing:
971                m = torch.jit.trace(M(), data).eval()
972            else:
973                m = torch.jit.script(M()).eval()
974            m = prepare_jit(m, {"": default_qconfig})
975            assert len(attrs_with_prefix(m, "_observer_")) == result[tracing]
976
977    def test_insert_observers_for_if_consistent_observation(self):
978        """check quantization for if works as long as
979        output of all branches are quantized/observed consistently
980        """
981
982        class M(torch.nn.Module):
983            def __init__(self, cond):
984                super().__init__()
985                self.conv = torch.nn.Conv2d(3, 3, 3).float()
986                self.cond = cond
987
988            def forward(self, x):
989                x = self.conv(x)
990                # x is already observed
991                if self.cond:
992                    x = torch.flatten(x)
993                return x
994
995        class M2(torch.nn.Module):
996            def __init__(self, cond):
997                super().__init__()
998                self.conv1 = torch.nn.Conv2d(3, 3, 3).float()
999                self.conv2 = torch.nn.Conv2d(3, 3, 3).float()
1000                self.cond = cond
1001
1002            def forward(self, x):
1003                x = self.conv1(x)
1004                if self.cond:
1005                    x = self.conv2(x)
1006                    # x will be observed in the branch
1007                else:
1008                    x = torch.flatten(x)
1009                # since output for both branch are quantized
1010                # the if node is quantized consistently
1011                return x
1012
1013        data = torch.rand((1, 3, 5, 5), dtype=torch.float)
1014        options = list(itertools.product([True, False], [True, False]))
1015        for cond, tracing in options:
1016            if tracing:
1017                m = torch.jit.trace(M(cond), data)
1018            else:
1019                m = torch.jit.script(M(cond))
1020            m = prepare_jit(m, {"": default_qconfig})
1021            assert len(attrs_with_prefix(m, "_observer_")) == 2
1022
1023        for cond, tracing in options:
1024            if tracing:
1025                m = torch.jit.trace(M2(cond), data)
1026            else:
1027                m = torch.jit.script(M2(cond))
1028            m = prepare_jit(m, {"": default_qconfig})
1029            num_observers = 2 if tracing and not cond else 3
1030            assert len(attrs_with_prefix(m, "_observer_")) == num_observers
1031
1032    def test_insert_quant_dequant(self):
1033        class M(torch.nn.Module):
1034            def __init__(self) -> None:
1035                super().__init__()
1036                self.conv = torch.nn.Conv2d(3, 5, 3).float()
1037
1038            def forward(self, x):
1039                return self.conv(x)
1040
1041        for is_per_channel in [True, False]:
1042            m = torch.jit.script(M())
1043            observer = (
1044                default_per_channel_weight_observer.with_args(ch_axis=1)
1045                if is_per_channel
1046                else default_observer
1047            )
1048            qconfig_dict = {"": QConfig(activation=observer, weight=observer)}
1049            m = prepare_jit(m, qconfig_dict)
1050            data = torch.randn(1, 3, 10, 10, dtype=torch.float)
1051
1052            m(data)
1053            m = convert_jit(m, debug=True)
1054            assert (
1055                len(m._modules._c.items()) == 1
1056            ), "Expected to have single submodule of conv"
1057            # make sure the quantized model is executable
1058            m(data)
1059            quant_func = (
1060                "aten::quantize_per_channel"
1061                if is_per_channel
1062                else "aten::quantize_per_tensor"
1063            )
1064            FileCheck().check_count(quant_func, 3, exactly=True).run(m.graph)
1065
1066    def test_insert_quant_dequant_shared_class_type(self):
1067        class M(torch.nn.Module):
1068            def __init__(self) -> None:
1069                super().__init__()
1070                self.conv1 = torch.nn.Conv2d(3, 3, 3).float()
1071                self.conv2 = torch.nn.Conv2d(3, 3, 3).float()
1072
1073            def forward(self, x):
1074                return self.conv2(self.conv1(x))
1075
1076        for is_per_channel in [True, False]:
1077            m = torch.jit.script(M())
1078            observer = (
1079                default_per_channel_weight_observer.with_args(ch_axis=1)
1080                if is_per_channel
1081                else default_observer
1082            )
1083            qconfig = QConfig(activation=observer, weight=observer)
1084            qconfig_dict = {"": qconfig}
1085            m = prepare_jit(m, qconfig_dict)
1086            # observers for input, output and value between conv1/conv2
1087            assert (
1088                len(attrs_with_prefix(m, "_observer_")) == 3
1089            ), "Expected to have 3 obervers"
1090            # observer for weight
1091            assert (
1092                len(attrs_with_prefix(m.conv1, "_observer_")) == 1
1093            ), "Expected to have 1 obervers"
1094            # observer for weight
1095            assert (
1096                len(attrs_with_prefix(m.conv2, "_observer_")) == 1
1097            ), "Expected to have 1 obervers"
1098
1099            data = torch.randn(1, 3, 10, 10, dtype=torch.float)
1100            m(data)
1101            m = convert_jit(m, debug=True)
1102            m(data)
1103            assert m.conv1._c._type() == m.conv2._c._type()
1104
1105            # check all observers have been removed
1106            assert (
1107                len(attrs_with_prefix(m, "_observer_")) == 0
1108            ), "Expected to have 0 obervers"
1109            assert (
1110                len(attrs_with_prefix(m.conv1, "_observer_")) == 0
1111            ), "Expected to have 0 obervers"
1112            assert (
1113                len(attrs_with_prefix(m.conv2, "_observer_")) == 0
1114            ), "Expected to have 0 obervers"
1115
1116            quant_func = (
1117                "aten::quantize_per_channel"
1118                if is_per_channel
1119                else "aten::quantize_per_tensor"
1120            )
1121            for module in ["conv1", "conv2"]:
1122                conv = m._c.getattr(module)
1123                # quantize weight
1124                FileCheck().check(quant_func).check_next("aten::dequantize").check(
1125                    'prim::CallMethod[name="_conv_forward"]'
1126                ).check("return").run(get_forward_graph(conv))
1127                # no quantize node in _conv_forward
1128                FileCheck().check_not(quant_func).check("aten::conv2d").check_not(
1129                    quant_func
1130                ).check("return").run(conv._get_method("_conv_forward").graph)
1131
1132    def test_dedup_module_uses(self):
1133        class M(torch.nn.Module):
1134            def __init__(self) -> None:
1135                super().__init__()
1136                self.relu = torch.nn.ReLU()
1137
1138            def forward(self, x):
1139                x = self.relu(x)
1140                x -= 0.5
1141                return self.relu(x)
1142
1143        data = torch.randn((2, 2))
1144        m = torch.jit.script(M())
1145        ref_res = m(data)
1146        assert (
1147            len([x for x, _ in m._modules._c.items() if x.startswith("relu")]) == 1
1148        ), "Expected to have 1 relu modules after dedup module uses"
1149        torch._C._jit_pass_dedup_module_uses(m._c)
1150        m = torch.jit._recursive.wrap_cpp_module(m._c)
1151        res = m(data)
1152        assert (
1153            len([x for x, _ in m._modules._c.items() if x.startswith("relu")]) == 2
1154        ), "Expected to have 2 relu modules after dedup module uses"
1155        self.assertEqual(res, ref_res)
1156
1157    def test_replicate_dequantize(self):
1158        class M(torch.nn.Module):
1159            def __init__(self) -> None:
1160                super().__init__()
1161                self.conv = torch.nn.Conv2d(3, 3, 1).float()
1162
1163            def forward(self, x):
1164                x = torch.dequantize(x)
1165                r = self.conv(x)
1166                r += x
1167                return r
1168
1169        x = torch.randn([1, 3, 10, 10], dtype=torch.float)
1170        x = torch.quantize_per_tensor(x, 0.5, 1, torch.quint8)
1171        m = torch.jit.script(M())
1172        ref_res = m(x)
1173        FileCheck().check_count("aten::dequantize", 1, exactly=True).run(m.graph)
1174        torch._C._jit_pass_replicate_dequantize(m.graph)
1175        FileCheck().check_count("aten::dequantize", 2, exactly=True).run(m.graph)
1176        res = get_forward(m._c)(x)
1177        self.assertEqual(res, ref_res)
1178
1179    def test_replicate_dequantize_in_block(self):
1180        class M(torch.nn.Module):
1181            def __init__(self, cond):
1182                super().__init__()
1183                self.conv = torch.nn.Conv2d(3, 3, 1).float()
1184
1185                self.cond = cond
1186
1187            def forward(self, x):
1188                x = torch.dequantize(x)
1189                if self.cond:
1190                    x = self.conv(x)
1191                else:
1192                    x = x + 3
1193                return x
1194
1195        x = torch.randn([1, 3, 10, 10], dtype=torch.float)
1196        x = torch.quantize_per_tensor(x, 0.5, 1, torch.quint8)
1197        m = torch.jit.script(M(True))
1198        ref_res = m(x)
1199        FileCheck().check_count("aten::dequantize", 1, exactly=True).run(m.graph)
1200        torch._C._jit_pass_replicate_dequantize(m.graph)
1201        FileCheck().check_count("aten::dequantize", 2, exactly=True).run(m.graph)
1202        # check dequantize is right before CallMethod of conv
1203        FileCheck().check("aten::dequantize").check_next("CallMethod").run(m.graph)
1204        # check dequantize is right before add
1205        FileCheck().check("aten::dequantize").check("aten::dequantize").check_next(
1206            "aten::add"
1207        ).run(m.graph)
1208        res = get_forward(m._c)(x)
1209        self.assertEqual(res, ref_res)
1210
1211    def test_swap_functional_linear(self):
1212        # TODO: This pass replaces any function called "linear" with "aten::linear"
1213        # No longer necessary, and also quite surprising
1214        def linear(input, weight, bias):
1215            return torch.nn.functional.linear(input, weight, bias)
1216
1217        class M(torch.nn.Module):
1218            def forward(self, x, weight, bias):
1219                x = torch.dequantize(x)
1220                weight = torch.dequantize(weight)
1221                x = linear(x, weight, bias)
1222                x = torch.quantize_per_tensor(
1223                    x, scale=1.0, zero_point=0, dtype=torch.quint8
1224                )
1225                return x
1226
1227        x = torch.rand((10, 5), dtype=torch.float)
1228        x = torch.quantize_per_tensor(x, scale=0.5, zero_point=1, dtype=torch.quint8)
1229        weight = torch.rand((5, 5), dtype=torch.float)
1230        weight = torch.quantize_per_tensor(
1231            weight, scale=0.5, zero_point=1, dtype=torch.qint8
1232        )
1233        bias = torch.rand((5), dtype=torch.float)
1234        m = torch.jit.script(M())
1235        ref_res = m(x, weight, bias)
1236        FileCheck().check("CallFunction").run(m.graph)
1237        torch._C._jit_pass_swap_functional_linear(m.graph)
1238        FileCheck().check("aten::linear").check_not("CallFunction").run(m.graph)
1239        res = m(x, weight, bias)
1240        self.assertEqual(res, ref_res)
1241
1242    def test_replicate_quantize_for_if(self):
1243        """We want to move quantize nodes for output of prim::If
1244        inside the prim::If blocks so that we can match quantization
1245        patterns.
1246        """
1247
1248        class Res(torch.nn.Module):
1249            def __init__(self) -> None:
1250                super().__init__()
1251                self.conv = torch.nn.Conv2d(3, 3, 1).float()
1252                self.conv2 = torch.nn.Conv2d(3, 3, 1).float()
1253                self.use_skip = True
1254
1255            def forward(self, x: torch.Tensor, cond: bool) -> torch.Tensor:
1256                # to avoid being frozen
1257                self.use_skip = cond
1258                if self.use_skip:
1259                    return self.conv(x)
1260                else:
1261                    return self.conv2(x)
1262
1263        class M(torch.nn.Module):
1264            def __init__(self) -> None:
1265                super().__init__()
1266                self.res1 = Res()
1267                self.res2 = Res()
1268
1269            def forward(self, x):
1270                x = self.res1(x, True)
1271                x = self.res2(x, False)
1272                return x
1273
1274        data = [[torch.rand((1, 3, 10, 10), dtype=torch.float)]]
1275        qconfig_dict = {"": default_qconfig}
1276        m = torch.jit.script(M()).eval()
1277        m = quantize_jit(m, qconfig_dict, test_only_eval_fn, [data])
1278        # make sure patterns in both branches are fused
1279        FileCheck().check_count("quantized::conv2d(", 4, exactly=True).run(m.graph)
1280
1281    def test_finalize_for_linear(self):
1282        class M(torch.nn.Module):
1283            def __init__(self) -> None:
1284                super().__init__()
1285                self.fc = torch.nn.Linear(5, 5).float()
1286
1287            def forward(self, x):
1288                return self.fc(x)
1289
1290        data = [[torch.rand((1, 5), dtype=torch.float)]]
1291        qconfig_dict = {"": default_qconfig}
1292        model = torch.jit.script(M()).eval()
1293        model = quantize_jit(model, qconfig_dict, test_only_eval_fn, [data])
1294        # make sure there is only one quantize_per_tensor for input
1295        # and linear_prepack is folded
1296        FileCheck().check_count("aten::quantize_per_tensor", 1, exactly=True).check_not(
1297            "quantized::linear_prepack"
1298        ).check("quantized::linear").run(model.graph)
1299
1300    def test_inplace_option(self):
1301        for tracing in [True, False]:
1302            model = get_script_module(
1303                torch.nn.Conv2d(3, 3, 3).float(), tracing, self.img_data_2d[0][0]
1304            )
1305            qconfig_dict = {"": default_qconfig}
1306            quantize_jit(
1307                model, qconfig_dict, test_only_eval_fn, [self.img_data_2d], inplace=True
1308            )
1309            FileCheck().check("quantized::conv2d").run(model.graph)
1310
1311            FileCheck().check_not("aten::conv2d").run(model.graph)
1312
1313    def test_finalize_debug(self):
1314        class M(torch.nn.Module):
1315            def __init__(self) -> None:
1316                super().__init__()
1317                self.conv = torch.nn.Conv2d(3, 3, 3).float()
1318                self.avgpool = torch.nn.AvgPool2d(3)
1319
1320            def forward(self, x):
1321                x = self.conv(x)
1322                x = self.avgpool(x)
1323                return x
1324
1325        data = [[torch.rand((1, 3, 10, 10), dtype=torch.float)]]
1326        qconfig_dict = {"": default_qconfig}
1327        model = torch.jit.script(M()).eval()
1328        model = quantize_jit(model, qconfig_dict, test_only_eval_fn, [data], debug=True)
1329        FileCheck().check_not("quantized::conv2d").check("aten::conv2d").check(
1330            "aten::avg_pool2d"
1331        ).check("aten::q_scale").check_next("aten::q_zero_point").check_next(
1332            "prim::dtype"
1333        ).check_next(
1334            "aten::quantize_per_tensor"
1335        ).check(
1336            "aten::dequantize"
1337        ).run(
1338            model.graph
1339        )
1340
1341    def test_module_list(self):
1342        class SimpleLinearLayer(torch.nn.Module):
1343            def __init__(self) -> None:
1344                super().__init__()
1345                self.fc = torch.nn.Linear(5, 5).float()
1346
1347            def forward(self, x):
1348                return self.fc(x)
1349
1350        class ComplexModel(torch.nn.Module):
1351            def __init__(self) -> None:
1352                super().__init__()
1353                self.layers = torch.nn.ModuleList(
1354                    [SimpleLinearLayer() for i in range(2)]
1355                )
1356
1357            def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
1358                states = []
1359                for layer in self.layers:
1360                    val = layer(x)
1361                    states.append(val)
1362                return states
1363
1364        data = torch.rand((1, 5), dtype=torch.float)
1365        qconfig_dict = {"": default_qconfig}
1366        model = torch.jit.script(ComplexModel()).eval()
1367        model = prepare_jit(model, qconfig_dict)
1368        assert len(attrs_with_prefix(model, "_observer")) == 3
1369        model(data)
1370        model = convert_jit(model, debug=False)
1371        FileCheck().check("quantized::linear").check("quantized::linear").run(
1372            model.graph
1373        )
1374
1375    def test_conv_trace(self):
1376        class M(torch.nn.Module):
1377            def __init__(self) -> None:
1378                super().__init__()
1379                self.conv1d = torch.nn.Conv1d(3, 3, 3).float()
1380                self.conv2d = torch.nn.Conv2d(3, 3, 3).float()
1381                self.conv3d = torch.nn.Conv3d(3, 3, 3).float()
1382
1383            def forward(self, x, y, z):
1384                a = self.conv1d(x)
1385                b = self.conv2d(y)
1386                c = self.conv3d(z)
1387                return (a, b, c)
1388
1389        qconfig_dict = {"": default_qconfig}
1390        inputs = (
1391            torch.rand((1, 3, 10), dtype=torch.float),
1392            torch.rand((1, 3, 10, 10), dtype=torch.float),
1393            torch.rand((1, 3, 10, 10, 10), dtype=torch.float),
1394        )
1395        model = torch.jit.trace(M(), inputs).eval()
1396        m = prepare_jit(model, qconfig_dict)
1397        FileCheck().check("aten::conv1d").check_not("aten::_convolution").run(
1398            str(get_forward_graph(m.conv1d._c))
1399        )
1400        FileCheck().check("aten::conv2d").check_not("aten::_convolution").run(
1401            str(get_forward_graph(m.conv2d._c))
1402        )
1403        FileCheck().check("aten::conv3d").check_not("aten::_convolution").run(
1404            str(get_forward_graph(m.conv3d._c))
1405        )
1406
1407    def test_convtranspose_trace(self):
1408        class M(torch.nn.Module):
1409            def __init__(self) -> None:
1410                super().__init__()
1411                self.convtranspose1d = torch.nn.ConvTranspose1d(3, 3, 3).float()
1412                self.convtranspose2d = torch.nn.ConvTranspose2d(3, 3, 3).float()
1413                self.convtranspose3d = torch.nn.ConvTranspose3d(3, 3, 3).float()
1414
1415            def forward(self, x, y, z):
1416                a = self.convtranspose1d(x)
1417                b = self.convtranspose2d(y)
1418                c = self.convtranspose3d(z)
1419                return (a, b, c)
1420
1421        qconfig_dict = {"": default_qconfig}
1422        inputs = (
1423            torch.rand((1, 3, 10), dtype=torch.float),
1424            torch.rand((1, 3, 10, 10), dtype=torch.float),
1425            torch.rand((1, 3, 10, 10, 10), dtype=torch.float),
1426        )
1427        model = torch.jit.trace(M(), inputs).eval()
1428        m = prepare_jit(model, qconfig_dict)
1429        FileCheck().check("aten::conv_transpose1d").check_not("aten::_convolution").run(
1430            str(get_forward_graph(m.convtranspose1d._c))
1431        )
1432        FileCheck().check("aten::conv_transpose2d").check_not("aten::_convolution").run(
1433            str(get_forward_graph(m.convtranspose2d._c))
1434        )
1435        FileCheck().check("aten::conv_transpose3d").check_not("aten::_convolution").run(
1436            str(get_forward_graph(m.convtranspose3d._c))
1437        )
1438
1439    @unittest.skipUnless(
1440        "fbgemm" in torch.backends.quantized.supported_engines,
1441        " Quantized operations require FBGEMM. FBGEMM is only optimized for CPUs"
1442        " with instruction set support avx2 or newer.",
1443    )
1444    def test_replicate_dequant_same_value(self):
1445        class Mul(torch.nn.Module):
1446            def __init__(self) -> None:
1447                super().__init__()
1448                self.conv = torch.nn.Conv2d(3, 3, 3).float()
1449
1450            def forward(self, x):
1451                x = self.conv(x)
1452                return x * x
1453
1454        data = [[torch.rand((1, 3, 10, 10), dtype=torch.float)]]
1455        qconfig_dict = {"": default_qconfig}
1456        model = torch.jit.script(Mul()).eval()
1457        m = quantize_jit(model, qconfig_dict, test_only_eval_fn, [data])
1458        FileCheck().check("quantized::mul(").check_not("aten::mul").run(m.graph)
1459
1460    def test_interface_with_fork(self):
1461        class SubModule(torch.nn.Module):
1462            def __init__(self) -> None:
1463                super().__init__()
1464                self.embedding1 = torch.nn.EmbeddingBag(
1465                    num_embeddings=10,
1466                    embedding_dim=12,
1467                    include_last_offset=True,
1468                    sparse=False,
1469                    mode="sum",
1470                )
1471
1472            def forward(self, x, y):
1473                return self.embedding1(x, y)
1474
1475        class OrigMod(torch.nn.Module):
1476            def __init__(self) -> None:
1477                super().__init__()
1478                self.embedding1 = torch.nn.EmbeddingBag(
1479                    num_embeddings=10,
1480                    embedding_dim=12,
1481                    include_last_offset=True,
1482                    sparse=False,
1483                    mode="sum",
1484                )
1485
1486            def forward(self, x, y):
1487                return self.embedding1(x, y)
1488
1489        @torch.jit.interface
1490        class ModInterface(torch.nn.Module):
1491            def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
1492                pass
1493
1494        class TestModule(torch.nn.Module):
1495            proxy_mod: ModInterface
1496
1497            def __init__(self) -> None:
1498                super().__init__()
1499                self.proxy_mod = OrigMod()
1500                self.sub = SubModule()
1501
1502            def forward(self, x, y):
1503                a = self.proxy_mod(x, y)
1504                b = self.sub(x, y)
1505                return b
1506
1507        class MainModule(torch.nn.Module):
1508            def __init__(self) -> None:
1509                super().__init__()
1510                self.test = TestModule()
1511
1512            def forward(self, x, y):
1513                fut = torch.jit._fork(self.test.forward, x, y)
1514                z = torch.jit._wait(fut)
1515                return z
1516
1517        indices = torch.tensor(
1518            [
1519                9,
1520                6,
1521                5,
1522                7,
1523                8,
1524                8,
1525                9,
1526                2,
1527                8,
1528                6,
1529                6,
1530                9,
1531                1,
1532                6,
1533                8,
1534                8,
1535                3,
1536                2,
1537                3,
1538                6,
1539                3,
1540                6,
1541                5,
1542                7,
1543                0,
1544                8,
1545                4,
1546                6,
1547                5,
1548                8,
1549                2,
1550                3,
1551            ]
1552        )
1553        offsets = torch.tensor([0, 19, 20, 28, 28, 32])
1554        m = torch.jit.trace(MainModule(), (indices, offsets))
1555        m.eval()
1556
1557        int8_qconfig = QConfig(
1558            activation=PlaceholderObserver.with_args(
1559                dtype=torch.float, custom_op_name="embedding_bag_byte"
1560            ),
1561            weight=PlaceholderObserver.with_args(custom_op_name="embedding_bag_byte"),
1562        )
1563
1564        m = prepare_jit(m, {"": int8_qconfig})
1565        m = convert_jit(m)
1566        FileCheck().check("quantized::embedding_bag_byte_rowwise_offsets").run(m.graph)
1567
1568    @skipIfNoFBGEMM
1569    def test_quantize_fork_wait(self):
1570        """Tests the case where fork and wait calls are in different subgraphs
1571        Calling inline fork-wait only removes the fork call and leaves aten::wait
1572        calls in the graph, with Tensor as input (instead of Future[Tensor])
1573        """
1574
1575        class MainModule(nn.Module):
1576            def __init__(self) -> None:
1577                super().__init__()
1578                self.fork_ops = ForkModule()
1579
1580            def init_values(self, x):
1581                shared_module = self.fork_ops(x)
1582                self.fork_dict = shared_module
1583
1584            def forward(self, x):
1585                val = torch.jit._wait(self.fork_ops(x))
1586                return val
1587
1588        class TestModule(torch.nn.Module):
1589            def forward(self, x):
1590                w = torch.ones(5, 5)
1591                b = torch.zeros(5)
1592                return torch.nn.functional.linear(x, w, b)
1593
1594        class ForkModule(nn.Module):
1595            def __init__(self) -> None:
1596                super().__init__()
1597                self.test = TestModule()
1598
1599            def forward(self, x):
1600                fut = torch.jit._fork(self.test.forward, x)
1601                return fut
1602
1603        model = MainModule().eval()
1604        traced = torch.jit.trace(model, (torch.randn(5, 5),))
1605        model = prepare_dynamic_jit(traced, {"": default_qconfig})
1606        model = convert_dynamic_jit(model)
1607        FileCheck().check("quantized::linear_dynamic").run(model.graph)
1608        # Make sure model save works
1609        b = io.BytesIO()
1610        torch.jit.save(model, b)
1611
1612
1613class TestQuantizeJitOps(QuantizationTestCase):
1614    """Test graph mode post training static quantization works
1615    for individual ops end to end.
1616    """
1617
1618    @skipIfNoFBGEMM
1619    def test_linear(self):
1620        class ModuleLinear(torch.nn.Module):
1621            def __init__(self, has_relu=False, f_relu=False):
1622                super().__init__()
1623                self.linear = torch.nn.Linear(30, 4).float()
1624                if has_relu:
1625                    if f_relu:
1626                        self.relu = F.relu
1627                    else:
1628                        self.relu = torch.nn.ReLU()
1629                else:
1630                    self.relu = torch.nn.Identity()
1631
1632            def forward(self, x):
1633                return self.relu(self.linear(x))
1634
1635        class FuncLinear(torch.nn.Module):
1636            def __init__(self, has_relu=False, f_relu=False):
1637                super().__init__()
1638                self.w = torch.randn(4, 30)
1639                self.b = torch.randn(4)
1640                if has_relu:
1641                    if f_relu:
1642                        self.relu = F.relu
1643                    else:
1644                        self.relu = torch.nn.ReLU()
1645                else:
1646                    self.relu = torch.nn.Identity()
1647
1648            def forward(self, x):
1649                return self.relu(F.linear(x, self.w, self.b))
1650
1651        data = [[torch.rand((1, 30), dtype=torch.float)]]
1652        for model, tracing in itertools.product(
1653            [ModuleLinear(has_relu=False), FuncLinear(has_relu=False)], [True, False]
1654        ):
1655            model = self.checkGraphModeOp(model, data, "quantized::linear", tracing)
1656            FileCheck().check_count("aten::quantize_per_tensor", 1, exactly=True).run(
1657                model.graph
1658            )
1659            FileCheck().check_not("quantized::linear_prepack").run(model.graph)
1660
1661        for f_relu, tracing in itertools.product([True, False], [True, False]):
1662            for model in [
1663                ModuleLinear(has_relu=True, f_relu=f_relu),
1664                FuncLinear(has_relu=True, f_relu=f_relu),
1665            ]:
1666                model = self.checkGraphModeOp(
1667                    model, data, "quantized::linear_relu", tracing
1668                )
1669                checker = (
1670                    FileCheck()
1671                    .check_not("aten::linear")
1672                    .check_not("aten::relu")
1673                    .check_not("quantized::linear(")
1674                    .check_not("quantized::relu(")
1675                    .run(model.graph)
1676                )
1677
1678    @skipIfNoFBGEMM
1679    def test_quantized_conv(self):
1680        conv_module = {1: torch.nn.Conv1d, 2: torch.nn.Conv2d, 3: torch.nn.Conv3d}
1681
1682        class Conv(torch.nn.Module):
1683            def __init__(self, dim):
1684                super().__init__()
1685                self.conv = conv_module[dim](3, 3, 3).float()
1686
1687            def forward(self, x):
1688                return self.conv(x)
1689
1690        options = itertools.product([1, 2, 3], [True, False])
1691        for dim, tracing in options:
1692            model = self.checkGraphModeOp(
1693                Conv(dim),
1694                self.img_data_dict[dim],
1695                f"quantized::conv{dim}d",
1696                tracing,
1697            )
1698            # make sure there is only one quantize_per_tensor for input
1699            # and conv2d_prepack is folded
1700            FileCheck().check_count("aten::quantize_per_tensor", 1, exactly=True).run(
1701                model.graph
1702            )
1703
1704            FileCheck().check_not(f"quantized::conv{dim}d_prepack").run(model.graph)
1705
1706    @skipIfNoFBGEMM
1707    def test_quantized_conv_relu(self):
1708        """tests for conv1d_relu/conv2d_relu/conv3d_relu"""
1709        conv_module = {1: torch.nn.Conv1d, 2: torch.nn.Conv2d, 3: torch.nn.Conv3d}
1710
1711        class ConvNdRelu(torch.nn.Module):
1712            def __init__(self, dim, inplace):
1713                super().__init__()
1714                self.conv = conv_module[dim](3, 3, 3).float()
1715                self.relu = torch.nn.ReLU(inplace)
1716
1717            def forward(self, x):
1718                return self.relu(self.conv(x))
1719
1720        class ConvNdFunctionalRelu(torch.nn.Module):
1721            def __init__(self, dim):
1722                super().__init__()
1723                self.conv = conv_module[dim](3, 3, 3).float()
1724
1725            def forward(self, x):
1726                return F.relu(self.conv(x))
1727
1728        class ConvNdInplaceFunctionalRelu(torch.nn.Module):
1729            def __init__(self, dim):
1730                super().__init__()
1731                self.conv = conv_module[dim](3, 3, 3).float()
1732
1733            def forward(self, x):
1734                return F.relu(self.conv(x), True)
1735
1736        options = itertools.product([1, 2, 3], [True, False])
1737        for dim, tracing in options:
1738            for orig_m in [
1739                ConvNdRelu(dim, True),
1740                ConvNdRelu(dim, False),
1741                ConvNdFunctionalRelu(dim),
1742                ConvNdInplaceFunctionalRelu(dim),
1743            ]:
1744                conv_name = f"conv{dim}d"
1745                m = self.checkGraphModeOp(
1746                    orig_m,
1747                    self.img_data_dict[dim],
1748                    f"quantized::conv{dim}d_relu(",
1749                    tracing=tracing,
1750                )
1751
1752                FileCheck().check_not(f"aten::conv{dim}d(").check_not(
1753                    "aten::relu"
1754                ).check_not(f"quantized::conv{dim}d(").check_not(
1755                    "quantized::relu("
1756                ).run(
1757                    m.graph
1758                )
1759
1760    @skipIfNoFBGEMM
1761    def test_quantized_add_alpha(self):
1762        """Test quant fusion for multiple aten::add using same
1763        constant alpha as the third argument
1764        """
1765
1766        class QuantizedAdd(torch.nn.Module):
1767            def __init__(self) -> None:
1768                super().__init__()
1769                self.conv1 = torch.nn.Conv2d(2, 2, 2).float()
1770                self.conv2 = torch.nn.Conv2d(2, 2, 2).float()
1771
1772            def forward(self, x, y):
1773                x = self.conv1(x)
1774                y = self.conv2(y)
1775                z = x + y
1776                w = y + z
1777                return z + w
1778
1779        data = [
1780            [
1781                torch.randn(1, 2, 5, 5, dtype=torch.float),
1782                torch.randn(1, 2, 5, 5, dtype=torch.float),
1783            ]
1784        ]
1785        for tracing in [True, False]:
1786            m = self.checkGraphModeOp(QuantizedAdd(), data, "quantized::add", tracing)
1787            FileCheck().check_count("quantized::add", 3, exactly=True).run(m.graph)
1788            FileCheck().check_not("aten::add").check_not("aten::add_").run(m.graph)
1789
1790    @skipIfNoFBGEMM
1791    def test_quantized_add_relu_alpha(self):
1792        """Test quant fusion for multiple aten::add using same
1793        constant alpha as the third argument in add_relu pattern
1794        """
1795
1796        class AddRelu(torch.nn.Module):
1797            def __init__(self, inplace):
1798                super().__init__()
1799                self.conv1 = torch.nn.Conv2d(2, 2, 2).float()
1800                self.conv2 = torch.nn.Conv2d(2, 2, 2).float()
1801                self.relu = torch.nn.ReLU(inplace)
1802
1803            def forward(self, x, y):
1804                x = self.conv1(x)
1805                y = self.conv2(y)
1806                x = x + y
1807                x = self.relu(x)
1808                x = x + y
1809                return self.relu(x)
1810
1811        class InplaceAddRelu(torch.nn.Module):
1812            def __init__(self, inplace):
1813                super().__init__()
1814                self.conv1 = torch.nn.Conv2d(2, 2, 2).float()
1815                self.conv2 = torch.nn.Conv2d(2, 2, 2).float()
1816                self.relu = torch.nn.ReLU(inplace)
1817
1818            def forward(self, x, y):
1819                x = self.conv1(x)
1820                y = self.conv2(y)
1821                x += y
1822                x = self.relu(x)
1823                x += y
1824                return self.relu(x)
1825
1826        class AddFunctionalRelu(torch.nn.Module):
1827            def __init__(self) -> None:
1828                super().__init__()
1829                self.conv1 = torch.nn.Conv2d(2, 2, 2).float()
1830                self.conv2 = torch.nn.Conv2d(2, 2, 2).float()
1831
1832            def forward(self, x, y):
1833                x = self.conv1(x)
1834                y = self.conv2(y)
1835                x = x + y
1836                x = F.relu(x)
1837                x = x + y
1838                return F.relu(x)
1839
1840        class InplaceAddFunctionalRelu(torch.nn.Module):
1841            def __init__(self) -> None:
1842                super().__init__()
1843                self.conv1 = torch.nn.Conv2d(2, 2, 2).float()
1844                self.conv2 = torch.nn.Conv2d(2, 2, 2).float()
1845
1846            def forward(self, x, y):
1847                x = self.conv1(x)
1848                y = self.conv2(y)
1849                x += y
1850                x = F.relu(x)
1851                x += y
1852                return F.relu(x)
1853
1854        class AddInplaceFunctionalRelu(torch.nn.Module):
1855            def __init__(self) -> None:
1856                super().__init__()
1857                self.conv1 = torch.nn.Conv2d(2, 2, 2).float()
1858                self.conv2 = torch.nn.Conv2d(2, 2, 2).float()
1859
1860            def forward(self, x, y):
1861                x = self.conv1(x)
1862                y = self.conv2(y)
1863                x = x + y
1864                x = F.relu(x, True)
1865                x = x + y
1866                return F.relu(x, True)
1867
1868        class InplaceAddInplaceFunctionalRelu(torch.nn.Module):
1869            def __init__(self) -> None:
1870                super().__init__()
1871                self.conv1 = torch.nn.Conv2d(2, 2, 2).float()
1872                self.conv2 = torch.nn.Conv2d(2, 2, 2).float()
1873
1874            def forward(self, x, y):
1875                x = self.conv1(x)
1876                y = self.conv2(y)
1877                x += y
1878                x = F.relu(x, True)
1879                x += y
1880                return F.relu(x, True)
1881
1882        data = [
1883            [
1884                torch.rand((1, 2, 5, 5), dtype=torch.float),
1885                torch.rand((1, 2, 5, 5), dtype=torch.float),
1886            ]
1887        ]
1888        for m_orig in [
1889            AddRelu(True),
1890            AddRelu(False),
1891            InplaceAddRelu(True),
1892            InplaceAddRelu(False),
1893            AddFunctionalRelu(),
1894            InplaceAddFunctionalRelu(),
1895            AddInplaceFunctionalRelu(),
1896            InplaceAddInplaceFunctionalRelu(),
1897        ]:
1898            for tracing in [True, False]:
1899                m = self.checkGraphModeOp(
1900                    m_orig, data, "quantized::add_relu(", tracing=tracing
1901                )
1902                FileCheck().check_count("quantized::add_relu(", 2, exactly=True).run(
1903                    m.graph
1904                )
1905                FileCheck().check_not("aten::add(").check_not("aten::add_(").check_not(
1906                    "aten::relu("
1907                ).check_not("aten::relu_(").check_not("quantized::add(").check_not(
1908                    "quantized::relu("
1909                ).run(
1910                    m.graph
1911                )
1912
1913    @skipIfNoFBGEMM
1914    def test_quantized_add(self):
1915        class QuantizedAdd(torch.nn.Module):
1916            def __init__(self) -> None:
1917                super().__init__()
1918                self.conv1 = torch.nn.Conv2d(2, 2, 2).float()
1919                self.conv2 = torch.nn.Conv2d(2, 2, 2).float()
1920
1921            def forward(self, x, y):
1922                x = self.conv1(x)
1923                y = self.conv2(y)
1924                return x + y
1925
1926        class QuantizedInplaceAdd(torch.nn.Module):
1927            def __init__(self) -> None:
1928                super().__init__()
1929                self.conv1 = torch.nn.Conv2d(2, 2, 2).float()
1930                self.conv2 = torch.nn.Conv2d(2, 2, 2).float()
1931
1932            def forward(self, x, y):
1933                x = self.conv1(x)
1934                y = self.conv2(y)
1935                x += y
1936                return x
1937
1938        class NonQuantizedAdd(torch.nn.Module):
1939            def forward(self, x, y):
1940                return x + y
1941
1942        class NonQuantizedInplaceAdd(torch.nn.Module):
1943            def forward(self, x, y):
1944                x += y
1945                return x
1946
1947        data = [
1948            [
1949                torch.randn(1, 2, 3, 3, dtype=torch.float),
1950                torch.randn(1, 2, 3, 3, dtype=torch.float),
1951            ]
1952        ]
1953        for m, quantized in [
1954            (QuantizedAdd(), True),
1955            (QuantizedInplaceAdd(), True),
1956            (NonQuantizedAdd(), False),
1957            (NonQuantizedInplaceAdd(), False),
1958        ]:
1959            for tracing in [True, False]:
1960                op = "quantized::add" if quantized else "aten::add"
1961                m = self.checkGraphModeOp(m, data, op, tracing)
1962                # TODO: remove after refactor of checkGraphModeOp
1963                if quantized:
1964                    FileCheck().check_not("aten::add").check_not("aten::add_").run(
1965                        m.graph
1966                    )
1967                else:
1968                    FileCheck().check_not("quantized::add").run(m.graph)
1969
1970    @skipIfNoFBGEMM
1971    def test_quantized_add_scalar(self):
1972        class QuantizedAddScalar(torch.nn.Module):
1973            def __init__(self) -> None:
1974                super().__init__()
1975                self.conv = torch.nn.Conv2d(2, 2, 2).float()
1976
1977            def forward(self, x):
1978                x = self.conv(x)
1979                return x + 3
1980
1981        class QuantizedInplaceAddScalar(torch.nn.Module):
1982            def __init__(self) -> None:
1983                super().__init__()
1984                self.conv = torch.nn.Conv2d(2, 2, 2).float()
1985
1986            def forward(self, x):
1987                x = self.conv(x)
1988                x += 3
1989                return x
1990
1991        class NonQuantizedAddScalar(torch.nn.Module):
1992            def forward(self, x):
1993                return x + 3
1994
1995        class NonQuantizedInplaceAddScalar(torch.nn.Module):
1996            def forward(self, x):
1997                x += 3
1998                return x
1999
2000        data = [[torch.randn(1, 2, 3, 3, dtype=torch.float)]]
2001        for m, quantized in [
2002            (QuantizedAddScalar(), True),
2003            (QuantizedInplaceAddScalar(), True),
2004            (NonQuantizedAddScalar(), False),
2005            (NonQuantizedInplaceAddScalar(), False),
2006        ]:
2007            for tracing in [True, False]:
2008                op = "quantized::add_scalar" if quantized else "aten::add"
2009                # we don't check the numerical consistency for add_scalar
2010                # since it's not supported
2011                m = self.checkGraphModeOp(m, data, op, tracing, check=False)
2012                # TODO: remove after refactor of checkGraphModeOp
2013                if quantized:
2014                    FileCheck().check_not("aten::add").check_not("aten::add_").run(
2015                        m.graph
2016                    )
2017                else:
2018                    FileCheck().check_not("quantized::add_scalar").run(m.graph)
2019
2020    @skipIfNoFBGEMM
2021    def test_quantized_add_relu(self):
2022        class AddRelu(torch.nn.Module):
2023            def __init__(self, inplace):
2024                super().__init__()
2025                self.conv1 = torch.nn.Conv2d(2, 2, 2).float()
2026                self.conv2 = torch.nn.Conv2d(2, 2, 2).float()
2027                self.relu = torch.nn.ReLU(inplace)
2028
2029            def forward(self, x, y):
2030                x = self.conv1(x)
2031                y = self.conv2(y)
2032                x = x + y
2033                return self.relu(x)
2034
2035        class InplaceAddRelu(torch.nn.Module):
2036            def __init__(self, inplace):
2037                super().__init__()
2038                self.conv1 = torch.nn.Conv2d(2, 2, 2).float()
2039                self.conv2 = torch.nn.Conv2d(2, 2, 2).float()
2040                self.relu = torch.nn.ReLU(inplace)
2041
2042            def forward(self, x, y):
2043                x = self.conv1(x)
2044                y = self.conv2(y)
2045                x += y
2046                return self.relu(x)
2047
2048        class AddFunctionalRelu(torch.nn.Module):
2049            def __init__(self) -> None:
2050                super().__init__()
2051                self.conv1 = torch.nn.Conv2d(2, 2, 2).float()
2052                self.conv2 = torch.nn.Conv2d(2, 2, 2).float()
2053
2054            def forward(self, x, y):
2055                x = self.conv1(x)
2056                y = self.conv2(y)
2057                x = x + y
2058                return F.relu(x)
2059
2060        class InplaceAddFunctionalRelu(torch.nn.Module):
2061            def __init__(self) -> None:
2062                super().__init__()
2063                self.conv1 = torch.nn.Conv2d(2, 2, 2).float()
2064                self.conv2 = torch.nn.Conv2d(2, 2, 2).float()
2065
2066            def forward(self, x, y):
2067                x = self.conv1(x)
2068                y = self.conv2(y)
2069                x += y
2070                return F.relu(x)
2071
2072        class AddInplaceFunctionalRelu(torch.nn.Module):
2073            def __init__(self) -> None:
2074                super().__init__()
2075                self.conv1 = torch.nn.Conv2d(2, 2, 2).float()
2076                self.conv2 = torch.nn.Conv2d(2, 2, 2).float()
2077
2078            def forward(self, x, y):
2079                x = self.conv1(x)
2080                y = self.conv2(y)
2081                x = x + y
2082                return F.relu(x, True)
2083
2084        class InplaceAddInplaceFunctionalRelu(torch.nn.Module):
2085            def __init__(self) -> None:
2086                super().__init__()
2087                self.conv1 = torch.nn.Conv2d(2, 2, 2).float()
2088                self.conv2 = torch.nn.Conv2d(2, 2, 2).float()
2089
2090            def forward(self, x, y):
2091                x = self.conv1(x)
2092                y = self.conv2(y)
2093                x += y
2094                return F.relu(x, True)
2095
2096        data = [
2097            [
2098                torch.rand((1, 2, 5, 5), dtype=torch.float),
2099                torch.rand((1, 2, 5, 5), dtype=torch.float),
2100            ]
2101        ]
2102        for m in [
2103            AddRelu(True),
2104            AddRelu(False),
2105            InplaceAddRelu(True),
2106            InplaceAddRelu(False),
2107            AddFunctionalRelu(),
2108            InplaceAddFunctionalRelu(),
2109            AddInplaceFunctionalRelu(),
2110            InplaceAddInplaceFunctionalRelu(),
2111        ]:
2112            for tracing in [True, False]:
2113                m = self.checkGraphModeOp(m, data, "quantized::add_relu(", tracing)
2114                FileCheck().check_not("aten::add(").check_not("aten::add_(").check_not(
2115                    "aten::relu("
2116                ).check_not("aten::relu_(").check_not("quantized::add(").check_not(
2117                    "quantized::relu("
2118                ).run(
2119                    m.graph
2120                )
2121
2122    @skipIfNoFBGEMM
2123    def test_quantized_add_scalar_relu(self):
2124        class AddScalarRelu(torch.nn.Module):
2125            def __init__(self, inplace):
2126                super().__init__()
2127                self.conv = torch.nn.Conv2d(2, 2, 2).float()
2128                self.relu = torch.nn.ReLU(inplace)
2129
2130            def forward(self, x):
2131                x = self.conv(x)
2132                return self.relu(x + 3)
2133
2134        class InplaceAddScalarRelu(torch.nn.Module):
2135            def __init__(self, inplace):
2136                super().__init__()
2137                self.conv = torch.nn.Conv2d(2, 2, 2).float()
2138                self.relu = torch.nn.ReLU(inplace)
2139
2140            def forward(self, x):
2141                x = self.conv(x)
2142                x += 3
2143                return self.relu(x)
2144
2145        class AddScalarFunctionalRelu(torch.nn.Module):
2146            def __init__(self) -> None:
2147                super().__init__()
2148                self.conv = torch.nn.Conv2d(2, 2, 2).float()
2149
2150            def forward(self, x):
2151                x = self.conv(x)
2152                return F.relu(x + 3)
2153
2154        class InplaceAddScalarFunctionalRelu(torch.nn.Module):
2155            def __init__(self) -> None:
2156                super().__init__()
2157                self.conv = torch.nn.Conv2d(2, 2, 2).float()
2158
2159            def forward(self, x):
2160                x = self.conv(x)
2161                x += 3
2162                return F.relu(x)
2163
2164        class AddScalarInplaceFunctionalRelu(torch.nn.Module):
2165            def __init__(self) -> None:
2166                super().__init__()
2167                self.conv = torch.nn.Conv2d(2, 2, 2).float()
2168
2169            def forward(self, x):
2170                x = self.conv(x)
2171                return F.relu(x + 3, True)
2172
2173        class InplaceAddScalarInplaceFunctionalRelu(torch.nn.Module):
2174            def __init__(self) -> None:
2175                super().__init__()
2176                self.conv = torch.nn.Conv2d(2, 2, 2).float()
2177
2178            def forward(self, x):
2179                x = self.conv(x)
2180                x += 3
2181                return F.relu(x, True)
2182
2183        data = [[torch.rand((1, 2, 5, 5), dtype=torch.float)]]
2184        for m in [
2185            AddScalarRelu(True),
2186            AddScalarRelu(False),
2187            InplaceAddScalarRelu(True),
2188            InplaceAddScalarRelu(False),
2189            AddScalarFunctionalRelu(),
2190            InplaceAddScalarFunctionalRelu(),
2191            AddScalarInplaceFunctionalRelu(),
2192            InplaceAddScalarInplaceFunctionalRelu(),
2193        ]:
2194            for tracing in [True, False]:
2195                # quantized::add_scalar_relu or quantized::add_scalar_relu_out
2196                # TODO: split this after refactor of checkGraphModeOp
2197                m = self.checkGraphModeOp(
2198                    m, data, "quantized::add_scalar_relu", tracing, check=False
2199                )
2200                FileCheck().check_not("aten::add(").check_not("aten::add_(").check_not(
2201                    "aten::relu("
2202                ).check_not("aten::relu_(").check_not(
2203                    "quantized::add_scalar("
2204                ).check_not(
2205                    "quantized::relu("
2206                ).run(
2207                    m.graph
2208                )
2209
2210    @skipIfNoFBGEMM
2211    def test_quantized_cat(self):
2212        """quantization of the output of cat will be depend on the
2213        input of cat. we only quantize the output of cat when its inputs are quantized.
2214        """
2215
2216        class QuantizedCat(torch.nn.Module):
2217            def __init__(self) -> None:
2218                super().__init__()
2219                self.conv1 = torch.nn.Conv2d(2, 2, 2).float()
2220                self.conv2 = torch.nn.Conv2d(2, 2, 2).float()
2221
2222            def forward(self, x, y):
2223                x = self.conv1(x)
2224                y = self.conv2(y)
2225                return torch.cat([x, y], 1)
2226
2227        class NonQuantizedCat(torch.nn.Module):
2228            def forward(self, x, y):
2229                return torch.cat([x, y], 1)
2230
2231        data = [
2232            [
2233                torch.randn(1, 2, 5, 5, dtype=torch.float),
2234                torch.randn(1, 2, 5, 5, dtype=torch.float),
2235            ]
2236        ]
2237        for tracing in [True, False]:
2238            m = self.checkGraphModeOp(QuantizedCat(), data, "quantized::cat", tracing)
2239            FileCheck().check_not("aten::cat").run(m.graph)
2240
2241            m = self.checkGraphModeOp(NonQuantizedCat(), data, "aten::cat", tracing)
2242            FileCheck().check_not("quantized::cat").run(m.graph)
2243
2244    @skipIfNoFBGEMM
2245    def test_qbatch_norm(self):
2246        bn_module = {
2247            1: torch.nn.BatchNorm1d,
2248            2: torch.nn.BatchNorm2d,
2249            3: torch.nn.BatchNorm3d,
2250        }
2251
2252        class M(torch.nn.Module):
2253            def __init__(self, dim):
2254                super().__init__()
2255                self.bn = bn_module[dim](3).to(torch.float)
2256
2257            def forward(self, x):
2258                return self.bn(x)
2259
2260        options = itertools.product([True, False], [1, 2, 3])
2261        for tracing, dim in options:
2262            model = self.checkGraphModeOp(
2263                M(dim), self.img_data_dict[dim], "quantized::batch_norm", tracing
2264            )
2265
2266            FileCheck().check_not("aten::batch_norm").run(model.graph)
2267
2268    @skipIfNoFBGEMM
2269    def test_qbatch_norm_relu_BNRelu(self):
2270        bn_module = {2: torch.nn.BatchNorm2d, 3: torch.nn.BatchNorm3d}
2271
2272        class BNRelu(torch.nn.Module):
2273            def __init__(self, dim, inplace):
2274                super().__init__()
2275                self.bn = bn_module[dim](3).to(torch.float)
2276                self.relu = torch.nn.ReLU(inplace=inplace)
2277
2278            def forward(self, x):
2279                return self.relu(self.bn(x))
2280
2281        options = itertools.product([True, False], [2, 3])
2282        for tracing, dim in options:
2283            for instance in [BNRelu(dim, True), BNRelu(dim, False)]:
2284                model = self.checkGraphModeOp(
2285                    instance,
2286                    self.img_data_dict[dim],
2287                    "quantized::batch_norm_relu",
2288                    tracing,
2289                )
2290                FileCheck().check_not("aten::batch_norm").check_not(
2291                    "aten::relu"
2292                ).check_not("aten::relu_").run(model.graph)
2293
2294    @skipIfNoFBGEMM
2295    def test_qbatch_norm_relu_BNFuncRelu(self):
2296        bn_module = {2: torch.nn.BatchNorm2d, 3: torch.nn.BatchNorm3d}
2297
2298        class BNFuncRelu(torch.nn.Module):
2299            def __init__(self, dim):
2300                super().__init__()
2301                self.bn = bn_module[dim](3).to(torch.float)
2302
2303            def forward(self, x):
2304                return F.relu(self.bn(x), False)
2305
2306        options = itertools.product([True, False], [2, 3])
2307        for tracing, dim in options:
2308            instance = BNFuncRelu(dim)
2309            model = self.checkGraphModeOp(
2310                instance, self.img_data_dict[dim], "quantized::batch_norm_relu", tracing
2311            )
2312            FileCheck().check_not("aten::batch_norm").check_not("aten::relu").check_not(
2313                "aten::relu_"
2314            ).run(model.graph)
2315
2316    @skipIfNoFBGEMM
2317    def test_qbatch_norm_relu_BNFuncInplaceRelu(self):
2318        bn_module = {2: torch.nn.BatchNorm2d, 3: torch.nn.BatchNorm3d}
2319
2320        class BNFuncInplaceRelu(torch.nn.Module):
2321            def __init__(self, dim):
2322                super().__init__()
2323                self.bn = bn_module[dim](3).to(torch.float)
2324
2325            def forward(self, x):
2326                return F.relu(self.bn(x), True)
2327
2328        options = itertools.product([True, False], [2, 3])
2329        for tracing, dim in options:
2330            instance = BNFuncInplaceRelu(dim)
2331            model = self.checkGraphModeOp(
2332                instance, self.img_data_dict[dim], "quantized::batch_norm_relu", tracing
2333            )
2334            FileCheck().check_not("aten::batch_norm").check_not("aten::relu").check_not(
2335                "aten::relu_"
2336            ).run(model.graph)
2337
2338    @skipIfNoFBGEMM
2339    def test_quantized_mul(self):
2340        class QuantizedMul(torch.nn.Module):
2341            def __init__(self) -> None:
2342                super().__init__()
2343                self.conv1 = torch.nn.Conv2d(2, 2, 2).float()
2344                self.conv2 = torch.nn.Conv2d(2, 2, 2).float()
2345
2346            def forward(self, x, y):
2347                x = self.conv1(x)
2348                y = self.conv2(y)
2349                return x * y
2350
2351        class QuantizedInplaceMul(torch.nn.Module):
2352            def __init__(self) -> None:
2353                super().__init__()
2354                self.conv1 = torch.nn.Conv2d(2, 2, 2).float()
2355                self.conv2 = torch.nn.Conv2d(2, 2, 2).float()
2356
2357            def forward(self, x, y):
2358                x = self.conv1(x)
2359                y = self.conv2(y)
2360                x *= y
2361                return x
2362
2363        class NonQuantizedMul(torch.nn.Module):
2364            def forward(self, x, y):
2365                return x * y
2366
2367        class NonQuantizedInplaceMul(torch.nn.Module):
2368            def forward(self, x, y):
2369                x *= y
2370                return x
2371
2372        data = [
2373            [
2374                torch.randn(1, 2, 10, 10, dtype=torch.float),
2375                torch.randn(1, 2, 10, 10, dtype=torch.float),
2376            ]
2377        ]
2378        for m, quantized in [
2379            (QuantizedMul(), True),
2380            (QuantizedInplaceMul(), True),
2381            (NonQuantizedMul(), False),
2382            (NonQuantizedInplaceMul(), False),
2383        ]:
2384            for tracing in [True, False]:
2385                op = "quantized::mul" if quantized else "aten::mul"
2386                m = self.checkGraphModeOp(m, data, op, tracing)
2387                # TODO: remove after refactor of checkGraphModeOp
2388                if quantized:
2389                    FileCheck().check_not("aten::mul").check_not("aten::mul_").run(
2390                        m.graph
2391                    )
2392                else:
2393                    FileCheck().check_not("quantized::mul").run(m.graph)
2394
2395    @skipIfNoFBGEMM
2396    def test_quantized_mul_scalar(self):
2397        class QuantizedMulScalar(torch.nn.Module):
2398            def __init__(self) -> None:
2399                super().__init__()
2400                self.conv = torch.nn.Conv2d(2, 2, 2).float()
2401
2402            def forward(self, x):
2403                x = self.conv(x)
2404                return x * 3
2405
2406        class QuantizedInplaceMulScalar(torch.nn.Module):
2407            def __init__(self) -> None:
2408                super().__init__()
2409                self.conv = torch.nn.Conv2d(2, 2, 2).float()
2410
2411            def forward(self, x):
2412                x = self.conv(x)
2413                x *= 3
2414                return x
2415
2416        class NonQuantizedMulScalar(torch.nn.Module):
2417            def forward(self, x):
2418                return x * 3
2419
2420        class NonQuantizedInplaceMulScalar(torch.nn.Module):
2421            def forward(self, x):
2422                x *= 3
2423                return x
2424
2425        data = [[torch.randn(1, 2, 5, 5, dtype=torch.float)]]
2426        for m, quantized in [
2427            (QuantizedMulScalar(), True),
2428            (QuantizedInplaceMulScalar(), True),
2429            (NonQuantizedMulScalar(), False),
2430            (NonQuantizedInplaceMulScalar(), False),
2431        ]:
2432            for tracing in [True, False]:
2433                op = "quantized::mul_scalar" if quantized else "aten::mul"
2434                # we don't check the numerical consistency for add_scalar
2435                # since it's not supported
2436                m = self.checkGraphModeOp(m, data, op, tracing, check=False)
2437                # TODO: remove after refactor of checkGraphModeOp
2438                if quantized:
2439                    FileCheck().check_not("aten::mul").check_not("aten::mul_").run(
2440                        m.graph
2441                    )
2442                else:
2443                    FileCheck().check_not("quantized::mul_scalar").run(m.graph)
2444
2445    @skipIfNoFBGEMM
2446    def test_quantized_mul_relu(self):
2447        class MulRelu(torch.nn.Module):
2448            def __init__(self, inplace):
2449                super().__init__()
2450                self.conv1 = torch.nn.Conv2d(2, 2, 2).float()
2451                self.conv2 = torch.nn.Conv2d(2, 2, 2).float()
2452                self.relu = torch.nn.ReLU(inplace)
2453
2454            def forward(self, x, y):
2455                x = self.conv1(x)
2456                y = self.conv2(y)
2457                x = x * y
2458                return self.relu(x)
2459
2460        class InplaceMulRelu(torch.nn.Module):
2461            def __init__(self, inplace):
2462                super().__init__()
2463                self.conv1 = torch.nn.Conv2d(2, 2, 2).float()
2464                self.conv2 = torch.nn.Conv2d(2, 2, 2).float()
2465                self.relu = torch.nn.ReLU(inplace)
2466
2467            def forward(self, x, y):
2468                x = self.conv1(x)
2469                y = self.conv2(y)
2470                x *= y
2471                return self.relu(x)
2472
2473        class MulFunctionalRelu(torch.nn.Module):
2474            def __init__(self) -> None:
2475                super().__init__()
2476                self.conv1 = torch.nn.Conv2d(2, 2, 2).float()
2477                self.conv2 = torch.nn.Conv2d(2, 2, 2).float()
2478
2479            def forward(self, x, y):
2480                x = self.conv1(x)
2481                y = self.conv2(y)
2482                x = x * y
2483                return F.relu(x)
2484
2485        class InplaceMulFunctionalRelu(torch.nn.Module):
2486            def __init__(self) -> None:
2487                super().__init__()
2488                self.conv1 = torch.nn.Conv2d(2, 2, 2).float()
2489                self.conv2 = torch.nn.Conv2d(2, 2, 2).float()
2490
2491            def forward(self, x, y):
2492                x = self.conv1(x)
2493                y = self.conv2(y)
2494                x *= y
2495                return F.relu(x)
2496
2497        class MulInplaceFunctionalRelu(torch.nn.Module):
2498            def __init__(self) -> None:
2499                super().__init__()
2500                self.conv1 = torch.nn.Conv2d(2, 2, 2).float()
2501                self.conv2 = torch.nn.Conv2d(2, 2, 2).float()
2502
2503            def forward(self, x, y):
2504                x = self.conv1(x)
2505                y = self.conv2(y)
2506                x = x * y
2507                return F.relu(x, True)
2508
2509        class InplaceMulInplaceFunctionalRelu(torch.nn.Module):
2510            def __init__(self) -> None:
2511                super().__init__()
2512                self.conv1 = torch.nn.Conv2d(2, 2, 2).float()
2513                self.conv2 = torch.nn.Conv2d(2, 2, 2).float()
2514
2515            def forward(self, x, y):
2516                x = self.conv1(x)
2517                y = self.conv2(y)
2518                x *= y
2519                return F.relu(x, True)
2520
2521        data = [
2522            [
2523                torch.rand((1, 2, 5, 5), dtype=torch.float),
2524                torch.rand((1, 2, 5, 5), dtype=torch.float),
2525            ]
2526        ]
2527        for m in [
2528            MulRelu(True),
2529            MulRelu(False),
2530            InplaceMulRelu(True),
2531            InplaceMulRelu(False),
2532            MulFunctionalRelu(),
2533            InplaceMulFunctionalRelu(),
2534            MulInplaceFunctionalRelu(),
2535            InplaceMulInplaceFunctionalRelu(),
2536        ]:
2537            for tracing in [True, False]:
2538                m = self.checkGraphModeOp(m, data, "quantized::mul_relu(", tracing)
2539                FileCheck().check_not("aten::mul(").check_not("aten::mul_(").check_not(
2540                    "aten::relu("
2541                ).check_not("aten::relu_(").check_not("quantized::mul(").check_not(
2542                    "quantized::relu("
2543                ).run(
2544                    m.graph
2545                )
2546
2547    @skipIfNoFBGEMM
2548    def test_quantized_mul_scalar_relu(self):
2549        class MulScalarRelu(torch.nn.Module):
2550            def __init__(self, inplace):
2551                super().__init__()
2552                self.conv = torch.nn.Conv2d(2, 2, 2).float()
2553                self.relu = torch.nn.ReLU(inplace)
2554
2555            def forward(self, x):
2556                x = self.conv(x)
2557                return self.relu(x * 3)
2558
2559        class InplaceMulScalarRelu(torch.nn.Module):
2560            def __init__(self, inplace):
2561                super().__init__()
2562                self.conv = torch.nn.Conv2d(2, 2, 2).float()
2563                self.relu = torch.nn.ReLU(inplace)
2564
2565            def forward(self, x):
2566                x = self.conv(x)
2567                x *= 3
2568                return self.relu(x)
2569
2570        class MulScalarFunctionalRelu(torch.nn.Module):
2571            def __init__(self) -> None:
2572                super().__init__()
2573                self.conv = torch.nn.Conv2d(2, 2, 2).float()
2574
2575            def forward(self, x):
2576                x = self.conv(x)
2577                return F.relu(x * 3)
2578
2579        class InplaceMulScalarFunctionalRelu(torch.nn.Module):
2580            def __init__(self) -> None:
2581                super().__init__()
2582                self.conv = torch.nn.Conv2d(2, 2, 2).float()
2583
2584            def forward(self, x):
2585                x = self.conv(x)
2586                x *= 3
2587                return F.relu(x)
2588
2589        class MulScalarInplaceFunctionalRelu(torch.nn.Module):
2590            def __init__(self) -> None:
2591                super().__init__()
2592                self.conv = torch.nn.Conv2d(2, 2, 2).float()
2593
2594            def forward(self, x):
2595                x = self.conv(x)
2596                return F.relu(x * 3, True)
2597
2598        class InplaceMulScalarInplaceFunctionalRelu(torch.nn.Module):
2599            def __init__(self) -> None:
2600                super().__init__()
2601                self.conv = torch.nn.Conv2d(2, 2, 2).float()
2602
2603            def forward(self, x):
2604                x = self.conv(x)
2605                x *= 3
2606                return F.relu(x, True)
2607
2608        data = [[torch.randn(1, 2, 5, 5, dtype=torch.float)]]
2609        for m in [
2610            MulScalarRelu(True),
2611            MulScalarRelu(False),
2612            InplaceMulScalarRelu(True),
2613            InplaceMulScalarRelu(False),
2614            MulScalarFunctionalRelu(),
2615            InplaceMulScalarFunctionalRelu(),
2616            MulScalarInplaceFunctionalRelu(),
2617            InplaceMulScalarInplaceFunctionalRelu(),
2618        ]:
2619            for tracing in [True, False]:
2620                # quantized::mul_scalar_relu or quantized::mul_scalar_relu_out
2621                m = self.checkGraphModeOp(
2622                    m, data, "quantized::mul_scalar_relu", tracing, check=False
2623                )
2624                FileCheck().check_not("aten::mul(").check_not("aten::mul_(").check_not(
2625                    "aten::relu("
2626                ).check_not("aten::relu_(").check_not(
2627                    "quantized::mul_scalar("
2628                ).check_not(
2629                    "quantized::relu("
2630                ).run(
2631                    m.graph
2632                )
2633
2634    @override_qengines
2635    def test_hardswish(self):
2636        class FunctionalHardswish(torch.nn.Module):
2637            def __init__(self, inplace):
2638                super().__init__()
2639                self.inplace = inplace
2640
2641            def forward(self, input):
2642                return torch.nn.functional.hardswish(input, inplace=self.inplace)
2643
2644        modules = [
2645            torch.nn.Hardswish(),
2646            FunctionalHardswish(True),
2647            FunctionalHardswish(False),
2648        ]
2649
2650        for test_case in itertools.product([True, False], modules):
2651            tracing, m = test_case
2652            m = self.checkGraphModeOp(
2653                m, self.img_data_2d, "quantized::hardswish", tracing
2654            )
2655            FileCheck().check_not("aten::hardswish").check_not("aten::hardswish_").run(
2656                m.graph
2657            )
2658
2659    @override_qengines
2660    def test_elu(self):
2661        class FunctionalELU(torch.nn.Module):
2662            def __init__(self, inplace=False):
2663                super().__init__()
2664                self.inplace = inplace
2665
2666            def forward(self, input):
2667                return torch.nn.functional.elu(input, inplace=self.inplace)
2668
2669        modules = [torch.nn.ELU, FunctionalELU]
2670        for test_case in itertools.product([True, False], [True, False], modules):
2671            tracing, inplace, mod_class = test_case
2672            m = mod_class(inplace=inplace)
2673            m = self.checkGraphModeOp(m, self.img_data_2d, "quantized::elu", tracing)
2674            FileCheck().check_not("aten::elu").check_not("aten::elu_").run(m.graph)
2675
2676    @override_qengines
2677    def test_layer_norm(self):
2678        data = [[torch.rand((1, 2, 5, 5), dtype=torch.float)] for _ in range(2)]
2679        layer_norm = torch.nn.LayerNorm([2, 5, 5])
2680        for tracing in [True, False]:
2681            m = self.checkGraphModeOp(
2682                layer_norm, data, "quantized::layer_norm", tracing
2683            )
2684            FileCheck().check_not("aten::layer_norm").run(m.graph)
2685
2686    @override_qengines
2687    def test_group_norm(self):
2688        data = [[torch.rand((1, 4, 5, 5), dtype=torch.float)] for _ in range(2)]
2689        group_norm = torch.nn.GroupNorm(2, 4)
2690        for tracing in [True, False]:
2691            m = self.checkGraphModeOp(
2692                group_norm, data, "quantized::group_norm", tracing
2693            )
2694            FileCheck().check_not("aten::group_norm").run(m.graph)
2695
2696    @override_qengines
2697    def test_instance_norm(self):
2698        data_1d = [[torch.rand((1, 4, 5), dtype=torch.float)] for _ in range(2)]
2699        data_2d = [[torch.rand((1, 4, 5, 1), dtype=torch.float)] for _ in range(2)]
2700        data_3d = [[torch.rand((1, 4, 5, 1, 1), dtype=torch.float)] for _ in range(2)]
2701        data = {1: data_1d, 2: data_2d, 3: data_3d}
2702        instance_norm_modules = {
2703            1: torch.nn.InstanceNorm1d,
2704            2: torch.nn.InstanceNorm2d,
2705            3: torch.nn.InstanceNorm3d,
2706        }
2707
2708        options = itertools.product([1, 2, 3], [True, False])
2709        for dim, tracing in options:
2710            instance_norm = instance_norm_modules[dim](4)
2711            m = self.checkGraphModeOp(
2712                instance_norm, data[dim], "quantized::instance_norm", tracing
2713            )
2714            FileCheck().check_not("aten::instance_norm").run(m.graph)
2715
2716    @skipIfNoFBGEMM
2717    def test_dequantize_tuple(self):
2718        """Make sure dequantize can support Tuple of tensor"""
2719
2720        class M(torch.nn.Module):
2721            def __init__(self) -> None:
2722                super().__init__()
2723                self.conv1 = torch.nn.Conv2d(3, 3, 3).float()
2724                self.conv2 = torch.nn.Conv2d(3, 3, 3).float()
2725
2726            def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
2727                x1 = self.conv1(x)
2728                x2 = self.conv2(x)
2729                return x1, x2
2730
2731        for tracing in [True, False]:
2732            self.checkGraphModeOp(M(), self.img_data_2d, "quantized::conv2d", tracing)
2733
2734    @skipIfNoFBGEMM
2735    def test_clamp(self):
2736        class M(torch.nn.Module):
2737            def __init__(self) -> None:
2738                super().__init__()
2739                self.conv = torch.nn.Conv2d(2, 2, 2).float()
2740                self.relu6 = torch.nn.ReLU6()
2741                self.relu6_ = torch.nn.ReLU6(True)
2742                self.hardtanh = torch.nn.Hardtanh()
2743                self.hardtanh_ = torch.nn.Hardtanh(inplace=True)
2744
2745            def forward(self, x):
2746                x = self.conv(x)
2747                x = self.relu6(x)
2748                self.relu6_(x)
2749                x = F.relu6(x)
2750                x = torch.clamp(x, -3, 3)
2751                x = x.clamp(-2.5, 2.5)
2752                # x = x.clamp_(-2, 2)  # Enable when quantized `clamp_` is ready
2753                x = self.hardtanh(x)
2754                self.hardtanh_(x)
2755                x = F.hardtanh(x)
2756                F.hardtanh_(x)
2757                return x
2758
2759        data = [[torch.rand((1, 2, 5, 5), dtype=torch.float)]]
2760        options = itertools.product(
2761            ["aten::clamp", "aten::hardtanh", "aten::hardtanh_"], [True, False]
2762        )
2763        for op, tracing in options:
2764            m = self.checkGraphModeOp(M(), data, op, tracing)
2765            FileCheck().check_count("aten::quantize_per_tensor", 1, exactly=True).run(
2766                m.graph
2767            )
2768
2769            FileCheck().check_count("aten::dequantize", 1, exactly=True).run(m.graph)
2770
2771    def test_general_shape_ops(self):
2772        """A test that checks dequantize will be swapped for
2773        all supported general shape ops like aten::flatten
2774        without actually checking for execution of these ops
2775        """
2776
2777        class M(torch.nn.Module):
2778            def __init__(self) -> None:
2779                super().__init__()
2780                self.maxpool1d = torch.nn.MaxPool1d(kernel_size=3)
2781                self.maxpool2d = torch.nn.MaxPool2d(kernel_size=3)
2782                self.maxpool3d = torch.nn.MaxPool3d(kernel_size=3)
2783                self.dropout = torch.nn.Dropout()
2784                self.conv1 = torch.nn.Conv2d(3, 3, 3)
2785                self.conv2 = torch.nn.Conv2d(3, 3, 3)
2786                self.relu = torch.nn.ReLU()
2787
2788            def forward(self, x):
2789                x = self.conv1(x)
2790                # add_scalar
2791                x = x + 3
2792                # mul_scalar
2793                x = x * 3
2794                # add_scalar_out
2795                x += 3
2796                # mul_scalar_out
2797                x *= 3
2798                # add_scalar_relu
2799                x = x + 3
2800                x = F.relu(x)
2801                # add_scalar_relu_out
2802                x += 3
2803                x = F.relu(x)
2804                # mul_scalar_relu
2805                x = x * 3
2806                x = F.relu(x)
2807                # mul_scalar_relu_out
2808                x *= 3
2809                x = F.relu(x)
2810                x = self.maxpool1d(x)
2811                x = self.maxpool2d(x)
2812                x = self.maxpool3d(x)
2813                x = torch.flatten(x)
2814                x = torch.max(x)
2815                x = torch.min(x)
2816                x = x.reshape([-1])
2817                x = x.resize_(1, 1, x.numel())
2818                x = x.view(-1)
2819                # prim::ListConstruct
2820                xs = [x, x]
2821                # prim::ListUnpack
2822                x, y = xs
2823                # prim::TupleConstruct
2824                xs = (x, x)
2825                # prim::TupleUnpack
2826                x, y = xs
2827                x = x.transpose(1, 2)
2828                x = x.contiguous()
2829                x, y = torch.chunk(x, 2)
2830                x = F.dropout(x)
2831                x = self.dropout(x)
2832                x, _ = torch.sort(x)
2833                x = x.permute(0, 2, 3, 1)
2834                x = torch.repeat_interleave(x, 3, 1)
2835                x = self.relu(x)
2836                x = F.relu(x)
2837                x.relu_()
2838                x = x.squeeze(0)
2839                x.squeeze_(0)
2840                x = torch.squeeze(x, 0)
2841                x = x.unsqueeze(0)
2842                x.unsqueeze_(0)
2843                x = torch.unsqueeze(x, 0)
2844                x = x.detach()
2845                x.detach_()
2846                x = x.repeat(4, 2)
2847                y = []
2848                y.append(x)
2849                z = torch.stack(y, 0)
2850                z = [z, z]
2851                x, _ = z
2852                x = self.conv2(x)
2853                return x
2854
2855        data = torch.rand(1, 3, 10, 10)
2856        # This model is not executable since we just put all ops
2857        # in the same forward, therefore we only test scripting
2858        m = torch.jit.script(M())
2859        qconfig = script_qconfig(default_qconfig)
2860        # dummy data to suppress warning
2861        get_forward(qconfig.activation)(data)
2862        get_forward(qconfig.weight)(data)
2863
2864        m = wrap_cpp_module(
2865            torch._C._jit_pass_insert_observers(
2866                m._c, "forward", {"": qconfig}, inplace=False
2867            )
2868        )
2869        m = convert_jit(m)
2870        # This checks that the dequantize from the output of first conv
2871        # is being propagated to the end, so that we don't insert extra
2872        # observers and also successfully fused two quantized::conv2d
2873        # patterns
2874        # one quantize_per_tensor for input
2875        FileCheck().check_count("aten::quantize_per_tensor", 1, exactly=True).run(
2876            m.graph
2877        )
2878
2879        FileCheck().check_count("quantized::conv2d(", 2, exactly=True).run(m.graph)
2880
2881        FileCheck().check_count("aten::dequantize", 1, exactly=True).run(m.graph)
2882
2883        FileCheck().check("quantized::add_scalar").check("quantized::mul_scalar").run(
2884            m.graph
2885        )
2886
2887    def test_general_value_ops(self):
2888        """ A test that checks correct patterns are produced for
2889        all supported general value ops like aten::avg_pool2d \
2890        without actually checking for execution of these ops
2891        """
2892
2893        class M(torch.nn.Module):
2894            def __init__(self) -> None:
2895                super().__init__()
2896                self.conv = torch.nn.Conv2d(3, 3, 3)
2897                self.avg_pool1d = torch.nn.AvgPool1d(3)
2898                self.avg_pool2d = torch.nn.AvgPool2d(3)
2899                self.avg_pool3d = torch.nn.AvgPool3d(3)
2900                self.adaptive_avg_pool1d = torch.nn.AdaptiveAvgPool1d(1)
2901                self.adaptive_avg_pool2d = torch.nn.AdaptiveAvgPool2d((1, 1))
2902                self.adaptive_avg_pool3d = torch.nn.AdaptiveAvgPool3d((1, 1, 1))
2903                self.leaky_relu = torch.nn.LeakyReLU()
2904                self.hardsigmoid = torch.nn.Hardsigmoid()
2905                self.sigmoid = torch.nn.Sigmoid()
2906                self.tanh = torch.nn.Tanh()
2907
2908            def forward(self, x):
2909                x = self.conv(x)
2910                x = self.avg_pool1d(x)
2911                x = self.avg_pool2d(x)
2912                x = self.avg_pool3d(x)
2913                x = self.adaptive_avg_pool1d(x)
2914                x = self.adaptive_avg_pool2d(x)
2915                x = self.adaptive_avg_pool3d(x)
2916                x = F.avg_pool1d(x, 3)
2917                x = F.avg_pool2d(x, 3)
2918                x = F.avg_pool3d(x, 3)
2919                x = F.adaptive_avg_pool1d(x, (1))
2920                x = F.adaptive_avg_pool2d(x, (1, 1))
2921                x = F.adaptive_avg_pool3d(x, (1, 1, 1))
2922                x = torch.mean(x)
2923                x = torch.mean(x, [2, 3], False)
2924                x = x.mean()
2925                x = x.mean([2, 3], True)
2926                # interpolate node will introduce 3 quantize_per_tensor ops
2927                x = F.interpolate(x, 4, mode="nearest")  # interpolate node
2928                x = F.upsample(x, (32, 32))  # interpolate node
2929                x = F.upsample_nearest(x, (32, 32))  # interpolate node
2930                x = F.interpolate(x, 4, mode="linear")  # common node
2931                x = F.upsample_bilinear(x, (32, 32))  # common node
2932                x = self.leaky_relu(x)
2933                x = F.leaky_relu(x)
2934                x.leaky_relu_()
2935                x = self.hardsigmoid(x)
2936                x = F.hardsigmoid(x)
2937                x.hardsigmoid_()
2938                x = self.sigmoid(x)
2939                x = torch.sigmoid(x)
2940                # F.sigmoid is deprecated
2941                x = x.sigmoid()
2942                x.sigmoid_()
2943                x = self.tanh(x)
2944                # F.tanh is deprecated
2945                x = torch.tanh(x)
2946                x = x.tanh()
2947                x.tanh_()
2948                x = self.conv(x)
2949                return x
2950
2951        # This model is not executable since we just put all ops
2952        # in the same forward, therefore we only test scripting
2953        m = torch.jit.script(M())
2954        qconfig = script_qconfig(default_qconfig)
2955        # dummy data to suppress warning
2956        data = torch.rand(1, 3, 10, 10)
2957        get_forward(qconfig.activation)(data)
2958        get_forward(qconfig.weight)(data)
2959
2960        m = wrap_cpp_module(
2961            torch._C._jit_pass_insert_observers(
2962                m._c, "forward", {"": qconfig}, inplace=False
2963            )
2964        )
2965        # Checking the model before fianlize contain unfused patterns
2966        # that numerically matches the model after quantize by checking
2967        # number of aten::quantize_per_tensor functions
2968        # conv has 3 quantize_per_tensor for activations and 1 for weight
2969        # and for N general value op between conv we should have
2970
2971        # N + 1 quantize_per_tensor between these ops
2972        m1 = convert_jit(m, debug=True)
2973        # NB: This Needs to be updated when we add more ops to test
2974        # mapping from number of quant for the op to the number of these ops
2975        # for example, for `3` in the key means for this type of op
2976        # we'll have 3 quantize_per_tensor
2977        num_op_by_num_quant = {1: 32, 2: 2, 3: 3}
2978        num_quantize_per_tensor = 1  # for output
2979        for num_quant, num_op in num_op_by_num_quant.items():
2980            num_quantize_per_tensor += num_op * num_quant
2981        num_quantize_per_tensor -= 4  # constant propagation removes some prepacks
2982        FileCheck().check_count(
2983            "aten::quantize_per_tensor(", num_quantize_per_tensor, exactly=True
2984        ).run(m1.graph)
2985
2986        # This checks that the dequantize from the output of first conv
2987        # is being propagated to the end, so that we don't insert extra
2988        # observers and also successfully fused two quantized::conv2d
2989        # patterns
2990        # one quantize_per_tensor for input
2991        m2 = convert_jit(m, debug=False)
2992        FileCheck().check_count("aten::quantize_per_tensor(", 1, exactly=True).run(
2993            m2.graph
2994        )
2995        FileCheck().check_count("quantized::conv2d(", 2, exactly=True).check(
2996            "aten::dequantize("
2997        ).run(m2.graph)
2998
2999    @override_qengines
3000    def test_conv_with_benchmark_flag(self):
3001        r"""Verifies that convolutions get quantized when
3002        torch.backends.cudnn.benchmark is enabled
3003        """
3004        if not qengine_is_qnnpack():
3005            return
3006        with torch.backends.cudnn.flags(enabled=True):
3007            m = torch.nn.Sequential(torch.nn.Conv2d(1, 1, 1))
3008            m.eval()
3009            m = torch.jit.trace(m, torch.rand(4, 1, 4, 4))
3010            qconfig = torch.ao.quantization.get_default_qconfig("qnnpack")
3011            prepared_model = torch.ao.quantization.prepare_jit(m, {"": qconfig})
3012            prepared_model(torch.rand(4, 1, 4, 4))
3013            converted_model = torch.ao.quantization.convert_jit(prepared_model)
3014            FileCheck().check("quantized::conv2d").run(converted_model.graph)
3015
3016    @skipIfNoFBGEMM
3017    def test_cat_linear(self):
3018        class LinearModel(torch.nn.Module):
3019            def __init__(self) -> None:
3020                super().__init__()
3021                self.weight = torch.randn(5, 5)
3022
3023            def forward(self, x, y):
3024                a = torch.cat([x, y])
3025                b = F.linear(a, self.weight)
3026                c = F.linear(b, self.weight)
3027                return b, c
3028
3029        model = LinearModel().eval()
3030        qconfig = {"": default_qconfig}
3031        float_model = torch.jit.script(model)
3032        prepared_model = prepare_jit(float_model, qconfig)
3033        prepared_model(torch.rand(5, 5), torch.rand(5, 5))
3034        converted_model = convert_jit(prepared_model)
3035        FileCheck().check("quantized::linear").check("quantized::linear").run(
3036            converted_model.graph
3037        )
3038
3039
3040class TestQuantizeDynamicJitPasses(QuantizationTestCase):
3041    def test_prepare_dynamic(self):
3042        class M(torch.nn.Module):
3043            def __init__(self) -> None:
3044                super().__init__()
3045                self.fc = torch.nn.Linear(5, 5)
3046
3047            def forward(self, x):
3048                return self.fc(x)
3049
3050        model = torch.jit.script(M())
3051        for qconfig in [float16_dynamic_qconfig, default_dynamic_qconfig]:
3052            m = prepare_dynamic_jit(model, {"": qconfig})
3053
3054            # observer for weight
3055            assert len(attrs_with_prefix(m.fc, "_observer_")) == 1
3056
3057            if qconfig == float16_dynamic_qconfig:
3058                observer_name = 'PlaceholderObserver = prim::GetAttr[name="_observer_'
3059                FileCheck().check(observer_name).run(m.fc.graph)
3060            else:
3061                # for input of FC for dynamic quant
3062                assert len(attrs_with_prefix(m, "_observer_")) == 1
3063                observer_name = 'Observer = prim::GetAttr[name="_observer_'
3064                FileCheck().check(observer_name).check(
3065                    'prim::GetAttr[name="fc"]'
3066                ).check("prim::CallMethod").check_not(observer_name).run(m.graph)
3067
3068    def test_prepare_dynamic_child_qconfig(self):
3069        class Sub(torch.nn.Module):
3070            def __init__(self) -> None:
3071                super().__init__()
3072                self.fc = torch.nn.Linear(5, 5)
3073
3074            def forward(self, x):
3075                return self.fc(x)
3076
3077        class M(torch.nn.Module):
3078            def __init__(self) -> None:
3079                super().__init__()
3080                self.conv = torch.nn.Conv2d(3, 5, 3)
3081                self.sub = Sub()
3082
3083            def forward(self, x):
3084                return self.sub(self.conv(x))
3085
3086        m = torch.jit.script(M())
3087        # only quantize child module.
3088        m = prepare_dynamic_jit(m, {"sub.fc": default_dynamic_qconfig})
3089
3090        # input of sub for dynamic quant
3091        assert len(attrs_with_prefix(m, "_observer_")) == 1
3092        # not quantized
3093        assert len(attrs_with_prefix(m.conv, "_observer_")) == 0
3094        # no observers since we observe in the outer most call site
3095        assert len(attrs_with_prefix(m.sub, "_observer_")) == 0
3096        # weight of linear
3097        assert len(attrs_with_prefix(m.sub.fc, "_observer_")) == 1
3098        FileCheck().check('prim::GetAttr[name="sub').check("prim::CallMethod").check(
3099            'Observer = prim::GetAttr[name="_observer_'
3100        ).check("prim::CallMethod").check_not(
3101            'Observer = prim::GetAttr[name="_observer_'
3102        ).run(
3103            m.graph
3104        )
3105
3106    def test_insert_quant_dequant_linear_dynamic(self):
3107        class M(torch.nn.Module):
3108            def __init__(self) -> None:
3109                super().__init__()
3110                self.fc1 = torch.nn.Linear(5, 5).float()
3111                self.fc2 = torch.nn.Linear(5, 5).float()
3112
3113            def forward(self, x):
3114                x = self.fc1(x)
3115                return self.fc2(x)
3116
3117        for is_per_channel in [True, False]:
3118            m = torch.jit.script(M())
3119            qconfig = (
3120                per_channel_dynamic_qconfig
3121                if is_per_channel is True
3122                else default_dynamic_qconfig
3123            )
3124            m = quantize_dynamic_jit(m, {"": qconfig}, debug=True)
3125            assert (
3126                len(m._modules._c.items()) == 2
3127            ), "Expected to have two submodule of linear"
3128
3129            wt_quant_func = (
3130                "aten::quantize_per_channel"
3131                if is_per_channel
3132                else "aten::quantize_per_tensor"
3133            )
3134            act_quant_func = "aten::quantize_per_tensor"
3135            # quantizing activations
3136            FileCheck().check("aten::_choose_qparams_per_tensor").check_next(
3137                act_quant_func
3138            ).check_next("aten::dequantize").check(
3139                "aten::_choose_qparams_per_tensor"
3140            ).check_next(
3141                act_quant_func
3142            ).check_next(
3143                "aten::dequantize"
3144            ).check(
3145                wt_quant_func
3146            ).check_next(
3147                "aten::dequantize"
3148            ).check_not(
3149                wt_quant_func
3150            ).check(
3151                "return"
3152            ).run(
3153                m.graph
3154            )
3155
3156    @override_qengines
3157    def test_dynamic_multi_op(self):
3158        class M(torch.nn.Module):
3159            def __init__(self) -> None:
3160                super().__init__()
3161                self.fc1 = torch.nn.Linear(5, 5).to(dtype=torch.float)
3162
3163            def forward(self, x):
3164                x = x + 5
3165                return self.fc1(x)
3166
3167        x = torch.randn(5, 5)
3168        for tracing in [True, False]:
3169            model = self.checkGraphModeOp(
3170                M(), x, "quantized::linear_dynamic", tracing=tracing, dynamic=True
3171            )
3172            # add op is not dynamically quantized.
3173            FileCheck().check("aten::add").run(model.graph)
3174
3175    @override_qengines
3176    def test_dynamic_quant_multi_uses(self):
3177        class M(torch.nn.Module):
3178            def __init__(self) -> None:
3179                super().__init__()
3180                self.fc = torch.nn.Linear(5, 5).float()
3181
3182            def forward(self, x):
3183                size1 = x.size()
3184                size2 = x.size()
3185                return self.fc(x), size1, size2
3186
3187        x = torch.randn(5, 5)
3188        for tracing in [True, False]:
3189            model = self.checkGraphModeOp(
3190                M(), x, "quantized::linear_dynamic", tracing=tracing, dynamic=True
3191            )
3192            FileCheck().check_not("aten::_choose_qparams_per_tensor").run(model.graph)
3193
3194    @override_qengines
3195    def test_dynamic_shared_weights(self):
3196        class myMod(torch.nn.Module):
3197            def __init__(self, weight):
3198                super().__init__()
3199                self.linear = nn.Linear(5, 5)
3200                self.linear.weight = weight
3201
3202            def forward(self, x):
3203                return self.linear(x)
3204
3205        class DynamicModel(torch.nn.Module):
3206            def __init__(self) -> None:
3207                super().__init__()
3208                self.weight = torch.nn.Parameter(torch.ones(5, 5))
3209                self.mod1 = myMod(self.weight)
3210
3211            def forward(self, x):
3212                y = self.mod1(x)
3213                z = torch.nn.functional.linear(y, self.weight)
3214                return z
3215
3216        model = torch.jit.script(DynamicModel()).eval()
3217        data = torch.randn(5, 5, dtype=torch.float)
3218        quant_ops = ["mod1", ""]
3219        counts = [1, 2]
3220        for op, count in zip(quant_ops, counts):
3221            qconfig_dict = {op: default_dynamic_qconfig}
3222            m1 = quantize_dynamic_jit(model, qconfig_dict)
3223            out_graph = m1(data)
3224
3225            FileCheck().check_count(
3226                "quantized::linear_dynamic(", count, exactly=True
3227            ).check_not("aten::_choose_qparams_per_tensor").run(m1.graph)
3228
3229            # Explicitly call forward on model before convert
3230            m2 = prepare_dynamic_jit(model, qconfig_dict)
3231            m2(data)
3232            m2 = convert_dynamic_jit(m2, debug=False)
3233            out_ref = m2(data)
3234            self.assertEqual(out_graph, out_ref)
3235
3236    @override_qengines
3237    def test_dynamic_with_if(self):
3238        class Res(torch.nn.Module):
3239            def __init__(self) -> None:
3240                super().__init__()
3241                self.weight = torch.nn.Parameter(torch.ones(5, 5))
3242
3243            def forward(self, x: torch.Tensor, cond: bool) -> torch.Tensor:
3244                if cond:
3245                    return torch.nn.functional.linear(x, self.weight)
3246                else:
3247                    return torch.nn.functional.linear(x, self.weight)
3248
3249        class M(torch.nn.Module):
3250            def __init__(self) -> None:
3251                super().__init__()
3252                self.res1 = Res()
3253                self.res2 = Res()
3254
3255            def forward(self, x):
3256                x = self.res1(x, True)
3257                x = self.res2(x, False)
3258                return x
3259
3260        model = torch.jit.script(M()).eval()
3261        data = torch.randn(5, 5, dtype=torch.float)
3262        qconfig_dict = {"": default_dynamic_qconfig}
3263        for tracing in [True, False]:
3264            m1 = self.checkGraphModeOp(
3265                M(), data, "quantized::linear_dynamic", tracing=tracing, dynamic=True
3266            )
3267            FileCheck().check_count(
3268                "quantized::linear_dynamic(", 2, exactly=True
3269            ).check_not("aten::_choose_qparams_per_tensor").run(m1.graph)
3270
3271        # Check to make sure weight observers run correctly
3272        ref_qparams = []
3273        qconfig = script_qconfig(default_dynamic_qconfig)
3274        wt_module = wrap_cpp_module(qconfig.weight)
3275        for wt in [model.res1.weight, model.res2.weight]:
3276            wt_module(wt)
3277            qparams = wt_module.calculate_qparams()
3278            ref_qparams.append((qparams[0].item(), qparams[1].item()))
3279
3280        m2 = quantize_dynamic_jit(model, qconfig_dict, debug=True)
3281        graph_params = []
3282        for x, obs in m2._modules._c.items():
3283            if x == "res1":
3284                graph_params.append(
3285                    (
3286                        obs.getattr("weight.2_scale_0"),
3287                        obs.getattr("weight.2_zero_point_0"),
3288                    )
3289                )
3290            elif x == "res2":
3291                graph_params.append(
3292                    (
3293                        obs.getattr("weight.4_scale_0"),
3294                        obs.getattr("weight.4_zero_point_0"),
3295                    )
3296                )
3297        self.assertEqual(ref_qparams, graph_params)
3298
3299    def test_dynamic_weight_observer(self):
3300        class M(torch.nn.Module):
3301            def __init__(self) -> None:
3302                super().__init__()
3303                self.fc = torch.nn.Linear(5, 5).float()
3304                self.fc2 = torch.nn.Linear(5, 5).float()
3305
3306            def forward(self, x):
3307                x = self.fc(x)
3308                return self.fc2(x)
3309
3310        qconfig_dict = {"": default_dynamic_qconfig}
3311        eager_model = M().eval()
3312        for tracing in [True, False]:
3313            x = torch.rand(5, 5)
3314            model = get_script_module(eager_model, tracing, x)
3315            ref_qparams = []
3316            for wt in [model.fc.weight, model.fc2.weight]:
3317                wt_module = default_dynamic_qconfig.weight()
3318                wt_module(wt)
3319                qparams = wt_module.calculate_qparams()
3320                ref_qparams.append((qparams[0].item(), qparams[1].item()))
3321            model = quantize_dynamic_jit(model, qconfig_dict, debug=True)
3322            graph_qparams = []
3323            for x, obs in model._modules._c.items():
3324                n = 2 if x == "fc" and tracing else 1
3325                graph_qparams.append(
3326                    (
3327                        obs.getattr(f"weight.{n}_scale_0"),
3328                        obs.getattr(f"weight.{n}_zero_point_0"),
3329                    )
3330                )
3331            self.assertEqual(ref_qparams, graph_qparams)
3332
3333    def test_convert_dynamic_fp16(self):
3334        class M(torch.nn.Module):
3335            def __init__(self) -> None:
3336                super().__init__()
3337                self.fc = torch.nn.Linear(5, 5)
3338
3339            def forward(self, x):
3340                return self.fc(x)
3341
3342        m = torch.jit.script(M())
3343        m = quantize_dynamic_jit(m, {"": float16_dynamic_qconfig}, debug=True)
3344        FileCheck().check("aten::_saturate_weight_to_fp16").check(
3345            "aten::linear"
3346        ).check_not("aten::dequantize").check_not("aten::quantize").run(m.graph)
3347
3348    def test_quantize_dynamic_fp16(self):
3349        class M(torch.nn.Module):
3350            def __init__(self) -> None:
3351                super().__init__()
3352                self.fc = torch.nn.Linear(5, 5)
3353
3354            def forward(self, x):
3355                return self.fc(x)
3356
3357        m = torch.jit.script(M())
3358        m = quantize_dynamic_jit(m, {"": float16_dynamic_qconfig})
3359
3360        FileCheck().check("quantized::linear_dynamic_fp16").check_not(
3361            "aten::linear"
3362        ).check_not("aten::dequantize").check_not("aten::quantize").run(m.graph)
3363
3364
3365class TestQuantizeDynamicJitOps(QuantizationTestCase):
3366    """Test graph mode post training dynamic quantization works
3367    for individual ops end to end.
3368    """
3369
3370    @override_qengines
3371    def test_linear(self):
3372        class FunctionalLinear(torch.nn.Module):
3373            def __init__(self, weight, bias):
3374                super().__init__()
3375                self.weight = weight
3376                self.bias = bias
3377
3378            def forward(self, x):
3379                return F.linear(x, self.weight, self.bias)
3380
3381        x = torch.rand(5, 5)
3382        for tracing in [True, False]:
3383            model = self.checkGraphModeOp(
3384                torch.nn.Linear(5, 5),
3385                x,
3386                "quantized::linear_dynamic",
3387                tracing=tracing,
3388                dynamic=True,
3389            )
3390
3391        weight = torch.rand(5, 5)
3392        b = torch.rand(5)
3393        for tracing, has_bias in itertools.product([True, False], [True, False]):
3394            bias = b if has_bias else None
3395            model = self.checkGraphModeOp(
3396                FunctionalLinear(weight, bias),
3397                x,
3398                "quantized::linear_dynamic",
3399                tracing=tracing,
3400                dynamic=True,
3401            )
3402
3403    @skipIfNoFBGEMM
3404    def test_embedding_bag(self):
3405        class M(torch.nn.Module):
3406            def __init__(self, weights):
3407                super().__init__()
3408                self.embedding1 = torch.nn.EmbeddingBag(
3409                    num_embeddings=10,
3410                    embedding_dim=12,
3411                    include_last_offset=True,
3412                    sparse=True,
3413                    _weight=weights,
3414                    mode="sum",
3415                )
3416
3417                self.embedding2 = torch.nn.EmbeddingBag(
3418                    num_embeddings=10,
3419                    embedding_dim=12,
3420                    include_last_offset=True,
3421                    sparse=True,
3422                    _weight=weights,
3423                    mode="sum",
3424                )
3425
3426            def forward(self, indices1, offsets1, indices2, offsets2):
3427                e1 = self.embedding1(indices1, offsets1)
3428                e2 = self.embedding2(indices2, offsets2)
3429                return e1, e2
3430
3431        weights = torch.randn(10, 12, dtype=torch.float32)
3432        module = M(weights)
3433
3434        indices = torch.tensor(
3435            [
3436                9,
3437                6,
3438                5,
3439                7,
3440                8,
3441                8,
3442                9,
3443                2,
3444                8,
3445                6,
3446                6,
3447                9,
3448                1,
3449                6,
3450                8,
3451                8,
3452                3,
3453                2,
3454                3,
3455                6,
3456                3,
3457                6,
3458                5,
3459                7,
3460                0,
3461                8,
3462                4,
3463                6,
3464                5,
3465                8,
3466                2,
3467                3,
3468            ]
3469        )
3470        offsets = torch.tensor([0, 19, 20, 28, 28, 32])
3471        dummy_inputs = (indices, offsets, indices, offsets)
3472        for trace in [True, False]:
3473            if trace:
3474                m = torch.jit.trace(module, dummy_inputs)
3475            else:
3476                m = torch.jit.script(module)
3477            int4_qconfig = QConfig(
3478                activation=PlaceholderObserver.with_args(
3479                    dtype=torch.float, custom_op_name="embedding_bag_4bit"
3480                ),
3481                weight=PlaceholderObserver.with_args(
3482                    custom_op_name="embedding_bag_4bit"
3483                ),
3484            )
3485            int8_qconfig = QConfig(
3486                activation=PlaceholderObserver.with_args(
3487                    dtype=torch.float, custom_op_name="embedding_bag_byte"
3488                ),
3489                weight=PlaceholderObserver.with_args(
3490                    custom_op_name="embedding_bag_byte"
3491                ),
3492            )
3493            m = prepare_jit(m, {"embedding1": int4_qconfig, "embedding2": int8_qconfig})
3494            m = convert_jit(m)
3495            FileCheck().check("quantized::embedding_bag_4bit_rowwise_offsets").check(
3496                "quantized::embedding_bag_byte_rowwise_offsets"
3497            ).run(m.graph)
3498            m(*dummy_inputs)
3499
3500    # Ensure that attempting to quantize an EmbeddingBag throws an error if
3501    # padding_idx is not None
3502    @skipIfNoFBGEMM
3503    def test_embedding_bag_padding_idx_error(self):
3504        class M(torch.nn.Module):
3505            def __init__(self, weights):
3506                super().__init__()
3507                self.embedding = torch.nn.EmbeddingBag(
3508                    num_embeddings=10,
3509                    embedding_dim=12,
3510                    include_last_offset=True,
3511                    sparse=True,
3512                    _weight=weights,
3513                    mode="sum",
3514                    padding_idx=0,
3515                )
3516
3517            def forward(self, indices, offsets):
3518                e = self.embedding(indices, offsets)
3519                return e
3520
3521        weights = torch.randn(10, 12, dtype=torch.float32)
3522        module = M(weights)
3523
3524        indices = torch.tensor([0, 1, 2, 3, 4])
3525        offsets = torch.tensor([0, 2, 5])
3526        dummy_inputs = (indices, offsets)
3527
3528        int4_qconfig = QConfig(
3529            activation=PlaceholderObserver.with_args(
3530                dtype=torch.float, custom_op_name="embedding_bag_4bit"
3531            ),
3532            weight=PlaceholderObserver.with_args(custom_op_name="embedding_bag_4bit"),
3533        )
3534        int8_qconfig = QConfig(
3535            activation=PlaceholderObserver.with_args(
3536                dtype=torch.float, custom_op_name="embedding_bag_byte"
3537            ),
3538            weight=PlaceholderObserver.with_args(custom_op_name="embedding_bag_byte"),
3539        )
3540
3541        error_msg = r"Expected aten::embedding_bag padding_idx input to be None"
3542        for trace, qconfig in itertools.product(
3543            [True, False], [int4_qconfig, int8_qconfig]
3544        ):
3545            if trace:
3546                m = torch.jit.trace(module, dummy_inputs)
3547            else:
3548                m = torch.jit.script(module)
3549            m = prepare_jit(m, {"embedding": qconfig})
3550            with self.assertRaisesRegex(RuntimeError, error_msg):
3551                m = convert_jit(m)
3552
3553
3554class TestQuantizeJit(QuantizationTestCase):
3555    @override_qengines
3556    def test_single_linear(self):
3557        r"""Compare the result of quantizing single linear layer in
3558        eager mode and graph mode
3559        """
3560        # eager mode
3561        annotated_linear_model = AnnotatedSingleLayerLinearModel(
3562            torch.backends.quantized.engine
3563        ).eval()
3564        linear_model = SingleLayerLinearModel().eval()
3565        # copy the weight from eager mode so that we can
3566        # compare the result of the two quantized models later
3567        linear_model.fc1.weight = torch.nn.Parameter(
3568            annotated_linear_model.fc1.module.weight.detach()
3569        )
3570        linear_model.fc1.bias = torch.nn.Parameter(
3571            annotated_linear_model.fc1.module.bias.detach()
3572        )
3573        model_eager = quantize(
3574            annotated_linear_model, test_only_eval_fn, [self.calib_data]
3575        )
3576
3577        qconfig_dict = {"": get_default_qconfig(torch.backends.quantized.engine)}
3578        model_traced = torch.jit.trace(linear_model, self.calib_data[0][0])
3579        model_script = torch.jit.script(linear_model)
3580        result_eager = model_eager(self.calib_data[0][0])
3581        for model_under_test in [model_traced, model_script]:
3582            model_quantized = quantize_jit(
3583                model_under_test,
3584                qconfig_dict,
3585                test_only_eval_fn,
3586                [self.calib_data],
3587                inplace=False,
3588            )
3589            self.assertEqual(model_quantized(self.calib_data[0][0]), result_eager)
3590
3591    @skipIfNoFBGEMM
3592    def test_observer_with_ignored_function(self):
3593        r"""Test observers with ignored function and make sure it works in
3594        graph mode
3595        """
3596        # eager mode
3597        annotated_linear_model = AnnotatedSingleLayerLinearModel("fbgemm").eval()
3598        for qconfig in [
3599            QConfig(activation=default_observer, weight=default_weight_observer),
3600            QConfig(
3601                activation=default_histogram_observer, weight=default_weight_observer
3602            ),
3603            QConfig(
3604                activation=default_observer, weight=default_per_channel_weight_observer
3605            ),
3606        ]:
3607            annotated_linear_model.qconfig = qconfig
3608            linear_model = SingleLayerLinearModel().eval()
3609            # copy the weight from eager mode so that we can
3610            # compare the result of the two quantized models later
3611            linear_model.fc1.weight = torch.nn.Parameter(
3612                annotated_linear_model.fc1.module.weight.detach()
3613            )
3614            linear_model.fc1.bias = torch.nn.Parameter(
3615                annotated_linear_model.fc1.module.bias.detach()
3616            )
3617            model_eager = quantize(
3618                annotated_linear_model, test_only_eval_fn, [self.calib_data]
3619            )
3620
3621            qconfig_dict = {"": qconfig}
3622            model_traced = torch.jit.trace(linear_model, self.calib_data[0][0])
3623            model_script = torch.jit.script(linear_model)
3624            result_eager = model_eager(self.calib_data[0][0])
3625            for model_under_test in [model_traced, model_script]:
3626                model_quantized = quantize_jit(
3627                    model_under_test,
3628                    qconfig_dict,
3629                    test_only_eval_fn,
3630                    [self.calib_data],
3631                    inplace=False,
3632                )
3633                self.assertEqual(model_quantized(self.calib_data[0][0]), result_eager)
3634
3635    @override_qengines
3636    def test_conv(self):
3637        r"""Compare the result of quantizing conv layer in
3638        eager mode and graph mode
3639        """
3640        # eager mode
3641        annotated_conv_model = AnnotatedConvModel(
3642            torch.backends.quantized.engine
3643        ).eval()
3644        conv_model = ConvModel().eval()
3645        # copy the weight from eager mode so that we can
3646        # compare the result of the two quantized models later
3647        conv_model.conv.weight = torch.nn.Parameter(
3648            annotated_conv_model.conv.weight.detach()
3649        )
3650        model_eager = quantize(
3651            annotated_conv_model, test_only_eval_fn, [self.img_data_2d]
3652        )
3653        qconfig_dict = {"": get_default_qconfig(torch.backends.quantized.engine)}
3654        model_traced = torch.jit.trace(conv_model, self.img_data_2d[0][0])
3655        model_script = torch.jit.script(conv_model)
3656        result_eager = model_eager(self.img_data_2d[0][0])
3657        for model_under_test in [model_traced, model_script]:
3658            model_quantized = quantize_jit(
3659                model_under_test,
3660                qconfig_dict,
3661                test_only_eval_fn,
3662                [self.img_data_2d],
3663                inplace=False,
3664            )
3665            self.assertEqual(model_quantized(self.img_data_2d[0][0]), result_eager)
3666
3667    @override_qengines
3668    def test_conv_transpose(self):
3669        r"""Compare the result of quantizing conv_transpose layer in
3670        eager mode and graph mode
3671        """
3672        if not qengine_is_qnnpack():
3673            return  # Currently only qnnpack is supported
3674        # eager mode
3675        annotated_conv_model = AnnotatedConvTransposeModel(
3676            torch.backends.quantized.engine
3677        ).eval()
3678        conv_model = ConvTransposeModel().eval()
3679        # copy the weight from eager mode so that we can
3680        # compare the result of the two quantized models later
3681        conv_model.conv.weight = torch.nn.Parameter(
3682            annotated_conv_model.conv.weight.detach()
3683        )
3684        model_eager = quantize(
3685            annotated_conv_model, test_only_eval_fn, [self.img_data_2d]
3686        )
3687        qconfig_dict = {"": get_default_qconfig(torch.backends.quantized.engine)}
3688        model_traced = torch.jit.trace(conv_model, self.img_data_2d[0][0])
3689        model_script = torch.jit.script(conv_model)
3690        result_eager = model_eager(self.img_data_2d[0][0])
3691        for model_under_test in [model_traced, model_script]:
3692            model_quantized = quantize_jit(
3693                model_under_test,
3694                qconfig_dict,
3695                test_only_eval_fn,
3696                [self.img_data_2d],
3697                inplace=False,
3698            )
3699            self.assertEqual(model_quantized(self.img_data_2d[0][0]), result_eager)
3700
3701    @override_qengines
3702    def test_conv_bn(self):
3703        r"""Compare the result of quantizing conv + bn layer in
3704        eager mode and graph mode
3705        """
3706        # eager mode
3707        conv_model = AnnotatedConvBnModel().eval()
3708        conv_model_to_script = ConvBnModel().eval()
3709        # copy the weight from eager mode so that we can
3710        # compare the result of the two quantized models later
3711        conv_model_to_script.conv.weight = torch.nn.Parameter(
3712            conv_model.conv.weight.detach()
3713        )
3714        fuse_modules(conv_model, ["conv", "bn"], inplace=True)
3715        model_eager = quantize(conv_model, test_only_eval_fn, [self.img_data_2d])
3716        qconfig_dict = {"": default_qconfig}
3717        model_script = quantize_jit(
3718            torch.jit.script(conv_model_to_script),
3719            qconfig_dict,
3720            test_only_eval_fn,
3721            [self.img_data_2d],
3722            inplace=False,
3723        )
3724        result_eager = model_eager(self.img_data_2d[0][0])
3725        result_script = model_script(self.img_data_2d[0][0])
3726        self.assertEqual(result_eager, result_script)
3727
3728    @override_qengines
3729    def test_nested(self):
3730        # Eager mode
3731        eager_model = AnnotatedNestedModel(torch.backends.quantized.engine).eval()
3732
3733        # Graph mode
3734        script_model = NestedModel().eval()
3735        # Copy weights for eager_model
3736        script_model.sub1.fc.weight = torch.nn.Parameter(
3737            eager_model.sub1.fc.weight.detach()
3738        )
3739        script_model.sub1.fc.bias = torch.nn.Parameter(
3740            eager_model.sub1.fc.bias.detach()
3741        )
3742        script_model.sub2.fc1.weight = torch.nn.Parameter(
3743            eager_model.sub2.fc1.module.weight.detach()
3744        )
3745        script_model.sub2.fc1.bias = torch.nn.Parameter(
3746            eager_model.sub2.fc1.module.bias.detach()
3747        )
3748        script_model.sub2.fc2.weight = torch.nn.Parameter(
3749            eager_model.sub2.fc2.weight.detach()
3750        )
3751        script_model.sub2.fc2.bias = torch.nn.Parameter(
3752            eager_model.sub2.fc2.bias.detach()
3753        )
3754        script_model.fc3.weight = torch.nn.Parameter(
3755            eager_model.fc3.module.weight.detach()
3756        )
3757        script_model.fc3.bias = torch.nn.Parameter(eager_model.fc3.module.bias.detach())
3758
3759        model_eager = quantize(eager_model, test_only_eval_fn, [self.calib_data])
3760        qconfig_dict = {
3761            "sub2.fc1": default_per_channel_qconfig
3762            if qengine_is_fbgemm()
3763            else default_qconfig,
3764            "fc3": default_qconfig,
3765        }
3766        model_traced = torch.jit.trace(script_model, self.calib_data[0][0])
3767        model_script = torch.jit.script(script_model)
3768        result_eager = model_eager(self.calib_data[0][0])
3769        for model_under_test in [model_traced, model_script]:
3770            model_quantized = quantize_jit(
3771                model_under_test,
3772                qconfig_dict,
3773                test_only_eval_fn,
3774                [self.calib_data],
3775                inplace=False,
3776            )
3777            self.assertEqual(model_quantized(self.calib_data[0][0]), result_eager)
3778
3779    @override_qengines
3780    def test_skip_quant(self):
3781        """Test None qconfig"""
3782        # Eager mode
3783        eager_model = AnnotatedSkipQuantModel(torch.backends.quantized.engine).eval()
3784
3785        # Graph mode
3786        script_model = SkipQuantModel().eval()
3787        # Copy weights for eager_model
3788        script_model.sub.fc1.weight = torch.nn.Parameter(
3789            eager_model.sub.module.fc1.weight.detach()
3790        )
3791        script_model.sub.fc1.bias = torch.nn.Parameter(
3792            eager_model.sub.module.fc1.bias.detach()
3793        )
3794        script_model.sub.fc2.weight = torch.nn.Parameter(
3795            eager_model.sub.module.fc2.weight.detach()
3796        )
3797        script_model.sub.fc2.bias = torch.nn.Parameter(
3798            eager_model.sub.module.fc2.bias.detach()
3799        )
3800        script_model.fc.weight = torch.nn.Parameter(eager_model.fc.weight.detach())
3801        script_model.fc.bias = torch.nn.Parameter(eager_model.fc.bias.detach())
3802
3803        eager_model.fuse_modules()
3804
3805        model_eager = quantize(eager_model, test_only_eval_fn, [self.calib_data])
3806        qconfig_dict = {
3807            "": get_default_qconfig(torch.backends.quantized.engine),
3808            "fc": None,
3809        }
3810        model_traced = torch.jit.trace(script_model, self.calib_data[0][0])
3811        model_script = torch.jit.script(script_model)
3812        result_eager = model_eager(self.calib_data[0][0])
3813        for model_under_test in [model_traced, model_script]:
3814            model_quantized = quantize_jit(
3815                model_under_test,
3816                qconfig_dict,
3817                test_only_eval_fn,
3818                [self.calib_data],
3819                inplace=False,
3820            )
3821            self.assertEqual(model_quantized(self.calib_data[0][0]), result_eager)
3822
3823    @override_qengines
3824    def test_single_linear_dynamic(self):
3825        r"""Compare the result of dynamic quantization of single linear layer in
3826        eager mode and graph mode.
3827        """
3828        if qengine_is_qnnpack():
3829            # eager mode
3830            annotated_linear_model = AnnotatedSingleLayerLinearModel("qnnpack").eval()
3831            linear_model = SingleLayerLinearModel().eval()
3832            # copy the weight from eager mode so that we can
3833            # compare the result of the two quantized models later
3834            linear_model.fc1.weight = torch.nn.Parameter(
3835                annotated_linear_model.fc1.module.weight.detach()
3836            )
3837            linear_model.fc1.bias = torch.nn.Parameter(
3838                annotated_linear_model.fc1.module.bias.detach()
3839            )
3840            qconfig_dict = {"": default_dynamic_qconfig}
3841            model_eager = quantize_dynamic(annotated_linear_model, qconfig_dict)
3842
3843            model_traced = torch.jit.trace(linear_model, self.calib_data[0][0])
3844            model_script = torch.jit.script(linear_model)
3845            result_eager = model_eager(self.calib_data[0][0])
3846
3847            for model_under_test in [model_traced, model_script]:
3848                model_quantized = quantize_dynamic_jit(model_under_test, qconfig_dict)
3849                self.assertEqual(model_quantized(self.calib_data[0][0]), result_eager)
3850
3851                # Check to make sure choose_qparams->quant->dequant->linear is numerically
3852                # equivalent to the final quantized model.
3853                model_fake_quantized = quantize_dynamic_jit(
3854                    model_under_test, qconfig_dict, debug=True
3855                )
3856                self.assertEqual(
3857                    model_fake_quantized(self.calib_data[0][0]), result_eager
3858                )
3859
3860    @skipIfNoFBGEMM
3861    def test_linear_dynamic_fp16(self):
3862        linear_model = SingleLayerLinearModel().eval()
3863        # Create weight tensor values that are beyond fp16 max
3864        x = torch.ones(5, 5) * 65532
3865        linear_model.fc1.weight = torch.nn.Parameter(x)
3866        import warnings
3867
3868        model_eager = quantize_dynamic(linear_model, dtype=torch.float16)
3869        result_eager = model_eager(self.calib_data[0][0])
3870        for trace in [True]:
3871            with warnings.catch_warnings(record=True) as w:
3872                quantized_model = self.checkGraphModeOp(
3873                    linear_model,
3874                    self.calib_data[0][0],
3875                    "quantized::linear_dynamic_fp16",
3876                    tracing=trace,
3877                    dynamic=True,
3878                    qconfig=float16_dynamic_qconfig,
3879                )
3880            # compare result with eager mode
3881            self.assertEqual(quantized_model(self.calib_data[0][0]), result_eager)
3882