1# Owner(s): ["module: intel"] 2 3import itertools 4import math 5import unittest 6from itertools import product 7 8import torch 9import torch.backends.cudnn as cudnn 10import torch.nn as nn 11import torch.nn.functional as F 12from torch._C._dynamo.guards import assert_size_stride 13from torch.testing import make_tensor 14from torch.testing._internal.common_cuda import tf32_is_not_fp32 15from torch.testing._internal.common_device_type import ( 16 dtypes, 17 instantiate_device_type_tests, 18 onlyXPU, 19) 20from torch.testing._internal.common_dtype import floating_types_and 21from torch.testing._internal.common_nn import _test_module_empty_input, NNTestCase 22from torch.testing._internal.common_utils import ( 23 dtype2prec_DONTUSE, 24 gradcheck, 25 gradgradcheck, 26 parametrize as parametrize_test, 27 run_tests, 28 set_default_dtype, 29 TEST_SCIPY, 30 TEST_WITH_ROCM, 31) 32 33 34AMPERE_OR_ROCM = TEST_WITH_ROCM or tf32_is_not_fp32() 35if TEST_SCIPY: 36 import scipy.ndimage 37 import scipy.signal 38 39 40class TestConvolutionNNDeviceType(NNTestCase): 41 def run_conv_double_back_test( 42 self, 43 kern, 44 stride, 45 padding, 46 chan_in, 47 chan_out, 48 batch_size, 49 inp_size, 50 dilation, 51 no_weight, 52 groups=1, 53 use_xpu=False, 54 use_bias=True, 55 dtype=torch.double, 56 ): 57 device = torch.device("xpu" if use_xpu else "cpu") 58 x = torch.randn( 59 batch_size, 60 chan_in, 61 inp_size, 62 inp_size, 63 device=device, 64 dtype=dtype, 65 requires_grad=True, 66 ) 67 weight = torch.randn( 68 chan_out, 69 chan_in // groups, 70 kern, 71 kern, 72 device=device, 73 dtype=dtype, 74 requires_grad=not no_weight, 75 ) 76 if use_bias: 77 bias = torch.randn(chan_out, device=device, dtype=dtype, requires_grad=True) 78 else: 79 bias = None 80 81 def func(*inputs): 82 if use_bias: 83 lx, lweight, lbias = inputs 84 else: 85 lx, lweight = inputs 86 lbias = None 87 out = F.conv2d(lx, lweight, lbias, stride, padding, dilation, groups) 88 return out 89 90 if use_bias: 91 inputs = x, weight, bias 92 else: 93 inputs = x, weight 94 95 dummy_out = func(*inputs) 96 grad_y = torch.randn_like( 97 dummy_out, device=device, dtype=dtype, requires_grad=True 98 ) 99 100 if dtype == torch.float: 101 (g,) = torch.autograd.grad(dummy_out.sum(), x, create_graph=True) 102 return g.requires_grad 103 104 return gradgradcheck(func, inputs, (grad_y,)) 105 106 @dtypes(*floating_types_and(torch.half, torch.bfloat16)) 107 def test_Conv2d_large_workspace(self, device, dtype): 108 sizes = [ 109 (1, 256, 109, 175), 110 (1, 256, 80, 128), 111 (1, 256, 120, 192), 112 ] 113 114 def run_test(benchmark): 115 conv = torch.nn.Conv2d(256, 256, kernel_size=3, padding=1).to(device, dtype) 116 for size in sizes: 117 x = torch.randn(size, device=device, dtype=dtype) 118 out = conv(x.detach().clone().requires_grad_()) 119 out.backward(torch.ones_like(out)) 120 121 run_test(benchmark=False) 122 run_test(benchmark=True) 123 124 @dtypes(torch.half, torch.float) 125 def test_ConvTranspose2d_large_output_padding(self, device, dtype): 126 net1 = torch.nn.ConvTranspose2d( 127 128, 64, kernel_size=3, stride=2, padding=1, output_padding=1 128 ).to(device=device, dtype=dtype) 129 net2 = torch.nn.ConvTranspose2d( 130 64, 32, kernel_size=3, stride=2, padding=1, output_padding=1 131 ).to(device=device, dtype=dtype) 132 net3 = torch.nn.ConvTranspose2d( 133 32, 3, kernel_size=3, stride=2, padding=1, output_padding=1 134 ).to(device=device, dtype=dtype) 135 x = torch.rand(1, 128, 6, 6, device=device, dtype=dtype, requires_grad=True) 136 x = net1(x) 137 x = net2(x) 138 x = net3(x) 139 x.backward(torch.randn_like(x)) 140 141 @dtypes(torch.float, torch.double, torch.half) 142 def test_Conv2d_depthwise_naive_groups(self, device, dtype): 143 if dtype == torch.half and "xpu" in device: 144 self.skipTest( 145 "The accuracy issue of dtype fp16 would be fixed in oneDNN v3.4" 146 ) 147 for depth_multiplier in [1, 2]: 148 m = nn.Conv2d(2, 2 * depth_multiplier, kernel_size=3, groups=2).to( 149 device, dtype 150 ) 151 i = ( 152 torch.randn(2, 2, 6, 6, device=device, dtype=dtype) 153 .div_(2) 154 .requires_grad_() 155 ) 156 output = m(i) 157 grad_output = ( 158 torch.randn(2, 2 * depth_multiplier, 4, 4, device=device, dtype=dtype) 159 / 2 160 ) 161 output.backward(grad_output) 162 163 offset = 1 * depth_multiplier 164 165 m1 = nn.Conv2d(1, 1 * depth_multiplier, kernel_size=3).to(device, dtype) 166 m1.weight.data = m.weight.data[:offset].clone() 167 m1.bias.data = m.bias.data[:offset].clone() 168 i1 = i.detach()[:, :1].clone().requires_grad_() 169 output1 = m1(i1) 170 output1.backward(grad_output[:, :offset].contiguous()) 171 172 m2 = nn.Conv2d(1, 1 * depth_multiplier, kernel_size=3).to(device, dtype) 173 m2.weight.data.copy_(m.weight.data[offset:]) 174 m2.bias.data.copy_(m.bias.data[offset:]) 175 i2 = i.detach()[:, 1:].clone().requires_grad_() 176 output2 = m2(i2) 177 output2.backward(grad_output[:, offset:].contiguous()) 178 179 self.assertEqual( 180 output, 181 torch.cat([output1, output2], 1), 182 atol=dtype2prec_DONTUSE[dtype], 183 rtol=0, 184 ) 185 self.assertEqual( 186 i.grad.data, 187 torch.cat([i1.grad.data, i2.grad.data], 1), 188 atol=dtype2prec_DONTUSE[dtype], 189 rtol=0, 190 ) 191 self.assertEqual( 192 m.bias.grad.data, 193 torch.cat([m1.bias.grad.data, m2.bias.grad.data], 0), 194 atol=dtype2prec_DONTUSE[dtype], 195 rtol=0, 196 ) 197 self.assertEqual( 198 m.weight.grad.data, 199 torch.cat([m1.weight.grad.data, m2.weight.grad.data], 0), 200 atol=dtype2prec_DONTUSE[dtype], 201 rtol=0, 202 ) 203 204 @dtypes(torch.float, torch.double, torch.half) 205 def test_Conv3d_depthwise_naive_groups(self, device, dtype): 206 if dtype == torch.half and "xpu" in device: 207 self.skipTest( 208 "The accuracy issue of dtype fp16 would be fixed in oneDNN v3.4" 209 ) 210 for depth_multiplier in [1, 2]: 211 m = nn.Conv3d(2, 2 * depth_multiplier, kernel_size=3, groups=2).to( 212 device, dtype 213 ) 214 i = ( 215 torch.randn(2, 2, 6, 6, 6, device=device, dtype=dtype) 216 .div_(2) 217 .requires_grad_() 218 ) 219 output = m(i) 220 grad_output = ( 221 torch.randn( 222 2, 2 * depth_multiplier, 4, 4, 4, device=device, dtype=dtype 223 ) 224 / 2 225 ) 226 output.backward(grad_output) 227 228 offset = 1 * depth_multiplier 229 230 m1 = nn.Conv3d(1, 1 * depth_multiplier, kernel_size=3).to(device, dtype) 231 m1.weight.data = m.weight.data[:offset].clone() 232 m1.bias.data = m.bias.data[:offset].clone() 233 i1 = i.detach()[:, :1].clone().requires_grad_() 234 output1 = m1(i1) 235 output1.backward(grad_output[:, :offset].contiguous()) 236 237 m2 = nn.Conv3d(1, 1 * depth_multiplier, kernel_size=3).to(device, dtype) 238 m2.weight.data.copy_(m.weight.data[offset:]) 239 m2.bias.data.copy_(m.bias.data[offset:]) 240 i2 = i.detach()[:, 1:].clone().requires_grad_() 241 output2 = m2(i2) 242 output2.backward(grad_output[:, offset:].contiguous()) 243 atol, rtol = (3e-4, 3e-2) 244 245 self.assertEqual( 246 output, torch.cat([output1, output2], 1), atol=atol, rtol=rtol 247 ) 248 self.assertEqual( 249 i.grad.data, 250 torch.cat([i1.grad.data, i2.grad.data], 1), 251 atol=dtype2prec_DONTUSE[dtype], 252 rtol=0, 253 ) 254 self.assertEqual( 255 m.bias.grad.data, 256 torch.cat([m1.bias.grad.data, m2.bias.grad.data], 0), 257 atol=dtype2prec_DONTUSE[dtype], 258 rtol=0, 259 ) 260 self.assertEqual( 261 m.weight.grad.data, 262 torch.cat([m1.weight.grad.data, m2.weight.grad.data], 0), 263 atol=atol, 264 rtol=rtol, 265 ) 266 267 @dtypes(torch.float, torch.double, torch.half) 268 def test_noncontig_conv_grad(self, device, dtype): 269 module = nn.Conv2d(3, 5, kernel_size=3, padding=1).to(device, dtype) 270 input = torch.randn( 271 2, 3, 10, 10, dtype=dtype, device=device, requires_grad=True 272 ) 273 output = module(input) 274 275 grad = torch.randn(2, 2, 5, 10, 10, dtype=dtype, device=device)[:, 1] 276 assert not grad.is_contiguous() 277 output.backward(grad, retain_graph=True) 278 self.assertIsNotNone(input.grad) 279 result = input.grad.data.clone() 280 input.grad.data.zero_() 281 282 output.backward(grad.contiguous()) 283 self.assertEqual( 284 result, input.grad.data, atol=dtype2prec_DONTUSE[dtype], rtol=0 285 ) 286 287 @dtypes(torch.double) 288 def test_conv_double_backward(self, device, dtype): 289 with torch.backends.cudnn.flags(enabled=True, deterministic=True): 290 batch_size = 1 291 for kern, inp_size, dilations in [(3, 5, [1, 2]), (4, 9, [1])]: 292 for stride, padding, chan_in, chan_out, dilation in product( 293 [1], [2], [2], [3], dilations 294 ): 295 no_weight = stride == 2 296 result = self.run_conv_double_back_test( 297 kern, 298 stride, 299 padding, 300 chan_in, 301 chan_out, 302 batch_size, 303 inp_size, 304 dilation, 305 no_weight, 306 use_xpu=True, 307 dtype=dtype, 308 ) 309 self.assertTrue(result, "Conv double backward test failed") 310 311 def test_conv_double_backward_no_bias(self): 312 kern, stride = 3, 2 313 chan_in, chan_out = 2, 4 314 batch_size, inp_size = 2, 5 315 padding, dilation = 1, 1 316 no_weight, use_bias = False, True 317 result = self.run_conv_double_back_test( 318 kern, 319 stride, 320 padding, 321 chan_in, 322 chan_out, 323 batch_size, 324 inp_size, 325 dilation, 326 no_weight, 327 use_bias=use_bias, 328 ) 329 self.assertTrue(result, "Conv double backward test failed") 330 331 def test_conv_double_backward_groups(self): 332 kern, stride, padding = 3, 1, 2 333 chan_in, chan_out = 2, 4 334 batch_size, inp_size, dilation = 2, 6, 1 335 no_weight = False 336 groups = 2 337 result = self.run_conv_double_back_test( 338 kern, 339 stride, 340 padding, 341 chan_in * groups, 342 chan_out * groups, 343 batch_size, 344 inp_size, 345 dilation, 346 no_weight, 347 groups=groups, 348 ) 349 self.assertTrue(result, "Conv double backward test failed") 350 351 def test_conv_double_backward_stride(self): 352 batch_size = 2 353 for kern, inp_size, dilations in [(3, 5, [1, 2]), (3, 7, [1])]: 354 for stride, padding, chan_in, chan_out, dilation in product( 355 [2], [0, 1], [1], [2], dilations 356 ): 357 no_weight = False 358 self.run_conv_double_back_test( 359 kern, 360 stride, 361 padding, 362 chan_in, 363 chan_out, 364 batch_size, 365 inp_size, 366 dilation, 367 no_weight, 368 ) 369 370 @dtypes(torch.float) 371 def test_conv1d_same_padding(self, device, dtype): 372 test_args = [ 373 range(50, 55), 374 [1, 2, 3, 8], 375 range(1, 4), 376 [1], 377 ] 378 for in_size, k_size, dilation, stride in itertools.product(*test_args): 379 x = torch.rand(1, 1, in_size, device=device, dtype=dtype) 380 y = torch.rand(1, 1, k_size, device=device, dtype=dtype) 381 z = F.conv1d(x, y, padding="same", dilation=dilation, stride=stride) 382 self.assertEqual(z.size(2), int(math.ceil(in_size / stride))) 383 384 x = torch.rand(1, 1, 12, device=device, dtype=dtype) 385 y = torch.rand(1, 1, 3, device=device, dtype=dtype) 386 expect = F.conv1d(x, y, padding=1) 387 actual = F.conv1d(x, y, padding="same") 388 self.assertEqual(expect, actual) 389 390 x = torch.rand(1, 1, 12, device=device, dtype=dtype) 391 y = torch.rand(1, 1, 4, device=device, dtype=dtype) 392 expect = F.conv1d(x, y, padding=3, dilation=2) 393 actual = F.conv1d(x, y, padding="same", dilation=2) 394 self.assertEqual(expect, actual) 395 396 expect = F.conv1d(x, y, padding=5, dilation=3)[..., 1:] 397 actual = F.conv1d(x, y, padding="same", dilation=3) 398 self.assertEqual(expect, actual) 399 400 @dtypes(torch.float) 401 def test_conv3d_same_padding(self, device, dtype): 402 rtol, atol = None, None 403 x = torch.rand(1, 1, 10, 11, 12, device=device, dtype=dtype) 404 y = torch.rand(1, 1, 1, 2, 5, device=device, dtype=dtype) 405 expect = F.conv3d(x, y, padding=(0, 1, 2))[..., :, 1:, :] 406 actual = F.conv3d(x, y, padding="same") 407 self.assertEqual(expect, actual, rtol=rtol, atol=atol) 408 409 expect = F.conv3d(x, y, padding=(0, 1, 4), dilation=2) 410 actual = F.conv3d(x, y, padding="same", dilation=2) 411 self.assertEqual(expect, actual, rtol=rtol, atol=atol) 412 413 y = torch.rand(1, 1, 4, 4, 4, device=device, dtype=dtype) 414 expect = F.conv3d(x, y, padding=5, dilation=3)[..., 1:, 1:, 1:] 415 actual = F.conv3d(x, y, padding="same", dilation=3) 416 self.assertEqual(expect, actual, rtol=rtol, atol=atol) 417 418 @dtypes(torch.float) 419 def test_conv1d_valid_padding(self, device, dtype): 420 x = torch.rand(1, 1, 10, device=device, dtype=dtype) 421 y = torch.rand(1, 1, 4, device=device, dtype=dtype) 422 expect = F.conv1d(x, y) 423 actual = F.conv1d(x, y, padding="valid") 424 self.assertEqual(expect, actual) 425 426 @dtypes(torch.float) 427 def test_conv2d_valid_padding(self, device, dtype): 428 x = torch.rand(1, 1, 1, 10, device=device, dtype=dtype) 429 y = torch.rand(1, 1, 1, 4, device=device, dtype=dtype) 430 expect = F.conv2d(x, y) 431 actual = F.conv2d(x, y, padding="valid") 432 self.assertEqual(expect, actual) 433 434 @dtypes(torch.float) 435 def test_conv3d_valid_padding(self, device, dtype): 436 x = torch.rand(1, 1, 1, 1, 10, dtype=dtype, device=device) 437 y = torch.rand(1, 1, 1, 1, 4, dtype=dtype, device=device) 438 expect = F.conv3d(x, y) 439 actual = F.conv3d(x, y, padding="valid") 440 self.assertEqual(expect, actual) 441 442 @dtypes(torch.float) 443 def test_conv1d_same_padding_backward(self, device, dtype): 444 x = torch.rand(1, 1, 12, dtype=dtype, device=device, requires_grad=True) 445 y = torch.rand(1, 1, 4, dtype=dtype, device=device, requires_grad=True) 446 447 z = F.conv1d(x, y, padding=3, dilation=2) 448 z.sum().abs().backward() 449 gx_expect, gy_expect = x.grad, y.grad 450 x.grad, y.grad = None, None 451 452 z = F.conv1d(x, y, padding="same", dilation=2) 453 z.sum().abs().backward() 454 self.assertEqual(gx_expect, x.grad) 455 self.assertEqual(gy_expect, y.grad) 456 x.grad, y.grad = None, None 457 458 z = F.conv1d(x, y, padding=2)[..., 1:] 459 z.sum().abs().backward() 460 gx_expect, gy_expect = x.grad, y.grad 461 x.grad, y.grad = None, None 462 463 z = F.conv1d(x, y, padding="same") 464 z.sum().abs().backward() 465 self.assertEqual(gx_expect, x.grad) 466 self.assertEqual(gy_expect, y.grad) 467 468 @dtypes(torch.float) 469 def test_conv2d_same_padding_backward(self, device, dtype): 470 x = torch.rand(1, 1, 10, 11, device=device, dtype=dtype, requires_grad=True) 471 y = torch.rand(1, 1, 4, 5, device=device, dtype=dtype, requires_grad=True) 472 473 z = F.conv2d(x, y, padding=(3, 4), dilation=2) 474 z.sum().abs().backward() 475 gx_expect, gy_expect = x.grad, y.grad 476 x.grad, y.grad = None, None 477 478 z = F.conv2d(x, y, padding="same", dilation=2) 479 z.sum().abs().backward() 480 self.assertEqual(gx_expect, x.grad) 481 self.assertEqual(gy_expect, y.grad) 482 x.grad, y.grad = None, None 483 484 y = torch.rand(1, 1, 4, 4, device=device, dtype=dtype, requires_grad=True) 485 z = F.conv2d(x, y, padding=2)[..., 1:, 1:] 486 z.sum().abs().backward() 487 gx_expect, gy_expect = x.grad, y.grad 488 x.grad, y.grad = None, None 489 490 z = F.conv2d(x, y, padding="same") 491 z.sum().abs().backward() 492 self.assertEqual(gx_expect, x.grad) 493 self.assertEqual(gy_expect, y.grad) 494 495 @dtypes(torch.double) 496 def test_conv3d_same_padding_backward(self, device, dtype): 497 x = torch.rand(1, 1, 1, 11, 12, dtype=dtype, device=device, requires_grad=True) 498 y = torch.rand(1, 1, 1, 2, 5, dtype=dtype, device=device, requires_grad=True) 499 z = F.conv3d(x, y, padding=(0, 1, 4), dilation=2) 500 z.sum().abs().backward() 501 gx_expect, gy_expect = x.grad, y.grad 502 x.grad, y.grad = None, None 503 504 z = F.conv3d(x, y, padding="same", dilation=2) 505 z.sum().abs().backward() 506 self.assertEqual(gx_expect, x.grad) 507 self.assertEqual(gy_expect, y.grad) 508 x.grad, y.grad = None, None 509 gradcheck( 510 lambda x, y: F.conv3d(x, y, padding="same", dilation=2), 511 (x, y), 512 check_forward_ad=True, 513 nondet_tol=1e-5, 514 ) 515 gradgradcheck( 516 lambda x, y: F.conv3d(x, y, padding="same", dilation=2), 517 (x, y), 518 check_fwd_over_rev=True, 519 ) 520 521 y = torch.rand(1, 1, 1, 4, 4, dtype=dtype, device=device, requires_grad=True) 522 z = F.conv3d(x, y, padding=2)[..., 1:, 1:] 523 z.sum().abs().backward() 524 gx_expect, gy_expect = x.grad, y.grad 525 x.grad, y.grad = None, None 526 527 z = F.conv3d(x, y, padding="same") 528 z.sum().abs().backward() 529 self.assertEqual(gx_expect, x.grad) 530 self.assertEqual(gy_expect, y.grad) 531 gradcheck( 532 lambda x, y: F.conv3d(x, y, padding="same"), 533 (x, y), 534 check_forward_ad=True, 535 nondet_tol=1e-5, 536 ) 537 gradgradcheck( 538 lambda x, y: F.conv3d(x, y, padding="same"), 539 (x, y), 540 check_fwd_over_rev=True, 541 ) 542 543 @dtypes(torch.float) 544 def test_conv1d_valid_padding_backward(self, device, dtype): 545 x = torch.rand(1, 1, 10, dtype=dtype, device=device, requires_grad=True) 546 y = torch.rand(1, 1, 4, dtype=dtype, device=device, requires_grad=True) 547 F.conv1d(x, y, padding=0).sum().abs().backward() 548 gx_expect, gy_expect = x.grad, y.grad 549 x.grad, y.grad = None, None 550 F.conv1d(x, y, padding="valid").sum().abs().backward() 551 gx_actual, gy_actual = x.grad, y.grad 552 self.assertEqual(gx_expect, gx_actual) 553 self.assertEqual(gy_expect, gy_actual) 554 555 @unittest.skipIf(not TEST_SCIPY, "Scipy required for the test.") 556 @dtypes(torch.float) 557 @parametrize_test("mode", ("valid", "same")) 558 def test_conv1d_vs_scipy(self, device, dtype, mode): 559 t = make_tensor((1, 10), device=device, dtype=dtype) 560 feat_dim = t.shape[1] 561 weight_even = make_tensor((1, 1, 4), device=device, dtype=dtype) 562 weight_odd = make_tensor((1, 1, 5), device=device, dtype=dtype) 563 564 def _test(t, weight, mode): 565 t_a = t.view(-1).cpu().numpy() 566 w_a = weight.view(-1).cpu().numpy() 567 expected = scipy.signal.convolve(t_a, w_a, mode=mode) 568 569 kwargs = {"padding": mode} 570 if mode == "same": 571 p = weight.shape[2] // 2 572 t = torch.nn.functional.pad(t, (p, p)) 573 kwargs.pop("padding") 574 575 weight_flipped = torch.flip(weight, (2,)) 576 actual = torch.nn.functional.conv1d(t, weight_flipped, **kwargs).squeeze(0) 577 if mode == "same": 578 actual = actual[:feat_dim] 579 580 self.assertEqual(actual, expected, atol=2e-5, rtol=2e-5) 581 582 with set_default_dtype(torch.float): 583 _test(t, weight_even, mode) 584 _test(t, weight_odd, mode) 585 586 @unittest.skipIf(not TEST_SCIPY, "Scipy required for the test.") 587 @dtypes(torch.float) 588 @parametrize_test("mode", ("valid", "same")) 589 def test_conv2d_vs_scipy(self, device, dtype, mode): 590 t = make_tensor((1, 5, 10), device=device, dtype=dtype) 591 weight_even = make_tensor((1, 1, 2, 4), device=device, dtype=dtype) 592 weight_odd = make_tensor((1, 1, 3, 5), device=device, dtype=dtype) 593 594 def _test(t, weight, mode): 595 t_a = t.squeeze(0).cpu().numpy() 596 w_a = weight.squeeze(0).squeeze(0).cpu().numpy() 597 expected = scipy.signal.convolve2d(t_a, w_a, mode=mode) 598 599 kwargs = {"padding": mode} 600 if mode == "same": 601 left_right_pad = weight.shape[3] // 2 602 top_bottom_pad = weight.shape[2] // 2 603 p = (left_right_pad, left_right_pad, top_bottom_pad, top_bottom_pad) 604 t = torch.nn.functional.pad(t, p) 605 kwargs.pop("padding") 606 607 weight_flipped = torch.flip(weight, (2, 3)) 608 actual = torch.nn.functional.conv2d(t, weight_flipped, **kwargs).squeeze(0) 609 if mode == "same": 610 actual = actual[:5, :10] 611 612 self.assertEqual(actual, expected, rtol=2e-5, atol=5e-6) 613 614 with set_default_dtype(torch.float): 615 _test(t, weight_even, mode) 616 _test(t, weight_odd, mode) 617 618 @unittest.skipIf(not TEST_SCIPY, "Scipy required for the test.") 619 @dtypes(torch.float) 620 @parametrize_test("mode", ("valid", "same")) 621 def test_conv3d_vs_scipy(self, device, dtype, mode): 622 t = make_tensor((1, 5, 5, 10), device=device, dtype=dtype) 623 weight_even = make_tensor((1, 1, 2, 2, 4), device=device, dtype=dtype) 624 weight_odd = make_tensor((1, 1, 2, 3, 5), device=device, dtype=dtype) 625 626 def _test(t, weight, mode): 627 t_a = t.squeeze(0).cpu().numpy() 628 w_a = weight.squeeze(0).squeeze(0).cpu().numpy() 629 expected = scipy.signal.convolve(t_a, w_a, mode=mode) 630 kwargs = {"padding": mode} 631 if mode == "same": 632 left_right_pad = weight.shape[4] // 2 633 top_bottom_pad = weight.shape[3] // 2 634 front_back_pad = weight.shape[2] // 2 635 p = ( 636 left_right_pad, 637 left_right_pad, 638 top_bottom_pad, 639 top_bottom_pad, 640 front_back_pad, 641 front_back_pad, 642 ) 643 t = torch.nn.functional.pad(t, p) 644 kwargs.pop("padding") 645 weight_flipped = torch.flip(weight, (2, 3, 4)) 646 actual = torch.nn.functional.conv3d(t, weight_flipped, **kwargs).squeeze(0) 647 if mode == "same": 648 actual = actual[:5, :5, :10] 649 self.assertEqual(actual, expected, rtol=2e-5, atol=5e-6) 650 651 with set_default_dtype(torch.float): 652 _test(t, weight_even, mode) 653 _test(t, weight_odd, mode) 654 655 @dtypes(torch.float) 656 def test_conv2d_valid_padding_backward(self, device, dtype): 657 x = torch.rand(1, 1, 1, 10, device=device, dtype=dtype, requires_grad=True) 658 y = torch.rand(1, 1, 1, 4, device=device, dtype=dtype, requires_grad=True) 659 F.conv2d(x, y, padding=0).sum().abs().backward() 660 gx_expect, gy_expect = x.grad, y.grad 661 x.grad, y.grad = None, None 662 F.conv2d(x, y, padding="valid").sum().abs().backward() 663 gx_actual, gy_actual = x.grad, y.grad 664 self.assertEqual(gx_expect, gx_actual) 665 self.assertEqual(gy_expect, gy_actual) 666 667 @dtypes(torch.double) 668 def test_conv3d_valid_padding_backward(self, device, dtype): 669 x = torch.rand(1, 1, 1, 1, 10, dtype=dtype, device=device, requires_grad=True) 670 y = torch.rand(1, 1, 1, 1, 4, dtype=dtype, device=device, requires_grad=True) 671 F.conv3d(x, y, padding=0).sum().abs().backward() 672 gx_expect, gy_expect = x.grad, y.grad 673 x.grad, y.grad = None, None 674 675 F.conv3d(x, y, padding="valid").sum().abs().backward() 676 gx_actual, gy_actual = x.grad, y.grad 677 self.assertEqual(gx_expect, gx_actual) 678 self.assertEqual(gy_expect, gy_actual) 679 gradcheck( 680 lambda x, y: F.conv3d(x, y, padding="valid"), 681 (x, y), 682 check_forward_ad=True, 683 ) 684 gradgradcheck( 685 lambda x, y: F.conv3d(x, y, padding="valid"), 686 (x, y), 687 check_fwd_over_rev=True, 688 ) 689 690 @parametrize_test("N", range(2, 4), name_fn=lambda N: f"ConvTranspose{N}d") 691 def test_conv_transpose_with_output_size_and_no_batch_dim(self, device, N): 692 inp = torch.randn((1, 15, 13) if N == 2 else (1, 15, 13, 13), device=device) 693 output_size = (1, 240, 200) if N == 2 else (1, 240, 200, 200) 694 ConvTransposeNd = getattr(nn, f"ConvTranspose{N}d") 695 m = ConvTransposeNd( 696 1, 1, kernel_size=16, stride=16, padding=7, bias=False, device=device 697 ) 698 output = m(inp, output_size=output_size) 699 self.assertEqual(output.shape, output_size) 700 701 @dtypes(torch.float) 702 def test_conv_empty_channel(self, device, dtype): 703 in_channels = 0 704 mod = torch.nn.Conv1d(in_channels, 8, 2, stride=2, dtype=dtype).to(device) 705 inp = torch.randn(2, 0, 15, device=device, dtype=dtype) 706 _test_module_empty_input(self, mod, inp, check_size=False) 707 708 with self.assertRaisesRegex(RuntimeError, "Given groups=1, weight"): 709 inp = torch.randn(2, 1, 0, device=device, dtype=dtype) 710 mod(inp) 711 712 mod = torch.nn.Conv2d(in_channels, 33, 3, stride=2, dtype=dtype).to(device) 713 inp = torch.randn(2, 0, 50, 100, device=device, dtype=dtype) 714 _test_module_empty_input(self, mod, inp, check_size=False) 715 716 with self.assertRaisesRegex(RuntimeError, "Given groups=1, weight"): 717 inp = torch.randn(2, 1, 40, 0, device=device, dtype=dtype) 718 mod(inp) 719 720 mod = torch.nn.Conv3d(in_channels, 33, 3, stride=2, dtype=dtype).to(device) 721 inp = torch.randn(2, 0, 50, 20, 40, device=device, dtype=dtype) 722 _test_module_empty_input(self, mod, inp, check_size=False) 723 724 with self.assertRaisesRegex(RuntimeError, "Given groups=1, weight"): 725 inp = torch.randn(2, 1, 50, 0, 40, device=device, dtype=dtype) 726 mod(inp) 727 728 def test_group_conv_empty(self, device): 729 mod = torch.nn.Conv2d(4, 4, stride=2, kernel_size=3, padding=1, groups=4).to( 730 device 731 ) 732 inp = torch.randn(0, 4, 4, 4, device=device) 733 _test_module_empty_input(self, mod, inp, check_size=False) 734 735 def test_group_convTranspose_empty(self, device): 736 mod = torch.nn.ConvTranspose2d( 737 4, 4, stride=2, kernel_size=3, padding=1, groups=4 738 ).to(device) 739 inp = torch.randn(0, 4, 4, 4, device=device) 740 _test_module_empty_input(self, mod, inp, check_size=False) 741 742 def test_convTranspose_empty(self, device): 743 mod = torch.nn.ConvTranspose2d(4, 4, stride=2, kernel_size=3, padding=1).to( 744 device 745 ) 746 inp = torch.randn(0, 4, 4, 4, device=device) 747 _test_module_empty_input(self, mod, inp, check_size=False) 748 749 def test_conv_large_nosplit(self, device): 750 dtype = torch.half 751 conv1 = nn.Conv2d(2, 2, 8, 8).to(device).to(dtype) 752 input_large = torch.randn(1, 2, 1024, 1024 * 1024, dtype=dtype, device=device) 753 conv1(input_large) 754 conv2 = torch.nn.Conv2d(1, 1024, 1, 1).to(device).to(dtype) 755 input_large = torch.randn(1, 1, 2048, 1024, dtype=dtype, device=device) 756 conv2(input_large) 757 758 def test_conv_noncontig_weights(self, device): 759 for dim in (1, 2, 3): 760 for grouped in (False, True): 761 nc = 3 762 groups = 3 if grouped else 1 763 w = torch.randn([3] * dim, device=device) 764 w = w.expand([nc, int(nc / groups)] + list(w.shape)) 765 w = w.detach().requires_grad_() 766 x = torch.randn( 767 [1, nc] + ([5] * dim), device=device, requires_grad=True 768 ) 769 y = getattr(F, f"conv{dim}d")(x, w, groups=groups) 770 y.sum().backward() 771 y = getattr(F, f"conv_transpose{dim}d")(x, w, groups=groups) 772 y.sum().backward() 773 774 def test_conv_noncontig_weights_and_bias(self, device): 775 for bias in [True, False]: 776 conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=bias).to( 777 device, torch.float 778 ) 779 input_nc = torch.randn( 780 (1, 3, 224, 224, 2), device=device, dtype=torch.float 781 )[:, :, :, :, 1] 782 input_c = input_nc.contiguous() 783 weight_nc = torch.randn((64, 3, 7, 7, 2), device=device, dtype=torch.float)[ 784 :, :, :, :, 1 785 ] 786 conv1.weight = nn.Parameter(weight_nc) 787 weight_c = conv1.weight.contiguous() 788 if bias: 789 bias_nc = torch.randn((64, 2), device=device, dtype=torch.float)[:, 1] 790 conv1.bias = nn.Parameter(bias_nc) 791 bias_c = conv1.bias.contiguous() 792 out1 = conv1(input_nc) 793 conv1.weight = nn.Parameter(weight_c) 794 if bias: 795 conv1.bias = nn.Parameter(bias_c) 796 out2 = conv1(input_c) 797 self.assertEqual(out1, out2) 798 799 def test_conv_transposed_large(self, device): 800 dtype = torch.half if self.device_type == "cuda" else torch.float 801 conv = nn.ConvTranspose2d(1, 1, 1, 1, bias=False).to(device).to(dtype) 802 input_large = torch.randn(4096, 1, 512, 1024, dtype=dtype, device=device) 803 ret = conv(input_large) 804 maxdiff0 = ( 805 (ret.narrow(0, 0, 1024) - conv(input_large.narrow(0, 0, 1024))) 806 .abs_() 807 .max() 808 .item() 809 ) 810 maxdiff1 = ( 811 (ret.narrow(0, 1024, 1024) - conv(input_large.narrow(0, 1024, 1024))) 812 .abs_() 813 .max() 814 .item() 815 ) 816 maxdiff2 = ( 817 (ret.narrow(0, 2048, 1024) - conv(input_large.narrow(0, 2048, 1024))) 818 .abs_() 819 .max() 820 .item() 821 ) 822 maxdiff3 = ( 823 (ret.narrow(0, 3072, 1024) - conv(input_large.narrow(0, 3072, 1024))) 824 .abs_() 825 .max() 826 .item() 827 ) 828 self.assertEqual(maxdiff0, 0) 829 self.assertEqual(maxdiff1, 0) 830 self.assertEqual(maxdiff2, 0) 831 self.assertEqual(maxdiff3, 0) 832 833 def test_conv_large(self, device): 834 dtype = torch.half if self.device_type == "cuda" else torch.float 835 conv = nn.Conv2d(2, 2, 8, 8, bias=False).to(device).to(dtype) 836 input_large = torch.randn(4097, 2, 512, 512, dtype=dtype, device=device) 837 ret = conv(input_large) 838 self.assertEqual(ret[:2048], conv(input_large[:2048])) 839 self.assertEqual(ret[2048:4096], conv(input_large[2048:4096])) 840 self.assertEqual(ret[4096:], conv(input_large[4096:])) 841 842 conv.zero_grad() 843 ret.view(4097, -1).max(dim=1).values.sum().backward() 844 del ret 845 grad1 = conv.weight.grad.detach().clone() 846 conv.zero_grad() 847 conv(input_large[:2048]).view(2048, -1).max(dim=1).values.sum().backward() 848 conv(input_large[2048:4096]).view(2048, -1).max(dim=1).values.sum().backward() 849 conv(input_large[4096:]).view(1, -1).max(dim=1).values.sum().backward() 850 grad2 = conv.weight.grad.detach().clone() 851 scale = 1 / grad2.abs().mean() 852 grad1 = grad1 * scale 853 grad2 = grad2 * scale 854 self.assertEqual(grad1, grad2, atol=5e-2, rtol=5e-3) 855 856 def test_Conv2d_size_1_kernel(self, device): 857 x_cpu = torch.randn(2, 3, 5, 5) 858 conv_cpu = torch.nn.Conv2d(3, 3, kernel_size=1) 859 y_cpu = conv_cpu(x_cpu) 860 y = torch.rand_like(y_cpu) 861 y_cpu.backward(y) 862 863 with cudnn.flags(enabled=False): 864 conv_cuda = torch.nn.Conv2d(3, 3, kernel_size=1).to(device) 865 conv_cuda.bias.data.copy_(conv_cpu.bias.data) 866 conv_cuda.weight.data.copy_(conv_cpu.weight.data) 867 y_cuda = conv_cuda(x_cpu.to(device)) 868 y_cuda.backward(y.to(device)) 869 870 self.assertEqual(y_cpu, y_cuda, atol=1e-5, rtol=0, exact_device=False) 871 self.assertEqual( 872 conv_cpu.bias.grad.data, 873 conv_cuda.bias.grad.data, 874 atol=1e-5, 875 rtol=0, 876 exact_device=False, 877 ) 878 self.assertEqual( 879 conv_cpu.weight.grad.data, 880 conv_cuda.weight.grad.data, 881 atol=1e-5, 882 rtol=0, 883 exact_device=False, 884 ) 885 886 def test_ConvTranspose2d_size_1_kernel(self, device): 887 x_cpu = torch.randn(2, 3, 5, 5) 888 conv_cpu = torch.nn.ConvTranspose2d(3, 3, kernel_size=1) 889 y_cpu = conv_cpu(x_cpu) 890 y = torch.rand_like(y_cpu) 891 y_cpu.backward(y) 892 conv_cuda = torch.nn.ConvTranspose2d(3, 3, kernel_size=1).to(device) 893 conv_cuda.bias.data.copy_(conv_cpu.bias.data) 894 conv_cuda.weight.data.copy_(conv_cpu.weight.data) 895 y_cuda = conv_cuda(x_cpu.to(device)) 896 y_cuda.backward(y.to(device)) 897 898 self.assertEqual(y_cpu, y_cuda, atol=1e-5, rtol=0, exact_device=False) 899 self.assertEqual( 900 conv_cpu.bias.grad.data, 901 conv_cuda.bias.grad.data, 902 atol=1e-5, 903 rtol=0, 904 exact_device=False, 905 ) 906 self.assertEqual( 907 conv_cpu.weight.grad.data, 908 conv_cuda.weight.grad.data, 909 atol=1e-5, 910 rtol=0, 911 exact_device=False, 912 ) 913 914 def test_ConvTranspose3d_size_1_kernel(self, device): 915 with set_default_dtype(torch.double): 916 x_cpu = torch.randn(2, 3, 3, 5, 5) 917 conv_cpu = torch.nn.ConvTranspose3d(3, 3, kernel_size=1) 918 y_cpu = conv_cpu(x_cpu) 919 y = torch.rand_like(y_cpu) 920 y_cpu.backward(y) 921 conv_cuda = torch.nn.ConvTranspose3d(3, 3, kernel_size=1).to(device) 922 conv_cuda.bias.data.copy_(conv_cpu.bias.data) 923 conv_cuda.weight.data.copy_(conv_cpu.weight.data) 924 y_cuda = conv_cuda(x_cpu.to(device)) 925 y_cuda.backward(y.to(device)) 926 927 self.assertEqual(y_cpu, y_cuda, atol=1e-5, rtol=0, exact_device=False) 928 self.assertEqual( 929 conv_cpu.bias.grad.data, 930 conv_cuda.bias.grad.data, 931 atol=1e-5, 932 rtol=0, 933 exact_device=False, 934 ) 935 self.assertEqual( 936 conv_cpu.weight.grad.data, 937 conv_cuda.weight.grad.data, 938 atol=1e-5, 939 rtol=0, 940 exact_device=False, 941 ) 942 943 @dtypes(torch.float) 944 def test_Conv2d_naive_groups(self, device, dtype): 945 m = nn.Conv2d(4, 4, kernel_size=3, groups=2).to(device, dtype) 946 i = torch.randn(2, 4, 6, 6, device=device, dtype=dtype, requires_grad=True) 947 output = m(i) 948 grad_output = torch.randn(2, 4, 4, 4, device=device, dtype=dtype) 949 output.backward(grad_output) 950 951 m1 = nn.Conv2d(2, 2, kernel_size=3).to(device, dtype) 952 m1.weight.data.copy_(m.weight.data[:2]) 953 m1.bias.data.copy_(m.bias.data[:2]) 954 i1 = i.data[:, :2].contiguous().requires_grad_(True) 955 output1 = m1(i1) 956 output1.backward(grad_output[:, :2].contiguous()) 957 958 m2 = nn.Conv2d(2, 2, kernel_size=3).to(device, dtype) 959 m2.weight.data.copy_(m.weight.data[2:]) 960 m2.bias.data.copy_(m.bias.data[2:]) 961 i2 = i.data[:, 2:].contiguous().requires_grad_(True) 962 output2 = m2(i2) 963 output2.backward(grad_output[:, 2:].contiguous()) 964 965 self.assertEqual(output, torch.cat([output1, output2], 1)) 966 self.assertEqual( 967 i.grad.data, 968 torch.cat([i1.grad.data, i2.grad.data], 1), 969 atol=dtype2prec_DONTUSE[dtype], 970 rtol=0, 971 ) 972 self.assertEqual( 973 m.bias.grad.data, 974 torch.cat([m1.bias.grad.data, m2.bias.grad.data], 0), 975 atol=dtype2prec_DONTUSE[dtype], 976 rtol=0, 977 ) 978 self.assertEqual( 979 m.weight.grad.data, 980 torch.cat([m1.weight.grad.data, m2.weight.grad.data], 0), 981 atol=dtype2prec_DONTUSE[dtype], 982 rtol=0, 983 ) 984 985 @dtypes(torch.double) 986 def test_Conv2d_backward_depthwise(self, device, dtype): 987 x = torch.randn(2, 2, 4, 20, device=device, dtype=dtype, requires_grad=True) 988 weight = torch.randn(2, 1, 3, 5, device=device, dtype=dtype, requires_grad=True) 989 990 def conv2d_depthwise(x, weight): 991 return torch.nn.functional.conv2d( 992 x, weight, bias=None, stride=(1, 10), groups=2 993 ) 994 995 torch.autograd.gradcheck(conv2d_depthwise, (x, weight)) 996 997 @dtypes(torch.half, torch.float) 998 def test_conv_cudnn_nhwc(self, device, dtype): 999 def helper(n, c, h, w, out_channels, kernel_size, groups): 1000 input = torch.randint(-3, 3, (n, c, h, w), dtype=dtype, device=device).to( 1001 memory_format=torch.channels_last 1002 ) 1003 input.requires_grad_() 1004 conv = nn.Conv2d(c, out_channels, kernel_size, groups=groups).to( 1005 device=device, dtype=dtype, memory_format=torch.channels_last 1006 ) 1007 for p in conv.parameters(): 1008 p.data = torch.randint_like(p, -3, 3) 1009 1010 ref_input = input.detach().clone().contiguous().double().requires_grad_() 1011 ref_conv = nn.Conv2d(c, out_channels, kernel_size, groups=groups) 1012 ref_conv.load_state_dict(conv.state_dict()) 1013 ref_conv = ref_conv.to( 1014 device=device, dtype=torch.double, memory_format=torch.contiguous_format 1015 ) 1016 1017 out = conv(input) 1018 ref_out = ref_conv(ref_input) 1019 1020 grad = torch.randint_like(out, -3, 3) 1021 ref_grad = grad.detach().clone().double().contiguous() 1022 1023 out.backward(grad) 1024 ref_out.backward(ref_grad) 1025 1026 self.assertTrue(out.is_contiguous(memory_format=torch.channels_last)) 1027 self.assertTrue(input.grad.is_contiguous(memory_format=torch.channels_last)) 1028 self.assertTrue( 1029 conv.weight.grad.is_contiguous(memory_format=torch.channels_last) 1030 ) 1031 1032 self.assertTrue(ref_out.is_contiguous()) 1033 self.assertTrue(ref_input.grad.is_contiguous()) 1034 self.assertTrue(ref_conv.weight.grad.is_contiguous()) 1035 1036 self.assertEqual(out, ref_out, exact_dtype=False) 1037 self.assertEqual(conv.weight.grad, ref_conv.weight.grad, exact_dtype=False) 1038 self.assertEqual(conv.bias.grad, ref_conv.bias.grad, exact_dtype=False) 1039 self.assertEqual(input.grad, ref_input.grad, exact_dtype=False) 1040 1041 helper(2, 8, 4, 4, out_channels=4, kernel_size=3, groups=1) 1042 helper(2, 8, 4, 4, out_channels=8, kernel_size=3, groups=8) 1043 helper(1, 16, 56, 56, out_channels=16, kernel_size=3, groups=1) 1044 helper(1, 16, 56, 56, out_channels=16, kernel_size=3, groups=16) 1045 1046 @dtypes(torch.half, torch.float) 1047 def test_conv_cudnn_ndhwc(self, device, dtype): 1048 def helper(n, c, d, h, w, out_channels, kernel_size, groups): 1049 input = torch.randint( 1050 -2, 2, (n, c, d, h, w), dtype=dtype, device=device 1051 ).to(memory_format=torch.channels_last_3d) 1052 input.requires_grad_() 1053 conv = nn.Conv3d(c, out_channels, kernel_size, groups=groups).to( 1054 device=device, dtype=dtype, memory_format=torch.channels_last_3d 1055 ) 1056 for p in conv.parameters(): 1057 p.data = torch.randint_like(p, -2, 2) 1058 1059 ref_input = input.detach().clone().contiguous().double().requires_grad_() 1060 ref_conv = nn.Conv3d(c, out_channels, kernel_size, groups=groups) 1061 ref_conv.load_state_dict(conv.state_dict()) 1062 ref_conv = ref_conv.to( 1063 device=device, dtype=torch.double, memory_format=torch.contiguous_format 1064 ) 1065 1066 out = conv(input) 1067 ref_out = ref_conv(ref_input) 1068 1069 grad = torch.randint_like(out, -2, 2) 1070 ref_grad = grad.detach().clone().double().contiguous() 1071 1072 out.backward(grad) 1073 ref_out.backward(ref_grad) 1074 1075 self.assertTrue(out.is_contiguous(memory_format=torch.channels_last_3d)) 1076 self.assertTrue( 1077 input.grad.is_contiguous(memory_format=torch.channels_last_3d) 1078 ) 1079 self.assertTrue( 1080 conv.weight.grad.is_contiguous(memory_format=torch.channels_last_3d) 1081 ) 1082 1083 self.assertTrue(ref_out.is_contiguous()) 1084 self.assertTrue(ref_input.grad.is_contiguous()) 1085 self.assertTrue(ref_conv.weight.grad.is_contiguous()) 1086 1087 self.assertEqual(out, ref_out, exact_dtype=False) 1088 self.assertEqual(conv.weight.grad, ref_conv.weight.grad, exact_dtype=False) 1089 self.assertEqual(conv.bias.grad, ref_conv.bias.grad, exact_dtype=False) 1090 self.assertEqual(input.grad, ref_input.grad, exact_dtype=False) 1091 1092 helper(2, 8, 4, 4, 4, out_channels=4, kernel_size=3, groups=1) 1093 helper(2, 8, 4, 4, 4, out_channels=8, kernel_size=3, groups=8) 1094 helper(1, 16, 18, 18, 18, out_channels=16, kernel_size=3, groups=1) 1095 helper(1, 16, 18, 18, 18, out_channels=16, kernel_size=3, groups=16) 1096 1097 def _run_conv( 1098 self, 1099 layer, 1100 device, 1101 inp, 1102 grad, 1103 ref_conv, 1104 ref_input, 1105 ref_out, 1106 input_format, 1107 weight_format, 1108 grad_format, 1109 output_format, 1110 ): 1111 conv = ( 1112 layer(inp.size(1), grad.size(1), ref_conv.weight.size(2)).float().to(device) 1113 ) 1114 conv.load_state_dict(ref_conv.state_dict()) 1115 weight_data = ( 1116 conv.weight.detach().clone().contiguous(memory_format=weight_format) 1117 ) 1118 conv.weight.data = weight_data.resize_( 1119 weight_data.size(), memory_format=weight_format 1120 ) 1121 input = inp.clone().contiguous(memory_format=input_format) 1122 input.resize_(input.size(), memory_format=input_format) 1123 input = input.requires_grad_() 1124 grad = grad.contiguous(memory_format=grad_format) 1125 grad.resize_(grad.size(), memory_format=grad_format) 1126 out = conv(input) 1127 out.backward(grad) 1128 self.assertTrue(out.is_contiguous(memory_format=output_format)) 1129 self.assertEqual(out, ref_out) 1130 self.assertEqual(conv.weight.grad, ref_conv.weight.grad) 1131 self.assertEqual(conv.bias.grad, ref_conv.bias.grad) 1132 self.assertEqual(input.grad, ref_input.grad) 1133 1134 def _test_conv_cudnn_nhwc_nchw(self, layer, n, c, h, w, k, filter_size, device): 1135 data = torch.randint(1, 10, (n, c, h, w), dtype=torch.float32, device=device) 1136 ref_input = data.clone().contiguous().requires_grad_(True) 1137 ref_conv = layer(c, k, filter_size).float().to(device) 1138 ref_out = ref_conv(ref_input) 1139 grad = torch.randint(1, 10, ref_out.size(), dtype=torch.float32, device=device) 1140 ref_out.backward(grad) 1141 1142 for w_f in [torch.contiguous_format, torch.channels_last]: 1143 for g_f in [torch.contiguous_format, torch.channels_last]: 1144 for input_format in [torch.contiguous_format, torch.channels_last]: 1145 output_format = torch.contiguous_format 1146 if input_format == torch.channels_last: 1147 output_format = torch.channels_last 1148 if w_f == torch.channels_last: 1149 output_format = torch.channels_last 1150 self._run_conv( 1151 layer, 1152 device, 1153 data, 1154 grad, 1155 ref_conv, 1156 ref_input, 1157 ref_out, 1158 input_format, 1159 w_f, 1160 g_f, 1161 output_format, 1162 ) 1163 1164 @dtypes(torch.float, torch.double) 1165 def test_conv_cudnn_nhwc_support(self, device, dtype): 1166 input = torch.randn( 1167 (1, 16, 1, 1), dtype=dtype, device=device, requires_grad=True 1168 ) 1169 weight = torch.randn( 1170 (8, 16, 3, 3), dtype=dtype, device=device, requires_grad=True 1171 ) 1172 weight = weight.to(memory_format=torch.channels_last) 1173 o = torch.conv2d(input, weight, None, (2, 1), (1, 1), (1, 1), 1) 1174 self.assertTrue(o.is_contiguous(memory_format=torch.channels_last)) 1175 o.sum().backward() 1176 1177 @dtypes(torch.float) 1178 def test_conv2d_no_grad(self, device, dtype): 1179 for batch in [1, 2, 3]: 1180 for groups in [1, 2, 4]: 1181 input = torch.rand(batch, groups, 8, 8, dtype=dtype, device=device) 1182 m = nn.Conv2d( 1183 groups, 1184 8, 1185 kernel_size=(3, 3), 1186 groups=groups, 1187 dtype=dtype, 1188 device=device, 1189 ) 1190 with torch.no_grad(): 1191 output_ng = m(input) 1192 output = m(input) 1193 self.assertEqual(output, output_ng, rtol=1e-2, atol=1e-5) 1194 1195 def test_conv_double_backward_strided_with_3D_input_and_weight(self, device): 1196 input = torch.randn(2, 3, 6, device=device) 1197 weight = torch.randn(3, 3, 3, device=device) 1198 bias = torch.randn(3, device=device) 1199 stride = (2,) 1200 padding = (1,) 1201 dilation = (1,) 1202 transposed = False 1203 output_padding = (0,) 1204 groups = 1 1205 output = torch.ops.aten.convolution( 1206 input, 1207 weight, 1208 bias, 1209 stride, 1210 padding, 1211 dilation, 1212 transposed, 1213 output_padding, 1214 groups, 1215 ) 1216 1217 ggI = torch.randn(input.shape, device=device) 1218 ggW = torch.randn(weight.shape, device=device) 1219 ggB = torch.randn(bias.shape, device=device) 1220 gO = torch.randn(output.shape, device=device) 1221 output_mask = [True, True, True] 1222 ( 1223 grad_grad_output, 1224 grad_input, 1225 grad_weight, 1226 ) = torch.ops.aten._convolution_double_backward( 1227 ggI, 1228 ggW, 1229 ggB, 1230 gO, 1231 weight, 1232 input, 1233 stride, 1234 padding, 1235 dilation, 1236 transposed, 1237 output_padding, 1238 groups, 1239 output_mask, 1240 ) 1241 1242 self.assertEqual(grad_grad_output.shape, gO.shape) 1243 self.assertEqual(grad_input.shape, input.shape) 1244 self.assertEqual(grad_weight.shape, weight.shape) 1245 1246 @onlyXPU 1247 @dtypes(torch.float16, torch.bfloat16, torch.float32, torch.float64) 1248 def test_channels_last_ouput_stride(self, device, dtype): 1249 input = torch.randn( 1250 (2, 3, 16, 16), device=device, dtype=dtype, requires_grad=True 1251 ) 1252 weight = torch.randn( 1253 (512, 3, 3, 3), device=device, dtype=dtype, requires_grad=True 1254 ) 1255 input = input.to(memory_format=torch.channels_last) 1256 weight = weight.to(memory_format=torch.channels_last) 1257 out = torch.conv2d(input, weight, None, (2, 2), (0, 0), (1, 1), 1) 1258 1259 if dtype is torch.float64: 1260 # Like most conv backend, xpu does not support float64 for chanel last conv. 1261 # input NHWC, output NCHW 1262 assert_size_stride(out, (2, 512, 7, 7), (25088, 49, 7, 1)) 1263 else: 1264 # input NHWC, output NHWC 1265 assert_size_stride(out, (2, 512, 7, 7), (25088, 1, 3584, 512)) 1266 1267 1268instantiate_device_type_tests( 1269 TestConvolutionNNDeviceType, globals(), only_for="xpu", allow_xpu=True 1270) 1271 1272if __name__ == "__main__": 1273 run_tests() 1274