xref: /aosp_15_r20/external/pytorch/test/test_nn.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: nn"]
2
3import contextlib
4import math
5import random
6import unittest
7import io
8import itertools
9import warnings
10import pickle
11import re
12from copy import deepcopy
13from itertools import product
14from functools import partial
15from collections import OrderedDict
16from unittest import SkipTest
17
18import torch
19from torch import inf, nan
20import torch.autograd.forward_ad as fwAD
21import torch.backends.cudnn as cudnn
22import torch.nn as nn
23import torch.nn.functional as F
24import torch.nn.utils.rnn as rnn_utils
25from torch.nn.utils import clip_grad_norm_, clip_grad_value_
26from torch.nn.utils import parameters_to_vector, vector_to_parameters
27from torch.nn.utils.fusion import fuse_conv_bn_weights
28from torch.nn.utils.fusion import fuse_linear_bn_weights
29from torch.nn import Buffer, Parameter
30from torch.nn.parallel._functions import Broadcast
31from torch.testing._internal.common_dtype import integral_types, get_all_math_dtypes, floating_types
32from torch.testing._internal.common_utils import freeze_rng_state, run_tests, TestCase, skipIfNoLapack, skipIfRocm, \
33    TEST_NUMPY, TEST_SCIPY, TEST_WITH_CROSSREF, TEST_WITH_ROCM, \
34    download_file, get_function_arglist, load_tests, skipIfMps, \
35    IS_PPC, \
36    parametrize as parametrize_test, subtest, instantiate_parametrized_tests, \
37    skipIfTorchDynamo, gcIfJetson, set_default_dtype
38from torch.testing._internal.common_cuda import TEST_CUDA, TEST_MULTIGPU, TEST_CUDNN, PLATFORM_SUPPORTS_FLASH_ATTENTION
39from torch.testing._internal.common_nn import NNTestCase, NewModuleTest, CriterionTest, \
40    module_tests, criterion_tests, loss_reference_fns, _create_basic_net, \
41    ctcloss_reference, new_module_tests, single_batch_reference_fn, _test_bfloat16_ops, _test_module_empty_input
42from torch.testing._internal.common_device_type import dtypesIfMPS, instantiate_device_type_tests, dtypes, \
43    dtypesIfCUDA, precisionOverride, skipCUDAIfCudnnVersionLessThan, onlyCUDA, onlyCPU, \
44    skipCUDAIfRocm, skipCUDAIf, skipCUDAIfNotRocm, \
45    onlyNativeDeviceTypes, deviceCountAtLeast, largeTensorTest, expectedFailureMeta, expectedFailureMPS, \
46    skipMeta, get_all_device_types
47
48from hypothesis import given
49import torch.testing._internal.hypothesis_utils as hu
50from torch.testing._internal.common_utils import _assertGradAndGradgradChecks, gradcheck, gradgradcheck, \
51    GRADCHECK_NONDET_TOL
52from torch.testing._internal.common_utils import dtype2prec_DONTUSE
53from torch.testing._internal.common_cuda import tf32_on_and_off, tf32_is_not_fp32, tf32_off, tf32_on
54from torch.types import _TensorOrTensors
55from torch.testing._internal.common_mkldnn import bf32_on_and_off
56
57AMPERE_OR_ROCM = TEST_WITH_ROCM or tf32_is_not_fp32()
58
59# load_tests from common_utils is used to automatically filter tests for
60# sharding on sandcastle. This line silences flake warnings
61load_tests = load_tests
62
63if TEST_SCIPY:
64    import scipy.signal
65    import scipy.ndimage
66
67if TEST_NUMPY:
68    import numpy as np
69
70
71# WARNING: If you add a new top-level test case to this file, you MUST
72# update test/run_test.py to list it, otherwise it will NOT be run in
73# CI.
74
75class TestNN(NNTestCase):
76    _do_cuda_memory_leak_check = True
77    _do_cuda_non_default_stream = True
78
79    def _forward(self, module, input: _TensorOrTensors):
80        with freeze_rng_state():
81            if isinstance(input, tuple):
82                return module(*input)
83            else:
84                return module(input)
85
86    def _backward(self, module, input: _TensorOrTensors, output, grad_output, create_graph=False):
87        output.backward(grad_output, retain_graph=True, create_graph=create_graph)
88        if isinstance(input, tuple):
89            return tuple(i.grad.data if i.grad is not None else None for i in input)
90        else:
91            return input.grad.data if input.grad is not None else None
92
93    def _forward_criterion(self, criterion, input, target, extra_args=None):
94        if extra_args is None:
95            extra_args = ()
96        if isinstance(input, tuple):
97            args = input + (target,) + extra_args
98            output = criterion(*args)
99        else:
100            output = criterion(input, target, *extra_args)
101        return output
102
103    def _backward_criterion(self, criterion, input, output, target, gradOutput=None, extra_args=None):
104        if extra_args is None:
105            extra_args = ()
106        input_tuple = input if isinstance(input, tuple) else (input,)
107        output_tuple = output if isinstance(output, tuple) else (output,)
108        for i in input_tuple:
109            if i.grad is not None:
110                i.grad.data.zero_()
111        args = input_tuple + (target,) + extra_args
112        if gradOutput is None:
113            gradOutput = torch.ones(())
114        criterion(*args).backward(gradOutput.to(output_tuple[0]))
115        if isinstance(input, tuple):
116            return tuple(i.grad.data for i in input)
117        else:
118            return input.grad.data
119
120    def _zero_grad_parameters(self, module):
121        for p in module.parameters():
122            if p.grad is not None:
123                with torch.no_grad():
124                    p.grad.zero_()
125                p.grad.detach_()
126
127    def _get_parameters(self, module):
128        params = []
129        d_params = []
130        for p in module.parameters():
131            params.append(p)
132            d_params.append(p.grad)
133        return params, d_params
134
135    def test_parse_to(self):
136        # Test for buggy use of THPMemoryFormat_New
137        self.assertEqual(
138            repr(torch._C._nn._parse_to(memory_format=torch.contiguous_format)[3]),
139            "torch.contiguous_format"
140        )
141
142    def test_requires_grad_(self):
143        m = _create_basic_net()[-1]
144        assert len(list(m.buffers())) > 0, 'invalid test'
145        assert all(not b.requires_grad for b in m.buffers()) > 0, 'invalid test'
146        assert len(list(m.parameters())) > 0, 'invalid test'
147        assert all(p.requires_grad for p in m.parameters()) > 0, 'invalid test'
148        for requires_grad in (False, True):
149            self.assertIs(m.requires_grad_(requires_grad), m)
150            for p in m.parameters():
151                self.assertEqual(p.requires_grad, requires_grad)
152            for b in m.buffers():
153                self.assertFalse(b.requires_grad)
154
155    def test_module_backcompat(self):
156        from torch.serialization import SourceChangeWarning
157        path = download_file('https://download.pytorch.org/test_data/linear.pt')
158        with warnings.catch_warnings():
159            warnings.simplefilter('ignore', SourceChangeWarning)
160            # weights_only=False as this is legacy code that saves the model
161            m = torch.load(path, weights_only=False)
162        input = torch.randn(2, 3, dtype=torch.float)
163        self.assertEqual(m(input).size(), (2, 5))
164
165    def test_module_super_init(self):
166        class MyMixin:
167            def __init__(self, *a, **kw):
168                super().__init__(*a, **kw)
169                self.mixin_init = True
170
171        class MyModuleWithMixinBefore(MyMixin, nn.Module):
172            pass
173
174        class MyModuleWithMixinAfter(nn.Module, MyMixin):
175            pass
176
177        self.assertTrue(hasattr(MyModuleWithMixinBefore(), 'mixin_init'))
178        self.assertFalse(hasattr(MyModuleWithMixinAfter(), 'mixin_init'))
179
180        nn.Module.call_super_init = True
181        self.assertTrue(hasattr(MyModuleWithMixinBefore(), 'mixin_init'))
182        self.assertTrue(hasattr(MyModuleWithMixinAfter(), 'mixin_init'))
183        nn.Module.call_super_init = False
184
185        MyModuleWithMixinBefore.call_super_init = True
186        MyModuleWithMixinAfter.call_super_init = True
187        self.assertTrue(hasattr(MyModuleWithMixinBefore(), 'mixin_init'))
188        self.assertTrue(hasattr(MyModuleWithMixinAfter(), 'mixin_init'))
189        MyModuleWithMixinBefore.call_super_init = False
190        MyModuleWithMixinAfter.call_super_init = False
191
192    def test_share_memory(self):
193        class Net(nn.Module):
194            def __init__(self) -> None:
195                super().__init__()
196                self.p = nn.Parameter(torch.eye(5))
197                self.par = nn.ParameterList()
198                self.par.append(nn.Parameter(torch.randn(10)))
199
200            def forward(self, inp):
201                # NB: dead code
202                return inp.clone()
203
204        net = Net()
205        for p in net.parameters():
206            self.assertFalse(p.storage().is_shared())
207        for b in net.buffers():
208            self.assertFalse(b.storage().is_shared())
209        net.share_memory()
210        for p in net.parameters():
211            self.assertTrue(p.storage().is_shared())
212        for b in net.buffers():
213            self.assertTrue(b.storage().is_shared())
214
215    def test_to(self):
216        m = nn.Linear(3, 5)
217        self.assertIs(m, m.to('cpu'))
218        self.assertIs(m, m.to('cpu', dtype=torch.float32))
219        self.assertEqual(m.double(), m.to(torch.float64))
220        self.assertRaises(RuntimeError, lambda: m.to('cpu', copy=True))
221
222        if torch.cuda.is_available():
223            for cuda in ['cuda', 'cuda:0' if torch.cuda.device_count() == 1 else 'cuda:1']:
224                m2 = m.cuda(device=cuda)
225                self.assertIs(m2, m2.to(cuda))
226                self.assertEqual(m, m2.to('cpu'))
227                self.assertEqual(m2, m.to(cuda))
228                self.assertIs(m2, m2.to(dtype=torch.float32))
229                self.assertEqual(m2.double(), m2.to(dtype=torch.float64))
230
231    def test_zero_grad(self):
232        i = torch.randn(2, 5, requires_grad=True)
233        module = nn.Linear(5, 5)
234        for p in module.parameters():
235            p.requires_grad = False
236        module.zero_grad()
237
238        module.weight.requires_grad = True
239        module.zero_grad()
240        self.assertIsNone(module.weight.grad)  # uninitialized grad
241
242        module(i).sum().backward()
243        self.assertIsNotNone(module.weight.grad)
244        self.assertGreater(module.weight.grad.data.abs().sum(), 0)
245        module.zero_grad()
246        self.assertIsNone(module.weight.grad)
247
248        module.bias.requires_grad = True
249        module.zero_grad()
250        self.assertIsNone(module.weight.grad)
251        self.assertIsNone(module.bias.grad)
252        module(i).sum().backward()
253        self.assertIsNotNone(module.weight.grad)
254        self.assertIsNotNone(module.bias.grad)
255        self.assertGreater(module.weight.grad.data.abs().sum(), 0)
256        self.assertGreater(module.bias.grad.data.abs().sum(), 0)
257        module.zero_grad(set_to_none=False)   # Force set to zeros.
258        self.assertEqual(module.weight.grad.data, module.weight.data.clone().zero_())
259        self.assertEqual(module.bias.grad.data, module.bias.data.clone().zero_())
260
261        module.zero_grad()
262        self.assertIsNone(module.weight.grad)
263        self.assertIsNone(module.bias.grad)
264
265    def test_no_grad(self):
266        for dtype in [torch.bfloat16, torch.float, torch.double]:
267            module = nn.Conv2d(2, 5, kernel_size=3, padding=1).to(dtype)
268            input = torch.randn(1, 2, 10, 10).to(dtype)
269            x = input
270            y = input.clone()
271
272            output = module(x)
273            self.assertTrue(output.requires_grad)
274            output.backward(torch.ones(1, 5, 10, 10))
275
276            with torch.no_grad():
277                output2 = module(y)
278                self.assertFalse(output2.requires_grad)
279                self.assertRaises(RuntimeError, lambda: output2.backward(torch.ones(1, 5, 10, 10)))
280
281    def test_parameters_and_named_parameters(self):
282        def names(named_parameters):
283            return [k for k, _ in named_parameters]
284
285        l, n, s = _create_basic_net()
286
287        self.assertEqual(len(list(l.parameters())), 1)
288        self.assertEqual(
289            names(l.named_parameters()),
290            ['layer_dummy_param'])
291
292        self.assertEqual(len(list(n.parameters())), 2)
293        self.assertEqual(
294            names(n.named_parameters()),
295            ['dummy_param', 'l1.layer_dummy_param'])
296
297        self.assertEqual(len(list(n.parameters(recurse=False))), 1)
298        self.assertEqual(
299            names(n.named_parameters(recurse=False)),
300            ['dummy_param'])
301
302        self.assertEqual(len(list(s.parameters())), 2)
303        self.assertEqual(
304            names(s.named_parameters()),
305            ['0.dummy_param', '0.l1.layer_dummy_param'])
306
307    def test_named_parameters_remove_duplicate(self):
308        def names(named_parameters):
309            return [k for k, _ in named_parameters]
310
311        class M1(nn.Module):
312            def __init__(self) -> None:
313                super().__init__()
314                self.param1 = nn.Parameter(torch.empty(3, 3))
315                self.param2 = self.param1
316
317        m1 = M1()
318        self.assertEqual(names(m1.named_parameters()),
319                         ["param1"])
320        self.assertEqual(names(m1.named_parameters(remove_duplicate=False)),
321                         ["param1", "param2"])
322
323        class M2(nn.Module):
324            def __init__(self) -> None:
325                super().__init__()
326                self.mod1 = nn.Linear(3, 4, bias=False)
327                self.mod2 = self.mod1
328
329        m2 = M2()
330        self.assertEqual(names(m2.named_parameters()),
331                         ["mod1.weight"])
332        self.assertEqual(names(m2.named_parameters(remove_duplicate=False)),
333                         ["mod1.weight", "mod2.weight"])
334
335    def test_buffers_and_named_buffers(self):
336        def names(named_buffers):
337            return [k for k, _ in named_buffers]
338
339        l, n, s = _create_basic_net()
340
341        self.assertEqual(len(list(l.buffers())), 1)
342        self.assertEqual(
343            names(l.named_buffers()),
344            ['layer_dummy_buf'])
345
346        self.assertEqual(len(list(n.buffers())), 2)
347        self.assertEqual(
348            names(n.named_buffers()),
349            ['dummy_buf', 'l1.layer_dummy_buf'])
350
351        self.assertEqual(len(list(n.buffers(recurse=False))), 1)
352        self.assertEqual(
353            names(n.named_buffers(recurse=False)),
354            ['dummy_buf'])
355
356        self.assertEqual(len(list(s.buffers())), 2)
357        self.assertEqual(
358            names(s.named_buffers()),
359            ['0.dummy_buf', '0.l1.layer_dummy_buf'])
360
361        # test remove_duplicate
362        class M(nn.Module):
363            def __init__(self) -> None:
364                super().__init__()
365                self.buffer1 = Buffer(torch.empty(3, 5))
366                self.buffer2 = self.buffer1
367
368        m = M()
369        self.assertEqual(names(m.named_buffers()),
370                         ["buffer1"])
371        self.assertEqual(names(m.named_buffers(remove_duplicate=False)),
372                         ["buffer1", "buffer2"])
373
374    def test_buffer_bad_module_subclass(self):
375        class MyBadModule(nn.Linear):
376            def __init__(self) -> None:
377                super().__init__(2, 2)
378                self.bar = Buffer(torch.rand(2, 2))
379
380            def register_buffer(self, name, value):
381                # persistent is explicitly missing!
382                super().register_buffer(name, value, True)
383
384        foo = MyBadModule()
385        self.assertIsNotNone(foo.bar)
386
387    def test_call_supports_python_dict_output(self):
388        class Net(nn.Module):
389            def __init__(self) -> None:
390                super().__init__()
391                self.l1 = nn.Linear(10, 20)
392                self.register_backward_hook(self.hook)
393                self.check_backward_hook_flag = False
394
395            def hook(self, module, grad_out, grad_in):
396                self.check_backward_hook_flag = True
397
398            def forward(self, inputs):
399                return {"output": self.l1(inputs).sum()}
400
401        net = Net()
402        model_output = net(torch.randn([5, 10]))
403        model_output["output"].backward()
404        self.assertTrue(net.check_backward_hook_flag)
405
406    def test_children(self):
407        l1 = nn.Linear(2, 2)
408        l2 = nn.Linear(2, 2)
409        l3 = nn.Linear(2, 2)
410        l4 = nn.Linear(2, 2)
411        subnet = nn.Sequential(l3, l4)
412        s = nn.Sequential(l1, l2, l1, l2, subnet)
413        self.assertEqual(list(s.children()), [l1, l2, subnet])
414
415    def test_train_errors_for_invalid_mode(self):
416        class SubclassNet(nn.Module):
417            def __init__(self) -> None:
418                super().__init__()
419                self.l1 = nn.Linear(2, 2)
420
421            def forward(self, inputs):
422                return self.l1(inputs)
423
424        subclass_net = SubclassNet()
425        sequential_net = nn.Sequential(nn.Linear(2, 2), nn.Linear(2, 2))
426
427        error_modes = ["invalid_str", torch.device('cpu')]
428        modules_to_check = [subclass_net, sequential_net]
429
430        for error_mode, module in itertools.product(error_modes, modules_to_check):
431            with self.assertRaises(ValueError):
432                module.train(error_mode)
433
434    def test_dir(self):
435        linear = nn.Linear(2, 2)
436        linear._test_submodule = nn.Linear(2, 2)
437        linear._test_parameter = Parameter(torch.empty(2, 2))
438        linear._test_buffer = Buffer(torch.empty(2, 2))
439        keys = dir(linear)
440        self.assertIn('_test_submodule', keys)
441        self.assertIn('_test_parameter', keys)
442        self.assertIn('_test_buffer', keys)
443
444        for key in keys:
445            self.assertTrue(hasattr(linear, key))
446
447    def test_repr(self):
448        # no extra information or sub-modules
449        empty_sequential = nn.Sequential()
450        expected_repr_empty = 'Sequential()'
451        self.assertEqual(repr(empty_sequential), expected_repr_empty)
452
453        # one liner extra information
454        linear = nn.Linear(1, 1)
455        expected_repr_linear = 'Linear(in_features=1, out_features=1, bias=True)'
456        self.assertEqual(repr(linear), expected_repr_linear)
457
458        # sub-modules repr
459        sequential = nn.Sequential(linear)
460        expected_repr_sequential = 'Sequential(\n' \
461            '  (0): Linear(in_features=1, out_features=1, bias=True)\n' \
462            ')'
463        self.assertEqual(repr(sequential), expected_repr_sequential)
464
465    def test_dir_digit(self):
466        model = nn.Sequential(nn.Linear(2, 2))
467        keys = dir(model)
468        self.assertNotIn('0', keys)
469
470    def test_named_children(self):
471        l1 = nn.Linear(2, 2)
472        l2 = nn.Linear(2, 2)
473        l3 = nn.Linear(2, 2)
474        l4 = nn.Linear(2, 2)
475        subnet = nn.Sequential(l3, l4)
476        s = nn.Sequential()
477        with self.assertRaises(KeyError):
478            s.add_module('', l1)
479        with self.assertRaises(KeyError):
480            s.add_module('name.with.dot', l1)
481        s.add_module('layer1', l1)
482        s.add_module('layer2', l2)
483        s.add_module('layer3', l1)
484        s.add_module('layer4', l2)
485        s.add_module('subnet', subnet)
486        self.assertEqual(list(s.named_children()), [('layer1', l1), ('layer2', l2), ('subnet', subnet)])
487
488    def test_modules(self):
489        class Net(nn.Module):
490            def __init__(self) -> None:
491                super().__init__()
492                self.l1 = l
493                self.l2 = l
494                self.param = torch.empty(3, 5)
495
496        l = nn.Linear(10, 20)
497        n = Net()
498        s = nn.Sequential(n, n, n, n)
499        self.assertEqual(list(s.modules()), [s, n, l])
500
501    def test_named_modules(self):
502        class Net(nn.Module):
503            def __init__(self) -> None:
504                super().__init__()
505                self.l1 = l
506                self.l2 = l
507                self.param = torch.empty(3, 5)
508                self.block = block
509        l = nn.Linear(10, 20)
510        l1 = nn.Linear(10, 20)
511        l2 = nn.Linear(10, 20)
512        block = nn.Sequential()
513        block.add_module('linear1', l1)
514        block.add_module('linear2', l2)
515        n = Net()
516        s = nn.Sequential(n, n)
517        self.assertEqual(list(s.named_modules()), [('', s), ('0', n), ('0.l1', l),
518                                                   ('0.block', block), ('0.block.linear1', l1),
519                                                   ('0.block.linear2', l2)])
520        # test the option to not remove duplicate module instances
521        self.assertEqual(list(s.named_modules(remove_duplicate=False)), [
522            ('', s), ('0', n), ('0.l1', l), ('0.l2', l),
523            ('0.block', block), ('0.block.linear1', l1),
524            ('0.block.linear2', l2),
525            ('1', n), ('1.l1', l), ('1.l2', l),
526            ('1.block', block), ('1.block.linear1', l1),
527            ('1.block.linear2', l2)])
528
529    def test_register_buffer_raises_error_if_name_is_not_string(self):
530        m = nn.Module()
531        expected_error = 'buffer name should be a string. Got '
532        with self.assertRaisesRegex(TypeError, expected_error + 'int'):
533            m.register_buffer(1, torch.rand(5))
534        with self.assertRaisesRegex(TypeError, expected_error + 'NoneType'):
535            m.register_buffer(None, torch.rand(5))
536
537    def test_register_buffer_raises_error_if_attr_exists(self):
538        m = nn.Module()
539        m.attribute_name = 5
540        with self.assertRaises(KeyError):
541            m.register_buffer('attribute_name', torch.rand(5))
542
543        with self.assertRaises(KeyError):
544            m.attribute_name = Buffer(torch.rand(5))
545
546        del m.attribute_name
547        m.register_parameter('attribute_name', nn.Parameter())
548        with self.assertRaises(KeyError):
549            m.register_buffer('attribute_name', torch.rand(5))
550
551        del m.attribute_name
552        m.add_module('attribute_name', nn.Module())
553        with self.assertRaises(KeyError):
554            m.register_buffer('attribute_name', torch.rand(5))
555
556    def test_register_buffer_raises_error_if_not_tensor(self):
557        m = nn.Module()
558        with self.assertRaises(TypeError):
559            m.register_buffer('attribute_name', 5)
560
561    def test_register_buffer_allows_overwriting_with_same_name(self):
562        m = nn.Module()
563        buffer1 = torch.rand(5)
564        buffer2 = buffer1 + 5
565        buffer3 = None
566        m.register_buffer('buffer_name', buffer1)
567        self.assertEqual(m.buffer_name, buffer1)
568        m.register_buffer('buffer_name', buffer2)
569        self.assertEqual(m.buffer_name, buffer2)
570        m.register_buffer('buffer_name', buffer3)
571        self.assertEqual(m.buffer_name, buffer3)
572        m.buffer_name = Buffer(buffer1)
573        self.assertEqual(m.buffer_name, Buffer(buffer1))
574        m.buffer_name = Buffer(buffer2)
575        self.assertEqual(m.buffer_name, Buffer(buffer2))
576        m.buffer_name = Buffer(buffer3)
577        self.assertEqual(m.buffer_name, Buffer(buffer3))
578
579    def test_get_buffer(self):
580        m = nn.Module()
581        buffer1 = torch.randn(2, 3)
582        buffer2 = torch.randn(4, 5)
583        m.foo = Buffer(buffer1)
584        m.register_buffer('bar', buffer2)
585        self.assertEqual(buffer1, m.get_buffer('foo'))
586        self.assertEqual(buffer2, m.get_buffer('bar'))
587
588    def test_get_buffer_from_submodules(self):
589        class MyModule(nn.Module):
590            def __init__(self, foo, bar):
591                super().__init__()
592                self.sub = Sub(foo, bar)
593
594        class Sub(nn.Module):
595            def __init__(self, foo, bar):
596                super().__init__()
597                self.foo = Buffer(foo)
598                self.subsub = SubSub(bar)
599
600        class SubSub(nn.Module):
601            def __init__(self, bar):
602                super().__init__()
603                self.bar = Buffer(bar)
604
605        foo = torch.randn(2, 3)
606        bar = torch.randn(4, 5)
607        m = MyModule(foo, bar)
608        self.assertEqual(foo, m.get_buffer('sub.foo'))
609        self.assertEqual(bar, m.get_buffer('sub.subsub.bar'))
610
611    def test_buffer_not_persistent(self):
612        m = nn.Module()
613        m.buf = nn.Buffer(torch.rand(5), persistent=False)
614        self.assertTrue(len(list(m.buffers())) == 1)
615        self.assertTrue(len(m.state_dict()) == 0)
616
617    def test_buffer_not_persistent_del(self):
618        m = nn.Module()
619        m.buf = nn.Buffer(torch.rand(5), persistent=False)
620        del m.buf
621        self.assertTrue(len(list(m.buffers())) == 0)
622
623    def test_buffer_not_persistent_overwrite(self):
624        m = nn.Module()
625        m.buf = nn.Buffer(torch.rand(5), persistent=False)
626        m.buf = nn.Buffer(torch.rand(5))
627
628        # can we overwrite a non-persistent buffer with a persistent one?
629        self.assertTrue(len(list(m.buffers())) == 1)
630        self.assertTrue(len(m.state_dict()) == 1)
631
632        # can we overwrite a persistent buffer with a non-persistent one?
633        m.buf = nn.Buffer(torch.rand(5), persistent=False)
634        self.assertTrue(len(list(m.buffers())) == 1)
635        self.assertTrue(len(m.state_dict()) == 0)
636
637    def test_buffer_not_persistent_assign(self):
638        m = nn.Module()
639        m.buf = nn.Buffer(torch.rand(5), persistent=False)
640        self.assertTrue(len(list(m.buffers())) == 1)
641        self.assertTrue(len(m.state_dict()) == 0)
642
643        # Assigning None removes the buffer but if we then assign a new Tensor
644        # to the same property, it should still be marked as a buffer.
645        m.buf = None
646        self.assertTrue(len(list(m.buffers())) == 0)
647        self.assertTrue(len(m.state_dict()) == 0)
648        m.buf = torch.rand(5)
649        self.assertTrue(len(list(m.buffers())) == 1)
650        self.assertTrue(len(m.state_dict()) == 0)
651
652        # Assigning a Parameter removes the buffer.
653        m.buf = nn.Parameter(torch.rand(5))
654        self.assertTrue(len(list(m.buffers())) == 0)
655        self.assertTrue(len(m.state_dict()) == 1)
656
657    def test_buffer_not_persistent_load(self):
658        m = nn.Module()
659        m.buf = nn.Buffer(torch.rand(5), persistent=False)
660        m.load_state_dict({})
661
662    def test_register_parameter_raises_error_if_name_is_not_string(self):
663        m = nn.Module()
664        expected_error = 'parameter name should be a string. Got '
665        with self.assertRaisesRegex(TypeError, expected_error + 'int'):
666            m.register_parameter(1, nn.Parameter())
667        with self.assertRaisesRegex(TypeError, expected_error + 'NoneType'):
668            m.register_parameter(None, nn.Parameter())
669
670    def test_register_parameter_raises_error_if_attr_exists(self):
671        m = nn.Module()
672        m.attribute_name = 5
673        with self.assertRaises(KeyError):
674            m.register_parameter('attribute_name', nn.Parameter())
675
676        del m.attribute_name
677        m.register_buffer('attribute_name', torch.rand(5))
678        with self.assertRaises(KeyError):
679            m.register_parameter('attribute_name', nn.Parameter())
680
681        del m.attribute_name
682        m.attribute_name = Buffer(torch.rand(5))
683        with self.assertRaises(KeyError):
684            m.register_parameter('attribute_name', nn.Parameter())
685
686        del m.attribute_name
687        m.add_module('attribute_name', nn.Module())
688        with self.assertRaises(KeyError):
689            m.register_parameter('attribute_name', nn.Parameter())
690
691    def test_register_parameter_allows_overwriting_with_same_name(self):
692        m = nn.Module()
693        param1 = nn.Parameter(torch.rand(5))
694        param2 = nn.Parameter(param1.data + 5)
695        param3 = None
696        m.register_parameter('param_name', param1)
697        self.assertEqual(m.param_name, param1)
698        m.register_parameter('param_name', param2)
699        self.assertEqual(m.param_name, param2)
700        m.register_parameter('param_name', param3)
701        self.assertEqual(m.param_name, param3)
702
703    def test_add_module_raises_error_if_attr_exists(self):
704        methods_to_test = ['add_module', 'register_module']
705        for fn in methods_to_test:
706            m = nn.Module()
707            m.attribute_name = 5
708            with self.assertRaises(KeyError):
709                getattr(m, fn)('attribute_name', nn.Module())
710
711            del m.attribute_name
712            m.register_buffer('attribute_name', torch.rand(5))
713            with self.assertRaises(KeyError):
714                getattr(m, fn)('attribute_name', nn.Module())
715
716            del m.attribute_name
717            m.register_parameter('attribute_name', nn.Parameter())
718            with self.assertRaises(KeyError):
719                getattr(m, fn)('attribute_name', nn.Module())
720
721    @unittest.expectedFailure
722    def test_getattr_with_property(self):
723        class Model(nn.Module):
724            @property
725            def some_property(self):
726                return self.something_that_doesnt_exist
727
728        model = Model()
729
730        with self.assertRaisesRegex(
731                AttributeError,
732                r"'Model' object has no attribute 'something_that_doesnt_exist'"):
733            model.some_property
734
735    def test_Sequential_getitem(self):
736        l1 = nn.Linear(10, 20)
737        l2 = nn.Linear(20, 30)
738        l3 = nn.Linear(30, 40)
739        l4 = nn.Linear(40, 50)
740        n = nn.Sequential(l1, l2, l3, l4)
741        self.assertIs(n[0], l1)
742        self.assertIs(n[1], l2)
743        self.assertIs(n[2], l3)
744        self.assertIs(n[3], l4)
745        self.assertIs(n[torch.tensor(3, dtype=torch.int64)], l4)
746        self.assertEqual(n[1:], nn.Sequential(l2, l3, l4))
747        self.assertEqual(n[3:], nn.Sequential(l4))
748        self.assertEqual(n[:-1], nn.Sequential(l1, l2, l3))
749        self.assertEqual(n[:-3], nn.Sequential(l1))
750        self.assertEqual(n[::-1], nn.Sequential(l4, l3, l2, l1))
751
752    def test_Sequential_setitem(self):
753        l1 = nn.Linear(10, 20)
754        l2 = nn.Linear(20, 30)
755        l3 = nn.Linear(30, 40)
756        l4 = nn.Linear(40, 50)
757        n = nn.Sequential(l1, l2, l3)
758        n[0] = l4
759        n[-1] = l4
760        n[torch.tensor(1, dtype=torch.int16)] = l1
761        self.assertIs(n[0], l4)
762        self.assertIs(n[1], l1)
763        self.assertIs(n[2], l4)
764
765    def test_Sequential_setitem_named(self):
766        l1 = nn.Linear(10, 20)
767        l2 = nn.Linear(20, 30)
768        l3 = nn.Linear(30, 40)
769        l4 = nn.Linear(40, 50)
770        n = nn.Sequential(OrderedDict([
771            ('linear1', l1),
772            ('linear2', l2),
773            ('linear3', l3),
774        ]))
775
776        n[0] = l4
777        n[-1] = l4
778        self.assertEqual(n.linear1, l4)
779        self.assertEqual(n.linear3, l4)
780
781    def test_Sequential_delitem(self):
782        l1 = nn.Linear(10, 20)
783        l2 = nn.Linear(20, 30)
784        l3 = nn.Linear(30, 40)
785        l4 = nn.Linear(40, 50)
786        n = nn.Sequential(l1, l2, l3, l4)
787        del n[-1]
788        self.assertEqual(n, nn.Sequential(l1, l2, l3))
789        del n[1::2]
790        self.assertEqual(n, nn.Sequential(l1, l3))
791
792    def test_Sequential_add(self):
793        l1 = nn.Linear(1, 2)
794        l2 = nn.Linear(2, 3)
795        l3 = nn.Linear(3, 4)
796        l4 = nn.Linear(4, 5)
797        n = nn.Sequential(l1, l2)
798        other = nn.Sequential(l3, l4)
799        self.assertEqual(n + other, nn.Sequential(l1, l2, l3, l4))
800
801    def test_Sequential_iadd(self):
802        l1 = nn.Linear(10, 20)
803        l2 = nn.Linear(20, 30)
804        l3 = nn.Linear(30, 40)
805        l4 = nn.Linear(40, 50)
806        n = nn.Sequential(l1, l2, l3)
807        n2 = nn.Sequential(l4)
808        n += n2
809        n2 += n
810        self.assertEqual(n, nn.Sequential(l1, l2, l3, l4))
811        self.assertEqual(n2, nn.Sequential(l4, l1, l2, l3, l4))
812
813    def test_Sequential_mul(self):
814        l1 = nn.Linear(10, 20)
815        l2 = nn.Linear(20, 30)
816        l3 = nn.Linear(30, 40)
817        l4 = nn.Linear(40, 50)
818        n = nn.Sequential(l1, l2, l3, l4)
819        n2 = n * 2
820        self.assertEqual(n2, nn.Sequential(l1, l2, l3, l4, l1, l2, l3, l4))
821
822    def test_Sequential_rmul(self):
823        l1 = nn.Linear(10, 20)
824        l2 = nn.Linear(20, 30)
825        l3 = nn.Linear(30, 40)
826        l4 = nn.Linear(40, 50)
827        n = nn.Sequential(l1, l2, l3, l4)
828        n2 = 2 * n
829        self.assertEqual(n2, nn.Sequential(l1, l2, l3, l4, l1, l2, l3, l4))
830
831    def test_Sequential_imul(self):
832        l1 = nn.Linear(10, 20)
833        l2 = nn.Linear(20, 30)
834        l3 = nn.Linear(30, 40)
835        l4 = nn.Linear(40, 50)
836        n = nn.Sequential(l1, l2, l3, l4)
837        n *= 2
838        self.assertEqual(n, nn.Sequential(l1, l2, l3, l4, l1, l2, l3, l4))
839        n *= 2
840        self.assertEqual(
841            n,
842            nn.Sequential(l1, l2, l3, l4, l1, l2, l3, l4, l1, l2, l3, l4, l1, l2, l3, l4)
843        )
844
845    def test_Sequential_append(self):
846        l1 = nn.Linear(10, 20)
847        l2 = nn.Linear(20, 30)
848        l3 = nn.Linear(30, 40)
849        l4 = nn.Linear(40, 50)
850        n = nn.Sequential(l1, l2, l3)
851        n2 = n.append(l4)
852        self.assertEqual(n, nn.Sequential(l1, l2, l3, l4))
853        self.assertEqual(n2, nn.Sequential(l1, l2, l3, l4))
854        self.assertEqual(nn.Sequential(l1).append(l2).append(l4), nn.Sequential(l1, l2, l4))
855
856    def test_Sequential_pop(self):
857        l1 = nn.Linear(1, 2)
858        l2 = nn.Linear(2, 3)
859        l3 = nn.Linear(3, 4)
860        l4 = nn.Linear(4, 5)
861        n1 = nn.Sequential(l1, l2, l3, l4)
862        self.assertEqual(l4, n1.pop(3))
863        n2 = nn.Sequential(l1, l2, l3)
864        self.assertEqual(n1, n2)
865        # check order of the index
866        for k, mod in zip(range(len(n1)), n1):
867            self.assertIs(n1[k], mod)
868
869    def test_Sequential_insert(self):
870        l1 = nn.Linear(1, 2)
871        l2 = nn.Linear(2, 3)
872        l3 = nn.Linear(3, 4)
873
874        n1 = nn.Sequential(l1, l2, l3)
875        module_1 = nn.Linear(4, 5)
876        n2 = nn.Sequential(l1, module_1, l2, l3)
877        self.assertEqual(n1.insert(1, module_1), n2)
878
879        # test for negative support
880        n3 = nn.Sequential(l1, l2, l3)
881        module_2 = nn.Linear(5, 6)
882        n4 = nn.Sequential(l1, module_2, l2, l3)
883        self.assertEqual(n3.insert(-2, module_2), n4)
884
885    def test_Sequential_insert_fail_case(self):
886        l1 = nn.Linear(1, 2)
887        l2 = nn.Linear(2, 3)
888        l3 = nn.Linear(3, 4)
889
890        module = nn.Linear(5, 6)
891
892        # test for error case
893        n1 = nn.Sequential(l1, l2, l3)
894        with self.assertRaises(IndexError):
895            n1.insert(-5, module)
896
897        with self.assertRaises(AssertionError):
898            n1.insert(1, [nn.Linear(6, 7)])
899
900    def test_Sequential_extend(self):
901        l1 = nn.Linear(10, 20)
902        l2 = nn.Linear(20, 30)
903        l3 = nn.Linear(30, 40)
904        l4 = nn.Linear(40, 50)
905        n1 = nn.Sequential(l1, l2)
906        n2 = nn.Sequential(l3, l4)
907        n3 = nn.Sequential(l1, l2)
908        for l in n2:
909            n1.append(l)
910        n3.extend(n2)
911        self.assertEqual(n3, n1)
912
913    def test_ModuleList(self):
914        modules = [nn.ReLU(), nn.Linear(5, 5)]
915        module_list = nn.ModuleList(modules)
916
917        def check():
918            self.assertEqual(len(module_list), len(modules))
919            for m1, m2 in zip(modules, module_list):
920                self.assertIs(m1, m2)
921            for m1, m2 in zip(modules, module_list.children()):
922                self.assertIs(m1, m2)
923            for i in range(len(modules)):
924                self.assertIs(module_list[i], modules[i])
925
926        check()
927        modules += [nn.Conv2d(3, 4, 3)]
928        module_list += [modules[-1]]
929        check()
930        modules = modules + [nn.Conv2d(3, 4, 3, bias=False), nn.GELU()]
931        module_list = module_list + nn.ModuleList(modules[-2:])
932        check()
933        modules.insert(1, nn.Linear(3, 2))
934        module_list.insert(1, modules[1])
935        check()
936        modules.append(nn.Tanh())
937        module_list.append(modules[-1])
938        check()
939        next_modules = [nn.Linear(5, 5), nn.Sigmoid()]
940        modules.extend(next_modules)
941        module_list.extend(next_modules)
942        check()
943        modules[2] = nn.Conv2d(5, 3, 2)
944        module_list[2] = modules[2]
945        check()
946        modules[-1] = nn.Conv2d(5, 2, 1)
947        module_list[-1] = modules[-1]
948        check()
949        idx = torch.tensor(2, dtype=torch.int32)
950        modules[2] = nn.Conv2d(5, 3, 2)
951        module_list[idx] = modules[2]
952        self.assertIs(module_list[idx], modules[2])
953        check()
954        self.assertEqual(module_list[1:], nn.ModuleList(modules[1:]))
955        self.assertEqual(module_list[3:], nn.ModuleList(modules[3:]))
956        self.assertEqual(module_list[:-1], nn.ModuleList(modules[:-1]))
957        self.assertEqual(module_list[:-3], nn.ModuleList(modules[:-3]))
958        self.assertEqual(module_list[::-1], nn.ModuleList(modules[::-1]))
959        del module_list[-1]
960        self.assertEqual(module_list, nn.ModuleList(modules[:-1]))
961        del module_list[1::2]
962        self.assertEqual(module_list, nn.ModuleList(modules[:-1][0::2]))
963
964        with self.assertRaises(TypeError):
965            module_list += nn.ReLU()
966        with self.assertRaises(TypeError):
967            module_list.extend(nn.ReLU())
968
969        l1 = nn.Linear(1, 2)
970        l2 = nn.Linear(2, 3)
971        l3 = nn.Linear(3, 2)
972        l4 = nn.Linear(2, 3)
973        subnet = nn.Sequential(l3, l4)
974        s = nn.Sequential(
975            OrderedDict([
976                ("layer1", l1),
977                ("layer2", l2),
978                ("layer3", l3),
979                ("layer4", l4),
980                ("subnet_layer", subnet)
981            ])
982        )
983        modules = list(s.modules())
984        module_list = nn.ModuleList()
985        module_list.extend(s.modules())
986        check()
987
988        modules = [nn.ReLU(), nn.Linear(5, 5), nn.Conv2d(3, 4, 3)]
989        module_list = nn.ModuleList(modules)
990        self.assertEqual(modules.pop(1), module_list.pop(1))
991        self.assertEqual(modules, module_list)
992        # check order of the index
993        for k, mod in zip(range(len(module_list)), module_list):
994            self.assertIs(module_list[k], mod)
995
996        # verify the right exception is thrown when trying to "forward" through a ModuleList
997        self.assertRaises(NotImplementedError, module_list)
998        self.assertRaises(NotImplementedError, module_list, torch.rand(1, 3))
999
1000    def test_ModuleDict(self):
1001        modules = OrderedDict([
1002            ('act', nn.ReLU()),
1003            ('conv', nn.Conv2d(10, 10, 5)),
1004            ('fc', nn.Linear(5, 5)),
1005        ])
1006
1007        module_dict = nn.ModuleDict(modules)
1008
1009        def check():
1010            self.assertEqual(len(module_dict), len(modules))
1011            for k1, m2 in zip(modules, module_dict.children()):
1012                self.assertIs(modules[k1], m2)
1013            for k1, k2 in zip(modules, module_dict):
1014                self.assertIs(modules[k1], module_dict[k2])
1015            for k in module_dict:
1016                self.assertIs(module_dict[k], modules[k])
1017            for k in module_dict.keys():
1018                self.assertIs(module_dict[k], modules[k])
1019            for k, v in module_dict.items():
1020                self.assertIs(modules[k], v)
1021            for k1, m2 in zip(modules, module_dict.values()):
1022                self.assertIs(modules[k1], m2)
1023            for k in modules.keys():
1024                self.assertTrue(k in module_dict)
1025        check()
1026
1027        modules['conv'] = nn.Conv2d(3, 4, 3)
1028        module_dict['conv'] = modules['conv']
1029        check()
1030
1031        next_modules = [
1032            ('fc2', nn.Linear(5, 5)),
1033            ('act', nn.Sigmoid()),
1034        ]
1035        modules.update(next_modules)
1036        module_dict.update(next_modules)
1037        check()
1038
1039        next_modules = OrderedDict([
1040            ('fc3', nn.Linear(5, 5)),
1041            ('act2', nn.Sigmoid()),
1042        ])
1043        modules.update(next_modules)
1044        module_dict.update(next_modules)
1045        check()
1046
1047        next_modules = {
1048            'fc4': nn.Linear(5, 5),
1049            'act3': nn.Sigmoid()
1050        }
1051        modules.update(next_modules.items())
1052        module_dict.update(next_modules)
1053        check()
1054
1055        next_modules = nn.ModuleDict([
1056            ('fc5', nn.Linear(5, 5)),
1057            ('act4', nn.Sigmoid()),
1058        ])
1059        modules.update(next_modules)
1060        module_dict.update(next_modules)
1061        check()
1062
1063        del module_dict['fc']
1064        del modules['fc']
1065        check()
1066
1067        with self.assertRaises(TypeError):
1068            module_dict.update(nn.ReLU())
1069
1070        with self.assertRaises(TypeError):
1071            module_dict.update([nn.ReLU()])
1072
1073        with self.assertRaises(ValueError):
1074            module_dict.update([[nn.ReLU()]])
1075
1076        with self.assertRaises(TypeError):
1077            module_dict[1] = nn.ReLU()
1078
1079        s = nn.Sequential(modules)
1080        module_dict = nn.ModuleDict(s.named_children())
1081        check()
1082
1083        c = module_dict.pop('conv')
1084        self.assertIs(c, modules['conv'])
1085        modules.pop('conv')
1086        check()
1087
1088        module_dict.clear()
1089        self.assertEqual(len(module_dict), 0)
1090        modules.clear()
1091        check()
1092
1093        # verify the right exception is thrown when trying to "forward" through a ModuleDict
1094        self.assertRaises(NotImplementedError, module_dict)
1095        self.assertRaises(NotImplementedError, module_dict, torch.rand(1, 3))
1096
1097    @skipIfTorchDynamo()
1098    def test_ParameterList(self):
1099        def make_param():
1100            return Parameter(torch.randn(2, 2))
1101        parameters = [make_param(), make_param()]
1102        param_list = nn.ParameterList(parameters)
1103
1104        def check():
1105            self.assertEqual(len(parameters), len(param_list))
1106            for p1, p2 in zip(parameters, param_list):
1107                self.assertIs(p1, p2)
1108            for p1, p2 in zip(filter(lambda x: isinstance(x, Parameter), parameters), param_list.parameters()):
1109                self.assertIs(p1, p2)
1110            for i in range(len(parameters)):
1111                self.assertIs(parameters[i], param_list[i])
1112
1113        check()
1114        parameters += [make_param()]
1115        param_list += [parameters[-1]]
1116        check()
1117        parameters.append(make_param())
1118        param_list.append(parameters[-1])
1119        check()
1120        next_params = [make_param(), make_param()]
1121        parameters.extend(next_params)
1122        param_list.extend(next_params)
1123        check()
1124        parameters[2] = make_param()
1125        param_list[2] = parameters[2]
1126        check()
1127        parameters[-1] = make_param()
1128        param_list[-1] = parameters[-1]
1129        check()
1130        idx = torch.tensor(2, dtype=torch.int32)
1131        parameters[2] = make_param()
1132        param_list[idx] = parameters[2]
1133        self.assertIs(param_list[idx], parameters[2])
1134        check()
1135        self.assertEqual(param_list[1:], nn.ParameterList(parameters[1:]))
1136        self.assertEqual(param_list[3:], nn.ParameterList(parameters[3:]))
1137        self.assertEqual(param_list[:-1], nn.ParameterList(parameters[:-1]))
1138        self.assertEqual(param_list[:-3], nn.ParameterList(parameters[:-3]))
1139        self.assertEqual(param_list[::-1], nn.ParameterList(parameters[::-1]))
1140
1141        with self.assertRaises(TypeError):
1142            param_list += make_param()
1143        with self.assertRaises(TypeError):
1144            param_list.extend(make_param())
1145
1146        l1 = nn.Linear(1, 2)
1147        l2 = nn.Linear(2, 3)
1148        l3 = nn.Linear(3, 2)
1149        l4 = nn.Linear(2, 3)
1150        subnet = nn.Sequential(l3, l4)
1151        s = nn.Sequential(
1152            OrderedDict([
1153                ("layer1", l1),
1154                ("layer2", l2),
1155                ("layer3", l3),
1156                ("layer4", l4),
1157                ("subnet_layer", subnet)
1158            ])
1159        )
1160        parameters = list(s.parameters())
1161        param_list = nn.ParameterList()
1162        param_list.extend(s.parameters())
1163        check()
1164
1165        param_list.append(torch.rand(2, 2))
1166        self.assertIsInstance(param_list[-1], Parameter)
1167        parameters.append(param_list[-1])
1168
1169        param_list.extend([torch.rand(2, 2), "foo"])
1170        self.assertIsInstance(param_list[-2], Parameter)
1171        self.assertIsInstance(param_list[-1], str)
1172        parameters.extend(param_list[-2:])
1173
1174        param_list += ["bar", torch.rand(2, 2)]
1175        self.assertIsInstance(param_list[-2], str)
1176        self.assertIsInstance(param_list[-1], Parameter)
1177        parameters += param_list[-2:]
1178        check()
1179
1180    def test_ParameterList_meta(self):
1181        p = torch.nn.Parameter(torch.empty(1, device='meta'))
1182        self.assertExpectedInline(str(p), """\
1183Parameter containing:
1184tensor(..., device='meta', size=(1,), requires_grad=True)""")
1185        pl = torch.nn.ParameterList([p])
1186        self.assertExpectedInline(str(pl), """ParameterList(  (0): Parameter containing: [torch.float32 of size 1])""")
1187
1188    def test_ParameterList_replication(self):
1189        # The actual replication code from DP cannot be used on CPU so doing it manually here
1190        def make_param():
1191            return Parameter(torch.randn(2, 2))
1192        parameters = [make_param(), make_param()]
1193        param_list = nn.ParameterList(parameters)
1194
1195        new_param_list = param_list._replicate_for_data_parallel()
1196
1197        for n, p in param_list.named_parameters():
1198            # Do a view here so that we can check the base later
1199            setattr(new_param_list, n, p.view_as(p))
1200
1201        for p, p2 in zip(param_list, new_param_list):
1202            self.assertEqual(p, p2)
1203            self.assertIsNotNone(p2.grad_fn)
1204            self.assertIs(p2._base, p)
1205
1206    def test_ParameterDict(self):
1207        parameters = OrderedDict([
1208            ('p1', Parameter(torch.randn(10, 10))),
1209            ('p2', Parameter(torch.randn(10, 10))),
1210            ('p3', Parameter(torch.randn(10, 10))),
1211        ])
1212
1213        parameter_dict = nn.ParameterDict(parameters)
1214
1215        def check():
1216            self.assertEqual(len(parameter_dict), len(parameters))
1217            for i, (k1, (k2, m2)) in enumerate(zip(parameters, parameter_dict.named_parameters())):
1218                self.assertEqual(k1, k2)
1219                self.assertIs(parameters[k1], m2)
1220            for k1, k2 in zip(parameters, parameter_dict):
1221                self.assertIs(parameters[k1], parameter_dict[k2])
1222            for k in parameter_dict:
1223                self.assertIs(parameter_dict[k], parameters[k])
1224            for k in parameter_dict.keys():
1225                self.assertIs(parameter_dict[k], parameters[k])
1226            for k, v in parameter_dict.items():
1227                self.assertIs(v, parameters[k])
1228            for k1, m2 in zip(parameters, parameter_dict.values()):
1229                self.assertIs(parameters[k1], m2)
1230            for k in parameters.keys():
1231                self.assertTrue(k in parameter_dict)
1232
1233        check()
1234
1235        parameters['p4'] = Parameter(torch.randn(10, 10))
1236        parameter_dict['p4'] = parameters['p4']
1237        check()
1238
1239        next_parameters = [
1240            ('p5', Parameter(torch.randn(10, 10))),
1241            ('p2', Parameter(torch.randn(10, 10))),
1242        ]
1243        parameters.update(next_parameters)
1244        parameter_dict.update(next_parameters)
1245        check()
1246
1247        next_parameters = OrderedDict([
1248            ('p6', Parameter(torch.randn(10, 10))),
1249            ('p5', Parameter(torch.randn(10, 10))),
1250        ])
1251        parameters.update(next_parameters)
1252        parameter_dict.update(next_parameters)
1253        check()
1254
1255        next_parameters = {
1256            'p8': Parameter(torch.randn(10, 10)),
1257            'p7': Parameter(torch.randn(10, 10))
1258        }
1259        parameters.update(sorted(next_parameters.items()))
1260        parameter_dict.update(next_parameters)
1261        check()
1262
1263        next_parameters = nn.ParameterDict([
1264            ('p10', Parameter(torch.randn(10, 10))),
1265            ('p9', Parameter(torch.randn(10, 10))),
1266        ])
1267        parameters.update(next_parameters)
1268        parameter_dict.update(next_parameters)
1269        check()
1270
1271        del parameter_dict['p3']
1272        del parameters['p3']
1273        check()
1274
1275        with self.assertRaises(TypeError):
1276            parameter_dict.update(1)
1277
1278        with self.assertRaises(TypeError):
1279            parameter_dict.update([1])
1280
1281        with self.assertRaises(ValueError):
1282            parameter_dict.update(Parameter(torch.randn(10, 10)))
1283
1284        p_pop = parameter_dict.pop('p4')
1285        self.assertIs(p_pop, parameters['p4'])
1286        parameters.pop('p4')
1287        check()
1288
1289        # Check reverse works
1290        forward = list(iter(parameter_dict))
1291        backward = list(reversed(parameter_dict))
1292        self.assertEqual(len(forward), len(backward))
1293        n = len(forward)
1294        for i in range(n):
1295            self.assertIs(forward[i], backward[n - i - 1])
1296        check()
1297
1298        # Check copy works
1299        copy = parameter_dict.copy()
1300
1301        # Check all keys are present and have shallow copied values
1302        for key in parameter_dict:
1303            self.assertTrue(key in copy)
1304            self.assertEqual(parameter_dict[key], copy[key])
1305            self.assertIs(parameter_dict[key], copy[key])
1306        check()
1307
1308        parameter_dict["p20"] = Parameter(torch.randn(10, 10))
1309        copy["p21"] = Parameter(torch.randn(9, 10))
1310
1311        self.assertTrue("p20" in parameter_dict)
1312        self.assertFalse("p20" in copy)
1313        self.assertFalse("p21" in parameter_dict)
1314        self.assertTrue("p21" in copy)
1315        parameter_dict.pop("p20")
1316        check()
1317
1318        p = Parameter(torch.randn(10, 10))
1319        parameter_dict['p12'] = p
1320        p_popitem = parameter_dict.popitem()
1321        self.assertEqual(p_popitem[0], 'p12')
1322        self.assertIs(p_popitem[1], p)
1323        check()
1324
1325        # Unit test for set_default
1326        # 1. Ensure parameter is correctly inserted when
1327        #    the key is not present in `ParameterDict`
1328        assert 'p11' not in parameter_dict
1329        assert 'p11' not in parameters
1330        parameters['p11'] = Parameter(torch.randn(10, 10))
1331        p_setdefault = parameter_dict.setdefault('p11', parameters['p11'])
1332        self.assertIs(p_setdefault, parameters['p11'])
1333        self.assertIs(p_setdefault, parameter_dict['p11'])
1334        check()
1335        # 2. Ensure parameter is NOT inserted when the
1336        #    key is already present in `ParameterDict`
1337        p = Parameter(torch.randn(10, 10))
1338        self.assertFalse(parameter_dict.setdefault('p11', p) is p)
1339        check()
1340        # 3. Ensure `None` is inserted when the key is not
1341        #    present in `Parameter` and parameter is not specified
1342        self.assertIs(parameter_dict.setdefault('p26'), None)
1343        del parameter_dict['p26']
1344        check()
1345
1346        parameters2 = OrderedDict([
1347            ('p13', Parameter(torch.randn(10, 10))),
1348            ('p2', Parameter(torch.randn(10, 10))),
1349            ('p3', Parameter(torch.randn(10, 10))),
1350        ])
1351        parameter_dict2 = nn.ParameterDict(parameters2)
1352        parameters.update(parameters2)
1353        parameter_dict |= parameter_dict2
1354        check()
1355
1356        parameters2 = OrderedDict()
1357        parameter_dict2 = nn.ParameterDict(parameters2)
1358        parameters.update(parameters2)
1359        parameter_dict |= parameter_dict2
1360        check()
1361
1362        parameters2 = OrderedDict([
1363            ('p14', Parameter(torch.randn(10, 10))),
1364            ('p15', Parameter(torch.randn(10, 10))),
1365            ('p13', Parameter(torch.randn(10, 10))),
1366        ])
1367        parameter_dict2 = nn.ParameterDict(parameters2)
1368        parameters.update(parameters2)
1369        parameter_dict |= parameter_dict2
1370        check()
1371
1372        # Check __or__ and __ror__ works
1373        parameters2 = OrderedDict([
1374            ('p20', Parameter(torch.randn(10, 10))),
1375            ('p21', Parameter(torch.randn(10, 10))),
1376            ('p22', Parameter(torch.randn(10, 10))),
1377        ])
1378        parameter_dict2 = nn.ParameterDict(parameters2)
1379        parameters.update(parameters2)
1380        parameter_dict = parameter_dict | parameter_dict2
1381        check()
1382
1383        parameters2 = OrderedDict([
1384            ('p23', Parameter(torch.randn(10, 10))),
1385            ('p24', Parameter(torch.randn(10, 10))),
1386            ('p25', Parameter(torch.randn(10, 10))),
1387        ])
1388        parameter_dict2 = nn.ParameterDict(parameters2)
1389        parameters2.update(parameters)
1390        parameters = parameters2
1391        parameter_dict = parameter_dict2 | parameter_dict
1392        check()
1393
1394        parameters['p17'] = Parameter(torch.randn(10, 10))
1395        parameter_dict['p17'] = parameters['p17']
1396        self.assertIs(parameters['p17'], parameter_dict.get('p17'))
1397        temp_param = Parameter(torch.randn(10, 10))
1398        self.assertIs(parameters['p17'], parameter_dict.get('p17', temp_param))
1399        self.assertIs(None, parameter_dict.get('p18'))
1400        self.assertIs(temp_param, parameter_dict.get('p18', temp_param))
1401        check()
1402
1403        parameter_dict.clear()
1404        self.assertEqual(len(parameter_dict), 0)
1405        parameters.clear()
1406        check()
1407
1408        parameter_dict2 = parameter_dict.fromkeys(['p19', 'p20'])
1409        self.assertEqual({'p19': None, 'p20': None}, parameter_dict2)
1410        check()
1411
1412        parameter_dict2 = parameter_dict.fromkeys(['p19', 'p20'], temp_param)
1413        self.assertEqual({'p19': temp_param, 'p20': temp_param}, parameter_dict2)
1414        check()
1415
1416        parameter_dict['p21'] = torch.rand(2, 2)
1417        self.assertIsInstance(parameter_dict['p21'], Parameter)
1418        parameters['p21'] = parameter_dict['p21']
1419
1420        parameter_dict.update({'p22': torch.rand(2, 2), 'foo': 'bar'})
1421        self.assertIsInstance(parameter_dict['p22'], Parameter)
1422        self.assertIsInstance(parameter_dict['foo'], str)
1423        parameters['p22'] = parameter_dict['p22']
1424        parameters['foo'] = parameter_dict['foo']
1425
1426    def test_ParameterDict_replication(self):
1427        # The actual replication code from DP cannot be used on CPU so doing it manually here
1428        def make_param():
1429            return Parameter(torch.randn(2, 2))
1430        parameters = {"foo": make_param(), "bar": make_param()}
1431        param_dict = nn.ParameterDict(parameters)
1432
1433        new_param_dict = param_dict._replicate_for_data_parallel()
1434
1435        for n, p in param_dict.named_parameters():
1436            # Do a view here so that we can check the base later
1437            setattr(new_param_dict, n, p.view_as(p))
1438
1439        for (k, p), (k2, p2) in zip(param_dict.items(), new_param_dict.items()):
1440            self.assertEqual(k, k2)
1441            self.assertEqual(p, p2)
1442            self.assertIsNotNone(p2.grad_fn)
1443            self.assertIs(p2._base, p)
1444
1445        self.assertEqual(param_dict["foo"], new_param_dict["foo"])
1446
1447    def test_add_module(self):
1448        methods_to_test = ['add_module', 'register_module']
1449        for fn in methods_to_test:
1450            l = nn.Linear(10, 20)
1451            net = nn.Module()
1452            net.l = l
1453            net.l2 = l
1454            getattr(net, fn)('empty', None)
1455            self.assertEqual(net.l, l)
1456            self.assertEqual(net.l2, l)
1457            self.assertEqual(net.empty, None)
1458            getattr(net, fn)('l3', l)
1459            self.assertEqual(net.l3, l)
1460            l3 = nn.Linear(20, 10)
1461            getattr(net, fn)('l', l3)
1462            self.assertEqual(net.l, l3)
1463            self.assertRaises(TypeError, lambda: getattr(net, fn)('x', 'non-module'))
1464            self.assertRaisesRegex(TypeError, 'module name should be a string. Got int',
1465                                   lambda: getattr(net, fn)(1, l))
1466            self.assertRaisesRegex(TypeError, 'module name should be a string. Got NoneType',
1467                                   lambda: getattr(net, fn)(None, l))
1468
1469    def test_set_submodule(self):
1470        net = nn.Module()
1471        net.t = nn.Module()
1472        l = nn.Linear(1, 2)
1473        target = "t.l"
1474        net.set_submodule(target, l)
1475        self.assertEqual(net.get_submodule(target), l)
1476        l2 = nn.Linear(2, 1)
1477        net.set_submodule(target, l2)
1478        self.assertEqual(net.get_submodule(target), l2)
1479        self.assertRaises(ValueError, net.set_submodule, "", l)
1480        self.assertRaises(AttributeError, net.set_submodule, "a.l", l)
1481
1482    def test_module_to_argparse(self):
1483        net = nn.Sequential(nn.Linear(3, 3))
1484        cpu = torch.device('cpu')
1485        with self.assertRaises(TypeError):
1486            net.to(cpu, True)
1487        with self.assertRaises(TypeError):
1488            net.to(torch.long)
1489        with self.assertRaises(TypeError):
1490            net.to(None, True)
1491        with self.assertRaises(TypeError):
1492            net.to(cpu, torch.long, True)
1493        with self.assertRaises(TypeError):
1494            net.to(cpu, dtype=torch.long, non_blocking=True)
1495        with self.assertRaises(TypeError):
1496            net.to([])
1497        with self.assertRaises(TypeError):
1498            net.to({}, non_blocking=True)
1499        with self.assertRaises(TypeError):
1500            net.to(torch.tensor(3, dtype=torch.long), non_blocking=True)
1501        with self.assertRaises(TypeError):
1502            net.to(cpu, torch.tensor(3, dtype=torch.long), non_blocking=True)
1503
1504    def test_RNN_nonlinearity(self):
1505        rnn = torch.nn.RNN(1, 10)
1506        self.assertEqual(rnn.nonlinearity, 'tanh')
1507
1508        rnn = torch.nn.RNN(1, 10, nonlinearity='relu')
1509        self.assertEqual(rnn.nonlinearity, 'relu')
1510
1511        with self.assertRaisesRegex(ValueError, 'Unknown nonlinearity'):
1512            rnn = torch.nn.RNN(1, 10, nonlinearity='garbage')
1513
1514    def test_RNN_nonlinearity_passed_as_arg(self):
1515        rnn = torch.nn.RNN(2, 3, 1, 'relu')
1516        self.assertEqual(rnn.nonlinearity, 'relu')
1517
1518    def test_module_apply_inplace_op(self):
1519        def add_one_inplace(t):
1520            return t.add_(1.0)
1521
1522        # Test that applying an in-place operation to a module would bump
1523        # the module's parameters' version counter.
1524        m = nn.Linear(20, 10)
1525        pvm = m.weight.mul(m.weight)
1526        m_weight_version_saved = m.weight._version
1527        m = m._apply(add_one_inplace)
1528        self.assertGreater(m.weight._version, m_weight_version_saved)
1529        with self.assertRaisesRegex(RuntimeError, "modified by an inplace operation"):
1530            pvm.backward(torch.randn(10, 20))
1531
1532        # Test that applying an in-place operation to a module would bump
1533        # the module's parameters' gradients' version counter.
1534        m = nn.Linear(20, 10)
1535        m.weight.grad = torch.randn(10, 20).requires_grad_()
1536        pgm = m.weight.grad.mul(m.weight.grad)
1537        m_weight_grad_version_saved = m.weight.grad._version
1538        m = m._apply(add_one_inplace)
1539        self.assertGreater(m.weight.grad._version, m_weight_grad_version_saved)
1540        with self.assertRaisesRegex(RuntimeError, "modified by an inplace operation"):
1541            pgm.backward(torch.randn(10, 20))
1542
1543    def test_overwrite_module_params_on_conversion(self):
1544        # Test that if the conversion function passed to `module._apply()`
1545        # changes the TensorImpl type of `module`'s parameters, the `module`'s
1546        # parameters are always overwritten, regardless of the value of
1547        # `torch.__future__.get_overwrite_module_params_on_conversion()`.
1548        m = nn.Linear(20, 10)
1549        m.weight.grad = torch.randn(10, 20)
1550        weight_ref = m.weight
1551        weight_grad_ref = m.weight.grad
1552        m = m._apply(lambda t: torch.sparse_coo_tensor(torch.zeros([2, 1]), torch.ones([1]), torch.Size([10, 20])))
1553        self.assertNotEqual(weight_ref.layout, m.weight.layout)
1554        self.assertNotEqual(weight_grad_ref.layout, m.weight.grad.layout)
1555
1556        # Test that under the current default settings
1557        # (`torch.__future__.get_overwrite_module_params_on_conversion() == False`),
1558        # a view to a module's parameters is not pointing to the same storage as
1559        # its base variable after converting the module to a different dtype.
1560        m = nn.Linear(20, 10).float()
1561        mw = m.weight[:]
1562        m.double()
1563        with torch.no_grad():
1564            mw[0][0] = 5
1565        self.assertTrue(mw[0][0].dtype == torch.float)
1566        self.assertTrue(mw._base[0][0].dtype == torch.double)
1567
1568        try:
1569            torch.__future__.set_overwrite_module_params_on_conversion(True)
1570
1571            # Test that if `torch.__future__.get_overwrite_module_params_on_conversion() == True`,
1572            # a view to a module's parameters is still pointing to the same storage as
1573            # its base variable after converting the module to a different dtype.
1574            m = nn.Linear(20, 10).float()
1575            mw = m.weight[:]
1576            m.double()
1577            with torch.no_grad():
1578                mw[0][0] = 5
1579            self.assertTrue(mw[0][0] == mw._base[0][0])
1580
1581            # Test that if `torch.__future__.get_overwrite_module_params_on_conversion() == True`,
1582            # `float_module.double()` doesn't preserve previous references to
1583            # `float_module`'s parameters or gradients.
1584            m = nn.Linear(20, 10).float()
1585            m.weight.grad = torch.randn(10, 20).float()
1586            weight_ref = m.weight
1587            weight_grad_ref = m.weight.grad
1588            m.double()
1589            self.assertNotEqual(weight_ref.dtype, m.weight.dtype)
1590            self.assertNotEqual(weight_grad_ref.dtype, m.weight.grad.dtype)
1591
1592            def add_one_inplace(t):
1593                return t.add_(1.0)
1594
1595            # Test that if `torch.__future__.get_overwrite_module_params_on_conversion() == True`,
1596            # applying an in-place operation to a module would bump the module's
1597            # original parameters' version counter.
1598            m = nn.Linear(20, 10)
1599            pvm = m.weight.mul(m.weight)
1600            weight_ref = m.weight
1601            m_weight_version_saved = weight_ref._version
1602            m = m._apply(add_one_inplace)
1603            # Test that the in-place operation bumps the original parameter's version counter
1604            self.assertGreater(weight_ref._version, m_weight_version_saved)
1605            with self.assertRaisesRegex(RuntimeError, "modified by an inplace operation"):
1606                pvm.backward(torch.randn(10, 20))
1607
1608            # Test that if `torch.__future__.get_overwrite_module_params_on_conversion() == True`,
1609            # applying an in-place operation to a module would bump the module's
1610            # original parameters' gradients' version counter.
1611            m = nn.Linear(20, 10)
1612            m.weight.grad = torch.randn(10, 20).requires_grad_()
1613            pgm = m.weight.grad.mul(m.weight.grad)
1614            weight_grad_ref = m.weight.grad
1615            m_weight_grad_version_saved = weight_grad_ref._version
1616            m = m._apply(add_one_inplace)
1617            self.assertGreater(weight_grad_ref._version, m_weight_grad_version_saved)
1618            with self.assertRaisesRegex(RuntimeError, "modified by an inplace operation"):
1619                pgm.backward(torch.randn(10, 20))
1620
1621            # Test that if `torch.__future__.get_overwrite_module_params_on_conversion() == True`,
1622            # applying an out-of-place operation to a module doesn't bump
1623            # the module's original parameters' version counter.
1624            m = nn.Linear(20, 10)
1625            weight_ref = m.weight
1626            m_weight_version_saved = weight_ref._version
1627            m = m._apply(lambda t: torch.randn(t.shape))
1628            self.assertEqual(weight_ref._version, m_weight_version_saved)
1629
1630            # Test that if `torch.__future__.get_overwrite_module_params_on_conversion() == True`,
1631            # applying an out-of-place operation to a module doesn't bump
1632            # the module's original parameters' gradients' version counter.
1633            m = nn.Linear(20, 10)
1634            m.weight.grad = torch.randn(10, 20).requires_grad_()
1635            weight_grad_ref = m.weight.grad
1636            m_weight_grad_version_saved = weight_grad_ref._version
1637            m = m._apply(lambda t: torch.randn(t.shape))
1638            self.assertEqual(weight_grad_ref._version, m_weight_grad_version_saved)
1639        finally:
1640            torch.__future__.set_overwrite_module_params_on_conversion(False)
1641
1642    def test_swap_module_params_poisons_acc_grad(self):
1643        try:
1644            torch.__future__.set_swap_module_params_on_conversion(True)
1645            # (1) backward cannot be run after _apply
1646            # forward will init AccumulateGrad nodes, which bumps use_count of parameters' at::Tensors
1647            # additionally, if any Tensors are saved for backward, their use_count will be bumped
1648            m = torch.nn.Linear(2, 3)
1649            inp = torch.randn(2, 2)
1650            out = m(inp)
1651            m.half()
1652            self.assertTrue(all(p.dtype == torch.float16 for p in m.parameters()))
1653            with self.assertRaisesRegex(RuntimeError, "Trying to execute AccumulateGrad node that was poisoned by swap_tensors"):
1654                out.sum().backward()
1655            # (2) _apply can be run after backward()
1656            # After running backward, all the references generated by "save for backward" will be cleared
1657            # So the use_count will be 2 (1 from Tensor itself, and 1 from AccumulateGrad node), swap_tensors
1658            # should allow this.
1659            inp2 = torch.randn(2, 2, dtype=torch.half)
1660            out2 = m(inp2)
1661            out2.sum().backward()
1662            m.float()
1663            self.assertTrue(all(p.dtype == torch.float32 for p in m.parameters()))
1664            out3 = m(inp)
1665        finally:
1666            torch.__future__.set_swap_module_params_on_conversion(False)
1667
1668    def test_type(self):
1669        l = nn.Linear(10, 20)
1670        net = nn.Module()
1671        net.l = l
1672        net.l2 = l
1673        net.add_module('empty', None)
1674        net.indices = Buffer(torch.LongTensor(1))
1675        net.float()
1676        self.assertIsInstance(l.weight.data, torch.FloatTensor)
1677        self.assertIsInstance(l.bias.data, torch.FloatTensor)
1678        self.assertIsInstance(net.indices, torch.LongTensor)
1679        net.double()
1680        self.assertIsInstance(l.weight.data, torch.DoubleTensor)
1681        self.assertIsInstance(l.bias.data, torch.DoubleTensor)
1682        self.assertIsInstance(net.indices, torch.LongTensor)
1683        net.to(torch.half)
1684        self.assertIsInstance(l.weight.data, torch.HalfTensor)
1685        self.assertIsInstance(l.bias.data, torch.HalfTensor)
1686        self.assertIsInstance(net.indices, torch.LongTensor)
1687        if TEST_CUDA:
1688            net.float().cuda()
1689            self.assertIsInstance(l.weight.data, torch.cuda.FloatTensor)
1690            self.assertIsInstance(l.bias.data, torch.cuda.FloatTensor)
1691            self.assertIsInstance(net.indices, torch.cuda.LongTensor)
1692            net.cpu()
1693            self.assertIsInstance(l.weight.data, torch.FloatTensor)
1694            self.assertIsInstance(l.bias.data, torch.FloatTensor)
1695            self.assertIsInstance(net.indices, torch.LongTensor)
1696            net.to("cuda", torch.double, True)
1697            self.assertIsInstance(l.weight.data, torch.cuda.DoubleTensor)
1698            self.assertIsInstance(l.bias.data, torch.cuda.DoubleTensor)
1699            self.assertIsInstance(net.indices, torch.cuda.LongTensor)
1700            net.to(torch.empty(1, device="cuda:0", dtype=torch.half))
1701            self.assertIsInstance(l.weight.data, torch.cuda.HalfTensor)
1702            self.assertIsInstance(l.bias.data, torch.cuda.HalfTensor)
1703            self.assertIsInstance(net.indices, torch.cuda.LongTensor)
1704        net.to(torch.device("cpu"), non_blocking=True)
1705        self.assertIsInstance(l.weight.data, torch.HalfTensor)
1706        self.assertIsInstance(l.bias.data, torch.HalfTensor)
1707        self.assertIsInstance(net.indices, torch.LongTensor)
1708        net.to(torch.float)
1709        self.assertIsInstance(l.weight.data, torch.FloatTensor)
1710        self.assertIsInstance(l.bias.data, torch.FloatTensor)
1711        net.to(torch.DoubleTensor(1))
1712        self.assertIsInstance(l.weight.data, torch.DoubleTensor)
1713        self.assertIsInstance(l.bias.data, torch.DoubleTensor)
1714        if TEST_CUDA:
1715            net.to(device='cuda', dtype=torch.float)
1716            self.assertIsInstance(l.weight.data, torch.cuda.FloatTensor)
1717            self.assertIsInstance(l.bias.data, torch.cuda.FloatTensor)
1718
1719    def test_non_leaf_parameters(self):
1720        l1 = nn.Linear(10, 10)
1721        l2 = nn.Linear(10, 10)
1722
1723        def assign_weight():
1724            l2.weight = l1.weight + 2
1725
1726        self.assertRaises(TypeError, assign_weight)
1727        # This should work though
1728        l2.weight = Parameter(torch.randn(10, 10))
1729
1730    def test_parameters_to_vector(self):
1731        conv1 = nn.Conv2d(3, 10, 5)
1732        fc1 = nn.Linear(10, 20)
1733        model = nn.Sequential(conv1, fc1)
1734
1735        vec = parameters_to_vector(model.parameters())
1736        self.assertEqual(vec.size(0), 980)
1737
1738    def test_vector_to_parameters(self):
1739        conv1 = nn.Conv2d(3, 10, 5)
1740        fc1 = nn.Linear(10, 20)
1741        model = nn.Sequential(conv1, fc1)
1742
1743        vec = torch.arange(0., 980)
1744        vector_to_parameters(vec, model.parameters())
1745
1746        sample = next(model.parameters())[0, 0, 0]
1747        self.assertTrue(torch.equal(sample.data, vec.data[:5]))
1748
1749    def test_rnn_weight_norm(self):
1750        def check_weight_norm(l, name, num_params):
1751            # This Module has 4 or 5 parameters called:
1752            # 'weight_ih_l0', 'weight_hh_l0', 'bias_ih_l0', 'bias_hh_l0', weight_hr_l0
1753
1754            # Applying weight norm on one of them causes it to become a tensor
1755            l = torch.nn.utils.weight_norm(l, name=name)
1756            self.assertEqual(
1757                sum(isinstance(p, torch.nn.Parameter) for p in l._flat_weights),
1758                num_params - 1,
1759            )
1760
1761            # Removing the weight norm reparametrization restores the Parameter
1762            l = torch.nn.utils.remove_weight_norm(l, name=name)
1763            self.assertEqual(
1764                sum(isinstance(p, torch.nn.Parameter) for p in l._flat_weights),
1765                num_params,
1766            )
1767
1768            # Make sure that, upon removal of the reparametrization, the
1769            # `._parameters` and `.named_parameters` contain the right params.
1770            # Specifically, the original weight ('weight_ih_l0') should be placed
1771            # back in the parameters, while the reparametrization components
1772            # ('weight_ih_l0_v' and 'weight_ih_l0_g') should be removed.
1773            self.assertTrue(name in l._parameters)
1774            self.assertIsNotNone(l._parameters[name])
1775            self.assertTrue(name + '_v' not in l._parameters)
1776            self.assertTrue(name + '_g' not in l._parameters)
1777            self.assertTrue(name in dict(l.named_parameters()))
1778            self.assertIsNotNone(dict(l.named_parameters())[name])
1779            self.assertTrue(name + '_v' not in dict(l.named_parameters()))
1780            self.assertTrue(name + '_g' not in dict(l.named_parameters()))
1781
1782        check_weight_norm(torch.nn.LSTM(32, 32), 'weight_ih_l0', 4)
1783        check_weight_norm(torch.nn.LSTM(32, 32, proj_size=16), 'weight_hr_l0', 5)
1784
1785
1786    def test_weight_norm(self):
1787        for dtype in [torch.float, torch.bfloat16]:
1788            input = torch.randn(3, 4, dtype=dtype)
1789            m = nn.Linear(4, 5).to(dtype=dtype)
1790            expected_output = m(input)
1791
1792            # add weight normalization
1793            m = torch.nn.utils.weight_norm(m)
1794            self.assertEqual(m.weight_v.size(), m.weight.size())
1795            self.assertEqual(m.weight_g.size(), (5, 1))
1796            self.assertEqual(m(input), expected_output, atol=dtype2prec_DONTUSE[dtype], rtol=0)
1797
1798            # remove weight norm
1799            m = torch.nn.utils.remove_weight_norm(m)
1800            self.assertFalse(hasattr(m, 'weight_g'))
1801            self.assertFalse(hasattr(m, 'weight_v'))
1802            self.assertEqual(m(input), expected_output, atol=dtype2prec_DONTUSE[dtype], rtol=0)
1803
1804            # test with dim=1
1805            m = torch.nn.utils.weight_norm(m, dim=1)
1806            self.assertEqual(m.weight_v.size(), m.weight.size())
1807            self.assertEqual(m.weight_g.size(), (1, 4))
1808            self.assertEqual(m(input), expected_output, atol=dtype2prec_DONTUSE[dtype], rtol=0)
1809
1810            # test with dim=None
1811            m = nn.Linear(4, 5).to(dtype=dtype)
1812            expected_output = m(input)
1813            m = torch.nn.utils.weight_norm(m, dim=None)
1814            self.assertEqual(m(input), expected_output)
1815
1816            with self.assertRaisesRegex(RuntimeError, 'register two weight_norm hooks'):
1817                m = torch.nn.utils.weight_norm(m)
1818                m = torch.nn.utils.weight_norm(m)
1819
1820        # For float16, the forward of the Module doesn't work but we must still be able
1821        # to register the weight norm as this is often done before sending the Module to
1822        # CUDA.
1823        m = nn.Linear(4, 5, dtype=torch.float16)
1824        m = torch.nn.utils.weight_norm(m)
1825
1826    def test_parameterlistdict_setting_attributes(self):
1827        with warnings.catch_warnings(record=True) as w:
1828            mod = nn.ParameterList(map(nn.Parameter, [torch.rand(2), torch.rand(2)]))
1829        self.assertTrue(len(w) == 0)
1830
1831        with warnings.catch_warnings(record=True) as w:
1832            mod.train()
1833            mod.eval()
1834        self.assertTrue(len(w) == 0)
1835
1836        with warnings.catch_warnings(record=True) as w:
1837            mod = nn.ParameterDict({"a": nn.Parameter(torch.rand(2)), "b": nn.Parameter(torch.rand(2))})
1838        self.assertTrue(len(w) == 0)
1839
1840        with warnings.catch_warnings(record=True) as w:
1841            mod.train()
1842            mod.eval()
1843        self.assertTrue(len(w) == 0)
1844
1845    def test_parameterlistdict_pickle(self):
1846        m = nn.ParameterList(map(nn.Parameter, [torch.rand(2), torch.rand(2)]))
1847        with warnings.catch_warnings(record=True) as w:
1848            m = pickle.loads(pickle.dumps(m))
1849        self.assertTrue(len(w) == 0)
1850
1851        # Test whether loading from older checkpoints works without triggering warnings
1852        m = nn.ParameterList(map(nn.Parameter, [torch.rand(2), torch.rand(2)]))
1853        del m._forward_pre_hooks, m._state_dict_hooks, m._load_state_dict_pre_hooks, m._non_persistent_buffers_set
1854        with warnings.catch_warnings(record=True) as w:
1855            m = pickle.loads(pickle.dumps(m))
1856        self.assertTrue(len(w) == 0)
1857
1858        m = nn.ParameterDict({"a": nn.Parameter(torch.rand(2)), "b": nn.Parameter(torch.rand(2))})
1859        with warnings.catch_warnings(record=True) as w:
1860            m = pickle.loads(pickle.dumps(m))
1861        self.assertTrue(len(w) == 0)
1862
1863        # Test whether loading from older checkpoints works without triggering warnings
1864        m = nn.ParameterDict({"a": nn.Parameter(torch.rand(2)), "b": nn.Parameter(torch.rand(2))})
1865        del m._forward_pre_hooks, m._state_dict_hooks, m._load_state_dict_pre_hooks, m._non_persistent_buffers_set
1866        with warnings.catch_warnings(record=True) as w:
1867            m = pickle.loads(pickle.dumps(m))
1868        self.assertTrue(len(w) == 0)
1869
1870    def test_weight_norm_pickle(self):
1871        m = torch.nn.utils.weight_norm(nn.Linear(5, 7))
1872        m = pickle.loads(pickle.dumps(m))
1873        self.assertIsInstance(m, nn.Linear)
1874
1875    @skipIfTorchDynamo("TorchDynamo fails here for unknown reasons")
1876    @set_default_dtype(torch.double)
1877    def test_spectral_norm(self):
1878        input = torch.randn(3, 5)
1879        m = nn.Linear(5, 7)
1880        m = torch.nn.utils.spectral_norm(m)
1881
1882        self.assertEqual(m.weight_u.size(), torch.Size([m.weight.size(0)]))
1883        # weight_orig should be trainable
1884        self.assertTrue(hasattr(m, 'weight_orig'))
1885        self.assertTrue('weight_orig' in m._parameters)
1886        # weight_u should be just a reused buffer
1887        self.assertTrue(hasattr(m, 'weight_u'))
1888        self.assertTrue('weight_u' in m._buffers)
1889        self.assertTrue('weight_v' in m._buffers)
1890        # weight should be a plain attribute, not counted as a buffer or a param
1891        self.assertFalse('weight' in m._buffers)
1892        self.assertFalse('weight' in m._parameters)
1893        # it should also be sharing storage as `weight_orig`
1894        self.assertEqual(m.weight_orig.storage(), m.weight.storage())
1895        self.assertEqual(m.weight_orig.size(), m.weight.size())
1896        self.assertEqual(m.weight_orig.stride(), m.weight.stride())
1897
1898        m = torch.nn.utils.remove_spectral_norm(m)
1899        self.assertFalse(hasattr(m, 'weight_orig'))
1900        self.assertFalse(hasattr(m, 'weight_u'))
1901        # weight should be converted back as a parameter
1902        self.assertTrue(hasattr(m, 'weight'))
1903        self.assertTrue('weight' in m._parameters)
1904
1905        with self.assertRaisesRegex(RuntimeError, 'register two spectral_norm hooks'):
1906            m = torch.nn.utils.spectral_norm(m)
1907            m = torch.nn.utils.spectral_norm(m)
1908
1909        # test correctness in training/eval modes and cpu/multi-gpu settings
1910        for apply_dp in (True, False):
1911            if apply_dp:
1912                if not TEST_MULTIGPU:
1913                    continue
1914                device = torch.device('cuda:0')
1915
1916                def maybe_wrap(m):
1917                    return torch.nn.DataParallel(m, [0, 1])
1918            else:
1919                device = torch.device('cpu')
1920
1921                def maybe_wrap(m):
1922                    return m
1923
1924            for requires_grad in (True, False):
1925                m = nn.Linear(3, 4).to(device)
1926                m.weight.requires_grad_(requires_grad)
1927                m = torch.nn.utils.spectral_norm(m)
1928                wrapped_m = maybe_wrap(m)
1929                self.assertTrue(hasattr(m, 'weight_u'))
1930                u0 = m.weight_u.clone()
1931                v0 = m.weight_v.clone()
1932
1933                # TEST TRAINING BEHAVIOR
1934
1935                # assert that u and v are updated
1936                input = torch.randn(2, 3, device=device)
1937                out = wrapped_m(input)
1938                self.assertNotEqual(u0, m.weight_u)
1939                self.assertNotEqual(v0, m.weight_v)
1940
1941                # assert that backprop reaches weight_orig
1942                # can't use gradcheck because the function changes as we
1943                # activate through it in training mode
1944                if requires_grad:
1945                    torch.autograd.grad(out.sum(), m.weight_orig)
1946
1947                # test backward works with multiple forwards
1948                # it uses training mode so we need to reset `u` and `v` vectors
1949                # to same value at beginning for finite difference test to pass
1950                saved_u = m.weight_u.clone()
1951                saved_v = m.weight_v.clone()
1952
1953                def fn(input):
1954                    m.weight_u.data.copy_(saved_u)
1955                    m.weight_v.data.copy_(saved_v)
1956                    out0 = wrapped_m(input)
1957                    out1 = wrapped_m(input)
1958                    return out0 + out1
1959
1960                gradcheck(fn, (input.clone().requires_grad_(),), check_batched_grad=False)
1961
1962                # test removing
1963                pre_remove_out = wrapped_m(input)
1964                m = torch.nn.utils.remove_spectral_norm(m)
1965                self.assertEqual(wrapped_m(input), pre_remove_out)
1966
1967                m = torch.nn.utils.spectral_norm(m)
1968                for _ in range(3):
1969                    pre_remove_out = wrapped_m(input)
1970                m = torch.nn.utils.remove_spectral_norm(m)
1971                self.assertEqual(wrapped_m(input), pre_remove_out)
1972
1973                # TEST EVAL BEHAVIOR
1974
1975                m = torch.nn.utils.spectral_norm(m)
1976                wrapped_m(input)
1977                last_train_out = wrapped_m(input)
1978                last_train_u = m.weight_u.clone()
1979                last_train_v = m.weight_v.clone()
1980                wrapped_m.zero_grad()
1981                wrapped_m.eval()
1982
1983                eval_out0 = wrapped_m(input)
1984                # assert eval gives same result as last training iteration
1985                self.assertEqual(eval_out0, last_train_out)
1986                # assert doing more iteartion in eval don't change things
1987                self.assertEqual(eval_out0, wrapped_m(input))
1988                self.assertEqual(last_train_u, m.weight_u)
1989                self.assertEqual(last_train_v, m.weight_v)
1990
1991                # FIXME: the code below is flaky when executed with DataParallel
1992                # see https://github.com/pytorch/pytorch/issues/13818
1993                if apply_dp:
1994                    continue
1995
1996                # test backward works with multiple forwards in mixed training
1997                # and eval modes
1998                # it uses training mode so we need to reset `u` and `v` vectors
1999                # to same value at beginning for finite difference test to pass
2000                saved_u = m.weight_u.clone()
2001                saved_v = m.weight_v.clone()
2002
2003                def fn(input):
2004                    m.weight_u.data.copy_(saved_u)
2005                    m.weight_v.data.copy_(saved_v)
2006                    wrapped_m.train()
2007                    out0 = wrapped_m(input)
2008                    wrapped_m.eval()
2009                    out1 = wrapped_m(input)
2010                    wrapped_m.train()
2011                    out2 = wrapped_m(input)
2012                    wrapped_m.eval()
2013                    out3 = wrapped_m(input)
2014                    return out0 + out1 + out2 + out3
2015
2016                gradcheck(fn, (input.clone().requires_grad_(),))
2017
2018                # assert that backprop reaches weight_orig in eval
2019                if requires_grad:
2020                    def fn(weight):
2021                        return wrapped_m(input)
2022
2023                    gradcheck(fn, (m.weight_orig,))
2024
2025    @skipIfNoLapack
2026    def test_spectral_norm_load_state_dict(self):
2027        inp = torch.randn(2, 3)
2028        for activate_times in (0, 3):
2029            # Test backward compatibility
2030            # At version None -> 1: weight becomes not a buffer and v vector becomes a buffer
2031            m = nn.Linear(3, 5)
2032            snm = torch.nn.utils.spectral_norm(m)
2033            snm.train()
2034            for _ in range(activate_times):
2035                snm(inp)
2036
2037            version_latest_ref_state_dict = deepcopy(snm.state_dict())
2038            self.assertEqual({'weight_orig', 'bias', 'weight_u', 'weight_v'}, set(version_latest_ref_state_dict.keys()))
2039
2040            # test that non-strict loading works
2041            non_strict_state_dict = deepcopy(version_latest_ref_state_dict)
2042            non_strict_state_dict['nonsense'] = 'nonsense'
2043            with self.assertRaisesRegex(RuntimeError, r'Unexpected key\(s\) in state_dict: "nonsense"'):
2044                snm.load_state_dict(non_strict_state_dict, strict=True)
2045            snm.load_state_dict(non_strict_state_dict, strict=False)
2046            del non_strict_state_dict['weight_orig']
2047            snm.load_state_dict(non_strict_state_dict, strict=False)
2048            del non_strict_state_dict['weight_u']
2049            snm.load_state_dict(non_strict_state_dict, strict=False)
2050            del non_strict_state_dict['weight_v']
2051            snm.load_state_dict(non_strict_state_dict, strict=False)
2052            non_strict_state_dict['weight'] = snm.weight.detach().clone()  # set W as a buffer
2053            snm.load_state_dict(non_strict_state_dict, strict=False)
2054            del non_strict_state_dict._metadata['']['spectral_norm']       # remove metadata info
2055            snm.load_state_dict(non_strict_state_dict, strict=False)
2056            del non_strict_state_dict['weight']                            # remove W buffer
2057            snm.load_state_dict(non_strict_state_dict, strict=False)
2058            del non_strict_state_dict['bias']
2059            snm.load_state_dict(non_strict_state_dict, strict=False)
2060
2061            # craft a version None state_dict
2062            version_none_state_dict = deepcopy(version_latest_ref_state_dict)
2063            self.assertIn('spectral_norm', version_none_state_dict._metadata[''])
2064            del version_none_state_dict._metadata['']['spectral_norm']       # remove metadata info
2065            del version_none_state_dict['weight_v']                          # remove v vector
2066            version_none_state_dict['weight'] = snm.weight.detach().clone()  # set W as a buffer
2067
2068            # normal state_dict
2069            for version_latest_with_metadata in [True, False]:
2070                version_latest_state_dict = deepcopy(version_latest_ref_state_dict)
2071
2072                if not version_latest_with_metadata:
2073                    # We want to still load a user-crafted state_dict, one without metadata
2074                    del version_latest_state_dict._metadata['']['spectral_norm']
2075
2076                # test that re-wrapping does not matter
2077                m = torch.nn.utils.remove_spectral_norm(snm)
2078                snm = torch.nn.utils.spectral_norm(m)
2079
2080                snm.load_state_dict(version_latest_ref_state_dict)
2081                with torch.no_grad():
2082                    snm.eval()
2083                    out0_eval = snm(inp)
2084                    snm.train()
2085                    out1_train = snm(inp)
2086                    out2_train = snm(inp)
2087                    snm.eval()
2088                    out3_eval = snm(inp)
2089
2090                # test that re-wrapping does not matter
2091                m = torch.nn.utils.remove_spectral_norm(snm)
2092                snm = torch.nn.utils.spectral_norm(m)
2093
2094                snm.load_state_dict(version_none_state_dict)
2095                if activate_times > 0:
2096                    # since in loading version None state dict, we assume that the
2097                    # values in the state dict have gone through at lease one
2098                    # forward, we only test for equivalence when activate_times > 0.
2099                    with torch.no_grad():
2100                        snm.eval()
2101                        self.assertEqual(out0_eval, snm(inp))
2102                        snm.train()
2103                        self.assertEqual(out1_train, snm(inp))
2104                        self.assertEqual(out2_train, snm(inp))
2105                        snm.eval()
2106                        self.assertEqual(out3_eval, snm(inp))
2107
2108                # test that re-wrapping does not matter
2109                m = torch.nn.utils.remove_spectral_norm(snm)
2110                snm = torch.nn.utils.spectral_norm(m)
2111
2112                # Test normal loading
2113                snm.load_state_dict(version_latest_state_dict)
2114                with torch.no_grad():
2115                    snm.eval()
2116                    self.assertEqual(out0_eval, snm(inp))
2117                    snm.train()
2118                    self.assertEqual(out1_train, snm(inp))
2119                    self.assertEqual(out2_train, snm(inp))
2120                    snm.eval()
2121                    self.assertEqual(out3_eval, snm(inp))
2122
2123    def test_spectral_norm_dim(self):
2124        inp = torch.randn(2, 3, 10, 12)
2125        m = nn.ConvTranspose2d(3, 4, (5, 6))
2126        m = torch.nn.utils.spectral_norm(m)
2127        # this should not run into incompatible shapes
2128        x = m(inp)
2129        # check that u refers to the same dimension
2130        self.assertEqual(m.weight_u.shape, m.weight_orig[0, :, 0, 0].shape)
2131
2132    def test_spectral_norm_forward(self):
2133        input = torch.randn(3, 5)
2134        m = nn.Linear(5, 7)
2135        m = torch.nn.utils.spectral_norm(m)
2136        # naive forward
2137        _weight, _bias, _u = m.weight_orig, m.bias, m.weight_u
2138        _weight_mat = _weight.view(_weight.size(0), -1)
2139        _v = torch.mv(_weight_mat.t(), _u)
2140        _v = F.normalize(_v, dim=0, eps=1e-12)
2141        _u = torch.mv(_weight_mat, _v)
2142        _u = F.normalize(_u, dim=0, eps=1e-12)
2143        _weight.data /= torch.dot(_u, torch.matmul(_weight_mat, _v))
2144        out_hat = torch.nn.functional.linear(input, _weight, _bias)
2145        expect_out = m(input)
2146        self.assertEqual(expect_out, out_hat)
2147
2148    def test_spectral_norm_pickle(self):
2149        m = torch.nn.utils.spectral_norm(nn.Linear(5, 7))
2150        m = pickle.loads(pickle.dumps(m))
2151        self.assertIsInstance(m, nn.Linear)
2152
2153    def test_threshold_int(self):
2154        x = torch.tensor([-3, -2, -1, 0, 1, 2, 3])
2155        expected = torch.tensor([99, 99, 99, 99, 1, 2, 3])
2156        self.assertEqual(F.threshold(x, 0, 99), expected)
2157
2158    def test_threshold_bfloat16_half(self):
2159        x = torch.randn(100)
2160        for dtype in [torch.bfloat16, torch.half]:
2161            for threshold in [0, -0.5, 0.5, float('inf'), float('-inf'), float('nan')]:
2162                expected = F.threshold(x, threshold, 0).to(dtype=dtype).float()
2163                res_bf16 = F.threshold(x.to(dtype=dtype), threshold, 0).float()
2164                self.assertEqual(res_bf16, expected)
2165
2166    @unittest.skipUnless('fbgemm' in torch.backends.quantized.supported_engines,
2167                         'Linear_FP16_weight requires FBGEMM. FBGEMM is only optimized for CPUs'
2168                         ' with instruction set support avx2 or newer.')
2169    def test_fb_fc_packed(self):
2170        X = np.random.rand(16, 16).astype(np.float32) - 0.5
2171        W = np.random.rand(16, 16).astype(np.float32) - 0.5
2172        b = np.random.rand(16).astype(np.float32) - 0.5
2173
2174        def fc_op(X, W, b):
2175            return np.dot(X, W.T) + b
2176
2177        x_tensor = torch.tensor(X)
2178        w_tensor = torch.tensor(W)
2179        b_tensor = torch.tensor(b)
2180        packed_w_tensor = torch.fbgemm_pack_gemm_matrix_fp16(w_tensor)
2181        actual_output = torch.fbgemm_linear_fp16_weight(x_tensor, packed_w_tensor, b_tensor)
2182        expected_output = fc_op(X, W, b)
2183        torch.testing.assert_close(torch.from_numpy(expected_output), actual_output.cpu(), atol=1e-3, rtol=1e-3)
2184
2185    def test_pad_scalar_error(self):
2186        inputs = torch.tensor(0., requires_grad=True)
2187        self.assertRaises(RuntimeError, lambda: F.pad(inputs, (1, 1)))
2188        self.assertRaises(RuntimeError, lambda: F.pad(inputs, (1,)))
2189
2190    def test_nested_tensor_from_mask(self):
2191        N, L, D = 10, 12, 14
2192
2193        input = torch.rand(N, L, D)
2194        mask = torch.ones(N, L, dtype=torch.bool)
2195        # Leave first row be all True to maintain the nt's size unchanged
2196        for i in range(1, N):
2197            end = torch.randint(1, L, size=()).item()
2198            mask[i, end:] = False
2199
2200        nt = torch._nested_tensor_from_mask(input, mask)
2201        input_convert = nt.to_padded_tensor(0.)
2202        input.masked_fill_(mask.reshape(N, L, 1).logical_not(), 0.)
2203
2204        self.assertEqual(input, input_convert)
2205
2206    def test_nested_tensor_from_mask_error(self):
2207        N, L, D = 10, 12, 14
2208
2209        input = torch.rand(N, L, D)
2210        # Mask is not bool
2211        mask = torch.zeros(N, L, dtype=torch.float)
2212        self.assertRaises(RuntimeError, lambda: torch._nested_tensor_from_mask(input, mask))
2213
2214        # Mask size is not 2
2215        mask = torch.zeros(N, L, D, dtype=torch.bool)
2216        self.assertRaises(RuntimeError, lambda: torch._nested_tensor_from_mask(input, mask))
2217
2218        # Input size is not 3
2219        mask = torch.zeros(N, L, dtype=torch.bool)
2220        input = torch.rand(N, L)
2221        self.assertRaises(RuntimeError, lambda: torch._nested_tensor_from_mask(input, mask))
2222
2223        # Mask size does not match input
2224        mask = torch.zeros(N + 1, L + 1, dtype=torch.bool)
2225        input = torch.rand(N, L, D)
2226        self.assertRaises(RuntimeError, lambda: torch._nested_tensor_from_mask(input, mask))
2227
2228        # Mask is not padding format
2229        mask = torch.ones(N, L, dtype=torch.bool)
2230        mask[0, 0] = False
2231        mask[0, 2] = False
2232        self.assertRaises(RuntimeError, lambda: torch._nested_tensor_from_mask(input, mask))
2233
2234    def test_normalize(self):
2235        inputs = torch.randn(1, 3, 4, 4, requires_grad=True, dtype=torch.double)
2236        self.assertTrue(gradcheck(lambda x: F.normalize(x, p=1, dim=-1), (inputs,)))
2237        self.assertTrue(gradcheck(lambda x: F.normalize(x, p=2, dim=-2), (inputs,)))
2238
2239        inputs = torch.randn((), requires_grad=True)
2240        self.assertTrue(gradcheck(lambda x: F.normalize(x, p=1, dim=-1), (inputs,)))
2241
2242    @unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported")
2243    # Skip the test for ROCm as per https://github.com/pytorch/pytorch/issues/53190
2244    @skipIfRocm
2245    def test_broadcast_double_backwards_gpu(self):
2246        tensors = (torch.randn(4, 4, device='cuda', requires_grad=True, dtype=torch.double),
2247                   torch.randn(4, 4, device='cuda', requires_grad=True, dtype=torch.double),
2248                   torch.randn(4, 4, device='cuda', requires_grad=True, dtype=torch.double))
2249        # TODO(#50743): the following segfaults with check_batched_grad=True
2250        _assertGradAndGradgradChecks(self, lambda *i: Broadcast.apply((0, 1), *i), tensors,
2251                                     check_batched_grad=False)
2252
2253    @unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported")
2254    def test_broadcast_not_requiring_grad(self):
2255        variables = [
2256            torch.randn(1, 2, device='cuda', requires_grad=True),
2257            torch.randn(1, 2, device='cuda', requires_grad=False),
2258            torch.randn(1, 2, device='cuda', requires_grad=False),
2259            torch.randn(1, 2, device='cuda', requires_grad=True),
2260            torch.randn(1, 2, device='cuda', requires_grad=True),
2261        ]
2262        broadcasted_variables = Broadcast.apply((0, 1), *variables)
2263        for output_idx, broadcasted_var in enumerate(broadcasted_variables):
2264            input_var = variables[output_idx % len(variables)]
2265            self.assertEqual(input_var.requires_grad, broadcasted_var.requires_grad)
2266
2267    @unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported")
2268    def test_broadcast_no_grad(self):
2269        x = torch.randn(1, 2, dtype=torch.float32, requires_grad=True, device='cuda')
2270        with torch.no_grad():
2271            broadcasted = Broadcast.apply((0, 1), x)
2272        self.assertTrue(x.requires_grad)
2273        for output in broadcasted:
2274            self.assertFalse(output.requires_grad)
2275
2276    def test_state_dict(self):
2277        l = nn.Linear(5, 5)
2278        block = nn.Module()
2279        block.conv = nn.Conv2d(3, 3, 3, bias=False)
2280        net = nn.Module()
2281        net.linear1 = l
2282        net.linear2 = l
2283        net.bn = nn.BatchNorm2d(2)
2284        net.block = block
2285        net.add_module('empty', None)
2286
2287        state_dict = net.state_dict()
2288        self.assertEqual(len(state_dict), 10)
2289        self.assertEqual(len(state_dict._metadata), 6)
2290        self.assertIn('', state_dict._metadata)
2291        self.assertIn('linear1', state_dict._metadata)
2292        self.assertIn('linear1.weight', state_dict)
2293        self.assertIn('linear1.bias', state_dict)
2294        self.assertIn('linear2', state_dict._metadata)
2295        self.assertIn('linear2.weight', state_dict)
2296        self.assertIn('linear2.bias', state_dict)
2297        self.assertIn('block', state_dict._metadata)
2298        self.assertIn('block.conv', state_dict._metadata)
2299        self.assertIn('block.conv.weight', state_dict)
2300        self.assertIn('block.conv.weight', state_dict)
2301        self.assertNotIn('block.conv.bias', state_dict)
2302        self.assertIn('bn', state_dict._metadata)
2303        self.assertIn('bn.weight', state_dict)
2304        self.assertIn('bn.bias', state_dict)
2305        self.assertIn('bn.running_var', state_dict)
2306        self.assertIn('bn.running_mean', state_dict)
2307        self.assertIn('bn.num_batches_tracked', state_dict)
2308        self.assertFalse(any(k.startswith('empty') for k in state_dict.keys()))
2309        for k, v in state_dict.items():
2310            param = net
2311            for component in k.split('.'):
2312                param = getattr(param, component)
2313                if isinstance(param, Parameter):
2314                    param = param.data
2315            self.assertEqual(v.data_ptr(), param.data_ptr())
2316
2317        l = nn.Linear(5, 5)
2318        state_dict = l.state_dict()
2319        self.assertEqual(len(state_dict), 2)
2320        self.assertEqual(len(state_dict._metadata), 1)
2321        self.assertIn('', state_dict._metadata)
2322        self.assertTrue(state_dict._metadata['']['version'] >= 0)
2323        self.assertEqual(state_dict['weight'].data_ptr(), l.weight.data_ptr())
2324        self.assertEqual(state_dict['bias'].data_ptr(), l.bias.data_ptr())
2325
2326        # Reference https://github.com/pytorch/pytorch/pull/75507#issuecomment-1110291545
2327        self.assertNotWarn(lambda: l.state_dict(destination={}), "Should not warn kwarg destination w/o _metadata")
2328
2329    def test_extra_state(self):
2330
2331        class SubModule(torch.nn.Module):
2332            def __init__(self, foo):
2333                super().__init__()
2334                self.foo = foo
2335
2336            def get_extra_state(self):
2337                return {
2338                    'foo': self.foo
2339                }
2340
2341            def set_extra_state(self, state):
2342                self.foo = state['foo']
2343
2344        class MyModule(torch.nn.Module):
2345            def __init__(self, foo, bar):
2346                super().__init__()
2347                self.sub = SubModule(foo)
2348                self.bar = bar
2349
2350            def get_extra_state(self):
2351                return {
2352                    'bar': self.bar
2353                }
2354
2355            def set_extra_state(self, state):
2356                self.bar = state['bar']
2357
2358        # Ensure state_dict contains the extra state by loading it into another module.
2359        m = MyModule(3, 'something')
2360        m2 = MyModule(5, 'something else')
2361        m2.load_state_dict(m.state_dict())
2362        self.assertEqual(m.state_dict(), m2.state_dict())
2363        self.assertEqual(m2.bar, m.bar)
2364        self.assertEqual(m2.sub.foo, m.sub.foo)
2365
2366    def test_extra_state_non_dict(self):
2367
2368        class MyModule(torch.nn.Module):
2369            def __init__(self, foo):
2370                super().__init__()
2371                self.foo = foo
2372
2373            def get_extra_state(self):
2374                return self.foo
2375
2376            def set_extra_state(self, state):
2377                self.foo = state
2378
2379        # Test various types of extra state.
2380        for state in ('something', 5, MyModule(3)):
2381            m = MyModule(state)
2382            m2 = MyModule('something else')
2383            m2.load_state_dict(m.state_dict())
2384            self.assertEqual(m.state_dict(), m2.state_dict())
2385            self.assertEqual(m.foo, m2.foo)
2386
2387    def test_extra_state_missing_set_extra_state(self):
2388
2389        class MyModule(torch.nn.Module):
2390            def get_extra_state(self):
2391                return {
2392                    'foo': 5
2393                }
2394
2395        m = MyModule()
2396        with self.assertRaisesRegex(RuntimeError, 'Unexpected key'):
2397            m.load_state_dict(m.state_dict())
2398
2399    def test_extra_state_missing_get_extra_state(self):
2400
2401        class MyModule(torch.nn.Module):
2402            def set_extra_state(self):
2403                pass
2404
2405        m = MyModule()
2406        with self.assertRaisesRegex(RuntimeError, 'Missing key'):
2407            m.load_state_dict(m.state_dict())
2408
2409    @skipIfTorchDynamo("TorchDynamo fails here for unknown reasons")
2410    def test_parameter_assignment(self):
2411        l = nn.Linear(5, 5)
2412
2413        def num_params():
2414            return len(list(l.parameters()))
2415
2416        self.assertEqual(num_params(), 2)
2417
2418        new_param = Parameter(torch.randn(5, 5))
2419        l.param_name = new_param
2420        self.assertEqual(num_params(), 3)
2421        self.assertObjectIn(new_param, l.parameters())
2422
2423        var = torch.randn(5, 5)
2424        l.var_name = var
2425        self.assertEqual(num_params(), 3)
2426        self.assertNotIn(id(var), map(id, l.parameters()))
2427
2428        # Make sure Variables are not saved as parameters
2429        l.variable_attr = torch.empty(5, 5)
2430        self.assertEqual(num_params(), 3)
2431        l.param_attr = Parameter(torch.empty(5, 5))
2432        self.assertEqual(num_params(), 4)
2433
2434        # It shouldn't be possible to replace a parameter with a Variable
2435        def assign_var():
2436            l.param_attr = torch.empty(5, 5)
2437
2438        self.assertRaises(TypeError, assign_var)
2439        # But replacing it with None should be fine
2440        l.param_attr = None
2441        self.assertEqual(num_params(), 3)
2442
2443    def test_assignment(self):
2444        l = nn.Module()
2445        a = nn.Parameter(torch.randn(2))
2446        b = nn.Parameter(torch.randn(3))
2447        c = nn.Parameter(torch.randn(4))
2448        q = nn.Linear(4, 4)
2449        r = nn.Linear(5, 5)
2450        w = nn.Linear(6, 6)
2451
2452        def test_assignments(get_list, a, b, c):
2453            # Check that None can be shadowed
2454            l.a = None
2455            self.assertIsNone(l.a)
2456            self.assertIn('a', l.__dict__)
2457            l.a = a
2458            self.assertIs(l.a, a)
2459            self.assertEqual(get_list(), [a])
2460            self.assertNotIn('a', l.__dict__)
2461
2462            # Assign second object
2463            l.b = None
2464            self.assertIsNone(l.b)
2465            self.assertIn('b', l.__dict__)
2466            l.b = b
2467            self.assertIs(l.b, b)
2468            self.assertEqual(get_list(), [a, b])
2469            self.assertNotIn('b', l.__dict__)
2470
2471            # Remove and add the object back. Order should be unchanged.
2472            l.a = None
2473            self.assertIsNone(l.a)
2474            self.assertEqual(get_list(), [b])
2475            l.a = a
2476            self.assertIs(l.a, a)
2477            self.assertEqual(get_list(), [a, b])
2478
2479            # Replace object with another one. Order should be unchanged.
2480            l.a = c
2481            self.assertIs(l.a, c)
2482            self.assertEqual(get_list(), [c, b])
2483
2484            # Remove and reassign an attribute. It should appear at the end of the list now.
2485            del l.a
2486            self.assertFalse(hasattr(l, 'a'))
2487            l.a = a
2488            self.assertIs(l.a, a)
2489            self.assertEqual(get_list(), [b, a])
2490
2491        test_assignments(lambda: list(l.parameters()), a, b, c)
2492        del l.a, l.b
2493        self.assertEqual(list(l.parameters()), [])
2494
2495        test_assignments(lambda: list(l.children()), q, r, w)
2496        del l.a, l.b
2497        self.assertEqual(list(l.children()), [])
2498
2499        buf = Buffer(torch.randn(10))
2500        l.buf = buf
2501        self.assertIs(l.buf, buf)
2502        l.buf = None
2503        self.assertIs(l.buf, None)
2504        self.assertNotIn('buf', l.__dict__)  # should be stored in l._buffers
2505        l.buf = buf
2506        self.assertIn('buf', l.state_dict())
2507        self.assertEqual(l.state_dict()['buf'], buf)
2508
2509    def test_container_copy(self):
2510        class Model(nn.Module):
2511            def __init__(self) -> None:
2512                super().__init__()
2513                self.linear = nn.Linear(4, 5)
2514
2515            def forward(self, input):
2516                return self.linear(input)
2517
2518        input = torch.randn(2, 4)
2519
2520        model = Model()
2521        model_cp = deepcopy(model)
2522        self.assertEqual(model(input).data, model_cp(input).data)
2523
2524        model_cp.linear.weight.data[:] = 2
2525        self.assertNotEqual(model(input).data, model_cp(input).data)
2526
2527    def test_RNN_cell(self):
2528        # this is just a smoke test; these modules are implemented through
2529        # autograd so no Jacobian test is needed
2530        for module in (nn.RNNCell, nn.GRUCell):
2531            for bias in (True, False):
2532                input = torch.randn(3, 10)
2533                hx = torch.randn(3, 20)
2534                cell = module(10, 20, bias=bias)
2535                for _ in range(6):
2536                    hx = cell(input, hx)
2537
2538                hx.sum().backward()
2539
2540    def test_RNN_cell_forward_zero_hidden_size(self):
2541        input = torch.randn(3, 10)
2542        hx = torch.randn(3, 0)
2543        cell_shared_param = (10, 0)
2544        for cell in (nn.RNNCell(*cell_shared_param, nonlinearity="relu"),
2545                     nn.RNNCell(*cell_shared_param, nonlinearity="tanh"),
2546                     nn.GRUCell(*cell_shared_param)):
2547            self.assertEqual(cell(input, hx).shape, torch.Size([3, 0]))
2548
2549    def _test_loss_equal_input_target_shape(self, cast):
2550        # Tests losses whose inputs should have the same size.
2551        losses = {
2552            'mse_loss': lambda x, y: F.mse_loss(x, y),
2553            'l1_loss': lambda x, y: F.l1_loss(x, y),
2554            'smooth_l1_loss': lambda x, y: F.smooth_l1_loss(x, y),
2555            'huber_loss': lambda x, y: F.huber_loss(x, y),
2556            'kl_div': lambda x, y: F.kl_div(x, y),
2557            'poisson_nll_loss': lambda x, y: F.poisson_nll_loss(x, y),
2558        }
2559
2560        input = cast(torch.randn(3, 5))
2561        target = cast(torch.randn(5, 3))
2562        for fn in losses.values():
2563            self.assertRaises(Exception, lambda: fn(input, target))
2564
2565    def test_loss_equal_input_target_shape(self):
2566        self._test_loss_equal_input_target_shape(lambda x: x)
2567
2568    def test_mse_loss_size_warning(self):
2569        i = torch.randn((10, 1), requires_grad=True)
2570        t = torch.randn((10,))
2571        with warnings.catch_warnings(record=True) as w:
2572            # Ensure warnings are being shown
2573            warnings.simplefilter("always")
2574            # Trigger Warning
2575            F.mse_loss(i, t)
2576            # Check warning occurs
2577            self.assertEqual(len(w), 1)
2578            self.assertIn('Please ensure they have the same size.', str(w[0]))
2579
2580    def test_gaussian_nll_loss_broadcasting(self):
2581        input = torch.tensor([[0.5, 1.5, 2.5], [2., 4., 6.]])
2582        target_full = torch.tensor([[1., 2., 3.], [1., 2., 3.]])
2583        target_part = torch.tensor([[1., 2., 3.]])
2584        var_full = torch.tensor([[0.5, 0.5, 0.5], [1.5, 1.5, 1.5]])
2585        var_part1 = torch.tensor([[0.5], [1.5]])
2586        var_part2 = torch.tensor([0.5, 1.5])
2587        component_wise_loss = 0.5 * (torch.log(var_full) + (input - target_full)**2 / var_full)
2588        self.assertEqual(component_wise_loss,
2589                         F.gaussian_nll_loss(input, target_part, var_full, reduction='none'))
2590        self.assertEqual(component_wise_loss,
2591                         F.gaussian_nll_loss(input, target_full, var_part1, reduction='none'))
2592        self.assertEqual(component_wise_loss,
2593                         F.gaussian_nll_loss(input, target_full, var_part2, reduction='none'))
2594        self.assertEqual(component_wise_loss,
2595                         F.gaussian_nll_loss(input, target_part, var_part1, reduction='none'))
2596        self.assertEqual(component_wise_loss,
2597                         F.gaussian_nll_loss(input, target_part, var_part2, reduction='none'))
2598
2599    def test_gaussian_nll_loss_args(self):
2600        input = torch.randn(3, 5)
2601        with self.assertRaisesRegex(ValueError, 'var is of incorrect size'):
2602            target = torch.randn(3, 5)
2603            var = torch.ones(3, 3)
2604            torch.nn.functional.gaussian_nll_loss(input, target, var)
2605        with self.assertRaisesRegex(ValueError, 'var has negative entry/entries'):
2606            var = -1 * torch.ones(3, 5)
2607            torch.nn.functional.gaussian_nll_loss(input, target, var)
2608
2609    def test_KLDivLoss_batch_mean(self):
2610        input_shape = (2, 5)
2611        log_prob1 = F.log_softmax(torch.randn(input_shape), 1)
2612        prob2 = F.softmax(torch.randn(input_shape), 1)
2613
2614        loss = nn.KLDivLoss(reduction='batchmean')
2615        l = loss(log_prob1, prob2)
2616
2617        loss_none_reduce = nn.KLDivLoss(reduction='sum')(log_prob1, prob2)
2618        expected = loss_none_reduce / input_shape[0]
2619
2620        self.assertEqual(l, expected)
2621
2622    def test_KLDivLoss_batch_mean_log_target(self):
2623        input_shape = (2, 5)
2624        log_prob1 = F.log_softmax(torch.randn(input_shape), 1)
2625        log_prob2 = F.log_softmax(torch.randn(input_shape), 1)
2626
2627        loss = nn.KLDivLoss(reduction='batchmean', log_target=True)
2628        l = loss(log_prob1, log_prob2)
2629
2630        loss_none_reduce = nn.KLDivLoss(reduction='sum', log_target=True)(log_prob1, log_prob2)
2631        expected = loss_none_reduce / input_shape[0]
2632
2633        self.assertEqual(l, expected)
2634
2635    def test_CTCLoss_typechecks(self):
2636        target_lengths = torch.tensor([30, 25, 20])
2637        input_lengths = torch.tensor([50, 50, 50])
2638        targets = torch.randint(1, 15, (sum(target_lengths),), dtype=torch.int)
2639        log_probs = torch.randn(50, 3, 15, dtype=torch.float).log_softmax(2)
2640        with self.assertRaises(RuntimeError):
2641            _input_lengths = input_lengths.to(dtype=torch.float)
2642            torch.nn.functional.ctc_loss(log_probs, targets, _input_lengths, target_lengths)
2643        with self.assertRaises(RuntimeError):
2644            target_lengths = target_lengths.to(dtype=torch.float)
2645            torch.nn.functional.ctc_loss(log_probs, targets, input_lengths, target_lengths)
2646
2647    @unittest.skipIf(not TEST_CUDA, 'CUDA not available')
2648    def test_CTCLoss_lengthchecks_cuda(self):
2649        for target_lengths in [[30, 25, 20], [-1, -1, -1]]:
2650            for input_lengths in [[50, 50, 50], [-1, -1, -1]]:
2651                targets = torch.randint(1, 15, (3, 29), dtype=torch.long, device='cuda')
2652                log_probs = torch.randn(50, 3, 15, dtype=torch.float, device='cuda').log_softmax(2)
2653                with self.assertRaises(RuntimeError):
2654                    torch.nn.functional.ctc_loss(log_probs, targets, input_lengths, target_lengths)
2655
2656    def test_CTCLoss_lengthchecks_cpu(self):
2657        for target_lengths in [[30, 25, 20], [-1, -1, -1]]:
2658            for input_lengths in [[50, 50, 50], [-1, -1, -1]]:
2659                targets = torch.randint(1, 15, (3, 29), dtype=torch.int)
2660                log_probs = torch.randn(50, 3, 15, dtype=torch.float).log_softmax(2)
2661                with self.assertRaises(RuntimeError):
2662                    torch.nn.functional.ctc_loss(log_probs, targets, input_lengths, target_lengths)
2663
2664    @unittest.skipIf(not TEST_CUDA, 'CUDA not available')
2665    def test_CTCLoss_long_targets(self):
2666        input_length = 4000
2667        vocab_size = 3
2668        batch_size = 4
2669        target_length = 1200
2670
2671        log_probs = torch.randn(input_length, batch_size, vocab_size, dtype=torch.double).log_softmax(2).requires_grad_()
2672        targets = torch.randint(low=1, high=vocab_size - 1, size=(batch_size, target_length), dtype=torch.long)
2673        input_lengths = batch_size * [input_length]
2674        target_lengths = batch_size * [target_length]
2675
2676        res_cpu = torch.nn.functional.ctc_loss(log_probs, targets, input_lengths, target_lengths,
2677                                               reduction='sum', zero_infinity=True)
2678        grad_out = torch.randn_like(res_cpu)
2679        grad_cpu, = torch.autograd.grad(res_cpu, log_probs, grad_out)
2680
2681        with torch.backends.cudnn.flags(enabled=False):
2682            res_gpu = torch.nn.functional.ctc_loss(log_probs.cuda(), targets.cuda(), input_lengths, target_lengths,
2683                                                   reduction='sum', zero_infinity=True)
2684            grad_gpu, = torch.autograd.grad(res_gpu, log_probs, grad_out.cuda())
2685        self.assertEqual(res_cpu, res_gpu, atol=1e-4, rtol=0)
2686        self.assertEqual(grad_cpu, grad_gpu, atol=1e-4, rtol=0)
2687
2688    @unittest.skipIf(not TEST_CUDA, 'CUDA not available')
2689    def test_CTCLoss_critical_target_len(self):
2690        # cudnn has an unexpected problem with target length 256, see issue #53505
2691        N = 1
2692        S = 256
2693        C = 10
2694        T = 500
2695        target = torch.randint(low=1, high=C, size=(S,), dtype=torch.int)
2696        input_lengths = torch.full(size=(N,), fill_value=T, dtype=torch.int)
2697        target_lengths = torch.tensor(S, dtype=torch.int)
2698        inp = torch.randn(T, N, C, dtype=torch.float, device='cuda').log_softmax(2).requires_grad_()
2699        with cudnn.flags(enabled=True):
2700            res_gpu = torch.nn.functional.ctc_loss(inp, target, input_lengths, target_lengths, reduction='none')
2701        res_cpu = torch.nn.functional.ctc_loss(inp.cpu(), target, input_lengths, target_lengths, reduction='none')
2702        self.assertEqual(res_cpu, res_gpu, atol=1e-3, rtol=0)
2703
2704    def test_CTCLoss_zero_lengths(self):
2705        devices = ['cpu']
2706        devices += ['cuda'] if TEST_CUDA else []
2707        N = 3
2708        S = 2
2709        C = 200
2710        T = 1
2711        target = torch.randint(low=1, high=C, size=(N, S), dtype=torch.int)
2712        input_lengths = torch.full(size=(N,), fill_value=0, dtype=torch.int)
2713        target_lengths = torch.full(size=(N,), fill_value=0, dtype=torch.int)
2714        for device in devices:
2715            inp = torch.randn(T, N, C, dtype=torch.float, device=device).log_softmax(2).requires_grad_()
2716            res = torch.nn.functional.ctc_loss(inp, target, input_lengths, target_lengths, reduction='none')
2717            self.assertTrue((res == 0).all().item())
2718            res.sum().backward()
2719            self.assertTrue((inp.grad == 0).all().item())
2720        target_lengths = torch.full(size=(N,), fill_value=1, dtype=torch.int)
2721        for device in devices:
2722            inp = torch.randn(T, N, C, dtype=torch.float, device=device).log_softmax(2).requires_grad_()
2723            res = torch.nn.functional.ctc_loss(inp, target, input_lengths, target_lengths, reduction='none')
2724            self.assertTrue((res == torch.inf).all().item())
2725            res.sum().backward()
2726            self.assertTrue((inp.grad == 0).all().item())
2727
2728    @unittest.skipIf(not TEST_CUDA, 'CUDA not available')
2729    def test_CTCLoss_zero_infinity(self):
2730        target_lengths = [60, 25, 20]
2731        input_lengths = [50, 50, 50]
2732        targets = torch.randint(1, 15, (sum(target_lengths),), dtype=torch.int, device='cuda')
2733        log_probs = torch.randn(50, 3, 15, dtype=torch.float, device='cuda').log_softmax(2).requires_grad_()
2734        res = torch.nn.functional.ctc_loss(log_probs, targets, input_lengths, target_lengths,
2735                                           reduction='sum', zero_infinity=True)
2736        with torch.backends.cudnn.flags(enabled=False):
2737            res2 = torch.nn.functional.ctc_loss(log_probs, targets.cuda().long(), input_lengths, target_lengths,
2738                                                reduction='sum', zero_infinity=True)
2739        res_cpu = torch.nn.functional.ctc_loss(log_probs.cpu(), targets.cpu(), input_lengths, target_lengths,
2740                                               reduction='sum', zero_infinity=True)
2741
2742        self.assertEqual(res2, res, atol=1e-4, rtol=0)
2743        self.assertEqual(res_cpu, res.cpu(), atol=1e-4, rtol=0)
2744        g1, = torch.autograd.grad(res, log_probs)
2745        g2, = torch.autograd.grad(res2, log_probs)
2746        g3, = torch.autograd.grad(res_cpu, log_probs)
2747        self.assertEqual(g2, g3, atol=1e-4, rtol=0)
2748        self.assertEqual(g1, g2, atol=1e-4, rtol=0)
2749        self.assertTrue((g1 == g1).all().item())  # check that we don't have NaN
2750
2751    def test_RNN_cell_no_broadcasting(self):
2752        def test(cell_module, input, hx, input_size, hidden_size):
2753            cell = cell_module(input_size, hidden_size)
2754            self.assertRaises(RuntimeError, lambda: cell(input, hx))
2755
2756        def test_all(hidden_size, bad_hx, good_hx, input_size, input):
2757            test(nn.RNNCell, input, bad_hx, input_size, hidden_size)
2758            test(nn.GRUCell, input, bad_hx, input_size, hidden_size)
2759            test(nn.LSTMCell, input, (bad_hx, good_hx), input_size, hidden_size)
2760            test(nn.LSTMCell, input, (good_hx, bad_hx), input_size, hidden_size)
2761
2762        hidden_size = 20
2763        input_size = 10
2764        input = torch.randn(3, input_size)
2765        bad_hx = torch.randn(1, hidden_size)
2766        good_hx = torch.randn(3, hidden_size)
2767
2768        # Test hidden/input batch size broadcasting
2769        test_all(hidden_size, bad_hx, good_hx, input_size, input)
2770
2771        # Test hx's hidden_size vs module's hidden_size broadcasting
2772        bad_hx = torch.randn(3, 1)
2773        test_all(hidden_size, bad_hx, good_hx, input_size, input)
2774
2775        # Test input's input_size vs module's input_size broadcasting
2776        bad_input = torch.randn(3, 1)
2777        test_all(hidden_size, good_hx, good_hx, input_size, bad_input)
2778
2779    def test_LSTM_cell(self):
2780        # this is just a smoke test; these modules are implemented through
2781        # autograd so no Jacobian test is needed
2782        for bias in (True, False):
2783            input = torch.randn(3, 10)
2784            hx = torch.randn(3, 20)
2785            cx = torch.randn(3, 20)
2786            lstm = nn.LSTMCell(10, 20, bias=bias)
2787            for _ in range(6):
2788                hx, cx = lstm(input, (hx, cx))
2789
2790            (hx + cx).sum().backward()
2791
2792    def test_LSTM_cell_forward_input_size(self):
2793        input = torch.randn(3, 11)
2794        hx = torch.randn(3, 20)
2795        cx = torch.randn(3, 20)
2796        lstm = nn.LSTMCell(10, 20)
2797        self.assertRaises(Exception, lambda: lstm(input, (hx, cx)))
2798
2799    def test_LSTM_cell_forward_hidden_size(self):
2800        input = torch.randn(3, 10)
2801        hx = torch.randn(3, 21)
2802        cx = torch.randn(3, 20)
2803        lstm = nn.LSTMCell(10, 20)
2804        self.assertRaises(Exception, lambda: lstm(input, (hx, cx)))
2805        self.assertRaises(Exception, lambda: lstm(input, (cx, hx)))
2806
2807
2808    @unittest.skipIf(not TEST_CUDA, 'CUDA not available')
2809    def test_pack_sequence_batch_sizes_throw(self):
2810        with self.assertRaisesRegex(ValueError, r"batch_sizes should always be on CPU"):
2811            m = nn.LSTM(3, 4, bidirectional=True, num_layers=2).to('cuda')
2812            a = torch.rand(5, 3, device='cuda')
2813            b = torch.tensor([1, 1, 1, 1, 1], device='cuda')
2814            input = nn.utils.rnn.PackedSequence(a, b)
2815
2816    def test_Transformer_cell(self):
2817        # this is just a smoke test; these modules are implemented through
2818        # autograd so no Jacobian test is needed
2819        d_model = 512
2820        nhead = 16
2821        num_encoder_layers = 4
2822        num_decoder_layers = 3
2823        dim_feedforward = 256
2824        dropout = 0.3
2825        bsz = 8
2826        seq_length = 35
2827        tgt_length = 15
2828        for batch_first, src_size, tgt_size in zip((True, False),
2829                                                   [(bsz, seq_length, d_model),
2830                                                    (seq_length, bsz, d_model)],
2831                                                   [(bsz, tgt_length, d_model),
2832                                                    (tgt_length, bsz, d_model)]):
2833            transformer = nn.Transformer(d_model, nhead, num_encoder_layers, num_decoder_layers,
2834                                         dim_feedforward, dropout, batch_first=batch_first,
2835                                         dtype=torch.double)
2836            src = torch.randn(src_size, dtype=torch.double)
2837            src_mask = transformer.generate_square_subsequent_mask(seq_length).double()
2838            tgt = torch.randn(tgt_size, dtype=torch.double)
2839            tgt_mask = transformer.generate_square_subsequent_mask(tgt_length).double()
2840            memory_mask = torch.randn(tgt_length, seq_length).double()
2841            src_key_padding_mask = torch.rand(bsz, seq_length) >= 0.5
2842            tgt_key_padding_mask = torch.rand(bsz, tgt_length) >= 0.5
2843            memory_key_padding_mask = torch.rand(bsz, seq_length) >= 0.5
2844
2845            output = transformer(src, tgt,
2846                                 src_mask=src_mask,
2847                                 tgt_mask=tgt_mask,
2848                                 memory_mask=memory_mask,
2849                                 src_key_padding_mask=src_key_padding_mask,
2850                                 tgt_key_padding_mask=tgt_key_padding_mask,
2851                                 memory_key_padding_mask=memory_key_padding_mask)
2852            output.sum().backward()
2853
2854    def test_transformerdecoderlayer(self):
2855        # this is a deterministic test for TransformerDecoderLayer
2856        d_model = 4
2857        nhead = 2
2858        dim_feedforward = 16
2859        dropout = 0.0
2860        bsz = 2
2861        seq_length = 5
2862        tgt_length = 3
2863
2864        for batch_first in (False, True):
2865            def perm_fn(x):
2866                return x.transpose(1, 0) if batch_first else x
2867
2868            model = nn.TransformerDecoderLayer(d_model, nhead, dim_feedforward, dropout,
2869                                               batch_first=batch_first)
2870
2871            # set constant weights of the model
2872            for idx, p in enumerate(model.parameters()):
2873                x = p.data
2874                sz = x.view(-1).size(0)
2875                shape = x.shape
2876                x = torch.cos(torch.arange(0, sz).float().view(shape))
2877                p.data.copy_(x)
2878
2879            # deterministic input
2880            decoder_input = torch.tensor([[[20., 30., 40., 50.]]])
2881            memory_input = torch.tensor([[[60., 70., 80., 90.]]])
2882            result = model(decoder_input, memory_input)
2883            ref_output = torch.tensor([[[2.314351, 0.094805, -0.671322, 0.101977]]])
2884            result = result.detach().numpy()
2885            ref_output = ref_output.detach().numpy()
2886            self.assertEqual(tuple(result.shape), tuple(ref_output.shape))
2887            np.testing.assert_allclose(result, ref_output, atol=1e-5)
2888
2889            # deterministic input
2890            decoder_input = perm_fn(torch.tensor([[[9., 10., 11., 12.]],
2891                                                  [[11., 12., 13., 14.]]]))
2892            memory_input = torch.tensor([[[1., 2., 3., 4.]]])
2893            result = model(decoder_input, memory_input)
2894            result = result.detach().numpy()
2895            ref_output = perm_fn(torch.tensor([[[2.422245, 0.051716, -0.606338, -0.024756]],
2896                                               [[2.422245, 0.051716, -0.606338, -0.024756]]]))
2897            ref_output = ref_output.detach().numpy()
2898            self.assertEqual(tuple(result.shape), tuple(ref_output.shape))
2899            np.testing.assert_allclose(result, ref_output, atol=1e-5)
2900
2901            # deterministic input
2902            decoder_input = perm_fn(torch.tensor([[[1., 2., 3., 4.]],
2903                                                  [[5., 6., 7., 8.]]]))
2904            memory_input = perm_fn(torch.tensor([[[9., 10., 11., 12.]],
2905                                                 [[11., 12., 13., 14.]]]))
2906            result = model(decoder_input, memory_input)
2907            ref_output = perm_fn(torch.tensor([[[2.343536, 0.085561, -0.654954, 0.074991]],
2908                                               [[2.343536, 0.085561, -0.654954, 0.074991]]]))
2909            result = result.detach().numpy()
2910            ref_output = ref_output.detach().numpy()
2911            self.assertEqual(tuple(result.shape), tuple(ref_output.shape))
2912            np.testing.assert_allclose(result, ref_output, atol=1e-5)
2913
2914            # deterministic input
2915            decoder_input = perm_fn(torch.tensor([[[0.4517, 0.6793, 0.5313, 0.0034],
2916                                                   [0.2678, 0.3677, 0.4459, 0.7166]],
2917                                                  [[0.8100, 0.3716, 0.4096, 0.1976],
2918                                                   [0.6958, 0.8844, 0.6081, 0.8315]],
2919                                                  [[0.0494, 0.9343, 0.5955, 0.3830],
2920                                                   [0.5404, 0.3464, 0.9378, 0.6200]]]))
2921            memory_input = perm_fn(torch.tensor([[[0.7462, 0.6653, 0.5679, 0.4891],
2922                                                  [0.5387, 0.1655, 0.3565, 0.0471]],
2923                                                 [[0.8335, 0.2799, 0.5031, 0.2947],
2924                                                  [0.1402, 0.0318, 0.7636, 0.1346]],
2925                                                 [[0.6333, 0.9344, 0.1376, 0.9938],
2926                                                  [0.8924, 0.2872, 0.6692, 0.2944]],
2927                                                 [[0.9897, 0.6915, 0.3154, 0.1733],
2928                                                  [0.8645, 0.3513, 0.3064, 0.0767]],
2929                                                 [[0.8117, 0.2366, 0.4838, 0.7881],
2930                                                  [0.3718, 0.4945, 0.9511, 0.0864]]]))
2931            result = model(decoder_input, memory_input)
2932            ref_output = perm_fn(torch.tensor([[[2.430065, 0.027862, -0.601136, -0.073096],
2933                                                [2.431935, 0.028907, -0.599809, -0.072488]],
2934                                               [[2.428457, 0.027053, -0.602275, -0.073462],
2935                                                [2.431970, 0.029387, -0.599789, -0.071621]],
2936                                               [[2.431934, 0.028196, -0.599802, -0.073809],
2937                                                [2.432306, 0.028858, -0.599542, -0.072846]]]))
2938            result = result.detach().numpy()
2939            ref_output = ref_output.detach().numpy()
2940            self.assertEqual(tuple(result.shape), tuple(ref_output.shape))
2941            np.testing.assert_allclose(result, ref_output, atol=1e-5)
2942
2943            # key_padding_mask
2944            key_padding_mask = torch.zeros(2, 3) == 1
2945            result = model(decoder_input, memory_input, tgt_key_padding_mask=key_padding_mask)
2946            ref_output = perm_fn(torch.tensor([[[2.430065, 0.027862, -0.601136, -0.073096],
2947                                                [2.431935, 0.028907, -0.599809, -0.072488]],
2948                                               [[2.428457, 0.027053, -0.602275, -0.073462],
2949                                                [2.431970, 0.029387, -0.599789, -0.071621]],
2950                                               [[2.431934, 0.028196, -0.599802, -0.073809],
2951                                                [2.432306, 0.028858, -0.599542, -0.072846]]]))
2952            result = result.detach().numpy()
2953            ref_output = ref_output.detach().numpy()
2954            self.assertEqual(tuple(result.shape), tuple(ref_output.shape))
2955            np.testing.assert_allclose(result, ref_output, atol=1e-5)
2956
2957            # key_padding_mask
2958            key_padding_mask[0, 2] = 1
2959            key_padding_mask[1, 1] = 1
2960            key_padding_mask[1, 2] = 1
2961            result = model(decoder_input, memory_input, tgt_key_padding_mask=key_padding_mask)
2962            ref_output = perm_fn(torch.tensor([[[2.430025, 0.027643, -0.601164, -0.073476],
2963                                                [2.4323, 0.029375, -0.599553, -0.071881]],
2964                                               [[2.428523, 0.026838, -0.602226, -0.07391],
2965                                                [2.432634, 0.029842, -0.599318, -0.071253]],
2966                                               [[2.432278, 0.028152, -0.599555, -0.074139],
2967                                                [2.432659, 0.029244, -0.599294, -0.072382]]]))
2968            result = result.detach().numpy()
2969            ref_output = ref_output.detach().numpy()
2970            self.assertEqual(tuple(result.shape), tuple(ref_output.shape))
2971            np.testing.assert_allclose(result, ref_output, atol=1e-5)
2972
2973            # memory_key_padding_mask
2974            key_padding_mask = torch.zeros(2, 5) == 1
2975            result = model(decoder_input, memory_input, memory_key_padding_mask=key_padding_mask)
2976            ref_output = perm_fn(torch.tensor([[[2.430065, 0.027862, -0.601136, -0.073096],
2977                                                [2.431935, 0.028907, -0.599809, -0.072488]],
2978                                               [[2.428457, 0.027053, -0.602275, -0.073462],
2979                                                [2.431970, 0.029387, -0.599789, -0.071621]],
2980                                               [[2.431934, 0.028196, -0.599802, -0.073809],
2981                                                [2.432306, 0.028858, -0.599542, -0.072846]]]))
2982            result = result.detach().numpy()
2983            ref_output = ref_output.detach().numpy()
2984            self.assertEqual(tuple(result.shape), tuple(ref_output.shape))
2985            np.testing.assert_allclose(result, ref_output, atol=1e-5)
2986
2987            # memory_key_padding_mask
2988            key_padding_mask[0, 4] = 1
2989            key_padding_mask[1, 3] = 1
2990            key_padding_mask[1, 4] = 1
2991            result = model(decoder_input, memory_input, memory_key_padding_mask=key_padding_mask)
2992            ref_output = perm_fn(torch.tensor([[[2.429757, 0.027358, -0.601351, -0.073816],
2993                                                [2.432692, 0.028583, -0.599263, -0.073634]],
2994                                               [[2.428247, 0.02662, -0.602419, -0.074123],
2995                                                [2.432657, 0.029055, -0.599293, -0.072732]],
2996                                               [[2.431515, 0.027687, -0.600096, -0.074459],
2997                                                [2.433075, 0.028543, -0.598987, -0.073985]]]))
2998            result = result.detach().numpy()
2999            ref_output = ref_output.detach().numpy()
3000            self.assertEqual(tuple(result.shape), tuple(ref_output.shape))
3001            np.testing.assert_allclose(result, ref_output, atol=1e-5)
3002
3003    @set_default_dtype(torch.double)
3004    def test_transformerdecoderlayer_gelu(self):
3005        # this is a deterministic test for TransformerDecoderLayer with gelu activation
3006        d_model = 4
3007        nhead = 2
3008        dim_feedforward = 16
3009        dropout = 0.0
3010        bsz = 2
3011        seq_length = 5
3012        tgt_length = 3
3013
3014        for activation, batch_first in product(('gelu', F.gelu, nn.GELU()), (True, False)):
3015            def perm_fn(x):
3016                return x.transpose(1, 0) if batch_first else x
3017
3018            model = nn.TransformerDecoderLayer(d_model, nhead, dim_feedforward, dropout,
3019                                               activation, batch_first=batch_first)
3020
3021            # set constant weights of the model
3022            for idx, p in enumerate(model.parameters()):
3023                x = p.data
3024                sz = x.view(-1).size(0)
3025                shape = x.shape
3026                x = torch.cos(torch.arange(0, sz).float().view(shape))
3027                p.data.copy_(x)
3028
3029            # deterministic input
3030            decoder_input = torch.tensor([[[20., 30., 40., 50.]]])
3031            memory_input = torch.tensor([[[60., 70., 80., 90.]]])
3032            result = model(decoder_input, memory_input)
3033            ref_output = torch.tensor([[[2.306435, 0.095946, -0.675796, 0.10687]]])
3034            torch.testing.assert_close(result, ref_output, rtol=1e-5, atol=0)
3035
3036            # deterministic input
3037            decoder_input = perm_fn(torch.tensor([[[9., 10., 11., 12.]],
3038                                                  [[11., 12., 13., 14.]]]))
3039            memory_input = perm_fn(torch.tensor([[[1., 2., 3., 4.]]]))
3040            result = model(decoder_input, memory_input)
3041            ref_output = perm_fn(torch.tensor([[[2.415448, 0.054389, -0.610932, -0.0156613]],
3042                                               [[2.415448, 0.054389, -0.610932, -0.0156613]]]))
3043            torch.testing.assert_close(result, ref_output, rtol=1e-5, atol=0)
3044
3045            # deterministic input
3046            decoder_input = perm_fn(torch.tensor([[[1., 2., 3., 4.]],
3047                                                  [[5., 6., 7., 8.]]]))
3048            memory_input = perm_fn(torch.tensor([[[9., 10., 11., 12.]],
3049                                                 [[11., 12., 13., 14.]]]))
3050            result = model(decoder_input, memory_input)
3051            ref_output = perm_fn(torch.tensor([[[2.338531, 0.087709, -0.65776, 0.080646]],
3052                                               [[2.338531, 0.087709, -0.65776, 0.080646]]]))
3053            torch.testing.assert_close(result, ref_output, rtol=1e-5, atol=0)
3054
3055            # deterministic input
3056            decoder_input = perm_fn(torch.tensor([[[0.4517, 0.6793, 0.5313, 0.0034],
3057                                                   [0.2678, 0.3677, 0.4459, 0.7166]],
3058                                                  [[0.8100, 0.3716, 0.4096, 0.1976],
3059                                                   [0.6958, 0.8844, 0.6081, 0.8315]],
3060                                                  [[0.0494, 0.9343, 0.5955, 0.3830],
3061                                                   [0.5404, 0.3464, 0.9378, 0.6200]]]))
3062            memory_input = perm_fn(torch.tensor([[[0.7462, 0.6653, 0.5679, 0.4891],
3063                                                  [0.5387, 0.1655, 0.3565, 0.0471]],
3064                                                 [[0.8335, 0.2799, 0.5031, 0.2947],
3065                                                  [0.1402, 0.0318, 0.7636, 0.1346]],
3066                                                 [[0.6333, 0.9344, 0.1376, 0.9938],
3067                                                  [0.8924, 0.2872, 0.6692, 0.2944]],
3068                                                 [[0.9897, 0.6915, 0.3154, 0.1733],
3069                                                  [0.8645, 0.3513, 0.3064, 0.0767]],
3070                                                 [[0.8117, 0.2366, 0.4838, 0.7881],
3071                                                  [0.3718, 0.4945, 0.9511, 0.0864]]]))
3072            result = model(decoder_input, memory_input)
3073            ref_output = perm_fn(torch.tensor([[[2.42049104, 0.03443088, -0.60793706, -0.05436271],
3074                                                [2.42210631, 0.03546578, -0.60679895, -0.05357488]],
3075                                               [[2.41907674, 0.0336104, -0.60892977, -0.05490462],
3076                                                [2.42216881, 0.03586554, -0.6067524, -0.05289126]],
3077                                               [[2.42205716, 0.03488046, -0.60683681, -0.05460596],
3078                                                [2.42240309, 0.0354595, -0.60659063, -0.05378816]]]))
3079            torch.testing.assert_close(result, ref_output, rtol=1e-5, atol=0)
3080
3081    @skipIfRocm(msg='Large numerical errors')
3082    def test_transformerdecoder(self):
3083        def get_a_test_layer(use_cuda, activation, batch_first=False):
3084            d_model = 4
3085            nhead = 2
3086            dim_feedforward = 16
3087            dropout = 0.0
3088            device = torch.device("cuda" if use_cuda else "cpu")
3089
3090            layer = nn.TransformerDecoderLayer(
3091                d_model,
3092                nhead,
3093                dim_feedforward=dim_feedforward,
3094                dropout=dropout,
3095                activation=activation,
3096                batch_first=batch_first).to(device)
3097
3098            with torch.no_grad():
3099                # set constant weights of the model
3100                for idx, p in enumerate(layer.parameters()):
3101                    x = p.data
3102                    sz = x.view(-1).size(0)
3103                    shape = x.shape
3104                    x = torch.cos(torch.arange(0, sz).float().view(shape))
3105                    p.data.copy_(x)
3106
3107            return layer
3108
3109        # this is a deterministic test for TransformerDecoder
3110        for batch_first in (False, True):
3111            def perm_fn(x):
3112                return x.transpose(1, 0) if batch_first else x
3113            activation = F.relu
3114            use_cuda = torch.cuda.is_available()
3115            device = torch.device("cuda" if use_cuda else "cpu")
3116
3117            decoder_layer = get_a_test_layer(use_cuda=use_cuda, activation=activation,
3118                                             batch_first=batch_first)
3119
3120            model = nn.TransformerDecoder(decoder_layer, 1).to(device)
3121
3122            # deterministic input
3123            decoder_input = torch.tensor([[[20., 30., 40., 50.]]]).to(device)
3124            memory_input = torch.tensor([[[60., 70., 80., 90.]]]).to(device)
3125            result = model(decoder_input, memory_input)
3126            ref_output = torch.tensor(
3127                [[[2.314351, 0.094805, -0.671322, 0.101977]]]).to(device)
3128            self.assertEqual(tuple(result.shape), tuple(ref_output.shape))
3129            torch.testing.assert_close(result, ref_output, rtol=1e-7, atol=1e-3)
3130
3131            # deterministic input
3132            decoder_input = perm_fn(torch.tensor([[[9., 10., 11., 12.]],
3133                                                  [[11., 12., 13., 14.]]])).to(device)
3134            memory_input = perm_fn(torch.tensor([[[1., 2., 3., 4.]]])).to(device)
3135            result = model(decoder_input, memory_input)
3136            ref_output = perm_fn(torch.tensor([[[2.422245, 0.051716, -0.606338, -0.024756]],
3137                                               [[2.422245, 0.051716, -0.606338, -0.024756]]]
3138                                              )).to(device)
3139            self.assertEqual(tuple(result.shape), tuple(ref_output.shape))
3140            torch.testing.assert_close(result, ref_output, rtol=1e-7, atol=1e-4)
3141
3142            # deterministic input
3143            decoder_input = perm_fn(torch.tensor([[[1., 2., 3., 4.]],
3144                                                  [[5., 6., 7., 8.]]])).to(device)
3145            memory_input = perm_fn(torch.tensor([[[9., 10., 11., 12.]],
3146                                                 [[11., 12., 13., 14.]]])).to(device)
3147            result = model(decoder_input, memory_input)
3148            ref_output = perm_fn(torch.tensor([[[2.343536, 0.085561, -0.654954, 0.074991]],
3149                                               [[2.343536, 0.085561, -0.654954, 0.074991]]]
3150                                              )).to(device)
3151            self.assertEqual(tuple(result.shape), tuple(ref_output.shape))
3152            torch.testing.assert_close(result, ref_output, rtol=1e-7, atol=1e-4)
3153
3154            # deterministic input
3155            decoder_input = perm_fn(torch.tensor([[[0.4517, 0.6793, 0.5313, 0.0034],
3156                                                   [0.2678, 0.3677, 0.4459, 0.7166]],
3157                                                  [[0.8100, 0.3716, 0.4096, 0.1976],
3158                                                   [0.6958, 0.8844, 0.6081, 0.8315]],
3159                                                  [[0.0494, 0.9343, 0.5955, 0.3830],
3160                                                   [0.5404, 0.3464, 0.9378, 0.6200]]]
3161                                                 )).to(device)
3162            memory_input = perm_fn(torch.tensor([[[0.7462, 0.6653, 0.5679, 0.4891],
3163                                                  [0.5387, 0.1655, 0.3565, 0.0471]],
3164                                                 [[0.8335, 0.2799, 0.5031, 0.2947],
3165                                                  [0.1402, 0.0318, 0.7636, 0.1346]],
3166                                                 [[0.6333, 0.9344, 0.1376, 0.9938],
3167                                                  [0.8924, 0.2872, 0.6692, 0.2944]],
3168                                                 [[0.9897, 0.6915, 0.3154, 0.1733],
3169                                                  [0.8645, 0.3513, 0.3064, 0.0767]],
3170                                                 [[0.8117, 0.2366, 0.4838, 0.7881],
3171                                                  [0.3718, 0.4945, 0.9511, 0.0864]]]
3172                                                )).to(device)
3173            result = model(decoder_input, memory_input)
3174            ref_output = perm_fn(torch.tensor([[[2.430065, 0.027862, -0.601136, -0.073096],
3175                                                [2.431935, 0.028907, -0.599809, -0.072488]],
3176                                               [[2.428457, 0.027053, -0.602275, -0.073462],
3177                                                [2.431970, 0.029387, -0.599789, -0.071621]],
3178                                               [[2.431934, 0.028196, -0.599802, -0.073809],
3179                                                [2.432306, 0.028858, -0.599542, -0.072846]]]
3180                                              )).to(device)
3181            self.assertEqual(tuple(result.shape), tuple(ref_output.shape))
3182            torch.testing.assert_close(result, ref_output, rtol=1e-7, atol=1e-5)
3183
3184            # key_padding_mask
3185            key_padding_mask = torch.zeros(2, 3).to(device) == 1
3186            result = model(decoder_input, memory_input,
3187                           tgt_key_padding_mask=key_padding_mask)
3188            ref_output = perm_fn(torch.tensor([[[2.430065, 0.027862, -0.601136, -0.073096],
3189                                                [2.431935, 0.028907, -0.599809, -0.072488]],
3190                                               [[2.428457, 0.027053, -0.602275, -0.073462],
3191                                                [2.431970, 0.029387, -0.599789, -0.071621]],
3192                                               [[2.431934, 0.028196, -0.599802, -0.073809],
3193                                                [2.432306, 0.028858, -0.599542, -0.072846]]]
3194                                              )).to(device)
3195            self.assertEqual(tuple(result.shape), tuple(ref_output.shape))
3196            torch.testing.assert_close(result, ref_output, rtol=1e-7, atol=1e-5)
3197
3198            # key_padding_mask
3199            key_padding_mask[0, 2] = 1
3200            key_padding_mask[1, 1] = 1
3201            key_padding_mask[1, 2] = 1
3202            result = model(decoder_input, memory_input,
3203                           tgt_key_padding_mask=key_padding_mask)
3204            ref_output = perm_fn(torch.tensor([[[2.430025, 0.027643, -0.601164, -0.073476],
3205                                                [2.4323, 0.029375, -0.599553, -0.071881]],
3206                                               [[2.428523, 0.026838, -0.602226, -0.07391],
3207                                                [2.432634, 0.029842, -0.599318, -0.071253]],
3208                                               [[2.432278, 0.028152, -0.599555, -0.074139],
3209                                                [2.432659, 0.029244, -0.599294, -0.072382]]]
3210                                              )).to(device)
3211            self.assertEqual(tuple(result.shape), tuple(ref_output.shape))
3212            torch.testing.assert_close(result, ref_output, rtol=1e-7, atol=1e-5)
3213
3214            # memory_key_padding_mask
3215            key_padding_mask = torch.zeros(2, 5).to(device) == 1
3216            result = model(decoder_input, memory_input,
3217                           memory_key_padding_mask=key_padding_mask)
3218            ref_output = perm_fn(torch.tensor([[[2.430065, 0.027862, -0.601136, -0.073096],
3219                                                [2.431935, 0.028907, -0.599809, -0.072488]],
3220                                               [[2.428457, 0.027053, -0.602275, -0.073462],
3221                                                [2.431970, 0.029387, -0.599789, -0.071621]],
3222                                               [[2.431934, 0.028196, -0.599802, -0.073809],
3223                                                [2.432306, 0.028858, -0.599542, -0.072846]]]
3224                                              )).to(device)
3225            self.assertEqual(tuple(result.shape), tuple(ref_output.shape))
3226            torch.testing.assert_close(result, ref_output, rtol=1e-7, atol=1e-5)
3227
3228            # memory_key_padding_mask
3229            key_padding_mask[0, 4] = 1
3230            key_padding_mask[1, 3] = 1
3231            key_padding_mask[1, 4] = 1
3232            result = model(decoder_input,
3233                           memory_input,
3234                           memory_key_padding_mask=key_padding_mask)
3235            ref_output = perm_fn(torch.tensor([[[2.429757, 0.027358, -0.601351, -0.073816],
3236                                                [2.432692, 0.028583, -0.599263, -0.073634]],
3237                                               [[2.428247, 0.02662, -0.602419, -0.074123],
3238                                                [2.432657, 0.029055, -0.599293, -0.072732]],
3239                                               [[2.431515, 0.027687, -0.600096, -0.074459],
3240                                                [2.433075, 0.028543, -0.598987, -0.073985]]]
3241                                              )).to(device)
3242            self.assertEqual(tuple(result.shape), tuple(ref_output.shape))
3243            torch.testing.assert_close(result, ref_output, rtol=1e-7, atol=1e-5)
3244
3245            # multiple layers no norm
3246            model = nn.TransformerDecoder(decoder_layer, 2).to(device)
3247
3248            # deterministic input
3249            decoder_input = torch.tensor([[[20., 30., 40., 50.]]]).to(device)
3250            memory_input = torch.tensor([[[60., 70., 80., 90.]]]).to(device)
3251            result = model(decoder_input, memory_input)
3252            ref_output = torch.tensor(
3253                [[[2.31316, 0.0950293, -0.671995, 0.102802]]]).to(device)
3254            self.assertEqual(tuple(result.shape), tuple(ref_output.shape))
3255            torch.testing.assert_close(result, ref_output, rtol=1e-7, atol=1e-3)
3256
3257            # multiple layers no norm
3258            model = nn.TransformerDecoder(decoder_layer, 6).to(device)
3259
3260            # deterministic input
3261            decoder_input = perm_fn(torch.tensor([[[0.4517, 0.6793, 0.5313, 0.0034],
3262                                                   [0.2678, 0.3677, 0.4459, 0.7166]],
3263                                                  [[0.8100, 0.3716, 0.4096, 0.1976],
3264                                                   [0.6958, 0.8844, 0.6081, 0.8315]],
3265                                                  [[0.0494, 0.9343, 0.5955, 0.3830],
3266                                                   [0.5404, 0.3464, 0.9378, 0.6200]]]
3267                                                 )).to(device)
3268            memory_input = perm_fn(torch.tensor([[[0.7462, 0.6653, 0.5679, 0.4891],
3269                                                  [0.5387, 0.1655, 0.3565, 0.0471]],
3270                                                 [[0.8335, 0.2799, 0.5031, 0.2947],
3271                                                  [0.1402, 0.0318, 0.7636, 0.1346]],
3272                                                 [[0.6333, 0.9344, 0.1376, 0.9938],
3273                                                  [0.8924, 0.2872, 0.6692, 0.2944]],
3274                                                 [[0.9897, 0.6915, 0.3154, 0.1733],
3275                                                  [0.8645, 0.3513, 0.3064, 0.0767]],
3276                                                 [[0.8117, 0.2366, 0.4838, 0.7881],
3277                                                  [0.3718, 0.4945, 0.9511, 0.0864]]]
3278                                                )).to(device)
3279            result = model(decoder_input, memory_input)
3280            ref_output = perm_fn(torch.tensor([[[2.42794, 0.026164, -0.60263, -0.0747591],
3281                                                [2.43113, 0.0279516, -0.600376, -0.0736896]],
3282                                               [[2.42794, 0.026164, -0.60263, -0.0747591],
3283                                                [2.43113, 0.0279516, -0.600376, -0.0736896]],
3284                                               [[2.42794, 0.026164, -0.60263, -0.0747591],
3285                                                [2.43113, 0.0279516, -0.600376, -0.0736896]]]
3286                                              )).to(device)
3287            self.assertEqual(tuple(result.shape), tuple(ref_output.shape))
3288            torch.testing.assert_close(result, ref_output, rtol=1e-7, atol=1e-5)
3289
3290            # multiple layers with norm
3291            # d_model = 4
3292            norm = nn.LayerNorm(4)
3293            model = nn.TransformerDecoder(decoder_layer, 2, norm=norm).to(device)
3294
3295            # deterministic input
3296            decoder_input = torch.tensor([[[20., 30., 40., 50.]]]).to(device)
3297            memory_input = torch.tensor([[[60., 70., 80., 90.]]]).to(device)
3298            result = model(decoder_input, memory_input)
3299            ref_output = torch.tensor(
3300                [[[1.66166, -0.326986, -1.01466, -0.320017]]]).to(device)
3301            self.assertEqual(tuple(result.shape), tuple(ref_output.shape))
3302            torch.testing.assert_close(result, ref_output, rtol=1e-7, atol=1e-3)
3303
3304            # multiple layers with norm
3305            model = nn.TransformerDecoder(decoder_layer, 6, norm=norm).to(device)
3306
3307            # deterministic input
3308            decoder_input = perm_fn(torch.tensor([[[0.4517, 0.6793, 0.5313, 0.0034],
3309                                                   [0.2678, 0.3677, 0.4459, 0.7166]],
3310                                                  [[0.8100, 0.3716, 0.4096, 0.1976],
3311                                                   [0.6958, 0.8844, 0.6081, 0.8315]],
3312                                                  [[0.0494, 0.9343, 0.5955, 0.3830],
3313                                                   [0.5404, 0.3464, 0.9378, 0.6200]]]
3314                                                 )).to(device)
3315            memory_input = perm_fn(torch.tensor([[[0.7462, 0.6653, 0.5679, 0.4891],
3316                                                  [0.5387, 0.1655, 0.3565, 0.0471]],
3317                                                 [[0.8335, 0.2799, 0.5031, 0.2947],
3318                                                  [0.1402, 0.0318, 0.7636, 0.1346]],
3319                                                 [[0.6333, 0.9344, 0.1376, 0.9938],
3320                                                  [0.8924, 0.2872, 0.6692, 0.2944]],
3321                                                 [[0.9897, 0.6915, 0.3154, 0.1733],
3322                                                  [0.8645, 0.3513, 0.3064, 0.0767]],
3323                                                 [[0.8117, 0.2366, 0.4838, 0.7881],
3324                                                  [0.3718, 0.4945, 0.9511, 0.0864]]]
3325                                                )).to(device)
3326            result = model(decoder_input, memory_input)
3327            ref_output = perm_fn(torch.tensor([[[1.69559, -0.357291, -0.894741, -0.443553],
3328                                                [1.69571, -0.357363, -0.894154, -0.444196]],
3329                                               [[1.69559, -0.357291, -0.894741, -0.443553],
3330                                                [1.69571, -0.357363, -0.894154, -0.444196]],
3331                                               [[1.69559, -0.357291, -0.894741, -0.443553],
3332                                                [1.69571, -0.357363, -0.894154, -0.444196]]]
3333                                              )).to(device)
3334            self.assertEqual(tuple(result.shape), tuple(ref_output.shape))
3335            torch.testing.assert_close(result, ref_output, rtol=1e-7, atol=1e-5)
3336
3337            # gelu activation test cases
3338            activation = "gelu"
3339            use_cuda = torch.cuda.is_available()
3340            device = torch.device("cuda" if use_cuda else "cpu")
3341
3342            decoder_layer = get_a_test_layer(use_cuda=use_cuda, activation=activation,
3343                                             batch_first=batch_first)
3344
3345            model = nn.TransformerDecoder(decoder_layer, 1).to(device)
3346
3347            # deterministic input
3348            decoder_input = torch.tensor([[[20., 30., 40., 50.]]]).to(device)
3349            memory_input = torch.tensor([[[60., 70., 80., 90.]]]).to(device)
3350            result = model(decoder_input, memory_input)
3351            ref_output = torch.tensor([[[2.306435, 0.095946, -0.675796, 0.10687]]]).to(device)
3352            self.assertEqual(tuple(result.shape), tuple(ref_output.shape))
3353            torch.testing.assert_close(result, ref_output, rtol=1e-7, atol=1e-3)
3354
3355            # deterministic input
3356            decoder_input = perm_fn(torch.tensor([[[9., 10., 11., 12.]],
3357                                                  [[11., 12., 13., 14.]]])).to(device)
3358            memory_input = perm_fn(torch.tensor([[[1., 2., 3., 4.]]])).to(device)
3359            result = model(decoder_input, memory_input)
3360            ref_output = perm_fn(torch.tensor([[[2.415448, 0.054389, -0.610932, -0.0156613]],
3361                                               [[2.415448, 0.054389, -0.610932, -0.0156613]]])).to(device)
3362            self.assertEqual(tuple(result.shape), tuple(ref_output.shape))
3363            torch.testing.assert_close(result, ref_output, rtol=1e-7, atol=1e-4)
3364
3365            # deterministic input
3366            decoder_input = perm_fn(torch.tensor([[[1., 2., 3., 4.]],
3367                                                  [[5., 6., 7., 8.]]])).to(device)
3368            memory_input = perm_fn(torch.tensor([[[9., 10., 11., 12.]],
3369                                                 [[11., 12., 13., 14.]]])).to(device)
3370            result = model(decoder_input, memory_input)
3371            ref_output = perm_fn(torch.tensor([[[2.338531, 0.087709, -0.65776, 0.080646]],
3372                                               [[2.338531, 0.087709, -0.65776, 0.080646]]])).to(device)
3373            self.assertEqual(tuple(result.shape), tuple(ref_output.shape))
3374            torch.testing.assert_close(result, ref_output, rtol=1e-7, atol=1e-4)
3375
3376            # deterministic input
3377            decoder_input = perm_fn(torch.tensor([[[0.4517, 0.6793, 0.5313, 0.0034],
3378                                                   [0.2678, 0.3677, 0.4459, 0.7166]],
3379                                                  [[0.8100, 0.3716, 0.4096, 0.1976],
3380                                                   [0.6958, 0.8844, 0.6081, 0.8315]],
3381                                                  [[0.0494, 0.9343, 0.5955, 0.3830],
3382                                                   [0.5404, 0.3464, 0.9378, 0.6200]]]
3383                                                 )).to(device)
3384            memory_input = perm_fn(torch.tensor([[[0.7462, 0.6653, 0.5679, 0.4891],
3385                                                  [0.5387, 0.1655, 0.3565, 0.0471]],
3386                                                 [[0.8335, 0.2799, 0.5031, 0.2947],
3387                                                  [0.1402, 0.0318, 0.7636, 0.1346]],
3388                                                 [[0.6333, 0.9344, 0.1376, 0.9938],
3389                                                  [0.8924, 0.2872, 0.6692, 0.2944]],
3390                                                 [[0.9897, 0.6915, 0.3154, 0.1733],
3391                                                  [0.8645, 0.3513, 0.3064, 0.0767]],
3392                                                 [[0.8117, 0.2366, 0.4838, 0.7881],
3393                                                  [0.3718, 0.4945, 0.9511, 0.0864]]]
3394                                                )).to(device)
3395            result = model(decoder_input, memory_input)
3396            ref_output = perm_fn(torch.tensor([[[2.42049104, 0.03443088, -0.60793706, -0.05436271],
3397                                                [2.42210631, 0.03546578, -0.60679895, -0.05357488]],
3398                                               [[2.41907674, 0.0336104, -0.60892977, -0.05490462],
3399                                                [2.42216881, 0.03586554, -0.6067524, -0.05289126]],
3400                                               [[2.42205716, 0.03488046, -0.60683681, -0.05460596],
3401                                                [2.42240309, 0.0354595, -0.60659063, -0.05378816]]]
3402                                              )).to(device)
3403            self.assertEqual(tuple(result.shape), tuple(ref_output.shape))
3404            torch.testing.assert_close(result, ref_output, rtol=1e-7, atol=1e-5)
3405
3406    @unittest.skipIf(not (TEST_CUDNN and TEST_MULTIGPU), 'CUDNN or multi-gpu not available')
3407    def test_cudnn_rnn_dropout_states_device(self):
3408        rnn = nn.RNN(10, 20, num_layers=2, dropout=.5)
3409        device = 1
3410        input = torch.randn(5, 4, 10).cuda(device)
3411        rnn.cuda(device)
3412        hx = torch.randn(2, 4, 20).cuda(device)
3413        output = rnn(input, hx)
3414
3415    def test_cudnn_forward_exception(self):
3416        rnns = [
3417            (nn.LSTM(10, 20, batch_first=True), (torch.zeros(1, 2, 19), torch.zeros(1, 2, 19))),
3418            (nn.LSTM(10, 20, batch_first=True, proj_size=10), (torch.zeros(1, 2, 19), torch.zeros(1, 2, 19))),
3419            (nn.GRU(10, 20, batch_first=True), torch.zeros(1, 2, 19)),
3420            (nn.RNN(10, 20, batch_first=True), torch.zeros(1, 2, 19)),
3421        ]
3422        x_wrong = torch.randn(2, 3, 3)
3423        x_right = torch.randn(2, 3, 10)
3424        for rnn, hidden in rnns:
3425            self.assertRaisesRegex(RuntimeError, "Expected hidden.*size.*got", rnn, x_right, hidden)
3426            self.assertRaisesRegex(RuntimeError, re.escape("input.size(-1) must be equal to input_size"), rnn, x_wrong)
3427
3428    @unittest.skipIf(not TEST_CUDNN, 'CUDNN not available')
3429    @skipIfRocm
3430    def test_cudnn_weight_format(self):
3431        rnns = [
3432            nn.LSTM(10, 20, batch_first=True),
3433            nn.LSTM(10, 20, batch_first=True, proj_size=10),
3434            nn.GRU(10, 20, batch_first=True),
3435            nn.RNN(10, 20, batch_first=True)
3436        ]
3437        first_warn = True
3438        for rnn in rnns:
3439            rnn.cuda()
3440            input = torch.randn(5, 4, 10, requires_grad=True, device="cuda")
3441            hx = torch.randn(1, 5, 20, requires_grad=True, device="cuda")
3442            all_vars = [input, hx] + list(rnn.parameters())
3443            if isinstance(rnn, nn.LSTM):
3444                # LSTM with projections has different hx size
3445                if rnn.proj_size > 0:
3446                    hx = torch.randn(1, 5, 10, requires_grad=True, device="cuda")
3447                    all_vars[1] = hx
3448                cx = torch.randn(1, 5, 20, requires_grad=True, device="cuda")
3449                all_vars[2:2] = [cx]
3450                hx = (hx, cx)
3451
3452            output = rnn(input, hx)
3453            output[0].sum().backward()
3454            grads = [v.grad.data.clone() for v in all_vars]
3455            for v in all_vars:
3456                v.grad.data.zero_()
3457
3458            # Weights will no longer view onto the same chunk of memory
3459            weight = all_vars[4]
3460            weight_data = weight.data.clone()
3461            with torch.no_grad():
3462                weight.set_(weight_data)
3463
3464            for _ in range(2):
3465                with warnings.catch_warnings(record=True) as w:
3466                    output_noncontig = rnn(input, hx)
3467                if first_warn:
3468                    self.assertEqual(len(w), 1)
3469                    self.assertIn('weights are not part of single contiguous chunk of memory', w[0].message.args[0])
3470                    first_warn = False
3471                    warnings.resetwarnings()
3472                output_noncontig[0].sum().backward()
3473                grads_noncontig = [v.grad.data.clone() for v in all_vars]
3474                for v in all_vars:
3475                    v.grad.data.zero_()
3476                self.assertEqual(output, output_noncontig)
3477                self.assertEqual(grads_noncontig, grads)
3478
3479            # Make sure these still share storage
3480            weight_data[:] = 4
3481            self.assertEqual(weight_data, all_vars[4].data)
3482
3483    @unittest.skipIf(not TEST_CUDNN, 'CUDNN not available')
3484    def test_cudnn_weight_tying(self):
3485        rnns = [
3486            nn.LSTM(10, 20, batch_first=True, bidirectional=True),
3487            nn.LSTM(10, 20, batch_first=True, bidirectional=True, proj_size=10),
3488            nn.GRU(10, 20, batch_first=True, bidirectional=True),
3489            nn.RNN(10, 20, batch_first=True, bidirectional=True)
3490        ]
3491        for rnn in rnns:
3492            rnn.bias_ih_l0_reverse = rnn.bias_ih_l0
3493            rnn.cuda()
3494            input = torch.randn(5, 4, 10, requires_grad=True, device="cuda")
3495            hx = torch.randn(2, 5, 20, requires_grad=True, device="cuda")
3496            all_vars = [input, hx] + list(rnn.parameters())
3497            opt = torch.optim.SGD(rnn.parameters(), lr=0.1)
3498            opt.zero_grad()
3499            if isinstance(rnn, nn.LSTM):
3500                # LSTM with projections has different hx size
3501                if rnn.proj_size > 0:
3502                    hx = torch.randn(2, 5, 10, requires_grad=True, device="cuda")
3503                    all_vars[1] = hx
3504                cx = torch.randn(2, 5, 20, requires_grad=True, device="cuda")
3505                all_vars[2:2] = [cx]
3506                hx = (hx, cx)
3507
3508            with warnings.catch_warnings(record=True) as w:
3509                output = rnn(input, hx)
3510            output[0].sum().backward()
3511
3512            opt.step()
3513            with warnings.catch_warnings(record=True) as w:
3514                output_cuda = rnn(input, hx)
3515            rnn.cpu()
3516            hx = (hx[0].cpu(), hx[1].cpu()) if isinstance(rnn, nn.LSTM) else hx.cpu()
3517            output_cpu = rnn(input.cpu(), hx)
3518            self.assertEqual(output_cuda, output_cpu)
3519
3520
3521    def test_transformer_args_check(self):
3522        model_name = 'Transformer'
3523        d_model = 128
3524        nhead = 4
3525        num_encoder_layers = 2
3526        num_decoder_layers = 3
3527        dim_feedforward = 65
3528        dropout = 0.3
3529        bsz = 3
3530        seq_len = 35
3531        tgt_len = 15
3532        activations = [F.relu, F.gelu]
3533
3534        wrong_bsz = 7
3535        wrong_d_model = 63
3536        wrong_nhead = 5
3537        wrong_activation = "abc"
3538
3539        def test(encoder_input_shape, decoder_input_shape,
3540                 src_mask_len=None, tgt_mask_len=None, memory_mask_size=None,
3541                 src_key_padding_mask_size=None, tgt_key_padding_mask_size=None,
3542                 memory_key_padding_mask_size=None,
3543                 src_is_causal=False, tgt_is_causal=False,
3544                 memory_is_causal=False):
3545
3546            encoder_input = torch.randn(encoder_input_shape)
3547            decoder_input = torch.randn(decoder_input_shape)
3548            model = getattr(nn, model_name)(d_model, nhead, num_encoder_layers,
3549                                            num_decoder_layers, dim_feedforward, dropout)
3550
3551            if src_mask_len is not None:
3552                src_mask = model.generate_square_subsequent_mask(src_mask_len)
3553            else:
3554                src_mask = None
3555
3556            if tgt_mask_len is not None:
3557                tgt_mask = model.generate_square_subsequent_mask(tgt_mask_len)
3558            else:
3559                tgt_mask = None
3560
3561            if memory_mask_size is not None:
3562                memory_task = torch.rand(memory_mask_size)
3563            else:
3564                memory_task = None
3565
3566            if src_key_padding_mask_size is not None:
3567                src_key_padding_mask = torch.rand(src_key_padding_mask_size) >= 0.5
3568            else:
3569                src_key_padding_mask = None
3570
3571            if tgt_key_padding_mask_size is not None:
3572                tgt_key_padding_mask = torch.rand(tgt_key_padding_mask_size) >= 0.5
3573            else:
3574                tgt_key_padding_mask = None
3575
3576            if memory_key_padding_mask_size is not None:
3577                memory_key_padding_mask = torch.rand(memory_key_padding_mask_size) >= 0.5
3578            else:
3579                memory_key_padding_mask = None
3580
3581            with self.assertRaises(RuntimeError):
3582                model(encoder_input, decoder_input,
3583                      src_mask=src_mask,
3584                      tgt_mask=tgt_mask,
3585                      memory_mask=memory_task,
3586                      src_key_padding_mask=src_key_padding_mask,
3587                      tgt_key_padding_mask=tgt_key_padding_mask,
3588                      memory_key_padding_mask=memory_key_padding_mask,
3589                      src_is_causal=src_is_causal,
3590                      tgt_is_causal=tgt_is_causal,
3591                      memory_is_causal=memory_is_causal)
3592
3593
3594        correct_encoder_input_shape = (seq_len, bsz, d_model)
3595        correct_decoder_input_shape = (tgt_len, bsz, d_model)
3596
3597        def update_shape(shape, dim, new_dim_size):
3598            new_shape = list(shape)
3599            new_shape[dim] = new_dim_size
3600            return tuple(new_shape)
3601
3602        # Incorrect encoder_input batch size
3603        encoder_input_shape = update_shape(correct_encoder_input_shape, 1, wrong_bsz)
3604        decoder_input_shape = correct_decoder_input_shape
3605        test(encoder_input_shape, decoder_input_shape)
3606
3607        # Incorrect decoder_input batch size
3608        encoder_input_shape = correct_encoder_input_shape
3609        decoder_input_shape = update_shape(correct_decoder_input_shape, 1, wrong_bsz)
3610        test(encoder_input_shape, decoder_input_shape)
3611
3612        # Incorrect encoder_input input size
3613        encoder_input_shape = update_shape(correct_encoder_input_shape, 2, wrong_d_model)
3614        decoder_input_shape = correct_decoder_input_shape
3615        test(encoder_input_shape, decoder_input_shape)
3616
3617        # Incorrect decoder_input input size
3618        encoder_input_shape = correct_encoder_input_shape
3619        decoder_input_shape = update_shape(correct_decoder_input_shape, 2, wrong_d_model)
3620        test(encoder_input_shape, decoder_input_shape)
3621
3622        # Incorrect nhead
3623        encoder_input_shape = correct_encoder_input_shape
3624        decoder_input_shape = correct_decoder_input_shape
3625        with self.assertRaises(AssertionError):
3626            model = getattr(nn, model_name)(d_model, wrong_nhead, num_encoder_layers,
3627                                            num_decoder_layers, dim_feedforward, dropout)
3628
3629        # Incorrect src_mask
3630        encoder_input_shape = correct_encoder_input_shape
3631        decoder_input_shape = correct_decoder_input_shape
3632        wrong_src_mask_size = seq_len + 1
3633        test(encoder_input_shape, decoder_input_shape, src_mask_len=wrong_src_mask_size)
3634
3635        # Incorrect tgt_mask
3636        encoder_input_shape = correct_encoder_input_shape
3637        decoder_input_shape = correct_decoder_input_shape
3638        wrong_tgt_mask_size = tgt_len + 1
3639        test(encoder_input_shape, decoder_input_shape, tgt_mask_len=wrong_tgt_mask_size)
3640
3641        # Incorrect memory_mask
3642        encoder_input_shape = correct_encoder_input_shape
3643        decoder_input_shape = correct_decoder_input_shape
3644        wrong_tgt_mask_size = tgt_len + 1
3645        test(encoder_input_shape, decoder_input_shape,
3646             memory_mask_size=(wrong_tgt_mask_size, wrong_src_mask_size))
3647
3648        # Incorrect src_key_padding_mask
3649        encoder_input_shape = correct_encoder_input_shape
3650        decoder_input_shape = correct_decoder_input_shape
3651        with self.assertRaises(AssertionError):
3652            test(encoder_input_shape, decoder_input_shape,
3653                 src_key_padding_mask_size=(wrong_bsz, wrong_src_mask_size))
3654
3655        # Incorrect tgt_key_padding_mask
3656        encoder_input_shape = correct_encoder_input_shape
3657        decoder_input_shape = correct_decoder_input_shape
3658        with self.assertRaises(AssertionError):
3659            test(encoder_input_shape, decoder_input_shape,
3660                 tgt_key_padding_mask_size=(wrong_bsz, wrong_tgt_mask_size))
3661
3662        # Incorrect memory_key_padding_mask
3663        encoder_input_shape = correct_encoder_input_shape
3664        decoder_input_shape = correct_decoder_input_shape
3665        with self.assertRaises(AssertionError):
3666            test(encoder_input_shape, decoder_input_shape,
3667                 memory_key_padding_mask_size=(wrong_bsz, wrong_src_mask_size))
3668
3669        # Correct activations
3670        for activation in activations:
3671            model = getattr(nn, model_name)(d_model, nhead, num_encoder_layers, num_decoder_layers,
3672                                            dim_feedforward, dropout, activation)
3673        # Incorrect activation
3674        with self.assertRaises(RuntimeError):
3675            model = getattr(nn, model_name)(d_model, nhead, num_encoder_layers, num_decoder_layers,
3676                                            dim_feedforward, dropout, wrong_activation)
3677
3678
3679    def test_transformer_layer_args_check(self):
3680        model_names = ['TransformerEncoderLayer', 'TransformerDecoderLayer']
3681        d_model = 128
3682        nhead = 4
3683        dim_feedforward = 65
3684        dropout = 0.3
3685        bsz = 3
3686        seq_len = 35
3687        tgt_len = 15
3688        activations = [F.relu, F.gelu]
3689
3690        wrong_activation = "abc"
3691
3692        encoder_input_shape = (seq_len, bsz, d_model)
3693        decoder_input_shape = (tgt_len, bsz, d_model)
3694
3695        encoder_input = torch.randn(encoder_input_shape)
3696        decoder_input = torch.randn(decoder_input_shape)
3697
3698        for model_name in model_names:
3699            for activation in activations:
3700                model = getattr(nn, model_name)(d_model, nhead, dim_feedforward,
3701                                                dropout, activation)
3702        # Incorrect activation
3703        for model_name in model_names:
3704            with self.assertRaises(RuntimeError):
3705                model = getattr(nn, model_name)(d_model, nhead, dim_feedforward,
3706                                                dropout, wrong_activation)
3707
3708    def test_rnn_args_check(self):
3709        input_size = 3
3710        hidden_size = 5
3711        num_layers = 2
3712        batch_size = 4
3713        seq_len = 6
3714        num_directions = 1
3715        bad_size = 7  # prime number so that no size can divide it.
3716
3717        def test(input_shape, hidden_shape, mode):
3718            for input, hidden in get_inputs(input_shape, hidden_shape, mode):
3719                model = getattr(nn, mode)(input_size, hidden_size, num_layers)
3720                self.assertRaises(RuntimeError, lambda: model(input, hidden))
3721
3722        correct_input_shape = (seq_len, batch_size, input_size)
3723        correct_hidden_shape = (num_layers * num_directions, batch_size, hidden_size)
3724
3725        def update_shape(shape, dim, new_dim_size):
3726            new_shape = list(shape)
3727            new_shape[dim] = new_dim_size
3728            return tuple(new_shape)
3729
3730        def get_inputs(input_shape, hidden_shape, mode):
3731            '''returns list( tuple(input, hidden) )
3732            where input, hidden are inputs to a model'''
3733            input = torch.randn(input_shape)
3734            hidden = torch.randn(hidden_shape)
3735            if mode != 'LSTM':
3736                return [(input, hidden)]
3737            if hidden_shape == correct_hidden_shape:
3738                return [(input, (hidden, hidden))]
3739            good_hidden = torch.randn(correct_hidden_shape)
3740            return [
3741                (input, (hidden, good_hidden)),
3742                (input, (good_hidden, hidden)),
3743            ]
3744
3745        rnn_modes = ['RNN', 'GRU', 'LSTM']
3746        for mode in rnn_modes:
3747            # Incorrect input batch size
3748            input_shape = update_shape(correct_input_shape, 1, bad_size)
3749            hidden_shape = correct_hidden_shape
3750            test(input_shape, hidden_shape, mode)
3751
3752            # Incorrect hidden batch size
3753            input_shape = correct_input_shape
3754            hidden_shape = update_shape(correct_hidden_shape, 1, bad_size)
3755            test(input_shape, hidden_shape, mode)
3756
3757            # Incorrect input size
3758            input_shape = update_shape(correct_input_shape, 2, bad_size)
3759            hidden_shape = correct_hidden_shape
3760            test(input_shape, hidden_shape, mode)
3761
3762            # Incorrect hidden size
3763            input_shape = correct_input_shape
3764            hidden_shape = update_shape(correct_hidden_shape, 2, bad_size)
3765            test(input_shape, hidden_shape, mode)
3766
3767            # Incorrect hidden[0]
3768            input_shape = correct_input_shape
3769            hidden_shape = update_shape(correct_hidden_shape, 0, bad_size)
3770            test(input_shape, hidden_shape, mode)
3771
3772    def test_projections_lstm_args_check(self):
3773        input_size = 3
3774        hidden_size = 5
3775        proj_size = 2
3776        num_layers = 2
3777        batch_size = 4
3778        seq_len = 6
3779        num_directions = 1
3780        bad_size = 7  # prime number so that no size can divide it.
3781
3782        def test(input_shape, hidden_h_shape, hidden_c_shape):
3783            for input, hidden in get_inputs(input_shape, hidden_h_shape, hidden_c_shape):
3784                model = nn.LSTM(input_size, hidden_size, num_layers, proj_size=proj_size)
3785                self.assertRaises(RuntimeError, lambda: model(input, hidden))
3786
3787        correct_input_shape = (seq_len, batch_size, input_size)
3788        correct_hidden_h_shape = (num_layers * num_directions, batch_size, proj_size)
3789        correct_hidden_c_shape = (num_layers * num_directions, batch_size, hidden_size)
3790
3791        def update_shape(shape, dim, new_dim_size):
3792            new_shape = list(shape)
3793            new_shape[dim] = new_dim_size
3794            return tuple(new_shape)
3795
3796        def get_inputs(input_shape, hidden_h_shape, hidden_c_shape):
3797            '''returns list( tuple(input, hidden) )
3798            where input, hidden are inputs to a model'''
3799            input = torch.randn(input_shape)
3800            hidden_h = torch.randn(hidden_h_shape)
3801            hidden_c = torch.randn(hidden_c_shape)
3802            return [(input, (hidden_h, hidden_c))]
3803
3804        # Incorrect input batch size
3805        input_shape = update_shape(correct_input_shape, 1, bad_size)
3806        test(input_shape, correct_hidden_h_shape, correct_hidden_c_shape)
3807
3808        # Incorrect hidden batch size
3809        input_shape = correct_input_shape
3810        hidden_h_shape = update_shape(correct_hidden_h_shape, 1, bad_size)
3811        hidden_c_shape = update_shape(correct_hidden_c_shape, 1, bad_size)
3812        test(input_shape, hidden_h_shape, hidden_c_shape)
3813
3814        # Incorrect input size
3815        input_shape = update_shape(correct_input_shape, 2, bad_size)
3816        test(input_shape, correct_hidden_h_shape, correct_hidden_c_shape)
3817
3818        # Incorrect hidden size
3819        input_shape = correct_input_shape
3820        hidden_h_shape = update_shape(correct_hidden_h_shape, 2, bad_size)
3821        hidden_c_shape = update_shape(correct_hidden_c_shape, 2, bad_size)
3822        test(input_shape, hidden_h_shape, hidden_c_shape)
3823
3824        # Incorrect hidden[0]
3825        input_shape = correct_input_shape
3826        hidden_h_shape = update_shape(correct_hidden_h_shape, 0, bad_size)
3827        hidden_c_shape = update_shape(correct_hidden_c_shape, 0, bad_size)
3828        test(input_shape, hidden_h_shape, hidden_c_shape)
3829
3830        # Incorrect proj size = hidden size
3831        input_shape = correct_input_shape
3832        hidden_h_shape = update_shape(correct_hidden_h_shape, 0, hidden_size)
3833        hidden_c_shape = correct_hidden_c_shape
3834        test(input_shape, hidden_h_shape, hidden_c_shape)
3835
3836        # Incorrect proj size != hidden size
3837        input_shape = correct_input_shape
3838        hidden_h_shape = update_shape(correct_hidden_h_shape, 0, bad_size)
3839        hidden_c_shape = correct_hidden_c_shape
3840        test(input_shape, hidden_h_shape, hidden_c_shape)
3841
3842        # Incorrect cell size != hidden size
3843        input_shape = correct_input_shape
3844        hidden_h_shape = correct_hidden_h_shape
3845        hidden_c_shape = update_shape(correct_hidden_c_shape, 0, bad_size)
3846        test(input_shape, hidden_h_shape, hidden_c_shape)
3847
3848    @unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported")
3849    def test_rnn_check_device(self):
3850        import copy
3851        input_size = 3
3852        hidden_size = 5
3853        num_layers = 2
3854        batch_size = 4
3855        seq_len = 6
3856        num_directions = 1
3857
3858        correct_input_shape = (seq_len, batch_size, input_size)
3859        correct_hidden_shape = (num_layers * num_directions, batch_size, hidden_size)
3860        rnn_modes = ['RNN', 'GRU', 'LSTM']
3861
3862        for mode in rnn_modes:
3863            model = getattr(nn, mode)(input_size, hidden_size, num_layers)
3864            model_cuda = copy.deepcopy(model).to('cuda:0')
3865            input = torch.randn(correct_input_shape)
3866            hidden = torch.randn(correct_hidden_shape)
3867
3868            # input and weights are not at the same device
3869            with self.assertRaisesRegex(RuntimeError,
3870                                        "Input and parameter tensors are not at the same device"):
3871                model(input.to('cuda:0'))
3872            with self.assertRaisesRegex(RuntimeError,
3873                                        "Input and parameter tensors are not at the same device"):
3874                model_cuda(input)
3875
3876            # input and hiddens are not at the same device
3877            with self.assertRaisesRegex(RuntimeError,
3878                                        r"Input and hidden tensors are not at the same device"):
3879                if mode == 'LSTM':
3880                    model(input, (hidden.to('cuda:0'), hidden.to('cuda:0')))
3881                else:
3882                    model(input, (hidden.to('cuda:0')))
3883            with self.assertRaisesRegex(RuntimeError,
3884                                        r"Input and hidden tensors are not at the same device"):
3885                if mode == 'LSTM':
3886                    model_cuda(input.to('cuda:0'), (hidden, hidden))
3887                else:
3888                    model_cuda(input.to('cuda:0'), (hidden))
3889
3890            # hidden tensors are not at the same CUDA device
3891            if mode == 'LSTM':
3892                with self.assertRaisesRegex(RuntimeError,
3893                                            "Input and hidden tensors are not at the same device"):
3894                    model(input.to('cuda:0'), (hidden.to('cuda:0'), hidden.to('cuda:1')))
3895
3896    @unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported")
3897    def test_projections_lstm_check_device(self):
3898        input_size = 3
3899        hidden_size = 5
3900        proj_size = 2
3901        num_layers = 2
3902        batch_size = 4
3903        seq_len = 6
3904        num_directions = 1
3905
3906        correct_input_shape = (seq_len, batch_size, input_size)
3907        correct_hidden_h_shape = (num_layers * num_directions, batch_size, proj_size)
3908        correct_hidden_c_shape = (num_layers * num_directions, batch_size, hidden_size)
3909
3910        model = nn.LSTM(input_size, hidden_size, num_layers, proj_size=proj_size)
3911        input = torch.randn(correct_input_shape)
3912        hidden_h = torch.randn(correct_hidden_h_shape)
3913        hidden_c = torch.randn(correct_hidden_c_shape)
3914
3915        # input and weights are not at the same device
3916        with self.assertRaisesRegex(RuntimeError,
3917                                    "Input and parameter tensors are not at the same device"):
3918            model(input.to('cuda:0'))
3919
3920        # input and hiddens are not at the same device
3921        with self.assertRaisesRegex(RuntimeError,
3922                                    r"Input and hidden tensors are not at the same device"):
3923            model(input, (hidden_h.to('cuda:0'), hidden_c.to('cuda:0')))
3924
3925        # hidden tensors are not at the same CUDA device
3926        with self.assertRaisesRegex(RuntimeError,
3927                                    "Input and hidden tensors are not at the same device"):
3928            model(input.to('cuda:0'), (hidden_h.to('cuda:0'), hidden_c.to('cuda:1')))
3929
3930    def test_rnn_initial_hidden_state(self):
3931        rnn_modes = ['RNN', 'GRU', 'LSTM']
3932        for mode in rnn_modes:
3933            rnn = getattr(nn, mode)(30, 20, 2)
3934            input = torch.randn(10, 32, 30)
3935            hidden = torch.zeros(2, 32, 20)
3936
3937            if mode == 'LSTM':
3938                hidden = (hidden, hidden)
3939            output1, hidden1 = rnn(input, hidden)
3940            output2, hidden2 = rnn(input)
3941            self.assertEqual(output1, output2)
3942            self.assertEqual(hidden1, hidden2)
3943
3944    def test_projections_lstm_initial_hidden_state(self):
3945        for bidir in [False, True]:
3946            rnn = nn.LSTM(30, 20, 2, bidirectional=bidir, proj_size=10)
3947            num_dirs = 2 if bidir else 1
3948            input = torch.randn(10, 32, 30)
3949            hidden_h = torch.zeros(2 * num_dirs, 32, 10)
3950            hidden_c = torch.zeros(2 * num_dirs, 32, 20)
3951            hidden = (hidden_h, hidden_c)
3952            output1, hidden1 = rnn(input, hidden)
3953            output2, hidden2 = rnn(input)
3954            self.assertEqual(output1, output2)
3955            self.assertEqual(hidden1, hidden2)
3956
3957    def test_projections_errors_on_gru_and_rnn(self):
3958        error_msg = "proj_size argument is only supported for LSTM, not RNN or GRU"
3959        for mode in ['RNN', 'GRU']:
3960            with self.assertRaisesRegex(ValueError, error_msg):
3961                rnn = getattr(nn, mode)(30, 20, 2, proj_size=10)
3962
3963    def _test_RNN_cpu_vs_cudnn(self, dropout, dtype=torch.double):
3964
3965        def forward_backward(cuda, rnn, input_val, grad_output, weights_val, hx_val, grad_hy,
3966                             cx_val=None, grad_cy=None):
3967            is_lstm = isinstance(rnn, nn.LSTM)
3968
3969            for x_layer, y_layer in zip(rnn.all_weights, weights_val):
3970                for x, y in zip(x_layer, y_layer):
3971                    x.data.copy_(y.data)
3972
3973            if isinstance(input_val, rnn_utils.PackedSequence):
3974                input = rnn_utils.PackedSequence(
3975                    input_val.data.data.requires_grad_(True), input_val.batch_sizes)
3976                input_var = input.data
3977            else:
3978                input = input_val.clone().requires_grad_(True)
3979                input_var = input
3980            if is_lstm:
3981                if cx_val is None:
3982                    hx = (hx_val.clone().requires_grad_(True),
3983                          hx_val.add(1).requires_grad_(True))
3984                else:
3985                    hx = (hx_val.clone().requires_grad_(True),
3986                          cx_val.add(1).requires_grad_(True))
3987            else:
3988                hx = hx_val.clone().requires_grad_(True)
3989
3990            if cuda:
3991                rnn.cuda()
3992                input_var.data = input_var.data.cuda()
3993                if is_lstm:
3994                    hx[0].data = hx[0].data.cuda()
3995                    hx[1].data = hx[1].data.cuda()
3996                else:
3997                    hx.data = hx.data.cuda()
3998                grad_hy = grad_hy.cuda()
3999                if grad_cy is not None:
4000                    grad_cy = grad_cy.cuda()
4001                grad_output = grad_output.cuda()
4002
4003            output, hy = rnn(input, hx)
4004
4005            if isinstance(output, rnn_utils.PackedSequence):
4006                output = output.data
4007
4008            if is_lstm:
4009                if grad_cy is None:
4010                    torch.autograd.backward([output, hy[0], hy[1]], [grad_output, grad_hy, grad_hy + 1])
4011                else:
4012                    torch.autograd.backward([output, hy[0], hy[1]], [grad_output, grad_hy, grad_cy + 1])
4013            else:
4014                torch.autograd.backward([output, hy], [grad_output, grad_hy])
4015
4016            return {'output': output.data,
4017                    'hy': hy[0].data if is_lstm else hy.data,
4018                    'weights': rnn.all_weights,
4019                    'grad_input': input_var.grad.data,
4020                    'grad_hx': hx[0].grad.data if is_lstm else hx.grad.data,
4021                    'cy': hy[1].data if is_lstm else None,
4022                    'grad_cx': hx[1].grad.data if is_lstm else None}
4023
4024        input_size = 10
4025        hidden_size = 6
4026        proj_size = 3
4027        num_layers = 2
4028        seq_length = 7
4029        batch = 6
4030
4031        def make_noncontig(tensor):
4032            ndim = tensor.dim()
4033            return torch.stack([tensor.clone().zero_(), tensor], ndim).select(ndim, 1)
4034
4035        def compare_cpu_gpu(outputs_cpu, outputs_gpu):
4036            self.assertEqual(list(outputs_cpu.keys()), list(outputs_gpu.keys()))
4037            for key in outputs_cpu.keys():
4038                if key != 'weights':
4039                    self.assertEqual(outputs_cpu[key], outputs_gpu[key], atol=5e-5, rtol=0, msg=key)
4040
4041            # check grad weights separately, as nested dict
4042            for cpu_layer_weight, gpu_layer_weight in zip(outputs_cpu['weights'], outputs_gpu['weights']):
4043                for (cpu_weight, gpu_weight) in zip(cpu_layer_weight, gpu_layer_weight):
4044                    self.assertEqual(cpu_weight.grad.data, gpu_weight.grad.data, atol=5e-5, rtol=0)
4045
4046        for module in (nn.RNN, nn.LSTM, nn.GRU):
4047            for bias, bidirectional, batch_first, contig, variable_len, lens_as_tensor \
4048                    in product((True, False), repeat=6):
4049
4050                num_directions = 2 if bidirectional else 1
4051                if batch_first:
4052                    input_val = torch.randn(batch, seq_length, input_size, dtype=dtype)
4053                    grad_output = torch.randn(batch, seq_length, hidden_size * num_directions, dtype=dtype)
4054                else:
4055                    input_val = torch.randn(seq_length, batch, input_size, dtype=dtype)
4056                    grad_output = torch.randn(seq_length, batch, hidden_size * num_directions, dtype=dtype)
4057
4058                hx_val = torch.randn(num_layers * num_directions, batch, hidden_size, dtype=dtype)
4059                grad_hy = torch.randn(num_layers * num_directions, batch, hidden_size, dtype=dtype)
4060
4061                if not contig:
4062                    grad_output = make_noncontig(grad_output)
4063                    grad_hy = make_noncontig(grad_hy)
4064                    input_var = make_noncontig(input_val)
4065                    hx_val = make_noncontig(hx_val)
4066
4067                if variable_len:
4068                    lengths = [7, 5, 5, 2, 1, 1]
4069                    if lens_as_tensor:
4070                        lengths = torch.tensor(lengths, dtype=torch.long)
4071                    input_val = rnn_utils.pack_padded_sequence(input_val, lengths, batch_first=batch_first)
4072                    grad_output = rnn_utils.pack_padded_sequence(grad_output, lengths, batch_first=batch_first).data
4073
4074                rnn = module(input_size,
4075                             hidden_size,
4076                             num_layers,
4077                             bias=bias,
4078                             dropout=dropout,
4079                             bidirectional=bidirectional,
4080                             batch_first=batch_first).to(dtype)
4081
4082                outputs_cpu = forward_backward(
4083                    False, rnn, input_val, grad_output, rnn.all_weights, hx_val, grad_hy)
4084
4085                rnn_gpu = module(input_size,
4086                                 hidden_size,
4087                                 num_layers,
4088                                 bias=bias,
4089                                 dropout=dropout,
4090                                 bidirectional=bidirectional,
4091                                 batch_first=batch_first).to(dtype)
4092
4093                outputs_gpu = forward_backward(
4094                    True, rnn_gpu, input_val, grad_output, rnn.all_weights, hx_val, grad_hy)
4095
4096                compare_cpu_gpu(outputs_cpu, outputs_gpu)
4097
4098        for nonlinearity in ('tanh', 'relu'):
4099            hx_val = torch.randn(num_layers, batch, hidden_size, dtype=dtype)
4100            input_val = torch.randn(seq_length, batch, input_size, dtype=dtype)
4101            grad_output = torch.randn(
4102                seq_length, batch, hidden_size * num_directions, dtype=dtype)
4103            grad_hy = torch.randn(
4104                num_layers * num_directions, batch, hidden_size, dtype=dtype)
4105
4106            rnn = nn.RNN(input_size, hidden_size, num_layers, bias=bias, nonlinearity=nonlinearity).to(dtype)
4107            outputs_cpu = forward_backward(False, rnn, input_val, grad_output, rnn.all_weights, hx_val, grad_hy)
4108
4109            rnn_gpu = nn.RNN(input_size, hidden_size, num_layers, bias=bias, nonlinearity=nonlinearity).to(dtype)
4110            outputs_gpu = forward_backward(True, rnn_gpu, input_val, grad_output, rnn.all_weights, hx_val, grad_hy)
4111
4112            compare_cpu_gpu(outputs_cpu, outputs_gpu)
4113
4114        # checking LSTM with projections
4115        for bias, bidirectional, batch_first, contig, variable_len, lens_as_tensor \
4116                in product((True, False), repeat=6):
4117            num_directions = 2 if bidirectional else 1
4118            if batch_first:
4119                input_val = torch.randn(batch, seq_length, input_size, dtype=dtype)
4120                grad_output = torch.randn(batch, seq_length, proj_size * num_directions, dtype=dtype)
4121            else:
4122                input_val = torch.randn(seq_length, batch, input_size, dtype=dtype)
4123                grad_output = torch.randn(seq_length, batch, proj_size * num_directions, dtype=dtype)
4124
4125            hx_val = torch.randn(num_layers * num_directions, batch, proj_size, dtype=dtype)
4126            cx_val = torch.randn(num_layers * num_directions, batch, hidden_size, dtype=dtype)
4127            grad_hy = torch.randn(num_layers * num_directions, batch, proj_size, dtype=dtype)
4128            grad_cy = torch.randn(num_layers * num_directions, batch, hidden_size, dtype=dtype)
4129
4130            if not contig:
4131                grad_output = make_noncontig(grad_output)
4132                grad_hy = make_noncontig(grad_hy)
4133                grad_cy = make_noncontig(grad_cy)
4134                input_var = make_noncontig(input_val)
4135                hx_val = make_noncontig(hx_val)
4136                cx_val = make_noncontig(cx_val)
4137
4138            if variable_len:
4139                lengths = [7, 5, 5, 2, 1, 1]
4140                if lens_as_tensor:
4141                    lengths = torch.tensor(lengths, dtype=torch.long)
4142                input_val = rnn_utils.pack_padded_sequence(input_val, lengths, batch_first=batch_first)
4143                grad_output = rnn_utils.pack_padded_sequence(grad_output, lengths, batch_first=batch_first).data
4144
4145            rnn = nn.LSTM(input_size,
4146                          hidden_size,
4147                          num_layers,
4148                          bias=bias,
4149                          dropout=dropout,
4150                          bidirectional=bidirectional,
4151                          batch_first=batch_first,
4152                          proj_size=proj_size).to(dtype)
4153
4154            outputs_cpu = forward_backward(
4155                False, rnn, input_val, grad_output, rnn.all_weights,
4156                hx_val, grad_hy, cx_val, grad_cy)
4157
4158            rnn_gpu = nn.LSTM(input_size,
4159                              hidden_size,
4160                              num_layers,
4161                              bias=bias,
4162                              dropout=dropout,
4163                              bidirectional=bidirectional,
4164                              batch_first=batch_first,
4165                              proj_size=proj_size).to(dtype)
4166
4167            outputs_gpu = forward_backward(
4168                True, rnn_gpu, input_val, grad_output, rnn.all_weights,
4169                hx_val, grad_hy, cx_val, grad_cy)
4170            compare_cpu_gpu(outputs_cpu, outputs_gpu)
4171
4172    @unittest.skipIf(not TEST_CUDNN, "needs cudnn")
4173    def test_RNN_cpu_vs_cudnn_no_dropout(self):
4174        dtype = torch.double
4175        self._test_RNN_cpu_vs_cudnn(0, dtype)
4176
4177    @unittest.skipIf(not TEST_CUDNN, "needs cudnn")
4178    def test_RNN_cpu_vs_cudnn_with_dropout(self):
4179        # Because of dropout randomness, can only compare dropout=0 and dropout=1
4180        self._test_RNN_cpu_vs_cudnn(1)
4181
4182    @unittest.skipIf(not TEST_CUDNN, "needs cudnn")
4183    def test_RNN_cudnn_weight_norm(self):
4184        input_size = 10
4185        hidden_size = 6
4186        num_layers = 2
4187        seq_length = 7
4188        batch = 6
4189
4190        # runs on CPU to acquire expected output
4191        def check_weight_norm(m, name):
4192            input = torch.randn(seq_length, batch, input_size)
4193            expected_output = m(input)
4194
4195            # adds weight normalization
4196            m = torch.nn.utils.weight_norm(m, name=name)
4197
4198            # moves to CUDA
4199            m = m.cuda()
4200            input = input.cuda()
4201
4202            # otherwise, subsequent warnings will be hidden, and further tests rely on them
4203            warnings.simplefilter("always")
4204            self.assertEqual(m(input), expected_output)
4205
4206            # remove weight norm
4207            m = torch.nn.utils.remove_weight_norm(m, name=name)
4208            self.assertEqual(m(input), expected_output)
4209
4210        check_weight_norm(nn.LSTM(input_size, hidden_size, num_layers), 'weight_hh_l0')
4211        check_weight_norm(nn.LSTM(input_size, hidden_size, num_layers, proj_size=3), 'weight_hr_l0')
4212
4213    @unittest.skipIf(not TEST_CUDA, 'CUDA not available')
4214    def test_partial_flat_weights(self):
4215        input_size = 10
4216        hidden_size = 6
4217        num_layers = 2
4218
4219        m = nn.LSTM(input_size, hidden_size, num_layers)
4220        inp = torch.randn(3, 2, 10)
4221        out_expected = m(inp)
4222        # deletes an attribute of original LSTM
4223        weight_orig = m.weight_hh_l0
4224        del m.weight_hh_l0
4225        self.assertFalse(hasattr(m, "weight_hh_l0"))
4226        # verifies that moving to CUDA with only some attributes defined
4227        # does not throw an error
4228        m.cuda()
4229        # recompute the weight and make sure that module can be used
4230        m.weight_hh_l0 = weight_orig.cuda()
4231        inp = inp.cuda()
4232        # otherwise, subsequent warnings will be hidden, and further tests rely on them
4233        warnings.simplefilter("always")
4234        self.assertEqual(m(inp)[0].cpu(), out_expected[0])
4235
4236    @unittest.skipIf(not TEST_CUDNN, "needs cudnn")
4237    @set_default_dtype(torch.double)
4238    def test_RNN_dropout(self):
4239        # checking the assumption that cuDNN sticks dropout in between
4240        # RNN layers
4241        for p in (0, 0.276, 0.731, 1):
4242            for train in (True, False):
4243                for cuda in (True, False):
4244                    rnn = nn.RNN(10, 1000, 2, bias=False, dropout=p, nonlinearity='relu')
4245                    if cuda:
4246                        rnn.cuda()
4247
4248                    if train:
4249                        rnn.train()
4250                    else:
4251                        rnn.eval()
4252                    rnn.weight_ih_l0.data.fill_(1)
4253                    rnn.weight_hh_l0.data.fill_(1)
4254                    rnn.weight_ih_l1.data.fill_(1)
4255                    rnn.weight_hh_l1.data.fill_(1)
4256                    input = torch.ones(1, 1, 10)
4257                    hx = torch.zeros(2, 1, 1000)
4258                    if cuda:
4259                        input = input.cuda()
4260                        hx = hx.cuda()
4261
4262                    output, hy = rnn(input, hx)
4263                    self.assertEqual(output.data.min(), output.data.max())
4264                    output_val = output.data[0][0][0]
4265                    if p == 0 or not train:
4266                        self.assertEqual(output_val, 10000)
4267                    elif p == 1:
4268                        self.assertEqual(output_val, 0)
4269                    else:
4270                        self.assertGreater(output_val, 8000)
4271                        self.assertLess(output_val, 12000)
4272                        denorm_mod = (output_val * (1 - p)) % 10
4273                        self.assertLess(min(denorm_mod, 10 - denorm_mod), 1e-2)
4274
4275                    self.assertEqual(hy[0].data.min(), hy[0].data.max())
4276                    self.assertEqual(hy[1].data.min(), hy[1].data.max())
4277                    self.assertEqual(hy.data[0][0][0], 10)
4278                    self.assertEqual(hy.data[1][0][0], output_val)
4279
4280    @unittest.skipIf(not TEST_CUDNN, "needs cudnn")
4281    @set_default_dtype(torch.double)
4282    def test_error_RNN_seq_len_zero(self):
4283        # checking error message when RNN has seq_len = 0
4284        for module in (nn.RNN, nn.LSTM, nn.GRU):
4285            for bidirectional in [True, False]:
4286                for device in get_all_device_types():
4287                    input = torch.ones(0, 10, 5)
4288                    rnn = module(5, 6, bidirectional=bidirectional)
4289                    if device == 'cuda':
4290                        rnn.cuda()
4291                        input = input.cuda()
4292
4293                    with self.assertRaisesRegex(RuntimeError, "Expected sequence length to be larger than 0 in RNN"):
4294                        rnn(input)
4295
4296    def test_RNN_input_size_zero(self):
4297        for module in (nn.RNN, nn.LSTM, nn.GRU):
4298            for device in get_all_device_types():
4299                input = torch.zeros((5, 0, 3))
4300                rnn = module(input_size=3, hidden_size=4)
4301                if device == 'cuda':
4302                    rnn.cuda()
4303                    input = input.cuda()
4304                outs = rnn(input)
4305                self.assertEqual(outs[0].shape, torch.Size([5, 0, 4]))
4306                # Check that backward does not cause a hard error
4307                outs[0].sum().backward()
4308
4309    @unittest.skipIf(not TEST_CUDNN, "needs cudnn")
4310    def test_RNN_dropout_state(self):
4311        for p in (0, 0.1234):
4312            for train in (True, False):
4313                for cuda in (True, False):
4314                    rnn = nn.RNN(100, 100, 2, bias=False, dropout=p, nonlinearity='relu')
4315                    if cuda:
4316                        rnn.cuda()
4317
4318                    if train:
4319                        rnn.train()
4320                    else:
4321                        rnn.eval()
4322                    input = torch.rand(1, 1, 100)
4323                    hx = torch.rand(2, 1, 100)
4324                    if cuda:
4325                        input = input.cuda()
4326                        hx = hx.cuda()
4327
4328                    output1, hy1 = rnn(input, hx)
4329                    output2, hy2 = rnn(input, hx)
4330
4331                    buf = io.BytesIO()
4332                    rnn_pickle = torch.save(rnn, buf)
4333                    buf.seek(0)
4334                    # weights_only=False as this is legacy code that saves the model
4335                    rnn2 = torch.load(buf, weights_only=False)
4336                    rnn2.flatten_parameters()
4337                    output3, hy3 = rnn2(input, hx)
4338
4339                    if p == 0 or not train:
4340                        self.assertEqual(output1, output2)
4341                        self.assertEqual(output1, output3)
4342                        self.assertEqual(hy1, hy2)
4343                        self.assertEqual(hy1, hy3)
4344                    else:
4345                        self.assertNotEqual(output1, output2)
4346                        self.assertNotEqual(output1, output3)
4347                        self.assertNotEqual(hy1, hy2)
4348                        self.assertNotEqual(hy1, hy3)
4349
4350    @unittest.skipIf(not TEST_CUDNN, "needs cudnn")
4351    @set_default_dtype(torch.double)
4352    def test_RNN_change_dropout(self):
4353        for train, cuda in product((True, False), repeat=2):
4354            rnn = nn.RNN(100, 100, 2, dropout=0, nonlinearity='relu')
4355            input = torch.rand(3, 2, 100)
4356            if cuda:
4357                input.data = input.data.cuda()
4358                rnn.cuda()
4359
4360            if train:
4361                rnn.train()
4362            else:
4363                rnn.eval()
4364
4365            prev_output = None
4366            for p in (0, 0.5, 0, 0.7, 0.2, 1, 0.2, 0):
4367                rnn.dropout = p
4368                output1, hy1 = rnn(input)
4369                output2, hy2 = rnn(input)
4370
4371                if p == 0 or p == 1 or not train:
4372                    self.assertEqual(output1, output2)
4373                    self.assertEqual(hy1, hy2)
4374                else:
4375                    self.assertNotEqual(output1, output2)
4376                    self.assertNotEqual(hy1, hy2)
4377
4378                if prev_output is not None:
4379                    if not train:
4380                        self.assertEqual(output1.data, prev_output)
4381                        self.assertEqual(output2.data, prev_output)
4382                    else:
4383                        self.assertNotEqual(output1.data, prev_output)
4384                        self.assertNotEqual(output2.data, prev_output)
4385                prev_output = output1.data
4386
4387    def test_inplace_thnn(self):
4388        modules = [nn.ReLU, nn.ELU, nn.SELU, nn.CELU, nn.RReLU]
4389        for mod in modules:
4390            r = mod(inplace=True)
4391            input = torch.randn(5, 5, requires_grad=True)
4392            output = r(input + 0)
4393            grad_output = torch.randn(5, 5)
4394            grad_output_clone = grad_output.clone()
4395            output.backward(grad_output)
4396            self.assertEqual(grad_output, grad_output_clone)
4397
4398
4399    def test_pixel_shuffle_unshuffle(self):
4400        def _test_pixel_shuffle_unshuffle_helper(num_input_dims, valid_channels_dim=True,
4401                                                 upscale_factor=None):
4402            # Function to imperatively ensure pixels are shuffled to the correct locations.
4403            # Used to validate the batch operations in pixel_shuffle.
4404            def _verify_pixel_shuffle(input, output, upscale_factor):
4405                for c in range(output.size(-3)):
4406                    for h in range(output.size(-2)):
4407                        for w in range(output.size(-1)):
4408                            height_idx = h // upscale_factor
4409                            weight_idx = w // upscale_factor
4410                            channel_idx = (upscale_factor * (h % upscale_factor)) + (w % upscale_factor) + \
4411                                          (c * upscale_factor ** 2)
4412                            self.assertEqual(output[..., c, h, w], input[..., channel_idx, height_idx, weight_idx])
4413
4414            upscale_factor = random.randint(2, 5) if upscale_factor is None else upscale_factor
4415            # If valid_channels_dim=False, add 1 to make channels dim indivisible by upscale_factor ** 2.
4416            channels = random.randint(1, 4) * upscale_factor ** 2 + (0 if valid_channels_dim else 1)
4417            height = random.randint(5, 10)
4418            width = random.randint(5, 10)
4419
4420            if num_input_dims == 1:
4421                input = torch.rand(channels, requires_grad=True)
4422            elif num_input_dims == 2:
4423                input = torch.rand(height, width, requires_grad=True)
4424            else:
4425                batch_sizes = [random.randint(1, 3) for _ in range(num_input_dims - 3)]
4426                input = torch.rand(*batch_sizes, channels, height, width, requires_grad=True)
4427            ps = nn.PixelShuffle(upscale_factor)
4428            pus = nn.PixelUnshuffle(downscale_factor=upscale_factor)
4429
4430            if num_input_dims >= 3 and valid_channels_dim and upscale_factor > 0:
4431                output = ps(input)
4432                _verify_pixel_shuffle(input, output, upscale_factor)
4433                output.backward(output.data)
4434                self.assertEqual(input.data, input.grad.data)
4435
4436                # Ensure unshuffle properly inverts shuffle.
4437                unshuffle_output = pus(output)
4438                self.assertEqual(input, unshuffle_output)
4439            else:
4440                self.assertRaises(RuntimeError, lambda: ps(input))
4441
4442        def _test_pixel_unshuffle_error_case_helper(num_input_dims, valid_height_dim=True, valid_width_dim=True,
4443                                                    downscale_factor=None):
4444            downscale_factor = random.randint(2, 5) if downscale_factor is None else downscale_factor
4445            channels = random.randint(1, 4)
4446            # If valid_height_dim=False, add 1 to make height dim indivisible by downscale_factor.
4447            height = random.randint(3, 5) * abs(downscale_factor) + (0 if valid_height_dim else 1)
4448            # If valid_width_dim=False, add 1 to make width dim indivisible by downscale_factor.
4449            width = random.randint(3, 5) * abs(downscale_factor) + (0 if valid_width_dim else 1)
4450
4451            if num_input_dims == 1:
4452                input = torch.rand(channels, requires_grad=True)
4453            elif num_input_dims == 2:
4454                input = torch.rand(height, width, requires_grad=True)
4455            else:
4456                batch_sizes = [random.randint(1, 3) for _ in range(num_input_dims - 3)]
4457                input = torch.rand(*batch_sizes, channels, height, width, requires_grad=True)
4458
4459            pus = nn.PixelUnshuffle(downscale_factor)
4460            self.assertRaises(RuntimeError, lambda: pus(input))
4461
4462        def _test_pixel_shuffle_unshuffle_for_input_dims(num_input_dims):
4463            # For 1D - 2D, this is an error case.
4464            # For 3D - 5D, this is a success case for pixel_shuffle + pixel_unshuffle.
4465            _test_pixel_shuffle_unshuffle_helper(num_input_dims=num_input_dims)
4466
4467            # Error cases for pixel_shuffle.
4468            _test_pixel_shuffle_unshuffle_helper(num_input_dims=num_input_dims, valid_channels_dim=False)
4469            _test_pixel_shuffle_unshuffle_helper(num_input_dims=num_input_dims, upscale_factor=0)
4470            _test_pixel_shuffle_unshuffle_helper(num_input_dims=num_input_dims, upscale_factor=-2)
4471
4472            # Error cases for pixel_unshuffle.
4473            _test_pixel_unshuffle_error_case_helper(num_input_dims=num_input_dims, valid_height_dim=False)
4474            _test_pixel_unshuffle_error_case_helper(num_input_dims=num_input_dims, valid_width_dim=False)
4475            _test_pixel_unshuffle_error_case_helper(num_input_dims=num_input_dims, downscale_factor=0)
4476            _test_pixel_unshuffle_error_case_helper(num_input_dims=num_input_dims, downscale_factor=-2)
4477
4478        def test_pixel_shuffle_unshuffle_1D():
4479            _test_pixel_shuffle_unshuffle_for_input_dims(num_input_dims=1)
4480
4481        def test_pixel_shuffle_unshuffle_2D():
4482            _test_pixel_shuffle_unshuffle_for_input_dims(num_input_dims=2)
4483
4484        def test_pixel_shuffle_unshuffle_3D():
4485            _test_pixel_shuffle_unshuffle_for_input_dims(num_input_dims=3)
4486
4487        def test_pixel_shuffle_unshuffle_4D():
4488            _test_pixel_shuffle_unshuffle_for_input_dims(num_input_dims=4)
4489
4490        def test_pixel_shuffle_unshuffle_5D():
4491            _test_pixel_shuffle_unshuffle_for_input_dims(num_input_dims=5)
4492
4493        test_pixel_shuffle_unshuffle_1D()
4494        test_pixel_shuffle_unshuffle_2D()
4495        test_pixel_shuffle_unshuffle_3D()
4496        test_pixel_shuffle_unshuffle_4D()
4497        test_pixel_shuffle_unshuffle_5D()
4498
4499    @set_default_dtype(torch.double)
4500    def test_pixel_shuffle_nhwc_cpu(self):
4501        input = torch.randn(3, 18, 4, 4, device='cpu')
4502        input = input.contiguous(memory_format=torch.channels_last).requires_grad_()
4503        grad = torch.randn(3, 18, 4, 4, device='cpu')
4504        ps = torch.nn.PixelShuffle(3)
4505        pus = torch.nn.PixelUnshuffle(3)
4506
4507        ref_input = input.detach().clone().contiguous().requires_grad_(True)
4508        ref_grad = grad.detach().clone().contiguous()
4509        ref_ps = torch.nn.PixelShuffle(3)
4510        ref_pus = torch.nn.PixelUnshuffle(3)
4511
4512        out = pus(ps(input))
4513        out.backward(grad)
4514        ref_out = ref_pus(ref_ps(ref_input))
4515        ref_out.backward(ref_grad)
4516
4517        self.assertTrue(out.is_contiguous(memory_format=torch.channels_last))
4518        self.assertTrue(ref_out.is_contiguous())
4519        self.assertEqual(out, ref_out)
4520        self.assertEqual(input.grad, ref_input.grad)
4521
4522    # These tests should be OpInfo'd
4523    def test_elu_inplace_on_view(self):
4524        v = torch.tensor([1.0, -1.0, 1.0, -1.0], requires_grad=True, dtype=torch.double)
4525
4526        def func(root):
4527            x = root.clone()
4528            view = x.narrow(0, 1, 2)
4529            res = F.elu(view, inplace=True)
4530            self.assertIs(res, view)
4531            return x
4532
4533        gradcheck(func, [v])
4534        gradgradcheck(func, [v])
4535
4536    def test_elu_inplace_gradgrad(self):
4537        v = torch.randn(8, requires_grad=True, dtype=torch.double)
4538
4539        def func(root):
4540            x = root.clone()
4541            return F.elu(x, inplace=True)
4542
4543        gradcheck(func, [v])
4544        gradgradcheck(func, [v])
4545
4546    def test_relu_inplace_on_view(self):
4547        v = torch.tensor([1.0, -1.0, 1.0, -1.0], requires_grad=True, dtype=torch.double)
4548
4549        def func(root):
4550            x = root.clone()
4551            view = x.narrow(0, 1, 2)
4552            res = F.relu(view, inplace=True)
4553            self.assertIs(res, view)
4554            return x
4555
4556        gradcheck(func, [v])
4557        gradgradcheck(func, [v])
4558
4559    def test_PReLU_backward_requires_grad_false(self):
4560        devices = ['cpu']
4561        devices += ['cuda'] if TEST_CUDA else []
4562        for d in devices:
4563            m = nn.PReLU().to(d)
4564            x = torch.randn(2, 3, 4, 5, device=d, requires_grad=False)
4565            y = m(x)
4566            y.mean().backward()
4567            self.assertEqual(x.grad, None)
4568
4569    def test_bce_loss_always_nonnegative(self):
4570        target = torch.ones(5)
4571        input = torch.ones(5)
4572        self.assertEqual((nn.BCELoss()(input, target) < 0).sum(), 0)
4573
4574        target = torch.zeros(5)
4575        input = torch.zeros(5)
4576        self.assertEqual((nn.BCELoss()(input, target) < 0).sum(), 0)
4577
4578    def test_bce_with_logits_raises_if_target_and_input_are_different_size(self):
4579        target = torch.rand(5)
4580        input = torch.rand(5, 1)
4581        with self.assertRaises(ValueError):
4582            nn.BCEWithLogitsLoss()(input, target)
4583
4584        target = torch.rand(5, 1)
4585        input = torch.rand(5)
4586        with self.assertRaises(ValueError):
4587            nn.BCEWithLogitsLoss()(input, target)
4588
4589    def test_bce_with_logits_gives_same_result_as_sigmoid_and_bce_loss(self):
4590        sigmoid = nn.Sigmoid()
4591
4592        target = torch.rand(64, 4)
4593        output = torch.rand(64, 4) - 0.5
4594
4595        self.assertEqual(nn.BCEWithLogitsLoss()(output, target), nn.BCELoss()(sigmoid(output), target))
4596
4597        weight = torch.rand(4)
4598        self.assertEqual(nn.BCEWithLogitsLoss(weight)(output, target), nn.BCELoss(weight)(sigmoid(output), target))
4599
4600        target = torch.zeros(4, 1, dtype=torch.float)
4601        output = torch.empty(4, 1, dtype=torch.float).fill_(-100)
4602
4603        self.assertEqual(nn.BCEWithLogitsLoss()(output, target), nn.BCELoss()(sigmoid(output), target))
4604
4605        self.assertEqual(nn.BCEWithLogitsLoss(reduction='none')(output, target),
4606                         nn.BCELoss(reduction='none')(sigmoid(output), target))
4607
4608        weight = torch.rand(1, dtype=torch.float)
4609        self.assertEqual(nn.BCEWithLogitsLoss(weight)(output, target), nn.BCELoss(weight)(sigmoid(output), target))
4610
4611    def test_bce_loss_input_range(self):
4612        bceloss = nn.BCELoss()
4613
4614        target = torch.rand(25, 25)
4615        output_valid = torch.rand(25, 25)
4616        output_too_negative = output_valid - 1.0
4617        output_too_positive = output_valid + 1.0
4618
4619        loss_valid = bceloss(output_valid, target)
4620        with self.assertRaisesRegex(RuntimeError, 'between 0 and 1'):
4621            loss_too_negative = bceloss(output_too_negative, target)
4622        with self.assertRaisesRegex(RuntimeError, 'between 0 and 1'):
4623            loss_too_positive = bceloss(output_too_positive, target)
4624
4625    def test_bce_loss_size_mismatch(self):
4626        bceloss = nn.BCELoss()
4627        a = torch.rand(25)
4628        b = torch.rand(25, 1)
4629        with self.assertRaisesRegex(ValueError, r'Using a target size \('):
4630            bceloss(a, b)
4631
4632    def test_bce_with_logits_gives_same_result_as_sigmoid_and_bce_loss_large_tensors_with_grad(self):
4633        x_size = 1024
4634        y_size = 256
4635        target = torch.rand(x_size, y_size)
4636
4637        for reduction in ['none', 'mean', 'sum']:
4638            output_sig = torch.rand(x_size, y_size) - 0.5
4639            output_logits = output_sig.clone().detach()
4640
4641            output_sig.requires_grad = True
4642            output_logits.requires_grad = True
4643            weight = torch.rand(y_size)
4644
4645            loss_sig = nn.BCELoss(weight, reduction=reduction)(
4646                torch.sigmoid(output_sig), target
4647            )
4648            loss_logits = nn.BCEWithLogitsLoss(weight, reduction=reduction)(
4649                output_logits, target
4650            )
4651
4652            self.assertEqual(loss_logits, loss_sig)
4653
4654            if reduction == 'none':
4655                grad = torch.rand(x_size, y_size)
4656                loss_sig.backward(grad)
4657                loss_logits.backward(grad)
4658            else:
4659                loss_sig.backward()
4660                loss_logits.backward()
4661
4662            self.assertEqual(output_sig.grad, output_logits.grad)
4663
4664    def test_bce_with_logits_has_correct_forward_grad(self):
4665        output = torch.randn(3, 5, requires_grad=True, dtype=torch.double)
4666        target = torch.randn(3, 5, dtype=torch.double)
4667        for reduction in ('sum', 'mean', 'none'):
4668            gradcheck(lambda self, target: nn.BCEWithLogitsLoss(reduction=reduction)(self, target),
4669                      (output, target), check_forward_ad=True)
4670
4671    def test_bce_with_logits_has_correct_grad_at_zero(self):
4672        output = torch.zeros(3, 1, requires_grad=True)
4673        target = torch.zeros(3, 1)
4674        nn.BCEWithLogitsLoss(reduction='sum')(output, target).backward()
4675        expected_grad = torch.empty(3, 1).fill_(0.5)
4676        self.assertEqual(output.grad, expected_grad)
4677
4678    def test_bce_with_logits_broadcasts_weights(self):
4679        target = torch.rand(16, 4)
4680        output = torch.rand(16, 4) - 0.5
4681
4682        weight = torch.rand(4)
4683        out1 = nn.BCEWithLogitsLoss(weight)(output, target)
4684
4685        weight = weight.expand(16, 4).contiguous()
4686        out2 = nn.BCEWithLogitsLoss(weight)(output, target)
4687
4688        self.assertEqual(out1, out2)
4689
4690        weight = torch.rand(16, 1)
4691        out1 = nn.BCEWithLogitsLoss(weight)(output, target)
4692
4693        weight = weight.expand(16, 4).contiguous()
4694        out2 = nn.BCEWithLogitsLoss(weight)(output, target)
4695
4696        self.assertEqual(out1, out2)
4697
4698    def test_bce_with_logits_ones_in_pos_weights_are_the_same_as_none(self):
4699        target = torch.rand(64, 4)
4700        output = torch.rand(64, 4) - 0.5
4701        pos_weight = torch.ones(64, 4)
4702
4703        self.assertEqual(nn.BCEWithLogitsLoss()(output, target),
4704                         nn.BCEWithLogitsLoss(pos_weight=pos_weight)(output, target))
4705
4706    def test_bce_with_logits_broadcasts_pos_weights(self):
4707        target = torch.rand(64, 4)
4708        output = torch.rand(64, 4) - 0.5
4709        pos_weight = torch.rand(4)
4710        out1 = nn.BCEWithLogitsLoss(pos_weight=pos_weight)(output, target)
4711
4712        pos_weight1 = pos_weight.expand(1, 4)
4713        out2 = nn.BCEWithLogitsLoss(pos_weight=pos_weight1)(output, target)
4714
4715        pos_weight2 = pos_weight.expand(64, 4)
4716        out3 = nn.BCEWithLogitsLoss(pos_weight=pos_weight2)(output, target)
4717
4718        self.assertEqual(out1, out2)
4719        self.assertEqual(out1, out3)
4720
4721    def test_bce_with_logits_with_pos_weight_has_correct_grad_at_zero(self):
4722        output = torch.zeros(3, 1, requires_grad=True)
4723        target = torch.zeros(3, 1)
4724        pos_weight = torch.ones(3, 1)
4725        nn.BCEWithLogitsLoss(pos_weight=pos_weight, reduction='sum')(output, target).backward()
4726        expected_grad = torch.empty(3, 1).fill_(0.5)
4727        grad = output.grad
4728        self.assertEqual(grad, expected_grad)
4729
4730    def test_bce_with_logits_stability(self):
4731        output = torch.tensor([0., -120.])
4732        target = torch.tensor([0., 1.])
4733        pos_weight = torch.tensor([1., 1.])
4734
4735        out1 = nn.BCEWithLogitsLoss()(output, target)
4736        self.assertTrue(torch.isfinite(out1).all().item())
4737
4738        out2 = nn.BCEWithLogitsLoss(pos_weight=pos_weight)(output, target)
4739        self.assertTrue(torch.isfinite(out2).all().item())
4740
4741    def test_bce_loss_broadcasts_weights(self):
4742        sigmoid = nn.Sigmoid()
4743        target = torch.rand(16, 4)
4744        output = torch.rand(16, 4) - 0.5
4745
4746        weight = torch.rand(4)
4747        out1 = nn.BCELoss(weight)(sigmoid(output), target)
4748
4749        weight = weight.expand(16, 4).contiguous()
4750        out2 = nn.BCELoss(weight)(sigmoid(output), target)
4751
4752        self.assertEqual(out1, out2)
4753
4754        weight = torch.rand(16, 1)
4755        out1 = nn.BCELoss(weight)(sigmoid(output), target)
4756
4757        weight = weight.expand(16, 4).contiguous()
4758        out2 = nn.BCELoss(weight)(sigmoid(output), target)
4759
4760        self.assertEqual(out1, out2)
4761
4762    def test_hardtanh_inplace_gradgrad(self):
4763        v = torch.randn(8, requires_grad=True, dtype=torch.double)
4764
4765        def func(root):
4766            x = root.clone()
4767            return F.hardtanh(x, inplace=True)
4768
4769        gradcheck(func, [v])
4770        gradgradcheck(func, [v])
4771
4772    # test hardtanh backward for large tensor
4773    def test_hardtanh_backward(self):
4774        x = torch.randn(128, 10000, requires_grad=True)
4775        grad = torch.randn(128, 10000)
4776        z = torch.zeros(128, 10000)
4777        y = F.hardtanh(x)
4778        y.backward(grad)
4779        # ref backward path for hardtanh
4780        mask = (x > -1) & (x < 1)
4781        x_grad_ref = torch.where(mask, grad, z)
4782        self.assertEqual(x.grad, x_grad_ref)
4783
4784    def test_batchnorm_nhwc_cpu(self):
4785        def helper(self, mod, size, dtype, mixed_dtype=False, format=torch.channels_last, precision=None):
4786            channels = size[1]
4787            input = torch.randn(size, dtype=dtype, device='cpu', requires_grad=True)
4788            input = input.contiguous(memory_format=format).to(dtype)
4789            input.retain_grad()
4790            grad = torch.randn(size, dtype=dtype, device='cpu')
4791            grad = grad.contiguous(memory_format=format)
4792            bn = mod(channels).cpu().to(dtype)
4793            bn.weight.data.uniform_()
4794            bn.bias.data.uniform_()
4795
4796            ref_input = input.detach().clone().contiguous().requires_grad_(True)
4797            ref_grad = grad.detach().clone().contiguous()
4798            ref_bn = mod(channels).cpu().to(dtype)
4799            ref_bn.load_state_dict(bn.state_dict())
4800
4801            if mixed_dtype:
4802                bn.float()
4803                ref_bn.float()
4804
4805            out = bn(input)
4806            out.backward(grad)
4807            ref_out = ref_bn(ref_input)
4808            ref_out.backward(ref_grad)
4809
4810            self.assertTrue(out.is_contiguous(memory_format=format))
4811            self.assertTrue(ref_out.is_contiguous())
4812            self.assertEqual(out, ref_out)
4813            self.assertEqual(bn.weight.grad, ref_bn.weight.grad, atol=precision, rtol=precision)
4814            self.assertEqual(bn.bias.grad, ref_bn.bias.grad)
4815            self.assertEqual(input.grad, ref_input.grad)
4816
4817        # test NC11 and N1HW; test mixed dtype
4818        for shape in [(4, 8, 10, 10), (4, 1, 9, 9), (4, 9, 1, 1)]:
4819            for dtype in [torch.float, torch.bfloat16, torch.float16]:
4820                for mixed_dtype in [False, True]:
4821                    if dtype == torch.float:
4822                        mixed_dtype = False
4823                    helper(self, nn.BatchNorm2d, shape, dtype, mixed_dtype, torch.channels_last)
4824
4825        precisons = {torch.float: 1e-4, torch.bfloat16: 1e-4, torch.float16: None}
4826        for shape in [(4, 8, 2, 10, 10), (4, 1, 2, 9, 9), (4, 9, 1, 1, 1)]:
4827            for dtype in [torch.float, torch.bfloat16, torch.float16]:
4828                for mixed_dtype in [False, True]:
4829                    if dtype == torch.float:
4830                        mixed_dtype = False
4831                    helper(self, nn.BatchNorm3d, shape, dtype, mixed_dtype, torch.channels_last_3d, precisons[dtype])
4832
4833    @parametrize_test(
4834        'bn_module',
4835        [
4836            subtest(torch.nn.BatchNorm2d, name="BatchNorm2d"),
4837            subtest(torch.nn.SyncBatchNorm, name="SyncBatchNorm"),
4838        ],
4839    )
4840    def test_batchnorm_non_contig_cpu(self, bn_module):
4841        def helper(self, dtype):
4842            input = torch.arange(6, dtype=torch.float).reshape(1, 3, 2, 1).cpu()
4843            input = input.permute(0, 2, 1, 3)
4844
4845            bn = bn_module(2).cpu().float().eval()
4846            bn.weight.data.uniform_()
4847            bn.bias.data.uniform_()
4848
4849            ref_input = input.detach().clone().contiguous()
4850            ref_bn = nn.BatchNorm2d(2).cpu().float().eval()
4851            ref_bn.load_state_dict(bn.state_dict())
4852
4853            out = bn(input)
4854            ref_out = ref_bn(ref_input)
4855
4856            self.assertTrue(out.is_contiguous(memory_format=torch.channels_last))
4857            self.assertTrue(ref_out.is_contiguous())
4858            self.assertEqual(out, ref_out)
4859
4860            input_bf = torch.arange(24, dtype=dtype).reshape(1, 3, 2, 4)
4861            input_bf = input_bf.permute(0, 2, 1, 3)
4862            input_f = input_bf.float()
4863            bn_mix = bn_module(2).float().eval()
4864            ref_bn_f = deepcopy(bn_mix)
4865            out_bf = bn_mix(input_bf)
4866            ref_out_bf = ref_bn_f(input_f)
4867            self.assertEqual(ref_out_bf, out_bf.float(), atol=0.05, rtol=0.05)
4868
4869        helper(self, torch.bfloat16)
4870        helper(self, torch.float16)
4871
4872    @unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
4873    @unittest.skipIf(not TEST_CUDNN, "needs cudnn")
4874    def test_batchnorm_cudnn_nhwc(self):
4875        def run_test(input, grad_output):
4876            c = input.size(1)
4877            mod = nn.BatchNorm2d(c).cuda().float()
4878            mod.weight.data.uniform_()
4879            mod.bias.data.uniform_()
4880            ref_input = input.detach().clone().contiguous().requires_grad_(True)
4881            ref_grad = grad.detach().clone().contiguous()
4882            ref_mod = nn.BatchNorm2d(c).cuda().float()
4883            ref_mod.load_state_dict(mod.state_dict())
4884            out = mod(input)
4885            out.backward(grad_output)
4886            ref_out = ref_mod(ref_input)
4887            ref_out.backward(ref_grad)
4888            self.assertTrue(out.is_contiguous(memory_format=torch.channels_last))
4889            self.assertTrue(ref_out.is_contiguous())
4890            self.assertEqual(out, ref_out)
4891            self.assertEqual(mod.weight.grad, ref_mod.weight.grad)
4892            self.assertEqual(mod.bias.grad, ref_mod.bias.grad)
4893            self.assertEqual(input.grad, ref_input.grad)
4894
4895        input = torch.randint(1, 10, (4, 8, 2, 2), dtype=torch.float32, device="cuda")
4896        input = input.contiguous(memory_format=torch.channels_last).detach().requires_grad_()
4897
4898        grad = torch.randint(1, 10, (4, 8, 2, 2), dtype=torch.float32, device="cuda")
4899        grad = grad.contiguous(memory_format=torch.channels_last)
4900        run_test(input, grad)
4901        # see #42588, grad is channels_last contiguous, but grad.suggest_memory_format (rightly) return "contiguous"
4902        # not channels_last
4903        input = torch.randint(1, 10, (2, 8, 8, 1), dtype=torch.float32, device="cuda")
4904        input = input.contiguous(memory_format=torch.channels_last).detach().requires_grad_()
4905        grad = torch.randint(1, 10, (2, 8, 8, 1), dtype=torch.float32, device="cuda")
4906        grad = grad.permute(0, 2, 1, 3)
4907        run_test(input, grad)
4908
4909    @unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
4910    def test_batchnorm_cudnn_half(self):
4911        # THNN
4912        input = torch.randint(1, 10, (2, 3, 2, 2), dtype=torch.half, device="cuda", requires_grad=True)
4913        m = nn.BatchNorm2d(3).half().cuda()
4914        thnn_output = m(input)
4915        thnn_output.sum().backward()
4916        thnn_input_grad = input.grad.data.clone()
4917        self.assertEqualTypeString(thnn_output, input)
4918        # cuDNN
4919        if TEST_CUDNN:
4920            input.grad = None
4921            m = m.float()
4922            cudnn_output = m(input)
4923            cudnn_output.sum().backward()
4924            cudnn_input_grad = input.grad.data.clone()
4925            self.assertEqualTypeString(cudnn_output, input)
4926            self.assertEqual(cudnn_output, thnn_output)
4927            self.assertEqual(cudnn_input_grad, thnn_input_grad, atol=1e-3, rtol=0)
4928
4929    @unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
4930    def test_batchnorm_nonaffine_cuda_half_input(self):
4931        input = torch.randn(16, 3, 24, 24, dtype=torch.half, device="cuda")
4932        m = nn.BatchNorm2d(3, affine=False).cuda().float()  # keep running stats in FP32
4933        output = m(input)
4934        self.assertEqualTypeString(output, input)
4935        m.eval()
4936        output = m(input)
4937        self.assertEqualTypeString(output, input)
4938
4939    def test_batchnorm_raises_error_if_less_than_one_value_per_channel(self):
4940        x = torch.rand(10)[None, :, None]
4941        with self.assertRaises(ValueError):
4942            torch.nn.BatchNorm1d(10)(x)
4943
4944    def test_batchnorm_raises_error_if_running_mean_is_not_same_size_as_input(self):
4945        input = torch.rand(2, 10)
4946        running_var = torch.rand(10)
4947        wrong_sizes = [9, 11]
4948        for size in wrong_sizes:
4949            with self.assertRaises(RuntimeError):
4950                F.batch_norm(input, torch.rand(size), running_var)
4951
4952    def test_batchnorm_raises_error_if_running_var_is_not_same_size_as_input(self):
4953        input = torch.rand(2, 10)
4954        running_mean = torch.rand(10)
4955        wrong_sizes = [9, 11]
4956        for size in wrong_sizes:
4957            with self.assertRaises(RuntimeError):
4958                F.batch_norm(input, running_mean, torch.rand(size))
4959
4960    def test_batchnorm_raises_error_if_weight_is_not_same_size_as_input(self):
4961        input = torch.rand(2, 10)
4962        running_mean = torch.rand(10)
4963        running_var = torch.rand(10)
4964        wrong_sizes = [9, 11]
4965        for size in wrong_sizes:
4966            with self.assertRaises(RuntimeError):
4967                F.batch_norm(input, running_mean, running_var, weight=Parameter(torch.rand(size)))
4968
4969    def test_batchnorm_raises_error_if_bias_is_not_same_size_as_input(self):
4970        input = torch.rand(2, 10)
4971        running_mean = torch.rand(10)
4972        running_var = torch.rand(10)
4973        wrong_sizes = [9, 11]
4974        for size in wrong_sizes:
4975            with self.assertRaises(RuntimeError):
4976                F.batch_norm(input, running_mean, running_var, bias=Parameter(torch.rand(size)))
4977
4978    def test_batchnorm_raises_error_if_running_var_or_running_mean_have_forward_grad(self):
4979        args = (
4980            torch.randn(3, 2, 5),  # input
4981            torch.randn(2),  # running_mean
4982            torch.randn(2),  # running_var
4983        )
4984        kwargs = {'training': False, 'momentum': -1.2}
4985        fn = partial(F.batch_norm, **kwargs)
4986
4987        for dual_indices in ((0,), (1,), (1, 2), (0, 1), (0, 1, 2),):
4988            tangents = tuple(torch.rand_like(x) for x in args)
4989
4990            with fwAD.dual_level():
4991                duals = [fwAD.make_dual(primal, tangent) if i in dual_indices else primal
4992                         for i, (primal, tangent) in enumerate(zip(args, tangents))]
4993                msg = "batch_norm is not differentiable wrt running_mean and running_var"
4994                # 0 needs to have forward grad because otherwise we won't even run batch_norm_jvp
4995                if (1 in dual_indices or 2 in dual_indices) and 0 in dual_indices:
4996                    with self.assertRaisesRegex(RuntimeError, msg):
4997                        fn(*duals)
4998                else:
4999                    fn(*duals)
5000
5001    def test_batchnorm_buffer_update_when_stats_are_not_tracked(self):
5002        input_size = (32, 4)
5003        # Instantiate BN with buffers that are not None
5004        bn = nn.BatchNorm1d(input_size[1], track_running_stats=True)
5005        # Use buffers for normalization but don't update them
5006        bn.track_running_stats = False
5007        # Store initial values
5008        num_batches = bn.num_batches_tracked.clone()
5009        running_mean = bn.running_mean.clone()
5010        running_var = bn.running_var.clone()
5011        # Forward random tensor
5012        _ = bn(torch.rand(input_size))
5013        # Ensure none of the buffers has been updated
5014        self.assertTrue(torch.equal(num_batches, bn.num_batches_tracked))
5015        self.assertTrue(torch.equal(running_mean, bn.running_mean))
5016        self.assertTrue(torch.equal(running_var, bn.running_var))
5017
5018    @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
5019    def test_batchnorm_nhwc_cuda(self):
5020        for dtype in (torch.half, torch.float):
5021            (N, C, H, W) = 2, 64, 50, 50
5022            model = torch.nn.BatchNorm2d(C, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
5023            model = model.eval().cuda().to(dtype)
5024            inp1 = torch.randn(N, C, H, W, device=torch.device('cuda'), dtype=dtype)
5025            inp2 = inp1.contiguous(memory_format=torch.channels_last)
5026            out1 = model(inp1)
5027            out2 = model(inp2)
5028            self.assertTrue(torch.equal(out1, out2))
5029
5030    def test_batchnorm_load_state_dict(self):
5031        bn = torch.nn.BatchNorm2d(3)
5032        self.assertEqual(bn.state_dict()["num_batches_tracked"], torch.tensor(0))
5033
5034        bn.num_batches_tracked = torch.tensor(10)
5035        self.assertEqual(bn.state_dict()["num_batches_tracked"], torch.tensor(10))
5036
5037        empty_dict = OrderedDict()
5038        bn.load_state_dict(empty_dict, strict=False)
5039        self.assertEqual(bn.state_dict()["num_batches_tracked"], torch.tensor(10))
5040
5041        # test that when `num_batches_tracked` is not in loaded state_dict,
5042        # meta num_batches_tracked is still replaced with singleton 0 tensor
5043        with torch.device('meta'):
5044            meta_bn = torch.nn.BatchNorm2d(3)
5045        self.assertTrue(meta_bn.num_batches_tracked.device == torch.device('meta'))
5046        meta_bn.load_state_dict(empty_dict, assign=True, strict=False)
5047        self.assertEqual(meta_bn.state_dict()["num_batches_tracked"], torch.tensor(0))
5048
5049    def test_batch_norm_update_stats(self):
5050        input = torch.rand(0, 1)
5051        running_mean = torch.rand(1)
5052        running_var = torch.rand(1)
5053        with self.assertRaisesRegex(RuntimeError,
5054                                    re.escape("input tensor must have at least one element, but got input_sizes = [0, 1]")):
5055            torch.batch_norm_update_stats(input=input, momentum=0.0, running_mean=running_mean, running_var=running_var)
5056
5057    def test_pairwise_distance(self):
5058        input1 = torch.randn(4, 4, requires_grad=True, dtype=torch.double)
5059        input2 = torch.randn(4, 4, requires_grad=True, dtype=torch.double)
5060        self.assertTrue(gradcheck(lambda x, y: F.pairwise_distance(x, y), (input1, input2)))
5061
5062    # TODO: Create an OpInfo for pdist
5063    def test_pdist(self):
5064        for device, trans in itertools.product(device_(), [False, True]):
5065            inp = torch.randn(4, 5, dtype=torch.double, device=device, requires_grad=True)
5066            if trans:
5067                inp = inp.transpose(0, 1)
5068            for p in [0, 1, 2, 0.5, 1.5, 2.5, float('inf')]:
5069                self.assertTrue(gradcheck(lambda x: F.pdist(x, p), (inp,)))
5070
5071    def test_pdist_zeros(self):
5072        """Test that grad is still valid when dist is 0"""
5073        for device in device_():
5074            inp = torch.randn(1, 3, dtype=torch.double, device=device, requires_grad=True).repeat([2, 1])
5075            for p in [0, 1, 2, 0.5, 1.5, 2.5, float('inf')]:
5076                self.assertTrue(gradcheck(lambda x: F.pdist(x, p), (inp,)))
5077
5078    def test_pdist_empty_row(self):
5079        for device in device_():
5080            inp = torch.randn(1, 3, dtype=torch.double, device=device, requires_grad=True)
5081            self.assertTrue(gradcheck(F.pdist, (inp,)))
5082
5083    def test_pdist_empty_col(self):
5084        for device in device_():
5085            inp = torch.randn(4, 0, dtype=torch.double, device=device, requires_grad=True)
5086            self.assertTrue(gradcheck(F.pdist, (inp,)))
5087
5088    @unittest.expectedFailure
5089    def test_pdist_cpu_gradgrad_unimplemented(self):
5090        inp = torch.randn(4, 5, requires_grad=True)
5091        gradgradcheck(F.pdist, (inp,))
5092
5093    @unittest.expectedFailure
5094    def test_pdist_cuda_gradgrad_unimplemented(self):
5095        inp = torch.randn(4, 5, device='cuda', requires_grad=True)
5096        gradgradcheck(F.pdist, (inp,))
5097
5098    # Merge into OpInfo?
5099    # test for backward in https://github.com/pytorch/pytorch/issues/15511
5100    def test_pdist_large(self):
5101        for device in device_():
5102            def func(x):
5103                return torch.pdist(x, p=2)
5104
5105            # shape[0] should be able to be (roughly) arbitrarily large, but the kernel
5106            # is currently limited to smaller sizes (see issue above); this is just testing
5107            # a floor.
5108            shape = (1000, 1)
5109            x = torch.randn(shape, device=device).requires_grad_()
5110            output = torch.pdist(x, p=2)
5111            # just run a single backward, as gradcheck/gradgradcheck is expensive here
5112            output.sum().backward()
5113
5114    def test_cosine_embedding_loss_with_diff_type(self):
5115        for device in device_():
5116            input1 = torch.tensor([[2, 3, 4], [6, 2, 4]], dtype=torch.double, device=device)
5117            input2 = torch.tensor([[2, 3, 5], [3, 2, 1]], dtype=torch.double, device=device)
5118            target = torch.tensor([1, -1], dtype=torch.int, device=device)
5119            expected = torch.nn.functional.cosine_embedding_loss(input1, input2, target)
5120            for dt1 in get_all_math_dtypes(device):
5121                for dt2 in get_all_math_dtypes(device):
5122                    for dt3 in get_all_math_dtypes(device):
5123                        # dt3 is used as dtype for target = [1, -1], so let's skip unsigned type
5124                        if dt3 == torch.uint8:
5125                            continue
5126                        if dt1.is_complex or dt2.is_complex or dt3.is_complex:
5127                            continue
5128                        input1 = input1.to(dt1)
5129                        input2 = input2.to(dt2)
5130                        target = target.to(dt3)
5131                        result = torch.nn.functional.cosine_embedding_loss(input1, input2, target)
5132                        self.assertEqual(result.item(), expected.item(), atol=0.001, rtol=0)
5133
5134    def test_cosine_embedding_loss_error_on_diff_shapes(self):
5135        for device in device_():
5136            input1 = torch.empty((0, 0), dtype=torch.double, device=device)
5137            input2 = torch.empty((0,), dtype=torch.double, device=device)
5138            target = torch.empty((0,), dtype=torch.int, device=device)
5139            with self.assertRaisesRegex(RuntimeError, ".*expects 2D.*"):
5140                torch.nn.functional.cosine_embedding_loss(input1, input2, target)
5141
5142    def test_cosine_embedding_loss_error_on_nonexpandable_shapes(self):
5143        for device in device_():
5144            input1 = torch.empty((1, 5), dtype=torch.double, device=device)
5145            input2 = torch.empty((1, 6), dtype=torch.double, device=device)
5146            target = torch.ones((1,), dtype=torch.int, device=device)
5147            with self.assertRaisesRegex(RuntimeError, ".*must match the size.*"):
5148                torch.nn.functional.cosine_embedding_loss(input1, input2, target)
5149
5150    def test_kl_div_with_diff_type(self):
5151        for device in device_():
5152            input = torch.tensor([[2, 3, 5], [3, 2, 1]], dtype=torch.double, device=device)
5153            target = torch.tensor([[1, 2, 3], [4, 5, 6]], dtype=torch.double, device=device)
5154            expected = torch.nn.functional.kl_div(input, target)
5155            real_dtypes = (torch.float32, torch.float64, torch.float16)
5156            for input_dtype, target_dtype in product(real_dtypes, repeat=2):
5157                if (torch.device(device).type == 'cpu' and target_dtype == torch.float16):
5158                    continue
5159                input = input.to(input_dtype)
5160                target = target.to(target_dtype)
5161                result = torch.nn.functional.kl_div(input, target)
5162                self.assertEqual(result.item(), expected.item(), atol=0.001, rtol=0)
5163
5164    def test_kl_div_with_diff_type_log_target(self):
5165        for device in device_():
5166            input = torch.tensor([[2, 3, 5], [3, 2, 1]], dtype=torch.double, device=device)
5167            target = torch.tensor([[1, 2, 3], [4, 5, 6]], dtype=torch.double, device=device).log()
5168            expected = torch.nn.functional.kl_div(input, target, log_target=True)
5169            real_dtypes = (torch.float32, torch.float64, torch.float16)
5170            for input_dtype, target_dtype in product(real_dtypes, repeat=2):
5171                if (torch.device(device).type == 'cpu' and target_dtype == torch.float16):
5172                    continue
5173                input = input.to(input_dtype)
5174                target = target.to(target_dtype)
5175                result = torch.nn.functional.kl_div(input, target, log_target=True)
5176                self.assertEqual(result.item(), expected.item(), atol=0.001, rtol=0)
5177
5178    def test_kl_div_log_softmax_target(self):
5179        for device in device_():
5180            a = torch.tensor([[1.0, 2, 3], [5.0, 5, 5]], device=device)
5181            b = torch.tensor([[1.0, 2, 3], [5.0, 5, 5]], device=device)
5182            self.assertEqual(
5183                F.kl_div(F.log_softmax(a, 1), F.log_softmax(b, 1), reduction='none', log_target=True),
5184                torch.zeros_like(a)
5185            )
5186
5187    def test_cosine_embedding_loss_no_reduce(self):
5188        input1 = torch.randn(15, 10, requires_grad=True, dtype=torch.double)
5189        input2 = torch.randn(15, 10, requires_grad=True, dtype=torch.double)
5190        target = torch.randn(15, dtype=torch.double).sign()
5191        self.assertTrue(gradcheck(lambda x, y, z: F.cosine_embedding_loss(
5192            x, y, z, reduction='none'), (input1, input2, target)))
5193        self.assertEqual(F.cosine_embedding_loss(input1, input2, target, reduction='none'),
5194                         loss_reference_fns['CosineEmbeddingLoss'](input1, input2, target, reduction='none'))
5195
5196    def test_cosine_embedding_loss_margin_no_reduce(self):
5197        input1 = torch.randn(15, 10, requires_grad=True, dtype=torch.double)
5198        input2 = torch.randn(15, 10, requires_grad=True, dtype=torch.double)
5199        target = torch.randn(15, dtype=torch.double).sign()
5200        self.assertTrue(gradcheck(lambda x, y, z: F.cosine_embedding_loss(
5201            x, y, z, margin=0.5, reduction='none'), (input1, input2, target)))
5202        self.assertEqual(F.cosine_embedding_loss(input1, input2, target, margin=0.5, reduction='none'),
5203                         loss_reference_fns['CosineEmbeddingLoss'](input1, input2, target,
5204                                                                   margin=0.5, reduction='none'))
5205
5206    def test_cosine_embedding_loss_invalid_shape(self):
5207        input1 = torch.randn(15, 10)
5208        input2 = torch.randn(15, 10)
5209        target = torch.randn(15, 1).sign()
5210
5211        with self.assertRaisesRegex(RuntimeError, "1D target tensor expected"):
5212            F.cosine_embedding_loss(input1, input2, target)
5213
5214        with self.assertRaisesRegex(RuntimeError, "1D target tensor expects 2D input tensors"):
5215            F.cosine_embedding_loss(torch.randn(10), torch.randn(10), torch.randn(10))
5216
5217        with self.assertRaisesRegex(RuntimeError, "0D target tensor expects 1D input tensors"):
5218            F.cosine_embedding_loss(torch.randn(2, 5), torch.randn(2, 5), torch.randn(()))
5219
5220    def test_margin_ranking_loss_no_reduce(self):
5221        input1 = torch.randn(15, dtype=torch.double).mul_(10).requires_grad_()
5222        input2 = torch.randn(15, dtype=torch.double).mul_(10).requires_grad_()
5223        target = torch.randn(15, dtype=torch.double).sign()
5224        self.assertTrue(gradcheck(lambda x, y, z: F.margin_ranking_loss(
5225            x, y, z, reduction='none'), (input1, input2, target)))
5226        self.assertEqual(F.margin_ranking_loss(input1, input2, target, reduction='none'),
5227                         loss_reference_fns['MarginRankingLoss'](input1, input2, target, reduction='none'))
5228
5229    def test_margin_ranking_loss_margin_no_reduce(self):
5230        input1 = torch.randn(15, dtype=torch.double).mul_(10).requires_grad_()
5231        input2 = torch.randn(15, dtype=torch.double).mul_(10).requires_grad_()
5232        target = torch.randn(15, dtype=torch.double).sign()
5233        self.assertTrue(gradcheck(lambda x, y, z: F.margin_ranking_loss(
5234            x, y, z, margin=0.5, reduction='none'), (input1, input2, target)))
5235        self.assertEqual(F.margin_ranking_loss(input1, input2, target, margin=0.5, reduction='none'),
5236                         loss_reference_fns['MarginRankingLoss'](input1, input2, target, margin=0.5, reduction='none'))
5237
5238    def test_triplet_margin_loss(self):
5239        input1 = torch.randn(5, 10, requires_grad=True, dtype=torch.double)
5240        input2 = torch.randn(5, 10, requires_grad=True, dtype=torch.double)
5241        input3 = torch.randn(5, 10, requires_grad=True, dtype=torch.double)
5242        self.assertTrue(gradcheck(lambda x1, x2, x3: F.triplet_margin_loss(
5243            x1, x2, x3), (input1, input2, input3)))
5244        self.assertEqual(F.triplet_margin_loss(input1, input2, input3),
5245                         loss_reference_fns['TripletMarginLoss'](input1, input2, input3))
5246
5247    def test_triplet_margin_loss_swap(self):
5248        input1 = torch.randn(5, 10, requires_grad=True, dtype=torch.double)
5249        input2 = torch.randn(5, 10, requires_grad=True, dtype=torch.double)
5250        input3 = torch.randn(5, 10, requires_grad=True, dtype=torch.double)
5251        self.assertTrue(gradcheck(lambda x1, x2, x3: F.triplet_margin_loss(
5252            x1, x2, x3, swap=True), (input1, input2, input3)))
5253        self.assertEqual(F.triplet_margin_loss(input1, input2, input3, swap=True),
5254                         loss_reference_fns['TripletMarginLoss'](input1, input2, input3, swap=True))
5255
5256    def test_triplet_margin_loss_no_reduce(self):
5257        input1 = torch.randn(5, 10, requires_grad=True, dtype=torch.double)
5258        input2 = torch.randn(5, 10, requires_grad=True, dtype=torch.double)
5259        input3 = torch.randn(5, 10, requires_grad=True, dtype=torch.double)
5260        self.assertTrue(gradcheck(lambda x1, x2, x3: F.triplet_margin_loss(
5261            x1, x2, x3, reduction='none'), (input1, input2, input3)))
5262        self.assertEqual(F.triplet_margin_loss(input1, input2, input3, reduction='none'),
5263                         loss_reference_fns['TripletMarginLoss'](input1, input2, input3, reduction='none'))
5264
5265    def test_triplet_margin_loss_swap_no_reduce(self):
5266        input1 = torch.randn(5, 10, requires_grad=True, dtype=torch.double)
5267        input2 = torch.randn(5, 10, requires_grad=True, dtype=torch.double)
5268        input3 = torch.randn(5, 10, requires_grad=True, dtype=torch.double)
5269        self.assertTrue(gradcheck(lambda x1, x2, x3: F.triplet_margin_loss(
5270            x1, x2, x3, swap=True, reduction='none'), (input1, input2, input3)))
5271        self.assertEqual(F.triplet_margin_loss(input1, input2, input3, swap=True, reduction='none'),
5272                         loss_reference_fns['TripletMarginLoss'](input1, input2, input3, swap=True, reduction='none'))
5273
5274    def test_pointwise_loss_target_grad_none_reduction(self):
5275        i = torch.randn(5, 10)
5276        t = torch.randn(5, 10, requires_grad=True)
5277        self.assertEqual(F.mse_loss(i, t, reduction='none').size(), t.size())
5278        self.assertEqual(F.l1_loss(i, t, reduction='none').size(), t.size())
5279
5280    def test_pointwise_loss_broadcast(self):
5281        losses = {
5282            'mse_loss': lambda x, y, r: F.mse_loss(x, y, reduction=r),
5283            'l1_loss': lambda x, y, r: F.l1_loss(x, y, reduction=r),
5284            'smooth_l1_loss': lambda x, y, r: F.smooth_l1_loss(x, y, reduction=r),
5285            'huber_loss': lambda x, y, r: F.huber_loss(x, y, reduction=r),
5286        }
5287
5288        input = torch.randn(2, 1, requires_grad=True, dtype=torch.double)
5289        for fn in losses.values():
5290            for requires_grad in [True, False]:
5291                # When target.requires_grad=True, its impl is in Python, while the other is in TH.
5292                target = torch.randn(2, 10, requires_grad=requires_grad, dtype=torch.double)
5293                for reduction in ['none', 'mean', 'sum']:
5294                    l = fn(input, target, reduction)
5295                    if reduction == 'none':
5296                        self.assertEqual(l.size(), target.size())
5297                    self.assertTrue(gradcheck(fn, (input, target, reduction)))
5298
5299    # https://github.com/pytorch/pytorch/issues/27692 reports
5300    # that l1_loss get a wrong result for big batch size
5301    def test_l1_loss_correct(self):
5302        for dtype in [torch.float, torch.cfloat]:
5303            for N in range(1, 50, 10):
5304                input = torch.rand(N, 3, 1024, 1024, dtype=dtype)
5305                self.assertEqual(
5306                    torch.nn.L1Loss()(input, torch.zeros_like(input)),
5307                    input.abs().mean())
5308
5309    def test_smoothl1loss_intergral_target(self):
5310        def _input_grad(input, target, reduction):
5311            output = F.smooth_l1_loss(input, target, reduction=reduction, beta=0.5)
5312            output.sum().backward()
5313            return input.grad
5314
5315        for device, dtype, reduction in product(device_(),
5316                                                integral_types(),
5317                                                ('none', 'sum', 'mean')):
5318            input = torch.randn(2, 2, device=device, requires_grad=True)
5319            target = torch.randint(0, 9, (2, 2), device=device, dtype=dtype)
5320
5321            input_grad_with_float_target = _input_grad(input, target.float(), reduction)
5322
5323            input_grad = _input_grad(input.detach().clone().requires_grad_(True),
5324                                     target,
5325                                     reduction)
5326            self.assertEqual(input_grad, input_grad_with_float_target)
5327
5328    def test_smoothl1loss_negative_beta_not_supported(self):
5329        with self.assertRaises(RuntimeError):
5330            F.smooth_l1_loss(torch.randn(2, 2), torch.randn(2, 2), beta=-1.0)
5331
5332    def test_huber_loss_invalid_delta(self):
5333        def _test_huber_loss_delta_error_helper(delta):
5334            input, target = torch.randn(2, 2), torch.randn(2, 2)
5335            loss = torch.nn.HuberLoss(delta=delta)
5336            with self.assertRaises(RuntimeError):
5337                loss(input, target)
5338
5339        def test_huber_loss_negative_delta():
5340            _test_huber_loss_delta_error_helper(delta=-0.5)
5341
5342        def test_huber_loss_zero_delta():
5343            _test_huber_loss_delta_error_helper(delta=0.0)
5344
5345        test_huber_loss_negative_delta()
5346        test_huber_loss_zero_delta()
5347
5348    @set_default_dtype(torch.double)
5349    def test_cosine_similarity(self):
5350        # Check cosine_similarity input/output shapes
5351        input_size = (1, 3, 2, 1)
5352        expected_size = (1, 2, 1)
5353        input1 = torch.randn(input_size, requires_grad=True)
5354        input2 = torch.randn(input_size, requires_grad=True)
5355        self.assertEqual(F.cosine_similarity(input1, input2, dim=1).size(), expected_size)
5356
5357        # Check numerical precision, issue #18057
5358        vv1 = torch.tensor([float(i) for i in range(84)]).unsqueeze(0)
5359        vv2 = torch.tensor([float(i) for i in range(84)]).unsqueeze(0)
5360        out = F.cosine_similarity(vv1, vv2)
5361        self.assertLessEqual(out, 1.0)
5362
5363        # Check dividing by 0.
5364        # previous behavior: <x,y>/max(eps, ||x|| * ||y||)
5365        # current: <x/max(eps, ||x||), y/max(eps,||y||)>
5366        # if f(x,y) is the cosine similarity, then
5367        # df/dx = y/(||x|| * ||y||) - (x * <x,y> * ||y||/||x||)/(||x|| * ||y||)^2
5368        # the tests below check division by zero in the backward formula when
5369        # x := input2 = 0, y := input1 != 0.
5370        # For these inputs the gradient wrt x simplifies to g(x,y) := y/(||x|| * ||y||)
5371        # Previous test checks g(x,y) == y/eps,
5372        # Current test checks g(x,y) == (y/||y||)/eps.
5373        input1 = torch.randn(10).requires_grad_()
5374        input2 = torch.zeros_like(input1).requires_grad_()
5375        torch.cosine_similarity(input1, input2, 0).sum().backward()
5376        self.assertEqual(input1.grad, torch.zeros_like(input1))
5377        self.assertEqual(input2.grad, input1 / input1.norm() * 1e8)
5378
5379        # Check type promotion, issue #61454
5380        input = torch.tensor(12.)
5381        out = F.cosine_similarity(input.to(torch.int8), input, dim=-1)
5382        self.assertEqual(out, 1.)
5383
5384        # Check broadcasting #109333
5385        a = torch.ones(2, 3, dtype=torch.float)
5386        b = torch.ones(1, 1, dtype=torch.float)
5387        out = F.cosine_similarity(a, b)
5388        self.assertEqual(out, torch.ones(2, dtype=torch.float))
5389
5390        a = torch.ones(2, 3, dtype=torch.float)
5391        b = torch.ones(1, dtype=torch.float)
5392        out = F.cosine_similarity(a, b)
5393        self.assertEqual(out, torch.ones(2, dtype=torch.float))
5394
5395
5396    def test_grid_sample_error_checking(self):
5397        input = torch.empty(1, 1, 2, 2)
5398        grid = torch.empty(1, 1, 1, 2)
5399
5400        # assert no error
5401        F.grid_sample(input, grid, align_corners=False)
5402
5403        with self.assertRaisesRegex(ValueError, "but got: 'garbage'"):
5404            F.grid_sample(input, grid, mode='garbage', align_corners=False)
5405
5406        with self.assertRaisesRegex(ValueError, "but got: 'garbage'"):
5407            F.grid_sample(input, grid, padding_mode='garbage', align_corners=False)
5408
5409        with self.assertRaisesRegex(RuntimeError, "expected grid to have size 1 in last dimension"):
5410            F.grid_sample(input[0], grid, align_corners=False)
5411
5412        with self.assertRaisesRegex(RuntimeError, "expected grid to have size 2 in last dimension"):
5413            F.grid_sample(input, torch.empty(1, 1, 1, 1, 3), align_corners=False)
5414
5415        with self.assertRaisesRegex(RuntimeError, "expected grid and input to have same batch size"):
5416            F.grid_sample(input, torch.empty(2, 1, 1, 2), align_corners=False)
5417
5418        with self.assertRaisesRegex(RuntimeError, "expected grid to have size 2 in last dimension"):
5419            F.grid_sample(input, torch.empty(1, 1, 1, 3), align_corners=False)
5420
5421        with self.assertRaisesRegex(RuntimeError, "expected input to have non-empty spatial dimensions"):
5422            F.grid_sample(torch.empty(1, 1, 0, 2), grid, align_corners=False)
5423
5424        with self.assertRaisesRegex(RuntimeError, "bicubic interpolation only supports 4D input"):
5425            F.grid_sample(torch.empty(1, 1, 2, 2, 2), torch.empty(1, 1, 1, 1, 3), mode='bicubic')
5426
5427        if TEST_CUDA:
5428            with self.assertRaisesRegex(RuntimeError, "Expected all tensors to be on the same device"):
5429                F.grid_sample(input.cuda(), grid, align_corners=False)
5430
5431    def test_affine_grid_error_checking(self):
5432        # 2D affine
5433        theta = torch.empty(1, 2, 3, dtype=torch.double)
5434        size = torch.Size([1, 1, 2, 2])
5435
5436        # assert no error
5437        F.affine_grid(theta, size, align_corners=False)
5438
5439        # check for warning for empty span along dimension
5440        with warnings.catch_warnings(record=True) as w:
5441            # Ensure warnings are being shown
5442            warnings.simplefilter("always")
5443            # Should not trigger warning
5444            F.affine_grid(theta, torch.Size([1, 1, 2, 1]), align_corners=False)
5445            # Check no warning occurs
5446            self.assertNotIn('See the documentation of affine_grid for details.', ' '.join(map(str, w)))
5447            # Should trigger warning
5448            F.affine_grid(theta, torch.Size([1, 1, 2, 1]), align_corners=True)
5449            # Check warning occurs
5450            self.assertIn('See the documentation of affine_grid for details.', ' '.join(map(str, w)))
5451
5452        with self.assertRaisesRegex(ValueError, "Expected theta to have floating point type"):
5453            F.affine_grid(theta.int(), size, align_corners=False)
5454
5455        with self.assertRaisesRegex(ValueError, "Expected a batch of 2D affine matrices of shape Nx2x3"):
5456            F.affine_grid(theta[0], size, align_corners=False)
5457
5458        with self.assertRaisesRegex(ValueError, "Expected a batch of 2D affine matrices of shape Nx2x3"):
5459            F.affine_grid(theta.unsqueeze(0), size, align_corners=False)
5460
5461        with self.assertRaisesRegex(ValueError, "Expected a batch of 2D affine matrices of shape Nx2x3"):
5462            F.affine_grid(theta.repeat(1, 2, 1), size, align_corners=False)
5463
5464        with self.assertRaisesRegex(ValueError, "Expected a batch of 2D affine matrices of shape Nx2x3"):
5465            F.affine_grid(theta.repeat(1, 1, 2), size, align_corners=False)
5466
5467        # 3D affine
5468        theta = torch.empty(1, 3, 4, dtype=torch.double)
5469        size = torch.Size([1, 1, 2, 2, 2])
5470
5471        # assert no error
5472        F.affine_grid(theta, size, align_corners=False)
5473
5474        # check for warning for empty span along dimension
5475        with warnings.catch_warnings(record=True) as w:
5476            # Ensure warnings are being shown
5477            warnings.simplefilter("always")
5478            # Should not trigger warning
5479            F.affine_grid(theta, torch.Size([1, 1, 3, 2, 1]), align_corners=False)
5480            # Check no warning occurs
5481            self.assertNotIn('See the documentation of affine_grid for details.', ' '.join(map(str, w)))
5482            # Should trigger warning
5483            F.affine_grid(theta, torch.Size([1, 1, 3, 2, 1]), align_corners=True)
5484            # Check warning occurs
5485            self.assertIn('See the documentation of affine_grid for details.', ' '.join(map(str, w)))
5486
5487        with self.assertRaisesRegex(ValueError, "Expected a batch of 3D affine matrices of shape Nx3x4"):
5488            F.affine_grid(theta[0], size, align_corners=False)
5489
5490        with self.assertRaisesRegex(ValueError, "Expected a batch of 3D affine matrices of shape Nx3x4"):
5491            F.affine_grid(theta.unsqueeze(0), size, align_corners=False)
5492
5493        with self.assertRaisesRegex(ValueError, "Expected a batch of 3D affine matrices of shape Nx3x4"):
5494            F.affine_grid(theta.repeat(1, 2, 1), size, align_corners=False)
5495
5496        with self.assertRaisesRegex(ValueError, "Expected a batch of 3D affine matrices of shape Nx3x4"):
5497            F.affine_grid(theta.repeat(1, 1, 2), size, align_corners=False)
5498
5499        with self.assertRaisesRegex(NotImplementedError, "affine_grid only supports 4D and 5D sizes"):
5500            F.affine_grid(theta, torch.Size([1, 2, 2]), align_corners=False)
5501
5502        with self.assertRaisesRegex(NotImplementedError, "affine_grid only supports 4D and 5D sizes"):
5503            F.affine_grid(theta, torch.Size([1, 1, 2, 2, 2, 2]), align_corners=False)
5504
5505    @parametrize_test('device', ['cpu'] + (['cuda'] if TEST_CUDA else []))
5506    @parametrize_test('nd', [2, 3])
5507    def test_affine_grid_backward_cl_cf_consistency(self, device, nd):
5508        # Test based on reported issue: https://github.com/pytorch/pytorch/issues/124154
5509
5510        theta = torch.rand([6, nd, nd + 1], requires_grad=True, device=device)
5511        size = [6, 3, 4, 5] if nd == 2 else [6, 3, 4, 5, 5]
5512        grid = torch.nn.functional.affine_grid(theta, size, align_corners=False)
5513
5514        grad_tensor = torch.rand(grid.shape, device=device)
5515
5516        memory_format_cl = torch.channels_last if nd == 2 else torch.channels_last_3d
5517        grad_tensor_cl = grad_tensor.contiguous(memory_format=memory_format_cl)
5518
5519        assert theta.grad is None
5520        grid.backward(grad_tensor_cl)
5521        theta_grad_cl = theta.grad.clone().contiguous()
5522
5523        theta.grad.zero_()
5524        grid.backward(grad_tensor)
5525        theta_grad_cf = theta.grad
5526
5527        self.assertEqual(theta_grad_cf, theta_grad_cl)
5528
5529    @set_default_dtype(torch.double)
5530    def test_grid_sample(self):
5531        # Backward pass of native C++ and CUDA kernels branch depending on whether input requires gradient,
5532        # so we test both cases.
5533        def test(N, C, H, W, mode, padding_mode, align_corners, input_requires_grad):
5534            def test_shape(N, C, IH, IW, H, W, mode, padding_mode, align_corners):
5535                for grid_dim_contig_order in [(0, 1, 2, 3), (0, 3, 1, 2), (3, 0, 1, 2), (0, 2, 1, 3)]:
5536                    # grid_dim_contig_order specifies the dimension order that can
5537                    # make grid to be contiguous.
5538                    # i.e., grid.permute(grid_dim_contig_order) is contiguous.
5539                    # e.g., with grid_dim_contig_order=[0, 3, 1, 2], grid should be
5540                    #       initialized with contiguous tensor of shape [N, 2, H, W]
5541                    #       and permuted to [N, H, W, 2] afterwards.
5542                    grid_shape = [N, H, W, 2]
5543                    grid_init_shape = [grid_shape[d] for d in grid_dim_contig_order]
5544                    grid_fwd_permute = [None, None, None, None]
5545                    for i, d in enumerate(grid_dim_contig_order):
5546                        grid_fwd_permute[d] = i
5547
5548                    def get_grid(device='cpu', data=None):
5549                        if data is not None:
5550                            assert list(data.shape) == grid_shape
5551                            data = data.permute(grid_dim_contig_order).to(device)
5552                        else:
5553                            data = torch.randn(grid_init_shape, device=device)
5554                        grid = data.permute(grid_fwd_permute)
5555                        assert grid.permute(grid_dim_contig_order).is_contiguous()
5556                        return grid
5557
5558                    input_cpu = torch.randn(C, N, IH, IW).transpose(0, 1).requires_grad_(input_requires_grad)
5559                    grid_cpu = get_grid().requires_grad_()
5560                    out_cpu = F.grid_sample(input_cpu, grid_cpu, mode=mode, padding_mode=padding_mode,
5561                                            align_corners=align_corners)
5562                    self.assertTrue(out_cpu.size() == torch.Size([N, C, H, W]))
5563
5564                    gradients = torch.randn_like(out_cpu)
5565                    out_cpu.backward(gradients)
5566
5567
5568                    # Compare against unvectorized CPU fallback
5569
5570                    # NOTE [ grid_sample CPU fallback ]
5571                    # grid_sample uses AVX for 2d images, but that requires 32-bit indexing for
5572                    # 32-bit floats. So we also have a fallback that is used only for float tensors
5573                    # requiring 64-bit indexing. That requires too much memory to run on CI, so we
5574                    # also export the fallback and test it here to ensure feature parity with
5575                    # the vectorized version.
5576                    input_fallback = input_cpu.float().detach_().requires_grad_()
5577                    grid_fallback = grid_cpu.float().detach_().requires_grad_()
5578                    out_fallback = torch._grid_sampler_2d_cpu_fallback(
5579                        input_fallback, grid_fallback,
5580                        F.GRID_SAMPLE_INTERPOLATION_MODES[mode],
5581                        F.GRID_SAMPLE_PADDING_MODES[padding_mode],
5582                        align_corners)
5583                    self.assertEqual(out_fallback, out_cpu.float(), atol=1e-5, rtol=5e-5)
5584
5585                    out_fallback.backward(gradients.float())
5586                    if input_requires_grad:
5587                        self.assertEqual(input_fallback.grad, input_cpu.grad.float(), atol=1e-4, rtol=5e-5)
5588                    self.assertEqual(grid_fallback.grad, grid_cpu.grad.float(), atol=1e-4, rtol=5e-5)
5589
5590                    if TEST_CUDA:
5591                        input_cuda = input_cpu.detach().transpose(0, 1).cuda().transpose(0, 1).requires_grad_(input_requires_grad)
5592                        grid_cuda = get_grid('cuda', grid_cpu.detach()).requires_grad_()
5593                        out_cuda = F.grid_sample(input_cuda, grid_cuda, mode=mode, padding_mode=padding_mode,
5594                                                 align_corners=align_corners)
5595                        self.assertEqual(out_cpu, out_cuda)
5596
5597                        out_cuda.backward(gradients.cuda())
5598                        if input_requires_grad:
5599                            self.assertEqual(input_cpu.grad, input_cuda.grad)
5600                        self.assertEqual(grid_cpu.grad, grid_cuda.grad, atol=5e-5, rtol=0)
5601
5602                        # check that zero-dimensional input strides don't error out
5603                        base_input = torch.randn(N, C, 1, IW)
5604                        input_cpu = base_input.expand_as(input_cuda).requires_grad_(input_requires_grad)
5605                        out_cpu = F.grid_sample(input_cpu, grid_cpu, mode=mode, padding_mode=padding_mode,
5606                                                align_corners=align_corners)
5607
5608                        input_cuda = base_input.cuda().expand_as(input_cuda).requires_grad_(input_requires_grad)
5609                        out_cuda = F.grid_sample(input_cuda, grid_cuda, mode=mode, padding_mode=padding_mode,
5610                                                 align_corners=align_corners)
5611                        self.assertEqual(out_cpu, out_cuda)
5612
5613            # test same size output
5614            test_shape(N, C, H, W, H, W, mode, padding_mode, align_corners)
5615
5616            # test larger output
5617            N = random.randint(2, 8)
5618            C = random.randint(2, 8)
5619            IH = random.randint(2, 8)
5620            IW = random.randint(2, 8)
5621            H = random.randint(IH + 1, 12)
5622            W = random.randint(IW + 1, 12)
5623            test_shape(N, C, IH, IW, H, W, mode, padding_mode, align_corners)
5624
5625            # test smaller output
5626            N = random.randint(2, 8)
5627            C = random.randint(2, 8)
5628            IH = random.randint(2, 8)
5629            IW = random.randint(2, 8)
5630            H = random.randint(2, IH)
5631            W = random.randint(2, IW)
5632            test_shape(N, C, IH, IW, H, W, mode, padding_mode, align_corners)
5633
5634            # test 1x1 inpput
5635            N = random.randint(2, 8)
5636            C = random.randint(2, 8)
5637            IH = 1
5638            IW = 1
5639            H = random.randint(2, 5)
5640            W = random.randint(2, 5)
5641            test_shape(N, C, IH, IW, H, W, mode, padding_mode, align_corners)
5642
5643            # testing empty grid
5644            N = random.randint(2, 8)
5645            C = random.randint(2, 8)
5646            IH = random.randint(2, 8)
5647            IW = random.randint(2, 8)
5648            W = random.randint(3, IW + 2)
5649            test_shape(N, C, IH, IW, 0, W, mode, padding_mode, align_corners)
5650
5651            # testing empty channel
5652            N = random.randint(2, 8)
5653            IH = random.randint(2, 8)
5654            IW = random.randint(2, 8)
5655            H = random.randint(3, IH + 2)
5656            W = random.randint(3, IW + 2)
5657            test_shape(N, 0, IH, IW, H, W, mode, padding_mode, align_corners)
5658
5659            # testing empty batch
5660            C = random.randint(2, 8)
5661            IH = random.randint(2, 8)
5662            IW = random.randint(2, 8)
5663            H = random.randint(3, IH + 2)
5664            W = random.randint(3, IW + 2)
5665            test_shape(0, C, IH, IW, H, W, mode, padding_mode, align_corners)
5666
5667        for mode in ('bilinear', 'nearest', 'bicubic'):
5668            for padding_mode in ('zeros', 'border', 'reflection'):
5669                for align_corners in (True, False):
5670                    # test known input on CPU
5671                    input = torch.arange(1., 11).view(1, 1, 2, 5)
5672                    grid = torch.tensor(
5673                        [[[-0.9, -4.1], [0, 0.2000], [1, -1], [-0.333, 1e-6], [0.5, 1.0]],
5674                         [[-1.0, -0.5], [0, 0.3333], [1, -1], [-0.200, 1e-6], [1.5, 0.5]]]).view(1, 2, 5, 2)
5675                    if mode == 'bilinear':
5676                        if padding_mode == 'zeros':
5677                            if align_corners:
5678                                groundtruth = torch.tensor(
5679                                    [[0.0000, 6.0000000000, 5.0000, 4.8340, 9.0000],
5680                                     [2.2500, 6.3332500450, 5.0000, 5.1000, 0.0000]]).view(1, 1, 2, 5)
5681                            else:
5682                                groundtruth = torch.tensor(
5683                                    [[0.0000, 6.5000000000, 1.2500, 4.6675000191, 4.6250],
5684                                     [0.5000, 7.1665000916, 1.2500, 5.0000000000, 0.0000]]).view(1, 1, 2, 5)
5685                        elif padding_mode == 'border':
5686                            if align_corners:
5687                                groundtruth = torch.tensor(
5688                                    [[1.2000, 6.0000000000, 5.0000, 4.8340, 9.0000],
5689                                     [2.2500, 6.3332500450, 5.0000, 5.1000, 8.7500]]).view(1, 1, 2, 5)
5690                            else:
5691                                groundtruth = torch.tensor(
5692                                    [[1.0000, 6.5000000000, 5.0000, 4.6675000191, 9.2500],
5693                                     [1.0000, 7.1665000916, 5.0000, 5.0000000000, 10.0000]]).view(1, 1, 2, 5)
5694                        elif padding_mode == 'reflection':
5695                            if align_corners:
5696                                groundtruth = torch.tensor(
5697                                    [[3.4500, 6.0000000000, 5.0000, 4.8340, 9.0000],
5698                                     [2.2500, 6.3332500450, 5.0000, 5.1000, 7.7500]]).view(1, 1, 2, 5)
5699                            else:
5700                                groundtruth = torch.tensor(
5701                                    [[3.0000004768, 6.5000000000, 5.0000, 4.6675000191, 9.2500],
5702                                     [1.0000000000, 7.1665000916, 5.0000, 5.0000000000, 9.2500]]).view(1, 1, 2, 5)
5703                        else:
5704                            raise AssertionError(f"missing groundtruth test for padding mode '{padding_mode}'")
5705                    elif mode == 'nearest':
5706                        if padding_mode == 'zeros':
5707                            if align_corners:
5708                                groundtruth = torch.tensor(
5709                                    [[0., 8., 5., 7., 9.],
5710                                     [1., 8., 5., 8., 0.]]).view(1, 1, 2, 5)
5711                            else:
5712                                groundtruth = torch.tensor(
5713                                    [[0., 8., 5., 7., 0.],
5714                                     [1., 8., 5., 8., 0.]]).view(1, 1, 2, 5)
5715                        elif padding_mode == 'border':
5716                            if align_corners:
5717                                groundtruth = torch.tensor(
5718                                    [[1., 8., 5., 7., 9.],
5719                                     [1., 8., 5., 8., 10.]]).view(1, 1, 2, 5)
5720                            else:
5721                                groundtruth = torch.tensor(
5722                                    [[1., 8., 5., 7., 9.],
5723                                     [1., 8., 5., 8., 10.]]).view(1, 1, 2, 5)
5724                        elif padding_mode == 'reflection':
5725                            if align_corners:
5726                                groundtruth = torch.tensor(
5727                                    [[1., 8., 5., 7., 9.],
5728                                     [1., 8., 5., 8., 9.]]).view(1, 1, 2, 5)
5729                            else:
5730                                groundtruth = torch.tensor(
5731                                    [[1., 8., 5., 7., 9.],
5732                                     [1., 8., 5., 8., 9.]]).view(1, 1, 2, 5)
5733                        else:
5734                            raise AssertionError(f"missing groundtruth test for padding mode '{padding_mode}'")
5735                    elif mode == 'bicubic':
5736                        if padding_mode == 'zeros':
5737                            if align_corners:
5738                                groundtruth = torch.tensor(
5739                                    [[-0.10424726, 7.1400003, 5.0000, 5.7842274, 9.0000],
5740                                     [2.4492188, 7.4814040, 5.0000, 6.0277520, 0.0000]]).view(1, 1, 2, 5)
5741                            else:
5742                                groundtruth = torch.tensor(
5743                                    [[0.00000, 7.6287503, 1.0625, 5.5977230, 5.3270264],
5744                                     [0.40625, 8.0288770, 1.0625, 5.9375067, -0.3515625]]).view(1, 1, 2, 5)
5745                        elif padding_mode == 'border':
5746                            if align_corners:
5747                                groundtruth = torch.tensor(
5748                                    [[1.1520010, 6.0599990, 5.0000, 4.870930, 9.0000000],
5749                                     [2.1328125, 6.4258375, 5.0000, 5.076003, 8.8671875]]).view(1, 1, 2, 5)
5750                            else:
5751                                groundtruth = torch.tensor(
5752                                    [[0.894531, 6.6050020, 4.625, 4.7138715, 9.800781],
5753                                     [0.906250, 7.2822485, 4.625, 5.0000052, 10.00000]]).view(1, 1, 2, 5)
5754                        elif padding_mode == 'reflection':
5755                            if align_corners:
5756                                groundtruth = torch.tensor(
5757                                    [[3.1822524, 6.239998, 5.0000, 4.8709273, 9.00000],
5758                                     [1.7812500, 6.703594, 5.0000, 5.0760007, 8.21875]]).view(1, 1, 2, 5)
5759                            else:
5760                                groundtruth = torch.tensor(
5761                                    [[2.7993753, 6.6050020, 4.25, 4.7138715, 10.269531],
5762                                     [0.8125000, 7.2822485, 4.25, 5.0000052, 9.332031]]).view(1, 1, 2, 5)
5763                        else:
5764                            raise AssertionError(f"missing groundtruth test for padding mode '{padding_mode}'")
5765
5766                    else:
5767                        raise AssertionError(f"missing groundtruth test for interpolation mode '{mode}'")
5768                    output = F.grid_sample(input, grid, mode=mode, padding_mode=padding_mode,
5769                                           align_corners=align_corners)
5770                    self.assertEqual(output, groundtruth, atol=1e-5, rtol=0,
5771                                     msg=f"groundtruth comparison failed for mode={mode}, "
5772                                     f"padding_mode={padding_mode}")
5773
5774                    # See NOTE [ grid_sample CPU fallback ]
5775                    output = torch._grid_sampler_2d_cpu_fallback(
5776                        input.float(), grid.float(),
5777                        F.GRID_SAMPLE_INTERPOLATION_MODES[mode],
5778                        F.GRID_SAMPLE_PADDING_MODES[padding_mode],
5779                        align_corners)
5780                    self.assertEqual(output, groundtruth.float(), atol=1e-5, rtol=0)
5781
5782                    # explicit check for gradient edge cases
5783                    input = torch.arange(0., 5).expand((1, 1, 5, 5))
5784                    grid = torch.tensor(
5785                        [[[1.0, 1.0], [1.0, -1.0], [0.8, 0.8], [0.8, -0.8]],
5786                         [[-1.0, -1.0], [-1.0, 1.0], [-0.8, -0.8], [-0.8, 0.8]]]).view(1, 2, 4, 2).requires_grad_()
5787                    if mode == 'bilinear':
5788                        if padding_mode == 'zeros':
5789                            if align_corners:
5790                                groundtruth = torch.tensor(
5791                                    [[[[-8., -8.], [-8., 0.], [2., 0.], [2., 0.]],
5792                                      [[2., 0.], [2., 0.], [2., 0.], [2., 0.]]]]).view(1, 2, 4, 2)
5793                            else:
5794                                groundtruth = torch.tensor(
5795                                    [[[[-5., -5.], [-5., 5.], [-10., -10.], [-10., 10.]],
5796                                      [[0., 0.], [0., 0.], [0., 0.], [0., 0.]]]]).view(1, 2, 4, 2)
5797                        elif padding_mode == 'border':
5798                            if align_corners:
5799                                groundtruth = torch.tensor(
5800                                    [[[[-0., -0.], [-0., 0.], [2., 0.], [2., 0.]],
5801                                      [[0., 0.], [0., 0.], [2., 0.], [2., 0.]]]]).view(1, 2, 4, 2)
5802                            else:
5803                                groundtruth = torch.tensor(
5804                                    [[[[-0., -0.], [-0., 0.], [-0., -0.], [-0., 0.]],
5805                                      [[0., 0.], [0., 0.], [0., 0.], [0., 0.]]]]).view(1, 2, 4, 2)
5806                        elif padding_mode == 'reflection':
5807                            if align_corners:
5808                                groundtruth = torch.tensor(
5809                                    [[[[-0., -0.], [-0., 0.], [2., 0.], [2., 0.]],
5810                                      [[0., 0.], [0., 0.], [2., 0.], [2., 0.]]]]).view(1, 2, 4, 2)
5811                            else:
5812                                groundtruth = torch.tensor(
5813                                    [[[[-0., -0.], [-0., 0.], [-0., -0.], [-0., 0.]],
5814                                      [[0., 0.], [0., 0.], [0., 0.], [0., 0.]]]]).view(1, 2, 4, 2)
5815                        else:
5816                            raise AssertionError(f"missing gradient groundtruth test for padding mode '{padding_mode}'")
5817                    elif mode == 'nearest':
5818                        groundtruth = torch.tensor(
5819                            [[[[-0., -0.], [-0., 0.], [-0., -0.], [-0., 0.]],
5820                              [[0., 0.], [0., 0.], [0., 0.], [0., 0.]]]]).view(1, 2, 4, 2)
5821                    elif mode == 'bicubic':
5822                        if padding_mode == 'zeros':
5823                            if align_corners:
5824                                groundtruth = torch.tensor(
5825                                    [[[[-4.5, -6.], [-4.5, 6.], [2.725679, 0.740878], [2.725679, -0.740878]],
5826                                      [[1.5, 0.], [1.5, 0.], [1.927921, -0.05688], [1.927921, 0.05688]]]]).view(1, 2, 4, 2)
5827                            else:
5828                                groundtruth = torch.tensor(
5829                                    [[[[-5.859375, -5.888672], [-5.859375, 5.888672], [-5.6250, -7.5000], [-5.6250, 7.5000]],
5830                                      [[-0.234375, -0.263672], [-0.234375, 0.263672], [1.8750, 0.], [1.8750, 0.]]]]
5831                                ).view(1, 2, 4, 2)
5832                        elif padding_mode == 'border':
5833                            if align_corners:
5834                                groundtruth = torch.tensor(
5835                                    [[[[1.5, 0.], [1.5, 0.], [1.74, 0.], [1.74, 0.]],
5836                                      [[1.5, 0.], [1.5, 0.], [1.74, 0.], [1.74, 0.]]]]).view(1, 2, 4, 2)
5837                            else:
5838                                groundtruth = torch.tensor(
5839                                    [[[[-0.46875, 0.], [-0.46875, 0.], [1.8750, 0.], [1.8750, 0.]],
5840                                      [[-0.46875, 0.], [-0.46875, 0.], [1.8750, 0.], [1.8750, 0.]]]]).view(1, 2, 4, 2)
5841                        elif padding_mode == 'reflection':
5842                            if align_corners:
5843                                groundtruth = torch.tensor(
5844                                    [[[[0., 0.], [0., 0.], [1.92, 0.], [1.92, 0.]],
5845                                      [[0., 0.], [0., 0.], [1.92, 0.], [1.92, 0.]]]]).view(1, 2, 4, 2)
5846                            else:
5847                                groundtruth = torch.tensor(
5848                                    [[[[0., 0.], [0., 0.], [1.875, 0.], [1.875, 0.]],
5849                                      [[0., 0.], [0., 0.], [1.875, 0.], [1.875, 0.]]]]).view(1, 2, 4, 2)
5850                        else:
5851                            raise AssertionError(f"missing gradient groundtruth test for padding mode '{padding_mode}'")
5852                    else:
5853                        raise AssertionError(f"missing gradient groundtruth test for interpolation mode '{mode}'")
5854                    for input_requires_grad in [False, True]:
5855                        input = input.requires_grad_(input_requires_grad)
5856                        F.grid_sample(input, grid, mode=mode, padding_mode=padding_mode,
5857                                      align_corners=align_corners).sum().backward()
5858                        self.assertEqual(grid.grad, groundtruth, atol=1e-5, rtol=0,
5859                                         msg=f"gradient groundtruth comparison failed for mode={mode}, "
5860                                         f"padding_mode={padding_mode}, input_requires_grad={input_requires_grad}")
5861                        grid.grad.zero_()
5862
5863                    # See NOTE [ grid_sample CPU fallback ]
5864                    torch._grid_sampler_2d_cpu_fallback(
5865                        input.float(), grid.float(),
5866                        F.GRID_SAMPLE_INTERPOLATION_MODES[mode],
5867                        F.GRID_SAMPLE_PADDING_MODES[padding_mode],
5868                        align_corners).sum().backward()
5869                    self.assertEqual(grid.grad, groundtruth, atol=1e-5, rtol=0)
5870
5871                    # do gradcheck
5872                    N = random.randint(2, 8)
5873                    C = random.randint(2, 6)
5874                    H = random.randint(2, 8)
5875                    W = random.randint(2, 8)
5876                    input = torch.randn(N, C, H, W, requires_grad=True)
5877                    grid = torch.randn(N, H, W, 2, requires_grad=True)
5878
5879                    for input_requires_grad in [False, True]:
5880                        input.requires_grad_(input_requires_grad)
5881                        self.assertTrue(gradcheck(
5882                            lambda inp, grd: F.grid_sample(inp, grd, mode=mode, padding_mode=padding_mode,
5883                                                           align_corners=align_corners),
5884                            (input, grid)))
5885                        test(N, C, H, W, mode, padding_mode, align_corners, input_requires_grad)
5886                        if TEST_CUDNN:
5887                            with cudnn.flags(enabled=False):
5888                                test(N, C, H, W, mode, padding_mode, align_corners, input_requires_grad)
5889
5890    @set_default_dtype(torch.double)
5891    def test_grid_sample_3d(self):
5892        # Backward pass of native C++ and CUDA kernels branch depending on whether input requires gradient,
5893        # so we test both cases.
5894        def test(N, C, D, H, W, mode, padding_mode, align_corners, input_requires_grad):
5895            def test_shape(N, C, ID, IH, IW, D, H, W, mode, padding_mode, align_corners):
5896                input_cpu = torch.randn(C, N, ID, IH, IW).transpose(0, 1).requires_grad_(input_requires_grad)
5897                grid_cpu = torch.randn(D, N, H, W, 3).transpose(0, 1).requires_grad_()
5898                out_cpu = F.grid_sample(input_cpu, grid_cpu, mode=mode, padding_mode=padding_mode,
5899                                        align_corners=align_corners)
5900                self.assertTrue(out_cpu.size() == torch.Size([N, C, D, H, W]))
5901
5902                gradients = torch.randn_like(out_cpu)
5903                out_cpu.backward(gradients)
5904
5905                if TEST_CUDA:
5906                    input_cuda = input_cpu.detach().transpose(0, 1).cuda().transpose(0, 1).requires_grad_(input_requires_grad)
5907                    grid_cuda = grid_cpu.detach().transpose(0, 1).cuda().transpose(0, 1).requires_grad_()
5908                    out_cuda = F.grid_sample(input_cuda, grid_cuda, mode=mode, padding_mode=padding_mode,
5909                                             align_corners=align_corners)
5910                    self.assertEqual(out_cpu, out_cuda)
5911
5912                    out_cuda.backward(gradients.cuda())
5913                    if input_requires_grad:
5914                        self.assertEqual(input_cpu.grad, input_cuda.grad)
5915                    self.assertEqual(grid_cpu.grad, grid_cuda.grad, atol=5e-5, rtol=0)
5916
5917                    # check that zero-dimensional input strides don't error out
5918                    base_input = torch.randn(N, C, 1, IH, IW)
5919                    input_cpu = base_input.expand_as(input_cuda).requires_grad_(input_requires_grad)
5920                    grid_cpu = torch.randn(N, D, H, W, 3, requires_grad=True)
5921                    out_cpu = F.grid_sample(input_cpu, grid_cpu, mode=mode, padding_mode=padding_mode,
5922                                            align_corners=align_corners)
5923
5924                    input_cuda = base_input.cuda().expand_as(input_cuda).requires_grad_(input_requires_grad)
5925                    grid_cuda = grid_cpu.detach().cuda().requires_grad_()
5926                    out_cuda = F.grid_sample(input_cuda, grid_cuda, mode=mode, padding_mode=padding_mode,
5927                                             align_corners=align_corners)
5928                    self.assertEqual(out_cpu, out_cuda)
5929
5930            # test same size output
5931            test_shape(N, C, D, H, W, D, H, W, mode, padding_mode, align_corners)
5932
5933            # test larger output
5934            N = random.randint(2, 7)
5935            C = random.randint(2, 5)
5936            ID = random.randint(2, 7)
5937            IH = random.randint(2, 7)
5938            IW = random.randint(2, 7)
5939            D = random.randint(ID + 1, 10)
5940            H = random.randint(IH + 1, 10)
5941            W = random.randint(IW + 1, 10)
5942            test_shape(N, C, ID, IH, IW, D, H, W, mode, padding_mode, align_corners)
5943
5944            # test smaller output
5945            N = random.randint(2, 7)
5946            C = random.randint(2, 5)
5947            ID = random.randint(2, 7)
5948            IH = random.randint(2, 7)
5949            IW = random.randint(2, 7)
5950            D = random.randint(2, ID)
5951            H = random.randint(2, IH)
5952            W = random.randint(2, IW)
5953            test_shape(N, C, ID, IH, IW, D, H, W, mode, padding_mode, align_corners)
5954
5955            # test 1x1 inpput
5956            N = random.randint(2, 7)
5957            C = random.randint(2, 7)
5958            ID = 1
5959            IH = 1
5960            IW = 1
5961            H = random.randint(2, 5)
5962            W = random.randint(2, 5)
5963            test_shape(N, C, ID, IH, IW, D, H, W, mode, padding_mode, align_corners)
5964
5965            # testing empty grid
5966            N = random.randint(2, 7)
5967            C = random.randint(2, 5)
5968            ID = random.randint(2, 7)
5969            IH = random.randint(2, 7)
5970            IW = random.randint(2, 7)
5971            D = random.randint(3, ID + 2)
5972            W = random.randint(3, IW + 2)
5973            test_shape(N, C, ID, IH, IW, D, 0, W, mode, padding_mode, align_corners)
5974
5975            # testing empty channel
5976            N = random.randint(2, 7)
5977            ID = random.randint(2, 5)
5978            IH = random.randint(2, 7)
5979            IW = random.randint(2, 7)
5980            D = random.randint(3, ID + 2)
5981            H = random.randint(3, IH + 2)
5982            W = random.randint(3, IW + 2)
5983            test_shape(N, 0, ID, IH, IW, D, H, W, mode, padding_mode, align_corners)
5984
5985            # testing empty batch
5986            C = random.randint(2, 5)
5987            ID = random.randint(2, 7)
5988            IH = random.randint(2, 7)
5989            IW = random.randint(2, 7)
5990            D = random.randint(3, ID + 2)
5991            H = random.randint(3, IH + 2)
5992            W = random.randint(3, IW + 2)
5993            test_shape(0, C, ID, IH, IW, D, H, W, mode, padding_mode, align_corners)
5994
5995        for mode in ('bilinear', 'nearest'):
5996            for padding_mode in ('zeros', 'border', 'reflection'):
5997                for align_corners in (True, False):
5998                    # do gradcheck
5999                    N = random.randint(2, 5)
6000                    C = random.randint(2, 4)
6001                    D = random.randint(2, 5)
6002                    H = random.randint(2, 5)
6003                    W = random.randint(2, 5)
6004                    input = torch.randn(N, C, D, H, W, requires_grad=True)
6005                    grid = torch.randn(N, D, H, W, 3, requires_grad=True)
6006                    self.assertTrue(gradcheck(
6007                        lambda inp, grid: F.grid_sample(inp, grid, mode=mode, padding_mode=padding_mode,
6008                                                        align_corners=align_corners),
6009                        (input, grid)))
6010                    input = input.requires_grad_(False)
6011                    self.assertTrue(gradcheck(
6012                        lambda grid: F.grid_sample(input, grid, mode=mode, padding_mode=padding_mode,
6013                                                   align_corners=align_corners),
6014                        (grid,)))
6015
6016                    for input_requires_grad in [False, True]:
6017                        test(N, C, D, H, W, mode, padding_mode, align_corners, input_requires_grad)
6018
6019    def test_grid_sample_nearest_neighbor_rounding_mode_consistency(self):
6020
6021        device_list = ['cpu']
6022        if TEST_CUDA:
6023            device_list.append('cuda')
6024
6025        def normalize_indices(indices_unnormalized: torch.Tensor, dim_size: int, align_corners: bool):
6026            if align_corners:
6027                indices_normalized = 2 * indices_unnormalized / (dim_size - 1) - 1
6028            else:
6029                indices_normalized = (indices_unnormalized * 2 + 1) / dim_size - 1
6030            return indices_normalized
6031
6032        test_dim_size = 10
6033        non_test_dim_size = 9
6034        step_size = 0.1
6035
6036        batch_size = 1
6037        channel_size = 1
6038
6039        mode = 'nearest'
6040        for device in device_list:
6041            for padding_mode in ('zeros', 'border', 'reflection'):
6042                for align_corners in (True, False):
6043                    # Unnormalized inquiry indices
6044                    inquiry_indices_unnormalized = torch.arange(
6045                        0,
6046                        test_dim_size - 1 + step_size, step_size,
6047                        dtype=torch.float32,
6048                        device=device
6049                    )
6050                    # Note that even though we are trying to create normalized indices
6051                    # which results in x.0 and x.5 indices after unnormalization,
6052                    # because of the numerical error,
6053                    # the rounding direction might not always be expected as designed.
6054                    # The best we could do is to ensure the rounding behaviors across
6055                    # different implementations for different dimensions are
6056                    # exactly the same.
6057                    inquiry_indices = normalize_indices(
6058                        indices_unnormalized=inquiry_indices_unnormalized,
6059                        dim_size=test_dim_size,
6060                        align_corners=align_corners
6061                    )
6062                    num_inqueries = inquiry_indices.shape[0]
6063                    inquiry_fixed_indices = torch.full((num_inqueries,), 0.5, dtype=torch.float32, device=device)
6064                    array_data = torch.rand(test_dim_size, dtype=torch.float32, device=device)
6065                    # 2D grid sample x-dim interpolation
6066                    # The input_tensor_2d_x is of shape
6067                    # [batch_size, channel_size, non_test_dim_size, test_dim_size]
6068                    input_tensor_2d_x = array_data.reshape(1, test_dim_size).repeat(
6069                        batch_size,
6070                        channel_size,
6071                        non_test_dim_size,
6072                        1
6073                    )
6074                    # The grid_tensor_2d_x is of shape
6075                    # [batch_size, 1, num_inqueries]
6076                    grid_tensor_2d_x = torch.cat(
6077                        tensors=(
6078                            inquiry_indices.reshape(num_inqueries, 1),
6079                            inquiry_fixed_indices.reshape(num_inqueries, 1),
6080                        ),
6081                        dim=1
6082                    ).repeat(batch_size, 1, 1, 1)
6083                    # The output_tensor_2d_x is of shape
6084                    # [batch_size, channel_size, 1, num_inqueries]
6085                    output_tensor_2d_x = F.grid_sample(
6086                        input=input_tensor_2d_x,
6087                        grid=grid_tensor_2d_x,
6088                        mode=mode,
6089                        padding_mode=padding_mode,
6090                        align_corners=align_corners,
6091                    )
6092                    # 2D grid sample y-dim interpolation
6093                    # The input_tensor_2d_y is of shape
6094                    # [batch_size, channel_size, test_dim_size, non_test_dim_size]
6095                    input_tensor_2d_y = torch.transpose(input_tensor_2d_x, 3, 2)
6096                    # The grid_tensor_2d_y is of shape
6097                    # [batch_size, 1, num_inqueries]
6098                    grid_tensor_2d_y = torch.index_select(
6099                        grid_tensor_2d_x,
6100                        -1,
6101                        torch.tensor([1, 0], dtype=torch.int64, device=device)
6102                    )
6103                    # The output_tensor_2d_y is of shape
6104                    # [batch_size, channel_size, 1, num_inqueries]
6105                    output_tensor_2d_y = F.grid_sample(
6106                        input=input_tensor_2d_y,
6107                        grid=grid_tensor_2d_y,
6108                        mode=mode,
6109                        padding_mode=padding_mode,
6110                        align_corners=align_corners,
6111                    )
6112                    self.assertEqual(output_tensor_2d_x[0, 0, 0, :], output_tensor_2d_y[0, 0, 0, :], atol=0, rtol=0)
6113                    # 3D grid sample x-dim interpolation
6114                    # The input_tensor_3d_x is of shape
6115                    # [batch_size, channel_size, non_test_dim_size, non_test_dim_size, test_dim_size]
6116                    input_tensor_3d_x = array_data.reshape(1, test_dim_size).repeat(
6117                        batch_size, channel_size, non_test_dim_size, non_test_dim_size, 1)
6118                    # The grid_tensor_3d_x is of shape
6119                    # [batch_size, 1, 1, num_inqueries]
6120                    grid_tensor_3d_x = torch.cat(
6121                        tensors=(
6122                            inquiry_indices.reshape(num_inqueries, 1),
6123                            inquiry_fixed_indices.reshape(num_inqueries, 1),
6124                            inquiry_fixed_indices.reshape(num_inqueries, 1),
6125                        ),
6126                        dim=1
6127                    ).repeat(batch_size, 1, 1, 1, 1)
6128                    # The output_tensor_3d_x is of shape
6129                    # [batch_size, channel_size, 1, 1, num_inqueries]
6130                    output_tensor_3d_x = F.grid_sample(
6131                        input=input_tensor_3d_x,
6132                        grid=grid_tensor_3d_x,
6133                        mode=mode,
6134                        padding_mode=padding_mode,
6135                        align_corners=align_corners,
6136                    )
6137                    self.assertEqual(output_tensor_2d_x[0, 0, 0, :], output_tensor_3d_x[0, 0, 0, 0, :], atol=0, rtol=0)
6138                    # 3D grid sample y-dim interpolation
6139                    # The input_tensor_3d_y is of shape
6140                    # [batch_size, channel_size, non_test_dim_size, test_dim_size, non_test_dim_size]
6141                    input_tensor_3d_y = torch.transpose(input_tensor_3d_x, 4, 3)
6142                    # The grid_tensor_3d_y is of shape
6143                    # [batch_size, 1, 1, num_inqueries]
6144                    grid_tensor_3d_y = torch.index_select(
6145                        grid_tensor_3d_x,
6146                        -1,
6147                        torch.tensor([1, 0, 2], dtype=torch.int64, device=device)
6148                    )
6149                    # The output_tensor_3d_y is of shape
6150                    # [batch_size, channel_size, 1, 1, num_inqueries]
6151                    output_tensor_3d_y = F.grid_sample(
6152                        input=input_tensor_3d_y,
6153                        grid=grid_tensor_3d_y,
6154                        mode=mode,
6155                        padding_mode=padding_mode,
6156                        align_corners=align_corners,
6157                    )
6158                    self.assertEqual(output_tensor_2d_x[0, 0, 0, :], output_tensor_3d_y[0, 0, 0, 0, :], atol=0, rtol=0)
6159                    # 3D grid sample z-dim interpolation
6160                    # The input_tensor_3d_z is of shape
6161                    # [batch_size, channel_size, non_test_dim_size, non_test_dim_size, test_dim_size]
6162                    input_tensor_3d_z = torch.transpose(input_tensor_3d_x, 4, 2)
6163                    # The grid_tensor_3d_z is of shape
6164                    # [batch_size, 1, 1, num_inqueries]
6165                    grid_tensor_3d_z = torch.index_select(
6166                        grid_tensor_3d_x,
6167                        -1,
6168                        torch.tensor([1, 2, 0], dtype=torch.int64, device=device)
6169                    )
6170                    # The output_tensor_3d_z is of shape
6171                    # [batch_size, channel_size, 1, 1, num_inqueries]
6172                    output_tensor_3d_z = F.grid_sample(
6173                        input=input_tensor_3d_z,
6174                        grid=grid_tensor_3d_z,
6175                        mode=mode,
6176                        padding_mode=padding_mode,
6177                        align_corners=align_corners,
6178                    )
6179                    self.assertEqual(output_tensor_2d_x[0, 0, 0, :], output_tensor_3d_z[0, 0, 0, 0, :], atol=0, rtol=0)
6180
6181    @set_default_dtype(torch.double)
6182    def test_affine_grid(self):
6183        # test known input on CPU
6184        input = torch.arange(1., 7).view(1, 2, 3)
6185        output = F.affine_grid(input, torch.Size([1, 1, 2, 2]), align_corners=True)
6186        groundtruth = torch.tensor(
6187            [[[0., -3.], [2., 5.]], [[4., 7.], [6., 15.]]]).view(1, 2, 2, 2)
6188        self.assertEqual(output, groundtruth)
6189        output = F.affine_grid(input, torch.Size([1, 1, 2, 2]), align_corners=False)
6190        groundtruth = torch.tensor(
6191            [[[1.5, 1.5], [2.5, 5.5]], [[3.5, 6.5], [4.5, 10.5]]]).view(1, 2, 2, 2)
6192        self.assertEqual(output, groundtruth)
6193
6194        for align_corners in (True, False):
6195            # do gradcheck
6196            N = random.randint(1, 8)
6197            C = random.randint(1, 8)
6198            H = random.randint(1, 8)
6199            W = random.randint(1, 8)
6200            sz = torch.Size([N, C, H, W])
6201            inp = torch.randn(N, 2, 3, requires_grad=True)
6202            with warnings.catch_warnings(record=True):
6203                warnings.simplefilter("always")  # python2 requires this so other tests can trigger
6204                self.assertTrue(gradcheck(
6205                    lambda inp: F.affine_grid(inp, sz, align_corners=align_corners),
6206                    (inp,)))
6207
6208        # test CPU against CUDA
6209        if TEST_CUDA:
6210            N = random.randint(1, 8)
6211            C = random.randint(1, 8)
6212            H = random.randint(1, 8)
6213            W = random.randint(1, 8)
6214            sz = torch.Size([N, C, H, W])
6215            for align_corners in (True, False):
6216                input_cpu = torch.randn(N, 2, 3, requires_grad=True)
6217                with warnings.catch_warnings(record=True):
6218                    warnings.simplefilter("always")  # python2 requires this so other tests can trigger
6219                    out_cpu = F.affine_grid(input_cpu, sz, align_corners=align_corners)
6220                gradients = torch.randn(out_cpu.size())
6221                out_cpu.backward(gradients)
6222                input_gpu = input_cpu.detach().cuda().requires_grad_()
6223                with warnings.catch_warnings(record=True):
6224                    warnings.simplefilter("always")  # python2 requires this so other tests can trigger
6225                    out_cuda = F.affine_grid(input_gpu, sz, align_corners=align_corners)
6226                out_cuda.backward(gradients.cuda())
6227                self.assertEqual(out_cpu, out_cuda)
6228                self.assertEqual(input_cpu.grad, input_gpu.grad)
6229
6230    @set_default_dtype(torch.double)
6231    def test_affine_grid_3d(self):
6232        # test known input on CPU
6233        input = torch.arange(1., 13).view(1, 3, 4)
6234        output = F.affine_grid(input, torch.Size([1, 1, 2, 2, 2]), align_corners=True)
6235        groundtruth = torch.tensor(
6236            [[[[[-2., -10., -18.], [0., 0., 0.]], [[2., 2., 2.], [4., 12., 20.]]],
6237              [[[4., 4., 4.], [6., 14., 22.]], [[8., 16., 24.], [10., 26., 42.]]]]]).view(1, 2, 2, 2, 3)
6238        self.assertEqual(output, groundtruth)
6239        output = F.affine_grid(input, torch.Size([1, 1, 2, 2, 2]), align_corners=False)
6240        groundtruth = torch.tensor(
6241            [[[[[1., -1., -3.], [2., 4., 6.]], [[3., 5., 7.], [4., 10., 16.]]],
6242              [[[4., 6., 8.], [5., 11., 17.]], [[6., 12., 18.], [7., 17., 27.]]]]]).view(1, 2, 2, 2, 3)
6243        self.assertEqual(output, groundtruth)
6244
6245        for align_corners in (True, False):
6246            # do gradcheck
6247            N = random.randint(1, 8)
6248            C = random.randint(1, 8)
6249            D = random.randint(1, 8)
6250            H = random.randint(1, 8)
6251            W = random.randint(1, 8)
6252            sz = torch.Size([N, C, D, H, W])
6253            inp = torch.randn(N, 3, 4, requires_grad=True)
6254            with warnings.catch_warnings(record=True):
6255                warnings.simplefilter("always")  # python2 requires this so other tests can trigger
6256                self.assertTrue(gradcheck(
6257                    lambda inp: F.affine_grid(inp, sz, align_corners=align_corners),
6258                    (inp,)))
6259
6260        # test CPU against CUDA
6261        if TEST_CUDA:
6262            N = random.randint(1, 8)
6263            C = random.randint(1, 8)
6264            D = random.randint(1, 8)
6265            H = random.randint(1, 8)
6266            W = random.randint(1, 8)
6267            sz = torch.Size([N, C, D, H, W])
6268            for align_corners in (True, False):
6269                input_cpu = torch.randn(N, 3, 4, requires_grad=True)
6270                with warnings.catch_warnings(record=True):
6271                    warnings.simplefilter("always")  # python2 requires this so other tests can trigger
6272                    out_cpu = F.affine_grid(input_cpu, sz, align_corners=align_corners)
6273                gradients = torch.randn(out_cpu.size())
6274                out_cpu.backward(gradients)
6275                input_gpu = input_cpu.detach().cuda().requires_grad_()
6276                with warnings.catch_warnings(record=True):
6277                    warnings.simplefilter("always")  # python2 requires this so other tests can trigger
6278                    out_cuda = F.affine_grid(input_gpu, sz, align_corners=align_corners)
6279                out_cuda.backward(gradients.cuda())
6280                self.assertEqual(out_cpu, out_cuda)
6281                self.assertEqual(input_cpu.grad, input_gpu.grad)
6282
6283    def test_channel_shuffle_return_alias_of_self(self):
6284        # gh-76616: nn.ChannelShuffle will return alias of self with an empty input tensor
6285        groups = 3
6286        input_tensor = torch.rand([0, 9, 4, 4])
6287        output = torch.nn.ChannelShuffle(groups)(input_tensor)
6288        torch.testing.assert_close(output, input_tensor)
6289
6290    @skipIfTorchDynamo("TorchDynamo fails here for unknown reasons")
6291    def test_native_channel_shuffle_return_alias_of_self(self):
6292        groups = 3
6293        input_tensor = torch.rand([0, 9, 4, 4])
6294        output = torch.native_channel_shuffle(input_tensor, groups)
6295        torch.testing.assert_close(output, input_tensor)
6296
6297    @set_default_dtype(torch.double)
6298    def test_upsamplingLinear1d(self):
6299        for align_corners in [True, False]:
6300            for recompute_scale_factor in [True, False]:
6301                kwargs = dict(
6302                    mode='linear', align_corners=align_corners, recompute_scale_factor=recompute_scale_factor
6303                )
6304                # test float scale factor up & downsampling
6305                for scale_factor in [0.5, 1.5, 2]:
6306                    m = nn.Upsample(scale_factor=scale_factor, **kwargs)
6307                    in_t = torch.ones(1, 1, 2)
6308                    out_size = int(math.floor(in_t.shape[-1] * scale_factor))
6309                    with warnings.catch_warnings(record=True) as w:
6310                        out_t = m(in_t)
6311                    self.assertEqual(torch.ones(1, 1, out_size), out_t.data)
6312
6313                    input = torch.randn(1, 1, 2, requires_grad=True)
6314                    if not recompute_scale_factor:
6315                        gradcheck(lambda x: F.interpolate(x, out_size, **kwargs), (input,))
6316                    else:
6317                        gradcheck(lambda x: F.interpolate(x, scale_factor=scale_factor, **kwargs), (input,))
6318
6319    def test_upsamplingLinear1d_spatial_invariance(self):
6320        m = nn.Upsample(scale_factor=3, mode='linear', align_corners=False)
6321        in_t_9 = torch.zeros(1, 1, 9)
6322        in_t_9[:, :, :4].normal_()
6323        with warnings.catch_warnings(record=True) as w:
6324            out_t_9 = m(in_t_9)
6325            out_t_5 = m(in_t_9[:, :, :5])
6326        self.assertEqual(out_t_9[:, :, :15], out_t_5)
6327
6328    @set_default_dtype(torch.double)
6329    def test_upsampling_not_recompute_scale_factor(self):
6330        # test output against known input: result must match opencv
6331        in_t = torch.arange(8.).view(1, 2, 2, 2)
6332        expected_out_t = torch.tensor(
6333            [[[[-0.32725, -0.08843, 0.37933, 0.79744],
6334              [0.15039, 0.38921, 0.85697, 1.27508],
6335              [1.08591, 1.32473, 1.79249, 2.21060],
6336              [1.92213, 2.16095, 2.62871, 3.04682]],
6337
6338             [[3.67275, 3.91157, 4.37933, 4.79744],
6339              [4.15039, 4.38921, 4.85697, 5.27508],
6340              [5.08591, 5.32473, 5.79249, 6.21060],
6341              [5.92213, 6.16095, 6.62871, 7.04682]]]])
6342        if IS_PPC:
6343            # Both OpenCV and PyTorch give a slightly different result on PPC
6344            expected_out_t = torch.tensor(
6345                [[[[-0.32725, -0.08843, 0.37933, 0.79744],
6346                  [0.15039, 0.38921, 0.85697, 1.27508],
6347                  [1.08591, 1.32473, 1.79249, 2.21060],
6348                  [1.92212, 2.16094, 2.62870, 3.04681]],
6349
6350                 [[3.67275, 3.91157, 4.37933, 4.79743],
6351                  [4.15039, 4.38921, 4.85697, 5.27508],
6352                  [5.08591, 5.32473, 5.79249, 6.21059],
6353                  [5.92212, 6.16094, 6.62870, 7.04680]]]])
6354        out_t = F.interpolate(in_t, scale_factor=2.3, mode='bicubic', align_corners=False, recompute_scale_factor=False)
6355        torch.set_printoptions(precision=5)
6356        self.assertEqual(out_t, expected_out_t, atol=1e-4, rtol=0)
6357
6358        device_list = ['cpu']
6359        if TEST_CUDA:
6360            device_list.append('cuda')
6361
6362        for align_corners in [True, False]:
6363            kwargs = dict(mode='bicubic', align_corners=align_corners)
6364            # test float scale factor up & downsampling
6365            for device in device_list:
6366                for scale_factor in [0.6, 1.6, 2.3]:
6367                    in_t = torch.ones(2, 2, 2, 2).to(device)
6368                    out_t = F.interpolate(in_t, scale_factor=scale_factor, **kwargs)
6369                    out_size = int(math.floor(in_t.shape[-1] * scale_factor))
6370                    self.assertEqual(torch.ones(2, 2, out_size, out_size), out_t.data, atol=1e-5, rtol=0)
6371
6372                    input = torch.randn(2, 2, 2, 2, requires_grad=True)
6373                    gradcheck(lambda x: F.interpolate(x, out_size, **kwargs), [input])
6374
6375    def test_upsamplingBilinear2d_spatial_invariance(self):
6376        m = nn.Upsample(scale_factor=3, mode='bilinear', align_corners=False)
6377        in_t_9 = torch.zeros(1, 1, 9, 9)
6378        in_t_9[:, :, :4, :4].normal_()
6379        with warnings.catch_warnings(record=True) as w:
6380            out_t_9 = m(in_t_9)
6381            out_t_5 = m(in_t_9[:, :, :5, :5])
6382        self.assertEqual(out_t_9[:, :, :15, :15], out_t_5)
6383
6384    def test_upsamplingTrilinear3d_spatial_invariance(self):
6385        m = nn.Upsample(scale_factor=3, mode='trilinear', align_corners=False)
6386        in_t_9 = torch.zeros(1, 1, 9, 9, 9)
6387        in_t_9[:, :, :4, :4, :4].normal_()
6388        with warnings.catch_warnings(record=True) as w:
6389            out_t_9 = m(in_t_9)
6390            out_t_5 = m(in_t_9[:, :, :5, :5, :5])
6391        self.assertEqual(out_t_9[:, :, :15, :15, :15], out_t_5)
6392
6393    def test_upsampling_small_scale(self):
6394        m = torch.nn.Upsample(scale_factor=0.5, mode="bilinear")
6395        in_t = torch.arange(1, 5, dtype=torch.get_default_dtype()).reshape(1, 1, 2, 2)
6396        out_t = m(in_t)
6397        expected_out_t = torch.tensor([[[[2.5]]]])
6398        self.assertEqual(expected_out_t, out_t)
6399
6400    def test_upsampling_bfloat16(self, dtype=torch.bfloat16):
6401        def helper(size, scale_factor, mode, device, memory_format=torch.contiguous_format):
6402            input = torch.randn(size, device=device, dtype=dtype).to(memory_format=memory_format).detach().requires_grad_(True)
6403            inputf = input.to(torch.float32).to(memory_format=torch.contiguous_format).detach().requires_grad_(True)
6404            m = nn.Upsample(scale_factor=scale_factor, mode=mode)
6405
6406            outf = m(inputf)
6407            out = m(input)
6408            self.assertEqual(out.to(torch.float32), outf, atol=0.05, rtol=0)
6409
6410            ginput = torch.randn(out.shape, device=device, dtype=dtype).to(memory_format=memory_format)
6411            ginputf = ginput.to(torch.float32).to(memory_format=torch.contiguous_format)
6412            out.backward(ginput)
6413            outf.backward(ginputf)
6414            self.assertEqual(input.grad.to(torch.float32), inputf.grad, atol=0.01, rtol=0.01)
6415
6416        for device in ['cpu']:
6417            helper([3, 20, 11, 7], 2, 'nearest', device)
6418            helper([3, 20, 11, 7], 2, 'nearest', device, torch.channels_last)
6419            helper([3, 20, 11, 7, 3], 2, 'nearest', device)
6420            helper([3, 20, 30], 2, 'linear', device)
6421            helper([3, 20, 11, 7], 2, 'bilinear', device)
6422            helper([3, 20, 11, 7], 2, 'bilinear', device, torch.channels_last)
6423            helper([1, 3, 11, 7], 2, 'bicubic', device)
6424            helper([1, 3, 11, 7], 2, 'bicubic', device, torch.channels_last)
6425            helper([3, 20, 11, 7, 3], 2, 'trilinear', device)
6426
6427            helper([3, 5, 5], 257., 'nearest', device)
6428            helper([3, 20, 11, 7], 20, 'nearest', device)
6429            helper([3, 20, 11, 7, 3], 20, 'nearest', device)
6430            helper([1, 2, 11, 7], 257, 'nearest', device, torch.channels_last)
6431            helper([1, 2, 2000, 2000], 1 / 377., 'nearest', device)
6432            helper([1, 2, 2000, 2000], 1 / 257., 'nearest', device, torch.channels_last)
6433            helper([3, 2, 11, 7, 3], 20, 'nearest', device, torch.channels_last_3d)
6434            helper([3, 5, 5], 10, 'linear', device)
6435            helper([3, 5, 5], 257, 'linear', device)
6436            helper([1, 2, 11, 7], 257, 'bilinear', device)
6437            helper([1, 2, 11, 7], 257, 'bilinear', device, torch.channels_last)
6438            helper([1, 3, 11, 7], 10, 'bicubic', device)
6439            helper([1, 3, 11, 7], 10, 'bicubic', device, torch.channels_last)
6440            helper([1, 1, 11, 7], 257, 'bicubic', device)
6441            helper([3, 2, 11, 7, 3], 20, 'trilinear', device)
6442            helper([3, 2, 11, 7, 3], 20, 'trilinear', device, torch.channels_last_3d)
6443
6444    @unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
6445    def test_interpolate_illegal_memory_access(self):
6446        in_s = 45
6447        out_s = 14
6448
6449        input = torch.ones((1, 1, in_s), device='cuda', requires_grad=True)
6450        # note we allocated grad_output to be larger so out of bound access
6451        # would be visible in grad_input
6452        grad = torch.ones((1, 1, out_s * 2), device='cuda', requires_grad=True)
6453        grad = grad[:, :, :out_s]
6454
6455        input_ref = input.detach().cpu().requires_grad_()
6456        grad_ref = grad.cpu()
6457
6458        out = F.interpolate(input, size=(out_s,), mode='nearest')
6459        out.backward(grad)
6460
6461        out_ref = F.interpolate(input_ref, size=(out_s,), mode='nearest')
6462        out_ref.backward(grad_ref)
6463
6464        self.assertEqual(out_ref, out)
6465        self.assertEqual(input_ref.grad, input.grad)
6466
6467    def test_interpolate_undefined_behavior_casting(self):
6468        x = torch.ones([1, 1, 16, 16])
6469        self.assertRaises(RuntimeError, lambda: F.interpolate(x, scale_factor=-1e20, mode="bilinear"))
6470        self.assertRaises(RuntimeError, lambda: F.interpolate(x, scale_factor=1e20, mode="bilinear"))
6471
6472    def test_interpolate_buffer_overflow(self):
6473        # Test buffer overflow issue due to inaccurate floating point
6474        # representation for integer values. See issue below for details.
6475        # https://github.com/pytorch/pytorch/issues/88939
6476
6477        def helper(size, dtype, mode, device, is_channels_last):
6478            input = torch.ones(size, dtype=dtype, device=device)
6479            if is_channels_last:
6480                if len(size) == 3:
6481                    input = input.transpose(1, 2).contiguous().transpose(1, 2)
6482                elif len(size) == 4:
6483                    input = input.to(memory_format=torch.channels_last)
6484                else:
6485                    input = input.to(memory_format=torch.channels_last_3d)
6486            output1 = F.interpolate(input, 2, mode=mode, align_corners=True)
6487            # reset the corner value and expect the output is changed as well
6488            # the output won't be changed on buffer overflow
6489            input[(-1,) * len(size)] = 0.5
6490            output2 = F.interpolate(input, 2, mode=mode, align_corners=True)
6491            self.assertNotEqual(output1, output2)
6492
6493        size_dtype_list = []
6494        # We set the size larger than the floating point exactly representable range
6495        # float: exact representable range (-2**24,2**24)
6496        size_dtype_list.append(([1, 10, 2**24 + 4], torch.float))
6497        size_dtype_list.append(([1, 10, 2, 2**24 + 4], torch.float))
6498        size_dtype_list.append(([1, 10, 2, 2, 2**24 + 4], torch.float))
6499        # bfloat16: exact representable range (-2**8, 2**8)
6500        size_dtype_list.append(([1, 10, 2**8 + 4], torch.bfloat16))
6501        size_dtype_list.append(([1, 10, 2, 2**8 + 4], torch.bfloat16))
6502        size_dtype_list.append(([1, 10, 2, 2, 2**8 + 4], torch.bfloat16))
6503        # half: exact representable range (-2**11, 2**11)
6504        size_dtype_list.append(([1, 10, 2**11 + 4], torch.half))
6505        size_dtype_list.append(([1, 10, 2, 2**11 + 4], torch.half))
6506        size_dtype_list.append(([1, 10, 2, 2, 2**11 + 4], torch.half))
6507
6508        # TODO: turn on cuda test after buffer overflow issue is fixed in cuda kernel
6509        # devices = ['cpu'] + (['cuda'] if torch.cuda.is_available() else [])
6510        devices = ['cpu']
6511
6512        for mode in ('linear', 'bilinear', 'bicubic', 'trilinear'):
6513            for size_dtype in size_dtype_list:
6514                size, dtype = size_dtype
6515                if (
6516                    mode == 'linear' and len(size) != 3
6517                    or (mode == 'bilinear' and len(size) != 4)
6518                    or (mode == 'bicubic' and len(size) != 4)
6519                    or (mode == 'trilinear' and len(size) != 5)
6520                ):
6521                    continue
6522                for device in devices:
6523                    if (
6524                        device == 'cpu' and dtype == torch.half
6525                        or (device == 'cuda' and dtype == torch.bfloat16)
6526                    ):
6527                        # no half precision support on cpu or bfloat16 on cuda yet
6528                        continue
6529                    for is_channels_last in (True, False):
6530                        helper(size, dtype, mode, device, is_channels_last)
6531
6532
6533    @set_default_dtype(torch.double)
6534    def test_interpolate(self):
6535        def _test_interpolate_non_integer_size_warning(in_t, out_size, dim, **kwargs):
6536            test_sizes = [float(out_size),
6537                          torch.tensor(out_size, dtype=torch.float)]
6538            for size in test_sizes:
6539                self.assertRaisesRegex(TypeError,
6540                                       "(expected size to be one of int or).*",
6541                                       F.interpolate, in_t, size=(size,) * dim, **kwargs)
6542
6543        def _test_interpolate_helper(in_t, scale_factor, layer):
6544            out_size = int(math.floor(in_t.shape[-1] * scale_factor))
6545            dim = len(in_t.shape) - 2
6546            out_shape = [1, 1] + [out_size] * dim
6547            with warnings.catch_warnings(record=True) as w:
6548                out_t = layer(in_t)
6549            self.assertEqual(torch.ones(out_shape), out_t)
6550
6551            self.assertEqual(
6552                F.interpolate(in_t, (out_size,) * dim, **kwargs),
6553                F.interpolate(in_t, scale_factor=scale_factor, **kwargs))
6554            gradcheck(lambda x: F.interpolate(x, out_size, **kwargs), [in_t], nondet_tol=GRADCHECK_NONDET_TOL)
6555            gradgradcheck(lambda x: F.interpolate(x, out_size, **kwargs), [in_t], nondet_tol=GRADCHECK_NONDET_TOL)
6556            _test_interpolate_non_integer_size_warning(in_t, out_size, dim, **kwargs)
6557
6558        def _make_input(dim, device):
6559            size = [1, 1]
6560            size += [2] * dim
6561            return torch.ones(size, requires_grad=True, device=device)
6562
6563        device_list = ['cpu']
6564        if TEST_CUDA:
6565            device_list.append('cuda')
6566
6567        for device in device_list:
6568            for scale_factor in [0.5, 1.5, 2]:
6569                for mode in ['nearest', 'area']:
6570                    kwargs = dict(mode=mode)
6571                    m = nn.Upsample(scale_factor=scale_factor, **kwargs).to(device)
6572                    for input in [_make_input(1, device), _make_input(2, device), _make_input(3, device)]:
6573                        _test_interpolate_helper(input, scale_factor, m)
6574
6575                for align_corners in [True, False]:
6576                    kwargs = dict(mode='linear', align_corners=align_corners)
6577                    m = nn.Upsample(scale_factor=scale_factor, **kwargs).to(device)
6578                    _test_interpolate_helper(_make_input(1, device), scale_factor, m)
6579
6580                    kwargs = dict(mode='bilinear', align_corners=align_corners)
6581                    m = nn.Upsample(scale_factor=scale_factor, **kwargs).to(device)
6582                    _test_interpolate_helper(_make_input(2, device), scale_factor, m)
6583
6584                    kwargs = dict(mode='bicubic', align_corners=align_corners)
6585
6586                    def m(t):
6587                        return F.interpolate(t, scale_factor=scale_factor, **kwargs).to(device)
6588                    _test_interpolate_helper(_make_input(2, device), scale_factor, m)
6589
6590                    kwargs = dict(mode='trilinear', align_corners=align_corners)
6591                    m = nn.Upsample(scale_factor=scale_factor, **kwargs).to(device)
6592                    _test_interpolate_helper(_make_input(3, device), scale_factor, m)
6593
6594    def test_linear_broadcasting(self):
6595        m = nn.Linear(5, 8)
6596        inp = torch.randn(2, 3, 5)
6597        expected = m(inp.view(6, 5)).view(2, 3, 8)
6598        self.assertEqual(expected, m(inp))
6599
6600    def test_linear_raise_on_scalar_input(self):
6601        # This used to cause an int underflow issue when reshaping the input
6602        # see https://github.com/pytorch/pytorch/issues/119161
6603        m = nn.Linear(1, 1)
6604        inp = torch.ones(1).squeeze()
6605        with self.assertRaisesRegex(RuntimeError, ".*both arguments.*1D.*"):
6606            m(inp)
6607
6608    @parametrize_test('device', ['cpu'] + (['cuda'] if TEST_CUDA else []))
6609    @parametrize_test('bias', [
6610        subtest(False, name='nobias'), subtest(True, name='bias')])
6611    @parametrize_test('weight_layout', [
6612        subtest(torch.strided, name='weightStrided'),
6613        subtest(torch.sparse_coo, name='weightCOO'),
6614        subtest(torch.sparse_csr, name='weightCSR'),
6615        subtest(torch.sparse_csc, name='weightCSC'),
6616        # TODO: addmm: computation on CPU is not implemented for Strided + Strided @ SparseBsr
6617        # subtest(torch.sparse_bsr, name='weightBSR'),
6618        # subtest(torch.sparse_bsc, name='weightBSC'),
6619    ])
6620    def test_linear_autograd(self, device, bias, weight_layout):
6621        module = nn.Linear(4, 4, bias=bias, device=device)
6622        if weight_layout == torch.strided:
6623            pass
6624        elif weight_layout == torch.sparse_csr:
6625            module.weight = nn.Parameter(module.weight.to_sparse_csr())
6626        elif weight_layout == torch.sparse_csc:
6627            module.weight = nn.Parameter(module.weight.to_sparse_csc())
6628        elif weight_layout == torch.sparse_bsr:
6629            module.weight = nn.Parameter(module.weight.to_sparse_bsr((2, 2)))
6630        elif weight_layout == torch.sparse_bsc:
6631            module.weight = nn.Parameter(module.weight.to_sparse_bsc((2, 2)))
6632        elif weight_layout == torch.sparse_coo:
6633            module.weight = nn.Parameter(module.weight.to_sparse_coo())
6634        else:
6635            raise AssertionError
6636
6637        inp = torch.randn(4, requires_grad=True, device=device)
6638        res = module(inp)
6639        if bias:
6640            expected = (torch.einsum("i,ji->j", inp, module.weight.to_dense())) + module.bias
6641        else:
6642            expected = (torch.einsum("i,ji->j", inp, module.weight.to_dense()))
6643        self.assertEqual(res, expected)
6644
6645        grad_output = torch.randn(4, device=device)
6646        grads = torch.autograd.grad(res, [module.weight, inp], grad_output)
6647        grads_expected = torch.autograd.grad(expected, [module.weight, inp], grad_output)
6648
6649        self.assertEqual(grads_expected[0].layout, weight_layout)
6650
6651        for g, ge in zip(grads, grads_expected):
6652            self.assertEqual(g, ge)
6653
6654    def test_bilinear(self):
6655        module = nn.Bilinear(10, 10, 8)
6656        input1 = torch.randn(4, 10, requires_grad=True)
6657        input2 = torch.randn(4, 10, requires_grad=True)
6658        grad_output = torch.randn(4, 8)
6659        res = module(input1, input2)
6660        expected = (torch.einsum("bi,kij,bj->bk", input1, module.weight, input2) +
6661                    module.bias)
6662        self.assertEqual(res, expected)
6663        grads = torch.autograd.grad(res, [module.weight, module.bias, input1, input2], grad_output)
6664        grads_expected = torch.autograd.grad(expected, [module.weight, module.bias, input1, input2], grad_output)
6665        for g, ge in zip(grads, grads_expected):
6666            self.assertEqual(g, ge)
6667
6668    def test_bilinear_non_contiguous(self):
6669        module = nn.Bilinear(7, 7, 5)
6670        input1 = torch.randn(4, 7, 10, requires_grad=True)
6671        input2 = torch.randn(4, 7, 10, requires_grad=True)
6672        input1_tp = input1.transpose(1, 2)
6673        input2_tp = input2.transpose(1, 2)
6674
6675        grad_output = torch.randn(4, 10, 5)
6676
6677        def run(input1_tp, input2_tp):
6678            input1.grad = input2.grad = None
6679            output = module(input1_tp, input2_tp)
6680            output.backward(grad_output)
6681
6682            return output.data, input1.grad.data, input2.grad.data
6683
6684        out_nc, g1_nc, g2_nc = run(input1_tp, input2_tp)
6685        input1_tp = input1_tp.contiguous()
6686        input2_tp = input2_tp.contiguous()
6687        out, g1, g2 = run(input1_tp, input2_tp)
6688
6689        self.assertEqual(out, out_nc)
6690        self.assertEqual(g1, g1_nc)
6691        self.assertEqual(g2, g2_nc)
6692
6693    def test_bilinear_no_bias(self):
6694        module = nn.Bilinear(10, 10, 8, dtype=torch.double)
6695        module_no_bias = nn.Bilinear(10, 10, 8, False, dtype=torch.double)
6696
6697        module.bias.data.zero_()
6698        module.weight.data.copy_(module_no_bias.weight)
6699
6700        input1 = torch.randn(4, 10, requires_grad=True, dtype=torch.double)
6701        input2 = torch.randn(4, 10, requires_grad=True, dtype=torch.double)
6702        grad_output = torch.randn(4, 8, dtype=torch.double)
6703
6704        def run(net):
6705            input1.grad = input2.grad = None
6706            output = net(input1, input2)
6707            output.backward(grad_output)
6708
6709            return output.data, input1.grad.data, input2.grad.data
6710
6711        out, g1, g2 = run(module)
6712        out_nb, g1_nb, g2_nb = run(module_no_bias)
6713
6714        self.assertEqual(out, out_nb)
6715        self.assertEqual(g1, g1_nb)
6716        self.assertEqual(g2, g2_nb)
6717
6718        _assertGradAndGradgradChecks(self,
6719                                     lambda x1, x2: F.bilinear(x1, x2, module_no_bias.weight, module_no_bias.bias),
6720                                     (input1, input2))
6721
6722    def test_bilinear_broadcasting(self):
6723        m = nn.Bilinear(5, 6, 8)
6724        input1 = torch.randn(2, 3, 5)
6725        input2 = torch.randn(2, 3, 6)
6726        expected = m(input1.view(6, 5), input2.view(6, 6)).view(2, 3, 8)
6727        self.assertEqual(expected, m(input1, input2))
6728
6729    def test_fold_invalid_arg(self):
6730        # input.size(1) not divisible by \prod(kernel_size)
6731
6732        fold = nn.Fold(output_size=(4, 5), kernel_size=(2, 3))
6733        with self.assertRaisesRegex(RuntimeError, r"be divisible by the product of kernel_size"):
6734            fold(torch.randn(1, 5, 9))
6735
6736        with self.assertRaisesRegex(RuntimeError, r"be divisible by the product of kernel_size"):
6737            fold(torch.randn(1, 19, 9))
6738
6739        # input.size(2) not matching the total number of sliding blocks
6740
6741        with self.assertRaisesRegex(RuntimeError, r"match the calculated number of sliding blocks"):
6742            fold = nn.Fold(output_size=(4, 5), kernel_size=(2, 3))
6743            fold(torch.randn(1, 6, 10))
6744
6745        with self.assertRaisesRegex(RuntimeError, r"match the calculated number of sliding blocks"):
6746            fold = nn.Fold(output_size=(4, 5), kernel_size=(2, 3), stride=(2, 2))
6747            fold(torch.randn(1, 6, 5))
6748
6749        with self.assertRaisesRegex(RuntimeError, r"match the calculated number of sliding blocks"):
6750            fold = nn.Fold(output_size=(4, 5), kernel_size=(2, 3), stride=(2, 2), dilation=(1, 2), padding=(2, 0))
6751            fold(torch.randn(1, 6, 5))  # should be 4 * 1 = 4 sliding blocks
6752
6753        fold = nn.Fold(output_size=(4, 5), kernel_size=(2, 2), stride=1, dilation=8, padding=0)
6754        with self.assertRaisesRegex(RuntimeError, r"calculated shape of the array of sliding blocks as"):
6755            fold(torch.randn(1, 12, 12))
6756
6757    def test_unfold_invalid_arg(self):
6758        # input wrong dimension
6759
6760        unfold = nn.Unfold(kernel_size=(2, 3))
6761
6762        # calculated output shape is too small
6763        with self.assertRaisesRegex(RuntimeError, r"its components must be at least one"):
6764            unfold = nn.Unfold(kernel_size=(2, 3))
6765            unfold(torch.randn(1, 2, 2, 2))
6766
6767        with self.assertRaisesRegex(RuntimeError, r"its components must be at least one"):
6768            unfold = nn.Unfold(kernel_size=(5, 3), padding=(1, 1))
6769            unfold(torch.randn(1, 2, 2, 3))
6770
6771        with self.assertRaisesRegex(RuntimeError, r"its components must be at least one"):
6772            unfold = nn.Unfold(kernel_size=(1, 3), padding=(1, 1), dilation=(1, 2))
6773            unfold(torch.randn(1, 2, 2, 2))
6774
6775    def test_softmin(self):
6776        x = torch.randn(2, 16)
6777        self.assertEqual(F.softmin(x, 1), F.softmax(-x, 1))
6778        self.assertEqual(F.softmin(x, 0), F.softmax(-x, 0))
6779
6780    def test_adaptive_log_softmax(self):
6781        # args validation
6782        with self.assertRaises(ValueError):
6783            _ = nn.AdaptiveLogSoftmaxWithLoss(16, 20, [5, 15, 15], div_value=2.)
6784
6785        with self.assertRaises(ValueError):
6786            _ = nn.AdaptiveLogSoftmaxWithLoss(16, 20, [5, 15, 10], div_value=2.)
6787
6788        with self.assertRaises(ValueError):
6789            _ = nn.AdaptiveLogSoftmaxWithLoss(16, 20, [5, 10, 25], div_value=2.)
6790
6791        with self.assertRaisesRegex(ValueError, "cutoffs should be a sequence of unique,"):
6792            _ = nn.AdaptiveLogSoftmaxWithLoss(16, 20, [5, 10, 20], div_value=2.)
6793
6794        # not raise
6795        _ = nn.AdaptiveLogSoftmaxWithLoss(16, 20, [5, 10, 19], div_value=2.)
6796
6797        # input shapes
6798        with self.assertRaisesRegex(RuntimeError, r"Input and target should have the same size"):
6799            asfm = nn.AdaptiveLogSoftmaxWithLoss(16, 20, [5, 10, 15], div_value=2.)
6800            x = torch.randn(2, 16)
6801            y = torch.tensor([0, 5, 10])
6802            asfm(x, y)
6803
6804        # out-of-bound targets
6805        with self.assertRaisesRegex(RuntimeError, r"Target values should be in"):
6806            asfm = nn.AdaptiveLogSoftmaxWithLoss(16, 20, [5, 10, 15], div_value=2.)
6807            x = torch.randn(2, 16)
6808            y = torch.tensor([0, 20])
6809            asfm(x, y)
6810
6811        # cluster sizes
6812        asfm = nn.AdaptiveLogSoftmaxWithLoss(16, 20, [5, 10, 15], div_value=2.)
6813        x = torch.randn(2, 16)
6814        y = torch.tensor([0, 17])
6815
6816        self.assertEqual(asfm.head.weight.size(), (5 + 3, 16))   # 5 targets in head, 3 clusters, dimensionality 16
6817        self.assertEqual(asfm.tail[0][1].weight.size(), (5, 8))  # 5 targets in this cluster, dimensionality 8
6818        self.assertEqual(asfm.tail[1][1].weight.size(), (5, 4))
6819        self.assertEqual(asfm.tail[2][1].weight.size(), (5, 2))
6820        self.assertEqual(asfm(x, y).output.size(), (2, ))
6821
6822        # test no_batch_dim support
6823        asfm = nn.AdaptiveLogSoftmaxWithLoss(16, 20, [5, 10, 15], div_value=2.)
6824        x = torch.randn(1, 16)
6825        y = torch.tensor([17])
6826        x2 = x.squeeze(0)
6827        y2 = y.squeeze(0)
6828        self.assertEqual(asfm(x, y).output.squeeze(0), asfm(x2, y2).output)
6829
6830        # log_probs actually returns log_proba
6831        asfm = nn.AdaptiveLogSoftmaxWithLoss(8, 4, [2], div_value=2.)
6832        x = torch.randn(4, 8)
6833        logprob_out = asfm.log_prob(x)
6834
6835        self.assertEqual(torch.exp(logprob_out).data.sum(1), torch.ones(4))
6836
6837        # forward returns the same thing as log_probs
6838        for v in [0, 1, 2, 3]:
6839            y = torch.full((4,), v, dtype=torch.long)
6840            out, loss = asfm(x, y)
6841
6842            self.assertEqual(out, logprob_out.gather(1, y.unsqueeze(1)).squeeze())
6843            self.assertEqual(loss, F.nll_loss(logprob_out, y))
6844
6845        # predict
6846        x = torch.randn(64, 8).abs_()
6847
6848        # argmax in shortlist
6849        asfm = nn.AdaptiveLogSoftmaxWithLoss(8, 10, [4, 8], div_value=2., head_bias=True)
6850        asfm.head.weight.data.abs_()
6851        asfm.head.bias.data.abs_()
6852        asfm.head.weight.data[asfm.shortlist_size:, :].zero_()
6853
6854        out = asfm.predict(x)
6855        self.assertEqual(out, asfm.log_prob(x).argmax(dim=1))
6856
6857        # argmax outside of shortlist
6858        asfm = nn.AdaptiveLogSoftmaxWithLoss(8, 10, [4, 8], div_value=2., head_bias=True)
6859        asfm.head.weight.data.abs_()
6860        asfm.head.bias.data.abs_()
6861        asfm.head.weight.data[:asfm.shortlist_size, :].zero_()
6862
6863        out = asfm.predict(x)
6864        self.assertEqual(out, asfm.log_prob(x).argmax(dim=1))
6865
6866        # half of the argmax in shortlist, half in clusters
6867        asfm = nn.AdaptiveLogSoftmaxWithLoss(8, 10, [4, 8], div_value=2., head_bias=True)
6868        asfm.head.weight.data.abs_()
6869        asfm.head.bias.data.abs_()
6870
6871        x[:32, :asfm.shortlist_size].zero_()
6872        x[32:, asfm.shortlist_size:].zero_()
6873
6874        asfm.head.weight.data[:asfm.shortlist_size, asfm.shortlist_size:].zero_()
6875        asfm.head.weight.data[asfm.shortlist_size:, :asfm.shortlist_size].zero_()
6876
6877        out = asfm.predict(x)
6878        self.assertEqual(out, asfm.log_prob(x).argmax(dim=1))
6879
6880    def test_cross_entropy_loss(self, dtype=torch.bfloat16):
6881        loss_cpu = nn.CrossEntropyLoss().cpu()
6882        inputf = torch.randn(15, 10, device="cpu", dtype=torch.float, requires_grad=True)
6883        input = inputf.to(dtype).detach().requires_grad_(True)
6884        target = torch.empty(15, dtype=torch.long).random_(10)
6885
6886        outf = loss_cpu(inputf, target)
6887        out = loss_cpu(input, target)
6888        self.assertEqual(out, outf.to(dtype=dtype), atol=1e-1, rtol=0)
6889
6890        outf.backward()
6891        out.backward()
6892        self.assertEqual(input.grad, inputf.grad.to(dtype=dtype), atol=1e-1, rtol=0)
6893
6894    def test_cross_entropy_loss_precision(self):
6895        # Regression test for #55657
6896        loss_cpu = nn.CrossEntropyLoss().cpu()
6897        inputf = torch.randn(128, 2, 768, 768, device="cpu", dtype=torch.float)
6898        inputd = inputf.double()
6899        target = torch.randint(2, (128, 768, 768), dtype=torch.long)
6900
6901        outf = loss_cpu(inputf, target)
6902        outd = loss_cpu(inputd, target)
6903        self.assertEqual(outf, outd, exact_dtype=False)
6904
6905    def test_cross_entropy_loss_zero_div(self):
6906        # Test for issue #73165
6907        input_1 = torch.rand([5, 0], dtype=torch.float32)
6908        input_2 = torch.rand([5, 0], dtype=torch.float32)
6909        torch.nn.CrossEntropyLoss()(input_1, input_2)
6910
6911    @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
6912    def test_convert_sync_batchnorm(self):
6913        module = torch.nn.Sequential(
6914            torch.nn.BatchNorm1d(100),
6915            torch.nn.InstanceNorm1d(100)
6916        ).cuda()
6917
6918        # necessary to have an anchor point for comparison, in case the
6919        # convert_sync_batchnorm updates in place
6920        comp_module = torch.nn.Sequential(
6921            torch.nn.BatchNorm1d(100),
6922            torch.nn.InstanceNorm1d(100)
6923        ).cuda()
6924        comp_module.load_state_dict(module.state_dict())
6925
6926        sync_bn_module = torch.nn.SyncBatchNorm.convert_sync_batchnorm(module)
6927        children = list(sync_bn_module.children())
6928        self.assertEqual(children[0].__class__, torch.nn.SyncBatchNorm)
6929        self.assertEqual(children[1].__class__, torch.nn.InstanceNorm1d)
6930
6931        for layer, converted_layer in zip(comp_module.children(), sync_bn_module.children()):
6932            for key in layer.state_dict().keys():
6933                self.assertEqual(layer.state_dict()[key].device, converted_layer.state_dict()[key].device)
6934                self.assertEqual(layer.state_dict()[key], converted_layer.state_dict()[key])
6935
6936    @unittest.skipIf(not TEST_CUDA, "CUDA not available")
6937    def test_sync_batchnorm_backward_elemt(self):
6938        device = 'cuda'
6939        saved_input = torch.rand(2, 3, 2, 1, device=device)
6940        grad_output = torch.rand(2, 3, 2, 1, device=device)
6941        mean = torch.rand(3, device=device)
6942        invstd = torch.rand(3, device=device)
6943        weight = torch.rand(3, device=device)
6944        sum_dy = torch.rand(3, device=device)
6945        sum_dy_xmu = torch.rand(3, device=device)
6946        count_tensor = torch.tensor([5, 5, 5], dtype=torch.int32, device=device)
6947
6948        gI_contiguous = torch.batch_norm_backward_elemt(
6949            grad_output,
6950            saved_input,
6951            mean,
6952            invstd,
6953            weight,
6954            sum_dy,
6955            sum_dy_xmu,
6956            count_tensor
6957        )
6958
6959        # Test batch_norm_backward_elemt gives the same answer for all
6960        # combinations of contiguous as channels_last input
6961        for a, b in [
6962                (torch.channels_last, torch.contiguous_format),
6963                (torch.contiguous_format, torch.channels_last),
6964                (torch.channels_last, torch.channels_last),
6965        ]:
6966            gI_actual = torch.batch_norm_backward_elemt(
6967                grad_output.contiguous(memory_format=a),
6968                saved_input.contiguous(memory_format=b),
6969                mean,
6970                invstd,
6971                weight,
6972                sum_dy,
6973                sum_dy_xmu,
6974                count_tensor
6975            )
6976            self.assertEqual(gI_actual, gI_contiguous)
6977
6978    @unittest.skipIf(not TEST_CUDA, "CUDA not available")
6979    def test_sync_batchnorm_accuracy_cuda(self):
6980        # The target of this test is to test the functionality and accuracy of
6981        #   those single-GPU cuda kernels used in SyncBatchNorm
6982        # They are:
6983        #   fwd: torch.batch_norm_stats, torch.batch_norm_gather_stats_with_counts, torch.batch_norm_elemt
6984        #   bwd: torch.batch_norm_backward_reduce, torch.batch_norm_backward_elemt
6985
6986        def _batch_norm_stats(data, memory_format, mean_axes):
6987            mean1, _ = torch.batch_norm_stats(data, 1e-5)
6988            mean2, _ = torch.batch_norm_stats(data.to(memory_format=memory_format), 1e-5)
6989            mean_ref = torch.mean(data, mean_axes, keepdim=False)
6990
6991            self.assertEqual(mean_ref, mean1)
6992            self.assertEqual(mean_ref, mean2)
6993
6994        _batch_norm_stats(torch.randn(1, 96, 112, 112, dtype=torch.float, device='cuda'), torch.channels_last, (0, 2, 3))
6995        _batch_norm_stats(torch.randn(1, 96, 112, 112, 112, dtype=torch.float, device='cuda'), torch.channels_last_3d, (0, 2, 3, 4))
6996
6997    def test_flatten(self):
6998        tensor_input = torch.randn(2, 1, 2, 3)
6999
7000        # Flatten Tensor
7001
7002        flatten = nn.Flatten(start_dim=1, end_dim=-1)
7003        tensor_output = flatten(tensor_input)
7004        self.assertEqual(tensor_output.size(), torch.Size([2, 6]))
7005
7006    def test_unflatten(self):
7007        tensor_input = torch.randn(2, 50)
7008
7009        # Unflatten Tensor (unflattened_size as a tuple of ints and list of ints)
7010
7011        for us in ((2, 5, 5), [2, 5, 5]):
7012            unflatten = nn.Unflatten(dim=1, unflattened_size=us)
7013            tensor_output = unflatten(tensor_input)
7014            self.assertEqual(tensor_output.size(), torch.Size([2, 2, 5, 5]))
7015
7016        # Unflatten NamedTensor
7017
7018        unflatten = nn.Unflatten(dim='features', unflattened_size=(('C', 2), ('H', 5), ('W', 5)))
7019        named_tensor_input = tensor_input.refine_names('N', 'features')
7020        named_tensor_output = unflatten(named_tensor_input)
7021        self.assertEqual(named_tensor_output.size(), torch.Size([2, 2, 5, 5]))
7022
7023    def test_unflatten_invalid_arg(self):
7024        # Wrong type for unflattened_size (tuple of floats)
7025
7026        with self.assertRaisesRegex(
7027                TypeError,
7028                r"unflattened_size must be tuple of ints, but found element of type float at pos 2"):
7029            nn.Unflatten(dim=1, unflattened_size=(2, 5, 5.0))
7030
7031        # Wrong type for unflattened_size (list of lists and list of tuples)
7032        for us in ([['C', 2], ['W', 5], ['H', 5]], [('C', 2), ('W', 5), ('H', 5)]):
7033            with self.assertRaisesRegex(
7034                    TypeError,
7035                    r"unflattened_size must be a tuple of tuples, but found type list"):
7036                nn.Unflatten(dim='features', unflattened_size=us)
7037
7038        # Wrong type for unflattened_size (tuple of lists)
7039
7040        with self.assertRaisesRegex(
7041                TypeError,
7042                r"unflattened_size must be tuple of tuples, but found element of type list at pos 0"):
7043            nn.Unflatten(dim='features', unflattened_size=(['C', 2], ['W', 5], ['H', 5]))
7044
7045        # Wrong type for unflattened_size (tuple of dicts)
7046
7047        with self.assertRaisesRegex(
7048                TypeError,
7049                r"unflattened_size must be tuple of tuples, but found element of type dict at pos 0"):
7050            nn.Unflatten(dim='features', unflattened_size=({'C': 2}, {'W': 5}, {'H': 5}))
7051
7052    def test_layer_norm_grads_with_create_graph_flag(self):
7053        atol = 1e-5
7054        rtol = 1e-3
7055
7056        x = torch.randn((4, 4, 16), requires_grad=True)
7057        layer_norm = nn.LayerNorm((16,), 1e-5, True)
7058        with torch.no_grad():
7059            layer_norm.weight = torch.nn.Parameter(0.1 * torch.ones_like(layer_norm.weight))
7060
7061        grads1 = torch.autograd.grad(layer_norm(x).sum(), x, create_graph=False)[0]
7062        grads2 = torch.autograd.grad(layer_norm(x).sum(), x, create_graph=True)[0]
7063
7064        self.assertEqual(grads1, grads2, rtol=rtol, atol=atol)
7065
7066        if TEST_CUDA:
7067            x = x.to('cuda')
7068            layer_norm = layer_norm.to('cuda')
7069
7070            grads1 = torch.autograd.grad(layer_norm(x).sum(), x, create_graph=False)[0]
7071            grads2 = torch.autograd.grad(layer_norm(x).sum(), x, create_graph=True)[0]
7072
7073            self.assertEqual(grads1, grads2, rtol=rtol, atol=atol)
7074
7075    def test_layer_norm_eps(self):
7076        # test for https://github.com/pytorch/pytorch/issues/108072
7077        x = torch.Tensor([[[2.0, 2.0], [14.0, 14.0]], [[2.0, 2.0], [14.0, 14.0]]])
7078        ln = torch.nn.LayerNorm(2, eps=1e-6, elementwise_affine=False)
7079        self.assertEqual(ln.forward(x), torch.zeros_like(x))
7080
7081    def test_padding_list(self):
7082        # Padding can be a list, or tuple (regression test for gh-54452)
7083        x = torch.randn(4, 8, 32, 32)
7084        net = torch.nn.ConvTranspose2d(8, 16, kernel_size=3, padding=[3, 3])
7085        y = net(x)
7086
7087        net = torch.nn.ConvTranspose2d(8, 16, kernel_size=3, padding=(3, 3))
7088        y = net(x)
7089
7090    def test_fractional_max_pool2d_invalid_output_ratio(self):
7091        arg_1 = [2, 1]
7092        arg_2 = [0.5, 0.5, 0.6]
7093        arg_class = torch.nn.FractionalMaxPool2d(kernel_size=arg_1, output_ratio=arg_2,)
7094        arg_3_0_tensor = torch.rand([20, 16, 50, 32], dtype=torch.float32)
7095        arg_3_0 = arg_3_0_tensor.clone()
7096        arg_3 = [arg_3_0,]
7097
7098        with self.assertRaisesRegex(ValueError,
7099                                    "fractional_max_pool2d requires output_ratio to either be a single Int or tuple of Ints."):
7100            res = arg_class(*arg_3)
7101
7102    def test_max_pool1d_invalid_output_size(self):
7103        arg_1 = 3
7104        arg_2 = 255
7105        arg_3 = False
7106        arg_class = torch.nn.MaxPool1d(kernel_size=arg_1, stride=arg_2, return_indices=arg_3)
7107        arg_4_0 = torch.as_tensor([[0.3204]])
7108        arg_4 = [arg_4_0,]
7109
7110        with self.assertRaises(RuntimeError):
7111            res = arg_class(*arg_4)
7112
7113    def test_pickle_module_no_weights_only_warning(self):
7114        with warnings.catch_warnings(record=True) as w:
7115            pickle.loads(pickle.dumps(torch.nn.Linear(10, 10)))
7116        self.assertEqual(len(w), 0)
7117
7118class TestFusionEval(TestCase):
7119    @set_default_dtype(torch.double)
7120    @given(X=hu.tensor(shapes=((5, 3, 5, 5),), dtype=np.double),
7121           running_mean=hu.tensor(shapes=(6,), dtype=np.double),
7122           running_var=hu.tensor(shapes=(6,), dtype=np.double))
7123    def test_fuse_module_eval_numerics(self, X, running_mean, running_var):
7124        inputs, _ = X
7125
7126        iC, oC = inputs.shape[1], len(running_mean[0])
7127        inputs = torch.from_numpy(inputs)
7128        kernel_size = (3, 3)
7129
7130        conv_ref = torch.nn.Conv2d(iC, oC, bias=True, kernel_size=kernel_size)
7131        bn_ref = torch.nn.BatchNorm2d(oC)
7132        bn_ref.running_mean = torch.from_numpy(running_mean[0])
7133        bn_ref.running_var = torch.from_numpy(running_var[0])
7134
7135        conv_ref.eval()
7136        bn_ref.eval()
7137
7138        Y_ref = bn_ref(conv_ref(inputs))
7139        conv_bn_fused = torch.nn.utils.fusion.fuse_conv_bn_eval(conv_ref,
7140                                                                bn_ref)
7141        Y_hat = conv_bn_fused(inputs)
7142
7143        self.assertEqual(Y_ref, Y_hat, msg="Conv+BN fusion results are off")
7144
7145        na_bn_ref = torch.nn.BatchNorm2d(oC, affine=False)
7146        na_bn_ref.running_mean = torch.from_numpy(running_mean[0])
7147        na_bn_ref.running_var = torch.from_numpy(running_var[0])
7148        na_bn_ref.eval()
7149
7150        Y_ref = na_bn_ref(conv_ref(inputs))
7151        conv_na_bn_fused = torch.nn.utils.fusion.fuse_conv_bn_eval(conv_ref,
7152                                                                   na_bn_ref)
7153        Y_hat = conv_na_bn_fused(inputs)
7154
7155        self.assertEqual(Y_ref, Y_hat, msg="Conv+BN(non-affine) fusion results are off")
7156
7157
7158class TestConstantPadNd(TestCase):
7159    def test_constant_pad_nd(self):
7160        a = torch.tensor([[1, 2], [3, 4]])
7161        res = torch.constant_pad_nd(a, [1, 2, 1, 0], 9)
7162        expected = torch.tensor([
7163            [9, 9, 9, 9, 9],
7164            [9, 1, 2, 9, 9],
7165            [9, 3, 4, 9, 9]
7166        ])
7167        self.assertEqual(res, expected)
7168
7169    def test_preserves_memory_format(self):
7170        nchw_tensor = torch.rand((1, 2, 5, 3))
7171        nchw_padded = torch.constant_pad_nd(nchw_tensor, [1, 2], 0.5)
7172        self.assertTrue(nchw_padded.is_contiguous(memory_format=torch.contiguous_format))
7173
7174        nhwc_tensor = nchw_tensor.contiguous(memory_format=torch.channels_last)
7175        nhwc_padded = torch.constant_pad_nd(nhwc_tensor, [1, 2], 0.5)
7176        self.assertTrue(nhwc_padded.is_contiguous(memory_format=torch.channels_last))
7177
7178
7179class TestAddRelu(TestCase):
7180    def test_add_relu(self):
7181        a = torch.rand((7, 11))
7182        b = torch.rand((7, 11))
7183        a = a.float()
7184        b = b.float()
7185        a = a * -10
7186        a = a + 5
7187        add_res = a + b
7188        relu_res = torch.relu(add_res)
7189        add_relu_res = torch._VF._add_relu(a, b)
7190
7191        self.assertEqual(add_relu_res, relu_res)
7192
7193    def test_add_relu_broadcasting(self):
7194        a = torch.rand((1, 32))
7195        b = 1
7196        b_scalar = torch.ones(1, 32)
7197        res = torch._VF._add_relu(a, b)
7198        broadcasted_res = torch._VF._add_relu(a, b_scalar)
7199
7200        self.assertEqual(broadcasted_res, res)
7201
7202
7203def add_test(test, decorator=None):
7204    def add(test_name, fn):
7205        if hasattr(TestNN, test_name):
7206            raise RuntimeError('Found two tests with the same name: ' + test_name)
7207        if decorator is not None:
7208            fn = decorator(fn)
7209        setattr(TestNN, test_name, fn)
7210
7211    test_name = test.get_name()
7212    if not hasattr(test, 'test_cpu') or test.test_cpu:
7213        add(test_name, lambda self, test=test: test(self))
7214    cuda_test_name = test_name + '_cuda'
7215    # With dtype enable, it's good enough to test against three floating types
7216    kwargs = {}
7217    if 'extra_args' in get_function_arglist(test.test_cuda):
7218        kwargs['extra_args'] = test.extra_args
7219
7220    if 'dtype' in get_function_arglist(test.test_cuda):
7221        if tf32_is_not_fp32() and test.with_tf32:
7222
7223            def with_tf32_off(self, test=test, kwargs=kwargs):
7224                with tf32_off():
7225                    test.test_cuda(self, dtype=torch.float, **kwargs)
7226
7227            add(cuda_test_name + '_fp32', with_tf32_off)
7228
7229            def with_tf32_on(self, test=test, kwargs=kwargs):
7230                with tf32_on(self, test.tf32_precision):
7231                    test.test_cuda(self, dtype=torch.float, **kwargs)
7232
7233            add(cuda_test_name + '_tf32', with_tf32_on)
7234        else:
7235            add(cuda_test_name + '_float', lambda self,
7236                test=test, kwargs=kwargs: test.test_cuda(self, dtype=torch.float, **kwargs))
7237        add(cuda_test_name + '_double', lambda self,
7238            test=test, kwargs=kwargs: test.test_cuda(self, dtype=torch.double, **kwargs))
7239
7240        def test_half(self, test=test, kwargs=kwargs):
7241            test.test_cuda(self, dtype=torch.half, **kwargs)
7242        if getattr(test, 'check_half', True):
7243            add(cuda_test_name + '_half', test_half)
7244
7245        def test_bfloat16(self, test=test, kwargs=kwargs):
7246            test.test_cuda(self, dtype=torch.bfloat16, **kwargs)
7247        if getattr(test, 'check_bfloat16', True):
7248            add(cuda_test_name + '_bfloat16', test_bfloat16)
7249
7250        def test_cfloat(self, test=test, kwargs=kwargs):
7251            test.test_cuda(self, dtype=torch.cfloat, **kwargs)
7252
7253        def test_cdouble(self, test=test, kwargs=kwargs):
7254            test.test_cuda(self, dtype=torch.cdouble, **kwargs)
7255        if getattr(test, 'check_complex', False):
7256            add(cuda_test_name + '_cfloat', test_cfloat)
7257            add(cuda_test_name + '_cdouble', test_cdouble)
7258
7259    else:
7260        def with_tf32_off(self, test=test, kwargs=kwargs):
7261            with tf32_off():
7262                test.test_cuda(self, **kwargs)
7263
7264        if tf32_is_not_fp32() and test.with_tf32:
7265            add(cuda_test_name + '_fp32', with_tf32_off)
7266
7267            def with_tf32_on(self, test=test, kwargs=kwargs):
7268                with tf32_on(self, test.tf32_precision):
7269                    test.test_cuda(self, **kwargs)
7270
7271            add(cuda_test_name + '_tf32', with_tf32_on)
7272        else:
7273            add(cuda_test_name, with_tf32_off)
7274
7275for test_params in module_tests + new_module_tests:
7276    # TODO: CUDA is not implemented yet
7277    if 'constructor' not in test_params:
7278        name = test_params.pop('module_name')
7279        test_params['constructor'] = getattr(nn, name)
7280    decorator = test_params.pop('decorator', None)
7281    test = NewModuleTest(**test_params)
7282    add_test(test, decorator)
7283    if 'check_eval' in test_params:
7284        # create a new test that is identical but that sets module.training to False
7285        desc = test_params.get('desc', None)
7286        test_params['desc'] = 'eval' if desc is None else desc + '_eval'
7287
7288        def gen_eval_constructor(constructor):
7289            def eval_constructor(*args, **kwargs):
7290                cons = constructor(*args, **kwargs)
7291                cons.training = False
7292                return cons
7293            eval_constructor.__name__ = constructor.__name__
7294            return eval_constructor
7295
7296        test_params['constructor'] = gen_eval_constructor(test_params['constructor'])
7297        test = NewModuleTest(**test_params)
7298        add_test(test, decorator)
7299    if 'check_with_long_tensor' in test_params:
7300        fullname = test_params.get('fullname', None)
7301        if fullname:
7302            test_params['fullname'] = fullname + '_with_long_tensor'
7303        else:
7304            desc = test_params.get('desc', None)
7305            test_params['desc'] = 'with_long_tensor' if desc is None else desc + '_with_long_tensor'
7306
7307        def double_equivalent_of_long_tensor(size):
7308            return torch.randint(-1000, 1000, size=size).double()
7309
7310        def apply_to_cons(t):
7311            if t.is_floating_point():
7312                if isinstance(t, Parameter):
7313                    return Parameter(double_equivalent_of_long_tensor(t.size()))
7314                elif isinstance(t, torch.Tensor):
7315                    return double_equivalent_of_long_tensor(t.size())
7316            else:
7317                return t
7318
7319        def gen_long_tensor_constructor(constructor):
7320            def long_tensor_constructor(*args, **kwargs):
7321                cons = constructor(*args, **kwargs)
7322                cons._apply(apply_to_cons)
7323                return cons
7324            long_tensor_constructor.__name__ = constructor.__name__
7325            return long_tensor_constructor
7326
7327        def gen_long_tensor_input(input_size):
7328            def input_func():
7329                return double_equivalent_of_long_tensor(input_size)
7330            return input_func
7331
7332        def reference_fn(i, p, m):
7333            # For bad reasons this would create LongTensors that requires gradients
7334            # Remove requires_grad to avoid this
7335            for p in m.parameters():
7336                p.requires_grad_(False)
7337            m._apply(lambda t: t.long())
7338            input = i.long()
7339            out = m.forward(input)
7340            return out
7341
7342        test_params['constructor'] = gen_long_tensor_constructor(test_params['constructor'])
7343        test_params['input_fn'] = gen_long_tensor_input(test_params['input_size'])
7344        test_params['reference_fn'] = reference_fn
7345        test_params['check_forward_only'] = True
7346        # Currently we don't support conv2d/conv3d for LongTensor in CUDA
7347        test_params['test_cuda'] = False
7348        test = NewModuleTest(**test_params)
7349
7350        add_test(test, decorator)
7351
7352for test_params in criterion_tests:
7353    if 'constructor' not in test_params:
7354        name = test_params.pop('module_name')
7355        test_params['constructor'] = getattr(nn, name)
7356    test = CriterionTest(**test_params)
7357    decorator = test_params.pop('decorator', None)
7358    add_test(test, decorator)
7359    if 'check_sum_reduction' in test_params:
7360        desc = test_params.get('desc', None)
7361        test_params['desc'] = 'sum_reduction' if desc is None else desc + '_sum_reduction'
7362
7363        def gen_sum_reduction_constructor(constructor):
7364            def sum_reduction_constructor(*args, **kwargs):
7365                cons = constructor(*args, reduction='sum', **kwargs)
7366                return cons
7367            sum_reduction_constructor.__name__ = constructor.__name__
7368            return sum_reduction_constructor
7369
7370        test_params['constructor'] = gen_sum_reduction_constructor(test_params['constructor'])
7371        test = CriterionTest(**test_params)
7372        add_test(test, decorator)
7373
7374
7375class UnpoolingNet(nn.Module):
7376    def __init__(self, pool, unpool):
7377        super().__init__()
7378        self.pool = pool
7379        self.unpool = unpool
7380
7381    def forward(self, input):
7382        return self.unpool(*self.pool(input))
7383
7384
7385add_test(NewModuleTest(
7386    constructor=lambda: UnpoolingNet(
7387        nn.MaxPool1d(2, return_indices=True),
7388        nn.MaxUnpool1d(2)),
7389    input_size=(1, 1, 4),
7390    fullname='MaxUnpool1d_net',
7391    default_dtype=torch.double,))
7392add_test(NewModuleTest(
7393    constructor=lambda: UnpoolingNet(
7394        nn.MaxPool2d(2, return_indices=True),
7395        nn.MaxUnpool2d(2)),
7396    input_size=(1, 1, 2, 4),
7397    fullname='MaxUnpool2d_net',
7398    default_dtype=torch.double,))
7399add_test(NewModuleTest(
7400    constructor=lambda: UnpoolingNet(
7401        nn.MaxPool3d(2, return_indices=True),
7402        nn.MaxUnpool3d(2)),
7403    input_size=(1, 1, 2, 4, 6),
7404    fullname='MaxUnpool3d_net',
7405    check_gradgrad=False,
7406    default_dtype=torch.double,))
7407
7408add_test(NewModuleTest(
7409    constructor=lambda: UnpoolingNet(
7410        nn.MaxPool1d(2, return_indices=True),
7411        nn.MaxUnpool1d(2)),
7412    input_size=(1, 4),
7413    reference_fn=single_batch_reference_fn,
7414    fullname='MaxUnpool1d_net_no_batch_dim',
7415    default_dtype=torch.double,))
7416add_test(NewModuleTest(
7417    constructor=lambda: UnpoolingNet(
7418        nn.MaxPool2d(2, return_indices=True),
7419        nn.MaxUnpool2d(2)),
7420    input_size=(1, 2, 4),
7421    reference_fn=single_batch_reference_fn,
7422    fullname='MaxUnpool2d_net_no_batch_dim',
7423    default_dtype=torch.double,))
7424
7425add_test(NewModuleTest(
7426    constructor=lambda: UnpoolingNet(
7427        nn.MaxPool3d(2, return_indices=True),
7428        nn.MaxUnpool3d(2)),
7429    input_size=(1, 2, 4, 6),
7430    reference_fn=single_batch_reference_fn,
7431    fullname='MaxUnpool3d_net_no_batch_dim',
7432    check_gradgrad=False,
7433    default_dtype=torch.double,))
7434
7435class _AdaptiveLogSoftmaxWithLoss(nn.AdaptiveLogSoftmaxWithLoss):
7436    def __call__(self, input):
7437        t = torch.tensor([0, 1, 4, 8]).to(input.device)
7438        return nn.AdaptiveLogSoftmaxWithLoss.__call__(self, input, t).output
7439
7440add_test(NewModuleTest(
7441    constructor=lambda: _AdaptiveLogSoftmaxWithLoss(16, 10, [2, 6]),
7442    input_size=(4, 16),
7443    fullname='AdaptiveLogSoftmax',
7444    with_tf32=True,
7445    tf32_precision=0.005,
7446    default_dtype=torch.double))
7447
7448
7449# The following are helpers for TestNN.test_affine_*
7450if torch.cuda.is_available():
7451    def device_():
7452        return ['cpu', 'cuda']
7453else:
7454    def device_():
7455        return ['cpu']
7456
7457
7458def angle_rad_():
7459    return [r * math.pi * 2 for r in [0.0, 0.5, 0.25, 0.125, random.random()]]
7460
7461
7462def axis_vector_():
7463    t = (random.random(), random.random(), random.random())
7464    l = sum(x ** 2 for x in t) ** 0.5
7465
7466    return [(1.0, 0.0, 0.0), (0.0, 1.0, 0.0), (0.0, 0.0, 1.0), tuple(x / l for x in t)]
7467
7468
7469def input_size2d_():
7470    return [[1, 1, 3, 5], [1, 1, 3, 3], [1, 1, 4, 4], [1, 1, 3, 4]]
7471
7472
7473def output_size2d_():
7474    return [[1, 1, 5, 3], [1, 1, 3, 5], [1, 1, 4, 3], [1, 1, 5, 5], [1, 1, 6, 6]]
7475
7476
7477def input_size2dsq_():
7478    return [[1, 1, 2, 2], [1, 1, 3, 3], [1, 1, 4, 4], [1, 1, 6, 6]]
7479
7480
7481def output_size2dsq_():
7482    return [[1, 1, 2, 2], [1, 1, 3, 3], [1, 1, 4, 4], [1, 1, 5, 5], [1, 1, 6, 6]]
7483
7484
7485def input_size3d_():
7486    return [[1, 1, 2, 2, 2], [1, 1, 2, 3, 4], [1, 1, 3, 3, 3], [1, 1, 4, 4, 4], [1, 1, 3, 4, 5]]
7487
7488
7489def input_size3dsq_():
7490    return [[1, 1, 2, 2, 2], [1, 1, 3, 3, 3], [1, 1, 4, 4, 4], [1, 1, 6, 6, 6]]
7491
7492
7493def output_size3dsq_():
7494    return [[1, 1, 2, 2, 2], [1, 1, 3, 3, 3], [1, 1, 4, 4, 4], [1, 1, 5, 5, 5], [1, 1, 6, 6, 6]]
7495
7496
7497def output_size3d_():
7498    return [[1, 1, 2, 2, 2], [1, 1, 3, 3, 3], [1, 1, 3, 4, 5], [1, 1, 4, 3, 2], [1, 1, 5, 5, 5], [1, 1, 6, 6, 6]]
7499
7500
7501def _buildEquivalentAffineTransforms2d(device, input_size, output_size, angle_rad):
7502    input_center = [(x - 1) / 2.0 for x in input_size]
7503    output_center = [(x - 1) / 2.0 for x in output_size]
7504
7505    s = math.sin(angle_rad)
7506    c = math.cos(angle_rad)
7507
7508    intrans_ary = np.array([
7509        [1, 0, input_center[2]],
7510        [0, 1, input_center[3]],
7511        [0, 0, 1],
7512    ], dtype=np.float64)
7513
7514    inscale_ary = np.array([
7515        [input_center[2], 0, 0],
7516        [0, input_center[3], 0],
7517        [0, 0, 1],
7518    ], dtype=np.float64)
7519
7520    rotation_ary = np.array([
7521        [c, -s, 0],
7522        [s, c, 0],
7523        [0, 0, 1],
7524    ], dtype=np.float64)
7525
7526    outscale_ary = np.array([
7527        [1.0 / output_center[2], 0, 0],
7528        [0, 1.0 / output_center[3], 0],
7529        [0, 0, 1],
7530    ], dtype=np.float64)
7531
7532    outtrans_ary = np.array([
7533        [1, 0, -output_center[2]],
7534        [0, 1, -output_center[3]],
7535        [0, 0, 1],
7536    ], dtype=np.float64)
7537
7538    reorder_ary = np.array([
7539        [0, 1, 0],
7540        [1, 0, 0],
7541        [0, 0, 1],
7542    ], dtype=np.float64)
7543
7544    transform_ary = np.dot(np.dot(np.dot(np.dot(
7545        intrans_ary,
7546        inscale_ary),
7547        rotation_ary.T),
7548        outscale_ary),
7549        outtrans_ary)
7550    grid_ary = np.dot(np.dot(np.dot(reorder_ary, rotation_ary.T), outscale_ary), outtrans_ary)
7551
7552    transform_tensor = torch.from_numpy(rotation_ary).to(device, torch.float32)
7553    transform_tensor = transform_tensor[:2].unsqueeze(0)
7554
7555    return transform_tensor, transform_ary, grid_ary
7556
7557
7558def _buildEquivalentAffineTransforms3d(device, input_size, output_size, angle_rad, axis_vector):
7559    input_center = [(x - 1) / 2.0 for x in input_size]
7560    output_center = [(x - 1) / 2.0 for x in output_size]
7561
7562    s = math.sin(angle_rad)
7563    c = math.cos(angle_rad)
7564    c1 = 1 - c
7565
7566    intrans_ary = np.array([
7567        [1, 0, 0, input_center[2]],
7568        [0, 1, 0, input_center[3]],
7569        [0, 0, 1, input_center[4]],
7570        [0, 0, 0, 1],
7571    ], dtype=np.float64)
7572
7573    inscale_ary = np.array([
7574        [input_center[2], 0, 0, 0],
7575        [0, input_center[3], 0, 0],
7576        [0, 0, input_center[4], 0],
7577        [0, 0, 0, 1],
7578    ], dtype=np.float64)
7579
7580    l, m, n = axis_vector
7581    scipyRotation_ary = np.array([
7582        [l * l * c1 + c, m * l * c1 - n * s, n * l * c1 + m * s, 0],
7583        [l * m * c1 + n * s, m * m * c1 + c, n * m * c1 - l * s, 0],
7584        [l * n * c1 - m * s, m * n * c1 + l * s, n * n * c1 + c, 0],
7585        [0, 0, 0, 1],
7586    ], dtype=np.float64)
7587
7588    z, y, x = axis_vector
7589    torchRotation_ary = np.array([
7590        [x * x * c1 + c, y * x * c1 - z * s, z * x * c1 + y * s, 0],
7591        [x * y * c1 + z * s, y * y * c1 + c, z * y * c1 - x * s, 0],
7592        [x * z * c1 - y * s, y * z * c1 + x * s, z * z * c1 + c, 0],
7593        [0, 0, 0, 1],
7594    ], dtype=np.float64)
7595
7596    outscale_ary = np.array([
7597        [1.0 / output_center[2], 0, 0, 0],
7598        [0, 1.0 / output_center[3], 0, 0],
7599        [0, 0, 1.0 / output_center[4], 0],
7600        [0, 0, 0, 1],
7601    ], dtype=np.float64)
7602
7603    outtrans_ary = np.array([
7604        [1, 0, 0, -output_center[2]],
7605        [0, 1, 0, -output_center[3]],
7606        [0, 0, 1, -output_center[4]],
7607        [0, 0, 0, 1],
7608    ], dtype=np.float64)
7609
7610    reorder_ary = np.array([
7611        [0, 0, 1, 0],
7612        [0, 1, 0, 0],
7613        [1, 0, 0, 0],
7614        [0, 0, 0, 1],
7615    ], dtype=np.float64)
7616
7617    transform_ary = np.dot(np.dot(np.dot(np.dot(
7618        intrans_ary,
7619        inscale_ary),
7620        np.linalg.inv(scipyRotation_ary)),
7621        outscale_ary),
7622        outtrans_ary)
7623    grid_ary = np.dot(np.dot(np.dot(reorder_ary, np.linalg.inv(scipyRotation_ary)), outscale_ary), outtrans_ary)
7624
7625    transform_tensor = torch.from_numpy(torchRotation_ary).to(device, torch.float32)
7626    transform_tensor = transform_tensor[:3].unsqueeze(0)
7627
7628    return transform_tensor, transform_ary, grid_ary
7629# end TestNN.test_affine_* helpers
7630
7631
7632class TestNNDeviceType(NNTestCase):
7633    def _test_InstanceNorm_general(self, cls, input, device, dtype=torch.float):
7634        # default case track_running_stats=False
7635        b, c = input.size(0), input.size(1)
7636        input_var = input.to(device=device, dtype=dtype).requires_grad_()
7637
7638        IN = cls(c, eps=0).to(device, dtype)
7639
7640        output = IN(input_var)
7641        out_reshaped = output.view(b * c, -1)
7642
7643        mean = out_reshaped.mean(1)
7644        var = out_reshaped.var(1, unbiased=False)
7645
7646        self.assertEqual(torch.abs(mean.data).mean(), 0, atol=1e-5, rtol=0)
7647        self.assertEqual(torch.abs(var.data).mean(), 1, atol=1e-5, rtol=0)
7648
7649        # check that eval mode doesn't change behavior
7650        grad_out = torch.randn_like(output)
7651        res1 = output.data.clone()
7652        output.backward(grad_out)
7653        grad1 = input_var.grad.data.clone()
7654
7655        IN.eval()
7656        output = IN(input_var)
7657        input_var.grad = None
7658        output.backward(grad_out)
7659        res2 = output.data
7660        grad2 = input_var.grad.data
7661        self.assertEqual(res1, res2)
7662        self.assertEqual(grad1, grad2)
7663
7664        # If track_running_stats=True and momentum=1, running_mean/var should be
7665        # equal to mean/var of the input (with unbias correction)
7666        IN = cls(c, momentum=1, eps=0, track_running_stats=True).to(device, dtype)
7667
7668        output = IN(input_var)
7669
7670        input_reshaped = input_var.transpose(1, 0).reshape(c, -1)
7671        mean = input_reshaped.mean(1)
7672
7673        input_reshaped = input_var.transpose(1, 0).reshape(c, b, -1)
7674        var = input_reshaped.var(2, unbiased=True)[:, :]
7675
7676        self.assertEqual(torch.abs(mean.data - IN.running_mean).mean(), 0, atol=1e-5, rtol=0)
7677        self.assertEqual(torch.abs(var.data.mean(1) - IN.running_var).mean(), 0, atol=1e-5, rtol=0)
7678
7679        # in eval mode, adding X * std to a channel in input should make the
7680        # corresponding channel in output have mean X
7681        IN.eval()
7682        delta = IN.running_var.sqrt() * torch.arange(c, device=device, dtype=dtype)
7683        delta = delta.view(-1, *[1 for _ in range(2, input.dim())])
7684        output = IN(input_var + delta)
7685        self.assertEqual(output.transpose(0, 1).reshape(c, -1).mean(1), torch.arange(c, dtype=dtype))
7686
7687    def _test_InstanceNorm_cuda_half(self, cls, input, device):
7688        # THNN
7689        input = input.to(device=device, dtype=torch.half).random_(1, 10).requires_grad_(True)
7690        m = cls(input.size(1), affine=True, track_running_stats=True).to(device, torch.half)
7691        thnn_output = m(input)
7692        thnn_output.sum().backward()
7693        thnn_input_grad = input.grad.data.clone()
7694        self.assertEqualTypeString(thnn_output, input)
7695        # cuDNN
7696        if TEST_CUDNN:
7697            input.grad = None
7698            m = m.float()
7699            cudnn_output = m(input)
7700            cudnn_output.sum().backward()
7701            cudnn_input_grad = input.grad.data.clone()
7702            self.assertEqualTypeString(cudnn_output, input)
7703            self.assertEqual(cudnn_output, thnn_output, atol=1e-4, rtol=0)
7704            self.assertEqual(cudnn_input_grad, thnn_input_grad, atol=1e-3, rtol=0)
7705
7706    def _test_LayerNorm_general(self, device, dtype=torch.float):
7707        for i in range(2, 6):
7708            shape = torch.randint(3, 6, (i,), dtype=torch.long).tolist()
7709            x = torch.empty(*shape, device=device, dtype=dtype).uniform_(0, 10)
7710            normalized_ndim = random.randint(1, i - 1)  # inclusive
7711            normalized_shape = shape[-normalized_ndim:]
7712            unnormalized_shape = shape[:-normalized_ndim]
7713
7714            # test that LN normalizes to mean 0 and stddev 1
7715            ln = nn.LayerNorm(normalized_shape, eps=0).to(device, dtype)
7716            ln.weight.data.fill_(1)
7717            ln.bias.data.fill_(0)
7718            output = ln(x)
7719            out_reshaped = output.view(*(unnormalized_shape + [-1]))
7720            mean = out_reshaped.mean(-1)
7721            var = out_reshaped.var(-1, unbiased=False)
7722
7723            delta = 1e-1 if (dtype == torch.bfloat16 or dtype == torch.half) else 1e-5
7724            self.assertEqual(torch.abs(mean.data).mean(), 0, atol=delta, rtol=0)
7725            self.assertEqual(torch.abs(var.data).mean(), 1, atol=delta, rtol=0)
7726
7727            # test that LN applies weight and bias correctly
7728            scale, bias = torch.empty(2).uniform_(0.2, 2).tolist()
7729            ln.weight.data.fill_(scale)
7730            ln.bias.data.fill_(bias)
7731            output = ln(x)
7732            out_reshaped = output.view(*(unnormalized_shape + [-1]))
7733            mean = out_reshaped.mean(-1)
7734            var = out_reshaped.var(-1, unbiased=False)
7735            self.assertEqual(torch.abs(mean.data).mean(), bias, atol=delta, rtol=0)
7736            self.assertEqual(torch.abs(var.data).mean(), scale ** 2, atol=delta, rtol=0)
7737
7738        bad_norm_shape_input_shape = {
7739            (): (),
7740            (2, 3): (3,),
7741            (2,): (1, 2, 3),
7742            (10,): (2, 3),
7743            10: (2, 3),
7744        }
7745        for norm_shape, input_shape in bad_norm_shape_input_shape.items():
7746            ln = nn.LayerNorm(norm_shape)
7747            input = torch.empty(input_shape, device=device, dtype=dtype).uniform_(0, 10)
7748            self.assertRaises(RuntimeError, lambda: ln(input))
7749
7750    def _test_LayerNorm_cuda_half(self, device):
7751        input = torch.empty(2, 3, 3, 2, device=device, dtype=torch.half).random_(1, 10).requires_grad_(True)
7752        m = nn.LayerNorm([3, 2]).to(device, torch.half)
7753        output = m(input)
7754        output.sum().backward()
7755        self.assertEqualTypeString(output, input)
7756
7757    def _test_LayerNorm_cpu_mixed_dtype(self, device, dtype):
7758        for elementwise_affine in [True, False]:
7759            # layer norm input shape is normalized to m x n, cpu vectorized on n,
7760            # so make sure n exceeds vector length
7761            input = torch.empty(2, 3, 11, 3, device=device, dtype=dtype).random_(1, 10)
7762            m = nn.LayerNorm([11, 3], elementwise_affine=elementwise_affine).to(device, dtype)
7763
7764            # fp32
7765            m_fp32 = deepcopy(m).to(device, torch.float)
7766            x_fp32 = input.clone().detach().float().requires_grad_()
7767            out_fp32 = m_fp32(x_fp32)
7768            out_fp32.sum().backward()
7769
7770            # bf16/half
7771            m_bf16 = deepcopy(m)
7772            x_bf16 = input.clone().detach().requires_grad_()
7773            out_bf16 = m_bf16(x_bf16)
7774            out_bf16.sum().backward()
7775
7776            # bf16/half mixed type
7777            m_mix = deepcopy(m).to(device, torch.float)
7778            x_mix = input.clone().detach().requires_grad_()
7779            out_mix = m_mix(x_mix)
7780            out_mix.sum().backward()
7781            self.assertEqual(out_fp32.to(dtype=dtype), out_bf16)
7782            self.assertEqual(out_fp32.to(dtype=dtype), out_mix)
7783            self.assertEqual(x_fp32.grad.to(dtype=dtype), x_bf16.grad, atol=1e-1, rtol=1e-1)
7784            self.assertEqual(x_fp32.grad.to(dtype=dtype), x_mix.grad, atol=1e-1, rtol=1e-1)
7785
7786    def _test_GroupNorm_general(self, device, dtype=torch.float):
7787        good_shape_g = {
7788            (1, 2, 3, 4): 2,
7789            (2, 3, 10): 3,
7790            (3, 1, 1, 1, 2): 1,
7791            (2, 6, 4, 2, 2): 3,
7792            (1, 256, 1, 1): 32,
7793        }
7794        for shape_g, grad in product(good_shape_g.items(), [True, False]):
7795            shape, g = shape_g
7796            x = torch.empty(*shape, device=device, dtype=dtype).uniform_(0, 10)
7797            x.requires_grad_(grad)
7798            b = shape[0]
7799            c = shape[1]
7800
7801            # test that GN normalizes to mean 0 and stddev 1
7802            gn = nn.GroupNorm(g, c, eps=0).to(device, dtype)
7803            gn.weight.data.fill_(1)
7804            gn.bias.data.fill_(0)
7805            output = gn(x)
7806            out_reshaped = output.view(b, g, -1)
7807            mean = out_reshaped.mean(-1)
7808            var = out_reshaped.var(-1, unbiased=False)
7809            self.assertEqual(torch.abs(mean).mean(), 0, atol=1e-5, rtol=0)
7810            self.assertEqual(torch.abs(var).mean(), 1, atol=1e-5, rtol=0)
7811
7812            output.backward(torch.randn_like(output))
7813            if output.is_cuda:
7814                torch.cuda.synchronize()
7815
7816            # test that GN applies weight and bias correctly
7817            scale = torch.empty(c, device=device, dtype=dtype).uniform_(0.2, 2)
7818            bias = torch.empty(c, device=device, dtype=dtype).uniform_(0.2, 2)
7819            gn.weight.data.copy_(scale)
7820            gn.bias.data.copy_(bias)
7821            output = gn(x)
7822            out_reshaped = output.view(b, c, -1)
7823            out_normed = (out_reshaped - bias.view(c, 1)) / scale.view(c, 1)
7824            out_normed_reshaped = out_normed.view(b, g, -1)
7825            mean = out_normed_reshaped.mean(-1)
7826            var = out_normed_reshaped.var(-1, unbiased=False)
7827            self.assertEqual(torch.abs(mean).mean(), 0, atol=1e-5, rtol=0)
7828            self.assertEqual(torch.abs(var).mean(), 1, atol=1e-5, rtol=0)
7829
7830        bad_shape_g = {
7831            (1, 2, 3, 4): 3,
7832            (2, 3, 10): 2,
7833            (3, 1, 1, 1, 2): 10,
7834            (2, 6, 4, 2, 2): 4,
7835        }
7836        for shape, g in bad_shape_g.items():
7837            with self.assertRaises(ValueError):
7838                gn = nn.GroupNorm(g, shape[1])
7839
7840    def _test_GroupNorm_cuda_half(self):
7841        input = torch.zeros(2, 4, 3, 2, requires_grad=True).cuda().half().random_(1, 10)
7842        m = nn.GroupNorm(2, 4).to("cuda", torch.half)
7843        output = m(input)
7844        output.sum().backward()
7845        self.assertEqualTypeString(output, input)
7846
7847    def _test_GroupNorm_cpu_mixed_dtype(self):
7848        def helper(self, size, groups, memory_format, dtype):
7849            channels = size[1]
7850            input = torch.randn(size).cpu().to(dtype=dtype)
7851            input_bf1 = input.contiguous(memory_format=memory_format).detach().requires_grad_(True)
7852            input_bf2 = input_bf1.clone().detach().requires_grad_(True)
7853            input_f = input_bf1.float().detach().requires_grad_(True)
7854            m_bf = nn.GroupNorm(groups, channels).cpu().to(dtype=dtype)
7855            m_f = deepcopy(m_bf).float()
7856            m_f2 = deepcopy(m_f)
7857            # bfloat16 input and bfloat16 parameters
7858            out = m_bf(input_bf1)
7859            # bfloat16 input and float parameters
7860            out2 = m_f(input_bf2)
7861            # float input and float parameters
7862            out3 = m_f2(input_f)
7863            self.assertEqual(out, out2, atol=5e-3, rtol=5e-3)
7864            self.assertEqual(out2.float(), out3, atol=5e-3, rtol=5e-3)
7865            grad_out = torch.randn(out2.shape).cpu().to(dtype=dtype)
7866            grad_out_bf1 = grad_out.contiguous(memory_format=memory_format).detach().requires_grad_(True)
7867            grad_out_bf2 = grad_out_bf1.clone().detach().requires_grad_(True)
7868            grad_out_f = grad_out_bf2.clone().float().detach().requires_grad_(True)
7869            # bfloat16/half input grad and float parameters
7870            out2.backward(grad_out_bf2, retain_graph=True)
7871            # float input grad and float parameters
7872            out3.backward(grad_out_f, retain_graph=True)
7873            # bfloat16/half input grad and bfloat16/half parameters
7874            out.backward(grad_out_bf1, retain_graph=True)
7875            # Need higher tolerances atol=1e-4 and rtol=1e-4 on macos
7876            self.assertEqual(m_f.weight.grad, m_f2.weight.grad, atol=1e-4, rtol=1e-4)
7877            self.assertEqual(m_f.bias.grad, m_f2.bias.grad, atol=1e-5, rtol=1e-5)
7878            self.assertEqual(input_bf2.grad.float(), input_f.grad, atol=5e-5, rtol=5e-3)
7879            # Full bf16/half has lower precision compared with mixed bf16/half and fp32.
7880            # Use Amp to keep module parameters in acc dtype, i.e. float, for better numerical stability
7881            atol = None
7882            rtol = None
7883            if dtype == torch.bfloat16:
7884                atol = 1e-2
7885                rtol = 1.2e-1
7886            else:
7887                assert dtype == torch.half
7888                atol = 5e-3
7889                rtol = 1.5e-2
7890            self.assertEqual(m_bf.weight.grad, m_f.weight.grad.to(dtype=dtype), atol=atol, rtol=rtol)
7891            self.assertEqual(m_bf.bias.grad, m_f.bias.grad.to(dtype=dtype), atol=atol, rtol=rtol)
7892            self.assertEqual(input_bf1.grad, input_bf2.grad, atol=atol, rtol=rtol)
7893
7894        cl_formats = {4: torch.channels_last, 5: torch.channels_last_3d}
7895        for dtype in [torch.bfloat16, torch.half]:
7896            for shape, g in [((1, 8, 4, 3), 2), ((1, 8, 3, 4), 4),
7897                             ((4, 40, 40, 40), 2), ((4, 8, 40, 40), 4),
7898                             ((1, 8, 40, 40), 4), ((1, 8, 40, 40), 2),
7899                             ((1, 8, 50, 50), 2), ((1, 8, 50, 50), 4),
7900                             ((1, 40, 50, 50), 2), ((1, 9, 3, 4, 5), 3),
7901                             ((1, 60, 10, 10, 10), 3), ((1, 9, 10, 50, 50), 3),
7902                             ((1, 60, 10, 50, 50), 3), ((1, 8, 65, 55), 2),
7903                             ((1, 3, 65, 55), 1), ((1, 3, 20, 20), 1)]:
7904                for is_cl in [False, True]:
7905                    format = cl_formats[len(shape)] if is_cl else torch.contiguous_format
7906                    helper(self, shape, g, format, dtype)
7907
7908    def _test_module_empty_inputs(self, module, inputs):
7909        for _inp in inputs:
7910            _inp.requires_grad_(True)
7911        out = module(*inputs)
7912        gO = torch.rand_like(out)
7913        out.backward(gO)
7914
7915        for p in module.parameters():
7916            if p.requires_grad:
7917                self.assertEqual(p.grad, torch.zeros_like(p.grad))
7918
7919        for _inp in inputs:
7920            self.assertEqual(_inp.grad, torch.zeros_like(_inp))
7921
7922    @unittest.skipIf((not TEST_NUMPY) or (not TEST_SCIPY) or (scipy.__version__ < '1.0.0'),
7923                     "Scipy v1.0 and/or numpy not found")
7924    @expectedFailureMPS  # Unsupported Border padding mode https://github.com/pytorch/pytorch/issues/125098
7925    @tf32_on_and_off()
7926    @bf32_on_and_off()
7927    def test_affine_2d_rotate0(self, device):
7928        # scipy before 1.0.0 do not support homogeneous coordinate
7929        # scipy.ndimage.affine_transform, so we need to skip.
7930        input_size = [1, 1, 3, 3]
7931        input_ary = np.array(np.random.random(input_size), dtype=np.float32)
7932        output_size = [1, 1, 5, 5]
7933        angle_rad = 0.
7934
7935        transform_tensor, transform_ary, offset = \
7936            _buildEquivalentAffineTransforms2d(device, input_size, output_size, angle_rad)
7937
7938        scipy_ary = torch.from_numpy(scipy.ndimage.affine_transform(
7939            input_ary[0, 0],
7940            transform_ary,
7941            offset=offset,
7942            output_shape=output_size[2:],
7943            order=1,
7944            mode='nearest',
7945            prefilter=False))
7946
7947        affine_tensor = torch.nn.functional.affine_grid(
7948            transform_tensor,
7949            torch.Size(output_size),
7950            align_corners=True
7951        )
7952
7953        gridsample_ary = torch.nn.functional.grid_sample(
7954            torch.tensor(input_ary, device=device).to(device),
7955            affine_tensor,
7956            padding_mode='border',
7957            align_corners=True
7958        ).to('cpu')
7959
7960        self.assertEqual(scipy_ary.mean(), gridsample_ary.mean())
7961        self.assertEqual(scipy_ary, gridsample_ary.reshape_as(scipy_ary))
7962
7963    @unittest.skipIf((not TEST_NUMPY) or (not TEST_SCIPY) or (scipy.__version__ < '1.0.0'),
7964                     "Scipy v1.0 and/or numpy not found")
7965    @expectedFailureMPS  # Unsupported Border padding mode https://github.com/pytorch/pytorch/issues/125098
7966    @tf32_on_and_off(0.001)
7967    @bf32_on_and_off(0.001)
7968    def test_affine_2d_rotate90(self, device):
7969        # scipy before 1.0.0 do not support homogeneous coordinate
7970        # scipy.ndimage.affine_transform, so we need to skip.
7971        for input_size2dsq, output_size2dsq in \
7972                itertools.product(input_size2dsq_(), output_size2dsq_()):
7973            input_size = input_size2dsq
7974            input_ary = np.array(np.random.random(input_size), dtype=np.float32)
7975            output_size = output_size2dsq
7976            angle_rad = 0.25 * math.pi * 2
7977
7978            transform_tensor, transform_ary, offset = \
7979                _buildEquivalentAffineTransforms2d(device, input_size, output_size, angle_rad)
7980
7981            scipy_ary = torch.from_numpy(scipy.ndimage.affine_transform(
7982                input_ary[0, 0],
7983                transform_ary,
7984                offset=offset,
7985                output_shape=output_size[2:],
7986                order=1,
7987                mode='nearest',
7988                prefilter=True))
7989
7990            if input_size2dsq == output_size2dsq:
7991                self.assertEqual(scipy_ary.mean(), input_ary.mean())
7992            self.assertEqual(scipy_ary[0, 0], input_ary[0, 0, 0, -1])
7993            self.assertEqual(scipy_ary[0, -1], input_ary[0, 0, -1, -1])
7994            self.assertEqual(scipy_ary[-1, -1], input_ary[0, 0, -1, 0])
7995            self.assertEqual(scipy_ary[-1, 0], input_ary[0, 0, 0, 0])
7996
7997            affine_tensor = torch.nn.functional.affine_grid(
7998                transform_tensor,
7999                torch.Size(output_size),
8000                align_corners=True
8001            )
8002
8003            gridsample_ary = torch.nn.functional.grid_sample(
8004                torch.tensor(input_ary, device=device).to(device),
8005                affine_tensor,
8006                padding_mode='border',
8007                align_corners=True
8008            ).to('cpu')
8009
8010            self.assertEqual(scipy_ary.mean(), gridsample_ary.mean())
8011            self.assertEqual(scipy_ary, gridsample_ary.reshape_as(scipy_ary))
8012
8013    @unittest.skipIf((not TEST_NUMPY) or (not TEST_SCIPY) or (scipy.__version__ < '1.0.0'),
8014                     "Scipy v1.0 and/or numpy not found")
8015    @expectedFailureMPS  # Unsupported Border padding mode https://github.com/pytorch/pytorch/issues/125098
8016    @tf32_on_and_off(0.005)
8017    @bf32_on_and_off(0.005)
8018    def test_affine_2d_rotate45(self, device):
8019        # scipy before 1.0.0 do not support homogeneous coordinate
8020        # scipy.ndimage.affine_transform, so we need to skip.
8021        input_size = [1, 1, 3, 3]
8022        input_ary = np.array(np.zeros(input_size), dtype=np.float32)
8023        input_ary[0, 0, 0, :] = 0.5
8024        input_ary[0, 0, 2, 2] = 1.0
8025        output_size = [1, 1, 3, 3]
8026        angle_rad = 0.125 * math.pi * 2
8027
8028        transform_tensor, transform_ary, offset = \
8029            _buildEquivalentAffineTransforms2d(device, input_size, output_size, angle_rad)
8030
8031        scipy_ary = torch.from_numpy(scipy.ndimage.affine_transform(
8032            input_ary[0, 0],
8033            transform_ary,
8034            offset=offset,
8035            output_shape=output_size[2:],
8036            order=1,
8037            mode='nearest',
8038            prefilter=False))
8039
8040        affine_tensor = torch.nn.functional.affine_grid(
8041            transform_tensor,
8042            torch.Size(output_size),
8043            align_corners=True
8044        )
8045
8046        gridsample_ary = torch.nn.functional.grid_sample(
8047            torch.tensor(input_ary, device=device).to(device),
8048            affine_tensor,
8049            padding_mode='border',
8050            align_corners=True
8051        ).to('cpu')
8052
8053        self.assertEqual(scipy_ary, gridsample_ary.reshape_as(scipy_ary))
8054
8055    @onlyCUDA
8056    @largeTensorTest("60GB", "cpu")
8057    @largeTensorTest("16GB", "cuda")
8058    def test_avg_pool_large_tensor(self, device):
8059        # test for https://github.com/pytorch/pytorch/issues/113833
8060        a = torch.randn(128, 256, 256, 256, dtype=torch.half, device=device, requires_grad=True)
8061        a_cpu = a.detach().cpu().float()
8062        m = torch.nn.AvgPool2d(2)
8063        o = m(a)
8064        a_cpu.requires_grad = True
8065        o.sum().backward()
8066        o_cpu = m(a_cpu)
8067        o_cpu.sum().backward()
8068        # workaround for memory usage overhead of assertEqual
8069        self.assertTrue(torch.allclose(a.grad.cpu(), a_cpu.grad.half()))
8070
8071    @onlyCUDA
8072    @largeTensorTest("48GB", "cpu")
8073    @largeTensorTest("48GB", "cuda")
8074    def test_avg_pool_large_tensor2(self, device):
8075        # test for https://github.com/pytorch/pytorch/issues/129785
8076        out_size = [2048, 64, 104, 79]
8077        size = [2048, 64, 209, 159]
8078        inp = torch.randn(size, device=device, requires_grad=True, dtype=torch.float)
8079        inp_cpu = inp.detach().cpu()
8080        m = torch.nn.AvgPool2d([2, 2], [2, 2], [0, 0], False, True, None)
8081        o = m(inp)
8082        inp_cpu.requires_grad = True
8083        o.sum().backward()
8084        o_cpu = m(inp_cpu)
8085        o_cpu.sum().backward()
8086        self.assertEqual(o.shape, out_size)
8087        self.assertEqual(o_cpu.shape, out_size)
8088        # reduce memory usage
8089        self.assertEqual(inp.grad.sum(), inp_cpu.grad.sum())
8090
8091    @unittest.skipIf((not TEST_NUMPY) or (not TEST_SCIPY) or (scipy.__version__ < '1.0.0'),
8092                     "Scipy v1.0 and/or numpy not found")
8093    @expectedFailureMPS  # Unsupported Border padding mode https://github.com/pytorch/pytorch/issues/125098
8094    @tf32_on_and_off(0.005)
8095    @bf32_on_and_off(0.005)
8096    def test_affine_2d_rotateRandom(self, device):
8097        # scipy before 1.0.0 do not support homogeneous coordinate
8098        # scipy.ndimage.affine_transform, so we need to skip.
8099        for angle_rad, input_size2d, output_size2d in \
8100                itertools.product(angle_rad_(), input_size2d_(), output_size2d_()):
8101
8102            input_size = input_size2d
8103            input_ary = np.array(np.random.random(input_size), dtype=np.float32).round(3)
8104            output_size = output_size2d
8105
8106            input_ary[0, 0, 0, 0] = 2
8107            input_ary[0, 0, 0, -1] = 4
8108            input_ary[0, 0, -1, 0] = 6
8109            input_ary[0, 0, -1, -1] = 8
8110
8111            transform_tensor, transform_ary, grid_ary = \
8112                _buildEquivalentAffineTransforms2d(device, input_size, output_size, angle_rad)
8113
8114            scipy_ary = torch.from_numpy(scipy.ndimage.affine_transform(
8115                input_ary[0, 0],
8116                transform_ary,
8117                output_shape=output_size[2:],
8118                order=1,
8119                mode='nearest',
8120                prefilter=False))
8121
8122            affine_tensor = torch.nn.functional.affine_grid(
8123                transform_tensor,
8124                torch.Size(output_size),
8125                align_corners=True
8126            )
8127
8128            gridsample_ary = torch.nn.functional.grid_sample(
8129                torch.tensor(input_ary, device=device).to(device),
8130                affine_tensor,
8131                padding_mode='border',
8132                align_corners=True
8133            ).to('cpu')
8134
8135            affine_tensor = affine_tensor.to('cpu')
8136
8137            for r in range(affine_tensor.size(1)):
8138                for c in range(affine_tensor.size(2)):
8139                    grid_out = np.dot(grid_ary, [r, c, 1])
8140                    self.assertEqual(affine_tensor[0, r, c], grid_out[:2], exact_dtype=False)
8141
8142            self.assertEqual(scipy_ary, gridsample_ary.reshape_as(scipy_ary))
8143
8144    @unittest.skipIf((not TEST_NUMPY) or (not TEST_SCIPY) or (scipy.__version__ < '1.0.0'),
8145                     "Scipy v1.0 and/or numpy not found")
8146    @expectedFailureMPS  # aten::grid_sampler_3d not implemented https://github.com/pytorch/pytorch/issues/77764
8147    @tf32_on_and_off(0.005)
8148    @bf32_on_and_off(0.005)
8149    def test_affine_3d_rotateRandom(self, device):
8150        # scipy before 1.0.0 do not support homogeneous coordinate
8151        # scipy.ndimage.affine_transform, so we need to skip.
8152        for angle_rad, axis_vector, input_size3d, output_size3d in \
8153                itertools.product(angle_rad_(), axis_vector_(), input_size3d_(), output_size3d_()):
8154            input_size = input_size3d
8155            input_ary = np.array(np.random.random(input_size), dtype=np.float32)
8156            output_size = output_size3d
8157
8158            input_ary[0, 0, 0, 0, 0] = 2
8159            input_ary[0, 0, 0, 0, -1] = 3
8160            input_ary[0, 0, 0, -1, 0] = 4
8161            input_ary[0, 0, 0, -1, -1] = 5
8162            input_ary[0, 0, -1, 0, 0] = 6
8163            input_ary[0, 0, -1, 0, -1] = 7
8164            input_ary[0, 0, -1, -1, 0] = 8
8165            input_ary[0, 0, -1, -1, -1] = 9
8166
8167            transform_tensor, transform_ary, grid_ary = \
8168                _buildEquivalentAffineTransforms3d(device, input_size, output_size, angle_rad, axis_vector)
8169
8170            scipy_ary = torch.from_numpy(scipy.ndimage.affine_transform(
8171                input_ary[0, 0],
8172                transform_ary,
8173                output_shape=output_size[2:],
8174                order=1,
8175                mode='nearest',
8176                prefilter=False))
8177
8178            affine_tensor = torch.nn.functional.affine_grid(
8179                transform_tensor,
8180                torch.Size(output_size),
8181                align_corners=True
8182            )
8183
8184            gridsample_ary = torch.nn.functional.grid_sample(
8185                torch.tensor(input_ary, device=device).to(device),
8186                affine_tensor,
8187                padding_mode='border',
8188                align_corners=True
8189            ).to('cpu')
8190
8191            affine_tensor = affine_tensor.to('cpu')
8192
8193            for i in range(affine_tensor.size(1)):
8194                for r in range(affine_tensor.size(2)):
8195                    for c in range(affine_tensor.size(3)):
8196                        grid_out = np.dot(grid_ary, [i, r, c, 1])
8197                        self.assertEqual(affine_tensor[0, i, r, c], grid_out[:3], exact_dtype=False)
8198
8199            self.assertEqual(scipy_ary, gridsample_ary.reshape_as(scipy_ary))
8200
8201
8202    @onlyCUDA
8203    @dtypes(torch.float, torch.half)
8204    def test_batchnorm_large_batch(self, device, dtype):
8205        bn = nn.BatchNorm2d(1).to(device, dtype)
8206        data = torch.rand(880801, 1, 1, 1, device=device, dtype=dtype)
8207        out = bn(data).sum().backward()
8208
8209    @dtypesIfCUDA(torch.float, torch.double, torch.half, torch.complex128)
8210    @dtypesIfMPS(torch.float, torch.half, torch.complex64)
8211    @dtypes(torch.float, torch.double, torch.bfloat16, torch.complex128)
8212    def test_conv_empty_input(self, device, dtype):
8213        def help(input, conv, memory_format):
8214            ref_out = conv(input)
8215            conv_cl = conv.to(memory_format=memory_format)
8216            out_cl = conv_cl(input)
8217            self.assertEqual(ref_out, out_cl)
8218            input_cl = input.to(memory_format=memory_format)
8219            out_cl2 = conv(input_cl)
8220            self.assertEqual(out_cl, out_cl2)
8221            out_cl3 = conv_cl(input_cl)
8222            self.assertEqual(out_cl, out_cl3)
8223
8224        # channels_last case
8225        input2d = torch.randn((0, 4, 20, 20)).to(device=device, dtype=dtype)
8226        conv2d = torch.nn.Conv2d(4, 4, 3, 1).to(device=device, dtype=dtype)
8227        help(input2d, conv2d, torch.channels_last)
8228        # channels_last_3d case
8229        input3d = torch.randn((0, 4, 20, 20, 20)).to(device=device, dtype=dtype)
8230        conv3d = torch.nn.Conv3d(4, 4, 3, 1).to(device=device, dtype=dtype)
8231        help(input3d, conv3d, torch.channels_last_3d)
8232        # non-contiguous case
8233        weight = torch.rand(4, 8, 3, 3)[:, ::2, :, :].to(device=device, dtype=dtype)
8234        bias = torch.rand(4).to(device=device, dtype=dtype)
8235        out = F.conv2d(input2d, weight, bias, (1, 1), 0, (1, 1), 1)
8236        weight = weight.contiguous()
8237        out_ref = F.conv2d(input2d, weight, bias, (1, 1), 0, (1, 1), 1)
8238        self.assertEqual(out_ref, out)
8239        # sigfpe reported in https://github.com/pytorch/pytorch/issues/94125
8240        with self.assertRaises(RuntimeError):
8241            inp = torch.empty([1, 1, 1, 0], dtype=dtype, device=device)
8242            weight = torch.empty([1, 0, 1], dtype=dtype, device=device)
8243            torch._C._nn.slow_conv3d(inp, weight, 1)
8244
8245        with self.assertRaisesRegex(RuntimeError, re.escape("2D kernel_size expected")):
8246            torch._C._nn.thnn_conv2d(torch.rand([1, 1, 1, 1]), kernel_size=[], padding=[1, 1], stride=[1, 1],
8247                                     weight=torch.rand([1, 1]))
8248        with self.assertRaisesRegex(RuntimeError, re.escape("2D stride expected")):
8249            torch._C._nn.thnn_conv2d(torch.rand([1, 1, 1, 1]), kernel_size=[1, 1], padding=[1, 1], stride=[],
8250                                     weight=torch.rand([1, 1]))
8251        with self.assertRaisesRegex(RuntimeError, re.escape("2D padding expected")):
8252            torch._C._nn.thnn_conv2d(torch.rand([1, 1, 1, 1]), kernel_size=[1, 1], padding=[], stride=[1, 1],
8253                                     weight=torch.rand([1, 1]))
8254
8255    def test_InstanceNorm1d_general(self, device):
8256        b = random.randint(3, 5)
8257        c = random.randint(3, 5)
8258        d = random.randint(8, 10)
8259
8260        input = torch.rand(b, c, d)
8261        self._test_InstanceNorm_general(nn.InstanceNorm1d, input, device)
8262
8263        if self.device_type == 'cuda':
8264            self._test_InstanceNorm_cuda_half(nn.InstanceNorm1d, input, device)
8265
8266    def test_InstanceNorm2d_general(self, device):
8267        b = random.randint(3, 5)
8268        c = random.randint(3, 5)
8269        w = random.randint(3, 6)
8270        h = random.randint(6, 8)
8271
8272        input = torch.rand(b, c, h, w)
8273        self._test_InstanceNorm_general(nn.InstanceNorm2d, input, device)
8274
8275        if self.device_type == 'cuda':
8276            self._test_InstanceNorm_cuda_half(nn.InstanceNorm2d, input, device)
8277
8278    def test_InstanceNorm3d_general(self, device):
8279        b = random.randint(3, 5)
8280        c = random.randint(3, 5)
8281        w = random.randint(2, 5)
8282        h = random.randint(2, 5)
8283        d = random.randint(2, 5)
8284
8285        input = torch.rand(b, c, h, w, d)
8286        self._test_InstanceNorm_general(nn.InstanceNorm3d, input, device)
8287
8288        if self.device_type == 'cuda':
8289            self._test_InstanceNorm_cuda_half(nn.InstanceNorm3d, input, device)
8290
8291    @parametrize_test("instance_norm_cls", [nn.InstanceNorm1d, nn.InstanceNorm2d, nn.InstanceNorm3d], name_fn=lambda c: c.__name__)
8292    @parametrize_test("no_batch_dim", [True, False])
8293    @parametrize_test("affine", [True, False])
8294    def test_instancenorm_raises_error_if_input_channels_is_not_num_features(self, device, instance_norm_cls, no_batch_dim, affine):
8295        inst_norm = instance_norm_cls(4, affine=affine)
8296        size = [2] * inst_norm._get_no_batch_dim()
8297        if not no_batch_dim:
8298            size = [3] + size
8299        t = torch.randn(size)
8300        if affine:
8301            with self.assertRaisesRegex(ValueError, "expected input's size at dim="):
8302                inst_norm(t)
8303        else:
8304            with warnings.catch_warnings(record=True) as w:
8305                inst_norm(t)
8306            self.assertIn("which is not used because affine=False", str(w[0].message))
8307
8308    def test_instancenorm_raises_error_if_less_than_one_value_per_channel(self, device):
8309        x = torch.rand(10)[None, :, None]
8310        with self.assertRaises(ValueError):
8311            torch.nn.InstanceNorm1d(10)(x).to(device)
8312
8313    def test_instancenorm_raises_error_for_single_spatial_element_during_training(self, device):
8314        BATCH_SIZE = 10
8315        NUM_CHANNELS = 3
8316        norms = [torch.nn.InstanceNorm1d, torch.nn.InstanceNorm2d, torch.nn.InstanceNorm3d]
8317        for i, norm in enumerate(norms):
8318            m = norm(NUM_CHANNELS, track_running_stats=True)
8319            m.to(device)
8320
8321            # Create an appropriately-sized input with a single spatial element.
8322            input = torch.randn(BATCH_SIZE, NUM_CHANNELS, *[1 for _ in range(i + 1)],
8323                                device=device)
8324            with self.assertRaises(ValueError):
8325                m(input)
8326
8327            # Single spatial element should be fine in eval.
8328            m.eval()
8329            m(input)
8330
8331    def test_LayerNorm_general(self, device):
8332        self._test_LayerNorm_general(device)
8333
8334        if self.device_type == 'cuda' or self.device_type == 'cpu':
8335            for dtype in [torch.half, torch.bfloat16]:
8336                self._test_LayerNorm_general(device, dtype=dtype)
8337
8338        if self.device_type == 'cuda':
8339            self._test_LayerNorm_cuda_half(device)
8340
8341        if self.device_type == 'cpu':
8342            for dtype in [torch.half, torch.bfloat16]:
8343                self._test_LayerNorm_cpu_mixed_dtype(device, dtype=dtype)
8344
8345    @onlyNativeDeviceTypes
8346    def test_LayerNorm_numeric(self, device):
8347        def layer_norm_ref(X, gamma, beta, normalized_shape, eps):
8348            feature_size = np.prod(normalized_shape)
8349            X_view = X.view(-1, feature_size)
8350            mean = X_view.mean(dim=-1, keepdim=True)
8351            var = X_view.var(dim=-1, unbiased=False, keepdim=True)
8352            Y = (X_view - mean) / torch.sqrt(var + eps)
8353            Y = Y * gamma.view(-1) + beta.view(-1)
8354            return Y.view(*X.size())
8355
8356        normalized_shape = [256, 256, 144]
8357        layer_norm = nn.LayerNorm(normalized_shape).float().to(device)
8358        X = torch.rand(2, *normalized_shape, dtype=torch.float32,
8359                       device=device)
8360
8361        Y = layer_norm(X)
8362        Y_ref = layer_norm_ref(X, layer_norm.weight.data, layer_norm.bias.data,
8363                               normalized_shape, layer_norm.eps)
8364        self.assertEqual(Y, Y_ref, rtol=0, atol=1e-5)
8365
8366        if self.device_type == 'cuda':
8367            layer_norm.cpu()
8368            Y_cpu = layer_norm(X.cpu())
8369            self.assertEqual(Y_cpu, Y, rtol=0, atol=1e-5)
8370
8371    @onlyCPU
8372    def test_glu_bfloat16(self, device):
8373        def test_dtype(fn, input, dtype):
8374            input = input.detach().clone().to(dtype=dtype).requires_grad_(True)
8375            input2 = input.detach().clone().float().requires_grad_(True)
8376            out = fn(input)
8377            out.sum().backward()
8378            out2 = fn(input2)
8379            out2.sum().backward()
8380            self.assertEqual(out.dtype, dtype)
8381            self.assertEqual(input.grad.dtype, dtype)
8382            self.assertEqual(out, out2, exact_dtype=False)
8383            self.assertEqual(input.grad, input2.grad, atol=1e-2, rtol=0, exact_dtype=False)
8384
8385        def func(device):
8386            return torch.nn.GLU(dim=-1).to(device)
8387
8388        shapes = [[1, 3, 1, 6], [1, 3, 1, 128], [1, 3, 256, 256]]
8389        for shape in shapes:
8390            x = torch.randn(shape, device=device)
8391            test_dtype(func(device), x, torch.bfloat16)
8392
8393    @onlyNativeDeviceTypes
8394    def test_GroupNorm_general(self, device):
8395        self._test_GroupNorm_general(device)
8396
8397        if self.device_type == 'cuda':
8398            self._test_GroupNorm_cuda_half()
8399
8400        if self.device_type == 'cpu':
8401            self._test_GroupNorm_cpu_mixed_dtype()
8402
8403    def test_GroupNorm_raises_error_if_one_value_per_group(self, device):
8404        x = torch.rand(10)[None, :, None]
8405        with self.assertRaises(ValueError):
8406            torch.nn.GroupNorm(10, 10)(x).to(device)
8407
8408    def test_GroupNorm_empty(self, device):
8409        mod = torch.nn.GroupNorm(2, 4).to(device)
8410        inp = torch.randn(0, 4, 2, 2, device=device)
8411        _test_module_empty_input(self, mod, inp)
8412        if self.device_type == 'cuda' and self.has_cudnn():
8413            with torch.backends.cudnn.flags(enabled=False):
8414                _test_module_empty_input(self, mod, inp)
8415
8416    @onlyCPU
8417    @dtypes(torch.float, torch.double, torch.bfloat16, torch.half)
8418    def test_groupnorm_nhwc(self, device, dtype):
8419        def helper(self, size, groups, memory_format, is_mixed):
8420            channels = size[1]
8421            input = torch.randn(size, dtype=dtype, device=device, requires_grad=True)
8422            input = input.contiguous(memory_format=memory_format)
8423            input.retain_grad()
8424            grad = torch.randn(size, dtype=dtype, device=device)
8425            grad = grad.contiguous(memory_format=memory_format)
8426            if dtype == torch.bfloat16 and is_mixed:
8427                gn = nn.GroupNorm(groups, channels).to(device).to(torch.float)
8428            else:
8429                gn = nn.GroupNorm(groups, channels).to(device).to(dtype)
8430            gn.weight.data.uniform_()
8431            gn.bias.data.uniform_()
8432
8433            ref_input = input.detach().clone().contiguous(memory_format=torch.contiguous_format).requires_grad_(True)
8434            ref_grad = grad.detach().clone().contiguous(memory_format=torch.contiguous_format)
8435            if dtype == torch.bfloat16 and is_mixed:
8436                ref_gn = nn.GroupNorm(groups, channels).to(device).to(torch.float)
8437            else:
8438                ref_gn = nn.GroupNorm(groups, channels).to(device).to(dtype)
8439            ref_gn.load_state_dict(gn.state_dict())
8440            out = gn(input)
8441            out.backward(grad)
8442            ref_out = ref_gn(ref_input)
8443            ref_out.backward(ref_grad)
8444
8445            self.assertTrue(out.is_contiguous(memory_format=memory_format))
8446            self.assertTrue(ref_out.is_contiguous(memory_format=torch.contiguous_format))
8447            self.assertEqual(out, ref_out)
8448            # parameters in bfloat16/Half is not recommended
8449            atol = 5e-4
8450            rtol = 8e-3
8451
8452            self.assertEqual(gn.weight.grad, ref_gn.weight.grad, atol=atol, rtol=rtol)
8453            self.assertEqual(gn.bias.grad, ref_gn.bias.grad, atol=atol, rtol=rtol)
8454            self.assertEqual(input.grad, ref_input.grad, atol=atol, rtol=rtol)
8455
8456        for is_mixed in [True, False]:
8457            helper(self, (4, 8, 10, 10), 4, torch.channels_last, is_mixed)
8458            helper(self, (2, 30, 9, 9), 3, torch.channels_last, is_mixed)
8459            helper(self, (4, 8, 40, 40), 4, torch.channels_last, is_mixed)
8460            helper(self, (4, 40, 40, 40), 2, torch.channels_last, is_mixed)
8461            helper(self, (2, 30, 50, 50), 3, torch.channels_last, is_mixed)
8462            helper(self, (2, 60, 50, 50), 3, torch.channels_last, is_mixed)
8463            helper(self, (2, 9, 7, 11, 15), 3, torch.channels_last_3d, is_mixed)
8464            helper(self, (2, 9, 7, 200, 15), 3, torch.channels_last_3d, is_mixed)
8465            helper(self, (2, 60, 7, 200, 15), 3, torch.channels_last_3d, is_mixed)
8466
8467    @onlyNativeDeviceTypes
8468    def test_GroupNorm_memory_format(self, device):
8469        # Tests for regression reported in https://github.com/pytorch/pytorch/issues/92166
8470
8471        def helper(input_format, grad_format, B=2, C=4, W=4, H=4):
8472            import copy
8473            net_orig = torch.nn.GroupNorm(B, C).to(device=device)
8474            net = copy.deepcopy(net_orig)
8475            x_orig = torch.rand(B, C, W, H, device=device, requires_grad=True)
8476            grad_orig = torch.rand(B, C, W, H, device=device)
8477            x = x_orig.clone().detach().to(memory_format=input_format).requires_grad_(True)
8478            grad = grad_orig.detach().to(memory_format=grad_format)
8479
8480            y = net(x)
8481            y.backward(grad)
8482
8483            y_orig = net_orig(x_orig)
8484            y_orig.backward(grad_orig)
8485
8486            self.assertEqual(y, y_orig)
8487            self.assertEqual(x.grad, x_orig.grad)
8488
8489        for input_format in [torch.contiguous_format, torch.channels_last]:
8490            for grad_format in [torch.contiguous_format, torch.channels_last]:
8491                helper(input_format, grad_format)
8492
8493    @onlyNativeDeviceTypes
8494    def test_GroupNorm_numeric(self, device):
8495        def group_norm_ref(X, gamma, beta, groups, channels, eps):
8496            batch_size = X.size()[0]
8497            X_view = X.view(batch_size, groups, -1)
8498            mean = X_view.mean(dim=-1, keepdim=True)
8499            var = X_view.var(dim=-1, unbiased=False, keepdim=True)
8500            Y = ((X_view - mean) / torch.sqrt(var + eps)).view(
8501                batch_size, channels, -1)
8502            Y = Y * gamma.view(channels, 1) + beta.view(channels, 1)
8503            return Y.view(*X.size())
8504
8505        batch_size = 1
8506        groups = 2
8507        channels = 8
8508        group_norm = nn.GroupNorm(groups, channels).float().to(device)
8509        X = torch.rand(batch_size, channels, 256, 256, 72,
8510                       dtype=torch.float32, device=device)
8511
8512        Y = group_norm(X)
8513        Y_ref = group_norm_ref(
8514            X, group_norm.weight.data, group_norm.bias.data, groups,
8515            channels, group_norm.eps)
8516        self.assertEqual(Y, Y_ref, rtol=0, atol=1e-5)
8517
8518        if self.device_type == 'cuda':
8519            group_norm.cpu()
8520            Y_cpu = group_norm(X.cpu())
8521            self.assertEqual(Y_cpu, Y, rtol=0, atol=1e-5)
8522
8523    @onlyNativeDeviceTypes
8524    @dtypes(torch.float64, torch.complex128)
8525    def test_pad(self, device, dtype):
8526        # Assert assertion errors are raised for invalid circular padding values
8527        inputs = torch.randn(1, 1, 4, device=device, dtype=dtype, requires_grad=True)
8528        # Should raise error when trying to wrap around more than once
8529        self.assertRaises(RuntimeError, lambda: F.pad(inputs, (5, 4), mode='circular'))
8530        self.assertRaises(RuntimeError, lambda: F.pad(inputs, (3, 6), mode='circular'))
8531        # Should raise error when negative padding results in negative output shape
8532        self.assertRaises(RuntimeError, lambda: F.pad(inputs, (-3, -2), mode='circular'))
8533
8534        # assert that relfection padding errors when pad >= input size
8535        expected_err_msg = r"Padding size should be less than the corresponding input dimension"
8536        inputs = torch.randn(1, 1, 2, 3, device=device, dtype=dtype)
8537        self.assertRaisesRegex(RuntimeError, expected_err_msg,
8538                               lambda: F.pad(inputs, (1, 1, 3, 0), mode='reflect'))
8539        inputs = torch.randn(1, 1, 2, device=device, dtype=dtype)
8540        self.assertRaisesRegex(RuntimeError, expected_err_msg,
8541                               lambda: F.pad(inputs, (2, 1), mode='reflect'))
8542
8543        inputs = torch.rand(1, 3, 4, 4, device=device, dtype=dtype)
8544        # assert that pad doesn't return a view into the input tensor
8545        for mode in 'constant', 'reflect', 'replicate', 'circular':
8546            out = F.pad(inputs, (0, 0, 0, 0), mode=mode)
8547            out.fill_(4)
8548            self.assertTrue(torch.all(torch.abs(inputs) < 2))
8549
8550            out = F.pad(inputs, (0, 0, -1, -1), mode=mode)
8551            out.fill_(4)
8552            self.assertTrue(torch.all(torch.abs(inputs) < 2))
8553
8554    @onlyNativeDeviceTypes
8555    @dtypes(torch.float64, torch.complex128)
8556    def test_ReplicationPad_empty(self, device, dtype):
8557        for mod, inp in [
8558                (torch.nn.ReplicationPad1d(3), torch.randn(0, 3, 10, device=device, dtype=dtype)),
8559                (torch.nn.ReplicationPad2d(3), torch.randn(0, 3, 10, 10, device=device, dtype=dtype)),
8560                (torch.nn.ReplicationPad3d(3), torch.randn(0, 3, 10, 10, 10, device=device, dtype=dtype))]:
8561            _test_module_empty_input(self, mod, inp, check_size=False)
8562
8563        with self.assertRaisesRegex(RuntimeError, 'Expected 2D or 3D'):
8564            mod = torch.nn.ReplicationPad1d(2)
8565            inp = torch.randn(3, 0, 10, device=device, dtype=dtype)
8566            mod(inp)
8567
8568        with self.assertRaisesRegex(RuntimeError, 'Expected 3D or 4D'):
8569            mod = torch.nn.ReplicationPad2d((2, 2, 2, 2))
8570            inp = torch.randn(43, 0, 10, 10, device=device, dtype=dtype)
8571            mod(inp)
8572
8573        with self.assertRaisesRegex(RuntimeError, 'Expected 4D or 5D'):
8574            mod = torch.nn.ReplicationPad3d((2, 2, 2, 2, 2, 2))
8575            inp = torch.randn(3, 0, 10, 10, 10, device=device, dtype=dtype)
8576            mod(inp)
8577
8578        with self.assertRaisesRegex(RuntimeError, 'padding size is expected to be 2'):
8579            torch._C._nn.replication_pad1d(torch.randn([2]), padding=[])
8580
8581        with self.assertRaisesRegex(RuntimeError, 'padding size is expected to be 4'):
8582            torch._C._nn.replication_pad2d(torch.randn([2]), padding=[])
8583
8584        with self.assertRaisesRegex(RuntimeError, 'padding size is expected to be 6'):
8585            torch._C._nn.replication_pad3d(torch.randn([2]), padding=[])
8586
8587    @expectedFailureMPS  # TODO(hvaara): Investigate as possible bug.
8588    def test_ReplicationPad1d_large(self, device):
8589        shapes = ([2, 65736, 4], [65736, 2, 4])
8590        pl, pr = 3, 4
8591        for shape in shapes:
8592            x = torch.randn(shape, device=device, requires_grad=True)
8593            model = torch.nn.ReplicationPad1d((pl, pr))
8594
8595            # forward
8596            out = model(x)
8597            self.assertEqual(out[:, :, pl : -pr], x)
8598
8599            left_padding = out[:, :, : pl]
8600            self.assertEqual(left_padding, x[:, :, :1].expand_as(left_padding))
8601            right_padding = out[:, :, -pr :]
8602            self.assertEqual(right_padding, x[:, :, -1:].expand_as(right_padding))
8603
8604            # backward
8605            g = torch.randn_like(out)
8606            out.backward(g)
8607            self.assertEqual(x.grad[:, :, 1 : -1], g[:, :, pl + 1 : -pr - 1])
8608
8609            self.assertEqual(x.grad[:, :, 0], g[:, :, : pl + 1].sum(-1))
8610            self.assertEqual(x.grad[:, :, -1], g[:, :, -pr - 1:].sum(-1))
8611
8612    @expectedFailureMPS  # TODO(hvaara): Investigate as possible bug.
8613    def test_ReplicationPad2d_large(self, device):
8614        shapes = ([2, 65736, 4, 4], [65736, 2, 4, 4])
8615        pl, pr, pt, pb = 3, 4, 5, 6
8616        for shape in shapes:
8617            x = torch.randn(shape, device=device, requires_grad=True)
8618            model = torch.nn.ReplicationPad2d((pl, pr, pt, pb))
8619
8620            # forward center, edge
8621            out = model(x)
8622            self.assertEqual(out[:, :, pt : -pb, pl : -pr], x)
8623
8624            left_padding = out[:, :, pt : -pb, : pl]
8625            self.assertEqual(left_padding, x[:, :, :, :1].expand_as(left_padding))
8626            right_padding = out[:, :, pt : -pb, -pr :]
8627            self.assertEqual(right_padding, x[:, :, :, -1:].expand_as(right_padding))
8628            top_padding = out[:, :, : pt, pl : -pr]
8629            self.assertEqual(top_padding, x[:, :, :1, :].expand_as(top_padding))
8630            bottom_padding = out[:, :, -pb : , pl : -pr]
8631            self.assertEqual(bottom_padding, x[:, :, -1:, :].expand_as(bottom_padding))
8632
8633            # forward corner
8634            tl_padding = out[:, :, : pt + 1, : pl + 1]
8635            self.assertEqual(tl_padding, x[:, :, :1, :1].expand_as(tl_padding))
8636            tr_padding = out[:, :, : pt + 1, -pr - 1:]
8637            self.assertEqual(tr_padding, x[:, :, :1, -1:].expand_as(tr_padding))
8638            bl_padding = out[:, :, -pb - 1:, : pl + 1]
8639            self.assertEqual(bl_padding, x[:, :, -1:, :1].expand_as(bl_padding))
8640            br_padding = out[:, :, -pb - 1:, -pr - 1:]
8641            self.assertEqual(br_padding, x[:, :, -1:, -1:].expand_as(br_padding))
8642
8643            # backward center, edge
8644            g = torch.randn_like(out)
8645            out.backward(g)
8646            self.assertEqual(x.grad[:, :, 1:-1, 1:-1], g[:, :, pt + 1 : -pb - 1, pl + 1 : -pr - 1])
8647
8648            self.assertEqual(x.grad[:, :, 1:-1, 0], g[:, :, pt + 1 : -pb - 1, : pl + 1].sum(-1))
8649            self.assertEqual(x.grad[:, :, 1:-1, -1], g[:, :, pt + 1 : -pb - 1, -pr - 1 :].sum(-1))
8650            self.assertEqual(x.grad[:, :, 0, 1:-1], g[:, :, : pt + 1, pl + 1 : -pr - 1].sum(-2))
8651            self.assertEqual(x.grad[:, :, -1, 1:-1], g[:, :, -pb - 1 :, pl + 1 : -pr - 1].sum(-2))
8652
8653            # backward corner
8654            self.assertEqual(x.grad[:, :, 0, 0], g[:, :, : pt + 1, : pl + 1].sum((-2, -1)))
8655            self.assertEqual(x.grad[:, :, 0, -1], g[:, :, : pt + 1, -pr - 1 :].sum((-2, -1)))
8656            self.assertEqual(x.grad[:, :, -1, 0], g[:, :, -pb - 1 :, : pl + 1].sum((-2, -1)))
8657            self.assertEqual(x.grad[:, :, -1, -1], g[:, :, -pb - 1 :, -pr - 1 :].sum((-2, -1)))
8658
8659    @largeTensorTest("6GB")
8660    def test_ReplicationPad3d_large(self, device):
8661        shapes = ([1, 65736, 2, 2, 2], [65736, 1, 2, 2, 2])
8662        pl, pr, pt, pbt, pf, pbk = 3, 4, 5, 6, 7, 8
8663
8664        for shape in shapes:
8665            x = torch.randn(shape, device=device, requires_grad=True)
8666            model = torch.nn.ReplicationPad3d((pl, pr, pt, pbt, pf, pbk))
8667
8668            # forward center
8669            out = model(x)
8670            self.assertEqual(out[:, :, pf : -pbk, pt : -pbt, pl : -pr], x)
8671
8672            # backward center
8673            g = torch.randn_like(out)
8674            out.backward(g)
8675            self.assertEqual(x.grad[:, :, 1:-1, 1:-1, 1:-1], g[:, :, pf + 1 : -pbk - 1, pt + 1 : -pbt - 1, pl + 1 : -pr - 1])
8676
8677    @onlyNativeDeviceTypes
8678    def test_Bilinear_empty(self, device):
8679        mod = torch.nn.Bilinear(20, 30, 40).to(device)
8680        inp1 = torch.randn(0, 10, 20, requires_grad=True, device=device)
8681        inp2 = torch.randn(0, 10, 30, requires_grad=True, device=device)
8682
8683        output = mod(inp1, inp2)
8684        output.sum().backward()
8685
8686        self.assertEqual(inp1, torch.zeros_like(inp1))
8687        self.assertEqual(inp2, torch.zeros_like(inp2))
8688
8689        self.assertEqual(inp1.grad, torch.zeros_like(inp1))
8690        self.assertEqual(inp2.grad, torch.zeros_like(inp2))
8691
8692    @expectedFailureMeta  # RuntimeError: cannot reshape tensor of 0 elements into shape [1, 0, -1]
8693    @onlyNativeDeviceTypes
8694    def test_TransformerEncoderLayer_empty(self, device):
8695        for training in (True, False):
8696            for batch_first, input_shape in [(True, (0, 10, 512)),
8697                                             (False, (10, 0, 512))]:
8698                input = torch.rand(*input_shape, device=device, dtype=torch.double)
8699                encoder_layer = nn.TransformerEncoderLayer(
8700                    d_model=512, nhead=8, batch_first=batch_first, dtype=torch.double).to(device)
8701                if not training:
8702                    encoder_layer = encoder_layer.eval()
8703                    with torch.no_grad():
8704                        _test_module_empty_input(self, encoder_layer, input, check_size=False, inference=True)
8705                    if batch_first and not TEST_WITH_CROSSREF:
8706                        with torch.no_grad():
8707                            # A NestedTensor with no tensors inside it doesn't have dim 3 (or dim
8708                            # 2, for that matter) so it can't hit the fast path, nor can we give a
8709                            # result.
8710                            with self.assertRaisesRegex(
8711                                    AssertionError, 'MultiheadAttention does not support NestedTensor outside'):
8712                                nt = torch.nested.nested_tensor([], device=device)
8713                                _test_module_empty_input(self, encoder_layer, nt, check_size=False, inference=True)
8714
8715                            nt = torch.nested.nested_tensor([torch.rand(0, 512, device=device, dtype=torch.double)], device=device)
8716                            _test_module_empty_input(self, encoder_layer, nt, check_size=False, inference=True)
8717                else:
8718                    _test_module_empty_input(self, encoder_layer, input, check_size=False)
8719
8720    @expectedFailureMeta  # RuntimeError: cannot reshape tensor of 0 elements into shape [1, 0, -1]
8721    @onlyNativeDeviceTypes
8722    def test_TransformerEncoder_empty(self, device):
8723        for batch_first, input_shape in [(True, (0, 10, 512)),
8724                                         (False, (10, 0, 512))]:
8725            input = torch.rand(*input_shape, device=device, dtype=torch.double)
8726            encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8, batch_first=batch_first, dtype=torch.double).to(device)
8727            transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=6).to(device)
8728            _test_module_empty_input(self, transformer_encoder, input, check_size=False)
8729
8730    @expectedFailureMeta  # RuntimeError: cannot reshape tensor of 0 elements into shape [1, 0, -1]
8731    @onlyNativeDeviceTypes
8732    def test_TransformerDecoderLayer_empty(self, device):
8733        for batch_first, memory_shape, tgt_shape in [(True, (0, 10, 512), (0, 20, 512)),
8734                                                     (False, (10, 0, 512), (20, 0, 512))]:
8735            memory = torch.rand(*memory_shape, device=device, dtype=torch.double)
8736            tgt = torch.rand(*tgt_shape, requires_grad=True, device=device, dtype=torch.double)
8737            decoder_layer = nn.TransformerDecoderLayer(d_model=512, nhead=8, batch_first=batch_first, dtype=torch.double).to(device)
8738            self._test_module_empty_inputs(decoder_layer, [tgt, memory])
8739
8740    @expectedFailureMeta  # RuntimeError: cannot reshape tensor of 0 elements into shape [1, 0, -1]
8741    @onlyNativeDeviceTypes
8742    def test_TransformerDecoder_empty(self, device):
8743        for batch_first, memory_shape, tgt_shape in [(True, (0, 10, 512), (0, 20, 512)),
8744                                                     (False, (10, 0, 512), (20, 0, 512))]:
8745            memory = torch.rand(*memory_shape, device=device, dtype=torch.double)
8746            tgt = torch.rand(*tgt_shape, requires_grad=True, device=device, dtype=torch.double)
8747            decoder_layer = nn.TransformerDecoderLayer(d_model=512, nhead=8, batch_first=batch_first, dtype=torch.double).to(device)
8748            transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers=6).to(device)
8749            self._test_module_empty_inputs(transformer_decoder, [tgt, memory])
8750
8751    @expectedFailureMeta  # RuntimeError: cannot reshape tensor of 0 elements into shape [1, 0, -1]
8752    @onlyNativeDeviceTypes
8753    def test_Transformer_empty(self, device):
8754        for batch_first, src_shape, tgt_shape in [(True, (10, 0, 512), (20, 0, 512))]:
8755            transformer_model = nn.Transformer(nhead=16, num_encoder_layers=12, dtype=torch.double).to(device)
8756            src = torch.rand(*src_shape, requires_grad=True, device=device, dtype=torch.double)
8757            tgt = torch.rand(*tgt_shape, requires_grad=True, device=device, dtype=torch.double)
8758            self._test_module_empty_inputs(transformer_model, [src, tgt])
8759
8760    @onlyNativeDeviceTypes
8761    @dtypes(torch.float32, torch.complex64)
8762    def test_ReflectionPad_empty(self, device, dtype):
8763        for mod, inp in [
8764                (torch.nn.ReflectionPad1d(2), torch.randn(0, 3, 10, device=device, dtype=dtype)),
8765                (torch.nn.ReflectionPad2d(2), torch.randn(0, 3, 10, 10, device=device, dtype=dtype)),
8766                (torch.nn.ReflectionPad3d(3), torch.randn(0, 3, 10, 10, 10, device=device, dtype=dtype))]:
8767            _test_module_empty_input(self, mod, inp, check_size=False)
8768
8769        with self.assertRaisesRegex(RuntimeError, '2D or 3D'):
8770            mod = torch.nn.ReflectionPad1d(2)
8771            inp = torch.randn(3, 0, 10, device=device, dtype=dtype)
8772            mod(inp)
8773
8774        with self.assertRaisesRegex(RuntimeError, '3D or 4D'):
8775            mod = torch.nn.ReflectionPad2d(2)
8776            inp = torch.randn(3, 0, 10, 10, device=device, dtype=dtype)
8777            mod(inp)
8778
8779        with self.assertRaisesRegex(RuntimeError, '4D or 5D'):
8780            mod = torch.nn.ReflectionPad3d(3)
8781            inp = torch.randn(3, 0, 10, 10, 10, device=device, dtype=dtype)
8782            mod(inp)
8783
8784    @onlyCUDA   # Test if CPU and GPU results match
8785    def test_ReflectionPad2d_large(self, device):
8786        shapes = ([2, 65736, 6, 6], [65736, 2, 6, 6])
8787        pad = (1, 2, 3, 4)
8788        for shape in shapes:
8789            x = torch.randn(shape, device=device, requires_grad=True)
8790            ref_x = x.detach().cpu().requires_grad_()
8791
8792            out = F.pad(x, pad, mode='reflect')
8793            ref_out = F.pad(ref_x, pad, mode='reflect')
8794
8795            self.assertEqual(out, ref_out)
8796
8797            g = torch.randn_like(out)
8798            ref_g = g.cpu()
8799
8800            out.backward(g)
8801            ref_out.backward(ref_g)
8802
8803            self.assertEqual(x.grad, ref_x.grad)
8804
8805    @onlyNativeDeviceTypes
8806    def test_LocalResponseNorm_empty(self, device):
8807        mod = torch.nn.LocalResponseNorm(2).to(device)
8808        inp = torch.ones(0, 5, 24, 24, device=device)
8809        _test_module_empty_input(self, mod, inp, check_size=False)
8810
8811    @onlyCUDA   # Test if CPU and GPU results match
8812    def test_ReflectionPad3d_large(self, device):
8813        shapes = ([2, 1000, 7, 7, 7], [1000, 2, 7, 7, 7])
8814        pad = (1, 2, 3, 4, 5, 6)
8815        for shape in shapes:
8816            x = torch.randn(shape, device=device, requires_grad=True)
8817            ref_x = x.detach().cpu().requires_grad_()
8818
8819            out = F.pad(x, pad, mode='reflect')
8820            ref_out = F.pad(ref_x, pad, mode='reflect')
8821
8822            self.assertEqual(out, ref_out)
8823
8824            g = torch.randn_like(out)
8825            ref_g = g.cpu()
8826
8827            out.backward(g)
8828            ref_out.backward(ref_g)
8829
8830            self.assertEqual(x.grad, ref_x.grad)
8831
8832    @onlyNativeDeviceTypes
8833    @dtypes(torch.float, torch.double)
8834    def test_MarginLoss_empty(self, device, dtype):
8835        for mod, x, y in [
8836                (torch.nn.MultiMarginLoss().to(device),
8837                 torch.randn(0, 10, requires_grad=True, device=device, dtype=dtype),
8838                 torch.ones(0, device=device).type(torch.long)),
8839                (torch.nn.MultiLabelMarginLoss().to(device),
8840                 torch.randn(0, 10, requires_grad=True, device=device, dtype=dtype),
8841                 torch.ones(0, 10, device=device).type(torch.long))]:
8842
8843            out = mod(x, y)
8844            out.sum().backward()
8845
8846            self.assertEqual(x, torch.zeros_like(x))
8847            self.assertEqual(x.grad, torch.zeros_like(x))
8848
8849            with self.assertRaisesRegex(RuntimeError, 'Expected'):
8850                x = torch.randn(0, requires_grad=True, device=device, dtype=dtype)
8851                y = torch.ones(10, device=device).type(torch.long)
8852                mod(x, y)
8853
8854            with self.assertRaisesRegex(RuntimeError, 'Expected'):
8855                x = torch.randn(10, 0, requires_grad=True, device=device, dtype=dtype)
8856                y = torch.ones(10, 0, device=device).type(torch.long)
8857                mod(x, y)
8858
8859    @onlyCUDA
8860    def test_MarginLoss_warnings(self, device):
8861        model = torch.nn.Linear(128, 22, device=device)
8862        loss = torch.nn.MultiMarginLoss()
8863        x = torch.rand((56, 128), device=device)
8864        targets = torch.randint(22, (56,), device=device)
8865        f = io.StringIO()
8866        with contextlib.redirect_stderr(f):
8867            out = model(x)
8868            l = loss(out, targets)
8869            l.backward()
8870        self.assertTrue(len(f.getvalue()) == 0)
8871
8872    @onlyNativeDeviceTypes
8873    def test_Unfold_empty(self, device):
8874        inp = torch.randn(0, 3, 3, 4, device=device)
8875        unfold = torch.nn.Unfold(kernel_size=(2, 3)).to(device)
8876        _test_module_empty_input(self, unfold, inp, check_size=False)
8877
8878        with self.assertRaisesRegex(RuntimeError, 'Expected 3D or 4D'):
8879            inp = torch.randn(3, 0, 3, 4, device=device)
8880            unfold = torch.nn.Unfold(kernel_size=(2, 3)).to(device)
8881            unfold(inp)
8882
8883    @onlyCUDA
8884    @dtypes(torch.float, torch.double)
8885    @tf32_on_and_off(0.005)
8886    def test_rnn_fused(self, device, dtype):
8887
8888        def copy_rnn(rnn1, rnn2):
8889            for x_layer, y_layer in zip(rnn1.all_weights, rnn2.all_weights):
8890                for x, y in zip(x_layer, y_layer):
8891                    x.data.copy_(y.data)
8892
8893        def check_rnn_grads(rnn1, rnn2):
8894            for x_layer, y_layer in zip(rnn1.all_weights, rnn2.all_weights):
8895                for x, y in zip(x_layer, y_layer):
8896                    self.assertEqual(x.grad, y.grad, atol=5e-5, rtol=0)
8897
8898        input_size = 10
8899        hidden_size = 6
8900        num_layers = 2
8901        seq_length = 7
8902        batch = 6
8903        input_val = torch.randn(seq_length, batch, input_size, dtype=dtype)
8904        grad_output = torch.randn(seq_length, batch, hidden_size, dtype=dtype)
8905        hx_val = torch.randn(num_layers, batch, hidden_size, dtype=dtype)
8906        grad_hy = torch.randn(num_layers, batch, hidden_size, dtype=dtype)
8907        with torch.backends.cudnn.flags(enabled=False, allow_tf32=None):
8908            for module in (nn.GRU, nn.LSTM):
8909                for bias in (True, False):
8910                    rnn = module(input_size, hidden_size, num_layers, bias=bias).to(dtype)
8911                    rnn_device = module(input_size, hidden_size, num_layers, bias=bias).to(device, dtype)
8912                    copy_rnn(rnn, rnn_device)
8913
8914                    is_lstm = isinstance(rnn, nn.LSTM)
8915                    if is_lstm:
8916                        hx = (hx_val.clone().requires_grad_(True),
8917                              hx_val.clone().add(1).requires_grad_(True))
8918                        hx_device = (hx_val.clone().to(device).requires_grad_(True),
8919                                     hx_val.clone().to(device).add(1).requires_grad_(True))
8920                    else:
8921                        hx = hx_val.clone().requires_grad_(True)
8922                        hx_device = hx_val.clone().to(device).requires_grad_(True)
8923
8924                    inp = input_val.clone().requires_grad_(True)
8925                    inp_cu = input_val.clone().to(device).requires_grad_(True)
8926                    output1, hy1 = rnn(inp, hx)
8927                    output2, hy2 = rnn_device(inp_cu, hx_device)
8928                    if is_lstm:
8929                        torch.autograd.backward(
8930                            [output1, hy1[0], hy1[1]], [grad_output, grad_hy, grad_hy + 1]
8931                        )
8932                        torch.autograd.backward(
8933                            [output2, hy2[0], hy2[1]],
8934                            [grad_output.to(device), grad_hy.to(device), (grad_hy + 1).to(device)]
8935                        )
8936                    else:
8937                        torch.autograd.backward([output1, hy1], [grad_output, grad_hy])
8938                        torch.autograd.backward([output2, hy2], [grad_output.to(device), grad_hy.to(device)])
8939
8940                    self.assertEqual(output1, output2)
8941                    self.assertEqual(hy1, hy2)
8942
8943                    check_rnn_grads(rnn, rnn_device)
8944                    self.assertEqual(inp.grad, inp_cu.grad)
8945                    if is_lstm:
8946                        self.assertEqual(hx[0].grad, hx_device[0].grad)
8947                        self.assertEqual(hx[1].grad, hx_device[1].grad)
8948                    else:
8949                        self.assertEqual(hx.grad, hx_device.grad)
8950
8951    @dtypesIfMPS(torch.float)
8952    @dtypes(torch.double)
8953    def test_BatchNorm_empty(self, device, dtype):
8954        mod = torch.nn.BatchNorm2d(3).to(device)
8955        inp = torch.randn(0, 3, 2, 2, device=device, dtype=dtype)
8956        _test_module_empty_input(self, mod, inp)
8957        if self.device_type == 'cuda' and self.has_cudnn():
8958            with torch.backends.cudnn.flags(enabled=False):
8959                _test_module_empty_input(self, mod, inp)
8960
8961        self.assertEqual(mod.running_mean, torch.tensor([0., 0, 0], device=device))
8962        self.assertEqual(mod.running_var, torch.tensor([1., 1, 1], device=device))
8963        self.assertEqual(mod.weight.grad, torch.tensor([0., 0, 0], device=device))
8964        self.assertEqual(mod.bias.grad, torch.tensor([0., 0, 0], device=device))
8965
8966    @onlyCUDA
8967    @largeTensorTest('16GB')
8968    def test_prelu_backward_32bit_indexing(self, device):
8969        m = torch.nn.PReLU().cuda().half()
8970        input_ = torch.ones((1024, 1024, 1024, 2), dtype=torch.half, device=device)
8971        output = m(input_)
8972        output.backward(input_)
8973
8974    def test_linear_empty(self, device):
8975        mod = torch.nn.Linear(7, 7).to(device)
8976        inp = torch.randn(0, 7, device=device)
8977        _test_module_empty_input(self, mod, inp)
8978
8979    def test_one_hot(self, device):
8980        # cuda throws device assert for invalid data
8981        # xla ignores out of bound indices
8982        if self.device_type not in ('cuda', 'mps', 'xla'):
8983            with self.assertRaises(RuntimeError):
8984                torch.nn.functional.one_hot(torch.tensor([3, 4, -1, 0], device=device), -1)
8985
8986            with self.assertRaises(RuntimeError):
8987                torch.nn.functional.one_hot(torch.tensor([3, 4, 1, 0], device=device), 3)
8988
8989        t = torch.nn.functional.one_hot(torch.tensor([3, 4, 1, 0], device=device))
8990        expected = torch.tensor([[0, 0, 0, 1, 0],
8991                                 [0, 0, 0, 0, 1],
8992                                 [0, 1, 0, 0, 0],
8993                                 [1, 0, 0, 0, 0]], device=device)
8994        self.assertEqual(t, expected)
8995
8996        t = torch.nn.functional.one_hot(torch.tensor([3, 4, 1, 0], device=device), -1)
8997        expected = torch.tensor([[0, 0, 0, 1, 0],
8998                                 [0, 0, 0, 0, 1],
8999                                 [0, 1, 0, 0, 0],
9000                                 [1, 0, 0, 0, 0]], device=device)
9001        self.assertEqual(t, expected)
9002
9003        t = torch.nn.functional.one_hot(torch.tensor([3, 4, 1, 0], device=device), 6)
9004        expected = torch.tensor([[0, 0, 0, 1, 0, 0],
9005                                 [0, 0, 0, 0, 1, 0],
9006                                 [0, 1, 0, 0, 0, 0],
9007                                 [1, 0, 0, 0, 0, 0]], device=device)
9008        self.assertEqual(t, expected)
9009
9010        t = torch.nn.functional.one_hot(torch.tensor([[3, 4], [1, 0]], device=device))
9011        expected = torch.tensor([[[0, 0, 0, 1, 0],
9012                                  [0, 0, 0, 0, 1]],
9013                                 [[0, 1, 0, 0, 0],
9014                                  [1, 0, 0, 0, 0]]], device=device)
9015        self.assertEqual(t, expected)
9016
9017        t = torch.nn.functional.one_hot(torch.tensor(4, device=device))
9018        expected = torch.tensor([0, 0, 0, 0, 1], device=device)
9019        self.assertEqual(t, expected)
9020
9021        t = torch.nn.functional.one_hot(torch.empty([4, 0], dtype=torch.long, device=device), 100)
9022        expected = torch.empty([4, 0, 100], dtype=torch.long)
9023        self.assertEqual(t, expected)
9024
9025        with self.assertRaises(RuntimeError):
9026            torch.nn.functional.one_hot(torch.empty([4, 0], dtype=torch.long, device=device))
9027
9028        with self.assertRaises(RuntimeError):
9029            torch.nn.functional.one_hot(torch.tensor([3, 4, 1, 0], device=device), -2)
9030
9031    @expectedFailureMPS  # NotImplementedError: aten::rrelu_with_noise https://github.com/pytorch/pytorch/issues/77764
9032    def test_nn_empty(self, device):
9033        # One off tests to ensure scalars from nn.yaml are properly applied
9034        def verify_scalars(input, output):
9035            self.assertEqual(input.shape, output.shape)
9036            self.assertEqual(0, output.numel())
9037
9038        for input_shape in [(0), (0, 2)]:
9039            for module in [torch.nn.ELU, torch.nn.Hardtanh, torch.nn.LeakyReLU, torch.nn.LogSigmoid,
9040                           torch.nn.RReLU, torch.nn.Softshrink, torch.nn.Softplus, torch.nn.Sigmoid,
9041                           torch.nn.Tanh]:
9042                input = torch.randn(input_shape, device=device, requires_grad=True)
9043                m = module()
9044                output = m(input)
9045                verify_scalars(input, output)
9046
9047    @expectedFailureMPS  # NotImplementedError: aten::rrelu_with_noise https://github.com/pytorch/pytorch/issues/77764
9048    def test_nn_scalars(self, device):
9049        # One off tests to ensure scalars from nn.yaml are properly applied
9050        def verify_scalars(input, output):
9051            if input.dim() == 0:
9052                self.assertEqual((), output.shape)
9053            else:
9054                self.assertNotEqual((), output.shape)
9055            output.sum().backward()
9056            self.assertEqual(input.shape, input.grad.shape)
9057
9058        for input_shape in [(5, 6), ()]:
9059            for module in [torch.nn.ELU, torch.nn.Hardtanh, torch.nn.LeakyReLU, torch.nn.LogSigmoid,
9060                           torch.nn.RReLU, torch.nn.Softshrink, torch.nn.Softplus, torch.nn.Sigmoid,
9061                           torch.nn.Tanh]:
9062                input = torch.randn(input_shape, device=device, requires_grad=True)
9063                m = module()
9064                output = m(input)
9065                verify_scalars(input, output)
9066
9067    def test_nn_scalars_reductions(self, device):
9068        # One off tests to ensure scalars from nn.yaml are properly applied
9069        def verify_reduction_scalars(input, reduction, output):
9070            if reduction != 'none' or input.dim() == 0:
9071                self.assertEqual((), output.shape)
9072            else:
9073                self.assertNotEqual((), output.shape)
9074            output.sum().backward()
9075            self.assertEqual(input.shape, input.grad.shape)
9076
9077        for input_shape in [(5, 6), ()]:
9078            for reduction in ['none', 'mean', 'sum']:
9079                for module in [torch.nn.BCELoss, torch.nn.L1Loss, torch.nn.MSELoss,
9080                               torch.nn.SmoothL1Loss, torch.nn.SoftMarginLoss]:
9081                    input = torch.randn(input_shape, device=device, requires_grad=True)
9082                    target = torch.empty(input_shape, device=device).random_(2)
9083                    sigmoid = nn.Sigmoid()
9084
9085                    input = torch.randn(input_shape, device=device, requires_grad=True)
9086                    m = module(reduction=reduction)
9087                    output = m(sigmoid(input), target)
9088                    verify_reduction_scalars(input, reduction, output)
9089
9090    # verify that bogus reduction strings are errors
9091    @onlyNativeDeviceTypes
9092    def test_invalid_reduction_strings(self, device):
9093        input = torch.randn(3, 5, requires_grad=True, device=device)
9094        cinput = torch.randn(3, 5, requires_grad=True, device=device, dtype=torch.cfloat)
9095        target = torch.tensor([1, 0, 4], device=device)
9096        var = torch.ones(size=input.size(), requires_grad=True, device=device)
9097
9098        for reduction in ['none', 'invalid']:
9099            def v(fn):
9100                if reduction == 'invalid':
9101                    self.assertRaises(ValueError, lambda: fn())
9102                else:
9103                    fn()
9104
9105            v(lambda: F.nll_loss(input, target, reduction=reduction))
9106            v(lambda: F.cross_entropy(input, target, reduction=reduction))
9107
9108            v(lambda: F.kl_div(input, input, reduction=reduction))
9109            v(lambda: F.huber_loss(input, input, reduction=reduction))
9110            v(lambda: F.smooth_l1_loss(input, input, reduction=reduction))
9111            v(lambda: F.l1_loss(input, input, reduction=reduction))
9112            v(lambda: F.l1_loss(cinput, cinput, reduction=reduction))
9113            v(lambda: F.mse_loss(input, input, reduction=reduction))
9114            v(lambda: F.hinge_embedding_loss(input, input, reduction=reduction))
9115            v(lambda: F.poisson_nll_loss(input, input, reduction=reduction))
9116            v(lambda: F.gaussian_nll_loss(input, input, var, reduction=reduction))
9117            v(lambda: F.binary_cross_entropy(torch.sigmoid(input), input.gt(0).to(torch.get_default_dtype()), reduction=reduction))
9118            v(lambda: F.binary_cross_entropy_with_logits(input, input, reduction=reduction))
9119
9120            zeros = torch.zeros_like(input).to(torch.int64)
9121            v(lambda: F.multilabel_soft_margin_loss(input, zeros, reduction=reduction))
9122
9123            v(lambda: F.triplet_margin_loss(input, input, input, reduction=reduction))
9124            v(lambda: F.triplet_margin_with_distance_loss(input, input, input, reduction=reduction))
9125            v(lambda: F.margin_ranking_loss(input, input, input.sign(), reduction=reduction))
9126            v(lambda: F.cosine_embedding_loss(input, input, input[:, 0].sign(), reduction=reduction))
9127
9128            log_probs = torch.randn(50, 16, 20, requires_grad=True, device=device).log_softmax(2)
9129            targets = torch.randint(1, 20, (16, 30), dtype=torch.long, device=device)
9130            input_lengths = torch.full((16,), 50, dtype=torch.long, device=device)
9131            target_lengths = torch.randint(10, 30, (16,), dtype=torch.long, device=device)
9132            v(lambda: F.ctc_loss(log_probs, targets, input_lengths, target_lengths, reduction=reduction))
9133
9134            # FIXME: should we allow derivatives on these?
9135            v(lambda: F.soft_margin_loss(input, input.sign().detach(), reduction=reduction))
9136
9137    @onlyNativeDeviceTypes
9138    def test_smooth_l1_loss_vs_huber_loss(self, device):
9139        def _make_test_tensor(shape, contiguous=True):
9140            if contiguous:
9141                test_tensor = torch.randn(shape, device=device)
9142            else:
9143                # Select every other element in the innermost dimension to
9144                # make it non-contiguous.
9145                doubled_shape = list(shape)
9146                doubled_shape[-1] *= 2
9147                test_tensor = torch.randn(doubled_shape, device=device)
9148                test_tensor = test_tensor[..., ::2]
9149            return test_tensor
9150
9151        def _test_smooth_l1_loss_vs_huber_loss_helper(input, target, beta, require_equal):
9152            for reduction in ['mean', 'sum', 'none']:
9153                smooth_l1 = torch.nn.SmoothL1Loss(beta=beta, reduction=reduction)
9154                # beta hyper-parameter is called delta for Huber
9155                huber = torch.nn.HuberLoss(delta=beta, reduction=reduction)
9156                smooth_l1_loss = smooth_l1(input, target)
9157                huber_loss = huber(input, target)
9158
9159                if require_equal:
9160                    self.assertEqual(smooth_l1_loss, huber_loss)
9161                else:
9162                    # Huber loss should be larger than smooth L1 loss by a factor of beta.
9163                    self.assertEqual(smooth_l1_loss * beta, huber_loss)
9164
9165        def _test_smooth_l1_loss_vs_huber_loss_multi_input_helper(beta, require_equal):
9166            # Test the non-vectorized case.
9167            shape = (2, 2)
9168            _test_smooth_l1_loss_vs_huber_loss_helper(input=_make_test_tensor(shape),
9169                                                      target=_make_test_tensor(shape),
9170                                                      beta=beta,
9171                                                      require_equal=require_equal)
9172
9173            # Test the vectorized case (innermost dim > 32).
9174            shape = (64, 64)
9175            _test_smooth_l1_loss_vs_huber_loss_helper(input=_make_test_tensor(shape),
9176                                                      target=_make_test_tensor(shape),
9177                                                      beta=beta,
9178                                                      require_equal=require_equal)
9179
9180            # Test the non-contiguous case.
9181            _test_smooth_l1_loss_vs_huber_loss_helper(input=_make_test_tensor(shape, contiguous=False),
9182                                                      target=_make_test_tensor(shape, contiguous=False),
9183                                                      beta=beta,
9184                                                      require_equal=require_equal)
9185
9186        def test_equal_when_beta_is_one():
9187            _test_smooth_l1_loss_vs_huber_loss_multi_input_helper(beta=1.0, require_equal=True)
9188
9189        def test_unequal_when_beta_is_less_than_one():
9190            _test_smooth_l1_loss_vs_huber_loss_multi_input_helper(beta=0.5, require_equal=False)
9191
9192        def test_unequal_when_beta_is_greater_than_one():
9193            _test_smooth_l1_loss_vs_huber_loss_multi_input_helper(beta=1.5, require_equal=False)
9194
9195        test_equal_when_beta_is_one()
9196        test_unequal_when_beta_is_less_than_one()
9197        test_unequal_when_beta_is_greater_than_one()
9198
9199    @onlyCPU
9200    def test_smooth_l1_loss_bfloat16(self, device):
9201        def test_dtype(fn, input, target, dtype):
9202            input = input.detach().clone().to(dtype=dtype).requires_grad_(True)
9203            input2 = input.detach().clone().float().requires_grad_(True)
9204            target = target.detach().clone().to(dtype=dtype)
9205            target2 = target.detach().clone().float()
9206            out = fn(input, target)
9207            out.sum().backward()
9208            out2 = fn(input2, target2)
9209            out2.sum().backward()
9210            self.assertEqual(out.dtype, dtype)
9211            self.assertEqual(input.grad.dtype, dtype)
9212            self.assertEqual(out, out2, exact_dtype=False)
9213            self.assertEqual(input.grad, input2.grad, exact_dtype=False)
9214
9215        def func(device):
9216            return nn.SmoothL1Loss().to(device=device)
9217
9218        shapes = [[1, 3, 1, 6], [1, 3, 1, 128], [1, 3, 128, 128]]
9219        for shape in shapes:
9220            x = torch.randn(shape, device=device, requires_grad=True)
9221            t = torch.randn(shape, device=device)
9222            test_dtype(func(device), x, t, torch.bfloat16)
9223
9224    # We don't want to make propagating NaN a hard requirement on ops, but for
9225    # these easy ones, we should make them do so.
9226    # MPS: NotImplementedError: aten::rrelu_with_noise_ https://github.com/pytorch/pytorch/issues/77764
9227    # MPS: NotImplementedError: aten::hardshrink.out https://github.com/pytorch/pytorch/issues/77764
9228    @expectedFailureMPS
9229    def test_nonlinearity_propagate_nan(self, device):
9230        def test(nonlinearity, *args, **kwargs):
9231            x = torch.tensor([nan], device=device)
9232            fn = getattr(F, nonlinearity)
9233            try:
9234                self.assertTrue(math.isnan(fn(x, *args, **kwargs).item()))
9235            except Exception as e:
9236                if 'not implemented' not in str(e):
9237                    raise
9238
9239        test('relu')
9240        test('relu', inplace=True)
9241        test('relu6')
9242        test('elu')
9243        test('selu')
9244        test('celu')
9245        test('rrelu')
9246        test('rrelu', inplace=True)
9247        test('hardtanh')
9248        test('tanh')
9249        test('sigmoid')
9250        test('logsigmoid')
9251        test('hardshrink')
9252        test('tanhshrink')
9253        test('softsign')
9254        test('softmin', 0)
9255        test('softmax', 0)
9256        test('log_softmax', 0)
9257        test('leaky_relu', 0.2)
9258        test('threshold', 3, 2)
9259        test('threshold', 3, 2, inplace=True)
9260
9261    @expectedFailureMPS  # TypeError: float64 the MPS framework doesn't support float64
9262    @parametrize_test("mode", ["nearest-exact", "nearest"])
9263    def test_upsamplingNearest1d(self, device, mode):
9264        # Forward AD does not support XLA because XLA tensors don't have storage
9265        check_forward_ad = torch.device(device).type != 'xla'
9266
9267        m = nn.Upsample(size=4, mode=mode)
9268        in_t = torch.ones(1, 1, 2, device=device, dtype=torch.double)
9269        in_uint8_t = torch.ones(1, 1, 2, dtype=torch.uint8, device=device)
9270        with warnings.catch_warnings(record=True) as w:
9271            out_t = m(in_t)
9272            out_uint8_t = m(in_uint8_t)
9273        self.assertEqual(torch.ones(1, 1, 4, device=device, dtype=torch.double), out_t.data)
9274        self.assertEqual(torch.ones(1, 1, 4, dtype=torch.uint8, device=device), out_uint8_t.data)
9275
9276        # Checks upsampling
9277        input = torch.randn(1, 1, 2, requires_grad=True, device=device, dtype=torch.double)
9278        gradcheck(lambda x: F.interpolate(x, 4, mode=mode), [input], check_forward_ad=check_forward_ad)
9279        gradgradcheck(lambda x: F.interpolate(x, 4, mode=mode), [input], check_fwd_over_rev=check_forward_ad)
9280
9281        # Checks downsampling
9282        input = torch.randn(1, 1, 20, requires_grad=True, device=device, dtype=torch.double)
9283        gradcheck(lambda x: F.interpolate(x, 11, mode=mode), [input], check_forward_ad=check_forward_ad)
9284        gradgradcheck(lambda x: F.interpolate(x, 4, mode=mode), [input], check_fwd_over_rev=check_forward_ad)
9285
9286        # consistency CUDA/CPU check
9287        if torch.device(device).type == 'cuda':
9288            input_cuda = torch.randn(1, 1, 20, device=device, dtype=torch.double)
9289            input_cpu = input_cuda.cpu()
9290            output_cuda = F.interpolate(input_cuda, 4, mode=mode)
9291            output_cpu = F.interpolate(input_cpu, 4, mode=mode)
9292            self.assertEqual(output_cuda.cpu(), output_cpu)
9293
9294            output_cuda = F.interpolate(input_cuda, 24, mode=mode)
9295            output_cpu = F.interpolate(input_cpu, 24, mode=mode)
9296            self.assertEqual(output_cuda.cpu(), output_cpu)
9297
9298    @parametrize_test("isize, osize", [(20, 11), (10, 15)])
9299    def test_upsamplingNearest1d_correctness(self, device, isize, osize):
9300        # Here we check if output matches OpenCV's INTER_NEAREST-like result
9301        in_t = torch.arange(isize, dtype=torch.float, device=device).unsqueeze(0).unsqueeze(0)
9302        out_t = F.interpolate(
9303            in_t, size=(osize, ), recompute_scale_factor=False, mode="nearest"
9304        )
9305        # compute expected output as OpenCV
9306        expected_out = torch.zeros(osize, dtype=torch.float).unsqueeze(0).unsqueeze(0)
9307        scale = 1.0 * isize / osize
9308        for o in range(osize):
9309            i_f32 = o * scale
9310            i = int(i_f32)
9311            expected_out[0, 0, o] = in_t[0, 0, i]
9312        expected_out = expected_out.to(device=device)
9313        self.assertEqual(out_t, expected_out)
9314
9315    def test_upsamplingNearestExact1d_rescale(self, device):
9316        # Checks https://github.com/pytorch/pytorch/issues/62237
9317        isize = 20
9318        in_t = torch.arange(isize, dtype=torch.float, device=device).unsqueeze(0).unsqueeze(0)
9319        # for s in [1.00001, 0.99999]:  # 0.9999 case is broken
9320        # See issue: https://github.com/pytorch/pytorch/issues/62396
9321        for s in [1.00001, ]:
9322            out_t = F.interpolate(
9323                in_t, scale_factor=s, recompute_scale_factor=False, mode="nearest-exact"
9324            )
9325            expected_out = in_t
9326            self.assertEqual(out_t, expected_out, msg=f"scale: {s}")
9327
9328        # checks data duplication if output_size == 2 * input_size
9329        # for s in [2.00001, 1.99999]:  # 1.99999 case is broken
9330        # See issue: https://github.com/pytorch/pytorch/issues/62396
9331        for s in [2.00001, ]:
9332            out_t = F.interpolate(
9333                in_t, scale_factor=s, recompute_scale_factor=False, mode="nearest-exact"
9334            )
9335            # input is [[[0, 1, 2, 3, ..., 9]]]
9336            # expected out is [[[0, 0, 1, 1, 2, 2, ..., 9, 9]]]
9337            expected_out = in_t.repeat_interleave(2, dim=-1)
9338            self.assertEqual(out_t, expected_out)
9339
9340    @skipIfMps  # Partially passes https://github.com/pytorch/pytorch/issues/134430
9341    @parametrize_test("isize, osize", [(20, 11), (10, 15)])
9342    def test_upsamplingNearestExact1d_correctness(self, device, isize, osize):
9343        # Here we check if output matches Scikit-Image/Scipy-like result
9344        # Checks https://github.com/pytorch/pytorch/issues/34808
9345        in_t = torch.arange(isize, dtype=torch.float, device=device).unsqueeze(0).unsqueeze(0)
9346        out_t = F.interpolate(
9347            in_t, size=(osize, ), recompute_scale_factor=False, mode="nearest-exact"
9348        )
9349        # compute expected output as scikit-image/scipy
9350        expected_out = torch.zeros(osize, dtype=torch.float).unsqueeze(0).unsqueeze(0)
9351        scale = 1.0 * isize / osize
9352        for o in range(osize):
9353            i_f32 = (o + 0.5) * scale
9354            i = int(i_f32)
9355            expected_out[0, 0, o] = in_t[0, 0, i]
9356        expected_out = expected_out.to(device=device)
9357        self.assertEqual(out_t, expected_out)
9358
9359    @expectedFailureMPS  # TypeError: the MPS framework doesn't support float64
9360    @parametrize_test("memory_format", [torch.contiguous_format, torch.channels_last])
9361    @parametrize_test("mode", ["nearest", "nearest-exact"])
9362    def test_upsamplingNearest2d(self, device, memory_format, mode):
9363        # Forward AD does not support XLA because XLA tensors don't have storage
9364        check_forward_ad = torch.device(device).type != 'xla'
9365
9366        in_t = torch.ones(1, 2, 2, 2, device=device, dtype=torch.double).contiguous(memory_format=memory_format)
9367        in_uint8_t = torch.ones(1, 2, 2, 2, dtype=torch.uint8, device=device).contiguous(memory_format=memory_format)
9368        with warnings.catch_warnings(record=True) as w:
9369            out_t = F.interpolate(in_t, size=4, mode=mode)
9370            out_uint8_t = F.interpolate(in_uint8_t, size=4, mode=mode)
9371            self.assertEqual(len(w), 0)
9372        self.assertEqual(torch.ones(1, 2, 4, 4, device=device, dtype=torch.double), out_t)
9373        self.assertEqual(torch.ones(1, 2, 4, 4, dtype=torch.uint8, device=device), out_uint8_t)
9374        # Assert that memory format is carried through to the output
9375        self.assertTrue(out_t.is_contiguous(memory_format=memory_format))
9376
9377        # test forward when input's height is not same as width
9378        in_t = torch.ones(1, 2, 2, 1, device=device, dtype=torch.double).contiguous(memory_format=memory_format).requires_grad_()
9379        out_t = F.interpolate(in_t, size=(4, 2), mode=mode)
9380        self.assertEqual(torch.ones(1, 2, 4, 2, device=device, dtype=torch.double), out_t)
9381        self.assertTrue(out_t.is_contiguous(memory_format=memory_format))
9382
9383        out_t.backward(torch.randn_like(out_t))
9384        self.assertTrue(in_t.grad.is_contiguous(memory_format=memory_format))
9385
9386        # test backward when input's height is not same as width
9387        input = torch.ones(
9388            1, 2, 2, 1, requires_grad=True, device=device,
9389            dtype=torch.double).contiguous(memory_format=memory_format)
9390        gradcheck(lambda x: F.interpolate(x, size=(4, 2), mode=mode), [input], check_forward_ad=check_forward_ad)
9391        gradgradcheck(lambda x: F.interpolate(x, size=(4, 2), mode=mode), [input], check_fwd_over_rev=check_forward_ad)
9392
9393        input = torch.randn(
9394            1, 2, 2, 2, requires_grad=True, device=device,
9395            dtype=torch.double).contiguous(memory_format=memory_format)
9396        self.assertEqual(
9397            F.interpolate(input, 4, mode=mode),
9398            F.interpolate(input, scale_factor=2, mode=mode))
9399        gradcheck(lambda x: F.interpolate(x, 4, mode=mode), [input], check_forward_ad=check_forward_ad)
9400        gradgradcheck(lambda x: F.interpolate(x, 4, mode=mode), [input], check_fwd_over_rev=check_forward_ad)
9401
9402        # Assert that cpu and cuda handle channels_last memory format in the same way
9403        # https://github.com/pytorch/pytorch/issues/54590
9404        if torch.device(device).type == 'cuda':
9405            for shapes, scale_factor in product([
9406                (2, 2, 3, 4), (2, 3, 4, 5), (3, 1, 2, 2), (1, 5, 3, 2)
9407            ], [0.5, 1.5, 2]):
9408                a_cuda = torch.randn(
9409                    *shapes, device=device,
9410                    dtype=torch.double).contiguous(memory_format=memory_format).requires_grad_()
9411                a_cpu = a_cuda.detach().cpu().requires_grad_()
9412
9413                out_cuda = F.interpolate(a_cuda, scale_factor=scale_factor, mode=mode)
9414                out_cpu = F.interpolate(a_cpu, scale_factor=scale_factor, mode=mode)
9415
9416                self.assertEqual(out_cpu.cuda(), out_cuda)
9417
9418                g_cuda = torch.randn_like(out_cuda)
9419                g_cpu = g_cuda.cpu()
9420
9421                out_cuda.backward(g_cuda)
9422                out_cpu.backward(g_cpu)
9423
9424                self.assertEqual(a_cuda.grad, a_cpu.grad)
9425
9426    @parametrize_test("memory_format", [torch.contiguous_format, torch.channels_last])
9427    @parametrize_test("isize, osize", [(20, 11), (10, 15)])
9428    def test_upsamplingNearest2d_correctness(self, device, memory_format, isize, osize):
9429        # Here we check if output matches OpenCV's INTER_NEAREST-like result
9430        in_t = torch.arange(isize * isize, dtype=torch.float, device=device).reshape(1, 1, isize, isize)
9431        in_t = in_t.contiguous(memory_format=memory_format)
9432        out_t = F.interpolate(
9433            in_t, size=(osize, osize), recompute_scale_factor=False, mode="nearest"
9434        )
9435        # compute expected output as OpenCV
9436        expected_out = torch.zeros(1, 1, osize, osize, dtype=torch.float)
9437        scale = 1.0 * isize / osize
9438        for o1 in range(osize):
9439            i1_f32 = o1 * scale
9440            i1 = int(i1_f32)
9441            for o2 in range(osize):
9442                i2_f32 = o2 * scale
9443                i2 = int(i2_f32)
9444                expected_out[0, 0, o1, o2] = in_t[0, 0, i1, i2]
9445        expected_out = expected_out.to(device=device)
9446        self.assertEqual(out_t, expected_out)
9447
9448    @skipIfMps  # Partially passes https://github.com/pytorch/pytorch/issues/134430
9449    @parametrize_test("memory_format", [torch.contiguous_format, torch.channels_last])
9450    @parametrize_test("isize, osize", [(20, 11), (10, 15)])
9451    def test_upsamplingNearestExact2d_correctness(self, device, memory_format, isize, osize):
9452        # Here we check if output matches Scikit-Image/Scipy-like result
9453        # Checks https://github.com/pytorch/pytorch/issues/34808
9454        in_t = torch.arange(isize * isize, dtype=torch.float, device=device).reshape(1, 1, isize, isize)
9455        in_t = in_t.contiguous(memory_format=memory_format)
9456        out_t = F.interpolate(
9457            in_t, size=(osize, osize), recompute_scale_factor=False, mode="nearest-exact"
9458        )
9459        # compute expected output as Scikit-Image/Scipy
9460        expected_out = torch.zeros(1, 1, osize, osize, dtype=torch.float)
9461        scale = 1.0 * isize / osize
9462        for o1 in range(osize):
9463            i1_f32 = (o1 + 0.5) * scale
9464            i1 = int(i1_f32)
9465            for o2 in range(osize):
9466                i2_f32 = (o2 + 0.5) * scale
9467                i2 = int(i2_f32)
9468                expected_out[0, 0, o1, o2] = in_t[0, 0, i1, i2]
9469        expected_out = expected_out.to(device=device)
9470        self.assertEqual(out_t, expected_out)
9471
9472    @expectedFailureMPS  # TypeError: the MPS framework doesn't support float64
9473    @parametrize_test("memory_format", [torch.contiguous_format, torch.channels_last_3d])
9474    @parametrize_test("mode", ["nearest", "nearest-exact"])
9475    def test_upsamplingNearest3d(self, device, memory_format, mode):
9476        # Forward AD does not support XLA because XLA tensors don't have storage
9477        check_forward_ad = torch.device(device).type != 'xla'
9478
9479        m = nn.Upsample(size=4, mode=mode)
9480        in_t = torch.ones(1, 2, 2, 2, 2, device=device, dtype=torch.double).contiguous(memory_format=memory_format).requires_grad_()
9481        in_uint8_t = torch.ones(
9482            1, 2, 2, 2, 2, dtype=torch.uint8, device=device
9483        ).contiguous(memory_format=memory_format)
9484        with warnings.catch_warnings(record=True) as w:
9485            out_t = m(in_t)
9486            out_uint8_t = m(in_uint8_t)
9487        expected_output = torch.ones(1, 2, 4, 4, 4, device=device, dtype=torch.double)
9488        self.assertEqual(expected_output, out_t)
9489        self.assertEqual(expected_output.to(torch.uint8), out_uint8_t)
9490        # Assert that memory format is carried through to the output
9491        self.assertTrue(out_t.is_contiguous(memory_format=memory_format))
9492        out_t.backward(torch.randn_like(out_t))
9493        self.assertTrue(in_t.grad.is_contiguous(memory_format=memory_format))
9494
9495        input = torch.randn(
9496            1, 2, 2, 2, 2, requires_grad=True, device=device, dtype=torch.double
9497        ).contiguous(memory_format=memory_format)
9498        gradcheck(lambda x: F.interpolate(x, 4, mode=mode), [input], check_forward_ad=check_forward_ad)
9499        gradgradcheck(lambda x: F.interpolate(x, 4, mode=mode), [input], check_fwd_over_rev=check_forward_ad)
9500
9501        # Assert that cpu and cuda handle channels_last memory format in the same way
9502        # https://github.com/pytorch/pytorch/issues/54590
9503        if torch.device(device).type == 'cuda':
9504            a = torch.ones(
9505                2, 2, 2, 3, 4, device=device, requires_grad=True, dtype=torch.double
9506            ).contiguous(memory_format=torch.channels_last_3d)
9507            # make the data asymmetric; ensure that cuda/cpu handle channels_last appropriately.
9508            a[1][1][1][2][2] = a[1][1][1][2][3] = 0
9509
9510            out_cuda = torch.nn.functional.interpolate(a, scale_factor=2, mode=mode)
9511            out_cpu = torch.nn.functional.interpolate(a.to('cpu'), scale_factor=2, mode=mode)
9512            self.assertEqual(out_cpu, out_cuda.to('cpu'))
9513
9514            gradcheck(lambda x: F.interpolate(x, 4, mode=mode), [a], check_forward_ad=check_forward_ad)
9515            gradgradcheck(lambda x: F.interpolate(x, 4, mode=mode), [a], check_fwd_over_rev=check_forward_ad)
9516
9517            gradcheck(lambda x: F.interpolate(x, 4, mode=mode), [a.to('cuda')], check_forward_ad=check_forward_ad)
9518            gradgradcheck(lambda x: F.interpolate(x, 4, mode=mode), [a.to('cuda')], check_fwd_over_rev=check_forward_ad)
9519
9520    @parametrize_test("memory_format", [torch.contiguous_format, torch.channels_last_3d])
9521    @parametrize_test("isize, osize", [(20, 11), (10, 15)])
9522    def test_upsamplingNearest3d_correctness(self, device, memory_format, isize, osize):
9523        # Here we check if output matches OpenCV's INTER_NEAREST-like result
9524        in_t = torch.arange(isize * isize * isize, dtype=torch.float, device=device)
9525        in_t = in_t.reshape(1, 1, isize, isize, isize)
9526        in_t = in_t.contiguous(memory_format=memory_format)
9527        out_t = F.interpolate(
9528            in_t, size=(osize, osize, osize), recompute_scale_factor=False, mode="nearest"
9529        )
9530        # compute expected output as OpenCV
9531        expected_out = torch.zeros(1, 1, osize, osize, osize, dtype=torch.float)
9532        scale = 1.0 * isize / osize
9533        for o1 in range(osize):
9534            i1_f32 = o1 * scale
9535            i1 = int(i1_f32)
9536            for o2 in range(osize):
9537                i2_f32 = o2 * scale
9538                i2 = int(i2_f32)
9539                for o3 in range(osize):
9540                    i3_f32 = o3 * scale
9541                    i3 = int(i3_f32)
9542                    expected_out[0, 0, o1, o2, o3] = in_t[0, 0, i1, i2, i3]
9543        expected_out = expected_out.to(device=device)
9544        self.assertEqual(out_t, expected_out)
9545
9546    @expectedFailureMPS  # NotImplementedError: aten::_upsample_nearest_exact3d.out https://github.com/pytorch/pytorch/issues/77764
9547    @parametrize_test("memory_format", [torch.contiguous_format, torch.channels_last_3d])
9548    @parametrize_test("isize, osize", [(20, 11), (10, 15)])
9549    def test_upsamplingNearestExact3d_correctness(self, device, memory_format, isize, osize):
9550        # Here we check if output matches Scikit-Image/Scipy-like result
9551        # Checks https://github.com/pytorch/pytorch/issues/34808
9552        in_t = torch.arange(isize * isize * isize, dtype=torch.float, device=device)
9553        in_t = in_t.reshape(1, 1, isize, isize, isize)
9554        in_t = in_t.contiguous(memory_format=memory_format)
9555        out_t = F.interpolate(
9556            in_t, size=(osize, osize, osize), recompute_scale_factor=False, mode="nearest-exact"
9557        )
9558        # compute expected output as Scikit-Image/Scipy
9559        expected_out = torch.zeros(1, 1, osize, osize, osize, dtype=torch.float)
9560        scale = 1.0 * isize / osize
9561        for o1 in range(osize):
9562            i1_f32 = (o1 + 0.5) * scale
9563            i1 = int(i1_f32)
9564            for o2 in range(osize):
9565                i2_f32 = (o2 + 0.5) * scale
9566                i2 = int(i2_f32)
9567                for o3 in range(osize):
9568                    i3_f32 = (o3 + 0.5) * scale
9569                    i3 = int(i3_f32)
9570                    expected_out[0, 0, o1, o2, o3] = in_t[0, 0, i1, i2, i3]
9571        expected_out = expected_out.to(device=device)
9572        self.assertEqual(out_t, expected_out)
9573
9574    @parametrize_test("antialias", [True, False])
9575    @parametrize_test("align_corners", [True, False])
9576    @parametrize_test("mode", ["bilinear", "bicubic"])
9577    @parametrize_test("memory_format", [torch.contiguous_format, torch.channels_last])
9578    @onlyNativeDeviceTypes
9579    def test_upsamplingBiMode2d(self, device, antialias, align_corners, mode, memory_format):
9580        # Forward AD does not support XLA because XLA tensors don't have storage
9581        check_forward_ad = torch.device(device).type != 'xla'
9582
9583        kwargs = dict(mode=mode, align_corners=align_corners, antialias=antialias)
9584        # test float scale factor up & downsampling
9585        for scale_factor in [0.5, 1.5, 2]:
9586            in_t = torch.ones(
9587                2, 3, 8, 8, device=device,
9588                dtype=torch.double).contiguous(memory_format=memory_format).requires_grad_()
9589            out_size = int(math.floor(in_t.shape[-1] * scale_factor))
9590            with warnings.catch_warnings(record=True) as w:
9591                out_t = F.interpolate(in_t, scale_factor=scale_factor, **kwargs)
9592            expected_out = torch.ones(2, 3, out_size, out_size, device=device, dtype=torch.double)
9593            self.assertEqual(expected_out, out_t)
9594            # Assert that memory format is carried through to the output
9595            self.assertTrue(out_t.is_contiguous(memory_format=memory_format))
9596            out_t.backward(torch.randn_like(out_t))
9597            self.assertTrue(in_t.grad.is_contiguous(memory_format=memory_format))
9598
9599            if torch.device(device).type == 'cuda':
9600                # Bilinear backward is nondeterministic because of atomicAdd usage
9601                nondet_tol = 1e-5
9602            else:
9603                nondet_tol = 0.0
9604
9605            input = torch.randn(
9606                2, 3, 8, 8, device=device,
9607                dtype=torch.double).contiguous(memory_format=memory_format).requires_grad_()
9608            gradcheck(
9609                lambda x: F.interpolate(x, out_size, **kwargs),
9610                [input],
9611                check_forward_ad=check_forward_ad, nondet_tol=nondet_tol
9612            )
9613            gradgradcheck(
9614                lambda x: F.interpolate(x, out_size, **kwargs),
9615                [input],
9616                check_fwd_over_rev=check_forward_ad, nondet_tol=nondet_tol
9617            )
9618
9619            # Assert that cpu and cuda give same results
9620            if torch.device(device).type == 'cuda':
9621                for shapes in [
9622                    (2, 2, 3, 4), (2, 3, 4, 5), (3, 1, 2, 2), (1, 5, 3, 2)
9623                ]:
9624                    a_cuda = torch.randn(
9625                        *shapes, device=device, dtype=torch.double
9626                    ).contiguous(memory_format=memory_format).requires_grad_()
9627                    a_cpu = a_cuda.detach().cpu().requires_grad_()
9628
9629                    with warnings.catch_warnings(record=True):
9630                        out_cuda = F.interpolate(a_cuda, scale_factor=scale_factor, **kwargs)
9631                        out_cpu = F.interpolate(a_cpu, scale_factor=scale_factor, **kwargs)
9632
9633                    self.assertEqual(out_cpu, out_cuda.cpu())
9634
9635                    g_cuda = torch.randn_like(out_cuda)
9636                    g_cpu = g_cuda.cpu()
9637
9638                    out_cuda.backward(g_cuda)
9639                    out_cpu.backward(g_cpu)
9640
9641                    self.assertEqual(a_cuda.grad, a_cpu.grad)
9642
9643    @parametrize_test("antialias", [True, False])
9644    @parametrize_test("num_channels", [3, 5])
9645    @parametrize_test("mode", ["nearest", "nearest-exact", "bilinear", "bicubic"])
9646    @parametrize_test("dtype", integral_types() + floating_types())
9647    @onlyNativeDeviceTypes
9648    def test_upsamplingBiMode2d_nonsupported_dtypes(self, device, antialias, num_channels, mode, dtype):
9649        x = torch.ones(1, num_channels, 32, 32, dtype=dtype, device=device)
9650
9651        should_raise_runtime_error = True
9652
9653        if "nearest" in mode:
9654            if antialias:
9655                raise SkipTest("Nearest mode does not have antialiasing")
9656            if dtype in (torch.uint8, ) + floating_types():
9657                should_raise_runtime_error = False
9658
9659        elif mode in ("bilinear", "bicubic"):
9660            if dtype in floating_types() or (device == "cpu" and dtype == torch.uint8):
9661                should_raise_runtime_error = False
9662
9663        if should_raise_runtime_error:
9664            with self.assertRaisesRegex(RuntimeError, "not implemented for"):
9665                F.interpolate(x, (12, 12), mode=mode, antialias=antialias)
9666        else:
9667            _ = F.interpolate(x, (12, 12), mode=mode, antialias=antialias)
9668
9669    @expectedFailureMPS  # NotImplementedError: aten::_upsample_bilinear2d_aa.out https://github.com/pytorch/pytorch/issues/77764
9670    @parametrize_test("memory_format", [torch.contiguous_format, torch.channels_last])
9671    def test_upsamplingBilinear2d_aa_correctness(self, device, memory_format):
9672        # NOTE: We expand the batch dim such that `b*c` is above the maximum
9673        # size of CUDA grid z-dimension (2**16)
9674        shape = [23000, 3, 8, 8]
9675        t_in = torch.arange(3 * 8 * 8, dtype=torch.float, device=device).reshape(1, *shape[1:])
9676        t_in = t_in.expand(shape)
9677        t_in = t_in.contiguous(memory_format=memory_format)
9678        # This expected result is obtain using PIL.Image.resize
9679        # for c in range(3):
9680        #   a_in = t_in.numpy()[0, c, ...]
9681        #   pil_in = Image.fromarray(a_in)
9682        #   pil_out = pil_in.resize((2, 2), resample=Image.LINEAR)
9683        expected_out = torch.tensor([
9684            17.035713, 20.25, 42.75, 45.964287, 81.03572, 84.25,
9685            106.75, 109.96428, 145.0357, 148.25, 170.75, 173.9643
9686        ], device=device, dtype=t_in.dtype).reshape(1, 3, 2, 2)
9687        t_out = F.interpolate(t_in, size=(2, 2), mode="bilinear", align_corners=False, antialias=True)
9688        self.assertEqual(expected_out.expand([*shape[:2], 2, 2]), t_out)
9689
9690    # Partially passes. NotImplementedError: aten::upsample_bicubic2d.out https://github.com/pytorch/pytorch/issues/77764
9691    @skipIfMps
9692    @parametrize_test("memory_format", [torch.contiguous_format, torch.channels_last])
9693    @parametrize_test("mode", ["bilinear", "bicubic"])
9694    @parametrize_test("antialias", [True, False])
9695    @parametrize_test("align_corners", [True, False])
9696    @parametrize_test("num_channels", [3, 5])
9697    @parametrize_test("output_size", [32, 600])
9698    @parametrize_test("check_as_unsqueezed_3d_tensor", [True, False])
9699    @parametrize_test("non_contig", [False, "sliced", "restrided"])
9700    @parametrize_test("batch_size", [1, 5])
9701    def test_upsamplingBiMode2d_consistency(
9702        self,
9703        device,
9704        memory_format,
9705        mode,
9706        antialias,
9707        align_corners,
9708        num_channels,
9709        output_size,
9710        check_as_unsqueezed_3d_tensor,
9711        non_contig,
9712        batch_size,
9713    ):
9714        # Check output value consistency between resized_input_uint8 and resized input_float
9715        if torch.device(device).type == "cuda":
9716            raise SkipTest("CUDA implementation is not yet supporting uint8")
9717
9718        torch.manual_seed(0)
9719
9720        # - input range is set to [30, 220] for bicubic mode, because the bicubic kernel may create
9721        #   [intermediate] values outside of the [0, 255] range, which need
9722        #   to be clipped in uint8 path, but not in float path. This isn't
9723        #   an issue with bilinear kernel.
9724        input_range = (30, 220) if mode == "bicubic" else (0, 256)
9725        input_ui8 = torch.randint(*input_range, size=(batch_size, num_channels, 400, 400), dtype=torch.uint8, device=device)
9726        input_ui8 = input_ui8.contiguous(memory_format=memory_format)
9727
9728        if non_contig == "sliced":
9729            input_ui8 = input_ui8[:, :, 10:-10, 10:-10]
9730        elif non_contig == "restrided":
9731            input_ui8 = input_ui8[:, :, ::2, ::2]
9732
9733        if batch_size == 1 and check_as_unsqueezed_3d_tensor:
9734            input_ui8 = input_ui8[0, ...]
9735            input_ui8 = input_ui8[None, ...]
9736
9737        input_f32 = input_ui8.float()
9738
9739        output_f32 = F.interpolate(
9740            input_f32, size=(output_size, output_size), mode=mode, align_corners=align_corners, antialias=antialias
9741        ).round().clip(0, 255)
9742        output_ui8 = F.interpolate(
9743            input_ui8, size=(output_size, output_size), mode=mode, align_corners=align_corners, antialias=antialias
9744        )
9745
9746        if non_contig is False:
9747            self.assertTrue(input_ui8.is_contiguous(memory_format=memory_format))
9748
9749        # FIXME if-clause shows the current behaviour which is definitely unexpected.
9750        # Ideally we want to fix it such that both the ui8 and f32 outputs are also channels_last
9751        # See for more details: https://github.com/pytorch/pytorch/pull/100373
9752        if batch_size == 1 and check_as_unsqueezed_3d_tensor and memory_format == torch.channels_last:
9753            self.assertTrue(output_ui8.is_contiguous())
9754            self.assertTrue(output_f32.is_contiguous())
9755        else:
9756            self.assertTrue(output_ui8.is_contiguous(memory_format=memory_format))
9757            self.assertTrue(output_f32.is_contiguous(memory_format=memory_format))
9758
9759        if mode == "bilinear":
9760            torch.testing.assert_close(output_f32, output_ui8.float(), rtol=0, atol=1)
9761        else:
9762            diff = (output_f32 - output_ui8.float()).abs()
9763            self.assertLess(diff.max(), 15)
9764
9765            threshold = 2
9766            percent = 3
9767            self.assertLess((diff > threshold).float().mean(), percent / 100)
9768
9769            threshold = 5
9770            percent = 1
9771            self.assertLess((diff > threshold).float().mean(), percent / 100)
9772
9773            self.assertLess(diff.mean(), 0.4)
9774
9775    @parametrize_test("memory_format", [torch.contiguous_format, torch.channels_last])
9776    @parametrize_test("align_corners", [True, False])
9777    @parametrize_test("input_size, output_size", [(399, 437), (403, 377)])
9778    def test_upsamplingBiLinear2d_consistency_interp_size_bug(self, device, memory_format, align_corners, input_size, output_size):
9779        # Non-regression test for https://github.com/pytorch/pytorch/pull/101403
9780
9781        if torch.device(device).type == "cuda":
9782            raise SkipTest("CUDA implementation is not yet supporting uint8")
9783
9784        mode = "bilinear"
9785        input_ui8 = torch.randint(0, 256, size=(1, 3, input_size, input_size), dtype=torch.uint8, device=device)
9786        input_ui8 = input_ui8.contiguous(memory_format=memory_format)
9787        input_f32 = input_ui8.float()
9788
9789        output_f32 = F.interpolate(
9790            input_f32, size=(output_size, output_size), mode=mode, align_corners=align_corners, antialias=False
9791        ).round().to(torch.uint8)
9792        output_ui8 = F.interpolate(
9793            input_ui8, size=(output_size, output_size), mode=mode, align_corners=align_corners, antialias=False
9794        )
9795        torch.testing.assert_close(output_f32, output_ui8, atol=1, rtol=0)
9796
9797    @expectedFailureMPS  # NotImplementedError: aten::upsample_bicubic2d.out https://github.com/pytorch/pytorch/issues/77764
9798    def test_upsamplingBicubic2d_correctness(self, device):
9799        # test output against known input: align_corners=False result must match opencv
9800        in_t = torch.arange(8., device=device).view(1, 2, 2, 2)
9801        expected_out_t = torch.tensor(
9802            [[[[-0.31641, 0.01562, 0.56250, 0.89453],
9803              [0.34766, 0.67969, 1.22656, 1.55859],
9804              [1.44141, 1.77344, 2.32031, 2.65234],
9805              [2.10547, 2.43750, 2.98438, 3.31641]],
9806
9807             [[3.68359, 4.01562, 4.56250, 4.89453],
9808              [4.34766, 4.67969, 5.22656, 5.55859],
9809              [5.44141, 5.77344, 6.32031, 6.65234],
9810              [6.10547, 6.43750, 6.98438, 7.31641]]]], device=device)
9811        out_t = F.interpolate(in_t, scale_factor=2, mode='bicubic', align_corners=False)
9812        torch.set_printoptions(precision=5)
9813        self.assertEqual(out_t, expected_out_t, atol=1e-5, rtol=0)
9814
9815    @expectedFailureMPS  # NotImplementedError: aten::_upsample_bicubic2d_aa.out https://github.com/pytorch/pytorch/issues/77764
9816    @parametrize_test("memory_format", [torch.contiguous_format, torch.channels_last])
9817    def test_upsamplingBicubic2d_aa_correctness(self, device, memory_format):
9818        t_in = torch.arange(3 * 8 * 8, dtype=torch.float, device=device).reshape(1, 3, 8, 8)
9819        t_in = t_in.contiguous(memory_format=memory_format)
9820        # This expected result is obtain using PIL.Image.resize
9821        # for c in range(3):
9822        #   a_in = t_in.numpy()[0, c, ...]
9823        #   pil_in = Image.fromarray(a_in)
9824        #   pil_out = pil_in.resize((2, 2), resample=Image.BICUBIC)
9825        expected_out = torch.tensor([
9826            15.1205635, 18.760439, 44.23956, 47.879436, 79.12056, 82.76044,
9827            108.23956, 111.87944, 143.12057, 146.76044, 172.23956, 175.87943
9828        ], device=device, dtype=t_in.dtype).reshape(1, 3, 2, 2)
9829        t_out = F.interpolate(t_in, size=(2, 2), mode="bicubic", align_corners=False, antialias=True)
9830        self.assertEqual(expected_out, t_out)
9831
9832    @expectedFailureMPS  # NotImplementedError: aten::upsample_trilinear3d.out https://github.com/pytorch/pytorch/issues/77764
9833    @parametrize_test("align_corners", [True, False])
9834    @parametrize_test("memory_format", [torch.contiguous_format, torch.channels_last_3d])
9835    def test_upsamplingTrilinear3d(self, device, align_corners, memory_format):
9836        kwargs = dict(mode='trilinear', align_corners=align_corners)
9837
9838        # test float scale factor up & downsampling
9839        for scale_factor in [0.5, 1.5, 2]:
9840            m = nn.Upsample(scale_factor=scale_factor, **kwargs)
9841            in_t = torch.ones(1, 2, 4, 4, 4, device=device, dtype=torch.double)
9842            in_t = in_t.contiguous(memory_format=memory_format).requires_grad_()
9843            out_size = int(math.floor(in_t.shape[-1] * scale_factor))
9844            with warnings.catch_warnings(record=True) as w:
9845                out_t = m(in_t)
9846            expected_out = torch.ones(1, 2, out_size, out_size, out_size, device=device, dtype=torch.double)
9847            self.assertEqual(expected_out, out_t)
9848            # Assert that memory format is carried through to the output
9849            self.assertTrue(out_t.is_contiguous(memory_format=memory_format))
9850
9851            grad_out = torch.randn_like(out_t).contiguous(memory_format=memory_format)
9852            in_t.grad = None
9853            out_t.backward(grad_out)
9854            grad_in = in_t.grad
9855            self.assertTrue(grad_in.is_contiguous(memory_format=memory_format))
9856
9857            if memory_format == torch.channels_last_3d:
9858                # check if grad inputs CF and CL match
9859                in_t.grad = None
9860                out_t.backward(grad_out.contiguous())
9861                self.assertEqual(in_t.grad, grad_in)
9862
9863            input = torch.randn(1, 2, 4, 4, 4, requires_grad=True, dtype=torch.double)
9864            self.assertEqual(
9865                F.interpolate(input, (out_size, out_size, out_size), **kwargs),
9866                F.interpolate(input, scale_factor=scale_factor, **kwargs))
9867            gradcheck(lambda x: F.interpolate(x, out_size, **kwargs), [input])
9868            gradgradcheck(lambda x: F.interpolate(x, out_size, **kwargs), [input])
9869
9870    @onlyCUDA
9871    @dtypes(torch.half)
9872    @largeTensorTest('40GB')
9873    def test_upsampling_64bit_indexing_channels_last(self, device, dtype):
9874        x = torch.rand((32, 64, 512, 512), dtype=dtype, device=device)
9875        out = torch.nn.functional.interpolate(x.to(memory_format=torch.channels_last), scale_factor=2, mode='nearest')
9876        out_ref = torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest')
9877        del x
9878        self.assertTrue(torch.allclose(out, out_ref))
9879
9880    @onlyCUDA
9881    @dtypes(torch.half)
9882    @largeTensorTest('40GB')
9883    def test_replicatepad_64bit_indexing(self, device, dtype):
9884        conv = torch.nn.Conv1d(128, 128, 3, 1, 1, padding_mode="replicate", device=device, dtype=dtype)
9885        x = torch.randn(size=(256 * 448 * 2, 128, 96), dtype=dtype, device=device)
9886        y = conv(x)
9887        torch.mean(y).backward()
9888
9889    @onlyCUDA
9890    @dtypes(torch.half)
9891    @largeTensorTest('40GB')
9892    def test_upsamplingnearest2d_backward_64bit_indexing(self, device, dtype):
9893        x = torch.randn(size=(36, 128, 512, 512), device=device, dtype=dtype).requires_grad_()
9894        y = F.interpolate(x, scale_factor=2, mode="nearest")
9895        y.backward(torch.randn_like(y))
9896
9897    def _slow_masked_softmax(self, input, mask):
9898        exp = torch.exp(input)
9899        exp = exp * mask
9900        s = exp.sum(dim=3, keepdim=True).expand(exp.size())
9901        return exp / s
9902
9903    def test_masked_softmax_mask_types(self, device):
9904        # Test that mask type 0 (LxL attention mask), mask type 1 (BxL padding mask),
9905        # and mask type 2 (generic BxHxLxL mask) are processed correctly on the
9906        # fast path and the results match explicit slow calculation.
9907        sizes = [(1, 1, 32), (3, 16, 310), (12, 4, 1024), (4, 2, 1200)]
9908
9909        for (B, num_heads, L) in sizes:
9910
9911            # mask_type == 0 => attention mask of shape LxL
9912            src_mask_orig = torch.randint(0, 2, (L, L)).bool()
9913            src_mask = src_mask_orig.reshape(1, 1, L, L).expand(B, num_heads, L, L).bool()
9914
9915            # mask_type == 1 => padding mask of shape BxL
9916            src_key_padding_mask_orig = torch.randint(0, 2, (B, L)).bool()
9917            src_key_padding_mask = src_key_padding_mask_orig.reshape(B, 1, 1, L).expand(B, num_heads, L, L).bool()
9918
9919            # mask_type == 2 =>  shape BxHxLxL
9920            generic_mask = torch.randint(0, 2, (B, num_heads, L, L)).bool()
9921            masks = [(src_mask_orig, src_mask, 0),
9922                     (src_key_padding_mask_orig, src_key_padding_mask, 1),
9923                     (generic_mask, generic_mask, 2)
9924                     ]
9925            for dim in [0, 3]:
9926                for mask_orig, mask, mask_type in masks:
9927                    if (self.device_type == "cuda") and (num_heads % 2) and (mask_type == 1):
9928                        # CUDA path doesn't support padding mask when the number of heads is odd
9929                        continue
9930                    input = torch.randn((B, num_heads, L, L))
9931                    if (self.device_type == "cuda"):
9932                        input = input.cuda()
9933                        mask = mask.cuda()
9934                        mask_orig = mask_orig.cuda()
9935                    native_res = torch._masked_softmax(input, mask_orig, dim, mask_type)
9936                    mask = ~mask
9937
9938                    def slow_masked_softmax(input, mask):
9939                        exp = torch.exp(input)
9940                        exp = exp * mask
9941                        s = exp.sum(dim=dim, keepdim=True).expand(exp.size())
9942                        return exp / s
9943
9944                    pt_res = slow_masked_softmax(input, mask)
9945                    pt_res = torch.nan_to_num(pt_res)
9946
9947                    mask_not = mask.logical_not()
9948                    # In result, should only fill the entirely masked out rows since those are non-deterministic (*may* be 0)
9949                    # Converts rows with all True's to False
9950                    mask_out = mask_not.all(dim, keepdim=True).expand(mask_not.shape)
9951                    self.assertEqual(
9952                        pt_res.masked_fill(mask_out, 0),
9953                        native_res.masked_fill(mask_out, 0),
9954                        exact_dtype=True
9955                    )
9956
9957    @onlyCUDA
9958    @gcIfJetson
9959    def test_masked_softmax_devices_parity(self):
9960        # Test that softmax with mask type 0 (LxL attention mask), mask type 1 (BxL padding mask),
9961        # and mask type 2 (BxHxLxL generic mask) gives the same result on CPU and on CUDA.
9962
9963        sizes = [(1, 1, 32), (3, 16, 310), (12, 4, 1024), (4, 2, 1200)]
9964        for (B, num_heads, L) in sizes:
9965            # mask_type == 0 => attention mask of shape LxL
9966            src_mask = torch.randint(0, 2, (L, L)).bool()
9967            # mask_type == 1 => padding mask of shape BxL
9968            src_key_padding_mask = torch.randint(0, 2, (B, L)).bool()
9969            # mask_type == 2 => generic mask of shape BxHxLxL
9970            generic_mask = torch.randint(0, 2, (B, num_heads, L, L)).bool()
9971            masks = [(src_mask, 0), (src_key_padding_mask, 1), (generic_mask, 2)]
9972            input = torch.randn((B, num_heads, L, L))
9973            for dim in [0, 3]:
9974                for mask, mask_type in masks:
9975                    if (num_heads % 2) and (mask_type == 1):
9976                        # CUDA path doesn't support padding mask when the number of heads is odd
9977                        continue
9978
9979                    def softmax_on_device(mask, input, device):
9980                        # Compute softmax on a given device
9981                        input_device = input.to(device)
9982                        mask_device = mask.to(device)
9983                        softmax_res = torch._masked_softmax(input_device, mask_device, dim, mask_type)
9984                        if mask_type == 0:
9985                            mask_expanded = mask_device.reshape(1, 1, L, L).expand(B, num_heads, L, L).bool()
9986                        elif mask_type == 1:
9987                            mask_expanded = mask_device.reshape(B, 1, 1, L).expand(B, num_heads, L, L).bool()
9988                        else:
9989                            mask_expanded = mask_device
9990                        # In result, should only fill the entirely masked out rows since those are non-deterministic (*may* be 0)
9991                        # Fill rows with all True's with 0
9992                        mask_out = mask_expanded.all(dim, keepdim=True).expand(mask_expanded.shape)
9993                        softmax_res = softmax_res.masked_fill(mask_out, 0)
9994                        return softmax_res
9995
9996                    cpu_res = softmax_on_device(mask, input, "cpu")
9997                    cuda_res = softmax_on_device(mask, input, "cuda")
9998                    self.assertEqual(cpu_res, cuda_res, exact_dtype=True)
9999
10000    def test_masked_softmax(self, device):
10001        sizes = [(1, 1, 32), (3, 16, 310), (12, 4, 1024), (4, 2, 1200)]
10002        for (B, num_heads, L) in sizes:
10003            for dim in [0, 3]:
10004                input = torch.randn((B, num_heads, L, L))
10005                mask = torch.randint(0, 2, (B, L))
10006                mask = mask.reshape(B, 1, 1, L).expand(B, num_heads, L, L).bool()
10007                mask_type = 1   # BxL => src_key_padding_mask
10008                if (self.device_type == "cuda"):
10009                    input = input.cuda()
10010                    mask = mask.cuda()
10011                native_res = torch._masked_softmax(input, mask, dim, mask_type)
10012                mask = ~mask
10013
10014                def slow_masked_softmax(input, mask):
10015                    exp = torch.exp(input)
10016                    exp = exp * mask
10017                    s = exp.sum(dim=dim, keepdim=True).expand(exp.size())
10018                    return exp / s
10019
10020                pt_res = slow_masked_softmax(input, mask)
10021                pt_res = torch.nan_to_num(pt_res)
10022
10023                mask_not = mask.logical_not()
10024                # In result, should only fill the entirely masked out rows since those are non-deterministic (*may* be 0)
10025                # Converts rows with all True's to False
10026                mask_out = mask_not.all(dim, keepdim=True).expand(mask_not.shape)
10027                self.assertEqual(
10028                    pt_res.masked_fill(mask_out, 0),
10029                    native_res.masked_fill(mask_out, 0),
10030                    exact_dtype=True
10031                )
10032
10033    @dtypes(torch.bfloat16, torch.half)
10034    @precisionOverride({torch.bfloat16: 2e-2, torch.half: 3e-3})
10035    def test_masked_softmax_lowp(self, dtype):
10036        sizes = [(1, 1, 32), (3, 16, 310), (12, 4, 1024), (4, 2, 1200)]
10037        for (B, num_heads, L) in sizes:
10038            for dim in [0, 3]:
10039                input_lowp = torch.randn((B, num_heads, L, L), dtype=dtype).requires_grad_()
10040                input_ref = input_lowp.float().detach().requires_grad_()
10041                mask = torch.randint(0, 2, (B, L))
10042                mask = mask.reshape(B, 1, 1, L).expand(B, num_heads, L, L).bool()
10043
10044                for mask_type in [1, 2]:
10045                    res_ref = torch._masked_softmax(input_ref, mask, dim, mask_type)
10046                    res = torch._masked_softmax(input_lowp, mask, dim, mask_type)
10047                    self.assertEqual(res_ref.to(dtype), res)
10048
10049                    grad_lowp = torch.randn_like(res_ref).to(dtype=dtype)
10050                    grad_ref = grad_lowp.float()
10051
10052                    res_ref.backward(grad_ref)
10053                    res.backward(grad_lowp)
10054                    self.assertEqual(input_ref.grad.to(dtype), input_lowp.grad)
10055
10056    def _test_masked_softmax_helper(self, input, dim, mask, mask_type):
10057        input_ref = input.detach().clone().requires_grad_()
10058        result = torch._masked_softmax(input, mask, dim, mask_type)
10059
10060        expected = torch._softmax(input_ref.masked_fill(mask, float('-inf')), dim, False)
10061        grad = torch.randn_like(expected).to(dtype=expected.dtype)
10062
10063        result.backward(grad)
10064        expected.backward(grad)
10065
10066        # Make sure the optional argument works as well
10067        if dim == input.dim() - 1:
10068            input_ref_default = input.detach().clone().requires_grad_()
10069            result_default = torch._masked_softmax(input_ref_default, mask, None, mask_type)
10070            result_default.backward(grad)
10071            self.assertEqual(result, result_default)
10072            self.assertEqual(input.grad, input_ref_default.grad)
10073
10074        # In result, should only fill the entirely masked out rows since those are non-deterministic (*may* be 0)
10075        # Converts rows with all True's to False
10076        mask_out = mask.all(dim, keepdim=True).expand(mask.shape)
10077        self.assertEqual(result.masked_fill(mask_out, 0), expected.masked_fill(mask_out, 0))
10078
10079        self.assertEqual(input.grad, torch.nan_to_num(input_ref.grad))
10080        self.assertEqual(input.grad, input.grad.masked_fill(mask, 0.0))
10081
10082    def test_masked_softmax_grad(self, device):
10083        shapes = [(1, 1, 32), (3, 16, 310), (12, 4, 1024), (4, 2, 1200)]
10084        for shape in shapes:
10085            dims = [0, len(shape) - 1] if len(shape) > 0 else [0]
10086            for dim in dims:
10087                for mask_type in [1, 2]:  # 1 = BxL => src_key_padding_mask
10088                    input = torch.randn(shape, requires_grad=True)
10089                    mask = torch.randint(0, 2, shape).bool()
10090                    if (self.device_type == "cuda"):
10091                        input = input.cuda().detach().requires_grad_()
10092                        mask = mask.cuda()
10093                    self._test_masked_softmax_helper(input, dim, mask, mask_type)
10094
10095    # In this test, the forward pass is expected to produce nan's because when dim=0, we only have unspecified values
10096    def test_masked_softmax_forward_with_nans(self, device):
10097        dim = 0
10098        shapes = [(4, 5), (50, 100), (1500, 1200)]
10099        for (x, y) in shapes:
10100            for mask_type in [1, 2]:  # 1 = BxL => src_key_padding_mask
10101                input = torch.randn((x, y), requires_grad=True)
10102                mask = torch.tensor([i % 2 for i in range(y)]).expand((x, y)).bool()
10103                if (self.device_type == "cuda"):
10104                    input = input.cuda().detach().requires_grad_()
10105                    mask = mask.cuda()
10106                self._test_masked_softmax_helper(input, dim, mask, mask_type)
10107
10108    @onlyCUDA
10109    def test_masked_softmax_transformer_layout(self, device):
10110        B = 211
10111        num_heads = 16
10112        L = 42
10113        input = torch.randn((B, num_heads, L, L))
10114        dim = input.dim() - 1
10115        mask = torch.randint(0, 2, (B, L))
10116        mask_type = 1   # BxL => src_key_padding_mask
10117        if (self.device_type == "cuda"):
10118            input = input.cuda()
10119            mask = mask.cuda()
10120        mask = mask.bool()
10121        native_res = torch._masked_softmax(input, mask, dim, mask_type)
10122        mask = mask.reshape(B, 1, 1, L).expand(B, num_heads, L, L)
10123        mask = ~mask
10124        mask = mask.float()
10125
10126        pt_res = self._slow_masked_softmax(input, mask)
10127        self.assertEqual(pt_res, native_res, exact_dtype=True)
10128
10129    @onlyCUDA
10130    def test_masked_softmax_TxT_layout(self, device):
10131        B = 211
10132        num_heads = 16
10133        L = 42
10134        input = torch.randn((B, num_heads, L, L))
10135        dim = input.dim() - 1
10136        mask = torch.randint(0, 2, (L, L))
10137        mask_type = 0   # LxL => src_mask
10138        if (self.device_type == "cuda"):
10139            input = input.cuda()
10140            mask = mask.cuda()
10141        mask = mask.bool()
10142        native_res = torch._masked_softmax(input, mask, dim, mask_type)
10143        mask = mask.expand(B, num_heads, L, L)
10144        mask = ~mask
10145        mask = mask.float()
10146
10147        pt_res = self._slow_masked_softmax(input, mask)
10148        self.assertEqual(pt_res, native_res, exact_dtype=True)
10149
10150    @onlyCPU
10151    @dtypes(torch.bfloat16, torch.half)
10152    def test_log_softmax_cpu(self, device, dtype):
10153        for dim in [0, 1]:
10154            inputf = torch.rand(200, 200, device=device, dtype=torch.float, requires_grad=True)
10155            input = inputf.to(dtype).detach().requires_grad_(True)
10156            outf = F.log_softmax(inputf, dim=dim)
10157            out = F.log_softmax(input, dim=dim)
10158            self.assertEqual(out, outf.to(dtype=dtype), atol=0.1, rtol=0)
10159
10160            out.sum().backward()
10161            outf.sum().backward()
10162            self.assertEqual(input.grad, inputf.grad.to(dtype), atol=0.1, rtol=0)
10163
10164    @onlyCPU
10165    @dtypes(torch.bfloat16, torch.half)
10166    def test_softmax_cpu(self, device, dtype):
10167        for dim in [0, 1]:
10168            inputf = torch.rand(200, 200, device=device, dtype=torch.float, requires_grad=True)
10169            input = inputf.to(dtype).detach().requires_grad_(True)
10170            outf = F.softmax(inputf, dim=dim)
10171            out = F.softmax(input, dim=dim)
10172            self.assertEqual(out, outf.to(dtype), atol=1e-3, rtol=0)
10173
10174            out.sum().backward()
10175            outf.sum().backward()
10176            self.assertEqual(input.grad, inputf.grad.to(dtype), atol=1e-3, rtol=0)
10177
10178    @dtypesIfCUDA(torch.half, torch.float)
10179    @dtypes(torch.float)
10180    def test_softmax_results(self, device, dtype):
10181        # Non-even sizes and non-zero shifts test fallback paths in vectorized kernel
10182        # Note: dim1 > 1024 is needed to exercise the vectorized (non-persistent) path, (16, 30576) is BERT-esque
10183        sizes = [(0, 10), (32, 20), (10, 0), (31, 20), (32, 21), (31, 23), (32, 1536), (31, 2048), (33, 2049), (16, 30576)]
10184        shifts = [(0, 0), (1, 0), (0, 1), (1, 1)]
10185        for fn in [F.softmax, F.log_softmax]:
10186            for size in sizes:
10187                for shift in shifts:
10188                    input = torch.rand(size, device=device, dtype=dtype)
10189                    # Note: With the largest tests we can hit upper limit of fp16 when we
10190                    # sum, so scale the input down to stay in a nicer range.
10191                    if dtype == torch.float16:
10192                        input = input / 100.
10193                    input = input[shift[0]:, shift[1]:]
10194                    # Note; Don't want to bprop back through slice op
10195                    input = input.detach().requires_grad_(True)
10196                    ref_input = input.clone().cpu().detach().requires_grad_(True)
10197                    for dim in [0, 1]:
10198                        ref_output = fn(ref_input, dtype=torch.float, dim=dim)
10199                        output = fn(input, dtype=torch.float, dim=dim)
10200                        grad_output = torch.rand(size, device=device, dtype=dtype)
10201                        grad_output = grad_output[shift[0]:, shift[1]:]
10202                        ref_grad_output = grad_output.clone().cpu().detach()
10203                        grad_input, = torch.autograd.grad(output, input, grad_outputs=(grad_output), create_graph=True)
10204                        ref_grad_input, = torch.autograd.grad(ref_output, ref_input,
10205                                                              grad_outputs=(ref_grad_output), create_graph=True)
10206                        grad_input.sum().backward()
10207                        ref_grad_input.sum().backward()
10208
10209                        self.assertEqual(output, ref_output)
10210                        self.assertEqual(grad_input, ref_grad_input)
10211                        self.assertEqual(input.grad, ref_input.grad)
10212
10213    @onlyCUDA
10214    @dtypes(torch.float, torch.half)
10215    @largeTensorTest("20GB")
10216    @largeTensorTest("64GB", "cpu")
10217    def test_warp_softmax_64bit_indexing(self, device, dtype):
10218        def run_test(*shape):
10219            x = torch.randn(shape, device="cuda", dtype=torch.float16, requires_grad=True)
10220            y = F.log_softmax(x, dim=-1, dtype=dtype)
10221            y.backward(y)
10222            with torch.no_grad():
10223                xx = x.cpu().requires_grad_()
10224            yy = F.log_softmax(xx.float(), dim=-1).to(dtype)
10225            yy.backward(yy)
10226            # workaround to reduce memory usage vs. self.assertEqual, see #84944
10227            rtol, atol = torch.testing._comparison.get_tolerances(dtype, rtol=None, atol=None)
10228            self.assertTrue(torch.allclose(y.cpu(), yy, rtol=rtol, atol=atol))
10229            # x is half
10230            rtol, _ = torch.testing._comparison.get_tolerances(torch.half, rtol=None, atol=None)
10231            self.assertTrue(torch.allclose(x.grad.cpu(), xx.grad, rtol=rtol, atol=1e-3))
10232
10233        run_test(1100000000, 2)  # Illegal memory access https://github.com/pytorch/pytorch/issues/52715
10234        run_test(2200000000, 1)  # invalid configuration argument https://github.com/pytorch/pytorch/issues/52716
10235
10236    @onlyCUDA
10237    @dtypes(torch.half)
10238    @largeTensorTest("20GB")
10239    @largeTensorTest("2GB", "cpu")
10240    @precisionOverride({torch.half: 0.001})
10241    def test_softmax_64bit_indexing(self, device, dtype):
10242        def run_test(*shape):
10243            x = torch.ones(shape, device=device, dtype=dtype, requires_grad=True)
10244            y = F.log_softmax(x, dim=-1, dtype=dtype)
10245            y.backward(y)
10246            self.assertEqual(y[0], y[-1])
10247            self.assertEqual(x.grad[0], x.grad[-1])
10248
10249        run_test(1024 * 256 + 1, 8192)  # https://github.com/pytorch/pytorch/issues/84144
10250
10251
10252    @dtypes(torch.float)
10253    @dtypesIfCUDA(torch.float, torch.half)
10254    def test_log_softmax_big(self, device, dtype):
10255        def _test_helper(shape):
10256            # generate a tensor with big numbers that are exactly representable in dtype
10257            # and are at a constant offset from tensor with small numbers
10258            # the logsoftmax of a small and big tensors should be equal
10259            x_small = torch.randint(100, shape, dtype=dtype, device=device)
10260            offset = 1.5e3 if dtype == torch.half else 1e7
10261            x_big = x_small + offset
10262            self.assertEqual(F.log_softmax(x_small, -1), F.log_softmax(x_big, -1))
10263        _test_helper((16, 4))
10264        if self.device_type == 'cuda':
10265            # test non-persistent softmax kernel
10266            _test_helper((4, 1536))
10267
10268    def test_save_lstm_compatibility(self, device):
10269        # Test that saving an LSTM in PyTorch 1.7 and older can still be
10270        # loaded in newer versions of PyTorch.
10271        model = nn.LSTM(2, 3)
10272        x = torch.randn(32, 5, 2)
10273        expected = model(x)
10274
10275        # Get a state dict for PyTorch 1.7 LSTM. Before PyTorch 1.8, proj_size
10276        # didn't exist.
10277        assert model.proj_size == 0
10278        state_dict = model.__dict__
10279        del state_dict['proj_size']
10280
10281        # load a model
10282        loaded_model = nn.LSTM(2, 3)
10283        loaded_model.__setstate__(state_dict)
10284        result = loaded_model(x)
10285        self.assertEqual(result, expected)
10286
10287    @onlyCUDA
10288    @tf32_on_and_off(0.005)
10289    def test_grid_sample_large(self, device):
10290        def issue_35202():
10291            input_tensor = torch.rand(1, 1, 480, 640, dtype=torch.float, device=device, requires_grad=True)
10292            coords = torch.tensor([[-10059144, 67680944], [67680944, 67680944]], dtype=torch.float, device=device)
10293            coords = coords.unsqueeze(0).unsqueeze(0).repeat(1, 1, 1, 1)
10294            result = torch.nn.functional.grid_sample(input_tensor, coords)
10295            self.assertEqual(result, torch.tensor([[[[0., 0.]]]], dtype=torch.float, device=device))
10296            result.backward(torch.ones_like(result))
10297            torch.cuda.synchronize()
10298        issue_35202()
10299
10300        def issue_24823_1(dtype):
10301            image = torch.arange(27, 0, -1, dtype=dtype, device=device).view(1, 1, 3, 3, 3)
10302            image.requires_grad_()
10303            grid = torch.nn.functional.affine_grid(
10304                torch.tensor([[[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0]]], dtype=dtype, device=device),
10305                (1, 1, 3, 3, 3))
10306            grid[:, 1, 1, 1, 0] = float('inf')
10307            result = torch.nn.functional.grid_sample(image, grid, padding_mode='zeros')
10308            tol_override = {'atol': 0.005, 'rtol': 0} if dtype == torch.half else {}
10309            self.assertEqual(result, torch.tensor([[[[[27., 26., 25.], [24., 23., 22.], [21., 20., 19.]],
10310                                                     [[18., 17., 16.], [15., 0., 13.], [12., 11., 10.]],
10311                                                     [[9., 8., 7.], [6., 5., 4.], [3., 2., 1.]]]]],
10312                                                  device=device, dtype=dtype), **tol_override)
10313            result.backward(torch.ones_like(result))
10314            expected_grad = torch.ones_like(image)
10315            expected_grad[0, 0, 1, 1, 1] = 0
10316            self.assertEqual(image.grad, expected_grad, atol=0.005, rtol=0)
10317        issue_24823_1(torch.half)
10318        issue_24823_1(torch.float)
10319        issue_24823_1(torch.double)
10320
10321        def issue_24823_2():
10322            param = torch.tensor([[[-1.0e+20, 0.0, 0.0], [0.0, -1.0e+20, 0.0]]], dtype=torch.float, device=device)
10323            img = torch.zeros((1, 1, 4, 4), dtype=torch.float, device=device, requires_grad=True)
10324            grid = torch.nn.functional.affine_grid(param, img.size())
10325            result = torch.nn.functional.grid_sample(img, grid)
10326            self.assertEqual(result, torch.zeros(1, 1, 4, 4, device=device, dtype=torch.float))
10327            result.backward(torch.ones_like(result))
10328            torch.cuda.synchronize()
10329        issue_24823_2()
10330
10331    @dtypes(torch.float, torch.double)
10332    @largeTensorTest(lambda self, device, dtype:
10333                     # Compute sum of the large tensor sizes:
10334                     # (im.numel() + small_image.numel() + small_image.grad.numel() +
10335                     #   large_view.grad.numel()) * sizeof(dtype)
10336                     32769 * (65536 + 3 * 65536 / 128) *
10337                     torch.tensor([], dtype=dtype).element_size())
10338    def test_grid_sample_large_index_2d(self, device, dtype):
10339        # Test 64-bit indexing with grid_sample (gh-41656)
10340        # Try accessing the corners, there should be no segfault
10341        coords = torch.tensor([[[-1., -1.],
10342                                [+1., -1.]],
10343
10344                               [[-1., +1.],
10345                                [+1., +1.]]], device=device, dtype=dtype)
10346        coords = coords.expand(1, 2, 2, 2)
10347        im = torch.zeros([1, 1, 32769, 65536], device=device, dtype=dtype)
10348
10349        # Compare sampling with large strides to the same op on a contiguous tensor
10350        coords = torch.rand(1, 4, 4, 2, device=device, dtype=dtype)
10351        large_view = im[..., 127::128]
10352        small_image = torch.rand_like(large_view)
10353        large_view[...] = small_image
10354        large_view.requires_grad, small_image.requires_grad = True, True
10355        self.assertTrue(
10356            sum(i * s for i, s in zip(large_view.size(), large_view.stride())) >= 2 ** 31,
10357            msg="View must use 64-bit indexing")
10358        for mode, padding_mode, align_corners in itertools.product(
10359                ('nearest', 'bilinear', 'bicubic'), ('zeros', 'border', 'reflection'), (True, False)):
10360            a = F.grid_sample(
10361                small_image, coords, mode=mode,
10362                padding_mode=padding_mode, align_corners=align_corners)
10363            a.sum().backward()
10364
10365            b = F.grid_sample(
10366                large_view, coords, mode=mode,
10367                padding_mode=padding_mode, align_corners=align_corners)
10368            b.sum().backward()
10369
10370            self.assertEqual(a, b)
10371            self.assertEqual(small_image.grad, large_view.grad)
10372
10373            small_image.grad.zero_()
10374            large_view.grad.zero_()
10375
10376    @dtypes(torch.float, torch.double)
10377    @largeTensorTest(lambda self, device, dtype:
10378                     # Compute sum of the large tensor sizes:
10379                     # (im.numel() + small_image.numel() + small_image.grad.numel() +
10380                     #   large_view.grad.numel()) * sizeof(dtype)
10381                     2 * 32769 * (32768 + 3 * 32768 / 128) *
10382                     torch.tensor([], dtype=dtype).element_size())
10383    def test_grid_sample_large_index_3d(self, device, dtype):
10384        # Test 64-bit indexing with grid_sample (gh-41656)
10385        # Try accessing the corners, there should be no segfault
10386        coords = torch.full((1, 2, 2, 2, 3), 1., device=device, dtype=dtype)
10387        im = torch.zeros([1, 1, 2, 32769, 32768], device=device, dtype=dtype)
10388
10389        result = F.grid_sample(im, coords, align_corners=False)
10390        self.assertEqual(result, torch.zeros((1, 1, 2, 2, 2), device=device, dtype=dtype))
10391
10392        # Compare sampling with large strides to the same op on a contiguous tensor
10393        coords = torch.rand(1, 1, 4, 4, 3, device=device, dtype=dtype)
10394        large_view = im[..., 127::128]
10395        small_image = torch.rand_like(large_view)
10396        large_view[...] = small_image
10397        small_image.requires_grad, large_view.requires_grad = True, True
10398        self.assertTrue(
10399            sum(i * s for i, s in zip(large_view.size(), large_view.stride())) >= 2 ** 31,
10400            msg="View must use 64-bit indexing")
10401        for mode, padding_mode, align_corners in itertools.product(
10402                ('nearest', 'bilinear'), ('zeros', 'border', 'reflection'), (True, False)):
10403            a = F.grid_sample(
10404                small_image, coords, mode=mode,
10405                padding_mode=padding_mode, align_corners=align_corners)
10406            a.sum().backward()
10407
10408            b = F.grid_sample(
10409                large_view, coords, mode=mode,
10410                padding_mode=padding_mode, align_corners=align_corners)
10411            b.sum().backward()
10412
10413            self.assertEqual(a, b)
10414            self.assertEqual(small_image.grad, large_view.grad)
10415
10416            small_image.grad.zero_()
10417            large_view.grad.zero_()
10418
10419    @onlyCUDA
10420    def test_grid_sample_half_precision(self):
10421        def helper(shape_in, shape_out, align_corners):
10422            for mode in ('bilinear', 'nearest', 'bicubic'):
10423                if len(shape_in) != 4 and mode == 'bicubic':
10424                    continue
10425                data = torch.randn(shape_in, device='cuda', dtype=torch.half)
10426                grid = torch.rand(shape_out, device='cuda', dtype=torch.half) * 2.0 - 1.0
10427
10428                out_half = F.grid_sample(data, grid, mode=mode, padding_mode='zeros', align_corners=align_corners)
10429                out_double = F.grid_sample(data.double(), grid.double(), mode=mode, padding_mode='zeros',
10430                                           align_corners=align_corners)
10431
10432                self.assertEqual(out_half, out_double.half(), msg=f"grid_sample with mode = {mode} doesn't match")
10433
10434        helper((32, 64, 16, 16), (32, 8, 8, 2), True)
10435        helper((32, 64, 16, 16, 16), (32, 8, 8, 8, 3), True)
10436        helper((32, 64, 16, 16), (32, 8, 8, 2), False)
10437        helper((32, 64, 16, 16, 16), (32, 8, 8, 8, 3), False)
10438
10439    @onlyCUDA
10440    def test_grid_sample_bfloat16_precision(self):
10441        def helper(shape_in, shape_out, align_corners):
10442            for mode in ('bilinear', 'nearest', 'bicubic'):
10443                if len(shape_in) != 4 and mode == 'bicubic':
10444                    continue
10445                data = torch.randn(shape_in, device='cuda', dtype=torch.bfloat16)
10446                grid = torch.rand(shape_out, device='cuda', dtype=torch.bfloat16) * 2.0 - 1.0
10447
10448                out_half = F.grid_sample(data, grid, mode=mode, padding_mode='zeros', align_corners=align_corners)
10449                out_double = F.grid_sample(data.double(), grid.double(), mode=mode, padding_mode='zeros',
10450                                           align_corners=align_corners)
10451
10452                self.assertEqual(out_half, out_double.bfloat16(), msg=f"grid_sample with mode = {mode} doesn't match")
10453
10454        helper((32, 64, 16, 16), (32, 8, 8, 2), True)
10455        helper((32, 64, 16, 16, 16), (32, 8, 8, 8, 3), True)
10456        helper((32, 64, 16, 16), (32, 8, 8, 2), False)
10457        helper((32, 64, 16, 16, 16), (32, 8, 8, 8, 3), False)
10458
10459    def _test_gumbel_softmax_st_shapes(self, device, dtype, shape, dim, count_expected):
10460        logits = torch.randn(shape, dtype=torch.float, device=device)
10461        logits = logits.to(dtype)
10462
10463        y_draw = F.gumbel_softmax(logits, hard=True, dim=dim)
10464
10465        # All values positive
10466        self.assertGreaterEqual(y_draw.min(), 0)
10467        # Shape unchanged
10468        self.assertTrue(y_draw.shape == logits.shape)
10469        # One choice per draw
10470        self.assertEqual(y_draw.sum(), count_expected, atol=torch.finfo(y_draw.dtype).eps, rtol=0)
10471
10472    def _test_gumbel_softmax_straight_through(self, device, dtype):
10473        num_draws = 100
10474
10475        logits = torch.tensor([[0.2, 0.8, 0.1]], device=device)
10476        logits = logits.reshape([1, 3])
10477        logits = logits.to(dtype).requires_grad_()
10478        probs = logits.softmax(dim=-1)
10479
10480        counts = torch.zeros_like(logits)
10481        for _ in range(num_draws):
10482            y_draw = F.gumbel_softmax(logits, hard=True)
10483            counts = counts + y_draw
10484
10485        # All values positive
10486        self.assertGreaterEqual(y_draw.min(), 0)
10487        # Each experiment should result in 1 draw.
10488        self.assertEqual(counts.sum(), num_draws, atol=torch.finfo(counts.dtype).eps, rtol=0)
10489
10490        # check results is asymptotically as expected.
10491        expected = probs * num_draws
10492        # ~z is approximately N(0,1) for unbiased count
10493        z = (counts - expected) / (expected * (1 - probs)).sqrt()
10494        # A (lazy) approximate 99% two-sided test:
10495        # occurs with prob alpha~>=0.01 if unbiased
10496        self.assertLess(z.abs().max().item(), 2.58)
10497
10498    def _test_gumbel_softmax_grad(self, device, dtype):
10499        # "hard" and "not hard" should propagate same gradient.
10500        logits_soft = torch.zeros(10, 10, dtype=dtype, device=device, requires_grad=True)
10501        logits_hard = torch.zeros(10, 10, dtype=dtype, device=device, requires_grad=True)
10502
10503        seed = torch.random.get_rng_state()
10504        y_soft = F.gumbel_softmax(logits_soft, hard=False)
10505        torch.random.set_rng_state(seed)
10506        y_hard = F.gumbel_softmax(logits_hard, hard=True)
10507
10508        y_soft.sum().backward()
10509        y_hard.sum().backward()
10510
10511        # 2eps = 1x addition + 1x subtraction.
10512        tol = 2 * torch.finfo(dtype).eps
10513        self.assertEqual(logits_soft.grad, logits_hard.grad, atol=tol, rtol=0)
10514
10515    @dtypesIfCUDA(torch.half, torch.float, torch.double)
10516    @dtypesIfMPS(torch.float)
10517    @dtypes(torch.float, torch.double)
10518    def test_gumbel_softmax(self, device, dtype):
10519        self._test_gumbel_softmax_st_shapes(device, dtype, shape=[5], dim=0, count_expected=1)
10520        self._test_gumbel_softmax_st_shapes(device, dtype, shape=[5], dim=-1, count_expected=1)
10521        self._test_gumbel_softmax_st_shapes(device, dtype, shape=[5, 4], dim=1, count_expected=5)
10522        self._test_gumbel_softmax_st_shapes(device, dtype, shape=[5, 4, 3], dim=1, count_expected=5 * 3)
10523        self._test_gumbel_softmax_st_shapes(device, dtype, shape=[5, 4, 3], dim=-1, count_expected=5 * 4)
10524        self._test_gumbel_softmax_straight_through(device, dtype)
10525        self._test_gumbel_softmax_grad(device, dtype)
10526
10527    def _test_rnn_retain_variables(self, device, dtype):
10528        rnns = [nn.LSTM(10, 20, num_layers=2).to(device, dtype),
10529                nn.GRU(10, 20, num_layers=2).to(device, dtype),
10530                nn.RNN(10, 20, num_layers=2).to(device, dtype)]
10531        for rnn in rnns:
10532            input = torch.randn(5, 6, 10, device=device, dtype=dtype, requires_grad=True)
10533            output = rnn(input)
10534            output[0].sum().backward(retain_graph=True)
10535            grads = [input.grad.data.clone()] + [p.grad.data.clone() for p in rnn.parameters()]
10536            for _ in range(4):
10537                rnn.zero_grad()
10538                input.grad.data.zero_()
10539                output[0].sum().backward(retain_graph=True)
10540                grads2 = [input.grad.data] + [p.grad.data for p in rnn.parameters()]
10541                self.assertEqual(grads, grads2)
10542
10543    @dtypesIfCUDA(torch.half, torch.float, torch.double)
10544    @dtypesIfMPS(torch.half, torch.float)
10545    @dtypes(torch.double)
10546    def test_rnn_retain_variables(self, device, dtype):
10547        self._test_rnn_retain_variables(device, dtype)
10548
10549        if self.device_type == 'cuda' and self.has_cudnn():
10550            with torch.backends.cudnn.flags(enabled=False):
10551                self._test_rnn_retain_variables(device, dtype)
10552
10553    @onlyCUDA
10554    @dtypes(torch.double)
10555    def test_lstmcell_backward_only_one_output_grad(self, device, dtype):
10556        # checks that undefined gradients doen't hamper the backward
10557        # see #11872
10558        l = torch.nn.LSTMCell(2, 3).to(device).to(dtype=dtype)
10559        s = torch.randn(1, 2, device=device, dtype=dtype, requires_grad=True)
10560        for i in range(2):
10561            out = l(s)[i]
10562            out.sum().backward()
10563            self.assertFalse(s.grad is None or s.grad.abs().sum().item() == 0)
10564
10565    def _test_rnn_mod(self, mod, inp):
10566        def flatten_out(mod, inp):
10567            out = mod(inp)
10568            return tuple([t if isinstance(t, torch.Tensor) else tt for t in out for tt in t])
10569        gradcheckfunc = partial(flatten_out, mod)
10570        with torch.backends.cudnn.flags(enabled=False):
10571            gradcheck(gradcheckfunc, inp, check_batched_grad=False)
10572            gradgradcheck(gradcheckfunc, inp, check_batched_grad=False)
10573
10574        if inp.is_cuda and not TEST_WITH_ROCM:
10575            # Assert that we have good error message around unsupported CuDNN double backward
10576            # NB: we trigger double backward using .backward() instead of autograd.grad due to
10577            # https://github.com/pytorch/pytorch/issues/37874
10578            with torch.backends.cudnn.flags(enabled=True):
10579                result = gradcheckfunc(inp)
10580                result[0].sum().backward(create_graph=True)
10581                grad0 = next(mod.parameters()).grad
10582                with self.assertRaisesRegex(RuntimeError,
10583                                            "please disable the CuDNN backend temporarily"):
10584                    grad0.sum().backward()
10585
10586                # Here we avoid the backward(create_graph=True) memory leak
10587                # described in https://github.com/pytorch/pytorch/issues/7343
10588                for param in mod.parameters():
10589                    param.grad = None
10590                inp.grad = None
10591
10592    # Merge into OpInfo?
10593    @skipMeta  # LSTM cell reuses output which was resized
10594    @expectedFailureMPS  # TypeError: the MPS framework doesn't support float64
10595    @dtypes(torch.double)
10596    def test_LSTM_grad_and_gradgrad(self, device, dtype):
10597        hsize = 4
10598        inp = torch.rand(1, 3, hsize, device=device, dtype=dtype, requires_grad=True)
10599        for bias in [True, False]:
10600            mod = torch.nn.LSTM(hsize, hsize, bias=bias).to(device).to(dtype)
10601            self._test_rnn_mod(mod, inp)
10602
10603    @skipMeta  # GRU cell reuses output which was resized
10604    @expectedFailureMPS  # TypeError: the MPS framework doesn't support float64
10605    @dtypes(torch.double)
10606    def test_GRU_grad_and_gradgrad(self, device, dtype):
10607        hsize = 4
10608        inp = torch.rand(1, 3, hsize, device=device, dtype=dtype, requires_grad=True)
10609        for bias in [True, False]:
10610            mod = torch.nn.GRU(hsize, hsize, bias=bias).to(device).to(dtype)
10611            self._test_rnn_mod(mod, inp)
10612
10613    @skipMeta
10614    @dtypes(torch.float32, torch.bfloat16)
10615    @onlyCPU
10616    def test_LSTM_differentiable_backward_using_oneDNN(self, dtype):
10617        batch = 10
10618        seq_len = 12
10619        input = 3
10620        Net = nn.LSTM(input, 3, 20, batch_first=True)
10621        import copy
10622        Net_clone = copy.deepcopy(Net)
10623        x = torch.rand(batch, seq_len, input)
10624        x1 = x.clone().requires_grad_(True)
10625        x2 = x.clone().requires_grad_(True)
10626
10627        torch._C._set_mkldnn_enabled(False)
10628        out1, _ = Net(x1)
10629        der_out1 = torch.autograd.grad(out1, x1,
10630                                       grad_outputs=torch.ones_like(out1),
10631                                       retain_graph=True,
10632                                       create_graph=True)[0]
10633        loss1 = der_out1.sum()
10634        loss1.backward(retain_graph=True)
10635
10636        torch._C._set_mkldnn_enabled(True)
10637        out2, _ = Net(x2)
10638        der_out2 = torch.autograd.grad(out2, x2,
10639                                       grad_outputs=torch.ones_like(out2),
10640                                       retain_graph=True,
10641                                       create_graph=True)[0]
10642        loss2 = der_out2.sum()
10643        loss2.backward(retain_graph=True)
10644        assert torch.allclose(der_out1, der_out2)
10645        assert torch.allclose(x1.grad, x2.grad)
10646
10647    @onlyCUDA
10648    def test_upsamplingNearest1d_launch_config(self, device):
10649        m = nn.Upsample(scale_factor=2)
10650        inp = torch.rand(2**25, 1, 1, device=device)
10651        out = m(inp)
10652        inp_ref = inp.cpu()
10653        out_ref = m(inp_ref)
10654        self.assertEqual(out_ref, out)
10655
10656    @onlyCUDA
10657    def test_upsamplingNearest2d_launch_config(self, device):
10658        m = nn.Upsample(scale_factor=2)
10659        inp = torch.rand(2**25, 1, 1, 1, device=device)
10660        out = m(inp)
10661        inp_ref = inp.cpu()
10662        out_ref = m(inp_ref)
10663        self.assertEqual(out_ref, out)
10664
10665    @onlyCUDA
10666    @gcIfJetson
10667    def test_upsamplingNearest3d_launch_config(self, device):
10668        m = nn.Upsample(scale_factor=2)
10669        inp = torch.rand(2**25, 1, 1, 1, 1, device=device)
10670        out = m(inp)
10671        inp_ref = inp.cpu()
10672        out_ref = m(inp_ref)
10673        self.assertEqual(out_ref, out)
10674
10675    @unittest.expectedFailure
10676    @skipIfRocm
10677    @onlyCUDA
10678    def test_upsamplingNearest2d_launch_fail(self, device):
10679        m = nn.Upsample(scale_factor=2)
10680        # launch grid_y == 2**16 (larger than maximum y-dimension limit 65535)
10681        inp = torch.rand(1, 1, 2**15, 2**8, device=device)
10682        out = m(inp)
10683
10684    @onlyCUDA
10685    @skipCUDAIfNotRocm
10686    def test_upsamplingNearest2d_launch_rocm(self, device):
10687        # test_upsamplingNearest2d_launch_fail should run OK on ROCm
10688        m = nn.Upsample(scale_factor=2)
10689        inp = torch.rand(1, 1, 2**15, 2**8, device=device)
10690        out = m(inp)
10691
10692    @onlyCUDA
10693    @skipCUDAIfCudnnVersionLessThan(7600)
10694    def test_CTCLoss_cudnn(self, device):
10695        def _helper(zero_infinity):
10696            target_lengths = [30, 25, 20]
10697            input_lengths = [50, 50, 50]
10698            targets = torch.randint(1, 15, (sum(target_lengths),), dtype=torch.int)
10699            log_probs = torch.randn(50, 3, 15, dtype=torch.float, device=device).log_softmax(2).requires_grad_()
10700
10701            log_probs_ref = log_probs.detach().clone().requires_grad_()
10702
10703            with torch.backends.cudnn.flags(enabled=True):
10704                res = torch.nn.functional.ctc_loss(log_probs, targets, input_lengths, target_lengths, zero_infinity=zero_infinity)
10705                res.backward()
10706
10707            expected = ctcloss_reference(log_probs, targets.cuda(), input_lengths, target_lengths).float()
10708
10709            with torch.backends.cudnn.flags(enabled=False):
10710                res2 = torch.nn.functional.ctc_loss(log_probs_ref, targets.cuda().long(), input_lengths, target_lengths,
10711                                                    zero_infinity=zero_infinity)
10712                res2.backward()
10713
10714            self.assertEqual(res, expected)
10715            self.assertEqual(res2, res)
10716            self.assertEqual(log_probs.grad, log_probs_ref.grad)
10717
10718        _helper(zero_infinity=True)
10719        _helper(zero_infinity=False)
10720
10721    def _CTCLoss_gen_losses(self, device, input_length, vocab_size, target_length, reduction, use_module_form):
10722        batch_size = 1
10723        log_probs = torch.randn(input_length, batch_size, vocab_size, dtype=torch.float, device=device) \
10724                         .log_softmax(2).requires_grad_()
10725        targets = torch.randint(low=1, high=vocab_size - 1, size=(batch_size, target_length),
10726                                dtype=torch.int, device=device)
10727        input_lengths = batch_size * [input_length]
10728        target_lengths = batch_size * [target_length]
10729
10730        log_probs_no_bd = log_probs.squeeze(1).detach().clone().requires_grad_()
10731        targets_no_bd = targets.squeeze(0).detach().clone()
10732        input_lengths_no_bd = torch.tensor(input_length)
10733        target_lengths_no_bd = torch.tensor(target_length)
10734
10735        # currently only length 2 and 1 right now, but left flexible for additional potential cases
10736        log_probs_refs = [log_probs.detach().clone().requires_grad_() for _ in range(2)]
10737        log_probs_no_bd_refs = [log_probs_no_bd.detach().clone().requires_grad_() for _ in range(1)]
10738
10739        losses = []
10740        losses_no_bd = []
10741
10742        has_cuda = torch.cuda.is_available()
10743        has_cudnn = has_cuda and 'cuda' in device and self.has_cudnn()
10744        # cudnn requires a cpu target
10745        if has_cuda and has_cudnn:
10746            targets = targets.cpu()
10747            targets_no_bd = targets_no_bd.cpu()
10748
10749        ctc_loss = (
10750            nn.CTCLoss(reduction=reduction, zero_infinity=True)
10751            if use_module_form
10752            else partial(torch.nn.functional.ctc_loss, reduction=reduction, zero_infinity=True)
10753        )
10754
10755        with torch.backends.cudnn.flags(enabled=has_cudnn):
10756            # batched case. log_probs.shape = (T, N, C), targets = (N, S), input_lengths/target_lengths = (N,)
10757            losses.append(ctc_loss(log_probs_refs[0], targets, input_lengths, target_lengths))
10758            # batched case. input.shape = (T, N, C), targets = (S,), input_lengths/target_lengths = (N,)
10759            losses.append(ctc_loss(log_probs_refs[1], targets_no_bd, input_lengths, target_lengths))
10760            # unbatched case. input.shape = (T, C), targets = (S,), input_lengths/target_lengths = (N,)
10761            losses_no_bd.append(ctc_loss(log_probs_no_bd_refs[0], targets_no_bd,
10762                                         input_lengths_no_bd, target_lengths_no_bd))
10763
10764            for loss in losses + losses_no_bd:
10765                loss.backward()
10766
10767        return losses, losses_no_bd, log_probs_refs, log_probs_no_bd_refs
10768
10769    def _assertEqual_list(self, expected, list_to_compare, atol=None, rtol=None):
10770        for ele in list_to_compare:
10771            self.assertEqual(expected, ele, atol=atol, rtol=rtol)
10772
10773    @expectedFailureMPS  # NotImplementedError: aten::_ctc_loss https://github.com/pytorch/pytorch/issues/77764
10774    @parametrize_test("reduction", ['none', 'mean', 'sum'])
10775    @parametrize_test("use_module_form", [True, False])
10776    def test_CTCLoss_no_batch_dim(self, device, reduction, use_module_form):
10777        input_length = 40
10778        vocab_size = 3
10779        target_length = 12
10780
10781        args = self._CTCLoss_gen_losses(device, input_length, vocab_size, target_length, reduction, use_module_form)
10782        losses, losses_no_bd, log_probs_refs, log_probs_no_bd_refs = args
10783
10784        # test output values
10785        self._assertEqual_list(losses[0], losses[1:], atol=1e-4, rtol=0)
10786        self._assertEqual_list(losses[0].squeeze(0), losses_no_bd, atol=1e-4, rtol=0)
10787
10788        # test gradient values
10789        self._assertEqual_list(log_probs_refs[0].grad, [t.grad for t in log_probs_refs[1:]], atol=1e-4, rtol=0)
10790        self._assertEqual_list(
10791            log_probs_refs[0].grad.squeeze(1),
10792            [t.grad for t in log_probs_no_bd_refs],
10793            atol=1e-4,
10794            rtol=0,
10795        )
10796
10797        # checking the output's shape
10798        # batch dim case should be (N,). no batch dim case should be ()
10799        self._assertEqual_list((1,) if reduction == 'none' else (), [loss.shape for loss in losses])
10800        self._assertEqual_list((), [loss.shape for loss in losses_no_bd])
10801
10802        # checking the gradient's shape
10803        # batch dim case should have shape (T, N, C). no batch dim case should have shape (T, C)
10804        self._assertEqual_list((input_length, 1, vocab_size), [t.grad.shape for t in log_probs_refs])
10805        self._assertEqual_list((input_length, vocab_size), [t.grad.shape for t in log_probs_no_bd_refs])
10806
10807    def _ordered_sequence(self, device, dtype):
10808        """Create ordered list of random sequences"""
10809        seqs = [torch.empty(random.randint(1, 6), device=device, dtype=dtype)
10810                for _ in range(5)]
10811        seqs = [s.random_(-128, 128) for s in seqs]
10812        ordered = sorted(seqs, key=len, reverse=True)
10813        return ordered
10814
10815    def _padded_sequence(self, device, dtype):
10816        """Create Tensor of random padded sequences"""
10817        ordered = self._ordered_sequence(device, dtype)
10818        lengths = [len(i) for i in ordered]
10819        padded_tensor = rnn_utils.pad_sequence(ordered)
10820        return padded_tensor, lengths
10821
10822    @onlyCUDA
10823    def test_device_mask(self, device):
10824        for enforce_sorted in [True, False]:
10825            padded, lengths = self._padded_sequence('cpu', torch.float)
10826            packed = rnn_utils.pack_padded_sequence(
10827                padded, lengths, enforce_sorted=enforce_sorted)
10828            self.assertFalse(packed.is_cuda)
10829            packed = packed.to(device)
10830            self.assertTrue(packed.is_cuda)
10831            unpacked, _ = rnn_utils.pad_packed_sequence(packed)
10832            self.assertTrue(unpacked.is_cuda)
10833            self.assertEqual(unpacked.dtype, torch.float)
10834
10835    @onlyCUDA
10836    def test_overwrite_module_params_on_conversion_cpu_device(self, device):
10837        # Test that under the current default settings
10838        # (`torch.__future__.get_overwrite_module_params_on_conversion() == False`),
10839        # a view to a module's parameters is not pointing to the same storage as
10840        # its base variable after converting the module to a different device.
10841        m = nn.Linear(20, 10)
10842        mw = m.weight[:]
10843        m.to(device)
10844        with torch.no_grad():
10845            # Without using `torch.no_grad()`, this will leak CUDA memory.
10846            # (Issue is filed at https://github.com/pytorch/pytorch/issues/21875)
10847            mw[0][0] = 5
10848            self.assertTrue(mw[0][0].device.type == "cpu")
10849            self.assertTrue(mw._base[0][0].device.type == "cuda")
10850
10851        try:
10852            torch.__future__.set_overwrite_module_params_on_conversion(True)
10853
10854            # Test that if `torch.__future__.get_overwrite_module_params_on_conversion() == True`,
10855            # a view to a module's parameters is still pointing to the same storage as
10856            # its base variable after converting the module to a different device.
10857            m = nn.Linear(20, 10)
10858            mw = m.weight[:]
10859            m.to(device)
10860            with torch.no_grad():
10861                mw[0][0] = 5
10862            self.assertTrue(mw[0][0] == mw._base[0][0])
10863
10864            # Test that if `torch.__future__.get_overwrite_module_params_on_conversion() == True`,
10865            # `cpu_module.to("cuda")` doesn't preserve previous references to
10866            # `cpu_module`'s parameters or gradients.
10867            m = nn.Linear(20, 10)
10868            m.weight.grad = torch.randn(10, 20)
10869            weight_ref = m.weight
10870            weight_grad_ref = m.weight.grad
10871            m.to(device)
10872            self.assertNotEqual(weight_ref.device, m.weight.device)
10873            self.assertNotEqual(weight_grad_ref.device, m.weight.grad.device)
10874        finally:
10875            torch.__future__.set_overwrite_module_params_on_conversion(False)
10876
10877    @onlyCUDA
10878    @dtypes(torch.half, torch.float)
10879    def test_softmax(self, device, dtype):
10880        input = torch.rand(32, 100, device=device, dtype=dtype, requires_grad=True)
10881        inputf = input.to(torch.float).detach().requires_grad_(True)
10882        out = F.softmax(input, dim=-1, dtype=torch.float)
10883        outf = F.softmax(inputf, dim=-1)
10884        # should be bitwise equal
10885        self.assertEqual(out, outf, atol=0, rtol=0)
10886        gO = torch.empty_like(outf).uniform_()
10887        out.backward(gO)
10888        outf.backward(gO)
10889        # should be bitwise equal
10890        self.assertEqual(input.grad, inputf.grad.to(dtype), atol=0, rtol=0)
10891
10892    def _test_batchnorm_grad(self, device, dtype=torch.double):
10893        bs, n_feat, size_feat = 4, 5, 6
10894        input = torch.arange(bs * n_feat * size_feat, device=device,
10895                             requires_grad=True, dtype=dtype).view(bs, n_feat, size_feat)
10896        weight = torch.arange(1, n_feat + 1, device=device, requires_grad=True, dtype=dtype)
10897        bias = torch.arange(n_feat, device=device, requires_grad=True, dtype=dtype)
10898        running_mean = 1 - torch.arange(n_feat, device=device, dtype=dtype)
10899        running_var = 2 * torch.arange(n_feat, device=device, dtype=dtype)
10900        for training in [False, True]:
10901            _assertGradAndGradgradChecks(self, F.batch_norm, (input, running_mean, running_var, weight, bias,
10902                                                              training, 0.1, 0.0001))
10903
10904    @expectedFailureMPS  # TypeError: the MPS framework doesn't support float64
10905    def test_batchnorm_grad(self, device):
10906        self._test_batchnorm_grad(device)
10907
10908        if self.device_type == 'cuda' and self.has_cudnn():
10909            with torch.backends.cudnn.flags(enabled=False):
10910                self._test_batchnorm_grad(device)
10911
10912    @onlyCUDA
10913    def test_layernorm_half_precision(self):
10914        width = 128
10915        input = torch.rand(1, 5, width, device="cuda", dtype=torch.half) * 0.1
10916        normalized_shape = (width,)
10917        weight = torch.ones(width, device="cuda", dtype=torch.half)
10918        bias = torch.zeros(width, device="cuda", dtype=torch.half)
10919        eps = 1e-5
10920
10921        output_fp16 = torch.layer_norm(input, normalized_shape, weight, bias, eps)
10922        output_fp32 = torch.layer_norm(input.float(), normalized_shape, weight.float(), bias.float(), eps).half()
10923        self.assertEqual(output_fp16, output_fp32, atol=0, rtol=0)
10924
10925    @onlyCUDA
10926    def test_layernorm_weight_bias(self):
10927        width = 128
10928        input = torch.rand(1, 5, width, device="cuda", dtype=torch.float32) * 0.1
10929        normalized_shape = (width,)
10930        data = torch.randn(width, device="cuda", dtype=torch.float32)
10931        weight = torch.ones(width, device="cuda", dtype=torch.float32)
10932        bias = torch.zeros(width, device="cuda", dtype=torch.float32)
10933        eps = 1e-5
10934
10935        out_none_weight = torch.layer_norm(input, normalized_shape, None, data, eps)
10936        out_one_weight = torch.layer_norm(input, normalized_shape, weight, data, eps)
10937        self.assertEqual(out_none_weight, out_one_weight)
10938
10939        out_none_bias = torch.layer_norm(input, normalized_shape, data, None, eps)
10940        out_zero_bias = torch.layer_norm(input, normalized_shape, data, bias, eps)
10941        self.assertEqual(out_none_bias, out_zero_bias)
10942
10943    @expectedFailureMPS  # TypeError: the MPS framework doesn't support float64
10944    def test_hardsigmoid_grad(self, device):
10945        inputs = (torch.randn(4, 16, 16, device=device, dtype=torch.double) - 0.5) * 10
10946        inputs.requires_grad = True
10947        self.assertTrue(gradcheck(F.hardsigmoid, (inputs,)))
10948
10949    # currently fails on XLA
10950    @onlyNativeDeviceTypes
10951    def test_hardswish_grad(self, device):
10952        inputs = (torch.randn(4, 16, 16, device=device, dtype=torch.double) - 0.5) * 10
10953        inputs.requires_grad = True
10954        self.assertTrue(gradcheck(F.hardswish, (inputs,)))
10955
10956
10957    def _test_batchnorm_eval(self, ndim, device, dtype, module_dtype=None):
10958        module_dtype = module_dtype or dtype
10959        module = nn.BatchNorm1d(3).to(device, module_dtype)
10960        module.eval()
10961
10962        data = torch.rand([3] * ndim, device=device, dtype=dtype, requires_grad=True)
10963        grad = torch.rand([3] * ndim, device=device, dtype=dtype)
10964
10965        # 1st pass
10966        res1 = module(data)
10967        res1.backward(grad)
10968        grad1 = data.grad.clone()
10969
10970        # 2nd pass
10971        if data.grad is not None:
10972            data.grad.data.zero_()
10973
10974        res2 = module(data)
10975        res2.backward(grad)
10976        grad2 = data.grad.clone()
10977        self.assertEqual(res1, res2)
10978        self.assertEqual(grad1, grad2)
10979
10980        # track_running_stats=False
10981        module = nn.BatchNorm1d(3, track_running_stats=False).to(device, module_dtype)
10982
10983        data = torch.rand(4, 3, device=device, dtype=dtype, requires_grad=True)
10984        grad = torch.rand(4, 3, device=device, dtype=dtype)
10985
10986        # 1st pass
10987        res1 = module(data)
10988        res1.backward(grad)
10989        grad1 = data.grad.clone()
10990
10991        # set eval
10992        module.eval()
10993
10994        # 2nd pass
10995        if data.grad is not None:
10996            data.grad.data.zero_()
10997
10998        res2 = module(data)
10999        res2.backward(grad)
11000        grad2 = data.grad.clone()
11001        self.assertEqual(res1, res2)
11002        self.assertEqual(grad1, grad2)
11003
11004    @dtypes(torch.float)
11005    @dtypesIfCUDA(torch.float, torch.bfloat16)
11006    def test_batchnorm_eval(self, device, dtype):
11007        self._test_batchnorm_eval(2, device, dtype)
11008        self._test_batchnorm_eval(3, device, dtype)
11009
11010        if self.device_type == 'cuda' and self.has_cudnn():
11011            with torch.backends.cudnn.flags(enabled=False):
11012                self._test_batchnorm_eval(2, device, dtype)
11013                self._test_batchnorm_eval(3, device, dtype)
11014
11015    @onlyCUDA
11016    @dtypes(torch.bfloat16, torch.half)
11017    def test_batchnorm_eval_mixed(self, device, dtype):
11018        # Test bfloat16 input with float module
11019        self._test_batchnorm_eval(2, device, dtype, torch.float)
11020        self._test_batchnorm_eval(3, device, dtype, torch.float)
11021
11022        if self.device_type == 'cuda' and self.has_cudnn():
11023            with torch.backends.cudnn.flags(enabled=False):
11024                self._test_batchnorm_eval(2, device, dtype, torch.float)
11025                self._test_batchnorm_eval(3, device, dtype, torch.float)
11026
11027    def _test_batchnorm_affine(self, ndim, device, dtype, module_dtype=None):
11028        # Compare affine against no-op weights and bias
11029        module_dtype = module_dtype or dtype
11030        module = nn.BatchNorm1d(3, affine=False).to(device, module_dtype)
11031        module_affine = nn.BatchNorm1d(3, affine=True).to(device, module_dtype)
11032        with torch.no_grad():
11033            module_affine.weight.fill_(1.0)
11034            module_affine.bias.zero_()
11035
11036        data = torch.rand([3] * ndim, device=device, dtype=dtype, requires_grad=True)
11037        grad = torch.ones_like(data, requires_grad=False)
11038
11039        # With weights all ones and bias all zeros
11040        res1 = module_affine(data)
11041        res1.backward(grad)
11042        grad1 = data.grad.clone()
11043        data.grad.zero_()
11044
11045        # Without any weights or bias
11046        res2 = module(data)
11047        res2.backward(grad)
11048        grad2 = data.grad
11049
11050        self.assertEqual(res1, res2)
11051        self.assertEqual(grad1, grad2)
11052
11053    @dtypes(torch.float)
11054    @dtypesIfCUDA(torch.float, torch.bfloat16)
11055    def test_batchnorm_affine(self, device, dtype):
11056        self._test_batchnorm_affine(2, device, dtype)
11057        self._test_batchnorm_affine(3, device, dtype)
11058
11059        if self.device_type == 'cuda' and self.has_cudnn():
11060            with torch.backends.cudnn.flags(enabled=False):
11061                self._test_batchnorm_affine(2, device, dtype)
11062                self._test_batchnorm_affine(3, device, dtype)
11063
11064    @onlyCUDA
11065    @dtypes(torch.bfloat16, torch.half)
11066    def test_batchnorm_affine_mixed(self, device, dtype):
11067        cudnn_enabled = [False]
11068        if self.device_type == 'cuda' and self.has_cudnn():
11069            # TODO: Test fails with cudnn, see gh-62034
11070            # cudnn_enabled = [False, True]
11071            pass
11072
11073        # Test bfloat16 input with float module
11074        for enabled in cudnn_enabled:
11075            with torch.backends.cudnn.flags(enabled=enabled):
11076                self._test_batchnorm_affine(2, device, dtype, torch.float)
11077                self._test_batchnorm_affine(3, device, dtype, torch.float)
11078
11079    def _test_batchnorm_simple_average(self, device, dtype, module_dtype=None):
11080        module_dtype = module_dtype or dtype
11081        module = nn.BatchNorm1d(3, momentum=None).to(dtype=module_dtype, device=device)
11082        zeros = torch.zeros(3, dtype=module_dtype, device=device)
11083        ones = torch.ones(3, dtype=module_dtype, device=device)
11084        self.assertEqual(module.running_mean, zeros)
11085        self.assertEqual(module.running_var, ones)
11086
11087        data1 = torch.rand(4, 3, dtype=dtype, device=device)
11088        data2 = torch.rand(4, 3, dtype=dtype, device=device)
11089
11090        # 1st pass
11091        res1 = module(data1)
11092        running_mean1 = module.running_mean.clone()
11093        running_var1 = module.running_var.clone()
11094        self.assertNotEqual(running_mean1, zeros)
11095        self.assertNotEqual(running_var1, ones)
11096
11097        # reset stats
11098        module.reset_running_stats()
11099        self.assertEqual(module.running_mean, zeros)
11100        self.assertEqual(module.running_var, ones)
11101
11102        # 2nd pass
11103        res2 = module(data2)
11104        running_mean2 = module.running_mean.clone()
11105        running_var2 = module.running_var.clone()
11106        self.assertNotEqual(running_mean2, zeros)
11107        self.assertNotEqual(running_var2, ones)
11108
11109        # reset stats
11110        module.reset_running_stats()
11111        self.assertEqual(module.running_mean, zeros)
11112        self.assertEqual(module.running_var, ones)
11113
11114        # 3rd (combined) pass
11115        res3 = module(data1)
11116        res4 = module(data2)
11117        self.assertEqual(res3, res1)
11118        self.assertEqual(res4, res2)
11119        self.assertEqual(module.running_mean, (running_mean1 + running_mean2) / 2)
11120        self.assertEqual(module.running_var, (running_var1 + running_var2) / 2)
11121
11122    @dtypes(torch.float)
11123    @dtypesIfCUDA(torch.float, torch.bfloat16)
11124    def test_batchnorm_simple_average(self, device, dtype):
11125        self._test_batchnorm_simple_average(device, dtype)
11126
11127        if self.device_type == 'cuda' and self.has_cudnn():
11128            with torch.backends.cudnn.flags(enabled=False):
11129                self._test_batchnorm_simple_average(device, dtype)
11130
11131    @onlyCUDA
11132    @dtypes(torch.bfloat16, torch.half)
11133    def test_batchnorm_simple_average_mixed(self, device, dtype):
11134        self._test_batchnorm_simple_average(device, dtype, torch.float)
11135
11136        if self.device_type == 'cuda' and self.has_cudnn():
11137            with torch.backends.cudnn.flags(enabled=False):
11138                self._test_batchnorm_simple_average(device, dtype, torch.float)
11139
11140    @onlyNativeDeviceTypes
11141    @dtypes(torch.float, torch.double)
11142    def test_grid_sample_nan_inf(self, device, dtype):
11143        input = torch.zeros([1, 1, 3, 3], device=device, dtype=dtype)
11144        grid = torch.tensor([[[[nan, 0], [0, inf]]]], device=device, dtype=dtype)
11145        for padding_mode in ('reflection', 'border', 'zeros'):
11146            sample = torch.nn.functional.grid_sample(input=input, grid=grid, mode='nearest',
11147                                                     padding_mode=padding_mode, align_corners=False)
11148            self.assertEqual(sample, torch.zeros([1, 1, 1, 2], device=device, dtype=dtype))
11149
11150    @expectedFailureMPS  # NotImplementedError aten::_ctc_loss https://github.com/pytorch/pytorch/issues/77764
11151    def test_CTCLoss_empty_target(self, device):
11152        target_lengths = [0, 0, 0]
11153        input_lengths = [50, 50, 50]
11154        targets = torch.randint(1, 15, (0,), dtype=torch.long, device=device)
11155        log_probs = torch.randn(50, 3, 15, dtype=torch.double, device=device).log_softmax(2)
11156        loss = torch.nn.functional.ctc_loss(log_probs, targets, input_lengths, target_lengths, reduction='none')
11157        self.assertTrue((loss >= 0).all().item())
11158        self.assertEqual(-log_probs.sum(0)[:, 0], loss)
11159
11160        target_lengths = [0, 9, 0]
11161        input_lengths = [50, 50, 50]
11162        targets = torch.randint(1, 15, (9,), dtype=torch.long, device=device)
11163        log_probs = torch.randn(50, 3, 15, dtype=torch.double, device=device).log_softmax(2)
11164        loss = torch.nn.functional.ctc_loss(log_probs, targets, input_lengths, target_lengths, reduction='none')
11165        self.assertTrue((loss >= 0).all().item())
11166        self.assertEqual(-log_probs.sum(0)[[0, 2], 0], loss[[0, 2]])
11167
11168    # Merge into OpInfo?
11169    @skipCUDAIf(True, """Test is flaky on Linux and Windows, typical error message:
11170                          https://github.com/pytorch/pytorch/issues/34870""")
11171    @expectedFailureMPS  # NotImplementedError aten::_ctc_loss https://github.com/pytorch/pytorch/issues/77764
11172    def test_ctc_loss(self, device):
11173        batch_size = 64
11174        num_labels = 101
11175        target_length = 15
11176        gradcheck_input_size = 10
11177
11178        ZERO_NONE = 0
11179        ZERO_SOME = 1
11180        ZERO_ALL = 2
11181
11182        # input_length, vary_lengths, zero_lengths
11183        tests = [(150, False, ZERO_NONE),
11184                 (150, True, ZERO_NONE),
11185                 (50, True, ZERO_SOME),
11186                 (50, True, ZERO_ALL)]
11187
11188        if 'cuda' in device:
11189            tests += [(50, False, ZERO_NONE),
11190                      (50, True, ZERO_NONE),
11191                      (150, True, ZERO_SOME),
11192                      (150, True, ZERO_ALL)]
11193
11194        for input_length, vary_lengths, zero_mode in tests:
11195            targets = torch.randint(1, num_labels, (batch_size, target_length),
11196                                    device=device, dtype=torch.long)
11197            x = torch.randn(gradcheck_input_size, dtype=torch.double, device=device, requires_grad=True)
11198            tile_factors = torch.randn(input_length * batch_size * num_labels // gradcheck_input_size + 1,
11199                                       device=device)
11200            input_lengths = [(torch.randint(input_length // 2, input_length + 1, ()).item()
11201                              if vary_lengths or i == 0 else input_length) for i in range(batch_size)]
11202            if zero_mode == ZERO_ALL:
11203                target_lengths = [0 for _ in range(batch_size)]
11204            else:
11205                target_lengths = [(torch.randint(target_length // 2, target_length + 1, ()).item()
11206                                   if vary_lengths else target_length) for _ in range(batch_size)]
11207                if zero_mode == ZERO_SOME:
11208                    idxes = torch.randint(0, batch_size, (10,))
11209                    for i in idxes:
11210                        target_lengths[i] = 0
11211
11212            def ctc_after_softmax(x):
11213                x_full = ((x[:, None] * tile_factors[None, :]).view(-1)[:input_length * batch_size * num_labels]
11214                          .view(input_length, batch_size, num_labels))
11215                log_probs = torch.log_softmax(x_full, 2)
11216                return torch.nn.functional.ctc_loss(log_probs, targets, input_lengths, target_lengths)
11217
11218            gradcheck(ctc_after_softmax, [x])
11219
11220    @onlyCUDA
11221    @skipCUDAIfRocm(msg="skipped Cudnn test on ROCm")
11222    @skipCUDAIfCudnnVersionLessThan(7600)
11223    def test_ctc_loss_cudnn(self, device):
11224        batch_size = 16
11225        input_length = 30
11226        num_labels = 101
11227        target_length = 15
11228        targets = torch.randint(1, num_labels, (batch_size * target_length,),
11229                                device='cuda', dtype=torch.long)
11230        log_probs = torch.log_softmax(torch.randn(input_length, batch_size, num_labels, device='cuda', dtype=torch.float), 2)
11231        log_probs.requires_grad_()
11232
11233        input_lengths = batch_size * [input_length]
11234        target_lengths = batch_size * [target_length]
11235        grad_out = torch.randn(batch_size, device='cuda', dtype=torch.float)
11236        with torch.backends.cudnn.flags(enabled=False):
11237            loss_native = torch.nn.functional.ctc_loss(log_probs, targets, input_lengths, target_lengths, reduction='none')
11238            grad_native, = torch.autograd.grad(loss_native, log_probs, grad_out)
11239        loss_cudnn = torch.nn.functional.ctc_loss(log_probs, targets.to('cpu', torch.int32),
11240                                                  input_lengths, target_lengths, reduction='none')
11241        self.assertTrue("Cudnn" in str(loss_cudnn.grad_fn))
11242        grad_cudnn, = torch.autograd.grad(loss_cudnn, log_probs, grad_out)
11243        self.assertEqual(grad_cudnn, grad_native, atol=1e-4, rtol=0)
11244
11245    @onlyCUDA
11246    @skipCUDAIfRocm(msg="skipped Cudnn test on ROCm")
11247    @skipCUDAIfCudnnVersionLessThan(8000)
11248    def test_ctc_loss_cudnn_tensor(self, device):
11249        batch_size = 16
11250        input_length = 30
11251        num_labels = 101
11252        target_length = 15
11253        targets = torch.randint(1, num_labels, (batch_size * target_length,),
11254                                device='cuda', dtype=torch.long)
11255        log_probs = torch.log_softmax(torch.randn(input_length, batch_size, num_labels, device='cuda', dtype=torch.float), 2)
11256        log_probs.requires_grad_()
11257
11258        input_lengths = batch_size * [input_length]
11259        input_lengths = torch.linspace(start=15, end=input_length, steps=batch_size, dtype=torch.long, device='cuda')
11260        target_lengths = torch.tensor(batch_size * [target_length], dtype=torch.long, device='cuda')
11261        grad_out = torch.randn(batch_size, device='cuda', dtype=torch.float)
11262        with torch.backends.cudnn.flags(enabled=False):
11263            loss_native = torch.nn.functional.ctc_loss(log_probs, targets, input_lengths, target_lengths, reduction='none')
11264            grad_native, = torch.autograd.grad(loss_native, log_probs, grad_out)
11265        loss_cudnn = torch.nn.functional.ctc_loss(log_probs,
11266                                                  targets.to('cuda', torch.int32),
11267                                                  input_lengths.to('cuda', torch.int32),
11268                                                  target_lengths.to('cuda', torch.int32),
11269                                                  reduction='none')
11270        self.assertTrue("Cudnn" in str(loss_cudnn.grad_fn))
11271        grad_cudnn, = torch.autograd.grad(loss_cudnn, log_probs, grad_out)
11272        self.assertEqual(grad_cudnn, grad_native, atol=1e-4, rtol=0)
11273
11274    @expectedFailureMPS  # RuntimeError: LSTM with projections is not currently supported with MPS.
11275    @dtypesIfCUDA(torch.half, torch.float, torch.double)
11276    @dtypes(torch.float)
11277    @tf32_on_and_off(0.005)
11278    @skipIfTorchDynamo("TorchDynamo fails here for unknown reasons")
11279    def test_variable_sequence(self, device, dtype):
11280        def pad(var, length):
11281            if var.size(0) == length:
11282                return var
11283            return torch.cat([var, var.new_zeros(length - var.size(0), *var.size()[1:])])
11284
11285        def maybe_index_tuple(maybe_tuple_of_tensors, index):
11286            if maybe_tuple_of_tensors is None:
11287                return None
11288            return tuple(maybe_tuple_of_tensors[j][:, index:index + 1, :].contiguous()
11289                         for j in range(2))
11290
11291        def check_lengths(lengths, enforce_sorted, use_default_hiddens, proj_size):
11292            input_size = 3
11293            hidden_size = 4
11294            num_layers = 2
11295            bidirectional = True
11296
11297            max_length = max(lengths)
11298            x_leaf = torch.randn(max_length, len(lengths), input_size, device=device,
11299                                 dtype=dtype, requires_grad=True)
11300            num_directions = 2 if bidirectional else 1
11301            lstm = nn.LSTM(input_size, hidden_size, bidirectional=bidirectional,
11302                           num_layers=num_layers, proj_size=proj_size).to(device, dtype)
11303            lstm2 = deepcopy(lstm).to(device, dtype)
11304            x = x_leaf
11305
11306            hidden0 = None
11307            if not use_default_hiddens:
11308                real_hidden_size = hidden_size if proj_size == 0 else proj_size
11309                hidden0 = (torch.randn(num_directions * num_layers, len(lengths), real_hidden_size,
11310                                       device=device, dtype=dtype),
11311                           torch.randn(num_directions * num_layers, len(lengths), hidden_size,
11312                                       device=device, dtype=dtype))
11313
11314            # Compute sequences separately
11315            seq_outs = []
11316            seq_hiddens = []
11317            for i, l in enumerate(lengths):
11318                hidden_i = maybe_index_tuple(hidden0, i)
11319                out, hid = lstm2(x[:l, i:i + 1], hidden_i)
11320                out_pad = pad(out, max_length)
11321                seq_outs.append(out_pad)
11322                seq_hiddens.append(hid)
11323            seq_out = torch.cat(seq_outs, 1)
11324            seq_hidden = tuple(torch.cat(hids, 1) for hids in zip(*seq_hiddens))
11325
11326            # Use packed format
11327            packed = rnn_utils.pack_padded_sequence(x, lengths, enforce_sorted=enforce_sorted)
11328            packed_out, packed_hidden = lstm(packed, hidden0)
11329            unpacked, unpacked_len = rnn_utils.pad_packed_sequence(packed_out)
11330
11331            # Check forward
11332            prec = dtype2prec_DONTUSE[dtype]
11333            self.assertEqual(packed_hidden, seq_hidden, atol=prec, rtol=0)
11334            self.assertEqual(unpacked, seq_out, atol=prec, rtol=0)
11335            self.assertEqual(unpacked_len, lengths, atol=prec, rtol=0)
11336
11337            # Check backward
11338            seq_out.sum().backward()
11339            grad_x = x_leaf.grad.data.clone()
11340            x_leaf.grad.data.zero_()
11341            unpacked.sum().backward()
11342
11343            self.assertEqual(x_leaf.grad, grad_x, atol=dtype2prec_DONTUSE[dtype], rtol=0)
11344            for p1, p2 in zip(lstm.parameters(), lstm2.parameters()):
11345                prec = dtype2prec_DONTUSE[dtype]
11346                if dtype == torch.float16:
11347                    prec = 4e-2
11348                self.assertEqual(p1.grad, p2.grad, atol=prec, rtol=0)
11349
11350        tests = [
11351            # enforce_sorted, lengths
11352            [True, [5]],
11353            [False, [5]],
11354            [True, [10, 10, 6, 2, 2, 1, 1]],
11355            [False, [10, 10, 6, 2, 2, 1, 1]],
11356            [False, [2, 1, 3, 2, 10, 5, 3]],
11357        ]
11358
11359        for enforce_sorted, seq_lens, in tests:
11360            for use_default_hiddens in (True, False):
11361                for proj_size in [0, 2]:
11362                    check_lengths(seq_lens, enforce_sorted, use_default_hiddens, proj_size)
11363
11364    def _test_batchnorm_update_stats(self, device, dtype=torch.float):
11365        module = nn.BatchNorm1d(3).to(device, dtype)
11366
11367        data = torch.rand(4, 3, device=device, dtype=dtype)
11368
11369        # training pass
11370        old_running_mean = module.running_mean.clone()
11371        old_running_var = module.running_var.clone()
11372        old_num_batches_tracked = module.num_batches_tracked.clone()
11373        module(data)
11374        self.assertNotEqual(old_running_mean, module.running_mean)
11375        self.assertNotEqual(old_running_var, module.running_var)
11376        self.assertEqual(old_num_batches_tracked + 1, module.num_batches_tracked)
11377
11378        # eval pass
11379        module.eval()
11380        old_running_mean = module.running_mean.clone()
11381        old_running_var = module.running_var.clone()
11382        old_num_batches_tracked = module.num_batches_tracked.clone()
11383        module(data)
11384        self.assertEqual(old_running_mean, module.running_mean)
11385        self.assertEqual(old_running_var, module.running_var)
11386        self.assertEqual(old_num_batches_tracked, module.num_batches_tracked)
11387
11388    def test_batchnorm_update_stats(self, device):
11389        self._test_batchnorm_update_stats(device)
11390
11391        if self.device_type == 'cuda' and self.has_cudnn():
11392            with torch.backends.cudnn.flags(enabled=False):
11393                self._test_batchnorm_update_stats(device)
11394
11395    @onlyCPU
11396    @dtypes(torch.bfloat16, torch.float16)
11397    def test_activations_bfloat16_half_cpu(self, device, dtype):
11398        def test_helper(fn, device, inp_dims, prec=None):
11399            torch.manual_seed(37)
11400            # bfloat16/half compute
11401            fn = fn.to(dtype=dtype)
11402            input = torch.randn(inp_dims, dtype=dtype, device=device, requires_grad=True)
11403            out = fn(input)
11404            grad_input = torch.randn_like(out, dtype=dtype, device=device)
11405            out.backward(grad_input)
11406
11407            # fp32 compute
11408            input2 = input.detach().clone().float().requires_grad_(True)
11409            out2 = fn.float()(input2)
11410            grad_input2 = grad_input.detach().clone().float()
11411            out2.backward(grad_input2)
11412
11413            self.assertEqual(out.dtype, dtype)
11414            self.assertEqual(input.grad.dtype, dtype)
11415            self.assertEqual(out, out2.to(dtype=dtype), atol=prec, rtol=prec)
11416            self.assertEqual(input.grad.data, input2.grad.data.to(dtype=dtype), atol=prec, rtol=prec)
11417
11418        shapes = [[1, 3, 1, 6], [1, 3, 1, 128], [1, 3, 256, 256]]
11419        for shape in shapes:
11420            test_helper(torch.nn.LogSigmoid(), device, shape)
11421            test_helper(torch.nn.Hardsigmoid(), device, shape)
11422            test_helper(torch.nn.Hardshrink(), device, shape)
11423            test_helper(torch.nn.Softshrink(), device, shape)
11424            test_helper(torch.nn.Hardswish(), device, shape)
11425            test_helper(torch.nn.Softplus(), device, shape)
11426            test_helper(torch.nn.SiLU(), device, shape)
11427            test_helper(torch.nn.Hardtanh(), device, shape)
11428            test_helper(torch.nn.Mish(), device, shape)
11429            test_helper(torch.nn.ELU(), device, shape)
11430            test_helper(torch.nn.PReLU(), device, shape)
11431            test_helper(torch.nn.GLU(), device, shape, prec=1e-2)
11432            test_helper(torch.nn.Threshold(0.1, 20), device, shape)
11433            test_helper(torch.nn.GELU(), device, shape)
11434            test_helper(torch.nn.Hardtanh(), device, shape)
11435            test_helper(torch.nn.LeakyReLU(), device, shape)
11436
11437    @onlyCUDA
11438    def test_activations_bfloat16(self, device):
11439        _test_bfloat16_ops(self, torch.nn.ReLU(), device, inp_dims=(5), prec=1e-2)
11440        _test_bfloat16_ops(self, torch.nn.Threshold(0.1, 20), device, inp_dims=(5), prec=1e-2)
11441        _test_bfloat16_ops(self, torch.nn.ELU(), device, inp_dims=(5), prec=1e-2)
11442        _test_bfloat16_ops(self, torch.nn.Softplus(), device, inp_dims=(5), prec=1e-2)
11443        _test_bfloat16_ops(self, torch.nn.Hardshrink(), device, inp_dims=(5), prec=1e-2)
11444        _test_bfloat16_ops(self, torch.nn.Softshrink(), device, inp_dims=(5), prec=1e-2)
11445        _test_bfloat16_ops(self, torch.nn.LeakyReLU(), device, inp_dims=(5), prec=1e-2)
11446
11447    @onlyNativeDeviceTypes
11448    def test_softmax_bfloat16(self, device):
11449        for dim in [0, 1, 2, 3]:
11450            _test_bfloat16_ops(self, torch.nn.Softmax(dim=dim), device, inp_dims=(16, 33, 15, 16), prec=1e-2)
11451            # test softmax with large input value which casues exp() to overflow
11452            _test_bfloat16_ops(self, torch.nn.Softmax(dim=dim), device, inp_dims=(16, 33, 15, 16), prec=0.05, scale_factor=1000.0)
11453
11454    def test_nll_loss_mismatched_batch(self, device):
11455        x = torch.randn((10, 3), requires_grad=True, device=device)
11456        # t should have size (10,)
11457        t = torch.zeros((3,), dtype=torch.int64, device=device)
11458        with self.assertRaisesRegex(ValueError, 'Expected.*batch_size'):
11459            F.nll_loss(x, t)
11460
11461    def test_nll_loss_out_of_bounds_ignore_index(self, device):
11462        x = torch.randn(6, 3, requires_grad=True, device=device)
11463        t = torch.tensor([0, 1, 255, 0, 1, 2], dtype=torch.int64, device=device)
11464        for reduction in ['mean', 'none']:
11465            F.nll_loss(x, t, ignore_index=255, reduction=reduction).sum().backward()
11466
11467    def test_nll_loss_invalid_target_dim(self, device):
11468        x = torch.randn((10, 3), device=device)
11469        t = torch.zeros((10, 2), dtype=torch.int64, device=device)
11470        with self.assertRaisesRegex(RuntimeError, "1D target tensor expected"):
11471            F.nll_loss(x, t)
11472
11473    def test_nll_loss_invalid_weights(self, device):
11474        x = torch.randn((10, 3), device=device)
11475        t = torch.empty(10, dtype=torch.int64, device=device).random_(0, 3)
11476        invalid_weights = [
11477            torch.randn(4, device=device),
11478            torch.randn(1, 3, device=device),
11479        ]
11480        msg = "weight tensor should be defined either for all 3 classes or no classes"
11481        for weight in invalid_weights:
11482            with self.assertRaisesRegex(RuntimeError, msg):
11483                F.nll_loss(x, t, weight=weight)
11484
11485    # Ref: https://github.com/pytorch/pytorch/issue/85005
11486    @onlyCUDA
11487    @largeTensorTest("120GB", "cpu")
11488    @largeTensorTest("45GB", "cuda")
11489    @parametrize_test("reduction", ("none", "mean", "sum"))
11490    def test_nll_loss_large_tensor(self, device, reduction):
11491        shape = [int(2 ** 16), int(2 ** 16) + 1]
11492
11493        input = torch.randn(shape, device=device, dtype=torch.float32, requires_grad=True)
11494        labels = torch.randint(shape[0], (shape[0],), dtype=torch.long, device=device)
11495
11496        out = F.nll_loss(input, labels, reduction=reduction)
11497
11498        with torch.no_grad():
11499            input_cpu = input.cpu().float().requires_grad_()
11500            labels_cpu = labels.cpu()
11501        out_cpu = F.nll_loss(input_cpu, labels_cpu, reduction=reduction)
11502        # workaround to reduce memory usage vs. self.assertEqual, see #84944
11503        rtol, atol = torch.testing._comparison.get_tolerances(torch.float32, rtol=None, atol=None)
11504        if reduction == "sum":
11505            orig_rtol, orig_atol = rtol, atol
11506            rtol, atol = 7 * rtol, 3 * atol
11507        with torch.no_grad():
11508            self.assertTrue(torch.allclose(out.cpu(), out_cpu, rtol=rtol, atol=atol))
11509        if reduction == "sum":
11510            rtol, atol = orig_rtol, orig_atol
11511
11512        if reduction != "none":
11513            out.backward()
11514            out_cpu.backward()
11515            with torch.no_grad():
11516                self.assertTrue(torch.allclose(input.grad.cpu(), input_cpu.grad, rtol=rtol, atol=atol))
11517
11518    # Ref: https://github.com/pytorch/pytorch/issue/108345
11519    @onlyCUDA
11520    @largeTensorTest("20GB", "cpu")
11521    @largeTensorTest("20GB", "cuda")
11522    @parametrize_test("reduction", ("none", "mean", "sum"))
11523    def test_cross_entropy_64bit(self, device, reduction):
11524        labels = torch.zeros(190, 50, dtype=torch.long, device=device)
11525        logits = torch.ones(190, 229000, 50, dtype=torch.float, device=device)
11526        loss = torch.nn.functional.cross_entropy(logits, labels)
11527        loss_cpu = torch.nn.functional.cross_entropy(logits.cpu(), labels.cpu())
11528        print(logits.numel(), labels.numel(), loss.numel())
11529        self.assertTrue(torch.allclose(loss_cpu, loss.cpu(), rtol=1e-4, atol=1e-4))
11530
11531    def _nll_loss_helper(self, input_size, reduction, expected, device):
11532        input = torch.rand(input_size, requires_grad=True, device=device)
11533        num_channels = input_size[1]
11534        target_size = (input_size[0], ) + tuple(input_size[2:])
11535        target = torch.randint(num_channels, target_size, device=device)
11536
11537        output = F.nll_loss(input, target, reduction=reduction)
11538        self.assertEqual(output, expected, exact_dtype=False)
11539
11540        output.sum().backward()
11541        self.assertEqual(input.grad.size(), input.size())
11542
11543    def test_nll_loss_empty_tensor_reduction_none(self, device):
11544        self._nll_loss_helper([0, 3], "none", torch.empty([0], device=device), device)
11545        self._nll_loss_helper([0, 3, 5, 7], "none", torch.empty([0, 5, 7], device=device), device)
11546        self._nll_loss_helper([2, 3, 0, 7], "none", torch.empty([2, 0, 7], device=device), device)
11547        self._nll_loss_helper([2, 3, 5, 0], "none", torch.empty([2, 5, 0], device=device), device)
11548        self._nll_loss_helper([2, 3, 5, 7, 0], "none", torch.empty([2, 5, 7, 0], device=device), device)
11549
11550    @expectedFailureMPS  # RuntimeError: [srcBuf length] > 0 INTERNAL ASSERT FAILED https://github.com/pytorch/pytorch/issues/134431
11551    def test_nll_loss_empty_tensor_reduction_mean(self, device):
11552        nan = torch.tensor(float('nan'), device=device)
11553        self._nll_loss_helper([0, 3], "mean", nan, device)
11554        self._nll_loss_helper([0, 3, 5, 7], "mean", nan, device)
11555        self._nll_loss_helper([2, 3, 0, 7], "mean", nan, device)
11556        self._nll_loss_helper([2, 3, 5, 0], "mean", nan, device)
11557        self._nll_loss_helper([2, 3, 5, 7, 0], "mean", nan, device)
11558
11559    @expectedFailureMPS  # RuntimeError: [srcBuf length] > 0 INTERNAL ASSERT FAILED https://github.com/pytorch/pytorch/issues/134431
11560    def test_nll_loss_empty_tensor_reduction_sum(self, device):
11561        zero = torch.tensor(0, device=device)
11562        self._nll_loss_helper([0, 3], "sum", zero, device)
11563        self._nll_loss_helper([0, 3, 5, 7], "sum", zero, device)
11564        self._nll_loss_helper([2, 3, 0, 7], "sum", zero, device)
11565        self._nll_loss_helper([2, 3, 5, 0], "sum", zero, device)
11566        self._nll_loss_helper([2, 3, 5, 7, 0], "sum", zero, device)
11567
11568    @expectedFailureMPS  # AssertionError: Expected nan but got 0.0.
11569    def test_nll_loss_total_weight_is_zero(self, device):
11570
11571        def helper(input_size):
11572            input = torch.ones(input_size, requires_grad=True, device=device)
11573            num_channels = input_size[1]
11574            target_size = (input_size[0], ) + tuple(input_size[2:])
11575            target = torch.zeros(target_size, dtype=torch.long, device=device)
11576            weight = torch.zeros([num_channels], device=device)
11577            self.assertEqual(F.nll_loss(input, target, weight, reduction="sum").item(), 0.)
11578            self.assertEqual(F.nll_loss(input, target, weight, reduction="mean").item(), float("nan"))
11579            self.assertEqual(F.nll_loss(input, target, weight, reduction="none"), torch.zeros(target.shape, device=device))
11580
11581        helper([2, 3])
11582        helper([2, 3, 5, 7])
11583        helper([2, 3, 5, 7, 9])
11584
11585    @expectedFailureMPS  # AssertionError: Expected nan but got 0.0.
11586    def test_nll_loss_all_ignored(self, device):
11587
11588        def helper(input_size):
11589            input = torch.ones(input_size, device=device)
11590            num_channels = input_size[1]
11591            target_size = (input_size[0], ) + tuple(input_size[2:])
11592            target = torch.zeros(target_size, dtype=torch.long, device=device)
11593            self.assertEqual(F.nll_loss(input, target, ignore_index=0, reduction="sum").item(), 0)
11594            self.assertEqual(F.nll_loss(input, target, ignore_index=0, reduction="mean").item(), float("nan"))
11595            self.assertEqual(F.nll_loss(input, target, ignore_index=0, reduction="none"), torch.zeros(target.shape, device=device))
11596
11597        helper([2, 3])
11598        helper([2, 3, 5, 7])
11599        helper([2, 3, 5, 7, 9])
11600
11601    def test_nll_loss_byte_target_matches_long(self, device):
11602        N, C = 10, 4
11603        input = torch.randn(N, C, device=device, requires_grad=True)
11604        target = torch.empty(N, dtype=torch.long, device=device).random_(0, C)
11605
11606        def compute_result_and_gradient(reduction, target_dtype):
11607            input_ = input.detach()
11608            input_.requires_grad_()
11609
11610            prob = F.log_softmax(input_, dim=-1)
11611            loss = nn.NLLLoss(reduction=reduction)
11612            result = loss(prob, target.to(target_dtype))
11613            result.sum().backward()
11614
11615            return result, input_.grad
11616
11617        for reduction in ["none", "mean", "sum"]:
11618            result_long, grad_long = compute_result_and_gradient(reduction, torch.long)
11619            result_byte, grad_byte = compute_result_and_gradient(reduction, torch.uint8)
11620            self.assertEqual(result_long, result_byte)
11621            self.assertEqual(grad_long, grad_byte)
11622
11623    @onlyCUDA
11624    @skipIfRocm
11625    @dtypes(torch.float16, torch.float32)
11626    def test_cross_entropy_loss_2d_out_of_bounds_class_index(self, device, dtype):
11627        # Test for issue #117532
11628        # Run in a different process to prevent the device-side assert from affecting other tests
11629        stderr = TestCase.runWithPytorchAPIUsageStderr(f"""\
11630#!/usr/bin/env python3
11631
11632import torch
11633import torch.nn.functional as F
11634from torch.testing._internal.common_utils import (run_tests, TestCase)
11635
11636class TestThatContainsCUDAAssert(TestCase):
11637    def test_cross_entropy_loss_2d_out_of_bounds_class_index(self):
11638        device = '{str(device)}'
11639        dtype = {str(dtype).strip("'")}
11640        ignore_index = 255
11641        b = 10
11642        n_classes = 3
11643        w = 768
11644        h = 1024
11645        pred = torch.randn(b, n_classes, w, h, dtype=dtype, device=device)
11646        labels = torch.zeros(b, w, h, dtype=torch.int64, device=device)
11647        labels[5, 200, 200] = ignore_index
11648        # Set invalid class index
11649        labels[5, 200, 200] = 254
11650
11651        x = F.cross_entropy(
11652            pred, labels, reduction="none", ignore_index=ignore_index
11653        )
11654        torch.cuda.synchronize()
11655
11656
11657if __name__ == '__main__':
11658    run_tests()
11659        """)
11660        self.assertIn('CUDA error: device-side assert triggered', stderr)
11661
11662
11663
11664    def test_cross_entropy_loss_prob_target_all_reductions(self, device):
11665        # Test with k-dimensional loss.
11666        for k in range(5):
11667            N, C = 5, 4
11668            other_dims = [torch.randint(2, 5, size=(1,)).item() for _ in range(k)]
11669            input = torch.randn(N, C, *other_dims, device=device, requires_grad=True)
11670            target = torch.randn(N, C, *other_dims, device=device, requires_grad=True)
11671            weight = torch.randn(C, device=device).abs()
11672
11673            for reduction, w in product(['none', 'mean', 'sum'], [None, weight]):
11674                m = torch.nn.CrossEntropyLoss(weight=w, reduction=reduction)
11675                output = m(input, target)
11676                output_ref = loss_reference_fns['CrossEntropyLoss'](
11677                    input, target, reduction=reduction, weight=w)
11678                self.assertEqual(output, output_ref)
11679
11680    def test_cross_entropy_loss_prob_target_unit_weights(self, device):
11681        # Test with k-dimensional loss.
11682        for k in range(5):
11683            N, C = 5, 4
11684            other_dims = [torch.randint(2, 5, size=(1,)).item() for _ in range(k)]
11685            input = torch.randn(N, C, *other_dims, device=device, requires_grad=True)
11686            target = torch.randn(N, C, *other_dims, device=device, requires_grad=True)
11687
11688            for reduction in ['none', 'mean', 'sum']:
11689                # Ensure result with unit weights is equivalent to result without weights.
11690                m = torch.nn.CrossEntropyLoss(reduction=reduction)
11691                unit_weight = torch.ones(C, device=device, dtype=target.dtype)
11692                m_unit = torch.nn.CrossEntropyLoss(weight=unit_weight, reduction=reduction)
11693                output = m(input, target)
11694                output_unit = m_unit(input, target)
11695                self.assertEqual(output, output_unit)
11696
11697    @parametrize_test('reduction', ['none', 'mean', 'sum'])
11698    @parametrize_test('weighted', [False, True])
11699    def test_cross_entropy_loss_prob_target_no_batch_dim(self, device, reduction, weighted):
11700        C = 5
11701        input = torch.randn(C, device=device).log_softmax(dim=-1)
11702        target = torch.randn(C, device=device).softmax(dim=-1)
11703        weight = torch.randn(C, device=device) if weighted else None
11704        m = nn.CrossEntropyLoss(reduction=reduction, weight=weight)
11705        loss_no_batch = m(input, target)
11706        loss_batch = m(input.unsqueeze(0), target.unsqueeze(0))
11707        if reduction == 'none':
11708            loss_batch = loss_batch.squeeze(0)
11709        self.assertEqual(loss_no_batch, loss_batch)
11710
11711    def test_cross_entropy_loss_index_target_unit_weights(self, device):
11712        # Test with k-dimensional loss.
11713        for k in range(5):
11714            N, C = 5, 4
11715            other_dims = [torch.randint(2, 5, size=(1,)).item() for _ in range(k)]
11716            input = torch.randn(N, C, *other_dims, device=device, requires_grad=True)
11717            target = torch.empty(N, *other_dims, dtype=torch.long, device=device).random_(0, C)
11718
11719            for reduction in ['none', 'mean', 'sum']:
11720                # Ensure result with unit weights is equivalent to result without weights.
11721                m = torch.nn.CrossEntropyLoss(reduction=reduction)
11722                unit_weight = torch.ones(C, device=device, dtype=input.dtype)
11723                m_unit = torch.nn.CrossEntropyLoss(weight=unit_weight, reduction=reduction)
11724                output = m(input, target)
11725                output_unit = m_unit(input, target)
11726                self.assertEqual(output, output_unit)
11727
11728    def test_cross_entropy_loss_one_hot_target(self, device):
11729        # Test with k-dimensional loss.
11730        for k in range(5):
11731            N, C = 5, 4
11732            other_dims = [torch.randint(2, 5, size=(1,)).item() for _ in range(k)]
11733            input = torch.randn(N, C, *other_dims, device=device, requires_grad=True)
11734            target = torch.empty(N, *other_dims, dtype=torch.long, device=device).random_(0, C)
11735            weight = torch.randn(C, device=device).abs()
11736
11737            # Get one-hot representation of the target.
11738            target_one_hot = F.one_hot(target, num_classes=C).to(input.dtype)
11739            # Need to put the C dim at index 1.
11740            target_one_hot = target_one_hot.permute(0, -1, *range(1, target_one_hot.dim() - 1))
11741
11742            for reduction, w in product(['none', 'mean', 'sum'], [None, weight]):
11743                # Skip this case for now because soft and hard label CE are not consistent
11744                # in the way they apply class weights (see issue #61309).
11745                if reduction == 'mean' and weight is not None:
11746                    continue
11747
11748                # Ensure loss computed with class indices matches loss
11749                # computed with one-hot class probs.
11750                m = torch.nn.CrossEntropyLoss(weight=w, reduction=reduction)
11751                output = m(input, target)
11752                output_one_hot = m(input, target_one_hot)
11753                self.assertEqual(output, output_one_hot)
11754
11755    def test_cross_entropy_label_smoothing_errors(self, device):
11756        N, C = 3, 4
11757        input_args = [
11758            (torch.randn((N, C), device=device), torch.arange(0, C, device=device)),
11759            (torch.randn((N, C), device=device), torch.randn(N, C, device=device))
11760        ]
11761        for input_arg in input_args:
11762            loss = nn.CrossEntropyLoss(label_smoothing=1.2)
11763            with self.assertRaisesRegex(RuntimeError,
11764                                        r"label_smoothing must be between 0\.0"):
11765                loss(*input_arg)
11766
11767    @expectedFailureMPS  # TypeError: the MPS framework doesn't support float64
11768    @set_default_dtype(torch.double)
11769    def test_cross_entropy_label_smoothing_consistent_index_target_and_probs(self, device):
11770        N, C = 10, 4
11771        ks = range(5)
11772        reductions = ['none', 'mean', 'sum']
11773        label_smoothings = [0.05, 0.15]
11774
11775        for k, reduction, label_smoothing in product(ks, reductions, label_smoothings):
11776            other_dims = [torch.randint(2, 5, size=(1,)).item() for _ in range(k)]
11777            input = torch.randn(N, C, *other_dims, device=device, requires_grad=True)
11778            target = torch.empty(N, *other_dims, dtype=torch.long, device=device).random_(0, C)
11779
11780            # construct target probablity that should have the same result as label_smoothing
11781            target_proba = F.one_hot(target, num_classes=C)
11782            # Need to put the C dim at index 1.
11783            target_proba = target_proba.permute(0, -1, *range(1, target_proba.dim() - 1))
11784            target_mask = (target_proba == 1)
11785            target_proba = target_proba.to(dtype=input.dtype)
11786
11787            # y_k^ls = y_k * (1 - label_smoothing) + label_smoothing / n_classes
11788            # Get one-hot representation of the target.
11789            target_proba.masked_fill_(target_mask, 1 - label_smoothing + label_smoothing / C)
11790            target_proba.masked_fill_(~target_mask, label_smoothing / C)
11791
11792            loss = nn.CrossEntropyLoss(reduction=reduction)
11793            output_with_prob = loss(input, target_proba)
11794
11795            loss = nn.CrossEntropyLoss(
11796                reduction=reduction, label_smoothing=label_smoothing)
11797            output_with_index = loss(input, target)
11798
11799            self.assertEqual(output_with_prob, output_with_index,
11800                             rtol=1e-07, atol=1e-05)
11801
11802    def test_cross_entropy_label_smoothing_with_probs(self, device):
11803        N, C = 10, 4
11804        ks = range(5)
11805        reductions = ['none', 'mean', 'sum']
11806        label_smoothings = [0.05, 0.15]
11807
11808        # Test with k-dimensional loss.
11809        for k, label_smoothing in product(ks, label_smoothings):
11810            other_dims = [torch.randint(2, 5, size=(1,)).item() for _ in range(k)]
11811            input = torch.randn(N, C, *other_dims, device=device, requires_grad=True)
11812            target = F.log_softmax(torch.randn(N, C, *other_dims, device=device), dim=1)
11813
11814            for reduction in reductions:
11815                # use with label_smoothing
11816                loss = nn.CrossEntropyLoss(reduction=reduction, label_smoothing=label_smoothing)
11817                output_with_smoothing = loss(input, target)
11818
11819                # manually smoothing target
11820                # class_proba^ls = class_proba * (1 - label_smoothing) +
11821                #                  label_smoothing / n_classes
11822                target_with_smoothing = target * (1 - label_smoothing) + label_smoothing / C
11823                loss = nn.CrossEntropyLoss(reduction=reduction)
11824                output_with_manual_smoothing = loss(input, target_with_smoothing)
11825
11826                self.assertEqual(output_with_smoothing, output_with_manual_smoothing)
11827
11828
11829    def test_cross_entropy_label_smoothing_weight_ignore_indices(self, device):
11830        reductions = ['none', 'sum', 'mean']
11831        label_smoothings = [0.05, 0.15]
11832
11833        wgt = torch.tensor([0.3, 0.6], device=device)
11834        inp1 = torch.tensor([[0.3, 0.4], [1, 2]], device=device)
11835        inp2 = torch.tensor([[0.3, 0.6], [1, 2]], device=device)
11836
11837        targ_default_ignore_index = torch.tensor([-100, 1], device=device)
11838        targ_negative_ignore_index = torch.tensor([-2, 1], device=device)
11839        targ_positive_ignore_index = torch.tensor([2, 1], device=device)
11840
11841        for reduction, label_smoothing, weight in product(reductions, label_smoothings, (None, wgt)):
11842            def check_equal(loss, inp_targ_1, inp_targ_2):
11843                inp1, targ1 = inp_targ_1
11844                inp2, targ2 = inp_targ_2
11845                l1 = loss(inp1, targ1)
11846                l2 = loss(inp2, targ2)
11847                self.assertEqual(l1, l2)
11848
11849            # Default ignore_index
11850            loss = nn.CrossEntropyLoss(reduction=reduction,
11851                                       label_smoothing=label_smoothing,
11852                                       weight=weight)
11853            check_equal(loss, (inp1, targ_default_ignore_index), (inp2, targ_default_ignore_index))
11854            if reduction != 'none':
11855                # Check that we correctly tally the denominator for `mean`
11856                # i.e. we don't count the ignored_idx at all.
11857                check_equal(loss, (inp1, targ_default_ignore_index), (inp2[1:], targ_default_ignore_index[1:]))
11858
11859            # negative ignore_index
11860            loss = nn.CrossEntropyLoss(reduction=reduction,
11861                                       label_smoothing=label_smoothing,
11862                                       ignore_index=-2,
11863                                       weight=weight)
11864            check_equal(loss, (inp1, targ_negative_ignore_index), (inp2, targ_negative_ignore_index))
11865            if reduction != 'none':
11866                # Check that we correctly tally the denominator for `mean`
11867                # i.e. we don't count the ignored_idx at all.
11868                check_equal(loss, (inp1, targ_negative_ignore_index), (inp2[1:], targ_negative_ignore_index[1:]))
11869
11870            # positive ignore_index
11871            loss = nn.CrossEntropyLoss(reduction=reduction,
11872                                       label_smoothing=label_smoothing,
11873                                       ignore_index=2,
11874                                       weight=weight)
11875            check_equal(loss, (inp1, targ_positive_ignore_index), (inp2, targ_positive_ignore_index))
11876            if reduction != 'none':
11877                # Check that we correctly tally the denominator for `mean`
11878                # i.e. we don't count the ignored_idx at all.
11879                check_equal(loss, (inp1, targ_positive_ignore_index), (inp2[1:], targ_positive_ignore_index[1:]))
11880
11881    # Ref: https://github.com/pytorch/pytorch/issues/85005
11882    @onlyCUDA
11883    @largeTensorTest("45GB", "cpu")
11884    @largeTensorTest("70GB", "cuda")
11885    @parametrize_test("reduction", ("none", "mean", "sum"))
11886    def test_cross_entropy_large_tensor(self, device, reduction):
11887        logits = torch.randn(int(2 ** 16), int(2 ** 16) + 1, dtype=torch.float32, device='cuda', requires_grad=True)
11888        labels = torch.zeros(logits.size(0), dtype=torch.long, device='cuda')
11889        loss = F.cross_entropy(logits, labels, reduction=reduction)
11890        if reduction != "none":
11891            loss.backward()
11892
11893        with torch.no_grad():
11894            logits_cpu = logits.cpu().detach().requires_grad_()
11895            labels_cpu = labels.cpu().detach()
11896        loss_cpu = F.cross_entropy(logits_cpu, labels_cpu, reduction=reduction)
11897        if reduction != "none":
11898            loss_cpu.backward()
11899
11900        # workaround to reduce memory usage vs. self.assertEqual, see #84944
11901        rtol, atol = torch.testing._comparison.get_tolerances(torch.float32, rtol=None, atol=None)
11902        self.assertTrue(torch.allclose(loss.cpu(), loss_cpu, rtol=rtol, atol=atol))
11903        if reduction != "none":
11904            self.assertTrue(torch.allclose(logits.grad.cpu(), logits_cpu.grad, rtol=rtol, atol=atol))
11905
11906    def test_smoothl1loss_backward_zero_beta(self, device):
11907        input = torch.randn(300, 256, requires_grad=True, device=device)
11908        target = input.detach()
11909
11910        loss = F.smooth_l1_loss(input, target, beta=0.0, reduction='sum')
11911        loss.backward()
11912
11913        grad_max_abs = input.grad.abs().max().item()
11914        self.assertLessEqual(grad_max_abs, 1.0)
11915
11916    def test_softshrink_negative(self, device):
11917        input = torch.randn(5, device=device, requires_grad=True)
11918        m = torch.nn.Softshrink(-1)
11919        with self.assertRaisesRegex(RuntimeError,
11920                                    r'lambda must be greater or equal to 0, but found to be -1\.'):
11921            m(input)
11922
11923    @expectedFailureMPS  # TypeError: the MPS framework doesn't support float64
11924    def test_fold(self, device):
11925        def test_dtype(fn, input, dtype):
11926            input = input.detach().clone().to(dtype=dtype).requires_grad_(True)
11927            input2 = input.detach().clone().float().requires_grad_(True)
11928            out = fn(input)
11929            out.sum().backward()
11930            out2 = fn(input2)
11931            out2.sum().backward()
11932            self.assertEqual(out.dtype, dtype)
11933            self.assertEqual(input.grad.dtype, dtype)
11934            self.assertEqual(out, out2.to(dtype=dtype), atol=0.05, rtol=0)
11935            self.assertEqual(input.grad, input2.grad.to(dtype=dtype))
11936
11937        def func(x):
11938            return F.fold(x, output_size=(4, 5), kernel_size=(2, 2))
11939
11940        seeds = (44, 83, 71, 25, 999)
11941        for sd in seeds:
11942            torch.manual_seed(sd)
11943            x = torch.randn(1, 12, 12, device=device, requires_grad=True, dtype=torch.double)
11944            gradcheck(func, [x], check_forward_ad=True)
11945            gradgradcheck(func, [x], check_fwd_over_rev=True)
11946            if device == 'cpu':
11947                test_dtype(func, x, torch.bfloat16)
11948
11949
11950    def test_logsigmoid_out(self, device):
11951        # this isn't actually documented, but was broken previously:
11952        # https://github.com/pytorch/pytorch/issues/36499
11953        x = torch.randn(2, 3, device=device).t()
11954        empty_out = torch.randn(0, device=device)
11955        self.assertEqual(F.logsigmoid(x), F.logsigmoid(x, out=empty_out))
11956
11957        noncontig_out = torch.randn(2, 3, device=device).t()
11958        self.assertEqual(F.logsigmoid(x), F.logsigmoid(x, out=noncontig_out))
11959
11960    # Check that clip_grad_norm_ raises an error if the total norm of the
11961    # parameters' gradients is non-finite
11962    @expectedFailureMPS  # TypeError: the MPS framework doesn't support float64
11963    def test_clip_grad_norm_error_if_nonfinite(self, device):
11964        norms_pos = [0.1, 1, 2, 3.5, inf]
11965        norms_neg = [-0.1, -1, -2, -3.5]
11966        norms_except_0 = norms_pos + norms_neg
11967        norms_all = norms_except_0 + [0]
11968
11969        # Each entry in test_cases has the following values, in this order:
11970        #
11971        # grad_only_one_elem    If True, only one element of the parameter's
11972        #                       gradient is set to the scalar grad, and the
11973        #                       rest of the elements are 0. If False, all grad
11974        #                       elements are equal to the scalar.
11975        #
11976        # prefix_finite_grad_param  If True, prefix a parameter that has a grad
11977        #                           of 1.
11978        #
11979        # scalars           Scalars to use as the parameter's grad, through
11980        #                   multiplication
11981        #
11982        # norms_nonfinite   Norm types that should produce nonfinite total norm
11983        #
11984        # norms_finite      Norm types that should produce finite total norm
11985        test_cases = [
11986            # Test errors from an infinite grad
11987            (False, False, [inf, -inf], norms_except_0, [0]),
11988            (False, True, [inf, -inf], norms_pos, norms_neg + [0]),
11989            (True, False, [inf, -inf], norms_pos, norms_neg + [0]),
11990            (True, True, [inf, -inf], norms_pos, norms_neg + [0]),
11991
11992            # Test errors from a NaN grad
11993            (False, False, [nan], norms_except_0, [0]),
11994            (False, True, [nan], norms_except_0, [0]),
11995            (True, False, [nan], norms_except_0, [0]),
11996            (True, True, [nan], norms_except_0, [0]),
11997
11998            # Test a grad that should never error
11999            (False, False, [2e22, -2e22], [], norms_all),
12000            (False, True, [2e22, -2e22], [], norms_all),
12001            (True, False, [2e22, -2e22], [], norms_all),
12002            (True, True, [2e22, -2e22], [], norms_all),
12003
12004            # Test a grad that will overflow to inf for only some norm orders
12005            (False, False, [2e200, -2e200], [3.5, 2, -2, -3.5], [inf, 1, 0.1, 0, -1, -0.1]),
12006            (False, True, [2e200, -2e200], [3.5, 2], norms_neg + [inf, 1, 0.1, 0]),
12007            (True, False, [2e200, -2e200], [3.5, 2], norms_neg + [inf, 1, 0.1, 0]),
12008            (True, True, [2e200, -2e200], [3.5, 2], norms_neg + [inf, 1, 0.1, 0]),
12009        ]
12010
12011        def gen_parameters(scalar, grad_only_one_elem, prefix_finite_grad_param):
12012            param = torch.ones(10, dtype=torch.float64, device=device, requires_grad=True)
12013
12014            if grad_only_one_elem:
12015                param[1].mul(scalar).sum().backward()
12016            else:
12017                param.mul(scalar).sum().backward()
12018
12019            if prefix_finite_grad_param:
12020                prefix_param = torch.ones(1, dtype=torch.float64, device=device, requires_grad=True)
12021                prefix_param.mul(1).sum().backward()
12022                parameters = [prefix_param, param]
12023            else:
12024                parameters = [param]
12025
12026            return parameters
12027
12028        def run_test_case(norm_type, error_if_nonfinite, scalar, grad_only_one_elem, prefix_finite_grad_param, is_norm_nonfinite):
12029            msg = (
12030                f'norm_type: {norm_type}, ',
12031                f'error_if_nonfinite: {error_if_nonfinite}, '
12032                f'scalar: {scalar}, '
12033                f'grad_only_one_elem: {grad_only_one_elem}, '
12034                f'prefix_finite_grad_param: {prefix_finite_grad_param}, '
12035                f'is_norm_nonfinite: {is_norm_nonfinite}')
12036
12037            parameters = gen_parameters(scalar, grad_only_one_elem, prefix_finite_grad_param)
12038
12039            # Should only throw an error if the total norm is expected to be
12040            # nonfinite and `error_if_nonfinite=True`
12041            if is_norm_nonfinite and error_if_nonfinite:
12042                error_msg = f'The total norm of order {float(norm_type)} for gradients'
12043
12044                grads_before = [p.grad.clone() for p in parameters]
12045
12046                with self.assertRaisesRegex(RuntimeError, error_msg, msg=msg):
12047                    clip_grad_norm_(parameters, 1, norm_type=norm_type, error_if_nonfinite=True)
12048
12049                # Grad should not change if error is thrown
12050                grads_after = [p.grad for p in parameters]
12051                self.assertEqual(grads_before, grads_after, msg=msg)
12052            else:
12053                clip_grad_norm_(parameters, 1, norm_type=norm_type, error_if_nonfinite=error_if_nonfinite)
12054
12055        for grad_only_one_elem, prefix_finite_grad_param, scalars, norms_nonfinite, norms_finite in test_cases:
12056            for error_if_nonfinite in [False, True]:
12057                for norm_type, scalar in product(norms_nonfinite, scalars):
12058                    run_test_case(norm_type, error_if_nonfinite, scalar, grad_only_one_elem, prefix_finite_grad_param, True)
12059
12060                for norm_type, scalar in product(norms_finite, scalars):
12061                    run_test_case(norm_type, error_if_nonfinite, scalar, grad_only_one_elem, prefix_finite_grad_param, False)
12062
12063    @onlyCUDA
12064    @deviceCountAtLeast(2)
12065    @parametrize_test('foreach', (False, True))
12066    def test_clip_grad_norm_multi_device(self, devices, foreach):
12067        class TestModel(nn.Module):
12068            def __init__(self) -> None:
12069                super().__init__()
12070                self.layer1 = nn.Linear(10, 10)
12071                self.layer2 = nn.Linear(10, 10)
12072
12073        test_model = TestModel()
12074        test_model.layer1.to(devices[0])
12075        test_model.layer2.to(devices[1])
12076        ref_model = TestModel().to(devices[0])
12077        for norm_type in [2., math.inf]:
12078            for p in test_model.parameters():
12079                p.grad = torch.ones_like(p)
12080            for p in ref_model.parameters():
12081                p.grad = torch.ones_like(p)
12082            norm = clip_grad_norm_(test_model.parameters(), 0.5, norm_type=norm_type, foreach=foreach)
12083            expected = clip_grad_norm_(ref_model.parameters(), 0.5, norm_type=norm_type, foreach=foreach)
12084            self.assertEqual(norm, expected)
12085            for p, pe in zip(test_model.parameters(), ref_model.parameters()):
12086                self.assertEqual(p.grad.to(devices[0]), pe.grad)
12087
12088    def test_elu_inplace_overlap(self, device):
12089        dtype = torch.bfloat16 if device != 'mps:0' else torch.float16
12090        x = torch.randn((1, 6), dtype=dtype, device=device).expand((6, 6))
12091        with self.assertRaisesRegex(RuntimeError, 'unsupported operation'):
12092            F.elu(x, inplace=True)
12093        with self.assertRaisesRegex(RuntimeError, 'unsupported operation'):
12094            F.elu_(x)
12095
12096    # Merge into OpInfo?
12097    @onlyNativeDeviceTypes
12098    def test_elu_inplace_with_neg_alpha(self, device):
12099        a = torch.tensor([-1., 1.], device=device, requires_grad=True)
12100        b = torch.nn.functional.elu_(a.clone(), alpha=-2)
12101        with self.assertRaisesRegex(RuntimeError, "call out-of-place version"):
12102            b.backward(torch.ones(2, device=device))
12103
12104        a = torch.tensor([-1., 1.], device=device, requires_grad=True)
12105        b = torch.nn.functional.celu_(a.clone(), alpha=-2)
12106        with self.assertRaisesRegex(RuntimeError, "call out-of-place version"):
12107            b.backward(torch.ones(2, device=device))
12108
12109    @expectedFailureMeta  # https://github.com/pytorch/pytorch/issues/54897
12110    def test_hardswish_inplace_overlap(self, device):
12111        x = torch.randn((1, 6), device=device).expand((6, 6))
12112        with self.assertRaisesRegex(RuntimeError, 'unsupported operation'):
12113            F.hardswish(x, inplace=True)
12114
12115    def test_silu_inplace_overlap(self, device):
12116        x = torch.randn((1, 6), device=device).expand((6, 6))
12117        with self.assertRaisesRegex(RuntimeError, 'unsupported operation'):
12118            F.silu(x, inplace=True)
12119
12120    @onlyNativeDeviceTypes
12121    def test_mish_inplace_overlap(self, device):
12122        x = torch.randn((1, 6), device=device).expand((6, 6))
12123        with self.assertRaisesRegex(RuntimeError, 'unsupported operation'):
12124            F.mish(x, inplace=True)
12125
12126    def test_softplus_inplace_overlap(self, device):
12127        x = torch.randn((1, 6), device=device).expand((6, 6))
12128        with self.assertRaisesRegex(RuntimeError, 'unsupported operation'):
12129            F.softplus(x, out=x)
12130
12131    @expectedFailureMPS  # TypeError: the MPS framework doesn't support float64
12132    def test_softplus_low_threshold(self, device):
12133        # Ensure gradients are computed correctly with a low threshold.
12134        model = torch.nn.Softplus(threshold=1).double()
12135        input = torch.tensor(0.9, device=device, dtype=torch.double,
12136                             requires_grad=True)
12137        output = model(input)
12138        torch.autograd.gradcheck(model, input)
12139
12140    def test_softshrink_inplace_overlap(self, device):
12141        x = torch.randn((1, 6), device=device).expand((6, 6))
12142        with self.assertRaisesRegex(RuntimeError, 'unsupported operation'):
12143            F.softshrink(x, out=x)
12144
12145    def test_leaky_relu_inplace_overlap(self, device):
12146        x = torch.randn((1, 6), device=device).expand((6, 6))
12147        with self.assertRaisesRegex(RuntimeError, 'unsupported operation'):
12148            F.leaky_relu(x, inplace=True)
12149        with self.assertRaisesRegex(RuntimeError, 'unsupported operation'):
12150            F.leaky_relu_(x)
12151
12152    # Merge into OpInfo?
12153    @expectedFailureMPS  # NotImplementedError: aten::rrelu_with_noise_ https://github.com/pytorch/pytorch/issues/77764
12154    def test_leaky_relu_inplace_with_neg_slope(self, device):
12155        a = torch.tensor([-1., 1.], device=device, requires_grad=True)
12156        b = torch.nn.functional.leaky_relu_(a.clone(), -2)
12157        with self.assertRaisesRegex(RuntimeError, "call out-of-place version"):
12158            b.backward(torch.ones(2, device=device))
12159
12160        a = torch.tensor([-1., 1.], device=device, requires_grad=True)
12161        b = torch.nn.functional.rrelu_(a.clone(), -5.0, 1.0)
12162        with self.assertRaisesRegex(RuntimeError, "call out-of-place version"):
12163            b.backward(torch.ones(2, device=device))
12164
12165    # Merge into OpInfo?
12166    def test_leaky_relu_inplace_with_zero_slope(self, device):
12167        a = torch.tensor([-2., 0., 2.], device=device, requires_grad=True)
12168        b = torch.nn.functional.leaky_relu_(a.clone(), 0.0)
12169        b.backward(torch.ones(3, device=device))
12170        expected = torch.tensor([0., 0., 1.], device=device)
12171        self.assertEqual(a.grad, expected)
12172
12173        dtype = torch.bfloat16 if device != 'mps:0' else torch.float16
12174        a_bf16 = torch.tensor([-2., 0., 2.], device=device, dtype=dtype, requires_grad=True)
12175        b_bf16 = torch.nn.functional.leaky_relu_(a_bf16.clone(), 0.0)
12176        b_bf16.backward(torch.ones(3, device=device))
12177        expected_bf16 = torch.tensor([0., 0., 1.], device=device, dtype=dtype)
12178        self.assertEqual(a_bf16.grad, expected_bf16)
12179
12180    @onlyCPU
12181    def test_softshrink(self, device):
12182        x = torch.tensor([[1.21, 0.56, 0.5001, 0.4999, 1.2357, -0.4999, -0.5001, -1.154,
12183                           0.254, -0.24, -0.225, 0.104, 0.002, -0.001, 0.0574, 1.2344,
12184                           0.1748, -0.1797, -0.8125, 0.2051, -1.1328, 1.2344, -0.1562, 2.3554,
12185                           -0.1953, 0.0304, -0.3613, -1.3047, 1.0312, 0.1436, -0.6953, 0.5664,
12186                           -0.5820, -0.3301, 0.8203, 0.6133, 0.5938],
12187                          [-0.8203, -1.2344, -0.5234, 2.5312, -0.4551, -0.6875, -1.5547, -0.2217,
12188                           -0.3027, 2.6406, 1.3047, 0.2344, -1.6719, 0.2773, -1.3516, 3.4575,
12189                           0.4414, 0.2656, 2.1094, -1.5156, 1.2344, -0.4336, 0.6797, -3.5486,
12190                           0.9766, -0.4062, 1.4844, 0.7500, -1.7578, 0.7461, 1.6094, 8.5458,
12191                           0.3730, -0.3477, -1.0625, 0.3848, 0.0557]], device=device)
12192        expected = torch.tensor([[0.71, 0.06, 0.0001, 0., 0.7357, 0., -0.0001, -0.654,
12193                                  0., 0., 0., 0., 0., 0., 0., 0.7344,
12194                                  0., 0., -0.3125, 0., -0.6328, 0.7344, 0., 1.8554,
12195                                  0., 0., 0., -0.8047, 0.5312, 0., -0.1953, 0.0664,
12196                                  -0.0820, 0.0, 0.3203, 0.1133, 0.0938],
12197                                 [-0.3203, -0.7344, -0.0234, 2.0312, 0.0, -0.1875, -1.0547, 0.,
12198                                  0.0, 2.1406, 0.8047, 0., -1.1719, 0., -0.8516, 2.9575,
12199                                  0., 0., 1.6094, -1.0156, 0.7344, 0., 0.1797, -3.0486,
12200                                  0.4766, 0., 0.9844, 0.2500, -1.2578, 0.2461, 1.1094, 8.0458,
12201                                  0., 0., -0.5625, 0., 0.]])
12202        softshrink = torch.nn.Softshrink()
12203        out = softshrink(x)
12204        self.assertEqual(out, expected, atol=1e-2, rtol=0)
12205
12206    def test_threshold_inplace_overlap(self, device):
12207        # Inplace threshold is okay, because it is idempotent
12208        x = torch.randn((1, 6), device=device).expand((6, 6))
12209        F.threshold(x, 0.5, 0.5, inplace=True)
12210        F.threshold_(x, 0.5, 0.5)
12211
12212    @onlyNativeDeviceTypes
12213    def test_triplet_margin_with_distance_loss_default_parity(self, device):
12214        # Test for `nn.TripletMarginWithDistanceLoss` and
12215        # `F.triplet_margin_with_distance_loss`.  Checks
12216        # for parity against the respective non-distance-agnostic
12217        # implementations of triplet margin loss (``nn.TripletMarginLoss`
12218        # and `F.triplet_margin_loss`) under *default args*.
12219
12220        for extra_args in \
12221                itertools.product((0.5, 1, 1.5), (True, False), ('none', 'mean', 'sum')):
12222            kwargs = {'margin': extra_args[0], 'swap': extra_args[1], 'reduction': extra_args[2]}
12223
12224            anchor = torch.randn(5, 10, device=device, requires_grad=True, dtype=torch.double)
12225            positive = torch.randn(5, 10, device=device, requires_grad=True, dtype=torch.double)
12226            negative = torch.randn(5, 10, device=device, requires_grad=True, dtype=torch.double)
12227
12228            # Test forward, functional
12229            expected = F.triplet_margin_loss(anchor, positive, negative, **kwargs)
12230            actual = F.triplet_margin_with_distance_loss(anchor, positive, negative, **kwargs)
12231            self.assertEqual(actual, expected, rtol=1e-6, atol=1e-6)
12232
12233            # Test forward, module
12234            loss_ref = nn.TripletMarginLoss(**kwargs)
12235            loss_op = nn.TripletMarginWithDistanceLoss(**kwargs)
12236            self.assertEqual(loss_op(anchor, positive, negative),
12237                             loss_ref(anchor, positive, negative),
12238                             rtol=1e-6, atol=1e-6)
12239
12240            # Test backward
12241            self.assertTrue(gradcheck(lambda a, p, n: F.triplet_margin_with_distance_loss(
12242                a, p, n, **kwargs), (anchor, positive, negative)))
12243            self.assertTrue(gradcheck(lambda a, p, n: loss_op(a, p, n),
12244                            (anchor, positive, negative)))
12245
12246    @onlyNativeDeviceTypes
12247    def test_triplet_margin_with_distance_loss(self, device):
12248        # Test for parity between `nn.TripletMarginWithDistanceLoss` and
12249        # `F.triplet_margin_with_distance_loss`.
12250
12251        pairwise_distance = nn.PairwiseDistance()
12252
12253        def cosine_distance(x, y):
12254            return 1.0 - F.cosine_similarity(x, y)
12255
12256        distance_functions = (pairwise_distance, cosine_distance,
12257                              lambda x, y: 1.0 - F.cosine_similarity(x, y))
12258
12259        reductions = ('mean', 'none', 'sum')
12260        margins = (1.0, 1.5, 0.5)
12261        swaps = (True, False)
12262
12263        for distance_fn, reduction, margin, swap \
12264                in itertools.product(distance_functions, reductions, margins, swaps):
12265            anchor = torch.randn(5, 10, device=device, requires_grad=True, dtype=torch.double)
12266            positive = torch.randn(5, 10, device=device, requires_grad=True, dtype=torch.double)
12267            negative = torch.randn(5, 10, device=device, requires_grad=True, dtype=torch.double)
12268
12269            # Test backward
12270            self.assertTrue(gradcheck(lambda a, p, n: F.triplet_margin_with_distance_loss(
12271                a, p, n, distance_function=distance_fn, reduction=reduction, margin=margin, swap=swap),
12272                (anchor, positive, negative)))
12273            loss_op = nn.TripletMarginWithDistanceLoss(distance_function=distance_fn,
12274                                                       reduction=reduction, margin=margin, swap=swap)
12275            self.assertTrue(gradcheck(lambda a, p, n: loss_op(
12276                a, p, n), (anchor, positive, negative)))
12277            traced_loss_op = torch.jit.trace(loss_op, (anchor, positive, negative))
12278            self.assertTrue(gradcheck(lambda a, p, n: traced_loss_op(
12279                a, p, n), (anchor, positive, negative)))
12280
12281            # Test forward parity
12282            functional = F.triplet_margin_with_distance_loss(anchor, positive, negative,
12283                                                             distance_function=distance_fn,
12284                                                             reduction=reduction, margin=margin, swap=swap)
12285            modular = loss_op(anchor, positive, negative)
12286            traced = traced_loss_op(anchor, positive, negative)
12287            self.assertEqual(functional, modular, atol=1e-6, rtol=1e-6)
12288            self.assertEqual(traced, modular, atol=1e-6, rtol=1e-6)
12289
12290    @dtypesIfMPS(torch.cfloat, torch.float)
12291    @dtypes(torch.cfloat, torch.cdouble, torch.float)
12292    def test_to_complex(self, device, dtype):
12293        m = nn.Linear(3, 5).to(device)
12294        self.assertIs(m, m.to(device))
12295        m.to(dtype)
12296        self.assertIs(m.weight.dtype, dtype)
12297        with warnings.catch_warnings(record=True) as w:
12298            # Trigger warning
12299            m.to(torch.cfloat)
12300            # Check warning occurs
12301            self.assertEqual(len(w), 1)
12302            self.assertTrue("Complex modules are a new feature" in str(w[-1].message))
12303
12304    @skipMeta
12305    @dtypesIfMPS(torch.float32)
12306    @dtypes(torch.float32, torch.float64)
12307    def test_module_to_empty(self, device, dtype):
12308        class MyModule(nn.Module):
12309            def __init__(self, in_features, out_features, device=None, dtype=None):
12310                super().__init__()
12311                factory_kwargs = {"device": device, "dtype": dtype}
12312                self.weight = nn.Parameter(torch.randn(in_features, out_features, **factory_kwargs))
12313
12314            def forward(self, x):
12315                return x @ self.weight
12316
12317        # Test meta module instantiation.
12318        input = torch.randn(5, 10, device=device, dtype=dtype)
12319        m = MyModule(10, 1, device='meta', dtype=dtype)
12320        m(input)
12321
12322        # Test empty meta module error with torch.nn.Module.to().
12323        with self.assertRaisesRegex(
12324            NotImplementedError,
12325            re.escape(
12326                "Cannot copy out of meta tensor; no data! Please use torch.nn.Module.to_empty() "
12327                "instead of torch.nn.Module.to() when moving module from meta to a different "
12328                "device."
12329            ),
12330        ):
12331            m.to(device)
12332
12333        # Test materializing meta module on a real device.
12334        m.to_empty(device=device)
12335        m(input)
12336        with torch.no_grad():
12337            torch.nn.init.kaiming_uniform_(m.weight)
12338        m(input)
12339
12340        # Test creating meta module from materialized module.
12341        m.to_empty(device='meta')
12342        m(input)
12343
12344    def test_module_to_empty_non_recursive(self, device):
12345        class Layer(nn.Module):
12346            def __init__(self, in_features, out_features):
12347                super().__init__()
12348                self.weight = nn.Parameter(torch.randn(in_features, out_features))
12349                self.register_buffer('buf', torch.randn(out_features))
12350
12351            def forward(self, x):
12352                return x @ self.weight + self.buf
12353
12354        class MyModule(nn.Module):
12355            def __init__(self, in_features, out_features):
12356                super().__init__()
12357                self.weight = nn.Parameter(torch.randn(in_features, out_features))
12358                self.register_buffer('buf1', torch.randn(out_features))
12359                self.layer = Layer(out_features, out_features)
12360
12361            def forward(self, x):
12362                return self.layer(x @ self.weight + self.buf1)
12363
12364        with torch.device('meta'):
12365            m = MyModule(3, 5)
12366
12367        m.to_empty(device=device, recurse=False)
12368
12369        # params/buffers of parent should have been materialized on device
12370        self.assertTrue(not m.weight.is_meta)
12371        self.assertTrue(not m.buf1.is_meta)
12372
12373        # parameters/buffers of children submodules should still be on meta
12374        for p in (*m.layer.parameters(), *m.layer.buffers()):
12375            self.assertTrue(p.is_meta)
12376
12377    @skipMeta
12378    def test_skip_init(self, device):
12379        torch.manual_seed(1)
12380        m_initialized = torch.nn.Linear(5, 1)
12381        m_initialized.to(device)
12382
12383        torch.manual_seed(1)
12384        m_uninitialized = torch.nn.utils.skip_init(torch.nn.Linear, 5, 1, device=device)
12385
12386        self.assertEqual(m_initialized.weight.device, m_uninitialized.weight.device)
12387        self.assertFalse(torch.allclose(m_initialized.weight, m_uninitialized.weight))
12388
12389    @skipIfRocm(msg='See https://github.com/pytorch/pytorch/issues/135150')
12390    @skipIfMps  # TODO(hvaara): Investigate as possible bug. macOS 13 passes, while 14 and 15 fails.
12391    @dtypes(torch.float)
12392    @dtypesIfCUDA(torch.double, torch.float, torch.half)
12393    def test_transformerencoderlayer(self, device, dtype):
12394        if TEST_WITH_ROCM and PLATFORM_SUPPORTS_FLASH_ATTENTION and dtype == torch.half:
12395            self.skipTest("Skip on ROCM due to Flash Attention tolerances")
12396        # this is a deterministic test for TransformerEncoderLayer
12397        d_model = 4
12398        nhead = 2
12399        dim_feedforward = 16
12400        dropout = 0.0
12401        bsz = 2
12402
12403        atol = 1e-5
12404        rtol = 1e-7
12405        if "cuda" in device:
12406            atol = 1e-3
12407            rtol = 1e-2
12408
12409        def _test(training, batch_first, atol, rtol):
12410            def perm_fn(x):
12411                return x.transpose(1, 0) if batch_first else x
12412
12413            model = nn.TransformerEncoderLayer(d_model, nhead, dim_feedforward, dropout,
12414                                               batch_first=batch_first, device=device, dtype=dtype)
12415
12416            if not training:
12417                assert dropout == 0
12418                model = model.eval()
12419
12420            # set constant weights of the model
12421            for idx, p in enumerate(model.parameters()):
12422                x = p.data
12423                sz = x.view(-1).size(0)
12424                shape = x.shape
12425                x = torch.cos(torch.arange(0, sz).float().view(shape))
12426                p.data.copy_(x)
12427
12428            # deterministic input
12429            encoder_input = torch.tensor([[[20., 30., 40., 50.]]], device=device, dtype=dtype)
12430            result = model(encoder_input)
12431            ref_output = torch.tensor([[[2.258703, 0.127985, -0.697881, 0.170862]]], device=device, dtype=dtype)
12432            self.assertEqual(result.shape, ref_output.shape)
12433            torch.testing.assert_close(result, ref_output, atol=atol, rtol=rtol)
12434            # 0 values are NOT masked. This shouldn't mask anything.
12435            mask = torch.tensor([[0]], device=device) == 1
12436            # TODO: enable fast path for calls with a mask!
12437            result = model(encoder_input, src_key_padding_mask=mask)
12438            self.assertEqual(result.shape, ref_output.shape)
12439            torch.testing.assert_close(result, ref_output, atol=atol, rtol=rtol)
12440            mask = torch.tensor([[1]], device=device) == 1
12441            result = model(encoder_input, src_key_padding_mask=mask)
12442            fast_path_device = result.is_cuda or result.is_cpu
12443            result = result.cpu().detach().numpy()
12444            # Non Fast Paths
12445            if training or not batch_first or TEST_WITH_CROSSREF or not fast_path_device:
12446                # We changed the semenatic, on the non fast path so that fully masked out rows return
12447                # 0 from attention thus NaNs should no longer be present and the output should be nonzero
12448                # due to skip connections
12449                self.assertTrue(not np.isnan(result).any())
12450            else:
12451                # Fast Paths
12452                self.assertTrue(np.isnan(result).all())
12453
12454
12455            # deterministic input
12456            encoder_input = perm_fn(torch.tensor([[[1., 2., 3., 4.]],
12457                                                  [[5., 6., 7., 8.]]], device=device, dtype=dtype))
12458            result = model(encoder_input)
12459            ref_output = perm_fn(torch.tensor([[[2.272644, 0.119035, -0.691669, 0.153486]],
12460                                               [[2.272644, 0.119035, -0.691669, 0.153486]]], device=device, dtype=dtype))
12461            self.assertEqual(result.shape, ref_output.shape)
12462            torch.testing.assert_close(result, ref_output, atol=atol, rtol=rtol)
12463            # all 0 which is no masking
12464            mask = torch.tensor([[0, 0]], device=device) == 1
12465            result = model(encoder_input, src_key_padding_mask=mask)
12466            self.assertEqual(result.shape, ref_output.shape)
12467            torch.testing.assert_close(result, ref_output, atol=atol, rtol=rtol)
12468            mask = torch.tensor([[1, 0]], device=device) == 1
12469            result = model(encoder_input, src_key_padding_mask=mask)
12470            ref_output = perm_fn(torch.tensor([[[2.301516, 0.092249, -0.679101, 0.103088]],
12471                                               [[2.301516, 0.092249, -0.679101, 0.103088]]], device=device, dtype=dtype))
12472            self.assertEqual(result.shape, ref_output.shape)
12473            torch.testing.assert_close(result, ref_output, atol=atol, rtol=rtol)
12474
12475            # deterministic input
12476            encoder_input = perm_fn(torch.tensor([[[0.7462, 0.6653, 0.5679, 0.4891],
12477                                                   [0.5387, 0.1655, 0.3565, 0.0471]],
12478                                                  [[0.8335, 0.2799, 0.5031, 0.2947],
12479                                                   [0.1402, 0.0318, 0.7636, 0.1346]],
12480                                                  [[0.6333, 0.9344, 0.1376, 0.9938],
12481                                                   [0.8924, 0.2872, 0.6692, 0.2944]],
12482                                                  [[0.9897, 0.6915, 0.3154, 0.1733],
12483                                                   [0.8645, 0.3513, 0.3064, 0.0767]],
12484                                                  [[0.8117, 0.2366, 0.4838, 0.7881],
12485                                                   [0.3718, 0.4945, 0.9511, 0.0864]]], device=device, dtype=dtype))
12486            result = model(encoder_input)
12487            ref_output = perm_fn(torch.tensor([[[2.428589, 0.020835, -0.602055, -0.085249],
12488                                                [2.427987, 0.021213, -0.602496, -0.084103]],
12489                                               [[2.424689, 0.019155, -0.604793, -0.085672],
12490                                                [2.413863, 0.022211, -0.612486, -0.072490]],
12491                                               [[2.433774, 0.021598, -0.598343, -0.087548],
12492                                                [2.425104, 0.019748, -0.604515, -0.084839]],
12493                                               [[2.436185, 0.022682, -0.596625, -0.087261],
12494                                                [2.433556, 0.021891, -0.598509, -0.086832]],
12495                                               [[2.416246, 0.017512, -0.610712, -0.082961],
12496                                                [2.422901, 0.024187, -0.606178, -0.074929]]], device=device, dtype=dtype))
12497            self.assertEqual(result.shape, ref_output.shape)
12498            torch.testing.assert_close(result, ref_output, atol=atol, rtol=rtol)
12499
12500            # all 0
12501            mask = torch.zeros([2, 5], device=device) == 1
12502            result = model(encoder_input, src_key_padding_mask=mask)
12503            self.assertEqual(result.shape, ref_output.shape)
12504            torch.testing.assert_close(result, ref_output, atol=atol, rtol=rtol)
12505            mask[0, 1] = 1
12506            mask[1, 3] = 1
12507            mask[1, 4] = 1
12508            result = model(encoder_input, src_key_padding_mask=mask)
12509            ref_output = perm_fn(torch.tensor([[[2.429026, 0.020793, -0.601741, -0.085642],
12510                                                [2.428811, 0.021445, -0.601912, -0.084252]],
12511                                               [[2.425009, 0.019155, -0.604566, -0.085899],
12512                                                [2.415408, 0.02249 , -0.611415, -0.073]],
12513                                               [[2.434199, 0.021682, -0.598039, -0.087699],
12514                                                [2.42598, 0.019941, -0.603896, -0.085091]],
12515                                               [[2.436457, 0.022736, -0.59643 , -0.08736],
12516                                                [2.434021, 0.022093, -0.598179, -0.08679]],
12517                                               [[2.416531, 0.017498, -0.610513, -0.083181],
12518                                                [2.4242, 0.024653, -0.605266, -0.074959]]], device=device, dtype=dtype))
12519            self.assertEqual(result.shape, ref_output.shape)
12520            torch.testing.assert_close(result, ref_output, atol=atol, rtol=rtol)
12521
12522            # NestedTensor is only supported for the fast path
12523            # currently, which won't be used if training.
12524            if (batch_first and not training and
12525                    ('cuda' in str(device) or 'cpu' in str(device)) and not TEST_WITH_CROSSREF):
12526                encoder_input[0][-1] = torch.zeros_like(encoder_input[0][1])
12527                mask = torch.zeros(encoder_input.shape[:-1], device=device, dtype=torch.bool)
12528                mask[0][-1] = True
12529
12530                nt = torch.nested.nested_tensor([encoder_input[0][:-1], encoder_input[1]], device=device)
12531                result = model(nt)
12532                ref_output = torch.tensor(
12533                    [
12534                        [
12535                            [2.4268184, 0.02042419, -0.603311, -0.08476824],
12536                            [2.423306, 0.01889652, -0.6057701, -0.08519465],
12537                            [2.431538, 0.02078694, -0.5999354, -0.08746159],
12538                            [2.4348664, 0.02212971, -0.5975677, -0.08733892],
12539                            [2.423133, 0.02097577, -0.60594773, -0.08113337],
12540                        ],
12541                        [
12542                            [2.4279876, 0.02121329, -0.60249615, -0.08410317],
12543                            [2.4138637, 0.02221113, -0.6124869, -0.07249016],
12544                            [2.4251041, 0.01974815, -0.6045152, -0.08483928],
12545                            [2.4335563, 0.0218913, -0.59850943, -0.08683228],
12546                            [2.4229012, 0.02418739, -0.6061784, -0.07492948],
12547                        ],
12548                    ],
12549                    device=device, dtype=dtype
12550                )
12551                result = result.to_padded_tensor(0)
12552                ref_output[0][-1] = torch.zeros_like(
12553                    ref_output[0][-1], device=device, dtype=dtype
12554                )
12555                result[0][-1] = torch.zeros_like(
12556                    result[0][-1], device=device, dtype=dtype
12557                )
12558                self.assertEqual(tuple(result.shape), tuple(ref_output.shape))
12559                if 'cuda' in device:
12560                    if dtype == torch.float:
12561                        atol = 2e-4
12562                        rtol = 4e-3
12563                    else:
12564                        atol = 7e-4
12565                        rtol = 2e-2
12566                    torch.testing.assert_close(result, ref_output, atol=atol, rtol=rtol)
12567                else:
12568                    torch.testing.assert_close(result, ref_output)
12569
12570
12571        for batch_first in (True, False):
12572            for training in (True, False):
12573                if training:
12574                    cm = contextlib.nullcontext()
12575                else:
12576                    # Fast path requires inference mode.
12577                    cm = torch.no_grad()
12578                with cm:
12579                    _test(batch_first=batch_first, training=training, atol=atol, rtol=rtol)
12580
12581    @onlyCPU
12582    @dtypes(torch.double)
12583    def test_transformerencoderlayer_fast_path(self, device, dtype):
12584        """
12585        Test transformer fast path on CPU with different valid mask types and shapes
12586        """
12587        d_model = 512
12588        nhead = 8
12589        batch_size = 32
12590        src_len = 10
12591
12592        model = torch.nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead, batch_first=True,
12593                                                 device=device, dtype=dtype, dropout=0)
12594        model.eval()
12595
12596        # Batched inputs
12597        src = torch.rand(batch_size, src_len, 512, dtype=dtype)
12598
12599        # Attention mask of shape (src_len, src_len)
12600        src_mask = torch.zeros(src_len, src_len).to(torch.bool)
12601        with torch.no_grad():
12602            model(src, src_mask=src_mask)
12603
12604        # Padding mask of shape (batch_size, src_len)
12605        src_key_padding_mask = torch.zeros(batch_size, src_len).to(torch.bool)
12606        with torch.no_grad():
12607            model(src, src_key_padding_mask=src_key_padding_mask)
12608
12609        # Provide both masks
12610        with torch.no_grad():
12611            model(src, src_mask=src_mask, src_key_padding_mask=src_key_padding_mask)
12612
12613
12614    @dtypes(torch.float)
12615    @dtypesIfCUDA(torch.half, torch.float)
12616    def test_transformerencoderlayer_gelu(self, device, dtype):
12617        if TEST_WITH_ROCM and PLATFORM_SUPPORTS_FLASH_ATTENTION and dtype == torch.half:
12618            self.skipTest("Skip on ROCM due to Flash Attention tolerances")
12619        # this is a deterministic test for TransformerEncoderLayer with gelu activation
12620        d_model = 4
12621        nhead = 2
12622        dim_feedforward = 16
12623        dropout = 0.0
12624        bsz = 2
12625
12626        atol = 0
12627        rtol = 1e-5
12628        if "cuda" in device:
12629            atol = 1e-3
12630            rtol = 1e-2
12631
12632        def _test(activation, batch_first, training):
12633            def perm_fn(x):
12634                return x.transpose(1, 0) if batch_first else x
12635
12636            model = nn.TransformerEncoderLayer(d_model, nhead, dim_feedforward, dropout,
12637                                               activation, batch_first=batch_first, device=device, dtype=dtype)
12638            if not training:
12639                assert dropout == 0
12640                model = model.eval()
12641
12642            # set constant weights of the model
12643            for idx, p in enumerate(model.parameters()):
12644                x = p.data
12645                sz = x.view(-1).size(0)
12646                shape = x.shape
12647                x = torch.cos(torch.arange(0, sz).float().view(shape))
12648                p.data.copy_(x)
12649
12650            # deterministic input
12651            encoder_input = torch.tensor([[[20., 30., 40., 50.]]], device=device, dtype=dtype)
12652            result = model(encoder_input)
12653            ref_output = torch.tensor([[[2.249815, 0.131006, -0.702199, 0.177868]]], device=device, dtype=dtype)
12654            torch.testing.assert_close(result, ref_output, rtol=rtol, atol=atol)
12655
12656            # deterministic input
12657            encoder_input = perm_fn(torch.tensor([[[1., 2., 3., 4.]],
12658                                                  [[5., 6., 7., 8.]]], device=device, dtype=dtype))
12659            result = model(encoder_input)
12660            ref_output = perm_fn(torch.tensor([[[2.264103, 0.121417, -0.696012, 0.159724]],
12661                                               [[2.264103, 0.121417, -0.696012, 0.159724]]], device=device, dtype=dtype))
12662            torch.testing.assert_close(result, ref_output, rtol=rtol, atol=atol)
12663
12664            # deterministic input
12665            encoder_input = perm_fn(torch.tensor([[[0.7462, 0.6653, 0.5679, 0.4891],
12666                                                  [0.5387, 0.1655, 0.3565, 0.0471]],
12667                                                  [[0.8335, 0.2799, 0.5031, 0.2947],
12668                                                  [0.1402, 0.0318, 0.7636, 0.1346]],
12669                                                  [[0.6333, 0.9344, 0.1376, 0.9938],
12670                                                  [0.8924, 0.2872, 0.6692, 0.2944]],
12671                                                  [[0.9897, 0.6915, 0.3154, 0.1733],
12672                                                  [0.8645, 0.3513, 0.3064, 0.0767]],
12673                                                  [[0.8117, 0.2366, 0.4838, 0.7881],
12674                                                  [0.3718, 0.4945, 0.9511, 0.0864]]], device=device, dtype=dtype))
12675            result = model(encoder_input)
12676            ref_output = perm_fn(torch.tensor([[[2.42163188, 0.03227153, -0.60714219, -0.05908082],
12677                                                [2.42151276, 0.03302179, -0.60722523, -0.05762651]],
12678                                               [[2.41926761, 0.02974034, -0.60879519, -0.0621269],
12679                                                [2.41626395, 0.03539356, -0.61087842, -0.04978623]],
12680                                               [[2.42382808, 0.03218872, -0.6055963, -0.06073591],
12681                                                [2.41983477, 0.03085259, -0.60840145, -0.06046414]],
12682                                               [[2.42500749, 0.03328855, -0.60476388, -0.0595334],
12683                                                [2.4237977, 0.03290575, -0.60561789, -0.05940082]],
12684                                               [[2.41383916, 0.02686345, -0.61256377, -0.06380707],
12685                                                [2.42000277, 0.03800944, -0.60824798, -0.04754947]]], device=device, dtype=dtype))
12686            torch.testing.assert_close(result, ref_output, rtol=rtol, atol=atol)
12687        for activation, batch_first, training in product(('gelu', F.gelu, nn.GELU()), (True, False), (True, False)):
12688            # Fast path requires inference mode.
12689            if training:
12690                cm = contextlib.nullcontext()
12691            else:
12692                cm = torch.no_grad()
12693            with cm:
12694                _test(activation=activation, batch_first=batch_first, training=training)
12695
12696    @skipIfMps  # RuntimeError: foreach=True was passed, but can't use the foreach API on mps tensors
12697    @parametrize_test('foreach', (False, True))
12698    def test_clip_grad_value(self, foreach, device):
12699        if torch.device(device).type == 'xla' and foreach:
12700            raise SkipTest('foreach not supported on XLA')
12701
12702        l = nn.Linear(10, 10).to(device)
12703        clip_value = 2.5
12704
12705        grad_w, grad_b = torch.arange(-50., 50, device=device).view(10, 10).div_(5), torch.ones(10, device=device).mul_(2)
12706        for grad_list in [[grad_w, grad_b], [grad_w, None]]:
12707            for p, g in zip(l.parameters(), grad_list):
12708                p._grad = g.clone().view_as(p.data) if g is not None else g
12709
12710            clip_grad_value_(l.parameters(), clip_value, foreach=foreach)
12711            for p in filter(lambda p: p.grad is not None, l.parameters()):
12712                self.assertLessEqual(p.grad.data.max(), clip_value)
12713                self.assertGreaterEqual(p.grad.data.min(), -clip_value)
12714
12715        # Should accept a single Tensor as input
12716        p1, p2 = torch.randn(10, 10, device=device), torch.randn(10, 10, device=device)
12717        g = torch.arange(-50., 50, device=device).view(10, 10).div_(5)
12718        p1._grad = g.clone()
12719        p2._grad = g.clone()
12720        clip_grad_value_(p1, clip_value, foreach=foreach)
12721        clip_grad_value_([p2], clip_value, foreach=foreach)
12722        self.assertEqual(p1.grad, p2.grad)
12723
12724    @skipIfMps  # TypeError: the MPS framework doesn't support float64
12725    @parametrize_test('foreach', (False, True))
12726    @parametrize_test('norm_type', (0.5, 1.5, 2, 4, 'inf'))
12727    def test_clip_grad_norm(self, norm_type, foreach, device):
12728        if torch.device(device).type == 'xla' and foreach:
12729            raise SkipTest('foreach not supported on XLA')
12730
12731        l = nn.Linear(10, 10).to(device)
12732        max_norm = 2
12733
12734        def compute_norm(norm_type):
12735            norm_type = float(norm_type)
12736            if norm_type != inf:
12737                total_norm = 0
12738                for p in l.parameters():
12739                    total_norm += p.grad.data.abs().pow(norm_type).sum()
12740                return pow(total_norm, 1. / norm_type)
12741            else:
12742                return max(p.grad.data.abs().max() for p in l.parameters())
12743
12744        def compare_scaling(grads):
12745            p_scale = [p.grad.data.div(g).view(-1) for p, g in zip(l.parameters(), grads)]
12746            scale = torch.cat(p_scale)
12747            self.assertEqual(scale.std(), 0)
12748            return scale[0]
12749
12750        grads = torch.arange(1., 101, device=device).view(10, 10), torch.ones(10, device=device).div(1000)
12751        for p, g in zip(l.parameters(), grads):
12752            p._grad = g.clone().view_as(p.data)
12753        norm_before = compute_norm(norm_type)
12754        norm = clip_grad_norm_(l.parameters(), max_norm, norm_type=norm_type, foreach=foreach)
12755        norm_after = compute_norm(norm_type)
12756        self.assertEqual(norm, norm_before)
12757        self.assertEqual(norm_after, max_norm)
12758        self.assertLessEqual(norm_after, norm_before)
12759        compare_scaling(grads)
12760
12761        # Small gradients should be left unchanged
12762        grads = torch.rand(10, 10, device=device).div(10000), torch.ones(10, device=device).div(500)
12763        for p, g in zip(l.parameters(), grads):
12764            p.grad.data.copy_(g)
12765        norm_before = compute_norm(norm_type)
12766        norm = clip_grad_norm_(l.parameters(), max_norm, norm_type=norm_type, foreach=foreach)
12767        norm_after = compute_norm(norm_type)
12768        self.assertEqual(norm, norm_before)
12769        self.assertEqual(norm_before, norm_after)
12770        self.assertLessEqual(norm_after, max_norm)
12771        scale = compare_scaling(grads)
12772        self.assertEqual(scale, 1)
12773
12774        # Should accept a single Tensor as input
12775        p1, p2 = torch.randn(10, 10, device=device), torch.randn(10, 10, device=device)
12776        g = torch.arange(1., 101, device=device).view(10, 10)
12777        p1._grad = g.clone()
12778        p2._grad = g.clone()
12779        clip_grad_norm_(p1, max_norm, norm_type=norm_type, foreach=foreach)
12780        clip_grad_norm_([p2], max_norm, norm_type=norm_type, foreach=foreach)
12781        self.assertEqual(p1.grad, p2.grad)
12782
12783    # reference issue: https://github.com/pytorch/pytorch/issues/111484
12784    @onlyCUDA
12785    @largeTensorTest("42GB", "cuda")
12786    def test_softmax_forward_64bit_indexing(self, device):
12787        batch_size = 70
12788        seq_len = 2048
12789        vocab_size = 50000
12790
12791        shift_labels = torch.zeros(batch_size, seq_len - 1, dtype=torch.long, device=device)
12792        logits = torch.ones(batch_size, seq_len - 1, vocab_size, dtype=torch.float16, device=device)
12793        loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
12794        nll = loss_fct(logits.permute(0, 2, 1), shift_labels).float()
12795        rtol, atol = torch.testing._comparison.get_tolerances(torch.float16, rtol=None, atol=None)
12796        self.assertEqual(nll, torch.ones_like(nll) * torch.log(torch.tensor(vocab_size)), rtol=rtol, atol=atol)
12797
12798    @onlyCUDA
12799    @largeTensorTest("20GB", "cuda")
12800    def test_softmax_backward_64bit_indexing(self, device):
12801        for numel in (2147483650, 2147483650 + 1):
12802            x = torch.empty([1, 1, numel], device=device, dtype=torch.float16)
12803            x.fill_(1.0 / numel)
12804            out = torch._softmax_backward_data(x, x, 2, x.dtype)
12805            self.assertEqual(out[0, 0, 0], 1 / numel)
12806
12807    # reference issue: https://github.com/pytorch/pytorch/issues/68248
12808    @onlyCUDA
12809    def test_adaptiveavg_pool1d_shmem(self, device):
12810        x = torch.randn(1, 256, 1, 5000, device=device).to(memory_format=torch.channels_last)
12811        x_cpu = x.cpu()
12812        x_cpu.requires_grad_()
12813        x.requires_grad_()
12814        y = torch.nn.functional.adaptive_avg_pool2d(x, (1, 256))
12815        y_cpu = torch.nn.functional.adaptive_avg_pool2d(x_cpu, (1, 256))
12816        grad = torch.randn_like(y)
12817        grad_cpu = grad.cpu()
12818        y.backward(grad)
12819        y_cpu.backward(grad_cpu)
12820        self.assertEqual(x.grad, x_cpu.grad)
12821
12822    @skipMeta
12823    @expectedFailureMPS  # NotImplementedError: aten::channel_shuffle https://github.com/pytorch/pytorch/issues/77764
12824    def test_channel_shuffle(self, device):
12825        #  3D tensor
12826        x = torch.tensor(
12827            [[[1, 2],
12828              [5, 6],
12829              [9, 10],
12830              [13, 14],
12831              ]], device=device
12832        )
12833        y_ref = torch.tensor(
12834            [[[1, 2],
12835              [9, 10],
12836              [5, 6],
12837              [13, 14],
12838              ]], device=device
12839        )
12840        #  ChannelsFirst
12841        with warnings.catch_warnings(record=True) as w:
12842            y = F.channel_shuffle(x, 2).to(device)
12843            self.assertEqual(len(w), 0)
12844        self.assertEqual(y, y_ref)
12845        #  ChannelsLast not supported for 3dim
12846
12847        #  4D tensor
12848        x = torch.tensor(
12849            [[[[1, 2],
12850               [3, 4]],
12851              [[5, 6],
12852               [7, 8]],
12853              [[9, 10],
12854               [11, 12]],
12855              [[13, 14],
12856               [15, 16]],
12857              ]], device=device
12858        )
12859        y_ref = torch.tensor(
12860            [[[[1, 2],
12861               [3, 4]],
12862              [[9, 10],
12863               [11, 12]],
12864              [[5, 6],
12865               [7, 8]],
12866              [[13, 14],
12867               [15, 16]],
12868              ]], device=device
12869        )
12870        #  ChannelsFirst NCHW
12871        with warnings.catch_warnings(record=True) as w:
12872            y = F.channel_shuffle(x, 2).to(device)
12873            self.assertEqual(len(w), 0)
12874        self.assertEqual(y, y_ref)
12875        #  ChannelsLast NHWC
12876        with warnings.catch_warnings(record=True) as w:
12877            y = F.channel_shuffle(x.contiguous(memory_format=torch.channels_last), 2).to(device)
12878            self.assertEqual(len(w), 0)
12879        y = y.contiguous(memory_format=torch.contiguous_format)
12880        self.assertEqual(y, y_ref)
12881
12882        #  5D tensor
12883        x = torch.tensor(
12884            [[[[[1, 2],
12885               [3, 4]]],
12886              [[[5, 6],
12887               [7, 8]]],
12888              [[[9, 10],
12889               [11, 12]]],
12890              [[[13, 14],
12891               [15, 16]]],
12892              ]], device=device
12893        )
12894        y_ref = torch.tensor(
12895            [[[[[1, 2],
12896               [3, 4]]],
12897              [[[9, 10],
12898               [11, 12]]],
12899              [[[5, 6],
12900               [7, 8]]],
12901              [[[13, 14],
12902               [15, 16]]],
12903              ]], device=device
12904        )
12905        #  ChannelsFirst NCHW
12906        with warnings.catch_warnings(record=True) as w:
12907            y = F.channel_shuffle(x, 2).to(device)
12908            self.assertEqual(len(w), 0)
12909        self.assertEqual(y, y_ref)
12910        #  ChannelsLast NHWC
12911        with warnings.catch_warnings(record=True) as w:
12912            y = F.channel_shuffle(x.contiguous(memory_format=torch.channels_last_3d), 2).to(device)
12913            self.assertEqual(len(w), 0)
12914        y = y.contiguous(memory_format=torch.contiguous_format)
12915        self.assertEqual(y, y_ref)
12916
12917
12918class TestFunctionalPickle(TestCase):
12919
12920    # issue gh-38137
12921    def test_pickle_softsign(self):
12922        # Make sure it does not throw an exception
12923        s = pickle.dumps(F.softsign)
12924
12925
12926class TestFusionUtils(TestCase):
12927    def test_fuse_conv_bn_requires_grad(self):
12928        conv = torch.nn.Conv2d(3, 3, 3)
12929        bn = torch.nn.BatchNorm2d(3)
12930        cases = itertools.product([True, False], [True, False])
12931        for w_rg, b_rg in cases:
12932            conv.weight.requires_grad = w_rg
12933            conv.bias.requires_grad = b_rg
12934            weight, bias = \
12935                fuse_conv_bn_weights(conv.weight, conv.bias,
12936                                     bn.running_mean, bn.running_var, bn.eps, bn.weight, bn.bias)
12937            self.assertEqual(weight.requires_grad, w_rg)
12938            self.assertEqual(bias.requires_grad, b_rg)
12939
12940    def test_fuse_linear_bn_requires_grad(self):
12941        linear = torch.nn.Linear(3, 3)
12942        bn = torch.nn.BatchNorm1d(3)
12943        cases = itertools.product([True, False], [True, False])
12944        for w_rg, b_rg in cases:
12945            linear.weight.requires_grad = w_rg
12946            linear.bias.requires_grad = b_rg
12947            weight, bias = \
12948                fuse_linear_bn_weights(linear.weight, linear.bias,
12949                                       bn.running_mean, bn.running_var, bn.eps, bn.weight, bn.bias)
12950            self.assertEqual(weight.requires_grad, w_rg)
12951            self.assertEqual(bias.requires_grad, b_rg)
12952
12953class TestUtils(TestCase):
12954    def test_consume_prefix_in_state_dict_if_present(self):
12955        class Block(nn.Module):
12956            def __init__(self) -> None:
12957                super().__init__()
12958                self.conv1 = nn.Conv2d(3, 3, 3, bias=True)
12959                self.conv2 = nn.Conv2d(3, 3, 3, bias=False)
12960
12961        class Net(nn.Module):
12962            def __init__(self) -> None:
12963                super().__init__()
12964                self.linear1 = nn.Linear(5, 5)
12965                self.linear2 = nn.Linear(5, 5)
12966                net.bn = nn.BatchNorm2d(2)
12967                self.block = Block()
12968
12969        # 0. Case non-DDP model empty state_dict
12970        net = nn.Module()
12971        state_dict = net.state_dict()
12972        nn.modules.utils.consume_prefix_in_state_dict_if_present(state_dict, 'module.')
12973        # check they are the same preserving order
12974        self.assertEqual(list(state_dict.keys()), list(net.state_dict().keys()))
12975        self.assertEqual(list(state_dict._metadata.keys()), list(net.state_dict()._metadata.keys()))
12976
12977        # 1. Case non-DDP model test example state_dict
12978        net = Net()
12979        state_dict = net.state_dict()
12980        nn.modules.utils.consume_prefix_in_state_dict_if_present(state_dict, 'module.')
12981        # Check they are the same preserving order
12982        self.assertEqual(list(state_dict.keys()), list(net.state_dict().keys()))
12983        self.assertEqual(list(state_dict._metadata.keys()), list(net.state_dict()._metadata.keys()))
12984
12985        # 2. Case DDP model test example state_dict
12986        state_dict = net.state_dict()
12987        metadata = state_dict._metadata
12988        ddp_state_dict = OrderedDict((f'module.{k}', v) for k, v in state_dict.items())
12989        ddp_state_dict._metadata = OrderedDict({'': metadata['']})
12990        ddp_state_dict._metadata.update(('module' if k == '' else f'module.{k}', v) for k, v in metadata.items())
12991        nn.modules.utils.consume_prefix_in_state_dict_if_present(ddp_state_dict, 'module.')
12992        # Check they are the same preserving order
12993        self.assertEqual(list(state_dict.keys()), list(ddp_state_dict.keys()))
12994        self.assertEqual(list(state_dict._metadata.keys()), list(ddp_state_dict._metadata.keys()))
12995
12996
12997instantiate_device_type_tests(TestNNDeviceType, globals(), allow_mps=True)
12998instantiate_parametrized_tests(TestNN)
12999
13000if __name__ == '__main__':
13001    TestCase._default_dtype_check_enabled = True
13002    run_tests()
13003