xref: /aosp_15_r20/external/pytorch/test/xpu/test_conv.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: intel"]
2
3import itertools
4import math
5import unittest
6from itertools import product
7
8import torch
9import torch.backends.cudnn as cudnn
10import torch.nn as nn
11import torch.nn.functional as F
12from torch._C._dynamo.guards import assert_size_stride
13from torch.testing import make_tensor
14from torch.testing._internal.common_cuda import tf32_is_not_fp32
15from torch.testing._internal.common_device_type import (
16    dtypes,
17    instantiate_device_type_tests,
18    onlyXPU,
19)
20from torch.testing._internal.common_dtype import floating_types_and
21from torch.testing._internal.common_nn import _test_module_empty_input, NNTestCase
22from torch.testing._internal.common_utils import (
23    dtype2prec_DONTUSE,
24    gradcheck,
25    gradgradcheck,
26    parametrize as parametrize_test,
27    run_tests,
28    set_default_dtype,
29    TEST_SCIPY,
30    TEST_WITH_ROCM,
31)
32
33
34AMPERE_OR_ROCM = TEST_WITH_ROCM or tf32_is_not_fp32()
35if TEST_SCIPY:
36    import scipy.ndimage
37    import scipy.signal
38
39
40class TestConvolutionNNDeviceType(NNTestCase):
41    def run_conv_double_back_test(
42        self,
43        kern,
44        stride,
45        padding,
46        chan_in,
47        chan_out,
48        batch_size,
49        inp_size,
50        dilation,
51        no_weight,
52        groups=1,
53        use_xpu=False,
54        use_bias=True,
55        dtype=torch.double,
56    ):
57        device = torch.device("xpu" if use_xpu else "cpu")
58        x = torch.randn(
59            batch_size,
60            chan_in,
61            inp_size,
62            inp_size,
63            device=device,
64            dtype=dtype,
65            requires_grad=True,
66        )
67        weight = torch.randn(
68            chan_out,
69            chan_in // groups,
70            kern,
71            kern,
72            device=device,
73            dtype=dtype,
74            requires_grad=not no_weight,
75        )
76        if use_bias:
77            bias = torch.randn(chan_out, device=device, dtype=dtype, requires_grad=True)
78        else:
79            bias = None
80
81        def func(*inputs):
82            if use_bias:
83                lx, lweight, lbias = inputs
84            else:
85                lx, lweight = inputs
86                lbias = None
87            out = F.conv2d(lx, lweight, lbias, stride, padding, dilation, groups)
88            return out
89
90        if use_bias:
91            inputs = x, weight, bias
92        else:
93            inputs = x, weight
94
95        dummy_out = func(*inputs)
96        grad_y = torch.randn_like(
97            dummy_out, device=device, dtype=dtype, requires_grad=True
98        )
99
100        if dtype == torch.float:
101            (g,) = torch.autograd.grad(dummy_out.sum(), x, create_graph=True)
102            return g.requires_grad
103
104        return gradgradcheck(func, inputs, (grad_y,))
105
106    @dtypes(*floating_types_and(torch.half, torch.bfloat16))
107    def test_Conv2d_large_workspace(self, device, dtype):
108        sizes = [
109            (1, 256, 109, 175),
110            (1, 256, 80, 128),
111            (1, 256, 120, 192),
112        ]
113
114        def run_test(benchmark):
115            conv = torch.nn.Conv2d(256, 256, kernel_size=3, padding=1).to(device, dtype)
116            for size in sizes:
117                x = torch.randn(size, device=device, dtype=dtype)
118                out = conv(x.detach().clone().requires_grad_())
119                out.backward(torch.ones_like(out))
120
121        run_test(benchmark=False)
122        run_test(benchmark=True)
123
124    @dtypes(torch.half, torch.float)
125    def test_ConvTranspose2d_large_output_padding(self, device, dtype):
126        net1 = torch.nn.ConvTranspose2d(
127            128, 64, kernel_size=3, stride=2, padding=1, output_padding=1
128        ).to(device=device, dtype=dtype)
129        net2 = torch.nn.ConvTranspose2d(
130            64, 32, kernel_size=3, stride=2, padding=1, output_padding=1
131        ).to(device=device, dtype=dtype)
132        net3 = torch.nn.ConvTranspose2d(
133            32, 3, kernel_size=3, stride=2, padding=1, output_padding=1
134        ).to(device=device, dtype=dtype)
135        x = torch.rand(1, 128, 6, 6, device=device, dtype=dtype, requires_grad=True)
136        x = net1(x)
137        x = net2(x)
138        x = net3(x)
139        x.backward(torch.randn_like(x))
140
141    @dtypes(torch.float, torch.double, torch.half)
142    def test_Conv2d_depthwise_naive_groups(self, device, dtype):
143        if dtype == torch.half and "xpu" in device:
144            self.skipTest(
145                "The accuracy issue of dtype fp16 would be fixed in oneDNN v3.4"
146            )
147        for depth_multiplier in [1, 2]:
148            m = nn.Conv2d(2, 2 * depth_multiplier, kernel_size=3, groups=2).to(
149                device, dtype
150            )
151            i = (
152                torch.randn(2, 2, 6, 6, device=device, dtype=dtype)
153                .div_(2)
154                .requires_grad_()
155            )
156            output = m(i)
157            grad_output = (
158                torch.randn(2, 2 * depth_multiplier, 4, 4, device=device, dtype=dtype)
159                / 2
160            )
161            output.backward(grad_output)
162
163            offset = 1 * depth_multiplier
164
165            m1 = nn.Conv2d(1, 1 * depth_multiplier, kernel_size=3).to(device, dtype)
166            m1.weight.data = m.weight.data[:offset].clone()
167            m1.bias.data = m.bias.data[:offset].clone()
168            i1 = i.detach()[:, :1].clone().requires_grad_()
169            output1 = m1(i1)
170            output1.backward(grad_output[:, :offset].contiguous())
171
172            m2 = nn.Conv2d(1, 1 * depth_multiplier, kernel_size=3).to(device, dtype)
173            m2.weight.data.copy_(m.weight.data[offset:])
174            m2.bias.data.copy_(m.bias.data[offset:])
175            i2 = i.detach()[:, 1:].clone().requires_grad_()
176            output2 = m2(i2)
177            output2.backward(grad_output[:, offset:].contiguous())
178
179            self.assertEqual(
180                output,
181                torch.cat([output1, output2], 1),
182                atol=dtype2prec_DONTUSE[dtype],
183                rtol=0,
184            )
185            self.assertEqual(
186                i.grad.data,
187                torch.cat([i1.grad.data, i2.grad.data], 1),
188                atol=dtype2prec_DONTUSE[dtype],
189                rtol=0,
190            )
191            self.assertEqual(
192                m.bias.grad.data,
193                torch.cat([m1.bias.grad.data, m2.bias.grad.data], 0),
194                atol=dtype2prec_DONTUSE[dtype],
195                rtol=0,
196            )
197            self.assertEqual(
198                m.weight.grad.data,
199                torch.cat([m1.weight.grad.data, m2.weight.grad.data], 0),
200                atol=dtype2prec_DONTUSE[dtype],
201                rtol=0,
202            )
203
204    @dtypes(torch.float, torch.double, torch.half)
205    def test_Conv3d_depthwise_naive_groups(self, device, dtype):
206        if dtype == torch.half and "xpu" in device:
207            self.skipTest(
208                "The accuracy issue of dtype fp16 would be fixed in oneDNN v3.4"
209            )
210        for depth_multiplier in [1, 2]:
211            m = nn.Conv3d(2, 2 * depth_multiplier, kernel_size=3, groups=2).to(
212                device, dtype
213            )
214            i = (
215                torch.randn(2, 2, 6, 6, 6, device=device, dtype=dtype)
216                .div_(2)
217                .requires_grad_()
218            )
219            output = m(i)
220            grad_output = (
221                torch.randn(
222                    2, 2 * depth_multiplier, 4, 4, 4, device=device, dtype=dtype
223                )
224                / 2
225            )
226            output.backward(grad_output)
227
228            offset = 1 * depth_multiplier
229
230            m1 = nn.Conv3d(1, 1 * depth_multiplier, kernel_size=3).to(device, dtype)
231            m1.weight.data = m.weight.data[:offset].clone()
232            m1.bias.data = m.bias.data[:offset].clone()
233            i1 = i.detach()[:, :1].clone().requires_grad_()
234            output1 = m1(i1)
235            output1.backward(grad_output[:, :offset].contiguous())
236
237            m2 = nn.Conv3d(1, 1 * depth_multiplier, kernel_size=3).to(device, dtype)
238            m2.weight.data.copy_(m.weight.data[offset:])
239            m2.bias.data.copy_(m.bias.data[offset:])
240            i2 = i.detach()[:, 1:].clone().requires_grad_()
241            output2 = m2(i2)
242            output2.backward(grad_output[:, offset:].contiguous())
243            atol, rtol = (3e-4, 3e-2)
244
245            self.assertEqual(
246                output, torch.cat([output1, output2], 1), atol=atol, rtol=rtol
247            )
248            self.assertEqual(
249                i.grad.data,
250                torch.cat([i1.grad.data, i2.grad.data], 1),
251                atol=dtype2prec_DONTUSE[dtype],
252                rtol=0,
253            )
254            self.assertEqual(
255                m.bias.grad.data,
256                torch.cat([m1.bias.grad.data, m2.bias.grad.data], 0),
257                atol=dtype2prec_DONTUSE[dtype],
258                rtol=0,
259            )
260            self.assertEqual(
261                m.weight.grad.data,
262                torch.cat([m1.weight.grad.data, m2.weight.grad.data], 0),
263                atol=atol,
264                rtol=rtol,
265            )
266
267    @dtypes(torch.float, torch.double, torch.half)
268    def test_noncontig_conv_grad(self, device, dtype):
269        module = nn.Conv2d(3, 5, kernel_size=3, padding=1).to(device, dtype)
270        input = torch.randn(
271            2, 3, 10, 10, dtype=dtype, device=device, requires_grad=True
272        )
273        output = module(input)
274
275        grad = torch.randn(2, 2, 5, 10, 10, dtype=dtype, device=device)[:, 1]
276        assert not grad.is_contiguous()
277        output.backward(grad, retain_graph=True)
278        self.assertIsNotNone(input.grad)
279        result = input.grad.data.clone()
280        input.grad.data.zero_()
281
282        output.backward(grad.contiguous())
283        self.assertEqual(
284            result, input.grad.data, atol=dtype2prec_DONTUSE[dtype], rtol=0
285        )
286
287    @dtypes(torch.double)
288    def test_conv_double_backward(self, device, dtype):
289        with torch.backends.cudnn.flags(enabled=True, deterministic=True):
290            batch_size = 1
291            for kern, inp_size, dilations in [(3, 5, [1, 2]), (4, 9, [1])]:
292                for stride, padding, chan_in, chan_out, dilation in product(
293                    [1], [2], [2], [3], dilations
294                ):
295                    no_weight = stride == 2
296                    result = self.run_conv_double_back_test(
297                        kern,
298                        stride,
299                        padding,
300                        chan_in,
301                        chan_out,
302                        batch_size,
303                        inp_size,
304                        dilation,
305                        no_weight,
306                        use_xpu=True,
307                        dtype=dtype,
308                    )
309                    self.assertTrue(result, "Conv double backward test failed")
310
311    def test_conv_double_backward_no_bias(self):
312        kern, stride = 3, 2
313        chan_in, chan_out = 2, 4
314        batch_size, inp_size = 2, 5
315        padding, dilation = 1, 1
316        no_weight, use_bias = False, True
317        result = self.run_conv_double_back_test(
318            kern,
319            stride,
320            padding,
321            chan_in,
322            chan_out,
323            batch_size,
324            inp_size,
325            dilation,
326            no_weight,
327            use_bias=use_bias,
328        )
329        self.assertTrue(result, "Conv double backward test failed")
330
331    def test_conv_double_backward_groups(self):
332        kern, stride, padding = 3, 1, 2
333        chan_in, chan_out = 2, 4
334        batch_size, inp_size, dilation = 2, 6, 1
335        no_weight = False
336        groups = 2
337        result = self.run_conv_double_back_test(
338            kern,
339            stride,
340            padding,
341            chan_in * groups,
342            chan_out * groups,
343            batch_size,
344            inp_size,
345            dilation,
346            no_weight,
347            groups=groups,
348        )
349        self.assertTrue(result, "Conv double backward test failed")
350
351    def test_conv_double_backward_stride(self):
352        batch_size = 2
353        for kern, inp_size, dilations in [(3, 5, [1, 2]), (3, 7, [1])]:
354            for stride, padding, chan_in, chan_out, dilation in product(
355                [2], [0, 1], [1], [2], dilations
356            ):
357                no_weight = False
358                self.run_conv_double_back_test(
359                    kern,
360                    stride,
361                    padding,
362                    chan_in,
363                    chan_out,
364                    batch_size,
365                    inp_size,
366                    dilation,
367                    no_weight,
368                )
369
370    @dtypes(torch.float)
371    def test_conv1d_same_padding(self, device, dtype):
372        test_args = [
373            range(50, 55),
374            [1, 2, 3, 8],
375            range(1, 4),
376            [1],
377        ]
378        for in_size, k_size, dilation, stride in itertools.product(*test_args):
379            x = torch.rand(1, 1, in_size, device=device, dtype=dtype)
380            y = torch.rand(1, 1, k_size, device=device, dtype=dtype)
381            z = F.conv1d(x, y, padding="same", dilation=dilation, stride=stride)
382            self.assertEqual(z.size(2), int(math.ceil(in_size / stride)))
383
384        x = torch.rand(1, 1, 12, device=device, dtype=dtype)
385        y = torch.rand(1, 1, 3, device=device, dtype=dtype)
386        expect = F.conv1d(x, y, padding=1)
387        actual = F.conv1d(x, y, padding="same")
388        self.assertEqual(expect, actual)
389
390        x = torch.rand(1, 1, 12, device=device, dtype=dtype)
391        y = torch.rand(1, 1, 4, device=device, dtype=dtype)
392        expect = F.conv1d(x, y, padding=3, dilation=2)
393        actual = F.conv1d(x, y, padding="same", dilation=2)
394        self.assertEqual(expect, actual)
395
396        expect = F.conv1d(x, y, padding=5, dilation=3)[..., 1:]
397        actual = F.conv1d(x, y, padding="same", dilation=3)
398        self.assertEqual(expect, actual)
399
400    @dtypes(torch.float)
401    def test_conv3d_same_padding(self, device, dtype):
402        rtol, atol = None, None
403        x = torch.rand(1, 1, 10, 11, 12, device=device, dtype=dtype)
404        y = torch.rand(1, 1, 1, 2, 5, device=device, dtype=dtype)
405        expect = F.conv3d(x, y, padding=(0, 1, 2))[..., :, 1:, :]
406        actual = F.conv3d(x, y, padding="same")
407        self.assertEqual(expect, actual, rtol=rtol, atol=atol)
408
409        expect = F.conv3d(x, y, padding=(0, 1, 4), dilation=2)
410        actual = F.conv3d(x, y, padding="same", dilation=2)
411        self.assertEqual(expect, actual, rtol=rtol, atol=atol)
412
413        y = torch.rand(1, 1, 4, 4, 4, device=device, dtype=dtype)
414        expect = F.conv3d(x, y, padding=5, dilation=3)[..., 1:, 1:, 1:]
415        actual = F.conv3d(x, y, padding="same", dilation=3)
416        self.assertEqual(expect, actual, rtol=rtol, atol=atol)
417
418    @dtypes(torch.float)
419    def test_conv1d_valid_padding(self, device, dtype):
420        x = torch.rand(1, 1, 10, device=device, dtype=dtype)
421        y = torch.rand(1, 1, 4, device=device, dtype=dtype)
422        expect = F.conv1d(x, y)
423        actual = F.conv1d(x, y, padding="valid")
424        self.assertEqual(expect, actual)
425
426    @dtypes(torch.float)
427    def test_conv2d_valid_padding(self, device, dtype):
428        x = torch.rand(1, 1, 1, 10, device=device, dtype=dtype)
429        y = torch.rand(1, 1, 1, 4, device=device, dtype=dtype)
430        expect = F.conv2d(x, y)
431        actual = F.conv2d(x, y, padding="valid")
432        self.assertEqual(expect, actual)
433
434    @dtypes(torch.float)
435    def test_conv3d_valid_padding(self, device, dtype):
436        x = torch.rand(1, 1, 1, 1, 10, dtype=dtype, device=device)
437        y = torch.rand(1, 1, 1, 1, 4, dtype=dtype, device=device)
438        expect = F.conv3d(x, y)
439        actual = F.conv3d(x, y, padding="valid")
440        self.assertEqual(expect, actual)
441
442    @dtypes(torch.float)
443    def test_conv1d_same_padding_backward(self, device, dtype):
444        x = torch.rand(1, 1, 12, dtype=dtype, device=device, requires_grad=True)
445        y = torch.rand(1, 1, 4, dtype=dtype, device=device, requires_grad=True)
446
447        z = F.conv1d(x, y, padding=3, dilation=2)
448        z.sum().abs().backward()
449        gx_expect, gy_expect = x.grad, y.grad
450        x.grad, y.grad = None, None
451
452        z = F.conv1d(x, y, padding="same", dilation=2)
453        z.sum().abs().backward()
454        self.assertEqual(gx_expect, x.grad)
455        self.assertEqual(gy_expect, y.grad)
456        x.grad, y.grad = None, None
457
458        z = F.conv1d(x, y, padding=2)[..., 1:]
459        z.sum().abs().backward()
460        gx_expect, gy_expect = x.grad, y.grad
461        x.grad, y.grad = None, None
462
463        z = F.conv1d(x, y, padding="same")
464        z.sum().abs().backward()
465        self.assertEqual(gx_expect, x.grad)
466        self.assertEqual(gy_expect, y.grad)
467
468    @dtypes(torch.float)
469    def test_conv2d_same_padding_backward(self, device, dtype):
470        x = torch.rand(1, 1, 10, 11, device=device, dtype=dtype, requires_grad=True)
471        y = torch.rand(1, 1, 4, 5, device=device, dtype=dtype, requires_grad=True)
472
473        z = F.conv2d(x, y, padding=(3, 4), dilation=2)
474        z.sum().abs().backward()
475        gx_expect, gy_expect = x.grad, y.grad
476        x.grad, y.grad = None, None
477
478        z = F.conv2d(x, y, padding="same", dilation=2)
479        z.sum().abs().backward()
480        self.assertEqual(gx_expect, x.grad)
481        self.assertEqual(gy_expect, y.grad)
482        x.grad, y.grad = None, None
483
484        y = torch.rand(1, 1, 4, 4, device=device, dtype=dtype, requires_grad=True)
485        z = F.conv2d(x, y, padding=2)[..., 1:, 1:]
486        z.sum().abs().backward()
487        gx_expect, gy_expect = x.grad, y.grad
488        x.grad, y.grad = None, None
489
490        z = F.conv2d(x, y, padding="same")
491        z.sum().abs().backward()
492        self.assertEqual(gx_expect, x.grad)
493        self.assertEqual(gy_expect, y.grad)
494
495    @dtypes(torch.double)
496    def test_conv3d_same_padding_backward(self, device, dtype):
497        x = torch.rand(1, 1, 1, 11, 12, dtype=dtype, device=device, requires_grad=True)
498        y = torch.rand(1, 1, 1, 2, 5, dtype=dtype, device=device, requires_grad=True)
499        z = F.conv3d(x, y, padding=(0, 1, 4), dilation=2)
500        z.sum().abs().backward()
501        gx_expect, gy_expect = x.grad, y.grad
502        x.grad, y.grad = None, None
503
504        z = F.conv3d(x, y, padding="same", dilation=2)
505        z.sum().abs().backward()
506        self.assertEqual(gx_expect, x.grad)
507        self.assertEqual(gy_expect, y.grad)
508        x.grad, y.grad = None, None
509        gradcheck(
510            lambda x, y: F.conv3d(x, y, padding="same", dilation=2),
511            (x, y),
512            check_forward_ad=True,
513            nondet_tol=1e-5,
514        )
515        gradgradcheck(
516            lambda x, y: F.conv3d(x, y, padding="same", dilation=2),
517            (x, y),
518            check_fwd_over_rev=True,
519        )
520
521        y = torch.rand(1, 1, 1, 4, 4, dtype=dtype, device=device, requires_grad=True)
522        z = F.conv3d(x, y, padding=2)[..., 1:, 1:]
523        z.sum().abs().backward()
524        gx_expect, gy_expect = x.grad, y.grad
525        x.grad, y.grad = None, None
526
527        z = F.conv3d(x, y, padding="same")
528        z.sum().abs().backward()
529        self.assertEqual(gx_expect, x.grad)
530        self.assertEqual(gy_expect, y.grad)
531        gradcheck(
532            lambda x, y: F.conv3d(x, y, padding="same"),
533            (x, y),
534            check_forward_ad=True,
535            nondet_tol=1e-5,
536        )
537        gradgradcheck(
538            lambda x, y: F.conv3d(x, y, padding="same"),
539            (x, y),
540            check_fwd_over_rev=True,
541        )
542
543    @dtypes(torch.float)
544    def test_conv1d_valid_padding_backward(self, device, dtype):
545        x = torch.rand(1, 1, 10, dtype=dtype, device=device, requires_grad=True)
546        y = torch.rand(1, 1, 4, dtype=dtype, device=device, requires_grad=True)
547        F.conv1d(x, y, padding=0).sum().abs().backward()
548        gx_expect, gy_expect = x.grad, y.grad
549        x.grad, y.grad = None, None
550        F.conv1d(x, y, padding="valid").sum().abs().backward()
551        gx_actual, gy_actual = x.grad, y.grad
552        self.assertEqual(gx_expect, gx_actual)
553        self.assertEqual(gy_expect, gy_actual)
554
555    @unittest.skipIf(not TEST_SCIPY, "Scipy required for the test.")
556    @dtypes(torch.float)
557    @parametrize_test("mode", ("valid", "same"))
558    def test_conv1d_vs_scipy(self, device, dtype, mode):
559        t = make_tensor((1, 10), device=device, dtype=dtype)
560        feat_dim = t.shape[1]
561        weight_even = make_tensor((1, 1, 4), device=device, dtype=dtype)
562        weight_odd = make_tensor((1, 1, 5), device=device, dtype=dtype)
563
564        def _test(t, weight, mode):
565            t_a = t.view(-1).cpu().numpy()
566            w_a = weight.view(-1).cpu().numpy()
567            expected = scipy.signal.convolve(t_a, w_a, mode=mode)
568
569            kwargs = {"padding": mode}
570            if mode == "same":
571                p = weight.shape[2] // 2
572                t = torch.nn.functional.pad(t, (p, p))
573                kwargs.pop("padding")
574
575            weight_flipped = torch.flip(weight, (2,))
576            actual = torch.nn.functional.conv1d(t, weight_flipped, **kwargs).squeeze(0)
577            if mode == "same":
578                actual = actual[:feat_dim]
579
580            self.assertEqual(actual, expected, atol=2e-5, rtol=2e-5)
581
582        with set_default_dtype(torch.float):
583            _test(t, weight_even, mode)
584            _test(t, weight_odd, mode)
585
586    @unittest.skipIf(not TEST_SCIPY, "Scipy required for the test.")
587    @dtypes(torch.float)
588    @parametrize_test("mode", ("valid", "same"))
589    def test_conv2d_vs_scipy(self, device, dtype, mode):
590        t = make_tensor((1, 5, 10), device=device, dtype=dtype)
591        weight_even = make_tensor((1, 1, 2, 4), device=device, dtype=dtype)
592        weight_odd = make_tensor((1, 1, 3, 5), device=device, dtype=dtype)
593
594        def _test(t, weight, mode):
595            t_a = t.squeeze(0).cpu().numpy()
596            w_a = weight.squeeze(0).squeeze(0).cpu().numpy()
597            expected = scipy.signal.convolve2d(t_a, w_a, mode=mode)
598
599            kwargs = {"padding": mode}
600            if mode == "same":
601                left_right_pad = weight.shape[3] // 2
602                top_bottom_pad = weight.shape[2] // 2
603                p = (left_right_pad, left_right_pad, top_bottom_pad, top_bottom_pad)
604                t = torch.nn.functional.pad(t, p)
605                kwargs.pop("padding")
606
607            weight_flipped = torch.flip(weight, (2, 3))
608            actual = torch.nn.functional.conv2d(t, weight_flipped, **kwargs).squeeze(0)
609            if mode == "same":
610                actual = actual[:5, :10]
611
612            self.assertEqual(actual, expected, rtol=2e-5, atol=5e-6)
613
614        with set_default_dtype(torch.float):
615            _test(t, weight_even, mode)
616            _test(t, weight_odd, mode)
617
618    @unittest.skipIf(not TEST_SCIPY, "Scipy required for the test.")
619    @dtypes(torch.float)
620    @parametrize_test("mode", ("valid", "same"))
621    def test_conv3d_vs_scipy(self, device, dtype, mode):
622        t = make_tensor((1, 5, 5, 10), device=device, dtype=dtype)
623        weight_even = make_tensor((1, 1, 2, 2, 4), device=device, dtype=dtype)
624        weight_odd = make_tensor((1, 1, 2, 3, 5), device=device, dtype=dtype)
625
626        def _test(t, weight, mode):
627            t_a = t.squeeze(0).cpu().numpy()
628            w_a = weight.squeeze(0).squeeze(0).cpu().numpy()
629            expected = scipy.signal.convolve(t_a, w_a, mode=mode)
630            kwargs = {"padding": mode}
631            if mode == "same":
632                left_right_pad = weight.shape[4] // 2
633                top_bottom_pad = weight.shape[3] // 2
634                front_back_pad = weight.shape[2] // 2
635                p = (
636                    left_right_pad,
637                    left_right_pad,
638                    top_bottom_pad,
639                    top_bottom_pad,
640                    front_back_pad,
641                    front_back_pad,
642                )
643                t = torch.nn.functional.pad(t, p)
644                kwargs.pop("padding")
645            weight_flipped = torch.flip(weight, (2, 3, 4))
646            actual = torch.nn.functional.conv3d(t, weight_flipped, **kwargs).squeeze(0)
647            if mode == "same":
648                actual = actual[:5, :5, :10]
649            self.assertEqual(actual, expected, rtol=2e-5, atol=5e-6)
650
651        with set_default_dtype(torch.float):
652            _test(t, weight_even, mode)
653            _test(t, weight_odd, mode)
654
655    @dtypes(torch.float)
656    def test_conv2d_valid_padding_backward(self, device, dtype):
657        x = torch.rand(1, 1, 1, 10, device=device, dtype=dtype, requires_grad=True)
658        y = torch.rand(1, 1, 1, 4, device=device, dtype=dtype, requires_grad=True)
659        F.conv2d(x, y, padding=0).sum().abs().backward()
660        gx_expect, gy_expect = x.grad, y.grad
661        x.grad, y.grad = None, None
662        F.conv2d(x, y, padding="valid").sum().abs().backward()
663        gx_actual, gy_actual = x.grad, y.grad
664        self.assertEqual(gx_expect, gx_actual)
665        self.assertEqual(gy_expect, gy_actual)
666
667    @dtypes(torch.double)
668    def test_conv3d_valid_padding_backward(self, device, dtype):
669        x = torch.rand(1, 1, 1, 1, 10, dtype=dtype, device=device, requires_grad=True)
670        y = torch.rand(1, 1, 1, 1, 4, dtype=dtype, device=device, requires_grad=True)
671        F.conv3d(x, y, padding=0).sum().abs().backward()
672        gx_expect, gy_expect = x.grad, y.grad
673        x.grad, y.grad = None, None
674
675        F.conv3d(x, y, padding="valid").sum().abs().backward()
676        gx_actual, gy_actual = x.grad, y.grad
677        self.assertEqual(gx_expect, gx_actual)
678        self.assertEqual(gy_expect, gy_actual)
679        gradcheck(
680            lambda x, y: F.conv3d(x, y, padding="valid"),
681            (x, y),
682            check_forward_ad=True,
683        )
684        gradgradcheck(
685            lambda x, y: F.conv3d(x, y, padding="valid"),
686            (x, y),
687            check_fwd_over_rev=True,
688        )
689
690    @parametrize_test("N", range(2, 4), name_fn=lambda N: f"ConvTranspose{N}d")
691    def test_conv_transpose_with_output_size_and_no_batch_dim(self, device, N):
692        inp = torch.randn((1, 15, 13) if N == 2 else (1, 15, 13, 13), device=device)
693        output_size = (1, 240, 200) if N == 2 else (1, 240, 200, 200)
694        ConvTransposeNd = getattr(nn, f"ConvTranspose{N}d")
695        m = ConvTransposeNd(
696            1, 1, kernel_size=16, stride=16, padding=7, bias=False, device=device
697        )
698        output = m(inp, output_size=output_size)
699        self.assertEqual(output.shape, output_size)
700
701    @dtypes(torch.float)
702    def test_conv_empty_channel(self, device, dtype):
703        in_channels = 0
704        mod = torch.nn.Conv1d(in_channels, 8, 2, stride=2, dtype=dtype).to(device)
705        inp = torch.randn(2, 0, 15, device=device, dtype=dtype)
706        _test_module_empty_input(self, mod, inp, check_size=False)
707
708        with self.assertRaisesRegex(RuntimeError, "Given groups=1, weight"):
709            inp = torch.randn(2, 1, 0, device=device, dtype=dtype)
710            mod(inp)
711
712        mod = torch.nn.Conv2d(in_channels, 33, 3, stride=2, dtype=dtype).to(device)
713        inp = torch.randn(2, 0, 50, 100, device=device, dtype=dtype)
714        _test_module_empty_input(self, mod, inp, check_size=False)
715
716        with self.assertRaisesRegex(RuntimeError, "Given groups=1, weight"):
717            inp = torch.randn(2, 1, 40, 0, device=device, dtype=dtype)
718            mod(inp)
719
720        mod = torch.nn.Conv3d(in_channels, 33, 3, stride=2, dtype=dtype).to(device)
721        inp = torch.randn(2, 0, 50, 20, 40, device=device, dtype=dtype)
722        _test_module_empty_input(self, mod, inp, check_size=False)
723
724        with self.assertRaisesRegex(RuntimeError, "Given groups=1, weight"):
725            inp = torch.randn(2, 1, 50, 0, 40, device=device, dtype=dtype)
726            mod(inp)
727
728    def test_group_conv_empty(self, device):
729        mod = torch.nn.Conv2d(4, 4, stride=2, kernel_size=3, padding=1, groups=4).to(
730            device
731        )
732        inp = torch.randn(0, 4, 4, 4, device=device)
733        _test_module_empty_input(self, mod, inp, check_size=False)
734
735    def test_group_convTranspose_empty(self, device):
736        mod = torch.nn.ConvTranspose2d(
737            4, 4, stride=2, kernel_size=3, padding=1, groups=4
738        ).to(device)
739        inp = torch.randn(0, 4, 4, 4, device=device)
740        _test_module_empty_input(self, mod, inp, check_size=False)
741
742    def test_convTranspose_empty(self, device):
743        mod = torch.nn.ConvTranspose2d(4, 4, stride=2, kernel_size=3, padding=1).to(
744            device
745        )
746        inp = torch.randn(0, 4, 4, 4, device=device)
747        _test_module_empty_input(self, mod, inp, check_size=False)
748
749    def test_conv_large_nosplit(self, device):
750        dtype = torch.half
751        conv1 = nn.Conv2d(2, 2, 8, 8).to(device).to(dtype)
752        input_large = torch.randn(1, 2, 1024, 1024 * 1024, dtype=dtype, device=device)
753        conv1(input_large)
754        conv2 = torch.nn.Conv2d(1, 1024, 1, 1).to(device).to(dtype)
755        input_large = torch.randn(1, 1, 2048, 1024, dtype=dtype, device=device)
756        conv2(input_large)
757
758    def test_conv_noncontig_weights(self, device):
759        for dim in (1, 2, 3):
760            for grouped in (False, True):
761                nc = 3
762                groups = 3 if grouped else 1
763                w = torch.randn([3] * dim, device=device)
764                w = w.expand([nc, int(nc / groups)] + list(w.shape))
765                w = w.detach().requires_grad_()
766                x = torch.randn(
767                    [1, nc] + ([5] * dim), device=device, requires_grad=True
768                )
769                y = getattr(F, f"conv{dim}d")(x, w, groups=groups)
770                y.sum().backward()
771                y = getattr(F, f"conv_transpose{dim}d")(x, w, groups=groups)
772                y.sum().backward()
773
774    def test_conv_noncontig_weights_and_bias(self, device):
775        for bias in [True, False]:
776            conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=bias).to(
777                device, torch.float
778            )
779            input_nc = torch.randn(
780                (1, 3, 224, 224, 2), device=device, dtype=torch.float
781            )[:, :, :, :, 1]
782            input_c = input_nc.contiguous()
783            weight_nc = torch.randn((64, 3, 7, 7, 2), device=device, dtype=torch.float)[
784                :, :, :, :, 1
785            ]
786            conv1.weight = nn.Parameter(weight_nc)
787            weight_c = conv1.weight.contiguous()
788            if bias:
789                bias_nc = torch.randn((64, 2), device=device, dtype=torch.float)[:, 1]
790                conv1.bias = nn.Parameter(bias_nc)
791                bias_c = conv1.bias.contiguous()
792            out1 = conv1(input_nc)
793            conv1.weight = nn.Parameter(weight_c)
794            if bias:
795                conv1.bias = nn.Parameter(bias_c)
796            out2 = conv1(input_c)
797            self.assertEqual(out1, out2)
798
799    def test_conv_transposed_large(self, device):
800        dtype = torch.half if self.device_type == "cuda" else torch.float
801        conv = nn.ConvTranspose2d(1, 1, 1, 1, bias=False).to(device).to(dtype)
802        input_large = torch.randn(4096, 1, 512, 1024, dtype=dtype, device=device)
803        ret = conv(input_large)
804        maxdiff0 = (
805            (ret.narrow(0, 0, 1024) - conv(input_large.narrow(0, 0, 1024)))
806            .abs_()
807            .max()
808            .item()
809        )
810        maxdiff1 = (
811            (ret.narrow(0, 1024, 1024) - conv(input_large.narrow(0, 1024, 1024)))
812            .abs_()
813            .max()
814            .item()
815        )
816        maxdiff2 = (
817            (ret.narrow(0, 2048, 1024) - conv(input_large.narrow(0, 2048, 1024)))
818            .abs_()
819            .max()
820            .item()
821        )
822        maxdiff3 = (
823            (ret.narrow(0, 3072, 1024) - conv(input_large.narrow(0, 3072, 1024)))
824            .abs_()
825            .max()
826            .item()
827        )
828        self.assertEqual(maxdiff0, 0)
829        self.assertEqual(maxdiff1, 0)
830        self.assertEqual(maxdiff2, 0)
831        self.assertEqual(maxdiff3, 0)
832
833    def test_conv_large(self, device):
834        dtype = torch.half if self.device_type == "cuda" else torch.float
835        conv = nn.Conv2d(2, 2, 8, 8, bias=False).to(device).to(dtype)
836        input_large = torch.randn(4097, 2, 512, 512, dtype=dtype, device=device)
837        ret = conv(input_large)
838        self.assertEqual(ret[:2048], conv(input_large[:2048]))
839        self.assertEqual(ret[2048:4096], conv(input_large[2048:4096]))
840        self.assertEqual(ret[4096:], conv(input_large[4096:]))
841
842        conv.zero_grad()
843        ret.view(4097, -1).max(dim=1).values.sum().backward()
844        del ret
845        grad1 = conv.weight.grad.detach().clone()
846        conv.zero_grad()
847        conv(input_large[:2048]).view(2048, -1).max(dim=1).values.sum().backward()
848        conv(input_large[2048:4096]).view(2048, -1).max(dim=1).values.sum().backward()
849        conv(input_large[4096:]).view(1, -1).max(dim=1).values.sum().backward()
850        grad2 = conv.weight.grad.detach().clone()
851        scale = 1 / grad2.abs().mean()
852        grad1 = grad1 * scale
853        grad2 = grad2 * scale
854        self.assertEqual(grad1, grad2, atol=5e-2, rtol=5e-3)
855
856    def test_Conv2d_size_1_kernel(self, device):
857        x_cpu = torch.randn(2, 3, 5, 5)
858        conv_cpu = torch.nn.Conv2d(3, 3, kernel_size=1)
859        y_cpu = conv_cpu(x_cpu)
860        y = torch.rand_like(y_cpu)
861        y_cpu.backward(y)
862
863        with cudnn.flags(enabled=False):
864            conv_cuda = torch.nn.Conv2d(3, 3, kernel_size=1).to(device)
865            conv_cuda.bias.data.copy_(conv_cpu.bias.data)
866            conv_cuda.weight.data.copy_(conv_cpu.weight.data)
867            y_cuda = conv_cuda(x_cpu.to(device))
868            y_cuda.backward(y.to(device))
869
870        self.assertEqual(y_cpu, y_cuda, atol=1e-5, rtol=0, exact_device=False)
871        self.assertEqual(
872            conv_cpu.bias.grad.data,
873            conv_cuda.bias.grad.data,
874            atol=1e-5,
875            rtol=0,
876            exact_device=False,
877        )
878        self.assertEqual(
879            conv_cpu.weight.grad.data,
880            conv_cuda.weight.grad.data,
881            atol=1e-5,
882            rtol=0,
883            exact_device=False,
884        )
885
886    def test_ConvTranspose2d_size_1_kernel(self, device):
887        x_cpu = torch.randn(2, 3, 5, 5)
888        conv_cpu = torch.nn.ConvTranspose2d(3, 3, kernel_size=1)
889        y_cpu = conv_cpu(x_cpu)
890        y = torch.rand_like(y_cpu)
891        y_cpu.backward(y)
892        conv_cuda = torch.nn.ConvTranspose2d(3, 3, kernel_size=1).to(device)
893        conv_cuda.bias.data.copy_(conv_cpu.bias.data)
894        conv_cuda.weight.data.copy_(conv_cpu.weight.data)
895        y_cuda = conv_cuda(x_cpu.to(device))
896        y_cuda.backward(y.to(device))
897
898        self.assertEqual(y_cpu, y_cuda, atol=1e-5, rtol=0, exact_device=False)
899        self.assertEqual(
900            conv_cpu.bias.grad.data,
901            conv_cuda.bias.grad.data,
902            atol=1e-5,
903            rtol=0,
904            exact_device=False,
905        )
906        self.assertEqual(
907            conv_cpu.weight.grad.data,
908            conv_cuda.weight.grad.data,
909            atol=1e-5,
910            rtol=0,
911            exact_device=False,
912        )
913
914    def test_ConvTranspose3d_size_1_kernel(self, device):
915        with set_default_dtype(torch.double):
916            x_cpu = torch.randn(2, 3, 3, 5, 5)
917            conv_cpu = torch.nn.ConvTranspose3d(3, 3, kernel_size=1)
918            y_cpu = conv_cpu(x_cpu)
919            y = torch.rand_like(y_cpu)
920            y_cpu.backward(y)
921            conv_cuda = torch.nn.ConvTranspose3d(3, 3, kernel_size=1).to(device)
922            conv_cuda.bias.data.copy_(conv_cpu.bias.data)
923            conv_cuda.weight.data.copy_(conv_cpu.weight.data)
924            y_cuda = conv_cuda(x_cpu.to(device))
925            y_cuda.backward(y.to(device))
926
927            self.assertEqual(y_cpu, y_cuda, atol=1e-5, rtol=0, exact_device=False)
928            self.assertEqual(
929                conv_cpu.bias.grad.data,
930                conv_cuda.bias.grad.data,
931                atol=1e-5,
932                rtol=0,
933                exact_device=False,
934            )
935            self.assertEqual(
936                conv_cpu.weight.grad.data,
937                conv_cuda.weight.grad.data,
938                atol=1e-5,
939                rtol=0,
940                exact_device=False,
941            )
942
943    @dtypes(torch.float)
944    def test_Conv2d_naive_groups(self, device, dtype):
945        m = nn.Conv2d(4, 4, kernel_size=3, groups=2).to(device, dtype)
946        i = torch.randn(2, 4, 6, 6, device=device, dtype=dtype, requires_grad=True)
947        output = m(i)
948        grad_output = torch.randn(2, 4, 4, 4, device=device, dtype=dtype)
949        output.backward(grad_output)
950
951        m1 = nn.Conv2d(2, 2, kernel_size=3).to(device, dtype)
952        m1.weight.data.copy_(m.weight.data[:2])
953        m1.bias.data.copy_(m.bias.data[:2])
954        i1 = i.data[:, :2].contiguous().requires_grad_(True)
955        output1 = m1(i1)
956        output1.backward(grad_output[:, :2].contiguous())
957
958        m2 = nn.Conv2d(2, 2, kernel_size=3).to(device, dtype)
959        m2.weight.data.copy_(m.weight.data[2:])
960        m2.bias.data.copy_(m.bias.data[2:])
961        i2 = i.data[:, 2:].contiguous().requires_grad_(True)
962        output2 = m2(i2)
963        output2.backward(grad_output[:, 2:].contiguous())
964
965        self.assertEqual(output, torch.cat([output1, output2], 1))
966        self.assertEqual(
967            i.grad.data,
968            torch.cat([i1.grad.data, i2.grad.data], 1),
969            atol=dtype2prec_DONTUSE[dtype],
970            rtol=0,
971        )
972        self.assertEqual(
973            m.bias.grad.data,
974            torch.cat([m1.bias.grad.data, m2.bias.grad.data], 0),
975            atol=dtype2prec_DONTUSE[dtype],
976            rtol=0,
977        )
978        self.assertEqual(
979            m.weight.grad.data,
980            torch.cat([m1.weight.grad.data, m2.weight.grad.data], 0),
981            atol=dtype2prec_DONTUSE[dtype],
982            rtol=0,
983        )
984
985    @dtypes(torch.double)
986    def test_Conv2d_backward_depthwise(self, device, dtype):
987        x = torch.randn(2, 2, 4, 20, device=device, dtype=dtype, requires_grad=True)
988        weight = torch.randn(2, 1, 3, 5, device=device, dtype=dtype, requires_grad=True)
989
990        def conv2d_depthwise(x, weight):
991            return torch.nn.functional.conv2d(
992                x, weight, bias=None, stride=(1, 10), groups=2
993            )
994
995        torch.autograd.gradcheck(conv2d_depthwise, (x, weight))
996
997    @dtypes(torch.half, torch.float)
998    def test_conv_cudnn_nhwc(self, device, dtype):
999        def helper(n, c, h, w, out_channels, kernel_size, groups):
1000            input = torch.randint(-3, 3, (n, c, h, w), dtype=dtype, device=device).to(
1001                memory_format=torch.channels_last
1002            )
1003            input.requires_grad_()
1004            conv = nn.Conv2d(c, out_channels, kernel_size, groups=groups).to(
1005                device=device, dtype=dtype, memory_format=torch.channels_last
1006            )
1007            for p in conv.parameters():
1008                p.data = torch.randint_like(p, -3, 3)
1009
1010            ref_input = input.detach().clone().contiguous().double().requires_grad_()
1011            ref_conv = nn.Conv2d(c, out_channels, kernel_size, groups=groups)
1012            ref_conv.load_state_dict(conv.state_dict())
1013            ref_conv = ref_conv.to(
1014                device=device, dtype=torch.double, memory_format=torch.contiguous_format
1015            )
1016
1017            out = conv(input)
1018            ref_out = ref_conv(ref_input)
1019
1020            grad = torch.randint_like(out, -3, 3)
1021            ref_grad = grad.detach().clone().double().contiguous()
1022
1023            out.backward(grad)
1024            ref_out.backward(ref_grad)
1025
1026            self.assertTrue(out.is_contiguous(memory_format=torch.channels_last))
1027            self.assertTrue(input.grad.is_contiguous(memory_format=torch.channels_last))
1028            self.assertTrue(
1029                conv.weight.grad.is_contiguous(memory_format=torch.channels_last)
1030            )
1031
1032            self.assertTrue(ref_out.is_contiguous())
1033            self.assertTrue(ref_input.grad.is_contiguous())
1034            self.assertTrue(ref_conv.weight.grad.is_contiguous())
1035
1036            self.assertEqual(out, ref_out, exact_dtype=False)
1037            self.assertEqual(conv.weight.grad, ref_conv.weight.grad, exact_dtype=False)
1038            self.assertEqual(conv.bias.grad, ref_conv.bias.grad, exact_dtype=False)
1039            self.assertEqual(input.grad, ref_input.grad, exact_dtype=False)
1040
1041        helper(2, 8, 4, 4, out_channels=4, kernel_size=3, groups=1)
1042        helper(2, 8, 4, 4, out_channels=8, kernel_size=3, groups=8)
1043        helper(1, 16, 56, 56, out_channels=16, kernel_size=3, groups=1)
1044        helper(1, 16, 56, 56, out_channels=16, kernel_size=3, groups=16)
1045
1046    @dtypes(torch.half, torch.float)
1047    def test_conv_cudnn_ndhwc(self, device, dtype):
1048        def helper(n, c, d, h, w, out_channels, kernel_size, groups):
1049            input = torch.randint(
1050                -2, 2, (n, c, d, h, w), dtype=dtype, device=device
1051            ).to(memory_format=torch.channels_last_3d)
1052            input.requires_grad_()
1053            conv = nn.Conv3d(c, out_channels, kernel_size, groups=groups).to(
1054                device=device, dtype=dtype, memory_format=torch.channels_last_3d
1055            )
1056            for p in conv.parameters():
1057                p.data = torch.randint_like(p, -2, 2)
1058
1059            ref_input = input.detach().clone().contiguous().double().requires_grad_()
1060            ref_conv = nn.Conv3d(c, out_channels, kernel_size, groups=groups)
1061            ref_conv.load_state_dict(conv.state_dict())
1062            ref_conv = ref_conv.to(
1063                device=device, dtype=torch.double, memory_format=torch.contiguous_format
1064            )
1065
1066            out = conv(input)
1067            ref_out = ref_conv(ref_input)
1068
1069            grad = torch.randint_like(out, -2, 2)
1070            ref_grad = grad.detach().clone().double().contiguous()
1071
1072            out.backward(grad)
1073            ref_out.backward(ref_grad)
1074
1075            self.assertTrue(out.is_contiguous(memory_format=torch.channels_last_3d))
1076            self.assertTrue(
1077                input.grad.is_contiguous(memory_format=torch.channels_last_3d)
1078            )
1079            self.assertTrue(
1080                conv.weight.grad.is_contiguous(memory_format=torch.channels_last_3d)
1081            )
1082
1083            self.assertTrue(ref_out.is_contiguous())
1084            self.assertTrue(ref_input.grad.is_contiguous())
1085            self.assertTrue(ref_conv.weight.grad.is_contiguous())
1086
1087            self.assertEqual(out, ref_out, exact_dtype=False)
1088            self.assertEqual(conv.weight.grad, ref_conv.weight.grad, exact_dtype=False)
1089            self.assertEqual(conv.bias.grad, ref_conv.bias.grad, exact_dtype=False)
1090            self.assertEqual(input.grad, ref_input.grad, exact_dtype=False)
1091
1092        helper(2, 8, 4, 4, 4, out_channels=4, kernel_size=3, groups=1)
1093        helper(2, 8, 4, 4, 4, out_channels=8, kernel_size=3, groups=8)
1094        helper(1, 16, 18, 18, 18, out_channels=16, kernel_size=3, groups=1)
1095        helper(1, 16, 18, 18, 18, out_channels=16, kernel_size=3, groups=16)
1096
1097    def _run_conv(
1098        self,
1099        layer,
1100        device,
1101        inp,
1102        grad,
1103        ref_conv,
1104        ref_input,
1105        ref_out,
1106        input_format,
1107        weight_format,
1108        grad_format,
1109        output_format,
1110    ):
1111        conv = (
1112            layer(inp.size(1), grad.size(1), ref_conv.weight.size(2)).float().to(device)
1113        )
1114        conv.load_state_dict(ref_conv.state_dict())
1115        weight_data = (
1116            conv.weight.detach().clone().contiguous(memory_format=weight_format)
1117        )
1118        conv.weight.data = weight_data.resize_(
1119            weight_data.size(), memory_format=weight_format
1120        )
1121        input = inp.clone().contiguous(memory_format=input_format)
1122        input.resize_(input.size(), memory_format=input_format)
1123        input = input.requires_grad_()
1124        grad = grad.contiguous(memory_format=grad_format)
1125        grad.resize_(grad.size(), memory_format=grad_format)
1126        out = conv(input)
1127        out.backward(grad)
1128        self.assertTrue(out.is_contiguous(memory_format=output_format))
1129        self.assertEqual(out, ref_out)
1130        self.assertEqual(conv.weight.grad, ref_conv.weight.grad)
1131        self.assertEqual(conv.bias.grad, ref_conv.bias.grad)
1132        self.assertEqual(input.grad, ref_input.grad)
1133
1134    def _test_conv_cudnn_nhwc_nchw(self, layer, n, c, h, w, k, filter_size, device):
1135        data = torch.randint(1, 10, (n, c, h, w), dtype=torch.float32, device=device)
1136        ref_input = data.clone().contiguous().requires_grad_(True)
1137        ref_conv = layer(c, k, filter_size).float().to(device)
1138        ref_out = ref_conv(ref_input)
1139        grad = torch.randint(1, 10, ref_out.size(), dtype=torch.float32, device=device)
1140        ref_out.backward(grad)
1141
1142        for w_f in [torch.contiguous_format, torch.channels_last]:
1143            for g_f in [torch.contiguous_format, torch.channels_last]:
1144                for input_format in [torch.contiguous_format, torch.channels_last]:
1145                    output_format = torch.contiguous_format
1146                    if input_format == torch.channels_last:
1147                        output_format = torch.channels_last
1148                    if w_f == torch.channels_last:
1149                        output_format = torch.channels_last
1150                    self._run_conv(
1151                        layer,
1152                        device,
1153                        data,
1154                        grad,
1155                        ref_conv,
1156                        ref_input,
1157                        ref_out,
1158                        input_format,
1159                        w_f,
1160                        g_f,
1161                        output_format,
1162                    )
1163
1164    @dtypes(torch.float, torch.double)
1165    def test_conv_cudnn_nhwc_support(self, device, dtype):
1166        input = torch.randn(
1167            (1, 16, 1, 1), dtype=dtype, device=device, requires_grad=True
1168        )
1169        weight = torch.randn(
1170            (8, 16, 3, 3), dtype=dtype, device=device, requires_grad=True
1171        )
1172        weight = weight.to(memory_format=torch.channels_last)
1173        o = torch.conv2d(input, weight, None, (2, 1), (1, 1), (1, 1), 1)
1174        self.assertTrue(o.is_contiguous(memory_format=torch.channels_last))
1175        o.sum().backward()
1176
1177    @dtypes(torch.float)
1178    def test_conv2d_no_grad(self, device, dtype):
1179        for batch in [1, 2, 3]:
1180            for groups in [1, 2, 4]:
1181                input = torch.rand(batch, groups, 8, 8, dtype=dtype, device=device)
1182                m = nn.Conv2d(
1183                    groups,
1184                    8,
1185                    kernel_size=(3, 3),
1186                    groups=groups,
1187                    dtype=dtype,
1188                    device=device,
1189                )
1190                with torch.no_grad():
1191                    output_ng = m(input)
1192                output = m(input)
1193                self.assertEqual(output, output_ng, rtol=1e-2, atol=1e-5)
1194
1195    def test_conv_double_backward_strided_with_3D_input_and_weight(self, device):
1196        input = torch.randn(2, 3, 6, device=device)
1197        weight = torch.randn(3, 3, 3, device=device)
1198        bias = torch.randn(3, device=device)
1199        stride = (2,)
1200        padding = (1,)
1201        dilation = (1,)
1202        transposed = False
1203        output_padding = (0,)
1204        groups = 1
1205        output = torch.ops.aten.convolution(
1206            input,
1207            weight,
1208            bias,
1209            stride,
1210            padding,
1211            dilation,
1212            transposed,
1213            output_padding,
1214            groups,
1215        )
1216
1217        ggI = torch.randn(input.shape, device=device)
1218        ggW = torch.randn(weight.shape, device=device)
1219        ggB = torch.randn(bias.shape, device=device)
1220        gO = torch.randn(output.shape, device=device)
1221        output_mask = [True, True, True]
1222        (
1223            grad_grad_output,
1224            grad_input,
1225            grad_weight,
1226        ) = torch.ops.aten._convolution_double_backward(
1227            ggI,
1228            ggW,
1229            ggB,
1230            gO,
1231            weight,
1232            input,
1233            stride,
1234            padding,
1235            dilation,
1236            transposed,
1237            output_padding,
1238            groups,
1239            output_mask,
1240        )
1241
1242        self.assertEqual(grad_grad_output.shape, gO.shape)
1243        self.assertEqual(grad_input.shape, input.shape)
1244        self.assertEqual(grad_weight.shape, weight.shape)
1245
1246    @onlyXPU
1247    @dtypes(torch.float16, torch.bfloat16, torch.float32, torch.float64)
1248    def test_channels_last_ouput_stride(self, device, dtype):
1249        input = torch.randn(
1250            (2, 3, 16, 16), device=device, dtype=dtype, requires_grad=True
1251        )
1252        weight = torch.randn(
1253            (512, 3, 3, 3), device=device, dtype=dtype, requires_grad=True
1254        )
1255        input = input.to(memory_format=torch.channels_last)
1256        weight = weight.to(memory_format=torch.channels_last)
1257        out = torch.conv2d(input, weight, None, (2, 2), (0, 0), (1, 1), 1)
1258
1259        if dtype is torch.float64:
1260            # Like most conv backend, xpu does not support float64 for chanel last conv.
1261            # input NHWC, output NCHW
1262            assert_size_stride(out, (2, 512, 7, 7), (25088, 49, 7, 1))
1263        else:
1264            # input NHWC, output NHWC
1265            assert_size_stride(out, (2, 512, 7, 7), (25088, 1, 3584, 512))
1266
1267
1268instantiate_device_type_tests(
1269    TestConvolutionNNDeviceType, globals(), only_for="xpu", allow_xpu=True
1270)
1271
1272if __name__ == "__main__":
1273    run_tests()
1274