1# Owner(s): ["module: mkldnn"] 2import sys 3import torch 4import unittest 5import itertools 6import torch.nn as nn 7from functools import wraps 8from concurrent import futures 9import torch.nn.functional as F 10import torch.fx.experimental.optimization as optimization 11from torch.testing._internal.jit_utils import JitTestCase 12from torch.testing._internal.common_utils import run_tests, TEST_SCIPY, IS_WINDOWS, IS_MACOS 13from torch.testing._internal.common_device_type import ( 14 instantiate_device_type_tests, 15 onlyCPU, 16 dtypes 17) 18 19# We use this wrapper to run UTs of TorchVision models because of a memory-leak 20# issue with JIT tracing that causes traced model objects to persist in the 21# memory. Ref: https://github.com/pytorch/pytorch/issues/35600 22# Memory requirement for running these UTs was thus increasing cumulatively, and 23# invoked the Linux kernel OOM killer on linux.2xlarge PyTorch CI runners, which 24# only have 16 GB RAM. Cumulatively, these UTs had been using more than 14 GB 25# memory (as per psutils). So now we run each TorchVision model UTs in separate processes. 26def separate_process(func): 27 @wraps(func) 28 def wrapper(*args, **kwargs): 29 with futures.ProcessPoolExecutor() as executor: 30 future = executor.submit(func, *args, **kwargs) 31 futures.wait([future]) 32 return wrapper 33 34def is_avx512_supported(): 35 if sys.platform != 'linux': 36 return False 37 with open("/proc/cpuinfo", encoding="ascii") as f: 38 lines = f.read() 39 return "avx512" in lines 40 41IS_AVX512_UNSUPPORTED = not is_avx512_supported() 42 43LLGA_FUSION_GROUP = 'prim::oneDNNFusionGroup' 44LLGA_NOT_ENABLED = not torch.backends.mkldnn.is_available() or IS_WINDOWS or IS_MACOS 45 46def warmup_forward(f, *args, profiling_count=3): 47 for i in range(profiling_count): 48 results = f(*args) 49 50 return results 51 52class JitLlgaTestCase(JitTestCase): 53 54 def setUp(self): 55 # PyTorch has divergent op support for AMP in JIT & eager modes 56 # so we disable AMP for JIT & leverage eager-mode AMP. 57 # Ref: https://github.com/pytorch/pytorch/issues/75956 58 self.original_autocast_mode = torch._C._jit_set_autocast_mode(False) 59 torch.jit.enable_onednn_fusion(True) 60 61 def tearDown(self): 62 torch.jit.enable_onednn_fusion(False) 63 torch._C._jit_set_autocast_mode(self.original_autocast_mode) 64 65 def checkTrace(self, m, x, dtype=torch.float32, *args, **kwargs): 66 if isinstance(m, torch.nn.Module): 67 m.eval() 68 with torch.no_grad(), torch._jit_internal._disable_emit_hooks(): 69 if dtype == torch.bfloat16: 70 # We rely upon eager-mode AMP support for BF16 71 with torch.cpu.amp.autocast(cache_enabled=False, dtype=torch.bfloat16): 72 traced = torch.jit.trace(m, x) 73 if isinstance(m, torch.nn.Module): 74 traced = torch.jit.freeze(traced) 75 warmup_forward(traced, *x) 76 ref_o = m(*x) 77 fwd_graph = traced.graph_for(*x) 78 else: 79 traced = torch.jit.trace(m, x) 80 if isinstance(m, torch.nn.Module): 81 traced = torch.jit.freeze(traced) 82 warmup_forward(traced, *x) 83 ref_o = m(*x) 84 fwd_graph = traced.graph_for(*x) 85 86 jit_o = traced(*x) 87 self.assertEqual(jit_o, ref_o) 88 return traced, fwd_graph 89 90 91 def assertFused(self, graph, fused_patterns): 92 for pat in fused_patterns: 93 self.assertGraphContainsExactly(graph, pat, 0) 94 95 def findFusionGroups(self, graph): 96 result = [] 97 for n in graph.nodes(): 98 if n.kind() == LLGA_FUSION_GROUP: 99 result.append(n.g('Subgraph')) 100 continue 101 for block in n.blocks(): 102 result += self.findFusionGroups(block) 103 return result 104 105 def checkPatterns(self, graph, patterns): 106 fusion_groups = self.findFusionGroups(graph) 107 assert len(fusion_groups) == len(patterns), "length of subgraphs not equal to length of given patterns" 108 109 for i in range(len(fusion_groups)): 110 for pattern in patterns[i]: 111 self.assertGraphContains(fusion_groups[i], pattern) 112 113try: 114 import torchvision 115 HAS_TORCHVISION = True 116except ImportError: 117 HAS_TORCHVISION = False 118except RuntimeError: 119 HAS_TORCHVISION = False 120skipIfNoTorchVision = unittest.skipIf(not HAS_TORCHVISION, 'no torchvision') 121 122def get_eltwise_fn(name): 123 if hasattr(torch, name): 124 return getattr(torch, name) 125 elif hasattr(F, name): 126 return getattr(F, name) 127 elif name == 'hardswish_': 128 return torch.nn.Hardswish(inplace=True) 129 else: 130 raise NameError(f'Eltwise function {name} not found') 131 132 133@unittest.skipIf(IS_AVX512_UNSUPPORTED, "This test fails for BF16 on machines without AVX512.") 134@unittest.skipIf(LLGA_NOT_ENABLED, "MKL-DNN build is disabled") 135class TestOp(JitLlgaTestCase): 136 @onlyCPU 137 @dtypes(torch.float32, torch.bfloat16) 138 def test_conv2d(self, dtype): 139 for [spatial, in_channels, out_channels, kernel, padding, stride, dilation, g, bias] in itertools.product( 140 [7, 8], 141 [8, 15], 142 [7, 16], 143 [3, 4], 144 [0, 2], 145 [1, 2], 146 [1, 2], 147 [1, 2], 148 [True, False]): 149 150 m = nn.Conv2d(in_channels=in_channels * g, 151 out_channels=out_channels * g, 152 kernel_size=kernel, 153 padding=padding, 154 stride=stride, 155 dilation=dilation, 156 groups=g, 157 bias=bias) 158 159 x = torch.rand(1, in_channels * g, spatial, spatial) 160 _, graph = self.checkTrace(m, [x], dtype) 161 self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 1) 162 163 @onlyCPU 164 @dtypes(torch.float32, torch.bfloat16) 165 def test_bn2d(self, dtype): 166 m = nn.BatchNorm2d(32).eval() 167 x = torch.rand(1, 32, 28, 28) 168 _, graph = self.checkTrace(m, [x], dtype) 169 # single-op partition shouldn't be created for softmax 170 self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 0) 171 172 @onlyCPU 173 @dtypes(torch.float32, torch.bfloat16) 174 def test_eltwise(self, dtype): 175 class M(nn.Module): 176 def __init__(self, eltwise_fn): 177 super().__init__() 178 self.eltwise = eltwise_fn 179 180 def forward(self, x): 181 return self.eltwise(x) 182 183 for eltwise in ['relu', 'gelu']: 184 eltwise_fn = get_eltwise_fn(eltwise) 185 m = M(eltwise_fn) 186 x = torch.rand(1, 32, 28, 28) 187 _, graph = self.checkTrace(m, [x], dtype) 188 # single-op partition shouldn't be created. 189 self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 0) 190 191 @onlyCPU 192 @dtypes(torch.float32, torch.bfloat16) 193 def test_max_pool2d(self, dtype): 194 for [spatial, kernel, padding, stride, dilation, ceil_mode] in itertools.product( 195 [15, 16, 17, 18, 19], 196 [4, 5], 197 [0, 1, 2], 198 [1, 2], # [1, 2, 4], TODO: fix issue in pad calculation 199 [1], # [1, 2], TODO: backend support for dilation 200 [True, False]): 201 202 m = nn.MaxPool2d(kernel_size=kernel, 203 stride=stride, 204 padding=padding, 205 dilation=dilation, 206 ceil_mode=ceil_mode) 207 208 x = torch.rand(1, 4, spatial, spatial) 209 _, graph = self.checkTrace(m, [x], dtype) 210 self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 1) 211 212 @onlyCPU 213 @dtypes(torch.float32, torch.bfloat16) 214 def test_avg_pool2d(self, dtype): 215 for [spatial, kernel, padding, stride, ceil_mode, count_include_pad] in itertools.product( 216 [15, 16, 17, 18, 19], 217 [4, 5], 218 [0, 1, 2], 219 [1, 2, 4], 220 [False], # TODO: oneDNN Graph does not fully support ceil_mode=True 221 [True, False]): 222 223 m = nn.AvgPool2d(kernel_size=kernel, 224 stride=stride, 225 padding=padding, 226 ceil_mode=ceil_mode, 227 count_include_pad=count_include_pad) 228 229 x = torch.rand(1, 4, spatial, spatial) 230 _, graph = self.checkTrace(m, [x], dtype) 231 self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 1) 232 233 @onlyCPU 234 @dtypes(torch.float32, torch.bfloat16) 235 def test_variable_kernel_avg_pool2d(self, dtype): 236 class M(nn.Module): 237 def forward(self, x): 238 x = F.avg_pool2d(x, kernel_size=(x.size(2), x.size(3)), padding=0, count_include_pad=False) 239 return x 240 241 x = torch.randn(1, 1000, 1, 1) 242 m = M() 243 _, graph = self.checkTrace(m, [x], dtype) 244 # kernel_size is not Constant, shouldn't have any LLGA_FUSION_GROUP 245 # TODO: with shape specialization, should have 1 LLGA_FUSION_GROUP 246 self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 0) 247 248 @onlyCPU 249 @dtypes(torch.float32, torch.bfloat16) 250 def test_softmax(self, dtype): 251 for dim in [-4, -3, -2, -1, 0, 1, 2, 3]: 252 m = nn.Softmax(dim=dim) 253 x = torch.rand(8, 12, 12, 12) 254 _, graph = self.checkTrace(m, [x], dtype) 255 # single-op partition shouldn't be created for softmax 256 self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 0) 257 258 @onlyCPU 259 @dtypes(torch.float32, torch.bfloat16) 260 def test_linear(self, dtype): 261 for bias in [True, False]: 262 x = torch.rand(32, 28) 263 m = torch.nn.Linear(in_features=28, out_features=64, bias=bias) 264 _, graph = self.checkTrace(m, [x], dtype) 265 self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 1) 266 self.assertFused(graph, ['aten::linear']) 267 268 269 def _gen_binary_inputs(self, gen_permute=True): 270 for xshape, yshape in [ 271 [[1, 32, 28, 28], [1, 32, 28, 28]], 272 [[1, 32, 28, 28], [1, 1, 28, 28]], 273 [[1, 32, 28, 28], [28]], 274 [[1, 32, 28, 28], [1]], 275 276 ]: 277 yield torch.rand(xshape), torch.rand(yshape) 278 if gen_permute and xshape != yshape: 279 yield torch.rand(yshape), torch.rand(xshape) 280 281 @onlyCPU 282 @dtypes(torch.float32, torch.bfloat16) 283 def test_add(self, dtype): 284 def forward_add(x, y): 285 return torch.add(x, y, alpha=2) 286 287 for x, y in self._gen_binary_inputs(): 288 _, graph = self.checkTrace(forward_add, [x, y], dtype) 289 self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 1) 290 291 @onlyCPU 292 @dtypes(torch.float32, torch.bfloat16) 293 def test_add_scalar(self, dtype): 294 def add_scalar(x): 295 return 42 + x + 3.14 296 297 x = torch.rand(32, 32) 298 _, graph = self.checkTrace(add_scalar, [x], dtype) 299 self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 1) 300 301 @onlyCPU 302 @dtypes(torch.float32, torch.bfloat16) 303 def test_addmm(self, dtype): 304 # Just a sidenote - comparison of eager-mode & oneDNN Graph JIT outputs of 305 # addmm (which entails matmul-bias-add fusion) might require higher tolerance 306 # bounds for BF16. This is subject to change in the near future. 307 def addmm(x, y, z): 308 # alpha and beta are 1, by default 309 return torch.addmm(z, x, y) 310 311 x = torch.rand(64, 32) 312 y = torch.rand(32, 32) 313 z = torch.rand(64, 32) 314 _, graph = self.checkTrace(addmm, [x, y, z], dtype) 315 # single-op partition should be created for matmul with bias. 316 self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 1) 317 318 @onlyCPU 319 @dtypes(torch.float32, torch.bfloat16) 320 def test_mul(self, dtype): 321 def forward_mul(x, y): 322 return torch.mul(x, y) * 3 323 324 for x, y in self._gen_binary_inputs(): 325 _, graph = self.checkTrace(forward_mul, [x, y], dtype) 326 # single-op partitions shouldn't be created 327 self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 1) 328 329 @onlyCPU 330 @dtypes(torch.float32, torch.bfloat16) 331 def test_identity_binary(self, dtype): 332 def forward(x): 333 return x * 1 + 0.0 334 335 x = torch.rand(32) 336 _, graph = self.checkTrace(forward, [x], dtype) 337 self.assertFused(graph, ['aten::add', 'aten::mul']) 338 339 @onlyCPU 340 @dtypes(torch.float32, torch.bfloat16) 341 def test_layer_norm(self, dtype): 342 # TODO: support more normalized_shape 343 m = torch.nn.LayerNorm(10) 344 x = torch.randn(2, 5, 10, 10) 345 _, graph = self.checkTrace(m, [x], dtype) 346 self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 1) 347 348 @onlyCPU 349 @dtypes(torch.float32, torch.bfloat16) 350 def test_cat(self, dtype): 351 def cat_along_dim(d): 352 def forward_cat(*inputs): 353 return torch.cat(inputs, d) 354 return forward_cat 355 356 for xshape in [ 357 [8, 8, 8, 8], 358 [64, 8, 32], 359 [2048, 64], 360 ]: 361 for d in range(len(xshape)): 362 x = torch.rand(xshape) 363 _, graph = self.checkTrace(cat_along_dim(d), [x, x, x], dtype) 364 self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 1) 365 366 @onlyCPU 367 @dtypes(torch.float32, torch.bfloat16) 368 def test_typecheck(self, dtype): 369 x = torch.rand(32, 28, dtype=dtype) 370 m = torch.nn.Linear(in_features=28, out_features=64, bias=True, dtype=dtype) 371 traced, graph = self.checkTrace(m, [x], dtype) 372 self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 1) 373 self.assertFused(graph, ['aten::linear']) 374 # change the shape of the input, we should enter fallback graph 375 x = torch.rand(5, 28, dtype=dtype) 376 self.assertEqual(m(x), traced(x)) 377 378 379@unittest.skipIf(IS_AVX512_UNSUPPORTED, "This test fails for BF16 on machines without AVX512.") 380@unittest.skipIf(LLGA_NOT_ENABLED, "MKL-DNN build is disabled") 381class TestFusionPattern(JitLlgaTestCase): 382 @onlyCPU 383 @dtypes(torch.float32, torch.bfloat16) 384 def test_conv2d_eltwise(self, dtype): 385 class M(nn.Module): 386 def __init__(self, eltwise_fn): 387 super().__init__() 388 self.conv1 = nn.Conv2d(32, 32, 3, padding=1, bias=True) 389 self.conv2 = nn.Conv2d(32, 32, 3, padding=1, bias=False) 390 self.eltwise = eltwise_fn 391 392 def forward(self, x): 393 x = self.conv1(x) 394 x = self.eltwise(x) 395 x = self.conv2(x) 396 x = self.eltwise(x) 397 return x 398 399 for eltwise in ['relu', 'leaky_relu', 'sigmoid', 'square', 400 'abs', 'exp', 'hardswish', 'tanh', 'hardtanh']: 401 for inplace in [True, False]: 402 eltwise_fn_name = eltwise + '_' if inplace else eltwise 403 eltwise_fn = get_eltwise_fn(eltwise_fn_name) 404 405 m = M(eltwise_fn) 406 x = torch.rand(1, 32, 28, 28) 407 _, graph = self.checkTrace(m, [x], dtype=dtype) 408 self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 2) 409 # test if relu_ is replace with relu by mutation removal pass 410 self.assertFused(graph, ['aten::' + eltwise_fn_name]) 411 # test if relu is fused into the fusion group 412 self.assertFused(graph, ['aten::' + eltwise]) 413 414 @onlyCPU 415 @dtypes(torch.float32, torch.bfloat16) 416 def test_conv2d_silu(self, dtype): 417 class M(nn.Module): 418 def __init__(self, inplace): 419 super().__init__() 420 self.conv1 = nn.Conv2d(32, 32, 3, padding=1, bias=True) 421 self.conv2 = nn.Conv2d(32, 32, 3, padding=1, bias=True) 422 self.eltwise = nn.SiLU(inplace=inplace) 423 424 def forward(self, x): 425 x = self.conv1(x) 426 x = self.eltwise(x) 427 x = self.conv2(x) 428 return x 429 for inplace in [False, True]: 430 for memory_format in [torch.contiguous_format, torch.channels_last]: 431 m = M(inplace) 432 x = torch.rand(1, 32, 28, 28).to(memory_format=memory_format) 433 434 _, graph = self.checkTrace(m, [x], dtype) 435 self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 2) 436 # oneDNN graph does not have silu OP. The bridge will convert silu to sigmoid - mul 437 # Inplace op will become outplace op on the JIT graph 438 patterns = [ 439 ["aten::_convolution", 'aten::sigmoid', 'aten::mul'], 440 ["aten::_convolution"] 441 ] 442 silu_op = 'aten::silu_' if inplace else 'aten::silu' 443 self.assertFused(graph, ['aten::_convolution', silu_op]) 444 self.checkPatterns(graph, patterns) 445 446 @onlyCPU 447 @dtypes(torch.float32, torch.bfloat16) 448 def test_ensure_tensor_is_rewrapped(self, dtype): 449 class M(nn.Module): 450 def __init__(self, eltwise_fn): 451 super().__init__() 452 self.conv1 = nn.Conv2d(32, 32, 3, padding=1, bias=True) 453 self.conv2 = nn.Conv2d(32, 32, 3, padding=1, bias=True) 454 self.conv3 = nn.Conv2d(32, 32, 3, padding=1, bias=True) 455 self.conv4 = nn.Conv2d(32, 32, 3, padding=1, bias=True) 456 self.eltwise = eltwise_fn 457 self.adaptive_avg_pool_2d = nn.AdaptiveAvgPool2d((5, 7)) 458 459 def forward(self, x, y): 460 x = self.conv1(x) 461 x = self.eltwise(x) 462 x = self.conv2(x) 463 x = self.eltwise(x) 464 y = self.conv3(y) 465 y = self.eltwise(y) 466 y = self.conv4(y) 467 y = self.eltwise(y) 468 469 x = torch.add(x, y) 470 x = self.adaptive_avg_pool_2d(x) 471 return x 472 473 eltwise_fn_name = 'relu' 474 eltwise_fn = get_eltwise_fn(eltwise_fn_name) 475 m = M(eltwise_fn) 476 m = m.to(memory_format=torch.channels_last) 477 x = torch.rand(1, 32, 28, 28).to(memory_format=torch.channels_last) 478 y = torch.rand(1, 32, 28, 28).to(memory_format=torch.channels_last) 479 # Simply test if the output is accurate 480 # The output of the second partition is input to adaptive_avg_pool2d, which is 481 # unsupported by LLGA. In resnext101 32x16d, we encountered an accuracy issue. 482 _, graph = self.checkTrace(m, [x, y], dtype) 483 self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 4) 484 485 @onlyCPU 486 @dtypes(torch.float32, torch.bfloat16) 487 def test_conv2d_clamp(self, dtype): 488 class M(nn.Module): 489 def __init__(self) -> None: 490 super().__init__() 491 self.conv1 = nn.Conv2d(32, 32, 3, padding=1, bias=True) 492 self.conv2 = nn.Conv2d(32, 32, 3, padding=1, bias=True) 493 self.conv3 = nn.Conv2d(32, 32, 3, padding=1, bias=True) 494 self.conv4 = nn.Conv2d(32, 32, 3, padding=1, bias=True) 495 self.conv5 = nn.Conv2d(32, 32, 3, padding=1, bias=True) 496 497 def forward(self, x): 498 x = self.conv1(x) 499 x = torch.clamp(x, min=float('-inf')) 500 x = self.conv2(x) 501 x = torch.clamp(x, min=-5) 502 x = self.conv3(x) 503 x = torch.clamp(x, min=0, max=float('inf')) 504 x = self.conv4(x) 505 x = torch.clamp(x, min=1, max=5) 506 x = self.conv5(x) 507 x = torch.clamp(x, max=2) 508 return x 509 510 for inplace in [False, True]: 511 for memory_format in [torch.contiguous_format, torch.channels_last]: 512 x = torch.rand(1, 32, 28, 28).to(memory_format=memory_format) 513 m = M() 514 _, graph = self.checkTrace(m, [x], dtype) 515 self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 5) 516 self.assertFused(graph, ['aten::_convolution', "aten::clamp"]) 517 518 @onlyCPU 519 @dtypes(torch.float32, torch.bfloat16) 520 def test_conv2d_bn(self, dtype): 521 class M(nn.Module): 522 def __init__(self) -> None: 523 super().__init__() 524 self.conv1 = nn.Conv2d(32, 32, 3, padding=1, bias=True) 525 self.bn1 = nn.BatchNorm2d(32) 526 527 def forward(self, x): 528 x = self.conv1(x) 529 x = self.bn1(x) 530 return x 531 532 m = M().eval() 533 if dtype == torch.bfloat16: 534 m = optimization.fuse(m) 535 x = torch.rand(1, 32, 28, 28) 536 _, graph = self.checkTrace(m, [x], dtype) 537 self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 1) 538 self.assertFused(graph, ['aten::_convolution', 'aten::batch_norm']) 539 540 @onlyCPU 541 @dtypes(torch.float32, torch.bfloat16) 542 def test_conv2d_bn_relu(self, dtype): 543 class M(nn.Module): 544 def __init__(self) -> None: 545 super().__init__() 546 self.conv1 = nn.Conv2d(32, 32, 3, padding=1, bias=True) 547 self.bn1 = nn.BatchNorm2d(32) 548 549 def forward(self, x): 550 x = self.conv1(x) 551 x = self.bn1(x) 552 x = F.relu(x) 553 return x 554 555 m = M().eval() 556 if dtype == torch.bfloat16: 557 m = optimization.fuse(m) 558 x = torch.rand(1, 32, 28, 28) 559 _, graph = self.checkTrace(m, [x], dtype) 560 self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 1) 561 self.assertFused(graph, ['aten::_convolution', 'aten::batch_norm', 562 'aten::relu']) 563 564 @onlyCPU 565 @dtypes(torch.float32, torch.bfloat16) 566 def test_bn2d_eltwise(self, dtype): 567 class M(nn.Module): 568 def __init__(self, eltwise_fn): 569 super().__init__() 570 self.eltwise = eltwise_fn 571 self.bn = nn.BatchNorm2d(32) 572 573 def forward(self, x): 574 x = self.bn(x) 575 x = self.eltwise(x) 576 return x 577 578 for eltwise in ['relu']: 579 eltwise_fn = get_eltwise_fn(eltwise) 580 m = M(eltwise_fn).eval() 581 x = torch.rand(1, 32, 28, 28) 582 _, graph = self.checkTrace(m, [x], dtype) 583 self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 1) 584 self.assertFused(graph, ['aten::' + eltwise]) 585 586 @onlyCPU 587 @dtypes(torch.float32, torch.bfloat16) 588 def test_linear_eltwise(self, dtype): 589 class M(nn.Module): 590 def __init__(self, eltwise_fn, bias): 591 super().__init__() 592 self.linear = nn.Linear(28, 64, bias) 593 self.eltwise = eltwise_fn 594 595 def forward(self, x): 596 x = self.linear(x) 597 x = self.eltwise(x) 598 return x 599 600 for [has_bias, eltwise] in itertools.product( 601 [True, False], 602 ['relu', 'gelu', 'sigmoid', 'hardtanh', 'relu6', 'elu']): 603 604 eltwise_fn = get_eltwise_fn(eltwise) 605 m = M(eltwise_fn, has_bias) 606 x = torch.rand(32, 28, requires_grad=False) 607 _, graph = self.checkTrace(m, [x], dtype) 608 self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 1) 609 self.assertFused(graph, ['aten::' + eltwise]) 610 611 @onlyCPU 612 @dtypes(torch.float32, torch.bfloat16) 613 def test_conv2d_sum(self, dtype): 614 class M(nn.Module): 615 def __init__(self, bias=False): 616 super().__init__() 617 self.conv1 = nn.Conv2d(32, 32, 3, padding=1, bias=bias) 618 self.bn1 = nn.BatchNorm2d(32) 619 self.conv2 = nn.Conv2d(32, 32, 3, padding=1, bias=bias) 620 self.bn2 = nn.BatchNorm2d(32) 621 self.relu = nn.ReLU() 622 self.conv3 = nn.Conv2d(32, 32, 3, padding=1, bias=bias) 623 self.bn3 = nn.BatchNorm2d(32) 624 625 def forward(self, x, y): 626 x = self.conv1(x) 627 x = self.bn1(x) 628 y = self.conv2(y) 629 y = self.bn2(y) 630 z = self.relu(x + y) 631 z = self.conv3(z) 632 z = self.bn3(z) 633 return z 634 635 for bias in [True, False]: 636 m = M(bias).eval() 637 if dtype == torch.bfloat16: 638 m = optimization.fuse(m) 639 x = torch.rand(1, 32, 16, 16, requires_grad=False) 640 y = torch.rand(1, 32, 16, 16, requires_grad=False) 641 _, graph = self.checkTrace(m, [x, y], dtype) 642 self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 3) 643 644 @onlyCPU 645 @dtypes(torch.float32, torch.bfloat16) 646 def test_wildcard(self, dtype): 647 class M(nn.Module): 648 def __init__(self) -> None: 649 super().__init__() 650 self.conv1 = nn.Conv2d(32, 32, 3, padding=1, bias=True) 651 self.eltwise = nn.ReLU() 652 653 def forward(self, x): 654 x = self.conv1(x) 655 y = self.eltwise(x) 656 return [x, y] 657 658 # The pattern is as the following: 659 # conv 660 # | \ 661 # eltwise \ 662 # | \ 663 # ListConstruct 664 # 665 # The output of conv is used by a wildcard op: ListConstruct. 666 # Thus conv-eltwise cannot be selected into the same Partition. 667 m = M() 668 x = torch.rand(1, 32, 28, 28) 669 _, graph = self.checkTrace(m, [x], dtype) 670 # conv can exist in a single-op oneDNN Graph partition but not relu 671 self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 1) 672 self.assertFused(graph, ['aten::_convolution']) 673 674 @onlyCPU 675 @dtypes(torch.int32) 676 def test_wildcard_unsupported_dtype(self, dtype): 677 class M(nn.Module): 678 def forward(self, x): 679 y = x // 2 680 return y 681 682 # In shufflenet_v2_x1_0, channels_per_groups is computed as: 683 # channels_per_group = num_channels // groups 684 # JIT IR converts groups to Long dtype, which is unsupported 685 # by oneDNN Graph, viz. Long(requires_grad=0, device=cpu) = prim::Constant[value={2}]() 686 # This test just ensures that the bridge code can handle 687 # unsupported dtypes for inputs to ops unsupported 688 # by oneDNN Graph. In this particular UT, aten::floor_divide 689 # would be added as a wildcard in graph-construction stage. 690 m = M() 691 x = torch.tensor([32], dtype=dtype) 692 _, graph = self.checkTrace(m, [x], dtype) 693 self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 0) 694 695 @onlyCPU 696 @dtypes(torch.float32, torch.bfloat16) 697 def test_rewrap_tensor_input_to_pytorch(self, dtype): 698 class M(nn.Module): 699 def __init__(self, eltwise_fn): 700 super().__init__() 701 self.conv1 = nn.Conv2d(32, 32, 3, padding=1, bias=True) 702 self.conv2 = nn.Conv2d(32, 32, 3, padding=1, bias=True) 703 self.eltwise = eltwise_fn 704 self.adaptive_avg_pool_2d = nn.AdaptiveAvgPool2d((5, 7)) 705 706 def forward(self, x, y): 707 x = self.conv1(x) 708 x = self.eltwise(x) 709 x = self.conv2(x) 710 x = self.eltwise(x) 711 x = torch.add(x, y) 712 x = self.adaptive_avg_pool_2d(x) 713 return x 714 715 eltwise_fn_name = 'relu' 716 eltwise_fn = get_eltwise_fn(eltwise_fn_name) 717 m = M(eltwise_fn) 718 m = m.to(memory_format=torch.channels_last) 719 x = torch.rand(1, 32, 28, 28).to(memory_format=torch.channels_last) 720 y = torch.rand(1, 32, 28, 28).to(memory_format=torch.channels_last) 721 # Simply test if the output is accurate 722 # The output of the second partition is input to adaptive_avg_pool2d, which is 723 # unsupported by LLGA, so it must be handled by PyTorch, which should receive 724 # correct strides info of the channels-last tensor. 725 graph, _ = self.checkTrace(m, [x, y], dtype) 726 727@unittest.skipIf(LLGA_NOT_ENABLED, "MKL-DNN build is disabled") 728class TestEnableDisableLlgaFuser(JitTestCase): 729 def setUp(self): 730 super().setUp() 731 self.is_enabled = torch._C._jit_set_llga_enabled(False) 732 733 def tearDown(self): 734 torch._C._jit_set_llga_enabled(self.is_enabled) 735 super().tearDown() 736 737 def test_context_manager(self): 738 x = torch.randn(4, 8) 739 y = torch.randn(4, 8) 740 with torch.jit.fuser('fuser3'): 741 with torch.jit.fuser('fuser3'): 742 743 def t1(x, y): 744 o = x + y 745 o = o + 2.0 746 return o 747 t_jit = torch.jit.script(t1) 748 t_jit(x, y) 749 t_jit(x, y) 750 self.assertGraphContains(t_jit.graph_for(x, y), LLGA_FUSION_GROUP) 751 752 def t2(x, y): 753 o = x + y 754 o = o + 3.0 755 return o 756 t_jit_2 = torch.jit.script(t2) 757 t_jit_2(x, y) 758 t_jit_2(x, y) 759 self.assertGraphContains(t_jit_2.graph_for(x, y), LLGA_FUSION_GROUP) 760 761 def t3(x, y): 762 o = x + y 763 o = o + 4.0 764 return o 765 t_jit_3 = torch.jit.script(t3) 766 t_jit_3(x, y) 767 t_jit_3(x, y) 768 self.assertGraphContainsExactly(t_jit_3.graph_for(x, y), LLGA_FUSION_GROUP, 0) 769 770 771@unittest.skipIf(LLGA_NOT_ENABLED, "MKL-DNN build is disabled") 772@unittest.skip("Enable when integration with dynamo aot_autograd is more stable") 773class TestDynamoAOT(JitTestCase): 774 def test_dynamo_aot_ts_onednn(self): 775 class Seq(nn.Module): 776 def __init__(self) -> None: 777 super().__init__() 778 self.layers = nn.Sequential( 779 nn.Linear(10, 10), 780 nn.ReLU(), 781 nn.Linear(10, 10), 782 nn.ReLU(), 783 ) 784 785 def forward(self, x): 786 return self.layers(x) 787 788 mod = Seq() 789 790 import torch._dynamo 791 aot_mod = torch._dynamo.optimize("aot_ts", nopython=True)(mod) 792 793 for _ in range(10): 794 with torch.jit.fuser("fuser3"): 795 loss = aot_mod(torch.rand([10, 10])).sum() 796 loss.backward() 797 798 torch._dynamo.reset() 799 800 801@unittest.skipIf(IS_AVX512_UNSUPPORTED, "This test fails for BF16 on machines without AVX512.") 802@unittest.skipIf(LLGA_NOT_ENABLED, "MKL-DNN build is disabled") 803class TestModel(JitLlgaTestCase): 804 @skipIfNoTorchVision 805 def _test_vision(self, model_name, dtype): 806 m = getattr(torchvision.models, model_name)().eval() 807 if dtype == torch.bfloat16: 808 m = optimization.fuse(m) 809 x = torch.rand(1, 3, 224, 224) / 10 810 _, graph = self.checkTrace(m, [x], dtype) 811 self.assertFused(graph, ['aten::_convolution', 'aten::batch_norm', 812 'aten::relu', 'aten::linear', 813 'aten::avg_pool2d', 'aten::max_pool2d']) 814 815for model_name, enabled in [ 816 ['resnet50', True], 817 ['resnext50_32x4d', True], 818 ['resnext101_32x8d', True], 819 ['densenet121', True], 820 ['densenet161', True], 821 ['densenet169', True], 822 ['densenet201', True], 823 ['efficientnet_b0', True], 824 ['efficientnet_b1', True], 825 ['efficientnet_b2', True], 826 ['efficientnet_b3', True], 827 ['efficientnet_b4', True], 828 ['efficientnet_b5', True], 829 ['efficientnet_b6', True], 830 ['efficientnet_b7', True], 831 ['regnet_y_400mf', True], 832 ['googlenet', TEST_SCIPY], 833 ['mobilenet_v2', True], 834 ['mobilenet_v3_large', True], 835 ['mnasnet1_0', True], 836 ['squeezenet1_0', True], 837 ['vgg16', True], 838 ['alexnet', True], 839 ['shufflenet_v2_x1_0', True], 840 ['wide_resnet50_2', True], 841]: 842 def _wrapper(mname, dtype): 843 @unittest.skipIf(not enabled, 'Disabled') 844 @separate_process 845 def test(self, dtype=dtype): 846 return self._test_vision(mname, dtype) 847 return test 848 849 for dtype in [torch.bfloat16, torch.float32]: 850 setattr(TestModel, 'test_vision_{}_{}'.format(model_name, str(dtype).split("torch.")[1]), _wrapper(model_name, dtype)) 851 852 853instantiate_device_type_tests(TestFusionPattern, globals()) 854instantiate_device_type_tests(TestOp, globals()) 855 856if __name__ == '__main__': 857 run_tests() 858