xref: /aosp_15_r20/external/pytorch/test/test_jit_llga_fuser.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: mkldnn"]
2import sys
3import torch
4import unittest
5import itertools
6import torch.nn as nn
7from functools import wraps
8from concurrent import futures
9import torch.nn.functional as F
10import torch.fx.experimental.optimization as optimization
11from torch.testing._internal.jit_utils import JitTestCase
12from torch.testing._internal.common_utils import run_tests, TEST_SCIPY, IS_WINDOWS, IS_MACOS
13from torch.testing._internal.common_device_type import (
14    instantiate_device_type_tests,
15    onlyCPU,
16    dtypes
17)
18
19# We use this wrapper to run UTs of TorchVision models because of a memory-leak
20# issue with JIT tracing that causes traced model objects to persist in the
21# memory. Ref: https://github.com/pytorch/pytorch/issues/35600
22# Memory requirement for running these UTs was thus increasing cumulatively, and
23# invoked the Linux kernel OOM killer on linux.2xlarge PyTorch CI runners, which
24# only have 16 GB RAM. Cumulatively, these UTs had been using more than 14 GB
25# memory (as per psutils). So now we run each TorchVision model UTs in separate processes.
26def separate_process(func):
27    @wraps(func)
28    def wrapper(*args, **kwargs):
29        with futures.ProcessPoolExecutor() as executor:
30            future = executor.submit(func, *args, **kwargs)
31            futures.wait([future])
32    return wrapper
33
34def is_avx512_supported():
35    if sys.platform != 'linux':
36        return False
37    with open("/proc/cpuinfo", encoding="ascii") as f:
38        lines = f.read()
39    return "avx512" in lines
40
41IS_AVX512_UNSUPPORTED = not is_avx512_supported()
42
43LLGA_FUSION_GROUP = 'prim::oneDNNFusionGroup'
44LLGA_NOT_ENABLED = not torch.backends.mkldnn.is_available() or IS_WINDOWS or IS_MACOS
45
46def warmup_forward(f, *args, profiling_count=3):
47    for i in range(profiling_count):
48        results = f(*args)
49
50    return results
51
52class JitLlgaTestCase(JitTestCase):
53
54    def setUp(self):
55        # PyTorch has divergent op support for AMP in JIT & eager modes
56        # so we disable AMP for JIT & leverage eager-mode AMP.
57        # Ref: https://github.com/pytorch/pytorch/issues/75956
58        self.original_autocast_mode = torch._C._jit_set_autocast_mode(False)
59        torch.jit.enable_onednn_fusion(True)
60
61    def tearDown(self):
62        torch.jit.enable_onednn_fusion(False)
63        torch._C._jit_set_autocast_mode(self.original_autocast_mode)
64
65    def checkTrace(self, m, x, dtype=torch.float32, *args, **kwargs):
66        if isinstance(m, torch.nn.Module):
67            m.eval()
68        with torch.no_grad(), torch._jit_internal._disable_emit_hooks():
69            if dtype == torch.bfloat16:
70                # We rely upon eager-mode AMP support for BF16
71                with torch.cpu.amp.autocast(cache_enabled=False, dtype=torch.bfloat16):
72                    traced = torch.jit.trace(m, x)
73                    if isinstance(m, torch.nn.Module):
74                        traced = torch.jit.freeze(traced)
75                    warmup_forward(traced, *x)
76                    ref_o = m(*x)
77                    fwd_graph = traced.graph_for(*x)
78            else:
79                traced = torch.jit.trace(m, x)
80                if isinstance(m, torch.nn.Module):
81                    traced = torch.jit.freeze(traced)
82                warmup_forward(traced, *x)
83                ref_o = m(*x)
84                fwd_graph = traced.graph_for(*x)
85
86            jit_o = traced(*x)
87            self.assertEqual(jit_o, ref_o)
88            return traced, fwd_graph
89
90
91    def assertFused(self, graph, fused_patterns):
92        for pat in fused_patterns:
93            self.assertGraphContainsExactly(graph, pat, 0)
94
95    def findFusionGroups(self, graph):
96        result = []
97        for n in graph.nodes():
98            if n.kind() == LLGA_FUSION_GROUP:
99                result.append(n.g('Subgraph'))
100                continue
101            for block in n.blocks():
102                result += self.findFusionGroups(block)
103        return result
104
105    def checkPatterns(self, graph, patterns):
106        fusion_groups = self.findFusionGroups(graph)
107        assert len(fusion_groups) == len(patterns), "length of subgraphs not equal to length of given patterns"
108
109        for i in range(len(fusion_groups)):
110            for pattern in patterns[i]:
111                self.assertGraphContains(fusion_groups[i], pattern)
112
113try:
114    import torchvision
115    HAS_TORCHVISION = True
116except ImportError:
117    HAS_TORCHVISION = False
118except RuntimeError:
119    HAS_TORCHVISION = False
120skipIfNoTorchVision = unittest.skipIf(not HAS_TORCHVISION, 'no torchvision')
121
122def get_eltwise_fn(name):
123    if hasattr(torch, name):
124        return getattr(torch, name)
125    elif hasattr(F, name):
126        return getattr(F, name)
127    elif name == 'hardswish_':
128        return torch.nn.Hardswish(inplace=True)
129    else:
130        raise NameError(f'Eltwise function {name} not found')
131
132
133@unittest.skipIf(IS_AVX512_UNSUPPORTED, "This test fails for BF16 on machines without AVX512.")
134@unittest.skipIf(LLGA_NOT_ENABLED, "MKL-DNN build is disabled")
135class TestOp(JitLlgaTestCase):
136    @onlyCPU
137    @dtypes(torch.float32, torch.bfloat16)
138    def test_conv2d(self, dtype):
139        for [spatial, in_channels, out_channels, kernel, padding, stride, dilation, g, bias] in itertools.product(
140                [7, 8],
141                [8, 15],
142                [7, 16],
143                [3, 4],
144                [0, 2],
145                [1, 2],
146                [1, 2],
147                [1, 2],
148                [True, False]):
149
150            m = nn.Conv2d(in_channels=in_channels * g,
151                          out_channels=out_channels * g,
152                          kernel_size=kernel,
153                          padding=padding,
154                          stride=stride,
155                          dilation=dilation,
156                          groups=g,
157                          bias=bias)
158
159            x = torch.rand(1, in_channels * g, spatial, spatial)
160            _, graph = self.checkTrace(m, [x], dtype)
161            self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 1)
162
163    @onlyCPU
164    @dtypes(torch.float32, torch.bfloat16)
165    def test_bn2d(self, dtype):
166        m = nn.BatchNorm2d(32).eval()
167        x = torch.rand(1, 32, 28, 28)
168        _, graph = self.checkTrace(m, [x], dtype)
169        # single-op partition shouldn't be created for softmax
170        self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 0)
171
172    @onlyCPU
173    @dtypes(torch.float32, torch.bfloat16)
174    def test_eltwise(self, dtype):
175        class M(nn.Module):
176            def __init__(self, eltwise_fn):
177                super().__init__()
178                self.eltwise = eltwise_fn
179
180            def forward(self, x):
181                return self.eltwise(x)
182
183        for eltwise in ['relu', 'gelu']:
184            eltwise_fn = get_eltwise_fn(eltwise)
185            m = M(eltwise_fn)
186            x = torch.rand(1, 32, 28, 28)
187            _, graph = self.checkTrace(m, [x], dtype)
188            # single-op partition shouldn't be created.
189            self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 0)
190
191    @onlyCPU
192    @dtypes(torch.float32, torch.bfloat16)
193    def test_max_pool2d(self, dtype):
194        for [spatial, kernel, padding, stride, dilation, ceil_mode] in itertools.product(
195                [15, 16, 17, 18, 19],
196                [4, 5],
197                [0, 1, 2],
198                [1, 2],  # [1, 2, 4], TODO: fix issue in pad calculation
199                [1],     # [1, 2], TODO: backend support for dilation
200                [True, False]):
201
202            m = nn.MaxPool2d(kernel_size=kernel,
203                             stride=stride,
204                             padding=padding,
205                             dilation=dilation,
206                             ceil_mode=ceil_mode)
207
208            x = torch.rand(1, 4, spatial, spatial)
209            _, graph = self.checkTrace(m, [x], dtype)
210            self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 1)
211
212    @onlyCPU
213    @dtypes(torch.float32, torch.bfloat16)
214    def test_avg_pool2d(self, dtype):
215        for [spatial, kernel, padding, stride, ceil_mode, count_include_pad] in itertools.product(
216                [15, 16, 17, 18, 19],
217                [4, 5],
218                [0, 1, 2],
219                [1, 2, 4],
220                [False],  # TODO: oneDNN Graph does not fully support ceil_mode=True
221                [True, False]):
222
223            m = nn.AvgPool2d(kernel_size=kernel,
224                             stride=stride,
225                             padding=padding,
226                             ceil_mode=ceil_mode,
227                             count_include_pad=count_include_pad)
228
229            x = torch.rand(1, 4, spatial, spatial)
230            _, graph = self.checkTrace(m, [x], dtype)
231            self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 1)
232
233    @onlyCPU
234    @dtypes(torch.float32, torch.bfloat16)
235    def test_variable_kernel_avg_pool2d(self, dtype):
236        class M(nn.Module):
237            def forward(self, x):
238                x = F.avg_pool2d(x, kernel_size=(x.size(2), x.size(3)), padding=0, count_include_pad=False)
239                return x
240
241        x = torch.randn(1, 1000, 1, 1)
242        m = M()
243        _, graph = self.checkTrace(m, [x], dtype)
244        # kernel_size is not Constant, shouldn't have any LLGA_FUSION_GROUP
245        # TODO: with shape specialization, should have 1 LLGA_FUSION_GROUP
246        self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 0)
247
248    @onlyCPU
249    @dtypes(torch.float32, torch.bfloat16)
250    def test_softmax(self, dtype):
251        for dim in [-4, -3, -2, -1, 0, 1, 2, 3]:
252            m = nn.Softmax(dim=dim)
253            x = torch.rand(8, 12, 12, 12)
254            _, graph = self.checkTrace(m, [x], dtype)
255            # single-op partition shouldn't be created for softmax
256            self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 0)
257
258    @onlyCPU
259    @dtypes(torch.float32, torch.bfloat16)
260    def test_linear(self, dtype):
261        for bias in [True, False]:
262            x = torch.rand(32, 28)
263            m = torch.nn.Linear(in_features=28, out_features=64, bias=bias)
264            _, graph = self.checkTrace(m, [x], dtype)
265            self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 1)
266            self.assertFused(graph, ['aten::linear'])
267
268
269    def _gen_binary_inputs(self, gen_permute=True):
270        for xshape, yshape in [
271            [[1, 32, 28, 28], [1, 32, 28, 28]],
272            [[1, 32, 28, 28], [1, 1, 28, 28]],
273            [[1, 32, 28, 28], [28]],
274            [[1, 32, 28, 28], [1]],
275
276        ]:
277            yield torch.rand(xshape), torch.rand(yshape)
278            if gen_permute and xshape != yshape:
279                yield torch.rand(yshape), torch.rand(xshape)
280
281    @onlyCPU
282    @dtypes(torch.float32, torch.bfloat16)
283    def test_add(self, dtype):
284        def forward_add(x, y):
285            return torch.add(x, y, alpha=2)
286
287        for x, y in self._gen_binary_inputs():
288            _, graph = self.checkTrace(forward_add, [x, y], dtype)
289            self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 1)
290
291    @onlyCPU
292    @dtypes(torch.float32, torch.bfloat16)
293    def test_add_scalar(self, dtype):
294        def add_scalar(x):
295            return 42 + x + 3.14
296
297        x = torch.rand(32, 32)
298        _, graph = self.checkTrace(add_scalar, [x], dtype)
299        self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 1)
300
301    @onlyCPU
302    @dtypes(torch.float32, torch.bfloat16)
303    def test_addmm(self, dtype):
304        # Just a sidenote - comparison of eager-mode & oneDNN Graph JIT outputs of
305        # addmm (which entails matmul-bias-add fusion) might require higher tolerance
306        # bounds for BF16. This is subject to change in the near future.
307        def addmm(x, y, z):
308            # alpha and beta are 1, by default
309            return torch.addmm(z, x, y)
310
311        x = torch.rand(64, 32)
312        y = torch.rand(32, 32)
313        z = torch.rand(64, 32)
314        _, graph = self.checkTrace(addmm, [x, y, z], dtype)
315        # single-op partition should be created for matmul with bias.
316        self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 1)
317
318    @onlyCPU
319    @dtypes(torch.float32, torch.bfloat16)
320    def test_mul(self, dtype):
321        def forward_mul(x, y):
322            return torch.mul(x, y) * 3
323
324        for x, y in self._gen_binary_inputs():
325            _, graph = self.checkTrace(forward_mul, [x, y], dtype)
326            # single-op partitions shouldn't be created
327            self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 1)
328
329    @onlyCPU
330    @dtypes(torch.float32, torch.bfloat16)
331    def test_identity_binary(self, dtype):
332        def forward(x):
333            return x * 1 + 0.0
334
335        x = torch.rand(32)
336        _, graph = self.checkTrace(forward, [x], dtype)
337        self.assertFused(graph, ['aten::add', 'aten::mul'])
338
339    @onlyCPU
340    @dtypes(torch.float32, torch.bfloat16)
341    def test_layer_norm(self, dtype):
342        # TODO: support more normalized_shape
343        m = torch.nn.LayerNorm(10)
344        x = torch.randn(2, 5, 10, 10)
345        _, graph = self.checkTrace(m, [x], dtype)
346        self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 1)
347
348    @onlyCPU
349    @dtypes(torch.float32, torch.bfloat16)
350    def test_cat(self, dtype):
351        def cat_along_dim(d):
352            def forward_cat(*inputs):
353                return torch.cat(inputs, d)
354            return forward_cat
355
356        for xshape in [
357            [8, 8, 8, 8],
358            [64, 8, 32],
359            [2048, 64],
360        ]:
361            for d in range(len(xshape)):
362                x = torch.rand(xshape)
363                _, graph = self.checkTrace(cat_along_dim(d), [x, x, x], dtype)
364                self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 1)
365
366    @onlyCPU
367    @dtypes(torch.float32, torch.bfloat16)
368    def test_typecheck(self, dtype):
369        x = torch.rand(32, 28, dtype=dtype)
370        m = torch.nn.Linear(in_features=28, out_features=64, bias=True, dtype=dtype)
371        traced, graph = self.checkTrace(m, [x], dtype)
372        self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 1)
373        self.assertFused(graph, ['aten::linear'])
374        # change the shape of the input, we should enter fallback graph
375        x = torch.rand(5, 28, dtype=dtype)
376        self.assertEqual(m(x), traced(x))
377
378
379@unittest.skipIf(IS_AVX512_UNSUPPORTED, "This test fails for BF16 on machines without AVX512.")
380@unittest.skipIf(LLGA_NOT_ENABLED, "MKL-DNN build is disabled")
381class TestFusionPattern(JitLlgaTestCase):
382    @onlyCPU
383    @dtypes(torch.float32, torch.bfloat16)
384    def test_conv2d_eltwise(self, dtype):
385        class M(nn.Module):
386            def __init__(self, eltwise_fn):
387                super().__init__()
388                self.conv1 = nn.Conv2d(32, 32, 3, padding=1, bias=True)
389                self.conv2 = nn.Conv2d(32, 32, 3, padding=1, bias=False)
390                self.eltwise = eltwise_fn
391
392            def forward(self, x):
393                x = self.conv1(x)
394                x = self.eltwise(x)
395                x = self.conv2(x)
396                x = self.eltwise(x)
397                return x
398
399        for eltwise in ['relu', 'leaky_relu', 'sigmoid', 'square',
400                        'abs', 'exp', 'hardswish', 'tanh', 'hardtanh']:
401            for inplace in [True, False]:
402                eltwise_fn_name = eltwise + '_' if inplace else eltwise
403                eltwise_fn = get_eltwise_fn(eltwise_fn_name)
404
405                m = M(eltwise_fn)
406                x = torch.rand(1, 32, 28, 28)
407                _, graph = self.checkTrace(m, [x], dtype=dtype)
408                self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 2)
409                # test if relu_ is replace with relu by mutation removal pass
410                self.assertFused(graph, ['aten::' + eltwise_fn_name])
411                # test if relu is fused into the fusion group
412                self.assertFused(graph, ['aten::' + eltwise])
413
414    @onlyCPU
415    @dtypes(torch.float32, torch.bfloat16)
416    def test_conv2d_silu(self, dtype):
417        class M(nn.Module):
418            def __init__(self, inplace):
419                super().__init__()
420                self.conv1 = nn.Conv2d(32, 32, 3, padding=1, bias=True)
421                self.conv2 = nn.Conv2d(32, 32, 3, padding=1, bias=True)
422                self.eltwise = nn.SiLU(inplace=inplace)
423
424            def forward(self, x):
425                x = self.conv1(x)
426                x = self.eltwise(x)
427                x = self.conv2(x)
428                return x
429        for inplace in [False, True]:
430            for memory_format in [torch.contiguous_format, torch.channels_last]:
431                m = M(inplace)
432                x = torch.rand(1, 32, 28, 28).to(memory_format=memory_format)
433
434                _, graph = self.checkTrace(m, [x], dtype)
435                self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 2)
436                # oneDNN graph does not have silu OP. The bridge will convert silu to sigmoid - mul
437                # Inplace op will become outplace op on the JIT graph
438                patterns = [
439                    ["aten::_convolution", 'aten::sigmoid', 'aten::mul'],
440                    ["aten::_convolution"]
441                ]
442                silu_op = 'aten::silu_' if inplace else 'aten::silu'
443                self.assertFused(graph, ['aten::_convolution', silu_op])
444                self.checkPatterns(graph, patterns)
445
446    @onlyCPU
447    @dtypes(torch.float32, torch.bfloat16)
448    def test_ensure_tensor_is_rewrapped(self, dtype):
449        class M(nn.Module):
450            def __init__(self, eltwise_fn):
451                super().__init__()
452                self.conv1 = nn.Conv2d(32, 32, 3, padding=1, bias=True)
453                self.conv2 = nn.Conv2d(32, 32, 3, padding=1, bias=True)
454                self.conv3 = nn.Conv2d(32, 32, 3, padding=1, bias=True)
455                self.conv4 = nn.Conv2d(32, 32, 3, padding=1, bias=True)
456                self.eltwise = eltwise_fn
457                self.adaptive_avg_pool_2d = nn.AdaptiveAvgPool2d((5, 7))
458
459            def forward(self, x, y):
460                x = self.conv1(x)
461                x = self.eltwise(x)
462                x = self.conv2(x)
463                x = self.eltwise(x)
464                y = self.conv3(y)
465                y = self.eltwise(y)
466                y = self.conv4(y)
467                y = self.eltwise(y)
468
469                x = torch.add(x, y)
470                x = self.adaptive_avg_pool_2d(x)
471                return x
472
473        eltwise_fn_name = 'relu'
474        eltwise_fn = get_eltwise_fn(eltwise_fn_name)
475        m = M(eltwise_fn)
476        m = m.to(memory_format=torch.channels_last)
477        x = torch.rand(1, 32, 28, 28).to(memory_format=torch.channels_last)
478        y = torch.rand(1, 32, 28, 28).to(memory_format=torch.channels_last)
479        # Simply test if the output is accurate
480        # The output of the second partition is input to adaptive_avg_pool2d, which is
481        # unsupported by LLGA. In resnext101 32x16d, we encountered an accuracy issue.
482        _, graph = self.checkTrace(m, [x, y], dtype)
483        self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 4)
484
485    @onlyCPU
486    @dtypes(torch.float32, torch.bfloat16)
487    def test_conv2d_clamp(self, dtype):
488        class M(nn.Module):
489            def __init__(self) -> None:
490                super().__init__()
491                self.conv1 = nn.Conv2d(32, 32, 3, padding=1, bias=True)
492                self.conv2 = nn.Conv2d(32, 32, 3, padding=1, bias=True)
493                self.conv3 = nn.Conv2d(32, 32, 3, padding=1, bias=True)
494                self.conv4 = nn.Conv2d(32, 32, 3, padding=1, bias=True)
495                self.conv5 = nn.Conv2d(32, 32, 3, padding=1, bias=True)
496
497            def forward(self, x):
498                x = self.conv1(x)
499                x = torch.clamp(x, min=float('-inf'))
500                x = self.conv2(x)
501                x = torch.clamp(x, min=-5)
502                x = self.conv3(x)
503                x = torch.clamp(x, min=0, max=float('inf'))
504                x = self.conv4(x)
505                x = torch.clamp(x, min=1, max=5)
506                x = self.conv5(x)
507                x = torch.clamp(x, max=2)
508                return x
509
510        for inplace in [False, True]:
511            for memory_format in [torch.contiguous_format, torch.channels_last]:
512                x = torch.rand(1, 32, 28, 28).to(memory_format=memory_format)
513                m = M()
514                _, graph = self.checkTrace(m, [x], dtype)
515                self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 5)
516                self.assertFused(graph, ['aten::_convolution', "aten::clamp"])
517
518    @onlyCPU
519    @dtypes(torch.float32, torch.bfloat16)
520    def test_conv2d_bn(self, dtype):
521        class M(nn.Module):
522            def __init__(self) -> None:
523                super().__init__()
524                self.conv1 = nn.Conv2d(32, 32, 3, padding=1, bias=True)
525                self.bn1 = nn.BatchNorm2d(32)
526
527            def forward(self, x):
528                x = self.conv1(x)
529                x = self.bn1(x)
530                return x
531
532        m = M().eval()
533        if dtype == torch.bfloat16:
534            m = optimization.fuse(m)
535        x = torch.rand(1, 32, 28, 28)
536        _, graph = self.checkTrace(m, [x], dtype)
537        self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 1)
538        self.assertFused(graph, ['aten::_convolution', 'aten::batch_norm'])
539
540    @onlyCPU
541    @dtypes(torch.float32, torch.bfloat16)
542    def test_conv2d_bn_relu(self, dtype):
543        class M(nn.Module):
544            def __init__(self) -> None:
545                super().__init__()
546                self.conv1 = nn.Conv2d(32, 32, 3, padding=1, bias=True)
547                self.bn1 = nn.BatchNorm2d(32)
548
549            def forward(self, x):
550                x = self.conv1(x)
551                x = self.bn1(x)
552                x = F.relu(x)
553                return x
554
555        m = M().eval()
556        if dtype == torch.bfloat16:
557            m = optimization.fuse(m)
558        x = torch.rand(1, 32, 28, 28)
559        _, graph = self.checkTrace(m, [x], dtype)
560        self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 1)
561        self.assertFused(graph, ['aten::_convolution', 'aten::batch_norm',
562                                 'aten::relu'])
563
564    @onlyCPU
565    @dtypes(torch.float32, torch.bfloat16)
566    def test_bn2d_eltwise(self, dtype):
567        class M(nn.Module):
568            def __init__(self, eltwise_fn):
569                super().__init__()
570                self.eltwise = eltwise_fn
571                self.bn = nn.BatchNorm2d(32)
572
573            def forward(self, x):
574                x = self.bn(x)
575                x = self.eltwise(x)
576                return x
577
578        for eltwise in ['relu']:
579            eltwise_fn = get_eltwise_fn(eltwise)
580            m = M(eltwise_fn).eval()
581            x = torch.rand(1, 32, 28, 28)
582            _, graph = self.checkTrace(m, [x], dtype)
583            self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 1)
584            self.assertFused(graph, ['aten::' + eltwise])
585
586    @onlyCPU
587    @dtypes(torch.float32, torch.bfloat16)
588    def test_linear_eltwise(self, dtype):
589        class M(nn.Module):
590            def __init__(self, eltwise_fn, bias):
591                super().__init__()
592                self.linear = nn.Linear(28, 64, bias)
593                self.eltwise = eltwise_fn
594
595            def forward(self, x):
596                x = self.linear(x)
597                x = self.eltwise(x)
598                return x
599
600        for [has_bias, eltwise] in itertools.product(
601                [True, False],
602                ['relu', 'gelu', 'sigmoid', 'hardtanh', 'relu6', 'elu']):
603
604            eltwise_fn = get_eltwise_fn(eltwise)
605            m = M(eltwise_fn, has_bias)
606            x = torch.rand(32, 28, requires_grad=False)
607            _, graph = self.checkTrace(m, [x], dtype)
608            self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 1)
609            self.assertFused(graph, ['aten::' + eltwise])
610
611    @onlyCPU
612    @dtypes(torch.float32, torch.bfloat16)
613    def test_conv2d_sum(self, dtype):
614        class M(nn.Module):
615            def __init__(self, bias=False):
616                super().__init__()
617                self.conv1 = nn.Conv2d(32, 32, 3, padding=1, bias=bias)
618                self.bn1 = nn.BatchNorm2d(32)
619                self.conv2 = nn.Conv2d(32, 32, 3, padding=1, bias=bias)
620                self.bn2 = nn.BatchNorm2d(32)
621                self.relu = nn.ReLU()
622                self.conv3 = nn.Conv2d(32, 32, 3, padding=1, bias=bias)
623                self.bn3 = nn.BatchNorm2d(32)
624
625            def forward(self, x, y):
626                x = self.conv1(x)
627                x = self.bn1(x)
628                y = self.conv2(y)
629                y = self.bn2(y)
630                z = self.relu(x + y)
631                z = self.conv3(z)
632                z = self.bn3(z)
633                return z
634
635        for bias in [True, False]:
636            m = M(bias).eval()
637            if dtype == torch.bfloat16:
638                m = optimization.fuse(m)
639            x = torch.rand(1, 32, 16, 16, requires_grad=False)
640            y = torch.rand(1, 32, 16, 16, requires_grad=False)
641            _, graph = self.checkTrace(m, [x, y], dtype)
642            self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 3)
643
644    @onlyCPU
645    @dtypes(torch.float32, torch.bfloat16)
646    def test_wildcard(self, dtype):
647        class M(nn.Module):
648            def __init__(self) -> None:
649                super().__init__()
650                self.conv1 = nn.Conv2d(32, 32, 3, padding=1, bias=True)
651                self.eltwise = nn.ReLU()
652
653            def forward(self, x):
654                x = self.conv1(x)
655                y = self.eltwise(x)
656                return [x, y]
657
658        # The pattern is as the following:
659        #      conv
660        #     |    \
661        # eltwise   \
662        #    |       \
663        #  ListConstruct
664        #
665        # The output of conv is used by a wildcard op: ListConstruct.
666        # Thus conv-eltwise cannot be selected into the same Partition.
667        m = M()
668        x = torch.rand(1, 32, 28, 28)
669        _, graph = self.checkTrace(m, [x], dtype)
670        # conv can exist in a single-op oneDNN Graph partition but not relu
671        self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 1)
672        self.assertFused(graph, ['aten::_convolution'])
673
674    @onlyCPU
675    @dtypes(torch.int32)
676    def test_wildcard_unsupported_dtype(self, dtype):
677        class M(nn.Module):
678            def forward(self, x):
679                y = x // 2
680                return y
681
682        # In shufflenet_v2_x1_0, channels_per_groups is computed as:
683        # channels_per_group = num_channels // groups
684        # JIT IR converts groups to Long dtype, which is unsupported
685        # by oneDNN Graph, viz. Long(requires_grad=0, device=cpu) = prim::Constant[value={2}]()
686        # This test just ensures that the bridge code can handle
687        # unsupported dtypes for inputs to ops unsupported
688        # by oneDNN Graph. In this particular UT, aten::floor_divide
689        # would be added as a wildcard in graph-construction stage.
690        m = M()
691        x = torch.tensor([32], dtype=dtype)
692        _, graph = self.checkTrace(m, [x], dtype)
693        self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 0)
694
695    @onlyCPU
696    @dtypes(torch.float32, torch.bfloat16)
697    def test_rewrap_tensor_input_to_pytorch(self, dtype):
698        class M(nn.Module):
699            def __init__(self, eltwise_fn):
700                super().__init__()
701                self.conv1 = nn.Conv2d(32, 32, 3, padding=1, bias=True)
702                self.conv2 = nn.Conv2d(32, 32, 3, padding=1, bias=True)
703                self.eltwise = eltwise_fn
704                self.adaptive_avg_pool_2d = nn.AdaptiveAvgPool2d((5, 7))
705
706            def forward(self, x, y):
707                x = self.conv1(x)
708                x = self.eltwise(x)
709                x = self.conv2(x)
710                x = self.eltwise(x)
711                x = torch.add(x, y)
712                x = self.adaptive_avg_pool_2d(x)
713                return x
714
715        eltwise_fn_name = 'relu'
716        eltwise_fn = get_eltwise_fn(eltwise_fn_name)
717        m = M(eltwise_fn)
718        m = m.to(memory_format=torch.channels_last)
719        x = torch.rand(1, 32, 28, 28).to(memory_format=torch.channels_last)
720        y = torch.rand(1, 32, 28, 28).to(memory_format=torch.channels_last)
721        # Simply test if the output is accurate
722        # The output of the second partition is input to adaptive_avg_pool2d, which is
723        # unsupported by LLGA, so it must be handled by PyTorch, which should receive
724        # correct strides info of the channels-last tensor.
725        graph, _ = self.checkTrace(m, [x, y], dtype)
726
727@unittest.skipIf(LLGA_NOT_ENABLED, "MKL-DNN build is disabled")
728class TestEnableDisableLlgaFuser(JitTestCase):
729    def setUp(self):
730        super().setUp()
731        self.is_enabled = torch._C._jit_set_llga_enabled(False)
732
733    def tearDown(self):
734        torch._C._jit_set_llga_enabled(self.is_enabled)
735        super().tearDown()
736
737    def test_context_manager(self):
738        x = torch.randn(4, 8)
739        y = torch.randn(4, 8)
740        with torch.jit.fuser('fuser3'):
741            with torch.jit.fuser('fuser3'):
742
743                def t1(x, y):
744                    o = x + y
745                    o = o + 2.0
746                    return o
747                t_jit = torch.jit.script(t1)
748                t_jit(x, y)
749                t_jit(x, y)
750                self.assertGraphContains(t_jit.graph_for(x, y), LLGA_FUSION_GROUP)
751
752            def t2(x, y):
753                o = x + y
754                o = o + 3.0
755                return o
756            t_jit_2 = torch.jit.script(t2)
757            t_jit_2(x, y)
758            t_jit_2(x, y)
759            self.assertGraphContains(t_jit_2.graph_for(x, y), LLGA_FUSION_GROUP)
760
761        def t3(x, y):
762            o = x + y
763            o = o + 4.0
764            return o
765        t_jit_3 = torch.jit.script(t3)
766        t_jit_3(x, y)
767        t_jit_3(x, y)
768        self.assertGraphContainsExactly(t_jit_3.graph_for(x, y), LLGA_FUSION_GROUP, 0)
769
770
771@unittest.skipIf(LLGA_NOT_ENABLED, "MKL-DNN build is disabled")
772@unittest.skip("Enable when integration with dynamo aot_autograd is more stable")
773class TestDynamoAOT(JitTestCase):
774    def test_dynamo_aot_ts_onednn(self):
775        class Seq(nn.Module):
776            def __init__(self) -> None:
777                super().__init__()
778                self.layers = nn.Sequential(
779                    nn.Linear(10, 10),
780                    nn.ReLU(),
781                    nn.Linear(10, 10),
782                    nn.ReLU(),
783                )
784
785            def forward(self, x):
786                return self.layers(x)
787
788        mod = Seq()
789
790        import torch._dynamo
791        aot_mod = torch._dynamo.optimize("aot_ts", nopython=True)(mod)
792
793        for _ in range(10):
794            with torch.jit.fuser("fuser3"):
795                loss = aot_mod(torch.rand([10, 10])).sum()
796                loss.backward()
797
798        torch._dynamo.reset()
799
800
801@unittest.skipIf(IS_AVX512_UNSUPPORTED, "This test fails for BF16 on machines without AVX512.")
802@unittest.skipIf(LLGA_NOT_ENABLED, "MKL-DNN build is disabled")
803class TestModel(JitLlgaTestCase):
804    @skipIfNoTorchVision
805    def _test_vision(self, model_name, dtype):
806        m = getattr(torchvision.models, model_name)().eval()
807        if dtype == torch.bfloat16:
808            m = optimization.fuse(m)
809        x = torch.rand(1, 3, 224, 224) / 10
810        _, graph = self.checkTrace(m, [x], dtype)
811        self.assertFused(graph, ['aten::_convolution', 'aten::batch_norm',
812                                 'aten::relu', 'aten::linear',
813                                 'aten::avg_pool2d', 'aten::max_pool2d'])
814
815for model_name, enabled in [
816    ['resnet50', True],
817    ['resnext50_32x4d', True],
818    ['resnext101_32x8d', True],
819    ['densenet121', True],
820    ['densenet161', True],
821    ['densenet169', True],
822    ['densenet201', True],
823    ['efficientnet_b0', True],
824    ['efficientnet_b1', True],
825    ['efficientnet_b2', True],
826    ['efficientnet_b3', True],
827    ['efficientnet_b4', True],
828    ['efficientnet_b5', True],
829    ['efficientnet_b6', True],
830    ['efficientnet_b7', True],
831    ['regnet_y_400mf', True],
832    ['googlenet', TEST_SCIPY],
833    ['mobilenet_v2', True],
834    ['mobilenet_v3_large', True],
835    ['mnasnet1_0', True],
836    ['squeezenet1_0', True],
837    ['vgg16', True],
838    ['alexnet', True],
839    ['shufflenet_v2_x1_0', True],
840    ['wide_resnet50_2', True],
841]:
842    def _wrapper(mname, dtype):
843        @unittest.skipIf(not enabled, 'Disabled')
844        @separate_process
845        def test(self, dtype=dtype):
846            return self._test_vision(mname, dtype)
847        return test
848
849    for dtype in [torch.bfloat16, torch.float32]:
850        setattr(TestModel, 'test_vision_{}_{}'.format(model_name, str(dtype).split("torch.")[1]), _wrapper(model_name, dtype))
851
852
853instantiate_device_type_tests(TestFusionPattern, globals())
854instantiate_device_type_tests(TestOp, globals())
855
856if __name__ == '__main__':
857    run_tests()
858