1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["module: nn"] 2*da0073e9SAndroid Build Coastguard Worker 3*da0073e9SAndroid Build Coastguard Workerimport contextlib 4*da0073e9SAndroid Build Coastguard Workerimport math 5*da0073e9SAndroid Build Coastguard Workerimport random 6*da0073e9SAndroid Build Coastguard Workerimport unittest 7*da0073e9SAndroid Build Coastguard Workerimport io 8*da0073e9SAndroid Build Coastguard Workerimport itertools 9*da0073e9SAndroid Build Coastguard Workerimport warnings 10*da0073e9SAndroid Build Coastguard Workerimport pickle 11*da0073e9SAndroid Build Coastguard Workerimport re 12*da0073e9SAndroid Build Coastguard Workerfrom copy import deepcopy 13*da0073e9SAndroid Build Coastguard Workerfrom itertools import product 14*da0073e9SAndroid Build Coastguard Workerfrom functools import partial 15*da0073e9SAndroid Build Coastguard Workerfrom collections import OrderedDict 16*da0073e9SAndroid Build Coastguard Workerfrom unittest import SkipTest 17*da0073e9SAndroid Build Coastguard Worker 18*da0073e9SAndroid Build Coastguard Workerimport torch 19*da0073e9SAndroid Build Coastguard Workerfrom torch import inf, nan 20*da0073e9SAndroid Build Coastguard Workerimport torch.autograd.forward_ad as fwAD 21*da0073e9SAndroid Build Coastguard Workerimport torch.backends.cudnn as cudnn 22*da0073e9SAndroid Build Coastguard Workerimport torch.nn as nn 23*da0073e9SAndroid Build Coastguard Workerimport torch.nn.functional as F 24*da0073e9SAndroid Build Coastguard Workerimport torch.nn.utils.rnn as rnn_utils 25*da0073e9SAndroid Build Coastguard Workerfrom torch.nn.utils import clip_grad_norm_, clip_grad_value_ 26*da0073e9SAndroid Build Coastguard Workerfrom torch.nn.utils import parameters_to_vector, vector_to_parameters 27*da0073e9SAndroid Build Coastguard Workerfrom torch.nn.utils.fusion import fuse_conv_bn_weights 28*da0073e9SAndroid Build Coastguard Workerfrom torch.nn.utils.fusion import fuse_linear_bn_weights 29*da0073e9SAndroid Build Coastguard Workerfrom torch.nn import Buffer, Parameter 30*da0073e9SAndroid Build Coastguard Workerfrom torch.nn.parallel._functions import Broadcast 31*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_dtype import integral_types, get_all_math_dtypes, floating_types 32*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import freeze_rng_state, run_tests, TestCase, skipIfNoLapack, skipIfRocm, \ 33*da0073e9SAndroid Build Coastguard Worker TEST_NUMPY, TEST_SCIPY, TEST_WITH_CROSSREF, TEST_WITH_ROCM, \ 34*da0073e9SAndroid Build Coastguard Worker download_file, get_function_arglist, load_tests, skipIfMps, \ 35*da0073e9SAndroid Build Coastguard Worker IS_PPC, \ 36*da0073e9SAndroid Build Coastguard Worker parametrize as parametrize_test, subtest, instantiate_parametrized_tests, \ 37*da0073e9SAndroid Build Coastguard Worker skipIfTorchDynamo, gcIfJetson, set_default_dtype 38*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_cuda import TEST_CUDA, TEST_MULTIGPU, TEST_CUDNN, PLATFORM_SUPPORTS_FLASH_ATTENTION 39*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_nn import NNTestCase, NewModuleTest, CriterionTest, \ 40*da0073e9SAndroid Build Coastguard Worker module_tests, criterion_tests, loss_reference_fns, _create_basic_net, \ 41*da0073e9SAndroid Build Coastguard Worker ctcloss_reference, new_module_tests, single_batch_reference_fn, _test_bfloat16_ops, _test_module_empty_input 42*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_device_type import dtypesIfMPS, instantiate_device_type_tests, dtypes, \ 43*da0073e9SAndroid Build Coastguard Worker dtypesIfCUDA, precisionOverride, skipCUDAIfCudnnVersionLessThan, onlyCUDA, onlyCPU, \ 44*da0073e9SAndroid Build Coastguard Worker skipCUDAIfRocm, skipCUDAIf, skipCUDAIfNotRocm, \ 45*da0073e9SAndroid Build Coastguard Worker onlyNativeDeviceTypes, deviceCountAtLeast, largeTensorTest, expectedFailureMeta, expectedFailureMPS, \ 46*da0073e9SAndroid Build Coastguard Worker skipMeta, get_all_device_types 47*da0073e9SAndroid Build Coastguard Worker 48*da0073e9SAndroid Build Coastguard Workerfrom hypothesis import given 49*da0073e9SAndroid Build Coastguard Workerimport torch.testing._internal.hypothesis_utils as hu 50*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import _assertGradAndGradgradChecks, gradcheck, gradgradcheck, \ 51*da0073e9SAndroid Build Coastguard Worker GRADCHECK_NONDET_TOL 52*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import dtype2prec_DONTUSE 53*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_cuda import tf32_on_and_off, tf32_is_not_fp32, tf32_off, tf32_on 54*da0073e9SAndroid Build Coastguard Workerfrom torch.types import _TensorOrTensors 55*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_mkldnn import bf32_on_and_off 56*da0073e9SAndroid Build Coastguard Worker 57*da0073e9SAndroid Build Coastguard WorkerAMPERE_OR_ROCM = TEST_WITH_ROCM or tf32_is_not_fp32() 58*da0073e9SAndroid Build Coastguard Worker 59*da0073e9SAndroid Build Coastguard Worker# load_tests from common_utils is used to automatically filter tests for 60*da0073e9SAndroid Build Coastguard Worker# sharding on sandcastle. This line silences flake warnings 61*da0073e9SAndroid Build Coastguard Workerload_tests = load_tests 62*da0073e9SAndroid Build Coastguard Worker 63*da0073e9SAndroid Build Coastguard Workerif TEST_SCIPY: 64*da0073e9SAndroid Build Coastguard Worker import scipy.signal 65*da0073e9SAndroid Build Coastguard Worker import scipy.ndimage 66*da0073e9SAndroid Build Coastguard Worker 67*da0073e9SAndroid Build Coastguard Workerif TEST_NUMPY: 68*da0073e9SAndroid Build Coastguard Worker import numpy as np 69*da0073e9SAndroid Build Coastguard Worker 70*da0073e9SAndroid Build Coastguard Worker 71*da0073e9SAndroid Build Coastguard Worker# WARNING: If you add a new top-level test case to this file, you MUST 72*da0073e9SAndroid Build Coastguard Worker# update test/run_test.py to list it, otherwise it will NOT be run in 73*da0073e9SAndroid Build Coastguard Worker# CI. 74*da0073e9SAndroid Build Coastguard Worker 75*da0073e9SAndroid Build Coastguard Workerclass TestNN(NNTestCase): 76*da0073e9SAndroid Build Coastguard Worker _do_cuda_memory_leak_check = True 77*da0073e9SAndroid Build Coastguard Worker _do_cuda_non_default_stream = True 78*da0073e9SAndroid Build Coastguard Worker 79*da0073e9SAndroid Build Coastguard Worker def _forward(self, module, input: _TensorOrTensors): 80*da0073e9SAndroid Build Coastguard Worker with freeze_rng_state(): 81*da0073e9SAndroid Build Coastguard Worker if isinstance(input, tuple): 82*da0073e9SAndroid Build Coastguard Worker return module(*input) 83*da0073e9SAndroid Build Coastguard Worker else: 84*da0073e9SAndroid Build Coastguard Worker return module(input) 85*da0073e9SAndroid Build Coastguard Worker 86*da0073e9SAndroid Build Coastguard Worker def _backward(self, module, input: _TensorOrTensors, output, grad_output, create_graph=False): 87*da0073e9SAndroid Build Coastguard Worker output.backward(grad_output, retain_graph=True, create_graph=create_graph) 88*da0073e9SAndroid Build Coastguard Worker if isinstance(input, tuple): 89*da0073e9SAndroid Build Coastguard Worker return tuple(i.grad.data if i.grad is not None else None for i in input) 90*da0073e9SAndroid Build Coastguard Worker else: 91*da0073e9SAndroid Build Coastguard Worker return input.grad.data if input.grad is not None else None 92*da0073e9SAndroid Build Coastguard Worker 93*da0073e9SAndroid Build Coastguard Worker def _forward_criterion(self, criterion, input, target, extra_args=None): 94*da0073e9SAndroid Build Coastguard Worker if extra_args is None: 95*da0073e9SAndroid Build Coastguard Worker extra_args = () 96*da0073e9SAndroid Build Coastguard Worker if isinstance(input, tuple): 97*da0073e9SAndroid Build Coastguard Worker args = input + (target,) + extra_args 98*da0073e9SAndroid Build Coastguard Worker output = criterion(*args) 99*da0073e9SAndroid Build Coastguard Worker else: 100*da0073e9SAndroid Build Coastguard Worker output = criterion(input, target, *extra_args) 101*da0073e9SAndroid Build Coastguard Worker return output 102*da0073e9SAndroid Build Coastguard Worker 103*da0073e9SAndroid Build Coastguard Worker def _backward_criterion(self, criterion, input, output, target, gradOutput=None, extra_args=None): 104*da0073e9SAndroid Build Coastguard Worker if extra_args is None: 105*da0073e9SAndroid Build Coastguard Worker extra_args = () 106*da0073e9SAndroid Build Coastguard Worker input_tuple = input if isinstance(input, tuple) else (input,) 107*da0073e9SAndroid Build Coastguard Worker output_tuple = output if isinstance(output, tuple) else (output,) 108*da0073e9SAndroid Build Coastguard Worker for i in input_tuple: 109*da0073e9SAndroid Build Coastguard Worker if i.grad is not None: 110*da0073e9SAndroid Build Coastguard Worker i.grad.data.zero_() 111*da0073e9SAndroid Build Coastguard Worker args = input_tuple + (target,) + extra_args 112*da0073e9SAndroid Build Coastguard Worker if gradOutput is None: 113*da0073e9SAndroid Build Coastguard Worker gradOutput = torch.ones(()) 114*da0073e9SAndroid Build Coastguard Worker criterion(*args).backward(gradOutput.to(output_tuple[0])) 115*da0073e9SAndroid Build Coastguard Worker if isinstance(input, tuple): 116*da0073e9SAndroid Build Coastguard Worker return tuple(i.grad.data for i in input) 117*da0073e9SAndroid Build Coastguard Worker else: 118*da0073e9SAndroid Build Coastguard Worker return input.grad.data 119*da0073e9SAndroid Build Coastguard Worker 120*da0073e9SAndroid Build Coastguard Worker def _zero_grad_parameters(self, module): 121*da0073e9SAndroid Build Coastguard Worker for p in module.parameters(): 122*da0073e9SAndroid Build Coastguard Worker if p.grad is not None: 123*da0073e9SAndroid Build Coastguard Worker with torch.no_grad(): 124*da0073e9SAndroid Build Coastguard Worker p.grad.zero_() 125*da0073e9SAndroid Build Coastguard Worker p.grad.detach_() 126*da0073e9SAndroid Build Coastguard Worker 127*da0073e9SAndroid Build Coastguard Worker def _get_parameters(self, module): 128*da0073e9SAndroid Build Coastguard Worker params = [] 129*da0073e9SAndroid Build Coastguard Worker d_params = [] 130*da0073e9SAndroid Build Coastguard Worker for p in module.parameters(): 131*da0073e9SAndroid Build Coastguard Worker params.append(p) 132*da0073e9SAndroid Build Coastguard Worker d_params.append(p.grad) 133*da0073e9SAndroid Build Coastguard Worker return params, d_params 134*da0073e9SAndroid Build Coastguard Worker 135*da0073e9SAndroid Build Coastguard Worker def test_parse_to(self): 136*da0073e9SAndroid Build Coastguard Worker # Test for buggy use of THPMemoryFormat_New 137*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 138*da0073e9SAndroid Build Coastguard Worker repr(torch._C._nn._parse_to(memory_format=torch.contiguous_format)[3]), 139*da0073e9SAndroid Build Coastguard Worker "torch.contiguous_format" 140*da0073e9SAndroid Build Coastguard Worker ) 141*da0073e9SAndroid Build Coastguard Worker 142*da0073e9SAndroid Build Coastguard Worker def test_requires_grad_(self): 143*da0073e9SAndroid Build Coastguard Worker m = _create_basic_net()[-1] 144*da0073e9SAndroid Build Coastguard Worker assert len(list(m.buffers())) > 0, 'invalid test' 145*da0073e9SAndroid Build Coastguard Worker assert all(not b.requires_grad for b in m.buffers()) > 0, 'invalid test' 146*da0073e9SAndroid Build Coastguard Worker assert len(list(m.parameters())) > 0, 'invalid test' 147*da0073e9SAndroid Build Coastguard Worker assert all(p.requires_grad for p in m.parameters()) > 0, 'invalid test' 148*da0073e9SAndroid Build Coastguard Worker for requires_grad in (False, True): 149*da0073e9SAndroid Build Coastguard Worker self.assertIs(m.requires_grad_(requires_grad), m) 150*da0073e9SAndroid Build Coastguard Worker for p in m.parameters(): 151*da0073e9SAndroid Build Coastguard Worker self.assertEqual(p.requires_grad, requires_grad) 152*da0073e9SAndroid Build Coastguard Worker for b in m.buffers(): 153*da0073e9SAndroid Build Coastguard Worker self.assertFalse(b.requires_grad) 154*da0073e9SAndroid Build Coastguard Worker 155*da0073e9SAndroid Build Coastguard Worker def test_module_backcompat(self): 156*da0073e9SAndroid Build Coastguard Worker from torch.serialization import SourceChangeWarning 157*da0073e9SAndroid Build Coastguard Worker path = download_file('https://download.pytorch.org/test_data/linear.pt') 158*da0073e9SAndroid Build Coastguard Worker with warnings.catch_warnings(): 159*da0073e9SAndroid Build Coastguard Worker warnings.simplefilter('ignore', SourceChangeWarning) 160*da0073e9SAndroid Build Coastguard Worker # weights_only=False as this is legacy code that saves the model 161*da0073e9SAndroid Build Coastguard Worker m = torch.load(path, weights_only=False) 162*da0073e9SAndroid Build Coastguard Worker input = torch.randn(2, 3, dtype=torch.float) 163*da0073e9SAndroid Build Coastguard Worker self.assertEqual(m(input).size(), (2, 5)) 164*da0073e9SAndroid Build Coastguard Worker 165*da0073e9SAndroid Build Coastguard Worker def test_module_super_init(self): 166*da0073e9SAndroid Build Coastguard Worker class MyMixin: 167*da0073e9SAndroid Build Coastguard Worker def __init__(self, *a, **kw): 168*da0073e9SAndroid Build Coastguard Worker super().__init__(*a, **kw) 169*da0073e9SAndroid Build Coastguard Worker self.mixin_init = True 170*da0073e9SAndroid Build Coastguard Worker 171*da0073e9SAndroid Build Coastguard Worker class MyModuleWithMixinBefore(MyMixin, nn.Module): 172*da0073e9SAndroid Build Coastguard Worker pass 173*da0073e9SAndroid Build Coastguard Worker 174*da0073e9SAndroid Build Coastguard Worker class MyModuleWithMixinAfter(nn.Module, MyMixin): 175*da0073e9SAndroid Build Coastguard Worker pass 176*da0073e9SAndroid Build Coastguard Worker 177*da0073e9SAndroid Build Coastguard Worker self.assertTrue(hasattr(MyModuleWithMixinBefore(), 'mixin_init')) 178*da0073e9SAndroid Build Coastguard Worker self.assertFalse(hasattr(MyModuleWithMixinAfter(), 'mixin_init')) 179*da0073e9SAndroid Build Coastguard Worker 180*da0073e9SAndroid Build Coastguard Worker nn.Module.call_super_init = True 181*da0073e9SAndroid Build Coastguard Worker self.assertTrue(hasattr(MyModuleWithMixinBefore(), 'mixin_init')) 182*da0073e9SAndroid Build Coastguard Worker self.assertTrue(hasattr(MyModuleWithMixinAfter(), 'mixin_init')) 183*da0073e9SAndroid Build Coastguard Worker nn.Module.call_super_init = False 184*da0073e9SAndroid Build Coastguard Worker 185*da0073e9SAndroid Build Coastguard Worker MyModuleWithMixinBefore.call_super_init = True 186*da0073e9SAndroid Build Coastguard Worker MyModuleWithMixinAfter.call_super_init = True 187*da0073e9SAndroid Build Coastguard Worker self.assertTrue(hasattr(MyModuleWithMixinBefore(), 'mixin_init')) 188*da0073e9SAndroid Build Coastguard Worker self.assertTrue(hasattr(MyModuleWithMixinAfter(), 'mixin_init')) 189*da0073e9SAndroid Build Coastguard Worker MyModuleWithMixinBefore.call_super_init = False 190*da0073e9SAndroid Build Coastguard Worker MyModuleWithMixinAfter.call_super_init = False 191*da0073e9SAndroid Build Coastguard Worker 192*da0073e9SAndroid Build Coastguard Worker def test_share_memory(self): 193*da0073e9SAndroid Build Coastguard Worker class Net(nn.Module): 194*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 195*da0073e9SAndroid Build Coastguard Worker super().__init__() 196*da0073e9SAndroid Build Coastguard Worker self.p = nn.Parameter(torch.eye(5)) 197*da0073e9SAndroid Build Coastguard Worker self.par = nn.ParameterList() 198*da0073e9SAndroid Build Coastguard Worker self.par.append(nn.Parameter(torch.randn(10))) 199*da0073e9SAndroid Build Coastguard Worker 200*da0073e9SAndroid Build Coastguard Worker def forward(self, inp): 201*da0073e9SAndroid Build Coastguard Worker # NB: dead code 202*da0073e9SAndroid Build Coastguard Worker return inp.clone() 203*da0073e9SAndroid Build Coastguard Worker 204*da0073e9SAndroid Build Coastguard Worker net = Net() 205*da0073e9SAndroid Build Coastguard Worker for p in net.parameters(): 206*da0073e9SAndroid Build Coastguard Worker self.assertFalse(p.storage().is_shared()) 207*da0073e9SAndroid Build Coastguard Worker for b in net.buffers(): 208*da0073e9SAndroid Build Coastguard Worker self.assertFalse(b.storage().is_shared()) 209*da0073e9SAndroid Build Coastguard Worker net.share_memory() 210*da0073e9SAndroid Build Coastguard Worker for p in net.parameters(): 211*da0073e9SAndroid Build Coastguard Worker self.assertTrue(p.storage().is_shared()) 212*da0073e9SAndroid Build Coastguard Worker for b in net.buffers(): 213*da0073e9SAndroid Build Coastguard Worker self.assertTrue(b.storage().is_shared()) 214*da0073e9SAndroid Build Coastguard Worker 215*da0073e9SAndroid Build Coastguard Worker def test_to(self): 216*da0073e9SAndroid Build Coastguard Worker m = nn.Linear(3, 5) 217*da0073e9SAndroid Build Coastguard Worker self.assertIs(m, m.to('cpu')) 218*da0073e9SAndroid Build Coastguard Worker self.assertIs(m, m.to('cpu', dtype=torch.float32)) 219*da0073e9SAndroid Build Coastguard Worker self.assertEqual(m.double(), m.to(torch.float64)) 220*da0073e9SAndroid Build Coastguard Worker self.assertRaises(RuntimeError, lambda: m.to('cpu', copy=True)) 221*da0073e9SAndroid Build Coastguard Worker 222*da0073e9SAndroid Build Coastguard Worker if torch.cuda.is_available(): 223*da0073e9SAndroid Build Coastguard Worker for cuda in ['cuda', 'cuda:0' if torch.cuda.device_count() == 1 else 'cuda:1']: 224*da0073e9SAndroid Build Coastguard Worker m2 = m.cuda(device=cuda) 225*da0073e9SAndroid Build Coastguard Worker self.assertIs(m2, m2.to(cuda)) 226*da0073e9SAndroid Build Coastguard Worker self.assertEqual(m, m2.to('cpu')) 227*da0073e9SAndroid Build Coastguard Worker self.assertEqual(m2, m.to(cuda)) 228*da0073e9SAndroid Build Coastguard Worker self.assertIs(m2, m2.to(dtype=torch.float32)) 229*da0073e9SAndroid Build Coastguard Worker self.assertEqual(m2.double(), m2.to(dtype=torch.float64)) 230*da0073e9SAndroid Build Coastguard Worker 231*da0073e9SAndroid Build Coastguard Worker def test_zero_grad(self): 232*da0073e9SAndroid Build Coastguard Worker i = torch.randn(2, 5, requires_grad=True) 233*da0073e9SAndroid Build Coastguard Worker module = nn.Linear(5, 5) 234*da0073e9SAndroid Build Coastguard Worker for p in module.parameters(): 235*da0073e9SAndroid Build Coastguard Worker p.requires_grad = False 236*da0073e9SAndroid Build Coastguard Worker module.zero_grad() 237*da0073e9SAndroid Build Coastguard Worker 238*da0073e9SAndroid Build Coastguard Worker module.weight.requires_grad = True 239*da0073e9SAndroid Build Coastguard Worker module.zero_grad() 240*da0073e9SAndroid Build Coastguard Worker self.assertIsNone(module.weight.grad) # uninitialized grad 241*da0073e9SAndroid Build Coastguard Worker 242*da0073e9SAndroid Build Coastguard Worker module(i).sum().backward() 243*da0073e9SAndroid Build Coastguard Worker self.assertIsNotNone(module.weight.grad) 244*da0073e9SAndroid Build Coastguard Worker self.assertGreater(module.weight.grad.data.abs().sum(), 0) 245*da0073e9SAndroid Build Coastguard Worker module.zero_grad() 246*da0073e9SAndroid Build Coastguard Worker self.assertIsNone(module.weight.grad) 247*da0073e9SAndroid Build Coastguard Worker 248*da0073e9SAndroid Build Coastguard Worker module.bias.requires_grad = True 249*da0073e9SAndroid Build Coastguard Worker module.zero_grad() 250*da0073e9SAndroid Build Coastguard Worker self.assertIsNone(module.weight.grad) 251*da0073e9SAndroid Build Coastguard Worker self.assertIsNone(module.bias.grad) 252*da0073e9SAndroid Build Coastguard Worker module(i).sum().backward() 253*da0073e9SAndroid Build Coastguard Worker self.assertIsNotNone(module.weight.grad) 254*da0073e9SAndroid Build Coastguard Worker self.assertIsNotNone(module.bias.grad) 255*da0073e9SAndroid Build Coastguard Worker self.assertGreater(module.weight.grad.data.abs().sum(), 0) 256*da0073e9SAndroid Build Coastguard Worker self.assertGreater(module.bias.grad.data.abs().sum(), 0) 257*da0073e9SAndroid Build Coastguard Worker module.zero_grad(set_to_none=False) # Force set to zeros. 258*da0073e9SAndroid Build Coastguard Worker self.assertEqual(module.weight.grad.data, module.weight.data.clone().zero_()) 259*da0073e9SAndroid Build Coastguard Worker self.assertEqual(module.bias.grad.data, module.bias.data.clone().zero_()) 260*da0073e9SAndroid Build Coastguard Worker 261*da0073e9SAndroid Build Coastguard Worker module.zero_grad() 262*da0073e9SAndroid Build Coastguard Worker self.assertIsNone(module.weight.grad) 263*da0073e9SAndroid Build Coastguard Worker self.assertIsNone(module.bias.grad) 264*da0073e9SAndroid Build Coastguard Worker 265*da0073e9SAndroid Build Coastguard Worker def test_no_grad(self): 266*da0073e9SAndroid Build Coastguard Worker for dtype in [torch.bfloat16, torch.float, torch.double]: 267*da0073e9SAndroid Build Coastguard Worker module = nn.Conv2d(2, 5, kernel_size=3, padding=1).to(dtype) 268*da0073e9SAndroid Build Coastguard Worker input = torch.randn(1, 2, 10, 10).to(dtype) 269*da0073e9SAndroid Build Coastguard Worker x = input 270*da0073e9SAndroid Build Coastguard Worker y = input.clone() 271*da0073e9SAndroid Build Coastguard Worker 272*da0073e9SAndroid Build Coastguard Worker output = module(x) 273*da0073e9SAndroid Build Coastguard Worker self.assertTrue(output.requires_grad) 274*da0073e9SAndroid Build Coastguard Worker output.backward(torch.ones(1, 5, 10, 10)) 275*da0073e9SAndroid Build Coastguard Worker 276*da0073e9SAndroid Build Coastguard Worker with torch.no_grad(): 277*da0073e9SAndroid Build Coastguard Worker output2 = module(y) 278*da0073e9SAndroid Build Coastguard Worker self.assertFalse(output2.requires_grad) 279*da0073e9SAndroid Build Coastguard Worker self.assertRaises(RuntimeError, lambda: output2.backward(torch.ones(1, 5, 10, 10))) 280*da0073e9SAndroid Build Coastguard Worker 281*da0073e9SAndroid Build Coastguard Worker def test_parameters_and_named_parameters(self): 282*da0073e9SAndroid Build Coastguard Worker def names(named_parameters): 283*da0073e9SAndroid Build Coastguard Worker return [k for k, _ in named_parameters] 284*da0073e9SAndroid Build Coastguard Worker 285*da0073e9SAndroid Build Coastguard Worker l, n, s = _create_basic_net() 286*da0073e9SAndroid Build Coastguard Worker 287*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(list(l.parameters())), 1) 288*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 289*da0073e9SAndroid Build Coastguard Worker names(l.named_parameters()), 290*da0073e9SAndroid Build Coastguard Worker ['layer_dummy_param']) 291*da0073e9SAndroid Build Coastguard Worker 292*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(list(n.parameters())), 2) 293*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 294*da0073e9SAndroid Build Coastguard Worker names(n.named_parameters()), 295*da0073e9SAndroid Build Coastguard Worker ['dummy_param', 'l1.layer_dummy_param']) 296*da0073e9SAndroid Build Coastguard Worker 297*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(list(n.parameters(recurse=False))), 1) 298*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 299*da0073e9SAndroid Build Coastguard Worker names(n.named_parameters(recurse=False)), 300*da0073e9SAndroid Build Coastguard Worker ['dummy_param']) 301*da0073e9SAndroid Build Coastguard Worker 302*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(list(s.parameters())), 2) 303*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 304*da0073e9SAndroid Build Coastguard Worker names(s.named_parameters()), 305*da0073e9SAndroid Build Coastguard Worker ['0.dummy_param', '0.l1.layer_dummy_param']) 306*da0073e9SAndroid Build Coastguard Worker 307*da0073e9SAndroid Build Coastguard Worker def test_named_parameters_remove_duplicate(self): 308*da0073e9SAndroid Build Coastguard Worker def names(named_parameters): 309*da0073e9SAndroid Build Coastguard Worker return [k for k, _ in named_parameters] 310*da0073e9SAndroid Build Coastguard Worker 311*da0073e9SAndroid Build Coastguard Worker class M1(nn.Module): 312*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 313*da0073e9SAndroid Build Coastguard Worker super().__init__() 314*da0073e9SAndroid Build Coastguard Worker self.param1 = nn.Parameter(torch.empty(3, 3)) 315*da0073e9SAndroid Build Coastguard Worker self.param2 = self.param1 316*da0073e9SAndroid Build Coastguard Worker 317*da0073e9SAndroid Build Coastguard Worker m1 = M1() 318*da0073e9SAndroid Build Coastguard Worker self.assertEqual(names(m1.named_parameters()), 319*da0073e9SAndroid Build Coastguard Worker ["param1"]) 320*da0073e9SAndroid Build Coastguard Worker self.assertEqual(names(m1.named_parameters(remove_duplicate=False)), 321*da0073e9SAndroid Build Coastguard Worker ["param1", "param2"]) 322*da0073e9SAndroid Build Coastguard Worker 323*da0073e9SAndroid Build Coastguard Worker class M2(nn.Module): 324*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 325*da0073e9SAndroid Build Coastguard Worker super().__init__() 326*da0073e9SAndroid Build Coastguard Worker self.mod1 = nn.Linear(3, 4, bias=False) 327*da0073e9SAndroid Build Coastguard Worker self.mod2 = self.mod1 328*da0073e9SAndroid Build Coastguard Worker 329*da0073e9SAndroid Build Coastguard Worker m2 = M2() 330*da0073e9SAndroid Build Coastguard Worker self.assertEqual(names(m2.named_parameters()), 331*da0073e9SAndroid Build Coastguard Worker ["mod1.weight"]) 332*da0073e9SAndroid Build Coastguard Worker self.assertEqual(names(m2.named_parameters(remove_duplicate=False)), 333*da0073e9SAndroid Build Coastguard Worker ["mod1.weight", "mod2.weight"]) 334*da0073e9SAndroid Build Coastguard Worker 335*da0073e9SAndroid Build Coastguard Worker def test_buffers_and_named_buffers(self): 336*da0073e9SAndroid Build Coastguard Worker def names(named_buffers): 337*da0073e9SAndroid Build Coastguard Worker return [k for k, _ in named_buffers] 338*da0073e9SAndroid Build Coastguard Worker 339*da0073e9SAndroid Build Coastguard Worker l, n, s = _create_basic_net() 340*da0073e9SAndroid Build Coastguard Worker 341*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(list(l.buffers())), 1) 342*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 343*da0073e9SAndroid Build Coastguard Worker names(l.named_buffers()), 344*da0073e9SAndroid Build Coastguard Worker ['layer_dummy_buf']) 345*da0073e9SAndroid Build Coastguard Worker 346*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(list(n.buffers())), 2) 347*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 348*da0073e9SAndroid Build Coastguard Worker names(n.named_buffers()), 349*da0073e9SAndroid Build Coastguard Worker ['dummy_buf', 'l1.layer_dummy_buf']) 350*da0073e9SAndroid Build Coastguard Worker 351*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(list(n.buffers(recurse=False))), 1) 352*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 353*da0073e9SAndroid Build Coastguard Worker names(n.named_buffers(recurse=False)), 354*da0073e9SAndroid Build Coastguard Worker ['dummy_buf']) 355*da0073e9SAndroid Build Coastguard Worker 356*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(list(s.buffers())), 2) 357*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 358*da0073e9SAndroid Build Coastguard Worker names(s.named_buffers()), 359*da0073e9SAndroid Build Coastguard Worker ['0.dummy_buf', '0.l1.layer_dummy_buf']) 360*da0073e9SAndroid Build Coastguard Worker 361*da0073e9SAndroid Build Coastguard Worker # test remove_duplicate 362*da0073e9SAndroid Build Coastguard Worker class M(nn.Module): 363*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 364*da0073e9SAndroid Build Coastguard Worker super().__init__() 365*da0073e9SAndroid Build Coastguard Worker self.buffer1 = Buffer(torch.empty(3, 5)) 366*da0073e9SAndroid Build Coastguard Worker self.buffer2 = self.buffer1 367*da0073e9SAndroid Build Coastguard Worker 368*da0073e9SAndroid Build Coastguard Worker m = M() 369*da0073e9SAndroid Build Coastguard Worker self.assertEqual(names(m.named_buffers()), 370*da0073e9SAndroid Build Coastguard Worker ["buffer1"]) 371*da0073e9SAndroid Build Coastguard Worker self.assertEqual(names(m.named_buffers(remove_duplicate=False)), 372*da0073e9SAndroid Build Coastguard Worker ["buffer1", "buffer2"]) 373*da0073e9SAndroid Build Coastguard Worker 374*da0073e9SAndroid Build Coastguard Worker def test_buffer_bad_module_subclass(self): 375*da0073e9SAndroid Build Coastguard Worker class MyBadModule(nn.Linear): 376*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 377*da0073e9SAndroid Build Coastguard Worker super().__init__(2, 2) 378*da0073e9SAndroid Build Coastguard Worker self.bar = Buffer(torch.rand(2, 2)) 379*da0073e9SAndroid Build Coastguard Worker 380*da0073e9SAndroid Build Coastguard Worker def register_buffer(self, name, value): 381*da0073e9SAndroid Build Coastguard Worker # persistent is explicitly missing! 382*da0073e9SAndroid Build Coastguard Worker super().register_buffer(name, value, True) 383*da0073e9SAndroid Build Coastguard Worker 384*da0073e9SAndroid Build Coastguard Worker foo = MyBadModule() 385*da0073e9SAndroid Build Coastguard Worker self.assertIsNotNone(foo.bar) 386*da0073e9SAndroid Build Coastguard Worker 387*da0073e9SAndroid Build Coastguard Worker def test_call_supports_python_dict_output(self): 388*da0073e9SAndroid Build Coastguard Worker class Net(nn.Module): 389*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 390*da0073e9SAndroid Build Coastguard Worker super().__init__() 391*da0073e9SAndroid Build Coastguard Worker self.l1 = nn.Linear(10, 20) 392*da0073e9SAndroid Build Coastguard Worker self.register_backward_hook(self.hook) 393*da0073e9SAndroid Build Coastguard Worker self.check_backward_hook_flag = False 394*da0073e9SAndroid Build Coastguard Worker 395*da0073e9SAndroid Build Coastguard Worker def hook(self, module, grad_out, grad_in): 396*da0073e9SAndroid Build Coastguard Worker self.check_backward_hook_flag = True 397*da0073e9SAndroid Build Coastguard Worker 398*da0073e9SAndroid Build Coastguard Worker def forward(self, inputs): 399*da0073e9SAndroid Build Coastguard Worker return {"output": self.l1(inputs).sum()} 400*da0073e9SAndroid Build Coastguard Worker 401*da0073e9SAndroid Build Coastguard Worker net = Net() 402*da0073e9SAndroid Build Coastguard Worker model_output = net(torch.randn([5, 10])) 403*da0073e9SAndroid Build Coastguard Worker model_output["output"].backward() 404*da0073e9SAndroid Build Coastguard Worker self.assertTrue(net.check_backward_hook_flag) 405*da0073e9SAndroid Build Coastguard Worker 406*da0073e9SAndroid Build Coastguard Worker def test_children(self): 407*da0073e9SAndroid Build Coastguard Worker l1 = nn.Linear(2, 2) 408*da0073e9SAndroid Build Coastguard Worker l2 = nn.Linear(2, 2) 409*da0073e9SAndroid Build Coastguard Worker l3 = nn.Linear(2, 2) 410*da0073e9SAndroid Build Coastguard Worker l4 = nn.Linear(2, 2) 411*da0073e9SAndroid Build Coastguard Worker subnet = nn.Sequential(l3, l4) 412*da0073e9SAndroid Build Coastguard Worker s = nn.Sequential(l1, l2, l1, l2, subnet) 413*da0073e9SAndroid Build Coastguard Worker self.assertEqual(list(s.children()), [l1, l2, subnet]) 414*da0073e9SAndroid Build Coastguard Worker 415*da0073e9SAndroid Build Coastguard Worker def test_train_errors_for_invalid_mode(self): 416*da0073e9SAndroid Build Coastguard Worker class SubclassNet(nn.Module): 417*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 418*da0073e9SAndroid Build Coastguard Worker super().__init__() 419*da0073e9SAndroid Build Coastguard Worker self.l1 = nn.Linear(2, 2) 420*da0073e9SAndroid Build Coastguard Worker 421*da0073e9SAndroid Build Coastguard Worker def forward(self, inputs): 422*da0073e9SAndroid Build Coastguard Worker return self.l1(inputs) 423*da0073e9SAndroid Build Coastguard Worker 424*da0073e9SAndroid Build Coastguard Worker subclass_net = SubclassNet() 425*da0073e9SAndroid Build Coastguard Worker sequential_net = nn.Sequential(nn.Linear(2, 2), nn.Linear(2, 2)) 426*da0073e9SAndroid Build Coastguard Worker 427*da0073e9SAndroid Build Coastguard Worker error_modes = ["invalid_str", torch.device('cpu')] 428*da0073e9SAndroid Build Coastguard Worker modules_to_check = [subclass_net, sequential_net] 429*da0073e9SAndroid Build Coastguard Worker 430*da0073e9SAndroid Build Coastguard Worker for error_mode, module in itertools.product(error_modes, modules_to_check): 431*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(ValueError): 432*da0073e9SAndroid Build Coastguard Worker module.train(error_mode) 433*da0073e9SAndroid Build Coastguard Worker 434*da0073e9SAndroid Build Coastguard Worker def test_dir(self): 435*da0073e9SAndroid Build Coastguard Worker linear = nn.Linear(2, 2) 436*da0073e9SAndroid Build Coastguard Worker linear._test_submodule = nn.Linear(2, 2) 437*da0073e9SAndroid Build Coastguard Worker linear._test_parameter = Parameter(torch.empty(2, 2)) 438*da0073e9SAndroid Build Coastguard Worker linear._test_buffer = Buffer(torch.empty(2, 2)) 439*da0073e9SAndroid Build Coastguard Worker keys = dir(linear) 440*da0073e9SAndroid Build Coastguard Worker self.assertIn('_test_submodule', keys) 441*da0073e9SAndroid Build Coastguard Worker self.assertIn('_test_parameter', keys) 442*da0073e9SAndroid Build Coastguard Worker self.assertIn('_test_buffer', keys) 443*da0073e9SAndroid Build Coastguard Worker 444*da0073e9SAndroid Build Coastguard Worker for key in keys: 445*da0073e9SAndroid Build Coastguard Worker self.assertTrue(hasattr(linear, key)) 446*da0073e9SAndroid Build Coastguard Worker 447*da0073e9SAndroid Build Coastguard Worker def test_repr(self): 448*da0073e9SAndroid Build Coastguard Worker # no extra information or sub-modules 449*da0073e9SAndroid Build Coastguard Worker empty_sequential = nn.Sequential() 450*da0073e9SAndroid Build Coastguard Worker expected_repr_empty = 'Sequential()' 451*da0073e9SAndroid Build Coastguard Worker self.assertEqual(repr(empty_sequential), expected_repr_empty) 452*da0073e9SAndroid Build Coastguard Worker 453*da0073e9SAndroid Build Coastguard Worker # one liner extra information 454*da0073e9SAndroid Build Coastguard Worker linear = nn.Linear(1, 1) 455*da0073e9SAndroid Build Coastguard Worker expected_repr_linear = 'Linear(in_features=1, out_features=1, bias=True)' 456*da0073e9SAndroid Build Coastguard Worker self.assertEqual(repr(linear), expected_repr_linear) 457*da0073e9SAndroid Build Coastguard Worker 458*da0073e9SAndroid Build Coastguard Worker # sub-modules repr 459*da0073e9SAndroid Build Coastguard Worker sequential = nn.Sequential(linear) 460*da0073e9SAndroid Build Coastguard Worker expected_repr_sequential = 'Sequential(\n' \ 461*da0073e9SAndroid Build Coastguard Worker ' (0): Linear(in_features=1, out_features=1, bias=True)\n' \ 462*da0073e9SAndroid Build Coastguard Worker ')' 463*da0073e9SAndroid Build Coastguard Worker self.assertEqual(repr(sequential), expected_repr_sequential) 464*da0073e9SAndroid Build Coastguard Worker 465*da0073e9SAndroid Build Coastguard Worker def test_dir_digit(self): 466*da0073e9SAndroid Build Coastguard Worker model = nn.Sequential(nn.Linear(2, 2)) 467*da0073e9SAndroid Build Coastguard Worker keys = dir(model) 468*da0073e9SAndroid Build Coastguard Worker self.assertNotIn('0', keys) 469*da0073e9SAndroid Build Coastguard Worker 470*da0073e9SAndroid Build Coastguard Worker def test_named_children(self): 471*da0073e9SAndroid Build Coastguard Worker l1 = nn.Linear(2, 2) 472*da0073e9SAndroid Build Coastguard Worker l2 = nn.Linear(2, 2) 473*da0073e9SAndroid Build Coastguard Worker l3 = nn.Linear(2, 2) 474*da0073e9SAndroid Build Coastguard Worker l4 = nn.Linear(2, 2) 475*da0073e9SAndroid Build Coastguard Worker subnet = nn.Sequential(l3, l4) 476*da0073e9SAndroid Build Coastguard Worker s = nn.Sequential() 477*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(KeyError): 478*da0073e9SAndroid Build Coastguard Worker s.add_module('', l1) 479*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(KeyError): 480*da0073e9SAndroid Build Coastguard Worker s.add_module('name.with.dot', l1) 481*da0073e9SAndroid Build Coastguard Worker s.add_module('layer1', l1) 482*da0073e9SAndroid Build Coastguard Worker s.add_module('layer2', l2) 483*da0073e9SAndroid Build Coastguard Worker s.add_module('layer3', l1) 484*da0073e9SAndroid Build Coastguard Worker s.add_module('layer4', l2) 485*da0073e9SAndroid Build Coastguard Worker s.add_module('subnet', subnet) 486*da0073e9SAndroid Build Coastguard Worker self.assertEqual(list(s.named_children()), [('layer1', l1), ('layer2', l2), ('subnet', subnet)]) 487*da0073e9SAndroid Build Coastguard Worker 488*da0073e9SAndroid Build Coastguard Worker def test_modules(self): 489*da0073e9SAndroid Build Coastguard Worker class Net(nn.Module): 490*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 491*da0073e9SAndroid Build Coastguard Worker super().__init__() 492*da0073e9SAndroid Build Coastguard Worker self.l1 = l 493*da0073e9SAndroid Build Coastguard Worker self.l2 = l 494*da0073e9SAndroid Build Coastguard Worker self.param = torch.empty(3, 5) 495*da0073e9SAndroid Build Coastguard Worker 496*da0073e9SAndroid Build Coastguard Worker l = nn.Linear(10, 20) 497*da0073e9SAndroid Build Coastguard Worker n = Net() 498*da0073e9SAndroid Build Coastguard Worker s = nn.Sequential(n, n, n, n) 499*da0073e9SAndroid Build Coastguard Worker self.assertEqual(list(s.modules()), [s, n, l]) 500*da0073e9SAndroid Build Coastguard Worker 501*da0073e9SAndroid Build Coastguard Worker def test_named_modules(self): 502*da0073e9SAndroid Build Coastguard Worker class Net(nn.Module): 503*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 504*da0073e9SAndroid Build Coastguard Worker super().__init__() 505*da0073e9SAndroid Build Coastguard Worker self.l1 = l 506*da0073e9SAndroid Build Coastguard Worker self.l2 = l 507*da0073e9SAndroid Build Coastguard Worker self.param = torch.empty(3, 5) 508*da0073e9SAndroid Build Coastguard Worker self.block = block 509*da0073e9SAndroid Build Coastguard Worker l = nn.Linear(10, 20) 510*da0073e9SAndroid Build Coastguard Worker l1 = nn.Linear(10, 20) 511*da0073e9SAndroid Build Coastguard Worker l2 = nn.Linear(10, 20) 512*da0073e9SAndroid Build Coastguard Worker block = nn.Sequential() 513*da0073e9SAndroid Build Coastguard Worker block.add_module('linear1', l1) 514*da0073e9SAndroid Build Coastguard Worker block.add_module('linear2', l2) 515*da0073e9SAndroid Build Coastguard Worker n = Net() 516*da0073e9SAndroid Build Coastguard Worker s = nn.Sequential(n, n) 517*da0073e9SAndroid Build Coastguard Worker self.assertEqual(list(s.named_modules()), [('', s), ('0', n), ('0.l1', l), 518*da0073e9SAndroid Build Coastguard Worker ('0.block', block), ('0.block.linear1', l1), 519*da0073e9SAndroid Build Coastguard Worker ('0.block.linear2', l2)]) 520*da0073e9SAndroid Build Coastguard Worker # test the option to not remove duplicate module instances 521*da0073e9SAndroid Build Coastguard Worker self.assertEqual(list(s.named_modules(remove_duplicate=False)), [ 522*da0073e9SAndroid Build Coastguard Worker ('', s), ('0', n), ('0.l1', l), ('0.l2', l), 523*da0073e9SAndroid Build Coastguard Worker ('0.block', block), ('0.block.linear1', l1), 524*da0073e9SAndroid Build Coastguard Worker ('0.block.linear2', l2), 525*da0073e9SAndroid Build Coastguard Worker ('1', n), ('1.l1', l), ('1.l2', l), 526*da0073e9SAndroid Build Coastguard Worker ('1.block', block), ('1.block.linear1', l1), 527*da0073e9SAndroid Build Coastguard Worker ('1.block.linear2', l2)]) 528*da0073e9SAndroid Build Coastguard Worker 529*da0073e9SAndroid Build Coastguard Worker def test_register_buffer_raises_error_if_name_is_not_string(self): 530*da0073e9SAndroid Build Coastguard Worker m = nn.Module() 531*da0073e9SAndroid Build Coastguard Worker expected_error = 'buffer name should be a string. Got ' 532*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(TypeError, expected_error + 'int'): 533*da0073e9SAndroid Build Coastguard Worker m.register_buffer(1, torch.rand(5)) 534*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(TypeError, expected_error + 'NoneType'): 535*da0073e9SAndroid Build Coastguard Worker m.register_buffer(None, torch.rand(5)) 536*da0073e9SAndroid Build Coastguard Worker 537*da0073e9SAndroid Build Coastguard Worker def test_register_buffer_raises_error_if_attr_exists(self): 538*da0073e9SAndroid Build Coastguard Worker m = nn.Module() 539*da0073e9SAndroid Build Coastguard Worker m.attribute_name = 5 540*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(KeyError): 541*da0073e9SAndroid Build Coastguard Worker m.register_buffer('attribute_name', torch.rand(5)) 542*da0073e9SAndroid Build Coastguard Worker 543*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(KeyError): 544*da0073e9SAndroid Build Coastguard Worker m.attribute_name = Buffer(torch.rand(5)) 545*da0073e9SAndroid Build Coastguard Worker 546*da0073e9SAndroid Build Coastguard Worker del m.attribute_name 547*da0073e9SAndroid Build Coastguard Worker m.register_parameter('attribute_name', nn.Parameter()) 548*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(KeyError): 549*da0073e9SAndroid Build Coastguard Worker m.register_buffer('attribute_name', torch.rand(5)) 550*da0073e9SAndroid Build Coastguard Worker 551*da0073e9SAndroid Build Coastguard Worker del m.attribute_name 552*da0073e9SAndroid Build Coastguard Worker m.add_module('attribute_name', nn.Module()) 553*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(KeyError): 554*da0073e9SAndroid Build Coastguard Worker m.register_buffer('attribute_name', torch.rand(5)) 555*da0073e9SAndroid Build Coastguard Worker 556*da0073e9SAndroid Build Coastguard Worker def test_register_buffer_raises_error_if_not_tensor(self): 557*da0073e9SAndroid Build Coastguard Worker m = nn.Module() 558*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(TypeError): 559*da0073e9SAndroid Build Coastguard Worker m.register_buffer('attribute_name', 5) 560*da0073e9SAndroid Build Coastguard Worker 561*da0073e9SAndroid Build Coastguard Worker def test_register_buffer_allows_overwriting_with_same_name(self): 562*da0073e9SAndroid Build Coastguard Worker m = nn.Module() 563*da0073e9SAndroid Build Coastguard Worker buffer1 = torch.rand(5) 564*da0073e9SAndroid Build Coastguard Worker buffer2 = buffer1 + 5 565*da0073e9SAndroid Build Coastguard Worker buffer3 = None 566*da0073e9SAndroid Build Coastguard Worker m.register_buffer('buffer_name', buffer1) 567*da0073e9SAndroid Build Coastguard Worker self.assertEqual(m.buffer_name, buffer1) 568*da0073e9SAndroid Build Coastguard Worker m.register_buffer('buffer_name', buffer2) 569*da0073e9SAndroid Build Coastguard Worker self.assertEqual(m.buffer_name, buffer2) 570*da0073e9SAndroid Build Coastguard Worker m.register_buffer('buffer_name', buffer3) 571*da0073e9SAndroid Build Coastguard Worker self.assertEqual(m.buffer_name, buffer3) 572*da0073e9SAndroid Build Coastguard Worker m.buffer_name = Buffer(buffer1) 573*da0073e9SAndroid Build Coastguard Worker self.assertEqual(m.buffer_name, Buffer(buffer1)) 574*da0073e9SAndroid Build Coastguard Worker m.buffer_name = Buffer(buffer2) 575*da0073e9SAndroid Build Coastguard Worker self.assertEqual(m.buffer_name, Buffer(buffer2)) 576*da0073e9SAndroid Build Coastguard Worker m.buffer_name = Buffer(buffer3) 577*da0073e9SAndroid Build Coastguard Worker self.assertEqual(m.buffer_name, Buffer(buffer3)) 578*da0073e9SAndroid Build Coastguard Worker 579*da0073e9SAndroid Build Coastguard Worker def test_get_buffer(self): 580*da0073e9SAndroid Build Coastguard Worker m = nn.Module() 581*da0073e9SAndroid Build Coastguard Worker buffer1 = torch.randn(2, 3) 582*da0073e9SAndroid Build Coastguard Worker buffer2 = torch.randn(4, 5) 583*da0073e9SAndroid Build Coastguard Worker m.foo = Buffer(buffer1) 584*da0073e9SAndroid Build Coastguard Worker m.register_buffer('bar', buffer2) 585*da0073e9SAndroid Build Coastguard Worker self.assertEqual(buffer1, m.get_buffer('foo')) 586*da0073e9SAndroid Build Coastguard Worker self.assertEqual(buffer2, m.get_buffer('bar')) 587*da0073e9SAndroid Build Coastguard Worker 588*da0073e9SAndroid Build Coastguard Worker def test_get_buffer_from_submodules(self): 589*da0073e9SAndroid Build Coastguard Worker class MyModule(nn.Module): 590*da0073e9SAndroid Build Coastguard Worker def __init__(self, foo, bar): 591*da0073e9SAndroid Build Coastguard Worker super().__init__() 592*da0073e9SAndroid Build Coastguard Worker self.sub = Sub(foo, bar) 593*da0073e9SAndroid Build Coastguard Worker 594*da0073e9SAndroid Build Coastguard Worker class Sub(nn.Module): 595*da0073e9SAndroid Build Coastguard Worker def __init__(self, foo, bar): 596*da0073e9SAndroid Build Coastguard Worker super().__init__() 597*da0073e9SAndroid Build Coastguard Worker self.foo = Buffer(foo) 598*da0073e9SAndroid Build Coastguard Worker self.subsub = SubSub(bar) 599*da0073e9SAndroid Build Coastguard Worker 600*da0073e9SAndroid Build Coastguard Worker class SubSub(nn.Module): 601*da0073e9SAndroid Build Coastguard Worker def __init__(self, bar): 602*da0073e9SAndroid Build Coastguard Worker super().__init__() 603*da0073e9SAndroid Build Coastguard Worker self.bar = Buffer(bar) 604*da0073e9SAndroid Build Coastguard Worker 605*da0073e9SAndroid Build Coastguard Worker foo = torch.randn(2, 3) 606*da0073e9SAndroid Build Coastguard Worker bar = torch.randn(4, 5) 607*da0073e9SAndroid Build Coastguard Worker m = MyModule(foo, bar) 608*da0073e9SAndroid Build Coastguard Worker self.assertEqual(foo, m.get_buffer('sub.foo')) 609*da0073e9SAndroid Build Coastguard Worker self.assertEqual(bar, m.get_buffer('sub.subsub.bar')) 610*da0073e9SAndroid Build Coastguard Worker 611*da0073e9SAndroid Build Coastguard Worker def test_buffer_not_persistent(self): 612*da0073e9SAndroid Build Coastguard Worker m = nn.Module() 613*da0073e9SAndroid Build Coastguard Worker m.buf = nn.Buffer(torch.rand(5), persistent=False) 614*da0073e9SAndroid Build Coastguard Worker self.assertTrue(len(list(m.buffers())) == 1) 615*da0073e9SAndroid Build Coastguard Worker self.assertTrue(len(m.state_dict()) == 0) 616*da0073e9SAndroid Build Coastguard Worker 617*da0073e9SAndroid Build Coastguard Worker def test_buffer_not_persistent_del(self): 618*da0073e9SAndroid Build Coastguard Worker m = nn.Module() 619*da0073e9SAndroid Build Coastguard Worker m.buf = nn.Buffer(torch.rand(5), persistent=False) 620*da0073e9SAndroid Build Coastguard Worker del m.buf 621*da0073e9SAndroid Build Coastguard Worker self.assertTrue(len(list(m.buffers())) == 0) 622*da0073e9SAndroid Build Coastguard Worker 623*da0073e9SAndroid Build Coastguard Worker def test_buffer_not_persistent_overwrite(self): 624*da0073e9SAndroid Build Coastguard Worker m = nn.Module() 625*da0073e9SAndroid Build Coastguard Worker m.buf = nn.Buffer(torch.rand(5), persistent=False) 626*da0073e9SAndroid Build Coastguard Worker m.buf = nn.Buffer(torch.rand(5)) 627*da0073e9SAndroid Build Coastguard Worker 628*da0073e9SAndroid Build Coastguard Worker # can we overwrite a non-persistent buffer with a persistent one? 629*da0073e9SAndroid Build Coastguard Worker self.assertTrue(len(list(m.buffers())) == 1) 630*da0073e9SAndroid Build Coastguard Worker self.assertTrue(len(m.state_dict()) == 1) 631*da0073e9SAndroid Build Coastguard Worker 632*da0073e9SAndroid Build Coastguard Worker # can we overwrite a persistent buffer with a non-persistent one? 633*da0073e9SAndroid Build Coastguard Worker m.buf = nn.Buffer(torch.rand(5), persistent=False) 634*da0073e9SAndroid Build Coastguard Worker self.assertTrue(len(list(m.buffers())) == 1) 635*da0073e9SAndroid Build Coastguard Worker self.assertTrue(len(m.state_dict()) == 0) 636*da0073e9SAndroid Build Coastguard Worker 637*da0073e9SAndroid Build Coastguard Worker def test_buffer_not_persistent_assign(self): 638*da0073e9SAndroid Build Coastguard Worker m = nn.Module() 639*da0073e9SAndroid Build Coastguard Worker m.buf = nn.Buffer(torch.rand(5), persistent=False) 640*da0073e9SAndroid Build Coastguard Worker self.assertTrue(len(list(m.buffers())) == 1) 641*da0073e9SAndroid Build Coastguard Worker self.assertTrue(len(m.state_dict()) == 0) 642*da0073e9SAndroid Build Coastguard Worker 643*da0073e9SAndroid Build Coastguard Worker # Assigning None removes the buffer but if we then assign a new Tensor 644*da0073e9SAndroid Build Coastguard Worker # to the same property, it should still be marked as a buffer. 645*da0073e9SAndroid Build Coastguard Worker m.buf = None 646*da0073e9SAndroid Build Coastguard Worker self.assertTrue(len(list(m.buffers())) == 0) 647*da0073e9SAndroid Build Coastguard Worker self.assertTrue(len(m.state_dict()) == 0) 648*da0073e9SAndroid Build Coastguard Worker m.buf = torch.rand(5) 649*da0073e9SAndroid Build Coastguard Worker self.assertTrue(len(list(m.buffers())) == 1) 650*da0073e9SAndroid Build Coastguard Worker self.assertTrue(len(m.state_dict()) == 0) 651*da0073e9SAndroid Build Coastguard Worker 652*da0073e9SAndroid Build Coastguard Worker # Assigning a Parameter removes the buffer. 653*da0073e9SAndroid Build Coastguard Worker m.buf = nn.Parameter(torch.rand(5)) 654*da0073e9SAndroid Build Coastguard Worker self.assertTrue(len(list(m.buffers())) == 0) 655*da0073e9SAndroid Build Coastguard Worker self.assertTrue(len(m.state_dict()) == 1) 656*da0073e9SAndroid Build Coastguard Worker 657*da0073e9SAndroid Build Coastguard Worker def test_buffer_not_persistent_load(self): 658*da0073e9SAndroid Build Coastguard Worker m = nn.Module() 659*da0073e9SAndroid Build Coastguard Worker m.buf = nn.Buffer(torch.rand(5), persistent=False) 660*da0073e9SAndroid Build Coastguard Worker m.load_state_dict({}) 661*da0073e9SAndroid Build Coastguard Worker 662*da0073e9SAndroid Build Coastguard Worker def test_register_parameter_raises_error_if_name_is_not_string(self): 663*da0073e9SAndroid Build Coastguard Worker m = nn.Module() 664*da0073e9SAndroid Build Coastguard Worker expected_error = 'parameter name should be a string. Got ' 665*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(TypeError, expected_error + 'int'): 666*da0073e9SAndroid Build Coastguard Worker m.register_parameter(1, nn.Parameter()) 667*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(TypeError, expected_error + 'NoneType'): 668*da0073e9SAndroid Build Coastguard Worker m.register_parameter(None, nn.Parameter()) 669*da0073e9SAndroid Build Coastguard Worker 670*da0073e9SAndroid Build Coastguard Worker def test_register_parameter_raises_error_if_attr_exists(self): 671*da0073e9SAndroid Build Coastguard Worker m = nn.Module() 672*da0073e9SAndroid Build Coastguard Worker m.attribute_name = 5 673*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(KeyError): 674*da0073e9SAndroid Build Coastguard Worker m.register_parameter('attribute_name', nn.Parameter()) 675*da0073e9SAndroid Build Coastguard Worker 676*da0073e9SAndroid Build Coastguard Worker del m.attribute_name 677*da0073e9SAndroid Build Coastguard Worker m.register_buffer('attribute_name', torch.rand(5)) 678*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(KeyError): 679*da0073e9SAndroid Build Coastguard Worker m.register_parameter('attribute_name', nn.Parameter()) 680*da0073e9SAndroid Build Coastguard Worker 681*da0073e9SAndroid Build Coastguard Worker del m.attribute_name 682*da0073e9SAndroid Build Coastguard Worker m.attribute_name = Buffer(torch.rand(5)) 683*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(KeyError): 684*da0073e9SAndroid Build Coastguard Worker m.register_parameter('attribute_name', nn.Parameter()) 685*da0073e9SAndroid Build Coastguard Worker 686*da0073e9SAndroid Build Coastguard Worker del m.attribute_name 687*da0073e9SAndroid Build Coastguard Worker m.add_module('attribute_name', nn.Module()) 688*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(KeyError): 689*da0073e9SAndroid Build Coastguard Worker m.register_parameter('attribute_name', nn.Parameter()) 690*da0073e9SAndroid Build Coastguard Worker 691*da0073e9SAndroid Build Coastguard Worker def test_register_parameter_allows_overwriting_with_same_name(self): 692*da0073e9SAndroid Build Coastguard Worker m = nn.Module() 693*da0073e9SAndroid Build Coastguard Worker param1 = nn.Parameter(torch.rand(5)) 694*da0073e9SAndroid Build Coastguard Worker param2 = nn.Parameter(param1.data + 5) 695*da0073e9SAndroid Build Coastguard Worker param3 = None 696*da0073e9SAndroid Build Coastguard Worker m.register_parameter('param_name', param1) 697*da0073e9SAndroid Build Coastguard Worker self.assertEqual(m.param_name, param1) 698*da0073e9SAndroid Build Coastguard Worker m.register_parameter('param_name', param2) 699*da0073e9SAndroid Build Coastguard Worker self.assertEqual(m.param_name, param2) 700*da0073e9SAndroid Build Coastguard Worker m.register_parameter('param_name', param3) 701*da0073e9SAndroid Build Coastguard Worker self.assertEqual(m.param_name, param3) 702*da0073e9SAndroid Build Coastguard Worker 703*da0073e9SAndroid Build Coastguard Worker def test_add_module_raises_error_if_attr_exists(self): 704*da0073e9SAndroid Build Coastguard Worker methods_to_test = ['add_module', 'register_module'] 705*da0073e9SAndroid Build Coastguard Worker for fn in methods_to_test: 706*da0073e9SAndroid Build Coastguard Worker m = nn.Module() 707*da0073e9SAndroid Build Coastguard Worker m.attribute_name = 5 708*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(KeyError): 709*da0073e9SAndroid Build Coastguard Worker getattr(m, fn)('attribute_name', nn.Module()) 710*da0073e9SAndroid Build Coastguard Worker 711*da0073e9SAndroid Build Coastguard Worker del m.attribute_name 712*da0073e9SAndroid Build Coastguard Worker m.register_buffer('attribute_name', torch.rand(5)) 713*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(KeyError): 714*da0073e9SAndroid Build Coastguard Worker getattr(m, fn)('attribute_name', nn.Module()) 715*da0073e9SAndroid Build Coastguard Worker 716*da0073e9SAndroid Build Coastguard Worker del m.attribute_name 717*da0073e9SAndroid Build Coastguard Worker m.register_parameter('attribute_name', nn.Parameter()) 718*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(KeyError): 719*da0073e9SAndroid Build Coastguard Worker getattr(m, fn)('attribute_name', nn.Module()) 720*da0073e9SAndroid Build Coastguard Worker 721*da0073e9SAndroid Build Coastguard Worker @unittest.expectedFailure 722*da0073e9SAndroid Build Coastguard Worker def test_getattr_with_property(self): 723*da0073e9SAndroid Build Coastguard Worker class Model(nn.Module): 724*da0073e9SAndroid Build Coastguard Worker @property 725*da0073e9SAndroid Build Coastguard Worker def some_property(self): 726*da0073e9SAndroid Build Coastguard Worker return self.something_that_doesnt_exist 727*da0073e9SAndroid Build Coastguard Worker 728*da0073e9SAndroid Build Coastguard Worker model = Model() 729*da0073e9SAndroid Build Coastguard Worker 730*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 731*da0073e9SAndroid Build Coastguard Worker AttributeError, 732*da0073e9SAndroid Build Coastguard Worker r"'Model' object has no attribute 'something_that_doesnt_exist'"): 733*da0073e9SAndroid Build Coastguard Worker model.some_property 734*da0073e9SAndroid Build Coastguard Worker 735*da0073e9SAndroid Build Coastguard Worker def test_Sequential_getitem(self): 736*da0073e9SAndroid Build Coastguard Worker l1 = nn.Linear(10, 20) 737*da0073e9SAndroid Build Coastguard Worker l2 = nn.Linear(20, 30) 738*da0073e9SAndroid Build Coastguard Worker l3 = nn.Linear(30, 40) 739*da0073e9SAndroid Build Coastguard Worker l4 = nn.Linear(40, 50) 740*da0073e9SAndroid Build Coastguard Worker n = nn.Sequential(l1, l2, l3, l4) 741*da0073e9SAndroid Build Coastguard Worker self.assertIs(n[0], l1) 742*da0073e9SAndroid Build Coastguard Worker self.assertIs(n[1], l2) 743*da0073e9SAndroid Build Coastguard Worker self.assertIs(n[2], l3) 744*da0073e9SAndroid Build Coastguard Worker self.assertIs(n[3], l4) 745*da0073e9SAndroid Build Coastguard Worker self.assertIs(n[torch.tensor(3, dtype=torch.int64)], l4) 746*da0073e9SAndroid Build Coastguard Worker self.assertEqual(n[1:], nn.Sequential(l2, l3, l4)) 747*da0073e9SAndroid Build Coastguard Worker self.assertEqual(n[3:], nn.Sequential(l4)) 748*da0073e9SAndroid Build Coastguard Worker self.assertEqual(n[:-1], nn.Sequential(l1, l2, l3)) 749*da0073e9SAndroid Build Coastguard Worker self.assertEqual(n[:-3], nn.Sequential(l1)) 750*da0073e9SAndroid Build Coastguard Worker self.assertEqual(n[::-1], nn.Sequential(l4, l3, l2, l1)) 751*da0073e9SAndroid Build Coastguard Worker 752*da0073e9SAndroid Build Coastguard Worker def test_Sequential_setitem(self): 753*da0073e9SAndroid Build Coastguard Worker l1 = nn.Linear(10, 20) 754*da0073e9SAndroid Build Coastguard Worker l2 = nn.Linear(20, 30) 755*da0073e9SAndroid Build Coastguard Worker l3 = nn.Linear(30, 40) 756*da0073e9SAndroid Build Coastguard Worker l4 = nn.Linear(40, 50) 757*da0073e9SAndroid Build Coastguard Worker n = nn.Sequential(l1, l2, l3) 758*da0073e9SAndroid Build Coastguard Worker n[0] = l4 759*da0073e9SAndroid Build Coastguard Worker n[-1] = l4 760*da0073e9SAndroid Build Coastguard Worker n[torch.tensor(1, dtype=torch.int16)] = l1 761*da0073e9SAndroid Build Coastguard Worker self.assertIs(n[0], l4) 762*da0073e9SAndroid Build Coastguard Worker self.assertIs(n[1], l1) 763*da0073e9SAndroid Build Coastguard Worker self.assertIs(n[2], l4) 764*da0073e9SAndroid Build Coastguard Worker 765*da0073e9SAndroid Build Coastguard Worker def test_Sequential_setitem_named(self): 766*da0073e9SAndroid Build Coastguard Worker l1 = nn.Linear(10, 20) 767*da0073e9SAndroid Build Coastguard Worker l2 = nn.Linear(20, 30) 768*da0073e9SAndroid Build Coastguard Worker l3 = nn.Linear(30, 40) 769*da0073e9SAndroid Build Coastguard Worker l4 = nn.Linear(40, 50) 770*da0073e9SAndroid Build Coastguard Worker n = nn.Sequential(OrderedDict([ 771*da0073e9SAndroid Build Coastguard Worker ('linear1', l1), 772*da0073e9SAndroid Build Coastguard Worker ('linear2', l2), 773*da0073e9SAndroid Build Coastguard Worker ('linear3', l3), 774*da0073e9SAndroid Build Coastguard Worker ])) 775*da0073e9SAndroid Build Coastguard Worker 776*da0073e9SAndroid Build Coastguard Worker n[0] = l4 777*da0073e9SAndroid Build Coastguard Worker n[-1] = l4 778*da0073e9SAndroid Build Coastguard Worker self.assertEqual(n.linear1, l4) 779*da0073e9SAndroid Build Coastguard Worker self.assertEqual(n.linear3, l4) 780*da0073e9SAndroid Build Coastguard Worker 781*da0073e9SAndroid Build Coastguard Worker def test_Sequential_delitem(self): 782*da0073e9SAndroid Build Coastguard Worker l1 = nn.Linear(10, 20) 783*da0073e9SAndroid Build Coastguard Worker l2 = nn.Linear(20, 30) 784*da0073e9SAndroid Build Coastguard Worker l3 = nn.Linear(30, 40) 785*da0073e9SAndroid Build Coastguard Worker l4 = nn.Linear(40, 50) 786*da0073e9SAndroid Build Coastguard Worker n = nn.Sequential(l1, l2, l3, l4) 787*da0073e9SAndroid Build Coastguard Worker del n[-1] 788*da0073e9SAndroid Build Coastguard Worker self.assertEqual(n, nn.Sequential(l1, l2, l3)) 789*da0073e9SAndroid Build Coastguard Worker del n[1::2] 790*da0073e9SAndroid Build Coastguard Worker self.assertEqual(n, nn.Sequential(l1, l3)) 791*da0073e9SAndroid Build Coastguard Worker 792*da0073e9SAndroid Build Coastguard Worker def test_Sequential_add(self): 793*da0073e9SAndroid Build Coastguard Worker l1 = nn.Linear(1, 2) 794*da0073e9SAndroid Build Coastguard Worker l2 = nn.Linear(2, 3) 795*da0073e9SAndroid Build Coastguard Worker l3 = nn.Linear(3, 4) 796*da0073e9SAndroid Build Coastguard Worker l4 = nn.Linear(4, 5) 797*da0073e9SAndroid Build Coastguard Worker n = nn.Sequential(l1, l2) 798*da0073e9SAndroid Build Coastguard Worker other = nn.Sequential(l3, l4) 799*da0073e9SAndroid Build Coastguard Worker self.assertEqual(n + other, nn.Sequential(l1, l2, l3, l4)) 800*da0073e9SAndroid Build Coastguard Worker 801*da0073e9SAndroid Build Coastguard Worker def test_Sequential_iadd(self): 802*da0073e9SAndroid Build Coastguard Worker l1 = nn.Linear(10, 20) 803*da0073e9SAndroid Build Coastguard Worker l2 = nn.Linear(20, 30) 804*da0073e9SAndroid Build Coastguard Worker l3 = nn.Linear(30, 40) 805*da0073e9SAndroid Build Coastguard Worker l4 = nn.Linear(40, 50) 806*da0073e9SAndroid Build Coastguard Worker n = nn.Sequential(l1, l2, l3) 807*da0073e9SAndroid Build Coastguard Worker n2 = nn.Sequential(l4) 808*da0073e9SAndroid Build Coastguard Worker n += n2 809*da0073e9SAndroid Build Coastguard Worker n2 += n 810*da0073e9SAndroid Build Coastguard Worker self.assertEqual(n, nn.Sequential(l1, l2, l3, l4)) 811*da0073e9SAndroid Build Coastguard Worker self.assertEqual(n2, nn.Sequential(l4, l1, l2, l3, l4)) 812*da0073e9SAndroid Build Coastguard Worker 813*da0073e9SAndroid Build Coastguard Worker def test_Sequential_mul(self): 814*da0073e9SAndroid Build Coastguard Worker l1 = nn.Linear(10, 20) 815*da0073e9SAndroid Build Coastguard Worker l2 = nn.Linear(20, 30) 816*da0073e9SAndroid Build Coastguard Worker l3 = nn.Linear(30, 40) 817*da0073e9SAndroid Build Coastguard Worker l4 = nn.Linear(40, 50) 818*da0073e9SAndroid Build Coastguard Worker n = nn.Sequential(l1, l2, l3, l4) 819*da0073e9SAndroid Build Coastguard Worker n2 = n * 2 820*da0073e9SAndroid Build Coastguard Worker self.assertEqual(n2, nn.Sequential(l1, l2, l3, l4, l1, l2, l3, l4)) 821*da0073e9SAndroid Build Coastguard Worker 822*da0073e9SAndroid Build Coastguard Worker def test_Sequential_rmul(self): 823*da0073e9SAndroid Build Coastguard Worker l1 = nn.Linear(10, 20) 824*da0073e9SAndroid Build Coastguard Worker l2 = nn.Linear(20, 30) 825*da0073e9SAndroid Build Coastguard Worker l3 = nn.Linear(30, 40) 826*da0073e9SAndroid Build Coastguard Worker l4 = nn.Linear(40, 50) 827*da0073e9SAndroid Build Coastguard Worker n = nn.Sequential(l1, l2, l3, l4) 828*da0073e9SAndroid Build Coastguard Worker n2 = 2 * n 829*da0073e9SAndroid Build Coastguard Worker self.assertEqual(n2, nn.Sequential(l1, l2, l3, l4, l1, l2, l3, l4)) 830*da0073e9SAndroid Build Coastguard Worker 831*da0073e9SAndroid Build Coastguard Worker def test_Sequential_imul(self): 832*da0073e9SAndroid Build Coastguard Worker l1 = nn.Linear(10, 20) 833*da0073e9SAndroid Build Coastguard Worker l2 = nn.Linear(20, 30) 834*da0073e9SAndroid Build Coastguard Worker l3 = nn.Linear(30, 40) 835*da0073e9SAndroid Build Coastguard Worker l4 = nn.Linear(40, 50) 836*da0073e9SAndroid Build Coastguard Worker n = nn.Sequential(l1, l2, l3, l4) 837*da0073e9SAndroid Build Coastguard Worker n *= 2 838*da0073e9SAndroid Build Coastguard Worker self.assertEqual(n, nn.Sequential(l1, l2, l3, l4, l1, l2, l3, l4)) 839*da0073e9SAndroid Build Coastguard Worker n *= 2 840*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 841*da0073e9SAndroid Build Coastguard Worker n, 842*da0073e9SAndroid Build Coastguard Worker nn.Sequential(l1, l2, l3, l4, l1, l2, l3, l4, l1, l2, l3, l4, l1, l2, l3, l4) 843*da0073e9SAndroid Build Coastguard Worker ) 844*da0073e9SAndroid Build Coastguard Worker 845*da0073e9SAndroid Build Coastguard Worker def test_Sequential_append(self): 846*da0073e9SAndroid Build Coastguard Worker l1 = nn.Linear(10, 20) 847*da0073e9SAndroid Build Coastguard Worker l2 = nn.Linear(20, 30) 848*da0073e9SAndroid Build Coastguard Worker l3 = nn.Linear(30, 40) 849*da0073e9SAndroid Build Coastguard Worker l4 = nn.Linear(40, 50) 850*da0073e9SAndroid Build Coastguard Worker n = nn.Sequential(l1, l2, l3) 851*da0073e9SAndroid Build Coastguard Worker n2 = n.append(l4) 852*da0073e9SAndroid Build Coastguard Worker self.assertEqual(n, nn.Sequential(l1, l2, l3, l4)) 853*da0073e9SAndroid Build Coastguard Worker self.assertEqual(n2, nn.Sequential(l1, l2, l3, l4)) 854*da0073e9SAndroid Build Coastguard Worker self.assertEqual(nn.Sequential(l1).append(l2).append(l4), nn.Sequential(l1, l2, l4)) 855*da0073e9SAndroid Build Coastguard Worker 856*da0073e9SAndroid Build Coastguard Worker def test_Sequential_pop(self): 857*da0073e9SAndroid Build Coastguard Worker l1 = nn.Linear(1, 2) 858*da0073e9SAndroid Build Coastguard Worker l2 = nn.Linear(2, 3) 859*da0073e9SAndroid Build Coastguard Worker l3 = nn.Linear(3, 4) 860*da0073e9SAndroid Build Coastguard Worker l4 = nn.Linear(4, 5) 861*da0073e9SAndroid Build Coastguard Worker n1 = nn.Sequential(l1, l2, l3, l4) 862*da0073e9SAndroid Build Coastguard Worker self.assertEqual(l4, n1.pop(3)) 863*da0073e9SAndroid Build Coastguard Worker n2 = nn.Sequential(l1, l2, l3) 864*da0073e9SAndroid Build Coastguard Worker self.assertEqual(n1, n2) 865*da0073e9SAndroid Build Coastguard Worker # check order of the index 866*da0073e9SAndroid Build Coastguard Worker for k, mod in zip(range(len(n1)), n1): 867*da0073e9SAndroid Build Coastguard Worker self.assertIs(n1[k], mod) 868*da0073e9SAndroid Build Coastguard Worker 869*da0073e9SAndroid Build Coastguard Worker def test_Sequential_insert(self): 870*da0073e9SAndroid Build Coastguard Worker l1 = nn.Linear(1, 2) 871*da0073e9SAndroid Build Coastguard Worker l2 = nn.Linear(2, 3) 872*da0073e9SAndroid Build Coastguard Worker l3 = nn.Linear(3, 4) 873*da0073e9SAndroid Build Coastguard Worker 874*da0073e9SAndroid Build Coastguard Worker n1 = nn.Sequential(l1, l2, l3) 875*da0073e9SAndroid Build Coastguard Worker module_1 = nn.Linear(4, 5) 876*da0073e9SAndroid Build Coastguard Worker n2 = nn.Sequential(l1, module_1, l2, l3) 877*da0073e9SAndroid Build Coastguard Worker self.assertEqual(n1.insert(1, module_1), n2) 878*da0073e9SAndroid Build Coastguard Worker 879*da0073e9SAndroid Build Coastguard Worker # test for negative support 880*da0073e9SAndroid Build Coastguard Worker n3 = nn.Sequential(l1, l2, l3) 881*da0073e9SAndroid Build Coastguard Worker module_2 = nn.Linear(5, 6) 882*da0073e9SAndroid Build Coastguard Worker n4 = nn.Sequential(l1, module_2, l2, l3) 883*da0073e9SAndroid Build Coastguard Worker self.assertEqual(n3.insert(-2, module_2), n4) 884*da0073e9SAndroid Build Coastguard Worker 885*da0073e9SAndroid Build Coastguard Worker def test_Sequential_insert_fail_case(self): 886*da0073e9SAndroid Build Coastguard Worker l1 = nn.Linear(1, 2) 887*da0073e9SAndroid Build Coastguard Worker l2 = nn.Linear(2, 3) 888*da0073e9SAndroid Build Coastguard Worker l3 = nn.Linear(3, 4) 889*da0073e9SAndroid Build Coastguard Worker 890*da0073e9SAndroid Build Coastguard Worker module = nn.Linear(5, 6) 891*da0073e9SAndroid Build Coastguard Worker 892*da0073e9SAndroid Build Coastguard Worker # test for error case 893*da0073e9SAndroid Build Coastguard Worker n1 = nn.Sequential(l1, l2, l3) 894*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(IndexError): 895*da0073e9SAndroid Build Coastguard Worker n1.insert(-5, module) 896*da0073e9SAndroid Build Coastguard Worker 897*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(AssertionError): 898*da0073e9SAndroid Build Coastguard Worker n1.insert(1, [nn.Linear(6, 7)]) 899*da0073e9SAndroid Build Coastguard Worker 900*da0073e9SAndroid Build Coastguard Worker def test_Sequential_extend(self): 901*da0073e9SAndroid Build Coastguard Worker l1 = nn.Linear(10, 20) 902*da0073e9SAndroid Build Coastguard Worker l2 = nn.Linear(20, 30) 903*da0073e9SAndroid Build Coastguard Worker l3 = nn.Linear(30, 40) 904*da0073e9SAndroid Build Coastguard Worker l4 = nn.Linear(40, 50) 905*da0073e9SAndroid Build Coastguard Worker n1 = nn.Sequential(l1, l2) 906*da0073e9SAndroid Build Coastguard Worker n2 = nn.Sequential(l3, l4) 907*da0073e9SAndroid Build Coastguard Worker n3 = nn.Sequential(l1, l2) 908*da0073e9SAndroid Build Coastguard Worker for l in n2: 909*da0073e9SAndroid Build Coastguard Worker n1.append(l) 910*da0073e9SAndroid Build Coastguard Worker n3.extend(n2) 911*da0073e9SAndroid Build Coastguard Worker self.assertEqual(n3, n1) 912*da0073e9SAndroid Build Coastguard Worker 913*da0073e9SAndroid Build Coastguard Worker def test_ModuleList(self): 914*da0073e9SAndroid Build Coastguard Worker modules = [nn.ReLU(), nn.Linear(5, 5)] 915*da0073e9SAndroid Build Coastguard Worker module_list = nn.ModuleList(modules) 916*da0073e9SAndroid Build Coastguard Worker 917*da0073e9SAndroid Build Coastguard Worker def check(): 918*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(module_list), len(modules)) 919*da0073e9SAndroid Build Coastguard Worker for m1, m2 in zip(modules, module_list): 920*da0073e9SAndroid Build Coastguard Worker self.assertIs(m1, m2) 921*da0073e9SAndroid Build Coastguard Worker for m1, m2 in zip(modules, module_list.children()): 922*da0073e9SAndroid Build Coastguard Worker self.assertIs(m1, m2) 923*da0073e9SAndroid Build Coastguard Worker for i in range(len(modules)): 924*da0073e9SAndroid Build Coastguard Worker self.assertIs(module_list[i], modules[i]) 925*da0073e9SAndroid Build Coastguard Worker 926*da0073e9SAndroid Build Coastguard Worker check() 927*da0073e9SAndroid Build Coastguard Worker modules += [nn.Conv2d(3, 4, 3)] 928*da0073e9SAndroid Build Coastguard Worker module_list += [modules[-1]] 929*da0073e9SAndroid Build Coastguard Worker check() 930*da0073e9SAndroid Build Coastguard Worker modules = modules + [nn.Conv2d(3, 4, 3, bias=False), nn.GELU()] 931*da0073e9SAndroid Build Coastguard Worker module_list = module_list + nn.ModuleList(modules[-2:]) 932*da0073e9SAndroid Build Coastguard Worker check() 933*da0073e9SAndroid Build Coastguard Worker modules.insert(1, nn.Linear(3, 2)) 934*da0073e9SAndroid Build Coastguard Worker module_list.insert(1, modules[1]) 935*da0073e9SAndroid Build Coastguard Worker check() 936*da0073e9SAndroid Build Coastguard Worker modules.append(nn.Tanh()) 937*da0073e9SAndroid Build Coastguard Worker module_list.append(modules[-1]) 938*da0073e9SAndroid Build Coastguard Worker check() 939*da0073e9SAndroid Build Coastguard Worker next_modules = [nn.Linear(5, 5), nn.Sigmoid()] 940*da0073e9SAndroid Build Coastguard Worker modules.extend(next_modules) 941*da0073e9SAndroid Build Coastguard Worker module_list.extend(next_modules) 942*da0073e9SAndroid Build Coastguard Worker check() 943*da0073e9SAndroid Build Coastguard Worker modules[2] = nn.Conv2d(5, 3, 2) 944*da0073e9SAndroid Build Coastguard Worker module_list[2] = modules[2] 945*da0073e9SAndroid Build Coastguard Worker check() 946*da0073e9SAndroid Build Coastguard Worker modules[-1] = nn.Conv2d(5, 2, 1) 947*da0073e9SAndroid Build Coastguard Worker module_list[-1] = modules[-1] 948*da0073e9SAndroid Build Coastguard Worker check() 949*da0073e9SAndroid Build Coastguard Worker idx = torch.tensor(2, dtype=torch.int32) 950*da0073e9SAndroid Build Coastguard Worker modules[2] = nn.Conv2d(5, 3, 2) 951*da0073e9SAndroid Build Coastguard Worker module_list[idx] = modules[2] 952*da0073e9SAndroid Build Coastguard Worker self.assertIs(module_list[idx], modules[2]) 953*da0073e9SAndroid Build Coastguard Worker check() 954*da0073e9SAndroid Build Coastguard Worker self.assertEqual(module_list[1:], nn.ModuleList(modules[1:])) 955*da0073e9SAndroid Build Coastguard Worker self.assertEqual(module_list[3:], nn.ModuleList(modules[3:])) 956*da0073e9SAndroid Build Coastguard Worker self.assertEqual(module_list[:-1], nn.ModuleList(modules[:-1])) 957*da0073e9SAndroid Build Coastguard Worker self.assertEqual(module_list[:-3], nn.ModuleList(modules[:-3])) 958*da0073e9SAndroid Build Coastguard Worker self.assertEqual(module_list[::-1], nn.ModuleList(modules[::-1])) 959*da0073e9SAndroid Build Coastguard Worker del module_list[-1] 960*da0073e9SAndroid Build Coastguard Worker self.assertEqual(module_list, nn.ModuleList(modules[:-1])) 961*da0073e9SAndroid Build Coastguard Worker del module_list[1::2] 962*da0073e9SAndroid Build Coastguard Worker self.assertEqual(module_list, nn.ModuleList(modules[:-1][0::2])) 963*da0073e9SAndroid Build Coastguard Worker 964*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(TypeError): 965*da0073e9SAndroid Build Coastguard Worker module_list += nn.ReLU() 966*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(TypeError): 967*da0073e9SAndroid Build Coastguard Worker module_list.extend(nn.ReLU()) 968*da0073e9SAndroid Build Coastguard Worker 969*da0073e9SAndroid Build Coastguard Worker l1 = nn.Linear(1, 2) 970*da0073e9SAndroid Build Coastguard Worker l2 = nn.Linear(2, 3) 971*da0073e9SAndroid Build Coastguard Worker l3 = nn.Linear(3, 2) 972*da0073e9SAndroid Build Coastguard Worker l4 = nn.Linear(2, 3) 973*da0073e9SAndroid Build Coastguard Worker subnet = nn.Sequential(l3, l4) 974*da0073e9SAndroid Build Coastguard Worker s = nn.Sequential( 975*da0073e9SAndroid Build Coastguard Worker OrderedDict([ 976*da0073e9SAndroid Build Coastguard Worker ("layer1", l1), 977*da0073e9SAndroid Build Coastguard Worker ("layer2", l2), 978*da0073e9SAndroid Build Coastguard Worker ("layer3", l3), 979*da0073e9SAndroid Build Coastguard Worker ("layer4", l4), 980*da0073e9SAndroid Build Coastguard Worker ("subnet_layer", subnet) 981*da0073e9SAndroid Build Coastguard Worker ]) 982*da0073e9SAndroid Build Coastguard Worker ) 983*da0073e9SAndroid Build Coastguard Worker modules = list(s.modules()) 984*da0073e9SAndroid Build Coastguard Worker module_list = nn.ModuleList() 985*da0073e9SAndroid Build Coastguard Worker module_list.extend(s.modules()) 986*da0073e9SAndroid Build Coastguard Worker check() 987*da0073e9SAndroid Build Coastguard Worker 988*da0073e9SAndroid Build Coastguard Worker modules = [nn.ReLU(), nn.Linear(5, 5), nn.Conv2d(3, 4, 3)] 989*da0073e9SAndroid Build Coastguard Worker module_list = nn.ModuleList(modules) 990*da0073e9SAndroid Build Coastguard Worker self.assertEqual(modules.pop(1), module_list.pop(1)) 991*da0073e9SAndroid Build Coastguard Worker self.assertEqual(modules, module_list) 992*da0073e9SAndroid Build Coastguard Worker # check order of the index 993*da0073e9SAndroid Build Coastguard Worker for k, mod in zip(range(len(module_list)), module_list): 994*da0073e9SAndroid Build Coastguard Worker self.assertIs(module_list[k], mod) 995*da0073e9SAndroid Build Coastguard Worker 996*da0073e9SAndroid Build Coastguard Worker # verify the right exception is thrown when trying to "forward" through a ModuleList 997*da0073e9SAndroid Build Coastguard Worker self.assertRaises(NotImplementedError, module_list) 998*da0073e9SAndroid Build Coastguard Worker self.assertRaises(NotImplementedError, module_list, torch.rand(1, 3)) 999*da0073e9SAndroid Build Coastguard Worker 1000*da0073e9SAndroid Build Coastguard Worker def test_ModuleDict(self): 1001*da0073e9SAndroid Build Coastguard Worker modules = OrderedDict([ 1002*da0073e9SAndroid Build Coastguard Worker ('act', nn.ReLU()), 1003*da0073e9SAndroid Build Coastguard Worker ('conv', nn.Conv2d(10, 10, 5)), 1004*da0073e9SAndroid Build Coastguard Worker ('fc', nn.Linear(5, 5)), 1005*da0073e9SAndroid Build Coastguard Worker ]) 1006*da0073e9SAndroid Build Coastguard Worker 1007*da0073e9SAndroid Build Coastguard Worker module_dict = nn.ModuleDict(modules) 1008*da0073e9SAndroid Build Coastguard Worker 1009*da0073e9SAndroid Build Coastguard Worker def check(): 1010*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(module_dict), len(modules)) 1011*da0073e9SAndroid Build Coastguard Worker for k1, m2 in zip(modules, module_dict.children()): 1012*da0073e9SAndroid Build Coastguard Worker self.assertIs(modules[k1], m2) 1013*da0073e9SAndroid Build Coastguard Worker for k1, k2 in zip(modules, module_dict): 1014*da0073e9SAndroid Build Coastguard Worker self.assertIs(modules[k1], module_dict[k2]) 1015*da0073e9SAndroid Build Coastguard Worker for k in module_dict: 1016*da0073e9SAndroid Build Coastguard Worker self.assertIs(module_dict[k], modules[k]) 1017*da0073e9SAndroid Build Coastguard Worker for k in module_dict.keys(): 1018*da0073e9SAndroid Build Coastguard Worker self.assertIs(module_dict[k], modules[k]) 1019*da0073e9SAndroid Build Coastguard Worker for k, v in module_dict.items(): 1020*da0073e9SAndroid Build Coastguard Worker self.assertIs(modules[k], v) 1021*da0073e9SAndroid Build Coastguard Worker for k1, m2 in zip(modules, module_dict.values()): 1022*da0073e9SAndroid Build Coastguard Worker self.assertIs(modules[k1], m2) 1023*da0073e9SAndroid Build Coastguard Worker for k in modules.keys(): 1024*da0073e9SAndroid Build Coastguard Worker self.assertTrue(k in module_dict) 1025*da0073e9SAndroid Build Coastguard Worker check() 1026*da0073e9SAndroid Build Coastguard Worker 1027*da0073e9SAndroid Build Coastguard Worker modules['conv'] = nn.Conv2d(3, 4, 3) 1028*da0073e9SAndroid Build Coastguard Worker module_dict['conv'] = modules['conv'] 1029*da0073e9SAndroid Build Coastguard Worker check() 1030*da0073e9SAndroid Build Coastguard Worker 1031*da0073e9SAndroid Build Coastguard Worker next_modules = [ 1032*da0073e9SAndroid Build Coastguard Worker ('fc2', nn.Linear(5, 5)), 1033*da0073e9SAndroid Build Coastguard Worker ('act', nn.Sigmoid()), 1034*da0073e9SAndroid Build Coastguard Worker ] 1035*da0073e9SAndroid Build Coastguard Worker modules.update(next_modules) 1036*da0073e9SAndroid Build Coastguard Worker module_dict.update(next_modules) 1037*da0073e9SAndroid Build Coastguard Worker check() 1038*da0073e9SAndroid Build Coastguard Worker 1039*da0073e9SAndroid Build Coastguard Worker next_modules = OrderedDict([ 1040*da0073e9SAndroid Build Coastguard Worker ('fc3', nn.Linear(5, 5)), 1041*da0073e9SAndroid Build Coastguard Worker ('act2', nn.Sigmoid()), 1042*da0073e9SAndroid Build Coastguard Worker ]) 1043*da0073e9SAndroid Build Coastguard Worker modules.update(next_modules) 1044*da0073e9SAndroid Build Coastguard Worker module_dict.update(next_modules) 1045*da0073e9SAndroid Build Coastguard Worker check() 1046*da0073e9SAndroid Build Coastguard Worker 1047*da0073e9SAndroid Build Coastguard Worker next_modules = { 1048*da0073e9SAndroid Build Coastguard Worker 'fc4': nn.Linear(5, 5), 1049*da0073e9SAndroid Build Coastguard Worker 'act3': nn.Sigmoid() 1050*da0073e9SAndroid Build Coastguard Worker } 1051*da0073e9SAndroid Build Coastguard Worker modules.update(next_modules.items()) 1052*da0073e9SAndroid Build Coastguard Worker module_dict.update(next_modules) 1053*da0073e9SAndroid Build Coastguard Worker check() 1054*da0073e9SAndroid Build Coastguard Worker 1055*da0073e9SAndroid Build Coastguard Worker next_modules = nn.ModuleDict([ 1056*da0073e9SAndroid Build Coastguard Worker ('fc5', nn.Linear(5, 5)), 1057*da0073e9SAndroid Build Coastguard Worker ('act4', nn.Sigmoid()), 1058*da0073e9SAndroid Build Coastguard Worker ]) 1059*da0073e9SAndroid Build Coastguard Worker modules.update(next_modules) 1060*da0073e9SAndroid Build Coastguard Worker module_dict.update(next_modules) 1061*da0073e9SAndroid Build Coastguard Worker check() 1062*da0073e9SAndroid Build Coastguard Worker 1063*da0073e9SAndroid Build Coastguard Worker del module_dict['fc'] 1064*da0073e9SAndroid Build Coastguard Worker del modules['fc'] 1065*da0073e9SAndroid Build Coastguard Worker check() 1066*da0073e9SAndroid Build Coastguard Worker 1067*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(TypeError): 1068*da0073e9SAndroid Build Coastguard Worker module_dict.update(nn.ReLU()) 1069*da0073e9SAndroid Build Coastguard Worker 1070*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(TypeError): 1071*da0073e9SAndroid Build Coastguard Worker module_dict.update([nn.ReLU()]) 1072*da0073e9SAndroid Build Coastguard Worker 1073*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(ValueError): 1074*da0073e9SAndroid Build Coastguard Worker module_dict.update([[nn.ReLU()]]) 1075*da0073e9SAndroid Build Coastguard Worker 1076*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(TypeError): 1077*da0073e9SAndroid Build Coastguard Worker module_dict[1] = nn.ReLU() 1078*da0073e9SAndroid Build Coastguard Worker 1079*da0073e9SAndroid Build Coastguard Worker s = nn.Sequential(modules) 1080*da0073e9SAndroid Build Coastguard Worker module_dict = nn.ModuleDict(s.named_children()) 1081*da0073e9SAndroid Build Coastguard Worker check() 1082*da0073e9SAndroid Build Coastguard Worker 1083*da0073e9SAndroid Build Coastguard Worker c = module_dict.pop('conv') 1084*da0073e9SAndroid Build Coastguard Worker self.assertIs(c, modules['conv']) 1085*da0073e9SAndroid Build Coastguard Worker modules.pop('conv') 1086*da0073e9SAndroid Build Coastguard Worker check() 1087*da0073e9SAndroid Build Coastguard Worker 1088*da0073e9SAndroid Build Coastguard Worker module_dict.clear() 1089*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(module_dict), 0) 1090*da0073e9SAndroid Build Coastguard Worker modules.clear() 1091*da0073e9SAndroid Build Coastguard Worker check() 1092*da0073e9SAndroid Build Coastguard Worker 1093*da0073e9SAndroid Build Coastguard Worker # verify the right exception is thrown when trying to "forward" through a ModuleDict 1094*da0073e9SAndroid Build Coastguard Worker self.assertRaises(NotImplementedError, module_dict) 1095*da0073e9SAndroid Build Coastguard Worker self.assertRaises(NotImplementedError, module_dict, torch.rand(1, 3)) 1096*da0073e9SAndroid Build Coastguard Worker 1097*da0073e9SAndroid Build Coastguard Worker @skipIfTorchDynamo() 1098*da0073e9SAndroid Build Coastguard Worker def test_ParameterList(self): 1099*da0073e9SAndroid Build Coastguard Worker def make_param(): 1100*da0073e9SAndroid Build Coastguard Worker return Parameter(torch.randn(2, 2)) 1101*da0073e9SAndroid Build Coastguard Worker parameters = [make_param(), make_param()] 1102*da0073e9SAndroid Build Coastguard Worker param_list = nn.ParameterList(parameters) 1103*da0073e9SAndroid Build Coastguard Worker 1104*da0073e9SAndroid Build Coastguard Worker def check(): 1105*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(parameters), len(param_list)) 1106*da0073e9SAndroid Build Coastguard Worker for p1, p2 in zip(parameters, param_list): 1107*da0073e9SAndroid Build Coastguard Worker self.assertIs(p1, p2) 1108*da0073e9SAndroid Build Coastguard Worker for p1, p2 in zip(filter(lambda x: isinstance(x, Parameter), parameters), param_list.parameters()): 1109*da0073e9SAndroid Build Coastguard Worker self.assertIs(p1, p2) 1110*da0073e9SAndroid Build Coastguard Worker for i in range(len(parameters)): 1111*da0073e9SAndroid Build Coastguard Worker self.assertIs(parameters[i], param_list[i]) 1112*da0073e9SAndroid Build Coastguard Worker 1113*da0073e9SAndroid Build Coastguard Worker check() 1114*da0073e9SAndroid Build Coastguard Worker parameters += [make_param()] 1115*da0073e9SAndroid Build Coastguard Worker param_list += [parameters[-1]] 1116*da0073e9SAndroid Build Coastguard Worker check() 1117*da0073e9SAndroid Build Coastguard Worker parameters.append(make_param()) 1118*da0073e9SAndroid Build Coastguard Worker param_list.append(parameters[-1]) 1119*da0073e9SAndroid Build Coastguard Worker check() 1120*da0073e9SAndroid Build Coastguard Worker next_params = [make_param(), make_param()] 1121*da0073e9SAndroid Build Coastguard Worker parameters.extend(next_params) 1122*da0073e9SAndroid Build Coastguard Worker param_list.extend(next_params) 1123*da0073e9SAndroid Build Coastguard Worker check() 1124*da0073e9SAndroid Build Coastguard Worker parameters[2] = make_param() 1125*da0073e9SAndroid Build Coastguard Worker param_list[2] = parameters[2] 1126*da0073e9SAndroid Build Coastguard Worker check() 1127*da0073e9SAndroid Build Coastguard Worker parameters[-1] = make_param() 1128*da0073e9SAndroid Build Coastguard Worker param_list[-1] = parameters[-1] 1129*da0073e9SAndroid Build Coastguard Worker check() 1130*da0073e9SAndroid Build Coastguard Worker idx = torch.tensor(2, dtype=torch.int32) 1131*da0073e9SAndroid Build Coastguard Worker parameters[2] = make_param() 1132*da0073e9SAndroid Build Coastguard Worker param_list[idx] = parameters[2] 1133*da0073e9SAndroid Build Coastguard Worker self.assertIs(param_list[idx], parameters[2]) 1134*da0073e9SAndroid Build Coastguard Worker check() 1135*da0073e9SAndroid Build Coastguard Worker self.assertEqual(param_list[1:], nn.ParameterList(parameters[1:])) 1136*da0073e9SAndroid Build Coastguard Worker self.assertEqual(param_list[3:], nn.ParameterList(parameters[3:])) 1137*da0073e9SAndroid Build Coastguard Worker self.assertEqual(param_list[:-1], nn.ParameterList(parameters[:-1])) 1138*da0073e9SAndroid Build Coastguard Worker self.assertEqual(param_list[:-3], nn.ParameterList(parameters[:-3])) 1139*da0073e9SAndroid Build Coastguard Worker self.assertEqual(param_list[::-1], nn.ParameterList(parameters[::-1])) 1140*da0073e9SAndroid Build Coastguard Worker 1141*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(TypeError): 1142*da0073e9SAndroid Build Coastguard Worker param_list += make_param() 1143*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(TypeError): 1144*da0073e9SAndroid Build Coastguard Worker param_list.extend(make_param()) 1145*da0073e9SAndroid Build Coastguard Worker 1146*da0073e9SAndroid Build Coastguard Worker l1 = nn.Linear(1, 2) 1147*da0073e9SAndroid Build Coastguard Worker l2 = nn.Linear(2, 3) 1148*da0073e9SAndroid Build Coastguard Worker l3 = nn.Linear(3, 2) 1149*da0073e9SAndroid Build Coastguard Worker l4 = nn.Linear(2, 3) 1150*da0073e9SAndroid Build Coastguard Worker subnet = nn.Sequential(l3, l4) 1151*da0073e9SAndroid Build Coastguard Worker s = nn.Sequential( 1152*da0073e9SAndroid Build Coastguard Worker OrderedDict([ 1153*da0073e9SAndroid Build Coastguard Worker ("layer1", l1), 1154*da0073e9SAndroid Build Coastguard Worker ("layer2", l2), 1155*da0073e9SAndroid Build Coastguard Worker ("layer3", l3), 1156*da0073e9SAndroid Build Coastguard Worker ("layer4", l4), 1157*da0073e9SAndroid Build Coastguard Worker ("subnet_layer", subnet) 1158*da0073e9SAndroid Build Coastguard Worker ]) 1159*da0073e9SAndroid Build Coastguard Worker ) 1160*da0073e9SAndroid Build Coastguard Worker parameters = list(s.parameters()) 1161*da0073e9SAndroid Build Coastguard Worker param_list = nn.ParameterList() 1162*da0073e9SAndroid Build Coastguard Worker param_list.extend(s.parameters()) 1163*da0073e9SAndroid Build Coastguard Worker check() 1164*da0073e9SAndroid Build Coastguard Worker 1165*da0073e9SAndroid Build Coastguard Worker param_list.append(torch.rand(2, 2)) 1166*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance(param_list[-1], Parameter) 1167*da0073e9SAndroid Build Coastguard Worker parameters.append(param_list[-1]) 1168*da0073e9SAndroid Build Coastguard Worker 1169*da0073e9SAndroid Build Coastguard Worker param_list.extend([torch.rand(2, 2), "foo"]) 1170*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance(param_list[-2], Parameter) 1171*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance(param_list[-1], str) 1172*da0073e9SAndroid Build Coastguard Worker parameters.extend(param_list[-2:]) 1173*da0073e9SAndroid Build Coastguard Worker 1174*da0073e9SAndroid Build Coastguard Worker param_list += ["bar", torch.rand(2, 2)] 1175*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance(param_list[-2], str) 1176*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance(param_list[-1], Parameter) 1177*da0073e9SAndroid Build Coastguard Worker parameters += param_list[-2:] 1178*da0073e9SAndroid Build Coastguard Worker check() 1179*da0073e9SAndroid Build Coastguard Worker 1180*da0073e9SAndroid Build Coastguard Worker def test_ParameterList_meta(self): 1181*da0073e9SAndroid Build Coastguard Worker p = torch.nn.Parameter(torch.empty(1, device='meta')) 1182*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline(str(p), """\ 1183*da0073e9SAndroid Build Coastguard WorkerParameter containing: 1184*da0073e9SAndroid Build Coastguard Workertensor(..., device='meta', size=(1,), requires_grad=True)""") 1185*da0073e9SAndroid Build Coastguard Worker pl = torch.nn.ParameterList([p]) 1186*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline(str(pl), """ParameterList( (0): Parameter containing: [torch.float32 of size 1])""") 1187*da0073e9SAndroid Build Coastguard Worker 1188*da0073e9SAndroid Build Coastguard Worker def test_ParameterList_replication(self): 1189*da0073e9SAndroid Build Coastguard Worker # The actual replication code from DP cannot be used on CPU so doing it manually here 1190*da0073e9SAndroid Build Coastguard Worker def make_param(): 1191*da0073e9SAndroid Build Coastguard Worker return Parameter(torch.randn(2, 2)) 1192*da0073e9SAndroid Build Coastguard Worker parameters = [make_param(), make_param()] 1193*da0073e9SAndroid Build Coastguard Worker param_list = nn.ParameterList(parameters) 1194*da0073e9SAndroid Build Coastguard Worker 1195*da0073e9SAndroid Build Coastguard Worker new_param_list = param_list._replicate_for_data_parallel() 1196*da0073e9SAndroid Build Coastguard Worker 1197*da0073e9SAndroid Build Coastguard Worker for n, p in param_list.named_parameters(): 1198*da0073e9SAndroid Build Coastguard Worker # Do a view here so that we can check the base later 1199*da0073e9SAndroid Build Coastguard Worker setattr(new_param_list, n, p.view_as(p)) 1200*da0073e9SAndroid Build Coastguard Worker 1201*da0073e9SAndroid Build Coastguard Worker for p, p2 in zip(param_list, new_param_list): 1202*da0073e9SAndroid Build Coastguard Worker self.assertEqual(p, p2) 1203*da0073e9SAndroid Build Coastguard Worker self.assertIsNotNone(p2.grad_fn) 1204*da0073e9SAndroid Build Coastguard Worker self.assertIs(p2._base, p) 1205*da0073e9SAndroid Build Coastguard Worker 1206*da0073e9SAndroid Build Coastguard Worker def test_ParameterDict(self): 1207*da0073e9SAndroid Build Coastguard Worker parameters = OrderedDict([ 1208*da0073e9SAndroid Build Coastguard Worker ('p1', Parameter(torch.randn(10, 10))), 1209*da0073e9SAndroid Build Coastguard Worker ('p2', Parameter(torch.randn(10, 10))), 1210*da0073e9SAndroid Build Coastguard Worker ('p3', Parameter(torch.randn(10, 10))), 1211*da0073e9SAndroid Build Coastguard Worker ]) 1212*da0073e9SAndroid Build Coastguard Worker 1213*da0073e9SAndroid Build Coastguard Worker parameter_dict = nn.ParameterDict(parameters) 1214*da0073e9SAndroid Build Coastguard Worker 1215*da0073e9SAndroid Build Coastguard Worker def check(): 1216*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(parameter_dict), len(parameters)) 1217*da0073e9SAndroid Build Coastguard Worker for i, (k1, (k2, m2)) in enumerate(zip(parameters, parameter_dict.named_parameters())): 1218*da0073e9SAndroid Build Coastguard Worker self.assertEqual(k1, k2) 1219*da0073e9SAndroid Build Coastguard Worker self.assertIs(parameters[k1], m2) 1220*da0073e9SAndroid Build Coastguard Worker for k1, k2 in zip(parameters, parameter_dict): 1221*da0073e9SAndroid Build Coastguard Worker self.assertIs(parameters[k1], parameter_dict[k2]) 1222*da0073e9SAndroid Build Coastguard Worker for k in parameter_dict: 1223*da0073e9SAndroid Build Coastguard Worker self.assertIs(parameter_dict[k], parameters[k]) 1224*da0073e9SAndroid Build Coastguard Worker for k in parameter_dict.keys(): 1225*da0073e9SAndroid Build Coastguard Worker self.assertIs(parameter_dict[k], parameters[k]) 1226*da0073e9SAndroid Build Coastguard Worker for k, v in parameter_dict.items(): 1227*da0073e9SAndroid Build Coastguard Worker self.assertIs(v, parameters[k]) 1228*da0073e9SAndroid Build Coastguard Worker for k1, m2 in zip(parameters, parameter_dict.values()): 1229*da0073e9SAndroid Build Coastguard Worker self.assertIs(parameters[k1], m2) 1230*da0073e9SAndroid Build Coastguard Worker for k in parameters.keys(): 1231*da0073e9SAndroid Build Coastguard Worker self.assertTrue(k in parameter_dict) 1232*da0073e9SAndroid Build Coastguard Worker 1233*da0073e9SAndroid Build Coastguard Worker check() 1234*da0073e9SAndroid Build Coastguard Worker 1235*da0073e9SAndroid Build Coastguard Worker parameters['p4'] = Parameter(torch.randn(10, 10)) 1236*da0073e9SAndroid Build Coastguard Worker parameter_dict['p4'] = parameters['p4'] 1237*da0073e9SAndroid Build Coastguard Worker check() 1238*da0073e9SAndroid Build Coastguard Worker 1239*da0073e9SAndroid Build Coastguard Worker next_parameters = [ 1240*da0073e9SAndroid Build Coastguard Worker ('p5', Parameter(torch.randn(10, 10))), 1241*da0073e9SAndroid Build Coastguard Worker ('p2', Parameter(torch.randn(10, 10))), 1242*da0073e9SAndroid Build Coastguard Worker ] 1243*da0073e9SAndroid Build Coastguard Worker parameters.update(next_parameters) 1244*da0073e9SAndroid Build Coastguard Worker parameter_dict.update(next_parameters) 1245*da0073e9SAndroid Build Coastguard Worker check() 1246*da0073e9SAndroid Build Coastguard Worker 1247*da0073e9SAndroid Build Coastguard Worker next_parameters = OrderedDict([ 1248*da0073e9SAndroid Build Coastguard Worker ('p6', Parameter(torch.randn(10, 10))), 1249*da0073e9SAndroid Build Coastguard Worker ('p5', Parameter(torch.randn(10, 10))), 1250*da0073e9SAndroid Build Coastguard Worker ]) 1251*da0073e9SAndroid Build Coastguard Worker parameters.update(next_parameters) 1252*da0073e9SAndroid Build Coastguard Worker parameter_dict.update(next_parameters) 1253*da0073e9SAndroid Build Coastguard Worker check() 1254*da0073e9SAndroid Build Coastguard Worker 1255*da0073e9SAndroid Build Coastguard Worker next_parameters = { 1256*da0073e9SAndroid Build Coastguard Worker 'p8': Parameter(torch.randn(10, 10)), 1257*da0073e9SAndroid Build Coastguard Worker 'p7': Parameter(torch.randn(10, 10)) 1258*da0073e9SAndroid Build Coastguard Worker } 1259*da0073e9SAndroid Build Coastguard Worker parameters.update(sorted(next_parameters.items())) 1260*da0073e9SAndroid Build Coastguard Worker parameter_dict.update(next_parameters) 1261*da0073e9SAndroid Build Coastguard Worker check() 1262*da0073e9SAndroid Build Coastguard Worker 1263*da0073e9SAndroid Build Coastguard Worker next_parameters = nn.ParameterDict([ 1264*da0073e9SAndroid Build Coastguard Worker ('p10', Parameter(torch.randn(10, 10))), 1265*da0073e9SAndroid Build Coastguard Worker ('p9', Parameter(torch.randn(10, 10))), 1266*da0073e9SAndroid Build Coastguard Worker ]) 1267*da0073e9SAndroid Build Coastguard Worker parameters.update(next_parameters) 1268*da0073e9SAndroid Build Coastguard Worker parameter_dict.update(next_parameters) 1269*da0073e9SAndroid Build Coastguard Worker check() 1270*da0073e9SAndroid Build Coastguard Worker 1271*da0073e9SAndroid Build Coastguard Worker del parameter_dict['p3'] 1272*da0073e9SAndroid Build Coastguard Worker del parameters['p3'] 1273*da0073e9SAndroid Build Coastguard Worker check() 1274*da0073e9SAndroid Build Coastguard Worker 1275*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(TypeError): 1276*da0073e9SAndroid Build Coastguard Worker parameter_dict.update(1) 1277*da0073e9SAndroid Build Coastguard Worker 1278*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(TypeError): 1279*da0073e9SAndroid Build Coastguard Worker parameter_dict.update([1]) 1280*da0073e9SAndroid Build Coastguard Worker 1281*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(ValueError): 1282*da0073e9SAndroid Build Coastguard Worker parameter_dict.update(Parameter(torch.randn(10, 10))) 1283*da0073e9SAndroid Build Coastguard Worker 1284*da0073e9SAndroid Build Coastguard Worker p_pop = parameter_dict.pop('p4') 1285*da0073e9SAndroid Build Coastguard Worker self.assertIs(p_pop, parameters['p4']) 1286*da0073e9SAndroid Build Coastguard Worker parameters.pop('p4') 1287*da0073e9SAndroid Build Coastguard Worker check() 1288*da0073e9SAndroid Build Coastguard Worker 1289*da0073e9SAndroid Build Coastguard Worker # Check reverse works 1290*da0073e9SAndroid Build Coastguard Worker forward = list(iter(parameter_dict)) 1291*da0073e9SAndroid Build Coastguard Worker backward = list(reversed(parameter_dict)) 1292*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(forward), len(backward)) 1293*da0073e9SAndroid Build Coastguard Worker n = len(forward) 1294*da0073e9SAndroid Build Coastguard Worker for i in range(n): 1295*da0073e9SAndroid Build Coastguard Worker self.assertIs(forward[i], backward[n - i - 1]) 1296*da0073e9SAndroid Build Coastguard Worker check() 1297*da0073e9SAndroid Build Coastguard Worker 1298*da0073e9SAndroid Build Coastguard Worker # Check copy works 1299*da0073e9SAndroid Build Coastguard Worker copy = parameter_dict.copy() 1300*da0073e9SAndroid Build Coastguard Worker 1301*da0073e9SAndroid Build Coastguard Worker # Check all keys are present and have shallow copied values 1302*da0073e9SAndroid Build Coastguard Worker for key in parameter_dict: 1303*da0073e9SAndroid Build Coastguard Worker self.assertTrue(key in copy) 1304*da0073e9SAndroid Build Coastguard Worker self.assertEqual(parameter_dict[key], copy[key]) 1305*da0073e9SAndroid Build Coastguard Worker self.assertIs(parameter_dict[key], copy[key]) 1306*da0073e9SAndroid Build Coastguard Worker check() 1307*da0073e9SAndroid Build Coastguard Worker 1308*da0073e9SAndroid Build Coastguard Worker parameter_dict["p20"] = Parameter(torch.randn(10, 10)) 1309*da0073e9SAndroid Build Coastguard Worker copy["p21"] = Parameter(torch.randn(9, 10)) 1310*da0073e9SAndroid Build Coastguard Worker 1311*da0073e9SAndroid Build Coastguard Worker self.assertTrue("p20" in parameter_dict) 1312*da0073e9SAndroid Build Coastguard Worker self.assertFalse("p20" in copy) 1313*da0073e9SAndroid Build Coastguard Worker self.assertFalse("p21" in parameter_dict) 1314*da0073e9SAndroid Build Coastguard Worker self.assertTrue("p21" in copy) 1315*da0073e9SAndroid Build Coastguard Worker parameter_dict.pop("p20") 1316*da0073e9SAndroid Build Coastguard Worker check() 1317*da0073e9SAndroid Build Coastguard Worker 1318*da0073e9SAndroid Build Coastguard Worker p = Parameter(torch.randn(10, 10)) 1319*da0073e9SAndroid Build Coastguard Worker parameter_dict['p12'] = p 1320*da0073e9SAndroid Build Coastguard Worker p_popitem = parameter_dict.popitem() 1321*da0073e9SAndroid Build Coastguard Worker self.assertEqual(p_popitem[0], 'p12') 1322*da0073e9SAndroid Build Coastguard Worker self.assertIs(p_popitem[1], p) 1323*da0073e9SAndroid Build Coastguard Worker check() 1324*da0073e9SAndroid Build Coastguard Worker 1325*da0073e9SAndroid Build Coastguard Worker # Unit test for set_default 1326*da0073e9SAndroid Build Coastguard Worker # 1. Ensure parameter is correctly inserted when 1327*da0073e9SAndroid Build Coastguard Worker # the key is not present in `ParameterDict` 1328*da0073e9SAndroid Build Coastguard Worker assert 'p11' not in parameter_dict 1329*da0073e9SAndroid Build Coastguard Worker assert 'p11' not in parameters 1330*da0073e9SAndroid Build Coastguard Worker parameters['p11'] = Parameter(torch.randn(10, 10)) 1331*da0073e9SAndroid Build Coastguard Worker p_setdefault = parameter_dict.setdefault('p11', parameters['p11']) 1332*da0073e9SAndroid Build Coastguard Worker self.assertIs(p_setdefault, parameters['p11']) 1333*da0073e9SAndroid Build Coastguard Worker self.assertIs(p_setdefault, parameter_dict['p11']) 1334*da0073e9SAndroid Build Coastguard Worker check() 1335*da0073e9SAndroid Build Coastguard Worker # 2. Ensure parameter is NOT inserted when the 1336*da0073e9SAndroid Build Coastguard Worker # key is already present in `ParameterDict` 1337*da0073e9SAndroid Build Coastguard Worker p = Parameter(torch.randn(10, 10)) 1338*da0073e9SAndroid Build Coastguard Worker self.assertFalse(parameter_dict.setdefault('p11', p) is p) 1339*da0073e9SAndroid Build Coastguard Worker check() 1340*da0073e9SAndroid Build Coastguard Worker # 3. Ensure `None` is inserted when the key is not 1341*da0073e9SAndroid Build Coastguard Worker # present in `Parameter` and parameter is not specified 1342*da0073e9SAndroid Build Coastguard Worker self.assertIs(parameter_dict.setdefault('p26'), None) 1343*da0073e9SAndroid Build Coastguard Worker del parameter_dict['p26'] 1344*da0073e9SAndroid Build Coastguard Worker check() 1345*da0073e9SAndroid Build Coastguard Worker 1346*da0073e9SAndroid Build Coastguard Worker parameters2 = OrderedDict([ 1347*da0073e9SAndroid Build Coastguard Worker ('p13', Parameter(torch.randn(10, 10))), 1348*da0073e9SAndroid Build Coastguard Worker ('p2', Parameter(torch.randn(10, 10))), 1349*da0073e9SAndroid Build Coastguard Worker ('p3', Parameter(torch.randn(10, 10))), 1350*da0073e9SAndroid Build Coastguard Worker ]) 1351*da0073e9SAndroid Build Coastguard Worker parameter_dict2 = nn.ParameterDict(parameters2) 1352*da0073e9SAndroid Build Coastguard Worker parameters.update(parameters2) 1353*da0073e9SAndroid Build Coastguard Worker parameter_dict |= parameter_dict2 1354*da0073e9SAndroid Build Coastguard Worker check() 1355*da0073e9SAndroid Build Coastguard Worker 1356*da0073e9SAndroid Build Coastguard Worker parameters2 = OrderedDict() 1357*da0073e9SAndroid Build Coastguard Worker parameter_dict2 = nn.ParameterDict(parameters2) 1358*da0073e9SAndroid Build Coastguard Worker parameters.update(parameters2) 1359*da0073e9SAndroid Build Coastguard Worker parameter_dict |= parameter_dict2 1360*da0073e9SAndroid Build Coastguard Worker check() 1361*da0073e9SAndroid Build Coastguard Worker 1362*da0073e9SAndroid Build Coastguard Worker parameters2 = OrderedDict([ 1363*da0073e9SAndroid Build Coastguard Worker ('p14', Parameter(torch.randn(10, 10))), 1364*da0073e9SAndroid Build Coastguard Worker ('p15', Parameter(torch.randn(10, 10))), 1365*da0073e9SAndroid Build Coastguard Worker ('p13', Parameter(torch.randn(10, 10))), 1366*da0073e9SAndroid Build Coastguard Worker ]) 1367*da0073e9SAndroid Build Coastguard Worker parameter_dict2 = nn.ParameterDict(parameters2) 1368*da0073e9SAndroid Build Coastguard Worker parameters.update(parameters2) 1369*da0073e9SAndroid Build Coastguard Worker parameter_dict |= parameter_dict2 1370*da0073e9SAndroid Build Coastguard Worker check() 1371*da0073e9SAndroid Build Coastguard Worker 1372*da0073e9SAndroid Build Coastguard Worker # Check __or__ and __ror__ works 1373*da0073e9SAndroid Build Coastguard Worker parameters2 = OrderedDict([ 1374*da0073e9SAndroid Build Coastguard Worker ('p20', Parameter(torch.randn(10, 10))), 1375*da0073e9SAndroid Build Coastguard Worker ('p21', Parameter(torch.randn(10, 10))), 1376*da0073e9SAndroid Build Coastguard Worker ('p22', Parameter(torch.randn(10, 10))), 1377*da0073e9SAndroid Build Coastguard Worker ]) 1378*da0073e9SAndroid Build Coastguard Worker parameter_dict2 = nn.ParameterDict(parameters2) 1379*da0073e9SAndroid Build Coastguard Worker parameters.update(parameters2) 1380*da0073e9SAndroid Build Coastguard Worker parameter_dict = parameter_dict | parameter_dict2 1381*da0073e9SAndroid Build Coastguard Worker check() 1382*da0073e9SAndroid Build Coastguard Worker 1383*da0073e9SAndroid Build Coastguard Worker parameters2 = OrderedDict([ 1384*da0073e9SAndroid Build Coastguard Worker ('p23', Parameter(torch.randn(10, 10))), 1385*da0073e9SAndroid Build Coastguard Worker ('p24', Parameter(torch.randn(10, 10))), 1386*da0073e9SAndroid Build Coastguard Worker ('p25', Parameter(torch.randn(10, 10))), 1387*da0073e9SAndroid Build Coastguard Worker ]) 1388*da0073e9SAndroid Build Coastguard Worker parameter_dict2 = nn.ParameterDict(parameters2) 1389*da0073e9SAndroid Build Coastguard Worker parameters2.update(parameters) 1390*da0073e9SAndroid Build Coastguard Worker parameters = parameters2 1391*da0073e9SAndroid Build Coastguard Worker parameter_dict = parameter_dict2 | parameter_dict 1392*da0073e9SAndroid Build Coastguard Worker check() 1393*da0073e9SAndroid Build Coastguard Worker 1394*da0073e9SAndroid Build Coastguard Worker parameters['p17'] = Parameter(torch.randn(10, 10)) 1395*da0073e9SAndroid Build Coastguard Worker parameter_dict['p17'] = parameters['p17'] 1396*da0073e9SAndroid Build Coastguard Worker self.assertIs(parameters['p17'], parameter_dict.get('p17')) 1397*da0073e9SAndroid Build Coastguard Worker temp_param = Parameter(torch.randn(10, 10)) 1398*da0073e9SAndroid Build Coastguard Worker self.assertIs(parameters['p17'], parameter_dict.get('p17', temp_param)) 1399*da0073e9SAndroid Build Coastguard Worker self.assertIs(None, parameter_dict.get('p18')) 1400*da0073e9SAndroid Build Coastguard Worker self.assertIs(temp_param, parameter_dict.get('p18', temp_param)) 1401*da0073e9SAndroid Build Coastguard Worker check() 1402*da0073e9SAndroid Build Coastguard Worker 1403*da0073e9SAndroid Build Coastguard Worker parameter_dict.clear() 1404*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(parameter_dict), 0) 1405*da0073e9SAndroid Build Coastguard Worker parameters.clear() 1406*da0073e9SAndroid Build Coastguard Worker check() 1407*da0073e9SAndroid Build Coastguard Worker 1408*da0073e9SAndroid Build Coastguard Worker parameter_dict2 = parameter_dict.fromkeys(['p19', 'p20']) 1409*da0073e9SAndroid Build Coastguard Worker self.assertEqual({'p19': None, 'p20': None}, parameter_dict2) 1410*da0073e9SAndroid Build Coastguard Worker check() 1411*da0073e9SAndroid Build Coastguard Worker 1412*da0073e9SAndroid Build Coastguard Worker parameter_dict2 = parameter_dict.fromkeys(['p19', 'p20'], temp_param) 1413*da0073e9SAndroid Build Coastguard Worker self.assertEqual({'p19': temp_param, 'p20': temp_param}, parameter_dict2) 1414*da0073e9SAndroid Build Coastguard Worker check() 1415*da0073e9SAndroid Build Coastguard Worker 1416*da0073e9SAndroid Build Coastguard Worker parameter_dict['p21'] = torch.rand(2, 2) 1417*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance(parameter_dict['p21'], Parameter) 1418*da0073e9SAndroid Build Coastguard Worker parameters['p21'] = parameter_dict['p21'] 1419*da0073e9SAndroid Build Coastguard Worker 1420*da0073e9SAndroid Build Coastguard Worker parameter_dict.update({'p22': torch.rand(2, 2), 'foo': 'bar'}) 1421*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance(parameter_dict['p22'], Parameter) 1422*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance(parameter_dict['foo'], str) 1423*da0073e9SAndroid Build Coastguard Worker parameters['p22'] = parameter_dict['p22'] 1424*da0073e9SAndroid Build Coastguard Worker parameters['foo'] = parameter_dict['foo'] 1425*da0073e9SAndroid Build Coastguard Worker 1426*da0073e9SAndroid Build Coastguard Worker def test_ParameterDict_replication(self): 1427*da0073e9SAndroid Build Coastguard Worker # The actual replication code from DP cannot be used on CPU so doing it manually here 1428*da0073e9SAndroid Build Coastguard Worker def make_param(): 1429*da0073e9SAndroid Build Coastguard Worker return Parameter(torch.randn(2, 2)) 1430*da0073e9SAndroid Build Coastguard Worker parameters = {"foo": make_param(), "bar": make_param()} 1431*da0073e9SAndroid Build Coastguard Worker param_dict = nn.ParameterDict(parameters) 1432*da0073e9SAndroid Build Coastguard Worker 1433*da0073e9SAndroid Build Coastguard Worker new_param_dict = param_dict._replicate_for_data_parallel() 1434*da0073e9SAndroid Build Coastguard Worker 1435*da0073e9SAndroid Build Coastguard Worker for n, p in param_dict.named_parameters(): 1436*da0073e9SAndroid Build Coastguard Worker # Do a view here so that we can check the base later 1437*da0073e9SAndroid Build Coastguard Worker setattr(new_param_dict, n, p.view_as(p)) 1438*da0073e9SAndroid Build Coastguard Worker 1439*da0073e9SAndroid Build Coastguard Worker for (k, p), (k2, p2) in zip(param_dict.items(), new_param_dict.items()): 1440*da0073e9SAndroid Build Coastguard Worker self.assertEqual(k, k2) 1441*da0073e9SAndroid Build Coastguard Worker self.assertEqual(p, p2) 1442*da0073e9SAndroid Build Coastguard Worker self.assertIsNotNone(p2.grad_fn) 1443*da0073e9SAndroid Build Coastguard Worker self.assertIs(p2._base, p) 1444*da0073e9SAndroid Build Coastguard Worker 1445*da0073e9SAndroid Build Coastguard Worker self.assertEqual(param_dict["foo"], new_param_dict["foo"]) 1446*da0073e9SAndroid Build Coastguard Worker 1447*da0073e9SAndroid Build Coastguard Worker def test_add_module(self): 1448*da0073e9SAndroid Build Coastguard Worker methods_to_test = ['add_module', 'register_module'] 1449*da0073e9SAndroid Build Coastguard Worker for fn in methods_to_test: 1450*da0073e9SAndroid Build Coastguard Worker l = nn.Linear(10, 20) 1451*da0073e9SAndroid Build Coastguard Worker net = nn.Module() 1452*da0073e9SAndroid Build Coastguard Worker net.l = l 1453*da0073e9SAndroid Build Coastguard Worker net.l2 = l 1454*da0073e9SAndroid Build Coastguard Worker getattr(net, fn)('empty', None) 1455*da0073e9SAndroid Build Coastguard Worker self.assertEqual(net.l, l) 1456*da0073e9SAndroid Build Coastguard Worker self.assertEqual(net.l2, l) 1457*da0073e9SAndroid Build Coastguard Worker self.assertEqual(net.empty, None) 1458*da0073e9SAndroid Build Coastguard Worker getattr(net, fn)('l3', l) 1459*da0073e9SAndroid Build Coastguard Worker self.assertEqual(net.l3, l) 1460*da0073e9SAndroid Build Coastguard Worker l3 = nn.Linear(20, 10) 1461*da0073e9SAndroid Build Coastguard Worker getattr(net, fn)('l', l3) 1462*da0073e9SAndroid Build Coastguard Worker self.assertEqual(net.l, l3) 1463*da0073e9SAndroid Build Coastguard Worker self.assertRaises(TypeError, lambda: getattr(net, fn)('x', 'non-module')) 1464*da0073e9SAndroid Build Coastguard Worker self.assertRaisesRegex(TypeError, 'module name should be a string. Got int', 1465*da0073e9SAndroid Build Coastguard Worker lambda: getattr(net, fn)(1, l)) 1466*da0073e9SAndroid Build Coastguard Worker self.assertRaisesRegex(TypeError, 'module name should be a string. Got NoneType', 1467*da0073e9SAndroid Build Coastguard Worker lambda: getattr(net, fn)(None, l)) 1468*da0073e9SAndroid Build Coastguard Worker 1469*da0073e9SAndroid Build Coastguard Worker def test_set_submodule(self): 1470*da0073e9SAndroid Build Coastguard Worker net = nn.Module() 1471*da0073e9SAndroid Build Coastguard Worker net.t = nn.Module() 1472*da0073e9SAndroid Build Coastguard Worker l = nn.Linear(1, 2) 1473*da0073e9SAndroid Build Coastguard Worker target = "t.l" 1474*da0073e9SAndroid Build Coastguard Worker net.set_submodule(target, l) 1475*da0073e9SAndroid Build Coastguard Worker self.assertEqual(net.get_submodule(target), l) 1476*da0073e9SAndroid Build Coastguard Worker l2 = nn.Linear(2, 1) 1477*da0073e9SAndroid Build Coastguard Worker net.set_submodule(target, l2) 1478*da0073e9SAndroid Build Coastguard Worker self.assertEqual(net.get_submodule(target), l2) 1479*da0073e9SAndroid Build Coastguard Worker self.assertRaises(ValueError, net.set_submodule, "", l) 1480*da0073e9SAndroid Build Coastguard Worker self.assertRaises(AttributeError, net.set_submodule, "a.l", l) 1481*da0073e9SAndroid Build Coastguard Worker 1482*da0073e9SAndroid Build Coastguard Worker def test_module_to_argparse(self): 1483*da0073e9SAndroid Build Coastguard Worker net = nn.Sequential(nn.Linear(3, 3)) 1484*da0073e9SAndroid Build Coastguard Worker cpu = torch.device('cpu') 1485*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(TypeError): 1486*da0073e9SAndroid Build Coastguard Worker net.to(cpu, True) 1487*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(TypeError): 1488*da0073e9SAndroid Build Coastguard Worker net.to(torch.long) 1489*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(TypeError): 1490*da0073e9SAndroid Build Coastguard Worker net.to(None, True) 1491*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(TypeError): 1492*da0073e9SAndroid Build Coastguard Worker net.to(cpu, torch.long, True) 1493*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(TypeError): 1494*da0073e9SAndroid Build Coastguard Worker net.to(cpu, dtype=torch.long, non_blocking=True) 1495*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(TypeError): 1496*da0073e9SAndroid Build Coastguard Worker net.to([]) 1497*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(TypeError): 1498*da0073e9SAndroid Build Coastguard Worker net.to({}, non_blocking=True) 1499*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(TypeError): 1500*da0073e9SAndroid Build Coastguard Worker net.to(torch.tensor(3, dtype=torch.long), non_blocking=True) 1501*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(TypeError): 1502*da0073e9SAndroid Build Coastguard Worker net.to(cpu, torch.tensor(3, dtype=torch.long), non_blocking=True) 1503*da0073e9SAndroid Build Coastguard Worker 1504*da0073e9SAndroid Build Coastguard Worker def test_RNN_nonlinearity(self): 1505*da0073e9SAndroid Build Coastguard Worker rnn = torch.nn.RNN(1, 10) 1506*da0073e9SAndroid Build Coastguard Worker self.assertEqual(rnn.nonlinearity, 'tanh') 1507*da0073e9SAndroid Build Coastguard Worker 1508*da0073e9SAndroid Build Coastguard Worker rnn = torch.nn.RNN(1, 10, nonlinearity='relu') 1509*da0073e9SAndroid Build Coastguard Worker self.assertEqual(rnn.nonlinearity, 'relu') 1510*da0073e9SAndroid Build Coastguard Worker 1511*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(ValueError, 'Unknown nonlinearity'): 1512*da0073e9SAndroid Build Coastguard Worker rnn = torch.nn.RNN(1, 10, nonlinearity='garbage') 1513*da0073e9SAndroid Build Coastguard Worker 1514*da0073e9SAndroid Build Coastguard Worker def test_RNN_nonlinearity_passed_as_arg(self): 1515*da0073e9SAndroid Build Coastguard Worker rnn = torch.nn.RNN(2, 3, 1, 'relu') 1516*da0073e9SAndroid Build Coastguard Worker self.assertEqual(rnn.nonlinearity, 'relu') 1517*da0073e9SAndroid Build Coastguard Worker 1518*da0073e9SAndroid Build Coastguard Worker def test_module_apply_inplace_op(self): 1519*da0073e9SAndroid Build Coastguard Worker def add_one_inplace(t): 1520*da0073e9SAndroid Build Coastguard Worker return t.add_(1.0) 1521*da0073e9SAndroid Build Coastguard Worker 1522*da0073e9SAndroid Build Coastguard Worker # Test that applying an in-place operation to a module would bump 1523*da0073e9SAndroid Build Coastguard Worker # the module's parameters' version counter. 1524*da0073e9SAndroid Build Coastguard Worker m = nn.Linear(20, 10) 1525*da0073e9SAndroid Build Coastguard Worker pvm = m.weight.mul(m.weight) 1526*da0073e9SAndroid Build Coastguard Worker m_weight_version_saved = m.weight._version 1527*da0073e9SAndroid Build Coastguard Worker m = m._apply(add_one_inplace) 1528*da0073e9SAndroid Build Coastguard Worker self.assertGreater(m.weight._version, m_weight_version_saved) 1529*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "modified by an inplace operation"): 1530*da0073e9SAndroid Build Coastguard Worker pvm.backward(torch.randn(10, 20)) 1531*da0073e9SAndroid Build Coastguard Worker 1532*da0073e9SAndroid Build Coastguard Worker # Test that applying an in-place operation to a module would bump 1533*da0073e9SAndroid Build Coastguard Worker # the module's parameters' gradients' version counter. 1534*da0073e9SAndroid Build Coastguard Worker m = nn.Linear(20, 10) 1535*da0073e9SAndroid Build Coastguard Worker m.weight.grad = torch.randn(10, 20).requires_grad_() 1536*da0073e9SAndroid Build Coastguard Worker pgm = m.weight.grad.mul(m.weight.grad) 1537*da0073e9SAndroid Build Coastguard Worker m_weight_grad_version_saved = m.weight.grad._version 1538*da0073e9SAndroid Build Coastguard Worker m = m._apply(add_one_inplace) 1539*da0073e9SAndroid Build Coastguard Worker self.assertGreater(m.weight.grad._version, m_weight_grad_version_saved) 1540*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "modified by an inplace operation"): 1541*da0073e9SAndroid Build Coastguard Worker pgm.backward(torch.randn(10, 20)) 1542*da0073e9SAndroid Build Coastguard Worker 1543*da0073e9SAndroid Build Coastguard Worker def test_overwrite_module_params_on_conversion(self): 1544*da0073e9SAndroid Build Coastguard Worker # Test that if the conversion function passed to `module._apply()` 1545*da0073e9SAndroid Build Coastguard Worker # changes the TensorImpl type of `module`'s parameters, the `module`'s 1546*da0073e9SAndroid Build Coastguard Worker # parameters are always overwritten, regardless of the value of 1547*da0073e9SAndroid Build Coastguard Worker # `torch.__future__.get_overwrite_module_params_on_conversion()`. 1548*da0073e9SAndroid Build Coastguard Worker m = nn.Linear(20, 10) 1549*da0073e9SAndroid Build Coastguard Worker m.weight.grad = torch.randn(10, 20) 1550*da0073e9SAndroid Build Coastguard Worker weight_ref = m.weight 1551*da0073e9SAndroid Build Coastguard Worker weight_grad_ref = m.weight.grad 1552*da0073e9SAndroid Build Coastguard Worker m = m._apply(lambda t: torch.sparse_coo_tensor(torch.zeros([2, 1]), torch.ones([1]), torch.Size([10, 20]))) 1553*da0073e9SAndroid Build Coastguard Worker self.assertNotEqual(weight_ref.layout, m.weight.layout) 1554*da0073e9SAndroid Build Coastguard Worker self.assertNotEqual(weight_grad_ref.layout, m.weight.grad.layout) 1555*da0073e9SAndroid Build Coastguard Worker 1556*da0073e9SAndroid Build Coastguard Worker # Test that under the current default settings 1557*da0073e9SAndroid Build Coastguard Worker # (`torch.__future__.get_overwrite_module_params_on_conversion() == False`), 1558*da0073e9SAndroid Build Coastguard Worker # a view to a module's parameters is not pointing to the same storage as 1559*da0073e9SAndroid Build Coastguard Worker # its base variable after converting the module to a different dtype. 1560*da0073e9SAndroid Build Coastguard Worker m = nn.Linear(20, 10).float() 1561*da0073e9SAndroid Build Coastguard Worker mw = m.weight[:] 1562*da0073e9SAndroid Build Coastguard Worker m.double() 1563*da0073e9SAndroid Build Coastguard Worker with torch.no_grad(): 1564*da0073e9SAndroid Build Coastguard Worker mw[0][0] = 5 1565*da0073e9SAndroid Build Coastguard Worker self.assertTrue(mw[0][0].dtype == torch.float) 1566*da0073e9SAndroid Build Coastguard Worker self.assertTrue(mw._base[0][0].dtype == torch.double) 1567*da0073e9SAndroid Build Coastguard Worker 1568*da0073e9SAndroid Build Coastguard Worker try: 1569*da0073e9SAndroid Build Coastguard Worker torch.__future__.set_overwrite_module_params_on_conversion(True) 1570*da0073e9SAndroid Build Coastguard Worker 1571*da0073e9SAndroid Build Coastguard Worker # Test that if `torch.__future__.get_overwrite_module_params_on_conversion() == True`, 1572*da0073e9SAndroid Build Coastguard Worker # a view to a module's parameters is still pointing to the same storage as 1573*da0073e9SAndroid Build Coastguard Worker # its base variable after converting the module to a different dtype. 1574*da0073e9SAndroid Build Coastguard Worker m = nn.Linear(20, 10).float() 1575*da0073e9SAndroid Build Coastguard Worker mw = m.weight[:] 1576*da0073e9SAndroid Build Coastguard Worker m.double() 1577*da0073e9SAndroid Build Coastguard Worker with torch.no_grad(): 1578*da0073e9SAndroid Build Coastguard Worker mw[0][0] = 5 1579*da0073e9SAndroid Build Coastguard Worker self.assertTrue(mw[0][0] == mw._base[0][0]) 1580*da0073e9SAndroid Build Coastguard Worker 1581*da0073e9SAndroid Build Coastguard Worker # Test that if `torch.__future__.get_overwrite_module_params_on_conversion() == True`, 1582*da0073e9SAndroid Build Coastguard Worker # `float_module.double()` doesn't preserve previous references to 1583*da0073e9SAndroid Build Coastguard Worker # `float_module`'s parameters or gradients. 1584*da0073e9SAndroid Build Coastguard Worker m = nn.Linear(20, 10).float() 1585*da0073e9SAndroid Build Coastguard Worker m.weight.grad = torch.randn(10, 20).float() 1586*da0073e9SAndroid Build Coastguard Worker weight_ref = m.weight 1587*da0073e9SAndroid Build Coastguard Worker weight_grad_ref = m.weight.grad 1588*da0073e9SAndroid Build Coastguard Worker m.double() 1589*da0073e9SAndroid Build Coastguard Worker self.assertNotEqual(weight_ref.dtype, m.weight.dtype) 1590*da0073e9SAndroid Build Coastguard Worker self.assertNotEqual(weight_grad_ref.dtype, m.weight.grad.dtype) 1591*da0073e9SAndroid Build Coastguard Worker 1592*da0073e9SAndroid Build Coastguard Worker def add_one_inplace(t): 1593*da0073e9SAndroid Build Coastguard Worker return t.add_(1.0) 1594*da0073e9SAndroid Build Coastguard Worker 1595*da0073e9SAndroid Build Coastguard Worker # Test that if `torch.__future__.get_overwrite_module_params_on_conversion() == True`, 1596*da0073e9SAndroid Build Coastguard Worker # applying an in-place operation to a module would bump the module's 1597*da0073e9SAndroid Build Coastguard Worker # original parameters' version counter. 1598*da0073e9SAndroid Build Coastguard Worker m = nn.Linear(20, 10) 1599*da0073e9SAndroid Build Coastguard Worker pvm = m.weight.mul(m.weight) 1600*da0073e9SAndroid Build Coastguard Worker weight_ref = m.weight 1601*da0073e9SAndroid Build Coastguard Worker m_weight_version_saved = weight_ref._version 1602*da0073e9SAndroid Build Coastguard Worker m = m._apply(add_one_inplace) 1603*da0073e9SAndroid Build Coastguard Worker # Test that the in-place operation bumps the original parameter's version counter 1604*da0073e9SAndroid Build Coastguard Worker self.assertGreater(weight_ref._version, m_weight_version_saved) 1605*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "modified by an inplace operation"): 1606*da0073e9SAndroid Build Coastguard Worker pvm.backward(torch.randn(10, 20)) 1607*da0073e9SAndroid Build Coastguard Worker 1608*da0073e9SAndroid Build Coastguard Worker # Test that if `torch.__future__.get_overwrite_module_params_on_conversion() == True`, 1609*da0073e9SAndroid Build Coastguard Worker # applying an in-place operation to a module would bump the module's 1610*da0073e9SAndroid Build Coastguard Worker # original parameters' gradients' version counter. 1611*da0073e9SAndroid Build Coastguard Worker m = nn.Linear(20, 10) 1612*da0073e9SAndroid Build Coastguard Worker m.weight.grad = torch.randn(10, 20).requires_grad_() 1613*da0073e9SAndroid Build Coastguard Worker pgm = m.weight.grad.mul(m.weight.grad) 1614*da0073e9SAndroid Build Coastguard Worker weight_grad_ref = m.weight.grad 1615*da0073e9SAndroid Build Coastguard Worker m_weight_grad_version_saved = weight_grad_ref._version 1616*da0073e9SAndroid Build Coastguard Worker m = m._apply(add_one_inplace) 1617*da0073e9SAndroid Build Coastguard Worker self.assertGreater(weight_grad_ref._version, m_weight_grad_version_saved) 1618*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "modified by an inplace operation"): 1619*da0073e9SAndroid Build Coastguard Worker pgm.backward(torch.randn(10, 20)) 1620*da0073e9SAndroid Build Coastguard Worker 1621*da0073e9SAndroid Build Coastguard Worker # Test that if `torch.__future__.get_overwrite_module_params_on_conversion() == True`, 1622*da0073e9SAndroid Build Coastguard Worker # applying an out-of-place operation to a module doesn't bump 1623*da0073e9SAndroid Build Coastguard Worker # the module's original parameters' version counter. 1624*da0073e9SAndroid Build Coastguard Worker m = nn.Linear(20, 10) 1625*da0073e9SAndroid Build Coastguard Worker weight_ref = m.weight 1626*da0073e9SAndroid Build Coastguard Worker m_weight_version_saved = weight_ref._version 1627*da0073e9SAndroid Build Coastguard Worker m = m._apply(lambda t: torch.randn(t.shape)) 1628*da0073e9SAndroid Build Coastguard Worker self.assertEqual(weight_ref._version, m_weight_version_saved) 1629*da0073e9SAndroid Build Coastguard Worker 1630*da0073e9SAndroid Build Coastguard Worker # Test that if `torch.__future__.get_overwrite_module_params_on_conversion() == True`, 1631*da0073e9SAndroid Build Coastguard Worker # applying an out-of-place operation to a module doesn't bump 1632*da0073e9SAndroid Build Coastguard Worker # the module's original parameters' gradients' version counter. 1633*da0073e9SAndroid Build Coastguard Worker m = nn.Linear(20, 10) 1634*da0073e9SAndroid Build Coastguard Worker m.weight.grad = torch.randn(10, 20).requires_grad_() 1635*da0073e9SAndroid Build Coastguard Worker weight_grad_ref = m.weight.grad 1636*da0073e9SAndroid Build Coastguard Worker m_weight_grad_version_saved = weight_grad_ref._version 1637*da0073e9SAndroid Build Coastguard Worker m = m._apply(lambda t: torch.randn(t.shape)) 1638*da0073e9SAndroid Build Coastguard Worker self.assertEqual(weight_grad_ref._version, m_weight_grad_version_saved) 1639*da0073e9SAndroid Build Coastguard Worker finally: 1640*da0073e9SAndroid Build Coastguard Worker torch.__future__.set_overwrite_module_params_on_conversion(False) 1641*da0073e9SAndroid Build Coastguard Worker 1642*da0073e9SAndroid Build Coastguard Worker def test_swap_module_params_poisons_acc_grad(self): 1643*da0073e9SAndroid Build Coastguard Worker try: 1644*da0073e9SAndroid Build Coastguard Worker torch.__future__.set_swap_module_params_on_conversion(True) 1645*da0073e9SAndroid Build Coastguard Worker # (1) backward cannot be run after _apply 1646*da0073e9SAndroid Build Coastguard Worker # forward will init AccumulateGrad nodes, which bumps use_count of parameters' at::Tensors 1647*da0073e9SAndroid Build Coastguard Worker # additionally, if any Tensors are saved for backward, their use_count will be bumped 1648*da0073e9SAndroid Build Coastguard Worker m = torch.nn.Linear(2, 3) 1649*da0073e9SAndroid Build Coastguard Worker inp = torch.randn(2, 2) 1650*da0073e9SAndroid Build Coastguard Worker out = m(inp) 1651*da0073e9SAndroid Build Coastguard Worker m.half() 1652*da0073e9SAndroid Build Coastguard Worker self.assertTrue(all(p.dtype == torch.float16 for p in m.parameters())) 1653*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "Trying to execute AccumulateGrad node that was poisoned by swap_tensors"): 1654*da0073e9SAndroid Build Coastguard Worker out.sum().backward() 1655*da0073e9SAndroid Build Coastguard Worker # (2) _apply can be run after backward() 1656*da0073e9SAndroid Build Coastguard Worker # After running backward, all the references generated by "save for backward" will be cleared 1657*da0073e9SAndroid Build Coastguard Worker # So the use_count will be 2 (1 from Tensor itself, and 1 from AccumulateGrad node), swap_tensors 1658*da0073e9SAndroid Build Coastguard Worker # should allow this. 1659*da0073e9SAndroid Build Coastguard Worker inp2 = torch.randn(2, 2, dtype=torch.half) 1660*da0073e9SAndroid Build Coastguard Worker out2 = m(inp2) 1661*da0073e9SAndroid Build Coastguard Worker out2.sum().backward() 1662*da0073e9SAndroid Build Coastguard Worker m.float() 1663*da0073e9SAndroid Build Coastguard Worker self.assertTrue(all(p.dtype == torch.float32 for p in m.parameters())) 1664*da0073e9SAndroid Build Coastguard Worker out3 = m(inp) 1665*da0073e9SAndroid Build Coastguard Worker finally: 1666*da0073e9SAndroid Build Coastguard Worker torch.__future__.set_swap_module_params_on_conversion(False) 1667*da0073e9SAndroid Build Coastguard Worker 1668*da0073e9SAndroid Build Coastguard Worker def test_type(self): 1669*da0073e9SAndroid Build Coastguard Worker l = nn.Linear(10, 20) 1670*da0073e9SAndroid Build Coastguard Worker net = nn.Module() 1671*da0073e9SAndroid Build Coastguard Worker net.l = l 1672*da0073e9SAndroid Build Coastguard Worker net.l2 = l 1673*da0073e9SAndroid Build Coastguard Worker net.add_module('empty', None) 1674*da0073e9SAndroid Build Coastguard Worker net.indices = Buffer(torch.LongTensor(1)) 1675*da0073e9SAndroid Build Coastguard Worker net.float() 1676*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance(l.weight.data, torch.FloatTensor) 1677*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance(l.bias.data, torch.FloatTensor) 1678*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance(net.indices, torch.LongTensor) 1679*da0073e9SAndroid Build Coastguard Worker net.double() 1680*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance(l.weight.data, torch.DoubleTensor) 1681*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance(l.bias.data, torch.DoubleTensor) 1682*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance(net.indices, torch.LongTensor) 1683*da0073e9SAndroid Build Coastguard Worker net.to(torch.half) 1684*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance(l.weight.data, torch.HalfTensor) 1685*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance(l.bias.data, torch.HalfTensor) 1686*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance(net.indices, torch.LongTensor) 1687*da0073e9SAndroid Build Coastguard Worker if TEST_CUDA: 1688*da0073e9SAndroid Build Coastguard Worker net.float().cuda() 1689*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance(l.weight.data, torch.cuda.FloatTensor) 1690*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance(l.bias.data, torch.cuda.FloatTensor) 1691*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance(net.indices, torch.cuda.LongTensor) 1692*da0073e9SAndroid Build Coastguard Worker net.cpu() 1693*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance(l.weight.data, torch.FloatTensor) 1694*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance(l.bias.data, torch.FloatTensor) 1695*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance(net.indices, torch.LongTensor) 1696*da0073e9SAndroid Build Coastguard Worker net.to("cuda", torch.double, True) 1697*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance(l.weight.data, torch.cuda.DoubleTensor) 1698*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance(l.bias.data, torch.cuda.DoubleTensor) 1699*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance(net.indices, torch.cuda.LongTensor) 1700*da0073e9SAndroid Build Coastguard Worker net.to(torch.empty(1, device="cuda:0", dtype=torch.half)) 1701*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance(l.weight.data, torch.cuda.HalfTensor) 1702*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance(l.bias.data, torch.cuda.HalfTensor) 1703*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance(net.indices, torch.cuda.LongTensor) 1704*da0073e9SAndroid Build Coastguard Worker net.to(torch.device("cpu"), non_blocking=True) 1705*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance(l.weight.data, torch.HalfTensor) 1706*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance(l.bias.data, torch.HalfTensor) 1707*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance(net.indices, torch.LongTensor) 1708*da0073e9SAndroid Build Coastguard Worker net.to(torch.float) 1709*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance(l.weight.data, torch.FloatTensor) 1710*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance(l.bias.data, torch.FloatTensor) 1711*da0073e9SAndroid Build Coastguard Worker net.to(torch.DoubleTensor(1)) 1712*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance(l.weight.data, torch.DoubleTensor) 1713*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance(l.bias.data, torch.DoubleTensor) 1714*da0073e9SAndroid Build Coastguard Worker if TEST_CUDA: 1715*da0073e9SAndroid Build Coastguard Worker net.to(device='cuda', dtype=torch.float) 1716*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance(l.weight.data, torch.cuda.FloatTensor) 1717*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance(l.bias.data, torch.cuda.FloatTensor) 1718*da0073e9SAndroid Build Coastguard Worker 1719*da0073e9SAndroid Build Coastguard Worker def test_non_leaf_parameters(self): 1720*da0073e9SAndroid Build Coastguard Worker l1 = nn.Linear(10, 10) 1721*da0073e9SAndroid Build Coastguard Worker l2 = nn.Linear(10, 10) 1722*da0073e9SAndroid Build Coastguard Worker 1723*da0073e9SAndroid Build Coastguard Worker def assign_weight(): 1724*da0073e9SAndroid Build Coastguard Worker l2.weight = l1.weight + 2 1725*da0073e9SAndroid Build Coastguard Worker 1726*da0073e9SAndroid Build Coastguard Worker self.assertRaises(TypeError, assign_weight) 1727*da0073e9SAndroid Build Coastguard Worker # This should work though 1728*da0073e9SAndroid Build Coastguard Worker l2.weight = Parameter(torch.randn(10, 10)) 1729*da0073e9SAndroid Build Coastguard Worker 1730*da0073e9SAndroid Build Coastguard Worker def test_parameters_to_vector(self): 1731*da0073e9SAndroid Build Coastguard Worker conv1 = nn.Conv2d(3, 10, 5) 1732*da0073e9SAndroid Build Coastguard Worker fc1 = nn.Linear(10, 20) 1733*da0073e9SAndroid Build Coastguard Worker model = nn.Sequential(conv1, fc1) 1734*da0073e9SAndroid Build Coastguard Worker 1735*da0073e9SAndroid Build Coastguard Worker vec = parameters_to_vector(model.parameters()) 1736*da0073e9SAndroid Build Coastguard Worker self.assertEqual(vec.size(0), 980) 1737*da0073e9SAndroid Build Coastguard Worker 1738*da0073e9SAndroid Build Coastguard Worker def test_vector_to_parameters(self): 1739*da0073e9SAndroid Build Coastguard Worker conv1 = nn.Conv2d(3, 10, 5) 1740*da0073e9SAndroid Build Coastguard Worker fc1 = nn.Linear(10, 20) 1741*da0073e9SAndroid Build Coastguard Worker model = nn.Sequential(conv1, fc1) 1742*da0073e9SAndroid Build Coastguard Worker 1743*da0073e9SAndroid Build Coastguard Worker vec = torch.arange(0., 980) 1744*da0073e9SAndroid Build Coastguard Worker vector_to_parameters(vec, model.parameters()) 1745*da0073e9SAndroid Build Coastguard Worker 1746*da0073e9SAndroid Build Coastguard Worker sample = next(model.parameters())[0, 0, 0] 1747*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.equal(sample.data, vec.data[:5])) 1748*da0073e9SAndroid Build Coastguard Worker 1749*da0073e9SAndroid Build Coastguard Worker def test_rnn_weight_norm(self): 1750*da0073e9SAndroid Build Coastguard Worker def check_weight_norm(l, name, num_params): 1751*da0073e9SAndroid Build Coastguard Worker # This Module has 4 or 5 parameters called: 1752*da0073e9SAndroid Build Coastguard Worker # 'weight_ih_l0', 'weight_hh_l0', 'bias_ih_l0', 'bias_hh_l0', weight_hr_l0 1753*da0073e9SAndroid Build Coastguard Worker 1754*da0073e9SAndroid Build Coastguard Worker # Applying weight norm on one of them causes it to become a tensor 1755*da0073e9SAndroid Build Coastguard Worker l = torch.nn.utils.weight_norm(l, name=name) 1756*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 1757*da0073e9SAndroid Build Coastguard Worker sum(isinstance(p, torch.nn.Parameter) for p in l._flat_weights), 1758*da0073e9SAndroid Build Coastguard Worker num_params - 1, 1759*da0073e9SAndroid Build Coastguard Worker ) 1760*da0073e9SAndroid Build Coastguard Worker 1761*da0073e9SAndroid Build Coastguard Worker # Removing the weight norm reparametrization restores the Parameter 1762*da0073e9SAndroid Build Coastguard Worker l = torch.nn.utils.remove_weight_norm(l, name=name) 1763*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 1764*da0073e9SAndroid Build Coastguard Worker sum(isinstance(p, torch.nn.Parameter) for p in l._flat_weights), 1765*da0073e9SAndroid Build Coastguard Worker num_params, 1766*da0073e9SAndroid Build Coastguard Worker ) 1767*da0073e9SAndroid Build Coastguard Worker 1768*da0073e9SAndroid Build Coastguard Worker # Make sure that, upon removal of the reparametrization, the 1769*da0073e9SAndroid Build Coastguard Worker # `._parameters` and `.named_parameters` contain the right params. 1770*da0073e9SAndroid Build Coastguard Worker # Specifically, the original weight ('weight_ih_l0') should be placed 1771*da0073e9SAndroid Build Coastguard Worker # back in the parameters, while the reparametrization components 1772*da0073e9SAndroid Build Coastguard Worker # ('weight_ih_l0_v' and 'weight_ih_l0_g') should be removed. 1773*da0073e9SAndroid Build Coastguard Worker self.assertTrue(name in l._parameters) 1774*da0073e9SAndroid Build Coastguard Worker self.assertIsNotNone(l._parameters[name]) 1775*da0073e9SAndroid Build Coastguard Worker self.assertTrue(name + '_v' not in l._parameters) 1776*da0073e9SAndroid Build Coastguard Worker self.assertTrue(name + '_g' not in l._parameters) 1777*da0073e9SAndroid Build Coastguard Worker self.assertTrue(name in dict(l.named_parameters())) 1778*da0073e9SAndroid Build Coastguard Worker self.assertIsNotNone(dict(l.named_parameters())[name]) 1779*da0073e9SAndroid Build Coastguard Worker self.assertTrue(name + '_v' not in dict(l.named_parameters())) 1780*da0073e9SAndroid Build Coastguard Worker self.assertTrue(name + '_g' not in dict(l.named_parameters())) 1781*da0073e9SAndroid Build Coastguard Worker 1782*da0073e9SAndroid Build Coastguard Worker check_weight_norm(torch.nn.LSTM(32, 32), 'weight_ih_l0', 4) 1783*da0073e9SAndroid Build Coastguard Worker check_weight_norm(torch.nn.LSTM(32, 32, proj_size=16), 'weight_hr_l0', 5) 1784*da0073e9SAndroid Build Coastguard Worker 1785*da0073e9SAndroid Build Coastguard Worker 1786*da0073e9SAndroid Build Coastguard Worker def test_weight_norm(self): 1787*da0073e9SAndroid Build Coastguard Worker for dtype in [torch.float, torch.bfloat16]: 1788*da0073e9SAndroid Build Coastguard Worker input = torch.randn(3, 4, dtype=dtype) 1789*da0073e9SAndroid Build Coastguard Worker m = nn.Linear(4, 5).to(dtype=dtype) 1790*da0073e9SAndroid Build Coastguard Worker expected_output = m(input) 1791*da0073e9SAndroid Build Coastguard Worker 1792*da0073e9SAndroid Build Coastguard Worker # add weight normalization 1793*da0073e9SAndroid Build Coastguard Worker m = torch.nn.utils.weight_norm(m) 1794*da0073e9SAndroid Build Coastguard Worker self.assertEqual(m.weight_v.size(), m.weight.size()) 1795*da0073e9SAndroid Build Coastguard Worker self.assertEqual(m.weight_g.size(), (5, 1)) 1796*da0073e9SAndroid Build Coastguard Worker self.assertEqual(m(input), expected_output, atol=dtype2prec_DONTUSE[dtype], rtol=0) 1797*da0073e9SAndroid Build Coastguard Worker 1798*da0073e9SAndroid Build Coastguard Worker # remove weight norm 1799*da0073e9SAndroid Build Coastguard Worker m = torch.nn.utils.remove_weight_norm(m) 1800*da0073e9SAndroid Build Coastguard Worker self.assertFalse(hasattr(m, 'weight_g')) 1801*da0073e9SAndroid Build Coastguard Worker self.assertFalse(hasattr(m, 'weight_v')) 1802*da0073e9SAndroid Build Coastguard Worker self.assertEqual(m(input), expected_output, atol=dtype2prec_DONTUSE[dtype], rtol=0) 1803*da0073e9SAndroid Build Coastguard Worker 1804*da0073e9SAndroid Build Coastguard Worker # test with dim=1 1805*da0073e9SAndroid Build Coastguard Worker m = torch.nn.utils.weight_norm(m, dim=1) 1806*da0073e9SAndroid Build Coastguard Worker self.assertEqual(m.weight_v.size(), m.weight.size()) 1807*da0073e9SAndroid Build Coastguard Worker self.assertEqual(m.weight_g.size(), (1, 4)) 1808*da0073e9SAndroid Build Coastguard Worker self.assertEqual(m(input), expected_output, atol=dtype2prec_DONTUSE[dtype], rtol=0) 1809*da0073e9SAndroid Build Coastguard Worker 1810*da0073e9SAndroid Build Coastguard Worker # test with dim=None 1811*da0073e9SAndroid Build Coastguard Worker m = nn.Linear(4, 5).to(dtype=dtype) 1812*da0073e9SAndroid Build Coastguard Worker expected_output = m(input) 1813*da0073e9SAndroid Build Coastguard Worker m = torch.nn.utils.weight_norm(m, dim=None) 1814*da0073e9SAndroid Build Coastguard Worker self.assertEqual(m(input), expected_output) 1815*da0073e9SAndroid Build Coastguard Worker 1816*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, 'register two weight_norm hooks'): 1817*da0073e9SAndroid Build Coastguard Worker m = torch.nn.utils.weight_norm(m) 1818*da0073e9SAndroid Build Coastguard Worker m = torch.nn.utils.weight_norm(m) 1819*da0073e9SAndroid Build Coastguard Worker 1820*da0073e9SAndroid Build Coastguard Worker # For float16, the forward of the Module doesn't work but we must still be able 1821*da0073e9SAndroid Build Coastguard Worker # to register the weight norm as this is often done before sending the Module to 1822*da0073e9SAndroid Build Coastguard Worker # CUDA. 1823*da0073e9SAndroid Build Coastguard Worker m = nn.Linear(4, 5, dtype=torch.float16) 1824*da0073e9SAndroid Build Coastguard Worker m = torch.nn.utils.weight_norm(m) 1825*da0073e9SAndroid Build Coastguard Worker 1826*da0073e9SAndroid Build Coastguard Worker def test_parameterlistdict_setting_attributes(self): 1827*da0073e9SAndroid Build Coastguard Worker with warnings.catch_warnings(record=True) as w: 1828*da0073e9SAndroid Build Coastguard Worker mod = nn.ParameterList(map(nn.Parameter, [torch.rand(2), torch.rand(2)])) 1829*da0073e9SAndroid Build Coastguard Worker self.assertTrue(len(w) == 0) 1830*da0073e9SAndroid Build Coastguard Worker 1831*da0073e9SAndroid Build Coastguard Worker with warnings.catch_warnings(record=True) as w: 1832*da0073e9SAndroid Build Coastguard Worker mod.train() 1833*da0073e9SAndroid Build Coastguard Worker mod.eval() 1834*da0073e9SAndroid Build Coastguard Worker self.assertTrue(len(w) == 0) 1835*da0073e9SAndroid Build Coastguard Worker 1836*da0073e9SAndroid Build Coastguard Worker with warnings.catch_warnings(record=True) as w: 1837*da0073e9SAndroid Build Coastguard Worker mod = nn.ParameterDict({"a": nn.Parameter(torch.rand(2)), "b": nn.Parameter(torch.rand(2))}) 1838*da0073e9SAndroid Build Coastguard Worker self.assertTrue(len(w) == 0) 1839*da0073e9SAndroid Build Coastguard Worker 1840*da0073e9SAndroid Build Coastguard Worker with warnings.catch_warnings(record=True) as w: 1841*da0073e9SAndroid Build Coastguard Worker mod.train() 1842*da0073e9SAndroid Build Coastguard Worker mod.eval() 1843*da0073e9SAndroid Build Coastguard Worker self.assertTrue(len(w) == 0) 1844*da0073e9SAndroid Build Coastguard Worker 1845*da0073e9SAndroid Build Coastguard Worker def test_parameterlistdict_pickle(self): 1846*da0073e9SAndroid Build Coastguard Worker m = nn.ParameterList(map(nn.Parameter, [torch.rand(2), torch.rand(2)])) 1847*da0073e9SAndroid Build Coastguard Worker with warnings.catch_warnings(record=True) as w: 1848*da0073e9SAndroid Build Coastguard Worker m = pickle.loads(pickle.dumps(m)) 1849*da0073e9SAndroid Build Coastguard Worker self.assertTrue(len(w) == 0) 1850*da0073e9SAndroid Build Coastguard Worker 1851*da0073e9SAndroid Build Coastguard Worker # Test whether loading from older checkpoints works without triggering warnings 1852*da0073e9SAndroid Build Coastguard Worker m = nn.ParameterList(map(nn.Parameter, [torch.rand(2), torch.rand(2)])) 1853*da0073e9SAndroid Build Coastguard Worker del m._forward_pre_hooks, m._state_dict_hooks, m._load_state_dict_pre_hooks, m._non_persistent_buffers_set 1854*da0073e9SAndroid Build Coastguard Worker with warnings.catch_warnings(record=True) as w: 1855*da0073e9SAndroid Build Coastguard Worker m = pickle.loads(pickle.dumps(m)) 1856*da0073e9SAndroid Build Coastguard Worker self.assertTrue(len(w) == 0) 1857*da0073e9SAndroid Build Coastguard Worker 1858*da0073e9SAndroid Build Coastguard Worker m = nn.ParameterDict({"a": nn.Parameter(torch.rand(2)), "b": nn.Parameter(torch.rand(2))}) 1859*da0073e9SAndroid Build Coastguard Worker with warnings.catch_warnings(record=True) as w: 1860*da0073e9SAndroid Build Coastguard Worker m = pickle.loads(pickle.dumps(m)) 1861*da0073e9SAndroid Build Coastguard Worker self.assertTrue(len(w) == 0) 1862*da0073e9SAndroid Build Coastguard Worker 1863*da0073e9SAndroid Build Coastguard Worker # Test whether loading from older checkpoints works without triggering warnings 1864*da0073e9SAndroid Build Coastguard Worker m = nn.ParameterDict({"a": nn.Parameter(torch.rand(2)), "b": nn.Parameter(torch.rand(2))}) 1865*da0073e9SAndroid Build Coastguard Worker del m._forward_pre_hooks, m._state_dict_hooks, m._load_state_dict_pre_hooks, m._non_persistent_buffers_set 1866*da0073e9SAndroid Build Coastguard Worker with warnings.catch_warnings(record=True) as w: 1867*da0073e9SAndroid Build Coastguard Worker m = pickle.loads(pickle.dumps(m)) 1868*da0073e9SAndroid Build Coastguard Worker self.assertTrue(len(w) == 0) 1869*da0073e9SAndroid Build Coastguard Worker 1870*da0073e9SAndroid Build Coastguard Worker def test_weight_norm_pickle(self): 1871*da0073e9SAndroid Build Coastguard Worker m = torch.nn.utils.weight_norm(nn.Linear(5, 7)) 1872*da0073e9SAndroid Build Coastguard Worker m = pickle.loads(pickle.dumps(m)) 1873*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance(m, nn.Linear) 1874*da0073e9SAndroid Build Coastguard Worker 1875*da0073e9SAndroid Build Coastguard Worker @skipIfTorchDynamo("TorchDynamo fails here for unknown reasons") 1876*da0073e9SAndroid Build Coastguard Worker @set_default_dtype(torch.double) 1877*da0073e9SAndroid Build Coastguard Worker def test_spectral_norm(self): 1878*da0073e9SAndroid Build Coastguard Worker input = torch.randn(3, 5) 1879*da0073e9SAndroid Build Coastguard Worker m = nn.Linear(5, 7) 1880*da0073e9SAndroid Build Coastguard Worker m = torch.nn.utils.spectral_norm(m) 1881*da0073e9SAndroid Build Coastguard Worker 1882*da0073e9SAndroid Build Coastguard Worker self.assertEqual(m.weight_u.size(), torch.Size([m.weight.size(0)])) 1883*da0073e9SAndroid Build Coastguard Worker # weight_orig should be trainable 1884*da0073e9SAndroid Build Coastguard Worker self.assertTrue(hasattr(m, 'weight_orig')) 1885*da0073e9SAndroid Build Coastguard Worker self.assertTrue('weight_orig' in m._parameters) 1886*da0073e9SAndroid Build Coastguard Worker # weight_u should be just a reused buffer 1887*da0073e9SAndroid Build Coastguard Worker self.assertTrue(hasattr(m, 'weight_u')) 1888*da0073e9SAndroid Build Coastguard Worker self.assertTrue('weight_u' in m._buffers) 1889*da0073e9SAndroid Build Coastguard Worker self.assertTrue('weight_v' in m._buffers) 1890*da0073e9SAndroid Build Coastguard Worker # weight should be a plain attribute, not counted as a buffer or a param 1891*da0073e9SAndroid Build Coastguard Worker self.assertFalse('weight' in m._buffers) 1892*da0073e9SAndroid Build Coastguard Worker self.assertFalse('weight' in m._parameters) 1893*da0073e9SAndroid Build Coastguard Worker # it should also be sharing storage as `weight_orig` 1894*da0073e9SAndroid Build Coastguard Worker self.assertEqual(m.weight_orig.storage(), m.weight.storage()) 1895*da0073e9SAndroid Build Coastguard Worker self.assertEqual(m.weight_orig.size(), m.weight.size()) 1896*da0073e9SAndroid Build Coastguard Worker self.assertEqual(m.weight_orig.stride(), m.weight.stride()) 1897*da0073e9SAndroid Build Coastguard Worker 1898*da0073e9SAndroid Build Coastguard Worker m = torch.nn.utils.remove_spectral_norm(m) 1899*da0073e9SAndroid Build Coastguard Worker self.assertFalse(hasattr(m, 'weight_orig')) 1900*da0073e9SAndroid Build Coastguard Worker self.assertFalse(hasattr(m, 'weight_u')) 1901*da0073e9SAndroid Build Coastguard Worker # weight should be converted back as a parameter 1902*da0073e9SAndroid Build Coastguard Worker self.assertTrue(hasattr(m, 'weight')) 1903*da0073e9SAndroid Build Coastguard Worker self.assertTrue('weight' in m._parameters) 1904*da0073e9SAndroid Build Coastguard Worker 1905*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, 'register two spectral_norm hooks'): 1906*da0073e9SAndroid Build Coastguard Worker m = torch.nn.utils.spectral_norm(m) 1907*da0073e9SAndroid Build Coastguard Worker m = torch.nn.utils.spectral_norm(m) 1908*da0073e9SAndroid Build Coastguard Worker 1909*da0073e9SAndroid Build Coastguard Worker # test correctness in training/eval modes and cpu/multi-gpu settings 1910*da0073e9SAndroid Build Coastguard Worker for apply_dp in (True, False): 1911*da0073e9SAndroid Build Coastguard Worker if apply_dp: 1912*da0073e9SAndroid Build Coastguard Worker if not TEST_MULTIGPU: 1913*da0073e9SAndroid Build Coastguard Worker continue 1914*da0073e9SAndroid Build Coastguard Worker device = torch.device('cuda:0') 1915*da0073e9SAndroid Build Coastguard Worker 1916*da0073e9SAndroid Build Coastguard Worker def maybe_wrap(m): 1917*da0073e9SAndroid Build Coastguard Worker return torch.nn.DataParallel(m, [0, 1]) 1918*da0073e9SAndroid Build Coastguard Worker else: 1919*da0073e9SAndroid Build Coastguard Worker device = torch.device('cpu') 1920*da0073e9SAndroid Build Coastguard Worker 1921*da0073e9SAndroid Build Coastguard Worker def maybe_wrap(m): 1922*da0073e9SAndroid Build Coastguard Worker return m 1923*da0073e9SAndroid Build Coastguard Worker 1924*da0073e9SAndroid Build Coastguard Worker for requires_grad in (True, False): 1925*da0073e9SAndroid Build Coastguard Worker m = nn.Linear(3, 4).to(device) 1926*da0073e9SAndroid Build Coastguard Worker m.weight.requires_grad_(requires_grad) 1927*da0073e9SAndroid Build Coastguard Worker m = torch.nn.utils.spectral_norm(m) 1928*da0073e9SAndroid Build Coastguard Worker wrapped_m = maybe_wrap(m) 1929*da0073e9SAndroid Build Coastguard Worker self.assertTrue(hasattr(m, 'weight_u')) 1930*da0073e9SAndroid Build Coastguard Worker u0 = m.weight_u.clone() 1931*da0073e9SAndroid Build Coastguard Worker v0 = m.weight_v.clone() 1932*da0073e9SAndroid Build Coastguard Worker 1933*da0073e9SAndroid Build Coastguard Worker # TEST TRAINING BEHAVIOR 1934*da0073e9SAndroid Build Coastguard Worker 1935*da0073e9SAndroid Build Coastguard Worker # assert that u and v are updated 1936*da0073e9SAndroid Build Coastguard Worker input = torch.randn(2, 3, device=device) 1937*da0073e9SAndroid Build Coastguard Worker out = wrapped_m(input) 1938*da0073e9SAndroid Build Coastguard Worker self.assertNotEqual(u0, m.weight_u) 1939*da0073e9SAndroid Build Coastguard Worker self.assertNotEqual(v0, m.weight_v) 1940*da0073e9SAndroid Build Coastguard Worker 1941*da0073e9SAndroid Build Coastguard Worker # assert that backprop reaches weight_orig 1942*da0073e9SAndroid Build Coastguard Worker # can't use gradcheck because the function changes as we 1943*da0073e9SAndroid Build Coastguard Worker # activate through it in training mode 1944*da0073e9SAndroid Build Coastguard Worker if requires_grad: 1945*da0073e9SAndroid Build Coastguard Worker torch.autograd.grad(out.sum(), m.weight_orig) 1946*da0073e9SAndroid Build Coastguard Worker 1947*da0073e9SAndroid Build Coastguard Worker # test backward works with multiple forwards 1948*da0073e9SAndroid Build Coastguard Worker # it uses training mode so we need to reset `u` and `v` vectors 1949*da0073e9SAndroid Build Coastguard Worker # to same value at beginning for finite difference test to pass 1950*da0073e9SAndroid Build Coastguard Worker saved_u = m.weight_u.clone() 1951*da0073e9SAndroid Build Coastguard Worker saved_v = m.weight_v.clone() 1952*da0073e9SAndroid Build Coastguard Worker 1953*da0073e9SAndroid Build Coastguard Worker def fn(input): 1954*da0073e9SAndroid Build Coastguard Worker m.weight_u.data.copy_(saved_u) 1955*da0073e9SAndroid Build Coastguard Worker m.weight_v.data.copy_(saved_v) 1956*da0073e9SAndroid Build Coastguard Worker out0 = wrapped_m(input) 1957*da0073e9SAndroid Build Coastguard Worker out1 = wrapped_m(input) 1958*da0073e9SAndroid Build Coastguard Worker return out0 + out1 1959*da0073e9SAndroid Build Coastguard Worker 1960*da0073e9SAndroid Build Coastguard Worker gradcheck(fn, (input.clone().requires_grad_(),), check_batched_grad=False) 1961*da0073e9SAndroid Build Coastguard Worker 1962*da0073e9SAndroid Build Coastguard Worker # test removing 1963*da0073e9SAndroid Build Coastguard Worker pre_remove_out = wrapped_m(input) 1964*da0073e9SAndroid Build Coastguard Worker m = torch.nn.utils.remove_spectral_norm(m) 1965*da0073e9SAndroid Build Coastguard Worker self.assertEqual(wrapped_m(input), pre_remove_out) 1966*da0073e9SAndroid Build Coastguard Worker 1967*da0073e9SAndroid Build Coastguard Worker m = torch.nn.utils.spectral_norm(m) 1968*da0073e9SAndroid Build Coastguard Worker for _ in range(3): 1969*da0073e9SAndroid Build Coastguard Worker pre_remove_out = wrapped_m(input) 1970*da0073e9SAndroid Build Coastguard Worker m = torch.nn.utils.remove_spectral_norm(m) 1971*da0073e9SAndroid Build Coastguard Worker self.assertEqual(wrapped_m(input), pre_remove_out) 1972*da0073e9SAndroid Build Coastguard Worker 1973*da0073e9SAndroid Build Coastguard Worker # TEST EVAL BEHAVIOR 1974*da0073e9SAndroid Build Coastguard Worker 1975*da0073e9SAndroid Build Coastguard Worker m = torch.nn.utils.spectral_norm(m) 1976*da0073e9SAndroid Build Coastguard Worker wrapped_m(input) 1977*da0073e9SAndroid Build Coastguard Worker last_train_out = wrapped_m(input) 1978*da0073e9SAndroid Build Coastguard Worker last_train_u = m.weight_u.clone() 1979*da0073e9SAndroid Build Coastguard Worker last_train_v = m.weight_v.clone() 1980*da0073e9SAndroid Build Coastguard Worker wrapped_m.zero_grad() 1981*da0073e9SAndroid Build Coastguard Worker wrapped_m.eval() 1982*da0073e9SAndroid Build Coastguard Worker 1983*da0073e9SAndroid Build Coastguard Worker eval_out0 = wrapped_m(input) 1984*da0073e9SAndroid Build Coastguard Worker # assert eval gives same result as last training iteration 1985*da0073e9SAndroid Build Coastguard Worker self.assertEqual(eval_out0, last_train_out) 1986*da0073e9SAndroid Build Coastguard Worker # assert doing more iteartion in eval don't change things 1987*da0073e9SAndroid Build Coastguard Worker self.assertEqual(eval_out0, wrapped_m(input)) 1988*da0073e9SAndroid Build Coastguard Worker self.assertEqual(last_train_u, m.weight_u) 1989*da0073e9SAndroid Build Coastguard Worker self.assertEqual(last_train_v, m.weight_v) 1990*da0073e9SAndroid Build Coastguard Worker 1991*da0073e9SAndroid Build Coastguard Worker # FIXME: the code below is flaky when executed with DataParallel 1992*da0073e9SAndroid Build Coastguard Worker # see https://github.com/pytorch/pytorch/issues/13818 1993*da0073e9SAndroid Build Coastguard Worker if apply_dp: 1994*da0073e9SAndroid Build Coastguard Worker continue 1995*da0073e9SAndroid Build Coastguard Worker 1996*da0073e9SAndroid Build Coastguard Worker # test backward works with multiple forwards in mixed training 1997*da0073e9SAndroid Build Coastguard Worker # and eval modes 1998*da0073e9SAndroid Build Coastguard Worker # it uses training mode so we need to reset `u` and `v` vectors 1999*da0073e9SAndroid Build Coastguard Worker # to same value at beginning for finite difference test to pass 2000*da0073e9SAndroid Build Coastguard Worker saved_u = m.weight_u.clone() 2001*da0073e9SAndroid Build Coastguard Worker saved_v = m.weight_v.clone() 2002*da0073e9SAndroid Build Coastguard Worker 2003*da0073e9SAndroid Build Coastguard Worker def fn(input): 2004*da0073e9SAndroid Build Coastguard Worker m.weight_u.data.copy_(saved_u) 2005*da0073e9SAndroid Build Coastguard Worker m.weight_v.data.copy_(saved_v) 2006*da0073e9SAndroid Build Coastguard Worker wrapped_m.train() 2007*da0073e9SAndroid Build Coastguard Worker out0 = wrapped_m(input) 2008*da0073e9SAndroid Build Coastguard Worker wrapped_m.eval() 2009*da0073e9SAndroid Build Coastguard Worker out1 = wrapped_m(input) 2010*da0073e9SAndroid Build Coastguard Worker wrapped_m.train() 2011*da0073e9SAndroid Build Coastguard Worker out2 = wrapped_m(input) 2012*da0073e9SAndroid Build Coastguard Worker wrapped_m.eval() 2013*da0073e9SAndroid Build Coastguard Worker out3 = wrapped_m(input) 2014*da0073e9SAndroid Build Coastguard Worker return out0 + out1 + out2 + out3 2015*da0073e9SAndroid Build Coastguard Worker 2016*da0073e9SAndroid Build Coastguard Worker gradcheck(fn, (input.clone().requires_grad_(),)) 2017*da0073e9SAndroid Build Coastguard Worker 2018*da0073e9SAndroid Build Coastguard Worker # assert that backprop reaches weight_orig in eval 2019*da0073e9SAndroid Build Coastguard Worker if requires_grad: 2020*da0073e9SAndroid Build Coastguard Worker def fn(weight): 2021*da0073e9SAndroid Build Coastguard Worker return wrapped_m(input) 2022*da0073e9SAndroid Build Coastguard Worker 2023*da0073e9SAndroid Build Coastguard Worker gradcheck(fn, (m.weight_orig,)) 2024*da0073e9SAndroid Build Coastguard Worker 2025*da0073e9SAndroid Build Coastguard Worker @skipIfNoLapack 2026*da0073e9SAndroid Build Coastguard Worker def test_spectral_norm_load_state_dict(self): 2027*da0073e9SAndroid Build Coastguard Worker inp = torch.randn(2, 3) 2028*da0073e9SAndroid Build Coastguard Worker for activate_times in (0, 3): 2029*da0073e9SAndroid Build Coastguard Worker # Test backward compatibility 2030*da0073e9SAndroid Build Coastguard Worker # At version None -> 1: weight becomes not a buffer and v vector becomes a buffer 2031*da0073e9SAndroid Build Coastguard Worker m = nn.Linear(3, 5) 2032*da0073e9SAndroid Build Coastguard Worker snm = torch.nn.utils.spectral_norm(m) 2033*da0073e9SAndroid Build Coastguard Worker snm.train() 2034*da0073e9SAndroid Build Coastguard Worker for _ in range(activate_times): 2035*da0073e9SAndroid Build Coastguard Worker snm(inp) 2036*da0073e9SAndroid Build Coastguard Worker 2037*da0073e9SAndroid Build Coastguard Worker version_latest_ref_state_dict = deepcopy(snm.state_dict()) 2038*da0073e9SAndroid Build Coastguard Worker self.assertEqual({'weight_orig', 'bias', 'weight_u', 'weight_v'}, set(version_latest_ref_state_dict.keys())) 2039*da0073e9SAndroid Build Coastguard Worker 2040*da0073e9SAndroid Build Coastguard Worker # test that non-strict loading works 2041*da0073e9SAndroid Build Coastguard Worker non_strict_state_dict = deepcopy(version_latest_ref_state_dict) 2042*da0073e9SAndroid Build Coastguard Worker non_strict_state_dict['nonsense'] = 'nonsense' 2043*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, r'Unexpected key\(s\) in state_dict: "nonsense"'): 2044*da0073e9SAndroid Build Coastguard Worker snm.load_state_dict(non_strict_state_dict, strict=True) 2045*da0073e9SAndroid Build Coastguard Worker snm.load_state_dict(non_strict_state_dict, strict=False) 2046*da0073e9SAndroid Build Coastguard Worker del non_strict_state_dict['weight_orig'] 2047*da0073e9SAndroid Build Coastguard Worker snm.load_state_dict(non_strict_state_dict, strict=False) 2048*da0073e9SAndroid Build Coastguard Worker del non_strict_state_dict['weight_u'] 2049*da0073e9SAndroid Build Coastguard Worker snm.load_state_dict(non_strict_state_dict, strict=False) 2050*da0073e9SAndroid Build Coastguard Worker del non_strict_state_dict['weight_v'] 2051*da0073e9SAndroid Build Coastguard Worker snm.load_state_dict(non_strict_state_dict, strict=False) 2052*da0073e9SAndroid Build Coastguard Worker non_strict_state_dict['weight'] = snm.weight.detach().clone() # set W as a buffer 2053*da0073e9SAndroid Build Coastguard Worker snm.load_state_dict(non_strict_state_dict, strict=False) 2054*da0073e9SAndroid Build Coastguard Worker del non_strict_state_dict._metadata['']['spectral_norm'] # remove metadata info 2055*da0073e9SAndroid Build Coastguard Worker snm.load_state_dict(non_strict_state_dict, strict=False) 2056*da0073e9SAndroid Build Coastguard Worker del non_strict_state_dict['weight'] # remove W buffer 2057*da0073e9SAndroid Build Coastguard Worker snm.load_state_dict(non_strict_state_dict, strict=False) 2058*da0073e9SAndroid Build Coastguard Worker del non_strict_state_dict['bias'] 2059*da0073e9SAndroid Build Coastguard Worker snm.load_state_dict(non_strict_state_dict, strict=False) 2060*da0073e9SAndroid Build Coastguard Worker 2061*da0073e9SAndroid Build Coastguard Worker # craft a version None state_dict 2062*da0073e9SAndroid Build Coastguard Worker version_none_state_dict = deepcopy(version_latest_ref_state_dict) 2063*da0073e9SAndroid Build Coastguard Worker self.assertIn('spectral_norm', version_none_state_dict._metadata['']) 2064*da0073e9SAndroid Build Coastguard Worker del version_none_state_dict._metadata['']['spectral_norm'] # remove metadata info 2065*da0073e9SAndroid Build Coastguard Worker del version_none_state_dict['weight_v'] # remove v vector 2066*da0073e9SAndroid Build Coastguard Worker version_none_state_dict['weight'] = snm.weight.detach().clone() # set W as a buffer 2067*da0073e9SAndroid Build Coastguard Worker 2068*da0073e9SAndroid Build Coastguard Worker # normal state_dict 2069*da0073e9SAndroid Build Coastguard Worker for version_latest_with_metadata in [True, False]: 2070*da0073e9SAndroid Build Coastguard Worker version_latest_state_dict = deepcopy(version_latest_ref_state_dict) 2071*da0073e9SAndroid Build Coastguard Worker 2072*da0073e9SAndroid Build Coastguard Worker if not version_latest_with_metadata: 2073*da0073e9SAndroid Build Coastguard Worker # We want to still load a user-crafted state_dict, one without metadata 2074*da0073e9SAndroid Build Coastguard Worker del version_latest_state_dict._metadata['']['spectral_norm'] 2075*da0073e9SAndroid Build Coastguard Worker 2076*da0073e9SAndroid Build Coastguard Worker # test that re-wrapping does not matter 2077*da0073e9SAndroid Build Coastguard Worker m = torch.nn.utils.remove_spectral_norm(snm) 2078*da0073e9SAndroid Build Coastguard Worker snm = torch.nn.utils.spectral_norm(m) 2079*da0073e9SAndroid Build Coastguard Worker 2080*da0073e9SAndroid Build Coastguard Worker snm.load_state_dict(version_latest_ref_state_dict) 2081*da0073e9SAndroid Build Coastguard Worker with torch.no_grad(): 2082*da0073e9SAndroid Build Coastguard Worker snm.eval() 2083*da0073e9SAndroid Build Coastguard Worker out0_eval = snm(inp) 2084*da0073e9SAndroid Build Coastguard Worker snm.train() 2085*da0073e9SAndroid Build Coastguard Worker out1_train = snm(inp) 2086*da0073e9SAndroid Build Coastguard Worker out2_train = snm(inp) 2087*da0073e9SAndroid Build Coastguard Worker snm.eval() 2088*da0073e9SAndroid Build Coastguard Worker out3_eval = snm(inp) 2089*da0073e9SAndroid Build Coastguard Worker 2090*da0073e9SAndroid Build Coastguard Worker # test that re-wrapping does not matter 2091*da0073e9SAndroid Build Coastguard Worker m = torch.nn.utils.remove_spectral_norm(snm) 2092*da0073e9SAndroid Build Coastguard Worker snm = torch.nn.utils.spectral_norm(m) 2093*da0073e9SAndroid Build Coastguard Worker 2094*da0073e9SAndroid Build Coastguard Worker snm.load_state_dict(version_none_state_dict) 2095*da0073e9SAndroid Build Coastguard Worker if activate_times > 0: 2096*da0073e9SAndroid Build Coastguard Worker # since in loading version None state dict, we assume that the 2097*da0073e9SAndroid Build Coastguard Worker # values in the state dict have gone through at lease one 2098*da0073e9SAndroid Build Coastguard Worker # forward, we only test for equivalence when activate_times > 0. 2099*da0073e9SAndroid Build Coastguard Worker with torch.no_grad(): 2100*da0073e9SAndroid Build Coastguard Worker snm.eval() 2101*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out0_eval, snm(inp)) 2102*da0073e9SAndroid Build Coastguard Worker snm.train() 2103*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out1_train, snm(inp)) 2104*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out2_train, snm(inp)) 2105*da0073e9SAndroid Build Coastguard Worker snm.eval() 2106*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out3_eval, snm(inp)) 2107*da0073e9SAndroid Build Coastguard Worker 2108*da0073e9SAndroid Build Coastguard Worker # test that re-wrapping does not matter 2109*da0073e9SAndroid Build Coastguard Worker m = torch.nn.utils.remove_spectral_norm(snm) 2110*da0073e9SAndroid Build Coastguard Worker snm = torch.nn.utils.spectral_norm(m) 2111*da0073e9SAndroid Build Coastguard Worker 2112*da0073e9SAndroid Build Coastguard Worker # Test normal loading 2113*da0073e9SAndroid Build Coastguard Worker snm.load_state_dict(version_latest_state_dict) 2114*da0073e9SAndroid Build Coastguard Worker with torch.no_grad(): 2115*da0073e9SAndroid Build Coastguard Worker snm.eval() 2116*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out0_eval, snm(inp)) 2117*da0073e9SAndroid Build Coastguard Worker snm.train() 2118*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out1_train, snm(inp)) 2119*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out2_train, snm(inp)) 2120*da0073e9SAndroid Build Coastguard Worker snm.eval() 2121*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out3_eval, snm(inp)) 2122*da0073e9SAndroid Build Coastguard Worker 2123*da0073e9SAndroid Build Coastguard Worker def test_spectral_norm_dim(self): 2124*da0073e9SAndroid Build Coastguard Worker inp = torch.randn(2, 3, 10, 12) 2125*da0073e9SAndroid Build Coastguard Worker m = nn.ConvTranspose2d(3, 4, (5, 6)) 2126*da0073e9SAndroid Build Coastguard Worker m = torch.nn.utils.spectral_norm(m) 2127*da0073e9SAndroid Build Coastguard Worker # this should not run into incompatible shapes 2128*da0073e9SAndroid Build Coastguard Worker x = m(inp) 2129*da0073e9SAndroid Build Coastguard Worker # check that u refers to the same dimension 2130*da0073e9SAndroid Build Coastguard Worker self.assertEqual(m.weight_u.shape, m.weight_orig[0, :, 0, 0].shape) 2131*da0073e9SAndroid Build Coastguard Worker 2132*da0073e9SAndroid Build Coastguard Worker def test_spectral_norm_forward(self): 2133*da0073e9SAndroid Build Coastguard Worker input = torch.randn(3, 5) 2134*da0073e9SAndroid Build Coastguard Worker m = nn.Linear(5, 7) 2135*da0073e9SAndroid Build Coastguard Worker m = torch.nn.utils.spectral_norm(m) 2136*da0073e9SAndroid Build Coastguard Worker # naive forward 2137*da0073e9SAndroid Build Coastguard Worker _weight, _bias, _u = m.weight_orig, m.bias, m.weight_u 2138*da0073e9SAndroid Build Coastguard Worker _weight_mat = _weight.view(_weight.size(0), -1) 2139*da0073e9SAndroid Build Coastguard Worker _v = torch.mv(_weight_mat.t(), _u) 2140*da0073e9SAndroid Build Coastguard Worker _v = F.normalize(_v, dim=0, eps=1e-12) 2141*da0073e9SAndroid Build Coastguard Worker _u = torch.mv(_weight_mat, _v) 2142*da0073e9SAndroid Build Coastguard Worker _u = F.normalize(_u, dim=0, eps=1e-12) 2143*da0073e9SAndroid Build Coastguard Worker _weight.data /= torch.dot(_u, torch.matmul(_weight_mat, _v)) 2144*da0073e9SAndroid Build Coastguard Worker out_hat = torch.nn.functional.linear(input, _weight, _bias) 2145*da0073e9SAndroid Build Coastguard Worker expect_out = m(input) 2146*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expect_out, out_hat) 2147*da0073e9SAndroid Build Coastguard Worker 2148*da0073e9SAndroid Build Coastguard Worker def test_spectral_norm_pickle(self): 2149*da0073e9SAndroid Build Coastguard Worker m = torch.nn.utils.spectral_norm(nn.Linear(5, 7)) 2150*da0073e9SAndroid Build Coastguard Worker m = pickle.loads(pickle.dumps(m)) 2151*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance(m, nn.Linear) 2152*da0073e9SAndroid Build Coastguard Worker 2153*da0073e9SAndroid Build Coastguard Worker def test_threshold_int(self): 2154*da0073e9SAndroid Build Coastguard Worker x = torch.tensor([-3, -2, -1, 0, 1, 2, 3]) 2155*da0073e9SAndroid Build Coastguard Worker expected = torch.tensor([99, 99, 99, 99, 1, 2, 3]) 2156*da0073e9SAndroid Build Coastguard Worker self.assertEqual(F.threshold(x, 0, 99), expected) 2157*da0073e9SAndroid Build Coastguard Worker 2158*da0073e9SAndroid Build Coastguard Worker def test_threshold_bfloat16_half(self): 2159*da0073e9SAndroid Build Coastguard Worker x = torch.randn(100) 2160*da0073e9SAndroid Build Coastguard Worker for dtype in [torch.bfloat16, torch.half]: 2161*da0073e9SAndroid Build Coastguard Worker for threshold in [0, -0.5, 0.5, float('inf'), float('-inf'), float('nan')]: 2162*da0073e9SAndroid Build Coastguard Worker expected = F.threshold(x, threshold, 0).to(dtype=dtype).float() 2163*da0073e9SAndroid Build Coastguard Worker res_bf16 = F.threshold(x.to(dtype=dtype), threshold, 0).float() 2164*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res_bf16, expected) 2165*da0073e9SAndroid Build Coastguard Worker 2166*da0073e9SAndroid Build Coastguard Worker @unittest.skipUnless('fbgemm' in torch.backends.quantized.supported_engines, 2167*da0073e9SAndroid Build Coastguard Worker 'Linear_FP16_weight requires FBGEMM. FBGEMM is only optimized for CPUs' 2168*da0073e9SAndroid Build Coastguard Worker ' with instruction set support avx2 or newer.') 2169*da0073e9SAndroid Build Coastguard Worker def test_fb_fc_packed(self): 2170*da0073e9SAndroid Build Coastguard Worker X = np.random.rand(16, 16).astype(np.float32) - 0.5 2171*da0073e9SAndroid Build Coastguard Worker W = np.random.rand(16, 16).astype(np.float32) - 0.5 2172*da0073e9SAndroid Build Coastguard Worker b = np.random.rand(16).astype(np.float32) - 0.5 2173*da0073e9SAndroid Build Coastguard Worker 2174*da0073e9SAndroid Build Coastguard Worker def fc_op(X, W, b): 2175*da0073e9SAndroid Build Coastguard Worker return np.dot(X, W.T) + b 2176*da0073e9SAndroid Build Coastguard Worker 2177*da0073e9SAndroid Build Coastguard Worker x_tensor = torch.tensor(X) 2178*da0073e9SAndroid Build Coastguard Worker w_tensor = torch.tensor(W) 2179*da0073e9SAndroid Build Coastguard Worker b_tensor = torch.tensor(b) 2180*da0073e9SAndroid Build Coastguard Worker packed_w_tensor = torch.fbgemm_pack_gemm_matrix_fp16(w_tensor) 2181*da0073e9SAndroid Build Coastguard Worker actual_output = torch.fbgemm_linear_fp16_weight(x_tensor, packed_w_tensor, b_tensor) 2182*da0073e9SAndroid Build Coastguard Worker expected_output = fc_op(X, W, b) 2183*da0073e9SAndroid Build Coastguard Worker torch.testing.assert_close(torch.from_numpy(expected_output), actual_output.cpu(), atol=1e-3, rtol=1e-3) 2184*da0073e9SAndroid Build Coastguard Worker 2185*da0073e9SAndroid Build Coastguard Worker def test_pad_scalar_error(self): 2186*da0073e9SAndroid Build Coastguard Worker inputs = torch.tensor(0., requires_grad=True) 2187*da0073e9SAndroid Build Coastguard Worker self.assertRaises(RuntimeError, lambda: F.pad(inputs, (1, 1))) 2188*da0073e9SAndroid Build Coastguard Worker self.assertRaises(RuntimeError, lambda: F.pad(inputs, (1,))) 2189*da0073e9SAndroid Build Coastguard Worker 2190*da0073e9SAndroid Build Coastguard Worker def test_nested_tensor_from_mask(self): 2191*da0073e9SAndroid Build Coastguard Worker N, L, D = 10, 12, 14 2192*da0073e9SAndroid Build Coastguard Worker 2193*da0073e9SAndroid Build Coastguard Worker input = torch.rand(N, L, D) 2194*da0073e9SAndroid Build Coastguard Worker mask = torch.ones(N, L, dtype=torch.bool) 2195*da0073e9SAndroid Build Coastguard Worker # Leave first row be all True to maintain the nt's size unchanged 2196*da0073e9SAndroid Build Coastguard Worker for i in range(1, N): 2197*da0073e9SAndroid Build Coastguard Worker end = torch.randint(1, L, size=()).item() 2198*da0073e9SAndroid Build Coastguard Worker mask[i, end:] = False 2199*da0073e9SAndroid Build Coastguard Worker 2200*da0073e9SAndroid Build Coastguard Worker nt = torch._nested_tensor_from_mask(input, mask) 2201*da0073e9SAndroid Build Coastguard Worker input_convert = nt.to_padded_tensor(0.) 2202*da0073e9SAndroid Build Coastguard Worker input.masked_fill_(mask.reshape(N, L, 1).logical_not(), 0.) 2203*da0073e9SAndroid Build Coastguard Worker 2204*da0073e9SAndroid Build Coastguard Worker self.assertEqual(input, input_convert) 2205*da0073e9SAndroid Build Coastguard Worker 2206*da0073e9SAndroid Build Coastguard Worker def test_nested_tensor_from_mask_error(self): 2207*da0073e9SAndroid Build Coastguard Worker N, L, D = 10, 12, 14 2208*da0073e9SAndroid Build Coastguard Worker 2209*da0073e9SAndroid Build Coastguard Worker input = torch.rand(N, L, D) 2210*da0073e9SAndroid Build Coastguard Worker # Mask is not bool 2211*da0073e9SAndroid Build Coastguard Worker mask = torch.zeros(N, L, dtype=torch.float) 2212*da0073e9SAndroid Build Coastguard Worker self.assertRaises(RuntimeError, lambda: torch._nested_tensor_from_mask(input, mask)) 2213*da0073e9SAndroid Build Coastguard Worker 2214*da0073e9SAndroid Build Coastguard Worker # Mask size is not 2 2215*da0073e9SAndroid Build Coastguard Worker mask = torch.zeros(N, L, D, dtype=torch.bool) 2216*da0073e9SAndroid Build Coastguard Worker self.assertRaises(RuntimeError, lambda: torch._nested_tensor_from_mask(input, mask)) 2217*da0073e9SAndroid Build Coastguard Worker 2218*da0073e9SAndroid Build Coastguard Worker # Input size is not 3 2219*da0073e9SAndroid Build Coastguard Worker mask = torch.zeros(N, L, dtype=torch.bool) 2220*da0073e9SAndroid Build Coastguard Worker input = torch.rand(N, L) 2221*da0073e9SAndroid Build Coastguard Worker self.assertRaises(RuntimeError, lambda: torch._nested_tensor_from_mask(input, mask)) 2222*da0073e9SAndroid Build Coastguard Worker 2223*da0073e9SAndroid Build Coastguard Worker # Mask size does not match input 2224*da0073e9SAndroid Build Coastguard Worker mask = torch.zeros(N + 1, L + 1, dtype=torch.bool) 2225*da0073e9SAndroid Build Coastguard Worker input = torch.rand(N, L, D) 2226*da0073e9SAndroid Build Coastguard Worker self.assertRaises(RuntimeError, lambda: torch._nested_tensor_from_mask(input, mask)) 2227*da0073e9SAndroid Build Coastguard Worker 2228*da0073e9SAndroid Build Coastguard Worker # Mask is not padding format 2229*da0073e9SAndroid Build Coastguard Worker mask = torch.ones(N, L, dtype=torch.bool) 2230*da0073e9SAndroid Build Coastguard Worker mask[0, 0] = False 2231*da0073e9SAndroid Build Coastguard Worker mask[0, 2] = False 2232*da0073e9SAndroid Build Coastguard Worker self.assertRaises(RuntimeError, lambda: torch._nested_tensor_from_mask(input, mask)) 2233*da0073e9SAndroid Build Coastguard Worker 2234*da0073e9SAndroid Build Coastguard Worker def test_normalize(self): 2235*da0073e9SAndroid Build Coastguard Worker inputs = torch.randn(1, 3, 4, 4, requires_grad=True, dtype=torch.double) 2236*da0073e9SAndroid Build Coastguard Worker self.assertTrue(gradcheck(lambda x: F.normalize(x, p=1, dim=-1), (inputs,))) 2237*da0073e9SAndroid Build Coastguard Worker self.assertTrue(gradcheck(lambda x: F.normalize(x, p=2, dim=-2), (inputs,))) 2238*da0073e9SAndroid Build Coastguard Worker 2239*da0073e9SAndroid Build Coastguard Worker inputs = torch.randn((), requires_grad=True) 2240*da0073e9SAndroid Build Coastguard Worker self.assertTrue(gradcheck(lambda x: F.normalize(x, p=1, dim=-1), (inputs,))) 2241*da0073e9SAndroid Build Coastguard Worker 2242*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported") 2243*da0073e9SAndroid Build Coastguard Worker # Skip the test for ROCm as per https://github.com/pytorch/pytorch/issues/53190 2244*da0073e9SAndroid Build Coastguard Worker @skipIfRocm 2245*da0073e9SAndroid Build Coastguard Worker def test_broadcast_double_backwards_gpu(self): 2246*da0073e9SAndroid Build Coastguard Worker tensors = (torch.randn(4, 4, device='cuda', requires_grad=True, dtype=torch.double), 2247*da0073e9SAndroid Build Coastguard Worker torch.randn(4, 4, device='cuda', requires_grad=True, dtype=torch.double), 2248*da0073e9SAndroid Build Coastguard Worker torch.randn(4, 4, device='cuda', requires_grad=True, dtype=torch.double)) 2249*da0073e9SAndroid Build Coastguard Worker # TODO(#50743): the following segfaults with check_batched_grad=True 2250*da0073e9SAndroid Build Coastguard Worker _assertGradAndGradgradChecks(self, lambda *i: Broadcast.apply((0, 1), *i), tensors, 2251*da0073e9SAndroid Build Coastguard Worker check_batched_grad=False) 2252*da0073e9SAndroid Build Coastguard Worker 2253*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported") 2254*da0073e9SAndroid Build Coastguard Worker def test_broadcast_not_requiring_grad(self): 2255*da0073e9SAndroid Build Coastguard Worker variables = [ 2256*da0073e9SAndroid Build Coastguard Worker torch.randn(1, 2, device='cuda', requires_grad=True), 2257*da0073e9SAndroid Build Coastguard Worker torch.randn(1, 2, device='cuda', requires_grad=False), 2258*da0073e9SAndroid Build Coastguard Worker torch.randn(1, 2, device='cuda', requires_grad=False), 2259*da0073e9SAndroid Build Coastguard Worker torch.randn(1, 2, device='cuda', requires_grad=True), 2260*da0073e9SAndroid Build Coastguard Worker torch.randn(1, 2, device='cuda', requires_grad=True), 2261*da0073e9SAndroid Build Coastguard Worker ] 2262*da0073e9SAndroid Build Coastguard Worker broadcasted_variables = Broadcast.apply((0, 1), *variables) 2263*da0073e9SAndroid Build Coastguard Worker for output_idx, broadcasted_var in enumerate(broadcasted_variables): 2264*da0073e9SAndroid Build Coastguard Worker input_var = variables[output_idx % len(variables)] 2265*da0073e9SAndroid Build Coastguard Worker self.assertEqual(input_var.requires_grad, broadcasted_var.requires_grad) 2266*da0073e9SAndroid Build Coastguard Worker 2267*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported") 2268*da0073e9SAndroid Build Coastguard Worker def test_broadcast_no_grad(self): 2269*da0073e9SAndroid Build Coastguard Worker x = torch.randn(1, 2, dtype=torch.float32, requires_grad=True, device='cuda') 2270*da0073e9SAndroid Build Coastguard Worker with torch.no_grad(): 2271*da0073e9SAndroid Build Coastguard Worker broadcasted = Broadcast.apply((0, 1), x) 2272*da0073e9SAndroid Build Coastguard Worker self.assertTrue(x.requires_grad) 2273*da0073e9SAndroid Build Coastguard Worker for output in broadcasted: 2274*da0073e9SAndroid Build Coastguard Worker self.assertFalse(output.requires_grad) 2275*da0073e9SAndroid Build Coastguard Worker 2276*da0073e9SAndroid Build Coastguard Worker def test_state_dict(self): 2277*da0073e9SAndroid Build Coastguard Worker l = nn.Linear(5, 5) 2278*da0073e9SAndroid Build Coastguard Worker block = nn.Module() 2279*da0073e9SAndroid Build Coastguard Worker block.conv = nn.Conv2d(3, 3, 3, bias=False) 2280*da0073e9SAndroid Build Coastguard Worker net = nn.Module() 2281*da0073e9SAndroid Build Coastguard Worker net.linear1 = l 2282*da0073e9SAndroid Build Coastguard Worker net.linear2 = l 2283*da0073e9SAndroid Build Coastguard Worker net.bn = nn.BatchNorm2d(2) 2284*da0073e9SAndroid Build Coastguard Worker net.block = block 2285*da0073e9SAndroid Build Coastguard Worker net.add_module('empty', None) 2286*da0073e9SAndroid Build Coastguard Worker 2287*da0073e9SAndroid Build Coastguard Worker state_dict = net.state_dict() 2288*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(state_dict), 10) 2289*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(state_dict._metadata), 6) 2290*da0073e9SAndroid Build Coastguard Worker self.assertIn('', state_dict._metadata) 2291*da0073e9SAndroid Build Coastguard Worker self.assertIn('linear1', state_dict._metadata) 2292*da0073e9SAndroid Build Coastguard Worker self.assertIn('linear1.weight', state_dict) 2293*da0073e9SAndroid Build Coastguard Worker self.assertIn('linear1.bias', state_dict) 2294*da0073e9SAndroid Build Coastguard Worker self.assertIn('linear2', state_dict._metadata) 2295*da0073e9SAndroid Build Coastguard Worker self.assertIn('linear2.weight', state_dict) 2296*da0073e9SAndroid Build Coastguard Worker self.assertIn('linear2.bias', state_dict) 2297*da0073e9SAndroid Build Coastguard Worker self.assertIn('block', state_dict._metadata) 2298*da0073e9SAndroid Build Coastguard Worker self.assertIn('block.conv', state_dict._metadata) 2299*da0073e9SAndroid Build Coastguard Worker self.assertIn('block.conv.weight', state_dict) 2300*da0073e9SAndroid Build Coastguard Worker self.assertIn('block.conv.weight', state_dict) 2301*da0073e9SAndroid Build Coastguard Worker self.assertNotIn('block.conv.bias', state_dict) 2302*da0073e9SAndroid Build Coastguard Worker self.assertIn('bn', state_dict._metadata) 2303*da0073e9SAndroid Build Coastguard Worker self.assertIn('bn.weight', state_dict) 2304*da0073e9SAndroid Build Coastguard Worker self.assertIn('bn.bias', state_dict) 2305*da0073e9SAndroid Build Coastguard Worker self.assertIn('bn.running_var', state_dict) 2306*da0073e9SAndroid Build Coastguard Worker self.assertIn('bn.running_mean', state_dict) 2307*da0073e9SAndroid Build Coastguard Worker self.assertIn('bn.num_batches_tracked', state_dict) 2308*da0073e9SAndroid Build Coastguard Worker self.assertFalse(any(k.startswith('empty') for k in state_dict.keys())) 2309*da0073e9SAndroid Build Coastguard Worker for k, v in state_dict.items(): 2310*da0073e9SAndroid Build Coastguard Worker param = net 2311*da0073e9SAndroid Build Coastguard Worker for component in k.split('.'): 2312*da0073e9SAndroid Build Coastguard Worker param = getattr(param, component) 2313*da0073e9SAndroid Build Coastguard Worker if isinstance(param, Parameter): 2314*da0073e9SAndroid Build Coastguard Worker param = param.data 2315*da0073e9SAndroid Build Coastguard Worker self.assertEqual(v.data_ptr(), param.data_ptr()) 2316*da0073e9SAndroid Build Coastguard Worker 2317*da0073e9SAndroid Build Coastguard Worker l = nn.Linear(5, 5) 2318*da0073e9SAndroid Build Coastguard Worker state_dict = l.state_dict() 2319*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(state_dict), 2) 2320*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(state_dict._metadata), 1) 2321*da0073e9SAndroid Build Coastguard Worker self.assertIn('', state_dict._metadata) 2322*da0073e9SAndroid Build Coastguard Worker self.assertTrue(state_dict._metadata['']['version'] >= 0) 2323*da0073e9SAndroid Build Coastguard Worker self.assertEqual(state_dict['weight'].data_ptr(), l.weight.data_ptr()) 2324*da0073e9SAndroid Build Coastguard Worker self.assertEqual(state_dict['bias'].data_ptr(), l.bias.data_ptr()) 2325*da0073e9SAndroid Build Coastguard Worker 2326*da0073e9SAndroid Build Coastguard Worker # Reference https://github.com/pytorch/pytorch/pull/75507#issuecomment-1110291545 2327*da0073e9SAndroid Build Coastguard Worker self.assertNotWarn(lambda: l.state_dict(destination={}), "Should not warn kwarg destination w/o _metadata") 2328*da0073e9SAndroid Build Coastguard Worker 2329*da0073e9SAndroid Build Coastguard Worker def test_extra_state(self): 2330*da0073e9SAndroid Build Coastguard Worker 2331*da0073e9SAndroid Build Coastguard Worker class SubModule(torch.nn.Module): 2332*da0073e9SAndroid Build Coastguard Worker def __init__(self, foo): 2333*da0073e9SAndroid Build Coastguard Worker super().__init__() 2334*da0073e9SAndroid Build Coastguard Worker self.foo = foo 2335*da0073e9SAndroid Build Coastguard Worker 2336*da0073e9SAndroid Build Coastguard Worker def get_extra_state(self): 2337*da0073e9SAndroid Build Coastguard Worker return { 2338*da0073e9SAndroid Build Coastguard Worker 'foo': self.foo 2339*da0073e9SAndroid Build Coastguard Worker } 2340*da0073e9SAndroid Build Coastguard Worker 2341*da0073e9SAndroid Build Coastguard Worker def set_extra_state(self, state): 2342*da0073e9SAndroid Build Coastguard Worker self.foo = state['foo'] 2343*da0073e9SAndroid Build Coastguard Worker 2344*da0073e9SAndroid Build Coastguard Worker class MyModule(torch.nn.Module): 2345*da0073e9SAndroid Build Coastguard Worker def __init__(self, foo, bar): 2346*da0073e9SAndroid Build Coastguard Worker super().__init__() 2347*da0073e9SAndroid Build Coastguard Worker self.sub = SubModule(foo) 2348*da0073e9SAndroid Build Coastguard Worker self.bar = bar 2349*da0073e9SAndroid Build Coastguard Worker 2350*da0073e9SAndroid Build Coastguard Worker def get_extra_state(self): 2351*da0073e9SAndroid Build Coastguard Worker return { 2352*da0073e9SAndroid Build Coastguard Worker 'bar': self.bar 2353*da0073e9SAndroid Build Coastguard Worker } 2354*da0073e9SAndroid Build Coastguard Worker 2355*da0073e9SAndroid Build Coastguard Worker def set_extra_state(self, state): 2356*da0073e9SAndroid Build Coastguard Worker self.bar = state['bar'] 2357*da0073e9SAndroid Build Coastguard Worker 2358*da0073e9SAndroid Build Coastguard Worker # Ensure state_dict contains the extra state by loading it into another module. 2359*da0073e9SAndroid Build Coastguard Worker m = MyModule(3, 'something') 2360*da0073e9SAndroid Build Coastguard Worker m2 = MyModule(5, 'something else') 2361*da0073e9SAndroid Build Coastguard Worker m2.load_state_dict(m.state_dict()) 2362*da0073e9SAndroid Build Coastguard Worker self.assertEqual(m.state_dict(), m2.state_dict()) 2363*da0073e9SAndroid Build Coastguard Worker self.assertEqual(m2.bar, m.bar) 2364*da0073e9SAndroid Build Coastguard Worker self.assertEqual(m2.sub.foo, m.sub.foo) 2365*da0073e9SAndroid Build Coastguard Worker 2366*da0073e9SAndroid Build Coastguard Worker def test_extra_state_non_dict(self): 2367*da0073e9SAndroid Build Coastguard Worker 2368*da0073e9SAndroid Build Coastguard Worker class MyModule(torch.nn.Module): 2369*da0073e9SAndroid Build Coastguard Worker def __init__(self, foo): 2370*da0073e9SAndroid Build Coastguard Worker super().__init__() 2371*da0073e9SAndroid Build Coastguard Worker self.foo = foo 2372*da0073e9SAndroid Build Coastguard Worker 2373*da0073e9SAndroid Build Coastguard Worker def get_extra_state(self): 2374*da0073e9SAndroid Build Coastguard Worker return self.foo 2375*da0073e9SAndroid Build Coastguard Worker 2376*da0073e9SAndroid Build Coastguard Worker def set_extra_state(self, state): 2377*da0073e9SAndroid Build Coastguard Worker self.foo = state 2378*da0073e9SAndroid Build Coastguard Worker 2379*da0073e9SAndroid Build Coastguard Worker # Test various types of extra state. 2380*da0073e9SAndroid Build Coastguard Worker for state in ('something', 5, MyModule(3)): 2381*da0073e9SAndroid Build Coastguard Worker m = MyModule(state) 2382*da0073e9SAndroid Build Coastguard Worker m2 = MyModule('something else') 2383*da0073e9SAndroid Build Coastguard Worker m2.load_state_dict(m.state_dict()) 2384*da0073e9SAndroid Build Coastguard Worker self.assertEqual(m.state_dict(), m2.state_dict()) 2385*da0073e9SAndroid Build Coastguard Worker self.assertEqual(m.foo, m2.foo) 2386*da0073e9SAndroid Build Coastguard Worker 2387*da0073e9SAndroid Build Coastguard Worker def test_extra_state_missing_set_extra_state(self): 2388*da0073e9SAndroid Build Coastguard Worker 2389*da0073e9SAndroid Build Coastguard Worker class MyModule(torch.nn.Module): 2390*da0073e9SAndroid Build Coastguard Worker def get_extra_state(self): 2391*da0073e9SAndroid Build Coastguard Worker return { 2392*da0073e9SAndroid Build Coastguard Worker 'foo': 5 2393*da0073e9SAndroid Build Coastguard Worker } 2394*da0073e9SAndroid Build Coastguard Worker 2395*da0073e9SAndroid Build Coastguard Worker m = MyModule() 2396*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, 'Unexpected key'): 2397*da0073e9SAndroid Build Coastguard Worker m.load_state_dict(m.state_dict()) 2398*da0073e9SAndroid Build Coastguard Worker 2399*da0073e9SAndroid Build Coastguard Worker def test_extra_state_missing_get_extra_state(self): 2400*da0073e9SAndroid Build Coastguard Worker 2401*da0073e9SAndroid Build Coastguard Worker class MyModule(torch.nn.Module): 2402*da0073e9SAndroid Build Coastguard Worker def set_extra_state(self): 2403*da0073e9SAndroid Build Coastguard Worker pass 2404*da0073e9SAndroid Build Coastguard Worker 2405*da0073e9SAndroid Build Coastguard Worker m = MyModule() 2406*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, 'Missing key'): 2407*da0073e9SAndroid Build Coastguard Worker m.load_state_dict(m.state_dict()) 2408*da0073e9SAndroid Build Coastguard Worker 2409*da0073e9SAndroid Build Coastguard Worker @skipIfTorchDynamo("TorchDynamo fails here for unknown reasons") 2410*da0073e9SAndroid Build Coastguard Worker def test_parameter_assignment(self): 2411*da0073e9SAndroid Build Coastguard Worker l = nn.Linear(5, 5) 2412*da0073e9SAndroid Build Coastguard Worker 2413*da0073e9SAndroid Build Coastguard Worker def num_params(): 2414*da0073e9SAndroid Build Coastguard Worker return len(list(l.parameters())) 2415*da0073e9SAndroid Build Coastguard Worker 2416*da0073e9SAndroid Build Coastguard Worker self.assertEqual(num_params(), 2) 2417*da0073e9SAndroid Build Coastguard Worker 2418*da0073e9SAndroid Build Coastguard Worker new_param = Parameter(torch.randn(5, 5)) 2419*da0073e9SAndroid Build Coastguard Worker l.param_name = new_param 2420*da0073e9SAndroid Build Coastguard Worker self.assertEqual(num_params(), 3) 2421*da0073e9SAndroid Build Coastguard Worker self.assertObjectIn(new_param, l.parameters()) 2422*da0073e9SAndroid Build Coastguard Worker 2423*da0073e9SAndroid Build Coastguard Worker var = torch.randn(5, 5) 2424*da0073e9SAndroid Build Coastguard Worker l.var_name = var 2425*da0073e9SAndroid Build Coastguard Worker self.assertEqual(num_params(), 3) 2426*da0073e9SAndroid Build Coastguard Worker self.assertNotIn(id(var), map(id, l.parameters())) 2427*da0073e9SAndroid Build Coastguard Worker 2428*da0073e9SAndroid Build Coastguard Worker # Make sure Variables are not saved as parameters 2429*da0073e9SAndroid Build Coastguard Worker l.variable_attr = torch.empty(5, 5) 2430*da0073e9SAndroid Build Coastguard Worker self.assertEqual(num_params(), 3) 2431*da0073e9SAndroid Build Coastguard Worker l.param_attr = Parameter(torch.empty(5, 5)) 2432*da0073e9SAndroid Build Coastguard Worker self.assertEqual(num_params(), 4) 2433*da0073e9SAndroid Build Coastguard Worker 2434*da0073e9SAndroid Build Coastguard Worker # It shouldn't be possible to replace a parameter with a Variable 2435*da0073e9SAndroid Build Coastguard Worker def assign_var(): 2436*da0073e9SAndroid Build Coastguard Worker l.param_attr = torch.empty(5, 5) 2437*da0073e9SAndroid Build Coastguard Worker 2438*da0073e9SAndroid Build Coastguard Worker self.assertRaises(TypeError, assign_var) 2439*da0073e9SAndroid Build Coastguard Worker # But replacing it with None should be fine 2440*da0073e9SAndroid Build Coastguard Worker l.param_attr = None 2441*da0073e9SAndroid Build Coastguard Worker self.assertEqual(num_params(), 3) 2442*da0073e9SAndroid Build Coastguard Worker 2443*da0073e9SAndroid Build Coastguard Worker def test_assignment(self): 2444*da0073e9SAndroid Build Coastguard Worker l = nn.Module() 2445*da0073e9SAndroid Build Coastguard Worker a = nn.Parameter(torch.randn(2)) 2446*da0073e9SAndroid Build Coastguard Worker b = nn.Parameter(torch.randn(3)) 2447*da0073e9SAndroid Build Coastguard Worker c = nn.Parameter(torch.randn(4)) 2448*da0073e9SAndroid Build Coastguard Worker q = nn.Linear(4, 4) 2449*da0073e9SAndroid Build Coastguard Worker r = nn.Linear(5, 5) 2450*da0073e9SAndroid Build Coastguard Worker w = nn.Linear(6, 6) 2451*da0073e9SAndroid Build Coastguard Worker 2452*da0073e9SAndroid Build Coastguard Worker def test_assignments(get_list, a, b, c): 2453*da0073e9SAndroid Build Coastguard Worker # Check that None can be shadowed 2454*da0073e9SAndroid Build Coastguard Worker l.a = None 2455*da0073e9SAndroid Build Coastguard Worker self.assertIsNone(l.a) 2456*da0073e9SAndroid Build Coastguard Worker self.assertIn('a', l.__dict__) 2457*da0073e9SAndroid Build Coastguard Worker l.a = a 2458*da0073e9SAndroid Build Coastguard Worker self.assertIs(l.a, a) 2459*da0073e9SAndroid Build Coastguard Worker self.assertEqual(get_list(), [a]) 2460*da0073e9SAndroid Build Coastguard Worker self.assertNotIn('a', l.__dict__) 2461*da0073e9SAndroid Build Coastguard Worker 2462*da0073e9SAndroid Build Coastguard Worker # Assign second object 2463*da0073e9SAndroid Build Coastguard Worker l.b = None 2464*da0073e9SAndroid Build Coastguard Worker self.assertIsNone(l.b) 2465*da0073e9SAndroid Build Coastguard Worker self.assertIn('b', l.__dict__) 2466*da0073e9SAndroid Build Coastguard Worker l.b = b 2467*da0073e9SAndroid Build Coastguard Worker self.assertIs(l.b, b) 2468*da0073e9SAndroid Build Coastguard Worker self.assertEqual(get_list(), [a, b]) 2469*da0073e9SAndroid Build Coastguard Worker self.assertNotIn('b', l.__dict__) 2470*da0073e9SAndroid Build Coastguard Worker 2471*da0073e9SAndroid Build Coastguard Worker # Remove and add the object back. Order should be unchanged. 2472*da0073e9SAndroid Build Coastguard Worker l.a = None 2473*da0073e9SAndroid Build Coastguard Worker self.assertIsNone(l.a) 2474*da0073e9SAndroid Build Coastguard Worker self.assertEqual(get_list(), [b]) 2475*da0073e9SAndroid Build Coastguard Worker l.a = a 2476*da0073e9SAndroid Build Coastguard Worker self.assertIs(l.a, a) 2477*da0073e9SAndroid Build Coastguard Worker self.assertEqual(get_list(), [a, b]) 2478*da0073e9SAndroid Build Coastguard Worker 2479*da0073e9SAndroid Build Coastguard Worker # Replace object with another one. Order should be unchanged. 2480*da0073e9SAndroid Build Coastguard Worker l.a = c 2481*da0073e9SAndroid Build Coastguard Worker self.assertIs(l.a, c) 2482*da0073e9SAndroid Build Coastguard Worker self.assertEqual(get_list(), [c, b]) 2483*da0073e9SAndroid Build Coastguard Worker 2484*da0073e9SAndroid Build Coastguard Worker # Remove and reassign an attribute. It should appear at the end of the list now. 2485*da0073e9SAndroid Build Coastguard Worker del l.a 2486*da0073e9SAndroid Build Coastguard Worker self.assertFalse(hasattr(l, 'a')) 2487*da0073e9SAndroid Build Coastguard Worker l.a = a 2488*da0073e9SAndroid Build Coastguard Worker self.assertIs(l.a, a) 2489*da0073e9SAndroid Build Coastguard Worker self.assertEqual(get_list(), [b, a]) 2490*da0073e9SAndroid Build Coastguard Worker 2491*da0073e9SAndroid Build Coastguard Worker test_assignments(lambda: list(l.parameters()), a, b, c) 2492*da0073e9SAndroid Build Coastguard Worker del l.a, l.b 2493*da0073e9SAndroid Build Coastguard Worker self.assertEqual(list(l.parameters()), []) 2494*da0073e9SAndroid Build Coastguard Worker 2495*da0073e9SAndroid Build Coastguard Worker test_assignments(lambda: list(l.children()), q, r, w) 2496*da0073e9SAndroid Build Coastguard Worker del l.a, l.b 2497*da0073e9SAndroid Build Coastguard Worker self.assertEqual(list(l.children()), []) 2498*da0073e9SAndroid Build Coastguard Worker 2499*da0073e9SAndroid Build Coastguard Worker buf = Buffer(torch.randn(10)) 2500*da0073e9SAndroid Build Coastguard Worker l.buf = buf 2501*da0073e9SAndroid Build Coastguard Worker self.assertIs(l.buf, buf) 2502*da0073e9SAndroid Build Coastguard Worker l.buf = None 2503*da0073e9SAndroid Build Coastguard Worker self.assertIs(l.buf, None) 2504*da0073e9SAndroid Build Coastguard Worker self.assertNotIn('buf', l.__dict__) # should be stored in l._buffers 2505*da0073e9SAndroid Build Coastguard Worker l.buf = buf 2506*da0073e9SAndroid Build Coastguard Worker self.assertIn('buf', l.state_dict()) 2507*da0073e9SAndroid Build Coastguard Worker self.assertEqual(l.state_dict()['buf'], buf) 2508*da0073e9SAndroid Build Coastguard Worker 2509*da0073e9SAndroid Build Coastguard Worker def test_container_copy(self): 2510*da0073e9SAndroid Build Coastguard Worker class Model(nn.Module): 2511*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 2512*da0073e9SAndroid Build Coastguard Worker super().__init__() 2513*da0073e9SAndroid Build Coastguard Worker self.linear = nn.Linear(4, 5) 2514*da0073e9SAndroid Build Coastguard Worker 2515*da0073e9SAndroid Build Coastguard Worker def forward(self, input): 2516*da0073e9SAndroid Build Coastguard Worker return self.linear(input) 2517*da0073e9SAndroid Build Coastguard Worker 2518*da0073e9SAndroid Build Coastguard Worker input = torch.randn(2, 4) 2519*da0073e9SAndroid Build Coastguard Worker 2520*da0073e9SAndroid Build Coastguard Worker model = Model() 2521*da0073e9SAndroid Build Coastguard Worker model_cp = deepcopy(model) 2522*da0073e9SAndroid Build Coastguard Worker self.assertEqual(model(input).data, model_cp(input).data) 2523*da0073e9SAndroid Build Coastguard Worker 2524*da0073e9SAndroid Build Coastguard Worker model_cp.linear.weight.data[:] = 2 2525*da0073e9SAndroid Build Coastguard Worker self.assertNotEqual(model(input).data, model_cp(input).data) 2526*da0073e9SAndroid Build Coastguard Worker 2527*da0073e9SAndroid Build Coastguard Worker def test_RNN_cell(self): 2528*da0073e9SAndroid Build Coastguard Worker # this is just a smoke test; these modules are implemented through 2529*da0073e9SAndroid Build Coastguard Worker # autograd so no Jacobian test is needed 2530*da0073e9SAndroid Build Coastguard Worker for module in (nn.RNNCell, nn.GRUCell): 2531*da0073e9SAndroid Build Coastguard Worker for bias in (True, False): 2532*da0073e9SAndroid Build Coastguard Worker input = torch.randn(3, 10) 2533*da0073e9SAndroid Build Coastguard Worker hx = torch.randn(3, 20) 2534*da0073e9SAndroid Build Coastguard Worker cell = module(10, 20, bias=bias) 2535*da0073e9SAndroid Build Coastguard Worker for _ in range(6): 2536*da0073e9SAndroid Build Coastguard Worker hx = cell(input, hx) 2537*da0073e9SAndroid Build Coastguard Worker 2538*da0073e9SAndroid Build Coastguard Worker hx.sum().backward() 2539*da0073e9SAndroid Build Coastguard Worker 2540*da0073e9SAndroid Build Coastguard Worker def test_RNN_cell_forward_zero_hidden_size(self): 2541*da0073e9SAndroid Build Coastguard Worker input = torch.randn(3, 10) 2542*da0073e9SAndroid Build Coastguard Worker hx = torch.randn(3, 0) 2543*da0073e9SAndroid Build Coastguard Worker cell_shared_param = (10, 0) 2544*da0073e9SAndroid Build Coastguard Worker for cell in (nn.RNNCell(*cell_shared_param, nonlinearity="relu"), 2545*da0073e9SAndroid Build Coastguard Worker nn.RNNCell(*cell_shared_param, nonlinearity="tanh"), 2546*da0073e9SAndroid Build Coastguard Worker nn.GRUCell(*cell_shared_param)): 2547*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cell(input, hx).shape, torch.Size([3, 0])) 2548*da0073e9SAndroid Build Coastguard Worker 2549*da0073e9SAndroid Build Coastguard Worker def _test_loss_equal_input_target_shape(self, cast): 2550*da0073e9SAndroid Build Coastguard Worker # Tests losses whose inputs should have the same size. 2551*da0073e9SAndroid Build Coastguard Worker losses = { 2552*da0073e9SAndroid Build Coastguard Worker 'mse_loss': lambda x, y: F.mse_loss(x, y), 2553*da0073e9SAndroid Build Coastguard Worker 'l1_loss': lambda x, y: F.l1_loss(x, y), 2554*da0073e9SAndroid Build Coastguard Worker 'smooth_l1_loss': lambda x, y: F.smooth_l1_loss(x, y), 2555*da0073e9SAndroid Build Coastguard Worker 'huber_loss': lambda x, y: F.huber_loss(x, y), 2556*da0073e9SAndroid Build Coastguard Worker 'kl_div': lambda x, y: F.kl_div(x, y), 2557*da0073e9SAndroid Build Coastguard Worker 'poisson_nll_loss': lambda x, y: F.poisson_nll_loss(x, y), 2558*da0073e9SAndroid Build Coastguard Worker } 2559*da0073e9SAndroid Build Coastguard Worker 2560*da0073e9SAndroid Build Coastguard Worker input = cast(torch.randn(3, 5)) 2561*da0073e9SAndroid Build Coastguard Worker target = cast(torch.randn(5, 3)) 2562*da0073e9SAndroid Build Coastguard Worker for fn in losses.values(): 2563*da0073e9SAndroid Build Coastguard Worker self.assertRaises(Exception, lambda: fn(input, target)) 2564*da0073e9SAndroid Build Coastguard Worker 2565*da0073e9SAndroid Build Coastguard Worker def test_loss_equal_input_target_shape(self): 2566*da0073e9SAndroid Build Coastguard Worker self._test_loss_equal_input_target_shape(lambda x: x) 2567*da0073e9SAndroid Build Coastguard Worker 2568*da0073e9SAndroid Build Coastguard Worker def test_mse_loss_size_warning(self): 2569*da0073e9SAndroid Build Coastguard Worker i = torch.randn((10, 1), requires_grad=True) 2570*da0073e9SAndroid Build Coastguard Worker t = torch.randn((10,)) 2571*da0073e9SAndroid Build Coastguard Worker with warnings.catch_warnings(record=True) as w: 2572*da0073e9SAndroid Build Coastguard Worker # Ensure warnings are being shown 2573*da0073e9SAndroid Build Coastguard Worker warnings.simplefilter("always") 2574*da0073e9SAndroid Build Coastguard Worker # Trigger Warning 2575*da0073e9SAndroid Build Coastguard Worker F.mse_loss(i, t) 2576*da0073e9SAndroid Build Coastguard Worker # Check warning occurs 2577*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(w), 1) 2578*da0073e9SAndroid Build Coastguard Worker self.assertIn('Please ensure they have the same size.', str(w[0])) 2579*da0073e9SAndroid Build Coastguard Worker 2580*da0073e9SAndroid Build Coastguard Worker def test_gaussian_nll_loss_broadcasting(self): 2581*da0073e9SAndroid Build Coastguard Worker input = torch.tensor([[0.5, 1.5, 2.5], [2., 4., 6.]]) 2582*da0073e9SAndroid Build Coastguard Worker target_full = torch.tensor([[1., 2., 3.], [1., 2., 3.]]) 2583*da0073e9SAndroid Build Coastguard Worker target_part = torch.tensor([[1., 2., 3.]]) 2584*da0073e9SAndroid Build Coastguard Worker var_full = torch.tensor([[0.5, 0.5, 0.5], [1.5, 1.5, 1.5]]) 2585*da0073e9SAndroid Build Coastguard Worker var_part1 = torch.tensor([[0.5], [1.5]]) 2586*da0073e9SAndroid Build Coastguard Worker var_part2 = torch.tensor([0.5, 1.5]) 2587*da0073e9SAndroid Build Coastguard Worker component_wise_loss = 0.5 * (torch.log(var_full) + (input - target_full)**2 / var_full) 2588*da0073e9SAndroid Build Coastguard Worker self.assertEqual(component_wise_loss, 2589*da0073e9SAndroid Build Coastguard Worker F.gaussian_nll_loss(input, target_part, var_full, reduction='none')) 2590*da0073e9SAndroid Build Coastguard Worker self.assertEqual(component_wise_loss, 2591*da0073e9SAndroid Build Coastguard Worker F.gaussian_nll_loss(input, target_full, var_part1, reduction='none')) 2592*da0073e9SAndroid Build Coastguard Worker self.assertEqual(component_wise_loss, 2593*da0073e9SAndroid Build Coastguard Worker F.gaussian_nll_loss(input, target_full, var_part2, reduction='none')) 2594*da0073e9SAndroid Build Coastguard Worker self.assertEqual(component_wise_loss, 2595*da0073e9SAndroid Build Coastguard Worker F.gaussian_nll_loss(input, target_part, var_part1, reduction='none')) 2596*da0073e9SAndroid Build Coastguard Worker self.assertEqual(component_wise_loss, 2597*da0073e9SAndroid Build Coastguard Worker F.gaussian_nll_loss(input, target_part, var_part2, reduction='none')) 2598*da0073e9SAndroid Build Coastguard Worker 2599*da0073e9SAndroid Build Coastguard Worker def test_gaussian_nll_loss_args(self): 2600*da0073e9SAndroid Build Coastguard Worker input = torch.randn(3, 5) 2601*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(ValueError, 'var is of incorrect size'): 2602*da0073e9SAndroid Build Coastguard Worker target = torch.randn(3, 5) 2603*da0073e9SAndroid Build Coastguard Worker var = torch.ones(3, 3) 2604*da0073e9SAndroid Build Coastguard Worker torch.nn.functional.gaussian_nll_loss(input, target, var) 2605*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(ValueError, 'var has negative entry/entries'): 2606*da0073e9SAndroid Build Coastguard Worker var = -1 * torch.ones(3, 5) 2607*da0073e9SAndroid Build Coastguard Worker torch.nn.functional.gaussian_nll_loss(input, target, var) 2608*da0073e9SAndroid Build Coastguard Worker 2609*da0073e9SAndroid Build Coastguard Worker def test_KLDivLoss_batch_mean(self): 2610*da0073e9SAndroid Build Coastguard Worker input_shape = (2, 5) 2611*da0073e9SAndroid Build Coastguard Worker log_prob1 = F.log_softmax(torch.randn(input_shape), 1) 2612*da0073e9SAndroid Build Coastguard Worker prob2 = F.softmax(torch.randn(input_shape), 1) 2613*da0073e9SAndroid Build Coastguard Worker 2614*da0073e9SAndroid Build Coastguard Worker loss = nn.KLDivLoss(reduction='batchmean') 2615*da0073e9SAndroid Build Coastguard Worker l = loss(log_prob1, prob2) 2616*da0073e9SAndroid Build Coastguard Worker 2617*da0073e9SAndroid Build Coastguard Worker loss_none_reduce = nn.KLDivLoss(reduction='sum')(log_prob1, prob2) 2618*da0073e9SAndroid Build Coastguard Worker expected = loss_none_reduce / input_shape[0] 2619*da0073e9SAndroid Build Coastguard Worker 2620*da0073e9SAndroid Build Coastguard Worker self.assertEqual(l, expected) 2621*da0073e9SAndroid Build Coastguard Worker 2622*da0073e9SAndroid Build Coastguard Worker def test_KLDivLoss_batch_mean_log_target(self): 2623*da0073e9SAndroid Build Coastguard Worker input_shape = (2, 5) 2624*da0073e9SAndroid Build Coastguard Worker log_prob1 = F.log_softmax(torch.randn(input_shape), 1) 2625*da0073e9SAndroid Build Coastguard Worker log_prob2 = F.log_softmax(torch.randn(input_shape), 1) 2626*da0073e9SAndroid Build Coastguard Worker 2627*da0073e9SAndroid Build Coastguard Worker loss = nn.KLDivLoss(reduction='batchmean', log_target=True) 2628*da0073e9SAndroid Build Coastguard Worker l = loss(log_prob1, log_prob2) 2629*da0073e9SAndroid Build Coastguard Worker 2630*da0073e9SAndroid Build Coastguard Worker loss_none_reduce = nn.KLDivLoss(reduction='sum', log_target=True)(log_prob1, log_prob2) 2631*da0073e9SAndroid Build Coastguard Worker expected = loss_none_reduce / input_shape[0] 2632*da0073e9SAndroid Build Coastguard Worker 2633*da0073e9SAndroid Build Coastguard Worker self.assertEqual(l, expected) 2634*da0073e9SAndroid Build Coastguard Worker 2635*da0073e9SAndroid Build Coastguard Worker def test_CTCLoss_typechecks(self): 2636*da0073e9SAndroid Build Coastguard Worker target_lengths = torch.tensor([30, 25, 20]) 2637*da0073e9SAndroid Build Coastguard Worker input_lengths = torch.tensor([50, 50, 50]) 2638*da0073e9SAndroid Build Coastguard Worker targets = torch.randint(1, 15, (sum(target_lengths),), dtype=torch.int) 2639*da0073e9SAndroid Build Coastguard Worker log_probs = torch.randn(50, 3, 15, dtype=torch.float).log_softmax(2) 2640*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(RuntimeError): 2641*da0073e9SAndroid Build Coastguard Worker _input_lengths = input_lengths.to(dtype=torch.float) 2642*da0073e9SAndroid Build Coastguard Worker torch.nn.functional.ctc_loss(log_probs, targets, _input_lengths, target_lengths) 2643*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(RuntimeError): 2644*da0073e9SAndroid Build Coastguard Worker target_lengths = target_lengths.to(dtype=torch.float) 2645*da0073e9SAndroid Build Coastguard Worker torch.nn.functional.ctc_loss(log_probs, targets, input_lengths, target_lengths) 2646*da0073e9SAndroid Build Coastguard Worker 2647*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_CUDA, 'CUDA not available') 2648*da0073e9SAndroid Build Coastguard Worker def test_CTCLoss_lengthchecks_cuda(self): 2649*da0073e9SAndroid Build Coastguard Worker for target_lengths in [[30, 25, 20], [-1, -1, -1]]: 2650*da0073e9SAndroid Build Coastguard Worker for input_lengths in [[50, 50, 50], [-1, -1, -1]]: 2651*da0073e9SAndroid Build Coastguard Worker targets = torch.randint(1, 15, (3, 29), dtype=torch.long, device='cuda') 2652*da0073e9SAndroid Build Coastguard Worker log_probs = torch.randn(50, 3, 15, dtype=torch.float, device='cuda').log_softmax(2) 2653*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(RuntimeError): 2654*da0073e9SAndroid Build Coastguard Worker torch.nn.functional.ctc_loss(log_probs, targets, input_lengths, target_lengths) 2655*da0073e9SAndroid Build Coastguard Worker 2656*da0073e9SAndroid Build Coastguard Worker def test_CTCLoss_lengthchecks_cpu(self): 2657*da0073e9SAndroid Build Coastguard Worker for target_lengths in [[30, 25, 20], [-1, -1, -1]]: 2658*da0073e9SAndroid Build Coastguard Worker for input_lengths in [[50, 50, 50], [-1, -1, -1]]: 2659*da0073e9SAndroid Build Coastguard Worker targets = torch.randint(1, 15, (3, 29), dtype=torch.int) 2660*da0073e9SAndroid Build Coastguard Worker log_probs = torch.randn(50, 3, 15, dtype=torch.float).log_softmax(2) 2661*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(RuntimeError): 2662*da0073e9SAndroid Build Coastguard Worker torch.nn.functional.ctc_loss(log_probs, targets, input_lengths, target_lengths) 2663*da0073e9SAndroid Build Coastguard Worker 2664*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_CUDA, 'CUDA not available') 2665*da0073e9SAndroid Build Coastguard Worker def test_CTCLoss_long_targets(self): 2666*da0073e9SAndroid Build Coastguard Worker input_length = 4000 2667*da0073e9SAndroid Build Coastguard Worker vocab_size = 3 2668*da0073e9SAndroid Build Coastguard Worker batch_size = 4 2669*da0073e9SAndroid Build Coastguard Worker target_length = 1200 2670*da0073e9SAndroid Build Coastguard Worker 2671*da0073e9SAndroid Build Coastguard Worker log_probs = torch.randn(input_length, batch_size, vocab_size, dtype=torch.double).log_softmax(2).requires_grad_() 2672*da0073e9SAndroid Build Coastguard Worker targets = torch.randint(low=1, high=vocab_size - 1, size=(batch_size, target_length), dtype=torch.long) 2673*da0073e9SAndroid Build Coastguard Worker input_lengths = batch_size * [input_length] 2674*da0073e9SAndroid Build Coastguard Worker target_lengths = batch_size * [target_length] 2675*da0073e9SAndroid Build Coastguard Worker 2676*da0073e9SAndroid Build Coastguard Worker res_cpu = torch.nn.functional.ctc_loss(log_probs, targets, input_lengths, target_lengths, 2677*da0073e9SAndroid Build Coastguard Worker reduction='sum', zero_infinity=True) 2678*da0073e9SAndroid Build Coastguard Worker grad_out = torch.randn_like(res_cpu) 2679*da0073e9SAndroid Build Coastguard Worker grad_cpu, = torch.autograd.grad(res_cpu, log_probs, grad_out) 2680*da0073e9SAndroid Build Coastguard Worker 2681*da0073e9SAndroid Build Coastguard Worker with torch.backends.cudnn.flags(enabled=False): 2682*da0073e9SAndroid Build Coastguard Worker res_gpu = torch.nn.functional.ctc_loss(log_probs.cuda(), targets.cuda(), input_lengths, target_lengths, 2683*da0073e9SAndroid Build Coastguard Worker reduction='sum', zero_infinity=True) 2684*da0073e9SAndroid Build Coastguard Worker grad_gpu, = torch.autograd.grad(res_gpu, log_probs, grad_out.cuda()) 2685*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res_cpu, res_gpu, atol=1e-4, rtol=0) 2686*da0073e9SAndroid Build Coastguard Worker self.assertEqual(grad_cpu, grad_gpu, atol=1e-4, rtol=0) 2687*da0073e9SAndroid Build Coastguard Worker 2688*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_CUDA, 'CUDA not available') 2689*da0073e9SAndroid Build Coastguard Worker def test_CTCLoss_critical_target_len(self): 2690*da0073e9SAndroid Build Coastguard Worker # cudnn has an unexpected problem with target length 256, see issue #53505 2691*da0073e9SAndroid Build Coastguard Worker N = 1 2692*da0073e9SAndroid Build Coastguard Worker S = 256 2693*da0073e9SAndroid Build Coastguard Worker C = 10 2694*da0073e9SAndroid Build Coastguard Worker T = 500 2695*da0073e9SAndroid Build Coastguard Worker target = torch.randint(low=1, high=C, size=(S,), dtype=torch.int) 2696*da0073e9SAndroid Build Coastguard Worker input_lengths = torch.full(size=(N,), fill_value=T, dtype=torch.int) 2697*da0073e9SAndroid Build Coastguard Worker target_lengths = torch.tensor(S, dtype=torch.int) 2698*da0073e9SAndroid Build Coastguard Worker inp = torch.randn(T, N, C, dtype=torch.float, device='cuda').log_softmax(2).requires_grad_() 2699*da0073e9SAndroid Build Coastguard Worker with cudnn.flags(enabled=True): 2700*da0073e9SAndroid Build Coastguard Worker res_gpu = torch.nn.functional.ctc_loss(inp, target, input_lengths, target_lengths, reduction='none') 2701*da0073e9SAndroid Build Coastguard Worker res_cpu = torch.nn.functional.ctc_loss(inp.cpu(), target, input_lengths, target_lengths, reduction='none') 2702*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res_cpu, res_gpu, atol=1e-3, rtol=0) 2703*da0073e9SAndroid Build Coastguard Worker 2704*da0073e9SAndroid Build Coastguard Worker def test_CTCLoss_zero_lengths(self): 2705*da0073e9SAndroid Build Coastguard Worker devices = ['cpu'] 2706*da0073e9SAndroid Build Coastguard Worker devices += ['cuda'] if TEST_CUDA else [] 2707*da0073e9SAndroid Build Coastguard Worker N = 3 2708*da0073e9SAndroid Build Coastguard Worker S = 2 2709*da0073e9SAndroid Build Coastguard Worker C = 200 2710*da0073e9SAndroid Build Coastguard Worker T = 1 2711*da0073e9SAndroid Build Coastguard Worker target = torch.randint(low=1, high=C, size=(N, S), dtype=torch.int) 2712*da0073e9SAndroid Build Coastguard Worker input_lengths = torch.full(size=(N,), fill_value=0, dtype=torch.int) 2713*da0073e9SAndroid Build Coastguard Worker target_lengths = torch.full(size=(N,), fill_value=0, dtype=torch.int) 2714*da0073e9SAndroid Build Coastguard Worker for device in devices: 2715*da0073e9SAndroid Build Coastguard Worker inp = torch.randn(T, N, C, dtype=torch.float, device=device).log_softmax(2).requires_grad_() 2716*da0073e9SAndroid Build Coastguard Worker res = torch.nn.functional.ctc_loss(inp, target, input_lengths, target_lengths, reduction='none') 2717*da0073e9SAndroid Build Coastguard Worker self.assertTrue((res == 0).all().item()) 2718*da0073e9SAndroid Build Coastguard Worker res.sum().backward() 2719*da0073e9SAndroid Build Coastguard Worker self.assertTrue((inp.grad == 0).all().item()) 2720*da0073e9SAndroid Build Coastguard Worker target_lengths = torch.full(size=(N,), fill_value=1, dtype=torch.int) 2721*da0073e9SAndroid Build Coastguard Worker for device in devices: 2722*da0073e9SAndroid Build Coastguard Worker inp = torch.randn(T, N, C, dtype=torch.float, device=device).log_softmax(2).requires_grad_() 2723*da0073e9SAndroid Build Coastguard Worker res = torch.nn.functional.ctc_loss(inp, target, input_lengths, target_lengths, reduction='none') 2724*da0073e9SAndroid Build Coastguard Worker self.assertTrue((res == torch.inf).all().item()) 2725*da0073e9SAndroid Build Coastguard Worker res.sum().backward() 2726*da0073e9SAndroid Build Coastguard Worker self.assertTrue((inp.grad == 0).all().item()) 2727*da0073e9SAndroid Build Coastguard Worker 2728*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_CUDA, 'CUDA not available') 2729*da0073e9SAndroid Build Coastguard Worker def test_CTCLoss_zero_infinity(self): 2730*da0073e9SAndroid Build Coastguard Worker target_lengths = [60, 25, 20] 2731*da0073e9SAndroid Build Coastguard Worker input_lengths = [50, 50, 50] 2732*da0073e9SAndroid Build Coastguard Worker targets = torch.randint(1, 15, (sum(target_lengths),), dtype=torch.int, device='cuda') 2733*da0073e9SAndroid Build Coastguard Worker log_probs = torch.randn(50, 3, 15, dtype=torch.float, device='cuda').log_softmax(2).requires_grad_() 2734*da0073e9SAndroid Build Coastguard Worker res = torch.nn.functional.ctc_loss(log_probs, targets, input_lengths, target_lengths, 2735*da0073e9SAndroid Build Coastguard Worker reduction='sum', zero_infinity=True) 2736*da0073e9SAndroid Build Coastguard Worker with torch.backends.cudnn.flags(enabled=False): 2737*da0073e9SAndroid Build Coastguard Worker res2 = torch.nn.functional.ctc_loss(log_probs, targets.cuda().long(), input_lengths, target_lengths, 2738*da0073e9SAndroid Build Coastguard Worker reduction='sum', zero_infinity=True) 2739*da0073e9SAndroid Build Coastguard Worker res_cpu = torch.nn.functional.ctc_loss(log_probs.cpu(), targets.cpu(), input_lengths, target_lengths, 2740*da0073e9SAndroid Build Coastguard Worker reduction='sum', zero_infinity=True) 2741*da0073e9SAndroid Build Coastguard Worker 2742*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res2, res, atol=1e-4, rtol=0) 2743*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res_cpu, res.cpu(), atol=1e-4, rtol=0) 2744*da0073e9SAndroid Build Coastguard Worker g1, = torch.autograd.grad(res, log_probs) 2745*da0073e9SAndroid Build Coastguard Worker g2, = torch.autograd.grad(res2, log_probs) 2746*da0073e9SAndroid Build Coastguard Worker g3, = torch.autograd.grad(res_cpu, log_probs) 2747*da0073e9SAndroid Build Coastguard Worker self.assertEqual(g2, g3, atol=1e-4, rtol=0) 2748*da0073e9SAndroid Build Coastguard Worker self.assertEqual(g1, g2, atol=1e-4, rtol=0) 2749*da0073e9SAndroid Build Coastguard Worker self.assertTrue((g1 == g1).all().item()) # check that we don't have NaN 2750*da0073e9SAndroid Build Coastguard Worker 2751*da0073e9SAndroid Build Coastguard Worker def test_RNN_cell_no_broadcasting(self): 2752*da0073e9SAndroid Build Coastguard Worker def test(cell_module, input, hx, input_size, hidden_size): 2753*da0073e9SAndroid Build Coastguard Worker cell = cell_module(input_size, hidden_size) 2754*da0073e9SAndroid Build Coastguard Worker self.assertRaises(RuntimeError, lambda: cell(input, hx)) 2755*da0073e9SAndroid Build Coastguard Worker 2756*da0073e9SAndroid Build Coastguard Worker def test_all(hidden_size, bad_hx, good_hx, input_size, input): 2757*da0073e9SAndroid Build Coastguard Worker test(nn.RNNCell, input, bad_hx, input_size, hidden_size) 2758*da0073e9SAndroid Build Coastguard Worker test(nn.GRUCell, input, bad_hx, input_size, hidden_size) 2759*da0073e9SAndroid Build Coastguard Worker test(nn.LSTMCell, input, (bad_hx, good_hx), input_size, hidden_size) 2760*da0073e9SAndroid Build Coastguard Worker test(nn.LSTMCell, input, (good_hx, bad_hx), input_size, hidden_size) 2761*da0073e9SAndroid Build Coastguard Worker 2762*da0073e9SAndroid Build Coastguard Worker hidden_size = 20 2763*da0073e9SAndroid Build Coastguard Worker input_size = 10 2764*da0073e9SAndroid Build Coastguard Worker input = torch.randn(3, input_size) 2765*da0073e9SAndroid Build Coastguard Worker bad_hx = torch.randn(1, hidden_size) 2766*da0073e9SAndroid Build Coastguard Worker good_hx = torch.randn(3, hidden_size) 2767*da0073e9SAndroid Build Coastguard Worker 2768*da0073e9SAndroid Build Coastguard Worker # Test hidden/input batch size broadcasting 2769*da0073e9SAndroid Build Coastguard Worker test_all(hidden_size, bad_hx, good_hx, input_size, input) 2770*da0073e9SAndroid Build Coastguard Worker 2771*da0073e9SAndroid Build Coastguard Worker # Test hx's hidden_size vs module's hidden_size broadcasting 2772*da0073e9SAndroid Build Coastguard Worker bad_hx = torch.randn(3, 1) 2773*da0073e9SAndroid Build Coastguard Worker test_all(hidden_size, bad_hx, good_hx, input_size, input) 2774*da0073e9SAndroid Build Coastguard Worker 2775*da0073e9SAndroid Build Coastguard Worker # Test input's input_size vs module's input_size broadcasting 2776*da0073e9SAndroid Build Coastguard Worker bad_input = torch.randn(3, 1) 2777*da0073e9SAndroid Build Coastguard Worker test_all(hidden_size, good_hx, good_hx, input_size, bad_input) 2778*da0073e9SAndroid Build Coastguard Worker 2779*da0073e9SAndroid Build Coastguard Worker def test_LSTM_cell(self): 2780*da0073e9SAndroid Build Coastguard Worker # this is just a smoke test; these modules are implemented through 2781*da0073e9SAndroid Build Coastguard Worker # autograd so no Jacobian test is needed 2782*da0073e9SAndroid Build Coastguard Worker for bias in (True, False): 2783*da0073e9SAndroid Build Coastguard Worker input = torch.randn(3, 10) 2784*da0073e9SAndroid Build Coastguard Worker hx = torch.randn(3, 20) 2785*da0073e9SAndroid Build Coastguard Worker cx = torch.randn(3, 20) 2786*da0073e9SAndroid Build Coastguard Worker lstm = nn.LSTMCell(10, 20, bias=bias) 2787*da0073e9SAndroid Build Coastguard Worker for _ in range(6): 2788*da0073e9SAndroid Build Coastguard Worker hx, cx = lstm(input, (hx, cx)) 2789*da0073e9SAndroid Build Coastguard Worker 2790*da0073e9SAndroid Build Coastguard Worker (hx + cx).sum().backward() 2791*da0073e9SAndroid Build Coastguard Worker 2792*da0073e9SAndroid Build Coastguard Worker def test_LSTM_cell_forward_input_size(self): 2793*da0073e9SAndroid Build Coastguard Worker input = torch.randn(3, 11) 2794*da0073e9SAndroid Build Coastguard Worker hx = torch.randn(3, 20) 2795*da0073e9SAndroid Build Coastguard Worker cx = torch.randn(3, 20) 2796*da0073e9SAndroid Build Coastguard Worker lstm = nn.LSTMCell(10, 20) 2797*da0073e9SAndroid Build Coastguard Worker self.assertRaises(Exception, lambda: lstm(input, (hx, cx))) 2798*da0073e9SAndroid Build Coastguard Worker 2799*da0073e9SAndroid Build Coastguard Worker def test_LSTM_cell_forward_hidden_size(self): 2800*da0073e9SAndroid Build Coastguard Worker input = torch.randn(3, 10) 2801*da0073e9SAndroid Build Coastguard Worker hx = torch.randn(3, 21) 2802*da0073e9SAndroid Build Coastguard Worker cx = torch.randn(3, 20) 2803*da0073e9SAndroid Build Coastguard Worker lstm = nn.LSTMCell(10, 20) 2804*da0073e9SAndroid Build Coastguard Worker self.assertRaises(Exception, lambda: lstm(input, (hx, cx))) 2805*da0073e9SAndroid Build Coastguard Worker self.assertRaises(Exception, lambda: lstm(input, (cx, hx))) 2806*da0073e9SAndroid Build Coastguard Worker 2807*da0073e9SAndroid Build Coastguard Worker 2808*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_CUDA, 'CUDA not available') 2809*da0073e9SAndroid Build Coastguard Worker def test_pack_sequence_batch_sizes_throw(self): 2810*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(ValueError, r"batch_sizes should always be on CPU"): 2811*da0073e9SAndroid Build Coastguard Worker m = nn.LSTM(3, 4, bidirectional=True, num_layers=2).to('cuda') 2812*da0073e9SAndroid Build Coastguard Worker a = torch.rand(5, 3, device='cuda') 2813*da0073e9SAndroid Build Coastguard Worker b = torch.tensor([1, 1, 1, 1, 1], device='cuda') 2814*da0073e9SAndroid Build Coastguard Worker input = nn.utils.rnn.PackedSequence(a, b) 2815*da0073e9SAndroid Build Coastguard Worker 2816*da0073e9SAndroid Build Coastguard Worker def test_Transformer_cell(self): 2817*da0073e9SAndroid Build Coastguard Worker # this is just a smoke test; these modules are implemented through 2818*da0073e9SAndroid Build Coastguard Worker # autograd so no Jacobian test is needed 2819*da0073e9SAndroid Build Coastguard Worker d_model = 512 2820*da0073e9SAndroid Build Coastguard Worker nhead = 16 2821*da0073e9SAndroid Build Coastguard Worker num_encoder_layers = 4 2822*da0073e9SAndroid Build Coastguard Worker num_decoder_layers = 3 2823*da0073e9SAndroid Build Coastguard Worker dim_feedforward = 256 2824*da0073e9SAndroid Build Coastguard Worker dropout = 0.3 2825*da0073e9SAndroid Build Coastguard Worker bsz = 8 2826*da0073e9SAndroid Build Coastguard Worker seq_length = 35 2827*da0073e9SAndroid Build Coastguard Worker tgt_length = 15 2828*da0073e9SAndroid Build Coastguard Worker for batch_first, src_size, tgt_size in zip((True, False), 2829*da0073e9SAndroid Build Coastguard Worker [(bsz, seq_length, d_model), 2830*da0073e9SAndroid Build Coastguard Worker (seq_length, bsz, d_model)], 2831*da0073e9SAndroid Build Coastguard Worker [(bsz, tgt_length, d_model), 2832*da0073e9SAndroid Build Coastguard Worker (tgt_length, bsz, d_model)]): 2833*da0073e9SAndroid Build Coastguard Worker transformer = nn.Transformer(d_model, nhead, num_encoder_layers, num_decoder_layers, 2834*da0073e9SAndroid Build Coastguard Worker dim_feedforward, dropout, batch_first=batch_first, 2835*da0073e9SAndroid Build Coastguard Worker dtype=torch.double) 2836*da0073e9SAndroid Build Coastguard Worker src = torch.randn(src_size, dtype=torch.double) 2837*da0073e9SAndroid Build Coastguard Worker src_mask = transformer.generate_square_subsequent_mask(seq_length).double() 2838*da0073e9SAndroid Build Coastguard Worker tgt = torch.randn(tgt_size, dtype=torch.double) 2839*da0073e9SAndroid Build Coastguard Worker tgt_mask = transformer.generate_square_subsequent_mask(tgt_length).double() 2840*da0073e9SAndroid Build Coastguard Worker memory_mask = torch.randn(tgt_length, seq_length).double() 2841*da0073e9SAndroid Build Coastguard Worker src_key_padding_mask = torch.rand(bsz, seq_length) >= 0.5 2842*da0073e9SAndroid Build Coastguard Worker tgt_key_padding_mask = torch.rand(bsz, tgt_length) >= 0.5 2843*da0073e9SAndroid Build Coastguard Worker memory_key_padding_mask = torch.rand(bsz, seq_length) >= 0.5 2844*da0073e9SAndroid Build Coastguard Worker 2845*da0073e9SAndroid Build Coastguard Worker output = transformer(src, tgt, 2846*da0073e9SAndroid Build Coastguard Worker src_mask=src_mask, 2847*da0073e9SAndroid Build Coastguard Worker tgt_mask=tgt_mask, 2848*da0073e9SAndroid Build Coastguard Worker memory_mask=memory_mask, 2849*da0073e9SAndroid Build Coastguard Worker src_key_padding_mask=src_key_padding_mask, 2850*da0073e9SAndroid Build Coastguard Worker tgt_key_padding_mask=tgt_key_padding_mask, 2851*da0073e9SAndroid Build Coastguard Worker memory_key_padding_mask=memory_key_padding_mask) 2852*da0073e9SAndroid Build Coastguard Worker output.sum().backward() 2853*da0073e9SAndroid Build Coastguard Worker 2854*da0073e9SAndroid Build Coastguard Worker def test_transformerdecoderlayer(self): 2855*da0073e9SAndroid Build Coastguard Worker # this is a deterministic test for TransformerDecoderLayer 2856*da0073e9SAndroid Build Coastguard Worker d_model = 4 2857*da0073e9SAndroid Build Coastguard Worker nhead = 2 2858*da0073e9SAndroid Build Coastguard Worker dim_feedforward = 16 2859*da0073e9SAndroid Build Coastguard Worker dropout = 0.0 2860*da0073e9SAndroid Build Coastguard Worker bsz = 2 2861*da0073e9SAndroid Build Coastguard Worker seq_length = 5 2862*da0073e9SAndroid Build Coastguard Worker tgt_length = 3 2863*da0073e9SAndroid Build Coastguard Worker 2864*da0073e9SAndroid Build Coastguard Worker for batch_first in (False, True): 2865*da0073e9SAndroid Build Coastguard Worker def perm_fn(x): 2866*da0073e9SAndroid Build Coastguard Worker return x.transpose(1, 0) if batch_first else x 2867*da0073e9SAndroid Build Coastguard Worker 2868*da0073e9SAndroid Build Coastguard Worker model = nn.TransformerDecoderLayer(d_model, nhead, dim_feedforward, dropout, 2869*da0073e9SAndroid Build Coastguard Worker batch_first=batch_first) 2870*da0073e9SAndroid Build Coastguard Worker 2871*da0073e9SAndroid Build Coastguard Worker # set constant weights of the model 2872*da0073e9SAndroid Build Coastguard Worker for idx, p in enumerate(model.parameters()): 2873*da0073e9SAndroid Build Coastguard Worker x = p.data 2874*da0073e9SAndroid Build Coastguard Worker sz = x.view(-1).size(0) 2875*da0073e9SAndroid Build Coastguard Worker shape = x.shape 2876*da0073e9SAndroid Build Coastguard Worker x = torch.cos(torch.arange(0, sz).float().view(shape)) 2877*da0073e9SAndroid Build Coastguard Worker p.data.copy_(x) 2878*da0073e9SAndroid Build Coastguard Worker 2879*da0073e9SAndroid Build Coastguard Worker # deterministic input 2880*da0073e9SAndroid Build Coastguard Worker decoder_input = torch.tensor([[[20., 30., 40., 50.]]]) 2881*da0073e9SAndroid Build Coastguard Worker memory_input = torch.tensor([[[60., 70., 80., 90.]]]) 2882*da0073e9SAndroid Build Coastguard Worker result = model(decoder_input, memory_input) 2883*da0073e9SAndroid Build Coastguard Worker ref_output = torch.tensor([[[2.314351, 0.094805, -0.671322, 0.101977]]]) 2884*da0073e9SAndroid Build Coastguard Worker result = result.detach().numpy() 2885*da0073e9SAndroid Build Coastguard Worker ref_output = ref_output.detach().numpy() 2886*da0073e9SAndroid Build Coastguard Worker self.assertEqual(tuple(result.shape), tuple(ref_output.shape)) 2887*da0073e9SAndroid Build Coastguard Worker np.testing.assert_allclose(result, ref_output, atol=1e-5) 2888*da0073e9SAndroid Build Coastguard Worker 2889*da0073e9SAndroid Build Coastguard Worker # deterministic input 2890*da0073e9SAndroid Build Coastguard Worker decoder_input = perm_fn(torch.tensor([[[9., 10., 11., 12.]], 2891*da0073e9SAndroid Build Coastguard Worker [[11., 12., 13., 14.]]])) 2892*da0073e9SAndroid Build Coastguard Worker memory_input = torch.tensor([[[1., 2., 3., 4.]]]) 2893*da0073e9SAndroid Build Coastguard Worker result = model(decoder_input, memory_input) 2894*da0073e9SAndroid Build Coastguard Worker result = result.detach().numpy() 2895*da0073e9SAndroid Build Coastguard Worker ref_output = perm_fn(torch.tensor([[[2.422245, 0.051716, -0.606338, -0.024756]], 2896*da0073e9SAndroid Build Coastguard Worker [[2.422245, 0.051716, -0.606338, -0.024756]]])) 2897*da0073e9SAndroid Build Coastguard Worker ref_output = ref_output.detach().numpy() 2898*da0073e9SAndroid Build Coastguard Worker self.assertEqual(tuple(result.shape), tuple(ref_output.shape)) 2899*da0073e9SAndroid Build Coastguard Worker np.testing.assert_allclose(result, ref_output, atol=1e-5) 2900*da0073e9SAndroid Build Coastguard Worker 2901*da0073e9SAndroid Build Coastguard Worker # deterministic input 2902*da0073e9SAndroid Build Coastguard Worker decoder_input = perm_fn(torch.tensor([[[1., 2., 3., 4.]], 2903*da0073e9SAndroid Build Coastguard Worker [[5., 6., 7., 8.]]])) 2904*da0073e9SAndroid Build Coastguard Worker memory_input = perm_fn(torch.tensor([[[9., 10., 11., 12.]], 2905*da0073e9SAndroid Build Coastguard Worker [[11., 12., 13., 14.]]])) 2906*da0073e9SAndroid Build Coastguard Worker result = model(decoder_input, memory_input) 2907*da0073e9SAndroid Build Coastguard Worker ref_output = perm_fn(torch.tensor([[[2.343536, 0.085561, -0.654954, 0.074991]], 2908*da0073e9SAndroid Build Coastguard Worker [[2.343536, 0.085561, -0.654954, 0.074991]]])) 2909*da0073e9SAndroid Build Coastguard Worker result = result.detach().numpy() 2910*da0073e9SAndroid Build Coastguard Worker ref_output = ref_output.detach().numpy() 2911*da0073e9SAndroid Build Coastguard Worker self.assertEqual(tuple(result.shape), tuple(ref_output.shape)) 2912*da0073e9SAndroid Build Coastguard Worker np.testing.assert_allclose(result, ref_output, atol=1e-5) 2913*da0073e9SAndroid Build Coastguard Worker 2914*da0073e9SAndroid Build Coastguard Worker # deterministic input 2915*da0073e9SAndroid Build Coastguard Worker decoder_input = perm_fn(torch.tensor([[[0.4517, 0.6793, 0.5313, 0.0034], 2916*da0073e9SAndroid Build Coastguard Worker [0.2678, 0.3677, 0.4459, 0.7166]], 2917*da0073e9SAndroid Build Coastguard Worker [[0.8100, 0.3716, 0.4096, 0.1976], 2918*da0073e9SAndroid Build Coastguard Worker [0.6958, 0.8844, 0.6081, 0.8315]], 2919*da0073e9SAndroid Build Coastguard Worker [[0.0494, 0.9343, 0.5955, 0.3830], 2920*da0073e9SAndroid Build Coastguard Worker [0.5404, 0.3464, 0.9378, 0.6200]]])) 2921*da0073e9SAndroid Build Coastguard Worker memory_input = perm_fn(torch.tensor([[[0.7462, 0.6653, 0.5679, 0.4891], 2922*da0073e9SAndroid Build Coastguard Worker [0.5387, 0.1655, 0.3565, 0.0471]], 2923*da0073e9SAndroid Build Coastguard Worker [[0.8335, 0.2799, 0.5031, 0.2947], 2924*da0073e9SAndroid Build Coastguard Worker [0.1402, 0.0318, 0.7636, 0.1346]], 2925*da0073e9SAndroid Build Coastguard Worker [[0.6333, 0.9344, 0.1376, 0.9938], 2926*da0073e9SAndroid Build Coastguard Worker [0.8924, 0.2872, 0.6692, 0.2944]], 2927*da0073e9SAndroid Build Coastguard Worker [[0.9897, 0.6915, 0.3154, 0.1733], 2928*da0073e9SAndroid Build Coastguard Worker [0.8645, 0.3513, 0.3064, 0.0767]], 2929*da0073e9SAndroid Build Coastguard Worker [[0.8117, 0.2366, 0.4838, 0.7881], 2930*da0073e9SAndroid Build Coastguard Worker [0.3718, 0.4945, 0.9511, 0.0864]]])) 2931*da0073e9SAndroid Build Coastguard Worker result = model(decoder_input, memory_input) 2932*da0073e9SAndroid Build Coastguard Worker ref_output = perm_fn(torch.tensor([[[2.430065, 0.027862, -0.601136, -0.073096], 2933*da0073e9SAndroid Build Coastguard Worker [2.431935, 0.028907, -0.599809, -0.072488]], 2934*da0073e9SAndroid Build Coastguard Worker [[2.428457, 0.027053, -0.602275, -0.073462], 2935*da0073e9SAndroid Build Coastguard Worker [2.431970, 0.029387, -0.599789, -0.071621]], 2936*da0073e9SAndroid Build Coastguard Worker [[2.431934, 0.028196, -0.599802, -0.073809], 2937*da0073e9SAndroid Build Coastguard Worker [2.432306, 0.028858, -0.599542, -0.072846]]])) 2938*da0073e9SAndroid Build Coastguard Worker result = result.detach().numpy() 2939*da0073e9SAndroid Build Coastguard Worker ref_output = ref_output.detach().numpy() 2940*da0073e9SAndroid Build Coastguard Worker self.assertEqual(tuple(result.shape), tuple(ref_output.shape)) 2941*da0073e9SAndroid Build Coastguard Worker np.testing.assert_allclose(result, ref_output, atol=1e-5) 2942*da0073e9SAndroid Build Coastguard Worker 2943*da0073e9SAndroid Build Coastguard Worker # key_padding_mask 2944*da0073e9SAndroid Build Coastguard Worker key_padding_mask = torch.zeros(2, 3) == 1 2945*da0073e9SAndroid Build Coastguard Worker result = model(decoder_input, memory_input, tgt_key_padding_mask=key_padding_mask) 2946*da0073e9SAndroid Build Coastguard Worker ref_output = perm_fn(torch.tensor([[[2.430065, 0.027862, -0.601136, -0.073096], 2947*da0073e9SAndroid Build Coastguard Worker [2.431935, 0.028907, -0.599809, -0.072488]], 2948*da0073e9SAndroid Build Coastguard Worker [[2.428457, 0.027053, -0.602275, -0.073462], 2949*da0073e9SAndroid Build Coastguard Worker [2.431970, 0.029387, -0.599789, -0.071621]], 2950*da0073e9SAndroid Build Coastguard Worker [[2.431934, 0.028196, -0.599802, -0.073809], 2951*da0073e9SAndroid Build Coastguard Worker [2.432306, 0.028858, -0.599542, -0.072846]]])) 2952*da0073e9SAndroid Build Coastguard Worker result = result.detach().numpy() 2953*da0073e9SAndroid Build Coastguard Worker ref_output = ref_output.detach().numpy() 2954*da0073e9SAndroid Build Coastguard Worker self.assertEqual(tuple(result.shape), tuple(ref_output.shape)) 2955*da0073e9SAndroid Build Coastguard Worker np.testing.assert_allclose(result, ref_output, atol=1e-5) 2956*da0073e9SAndroid Build Coastguard Worker 2957*da0073e9SAndroid Build Coastguard Worker # key_padding_mask 2958*da0073e9SAndroid Build Coastguard Worker key_padding_mask[0, 2] = 1 2959*da0073e9SAndroid Build Coastguard Worker key_padding_mask[1, 1] = 1 2960*da0073e9SAndroid Build Coastguard Worker key_padding_mask[1, 2] = 1 2961*da0073e9SAndroid Build Coastguard Worker result = model(decoder_input, memory_input, tgt_key_padding_mask=key_padding_mask) 2962*da0073e9SAndroid Build Coastguard Worker ref_output = perm_fn(torch.tensor([[[2.430025, 0.027643, -0.601164, -0.073476], 2963*da0073e9SAndroid Build Coastguard Worker [2.4323, 0.029375, -0.599553, -0.071881]], 2964*da0073e9SAndroid Build Coastguard Worker [[2.428523, 0.026838, -0.602226, -0.07391], 2965*da0073e9SAndroid Build Coastguard Worker [2.432634, 0.029842, -0.599318, -0.071253]], 2966*da0073e9SAndroid Build Coastguard Worker [[2.432278, 0.028152, -0.599555, -0.074139], 2967*da0073e9SAndroid Build Coastguard Worker [2.432659, 0.029244, -0.599294, -0.072382]]])) 2968*da0073e9SAndroid Build Coastguard Worker result = result.detach().numpy() 2969*da0073e9SAndroid Build Coastguard Worker ref_output = ref_output.detach().numpy() 2970*da0073e9SAndroid Build Coastguard Worker self.assertEqual(tuple(result.shape), tuple(ref_output.shape)) 2971*da0073e9SAndroid Build Coastguard Worker np.testing.assert_allclose(result, ref_output, atol=1e-5) 2972*da0073e9SAndroid Build Coastguard Worker 2973*da0073e9SAndroid Build Coastguard Worker # memory_key_padding_mask 2974*da0073e9SAndroid Build Coastguard Worker key_padding_mask = torch.zeros(2, 5) == 1 2975*da0073e9SAndroid Build Coastguard Worker result = model(decoder_input, memory_input, memory_key_padding_mask=key_padding_mask) 2976*da0073e9SAndroid Build Coastguard Worker ref_output = perm_fn(torch.tensor([[[2.430065, 0.027862, -0.601136, -0.073096], 2977*da0073e9SAndroid Build Coastguard Worker [2.431935, 0.028907, -0.599809, -0.072488]], 2978*da0073e9SAndroid Build Coastguard Worker [[2.428457, 0.027053, -0.602275, -0.073462], 2979*da0073e9SAndroid Build Coastguard Worker [2.431970, 0.029387, -0.599789, -0.071621]], 2980*da0073e9SAndroid Build Coastguard Worker [[2.431934, 0.028196, -0.599802, -0.073809], 2981*da0073e9SAndroid Build Coastguard Worker [2.432306, 0.028858, -0.599542, -0.072846]]])) 2982*da0073e9SAndroid Build Coastguard Worker result = result.detach().numpy() 2983*da0073e9SAndroid Build Coastguard Worker ref_output = ref_output.detach().numpy() 2984*da0073e9SAndroid Build Coastguard Worker self.assertEqual(tuple(result.shape), tuple(ref_output.shape)) 2985*da0073e9SAndroid Build Coastguard Worker np.testing.assert_allclose(result, ref_output, atol=1e-5) 2986*da0073e9SAndroid Build Coastguard Worker 2987*da0073e9SAndroid Build Coastguard Worker # memory_key_padding_mask 2988*da0073e9SAndroid Build Coastguard Worker key_padding_mask[0, 4] = 1 2989*da0073e9SAndroid Build Coastguard Worker key_padding_mask[1, 3] = 1 2990*da0073e9SAndroid Build Coastguard Worker key_padding_mask[1, 4] = 1 2991*da0073e9SAndroid Build Coastguard Worker result = model(decoder_input, memory_input, memory_key_padding_mask=key_padding_mask) 2992*da0073e9SAndroid Build Coastguard Worker ref_output = perm_fn(torch.tensor([[[2.429757, 0.027358, -0.601351, -0.073816], 2993*da0073e9SAndroid Build Coastguard Worker [2.432692, 0.028583, -0.599263, -0.073634]], 2994*da0073e9SAndroid Build Coastguard Worker [[2.428247, 0.02662, -0.602419, -0.074123], 2995*da0073e9SAndroid Build Coastguard Worker [2.432657, 0.029055, -0.599293, -0.072732]], 2996*da0073e9SAndroid Build Coastguard Worker [[2.431515, 0.027687, -0.600096, -0.074459], 2997*da0073e9SAndroid Build Coastguard Worker [2.433075, 0.028543, -0.598987, -0.073985]]])) 2998*da0073e9SAndroid Build Coastguard Worker result = result.detach().numpy() 2999*da0073e9SAndroid Build Coastguard Worker ref_output = ref_output.detach().numpy() 3000*da0073e9SAndroid Build Coastguard Worker self.assertEqual(tuple(result.shape), tuple(ref_output.shape)) 3001*da0073e9SAndroid Build Coastguard Worker np.testing.assert_allclose(result, ref_output, atol=1e-5) 3002*da0073e9SAndroid Build Coastguard Worker 3003*da0073e9SAndroid Build Coastguard Worker @set_default_dtype(torch.double) 3004*da0073e9SAndroid Build Coastguard Worker def test_transformerdecoderlayer_gelu(self): 3005*da0073e9SAndroid Build Coastguard Worker # this is a deterministic test for TransformerDecoderLayer with gelu activation 3006*da0073e9SAndroid Build Coastguard Worker d_model = 4 3007*da0073e9SAndroid Build Coastguard Worker nhead = 2 3008*da0073e9SAndroid Build Coastguard Worker dim_feedforward = 16 3009*da0073e9SAndroid Build Coastguard Worker dropout = 0.0 3010*da0073e9SAndroid Build Coastguard Worker bsz = 2 3011*da0073e9SAndroid Build Coastguard Worker seq_length = 5 3012*da0073e9SAndroid Build Coastguard Worker tgt_length = 3 3013*da0073e9SAndroid Build Coastguard Worker 3014*da0073e9SAndroid Build Coastguard Worker for activation, batch_first in product(('gelu', F.gelu, nn.GELU()), (True, False)): 3015*da0073e9SAndroid Build Coastguard Worker def perm_fn(x): 3016*da0073e9SAndroid Build Coastguard Worker return x.transpose(1, 0) if batch_first else x 3017*da0073e9SAndroid Build Coastguard Worker 3018*da0073e9SAndroid Build Coastguard Worker model = nn.TransformerDecoderLayer(d_model, nhead, dim_feedforward, dropout, 3019*da0073e9SAndroid Build Coastguard Worker activation, batch_first=batch_first) 3020*da0073e9SAndroid Build Coastguard Worker 3021*da0073e9SAndroid Build Coastguard Worker # set constant weights of the model 3022*da0073e9SAndroid Build Coastguard Worker for idx, p in enumerate(model.parameters()): 3023*da0073e9SAndroid Build Coastguard Worker x = p.data 3024*da0073e9SAndroid Build Coastguard Worker sz = x.view(-1).size(0) 3025*da0073e9SAndroid Build Coastguard Worker shape = x.shape 3026*da0073e9SAndroid Build Coastguard Worker x = torch.cos(torch.arange(0, sz).float().view(shape)) 3027*da0073e9SAndroid Build Coastguard Worker p.data.copy_(x) 3028*da0073e9SAndroid Build Coastguard Worker 3029*da0073e9SAndroid Build Coastguard Worker # deterministic input 3030*da0073e9SAndroid Build Coastguard Worker decoder_input = torch.tensor([[[20., 30., 40., 50.]]]) 3031*da0073e9SAndroid Build Coastguard Worker memory_input = torch.tensor([[[60., 70., 80., 90.]]]) 3032*da0073e9SAndroid Build Coastguard Worker result = model(decoder_input, memory_input) 3033*da0073e9SAndroid Build Coastguard Worker ref_output = torch.tensor([[[2.306435, 0.095946, -0.675796, 0.10687]]]) 3034*da0073e9SAndroid Build Coastguard Worker torch.testing.assert_close(result, ref_output, rtol=1e-5, atol=0) 3035*da0073e9SAndroid Build Coastguard Worker 3036*da0073e9SAndroid Build Coastguard Worker # deterministic input 3037*da0073e9SAndroid Build Coastguard Worker decoder_input = perm_fn(torch.tensor([[[9., 10., 11., 12.]], 3038*da0073e9SAndroid Build Coastguard Worker [[11., 12., 13., 14.]]])) 3039*da0073e9SAndroid Build Coastguard Worker memory_input = perm_fn(torch.tensor([[[1., 2., 3., 4.]]])) 3040*da0073e9SAndroid Build Coastguard Worker result = model(decoder_input, memory_input) 3041*da0073e9SAndroid Build Coastguard Worker ref_output = perm_fn(torch.tensor([[[2.415448, 0.054389, -0.610932, -0.0156613]], 3042*da0073e9SAndroid Build Coastguard Worker [[2.415448, 0.054389, -0.610932, -0.0156613]]])) 3043*da0073e9SAndroid Build Coastguard Worker torch.testing.assert_close(result, ref_output, rtol=1e-5, atol=0) 3044*da0073e9SAndroid Build Coastguard Worker 3045*da0073e9SAndroid Build Coastguard Worker # deterministic input 3046*da0073e9SAndroid Build Coastguard Worker decoder_input = perm_fn(torch.tensor([[[1., 2., 3., 4.]], 3047*da0073e9SAndroid Build Coastguard Worker [[5., 6., 7., 8.]]])) 3048*da0073e9SAndroid Build Coastguard Worker memory_input = perm_fn(torch.tensor([[[9., 10., 11., 12.]], 3049*da0073e9SAndroid Build Coastguard Worker [[11., 12., 13., 14.]]])) 3050*da0073e9SAndroid Build Coastguard Worker result = model(decoder_input, memory_input) 3051*da0073e9SAndroid Build Coastguard Worker ref_output = perm_fn(torch.tensor([[[2.338531, 0.087709, -0.65776, 0.080646]], 3052*da0073e9SAndroid Build Coastguard Worker [[2.338531, 0.087709, -0.65776, 0.080646]]])) 3053*da0073e9SAndroid Build Coastguard Worker torch.testing.assert_close(result, ref_output, rtol=1e-5, atol=0) 3054*da0073e9SAndroid Build Coastguard Worker 3055*da0073e9SAndroid Build Coastguard Worker # deterministic input 3056*da0073e9SAndroid Build Coastguard Worker decoder_input = perm_fn(torch.tensor([[[0.4517, 0.6793, 0.5313, 0.0034], 3057*da0073e9SAndroid Build Coastguard Worker [0.2678, 0.3677, 0.4459, 0.7166]], 3058*da0073e9SAndroid Build Coastguard Worker [[0.8100, 0.3716, 0.4096, 0.1976], 3059*da0073e9SAndroid Build Coastguard Worker [0.6958, 0.8844, 0.6081, 0.8315]], 3060*da0073e9SAndroid Build Coastguard Worker [[0.0494, 0.9343, 0.5955, 0.3830], 3061*da0073e9SAndroid Build Coastguard Worker [0.5404, 0.3464, 0.9378, 0.6200]]])) 3062*da0073e9SAndroid Build Coastguard Worker memory_input = perm_fn(torch.tensor([[[0.7462, 0.6653, 0.5679, 0.4891], 3063*da0073e9SAndroid Build Coastguard Worker [0.5387, 0.1655, 0.3565, 0.0471]], 3064*da0073e9SAndroid Build Coastguard Worker [[0.8335, 0.2799, 0.5031, 0.2947], 3065*da0073e9SAndroid Build Coastguard Worker [0.1402, 0.0318, 0.7636, 0.1346]], 3066*da0073e9SAndroid Build Coastguard Worker [[0.6333, 0.9344, 0.1376, 0.9938], 3067*da0073e9SAndroid Build Coastguard Worker [0.8924, 0.2872, 0.6692, 0.2944]], 3068*da0073e9SAndroid Build Coastguard Worker [[0.9897, 0.6915, 0.3154, 0.1733], 3069*da0073e9SAndroid Build Coastguard Worker [0.8645, 0.3513, 0.3064, 0.0767]], 3070*da0073e9SAndroid Build Coastguard Worker [[0.8117, 0.2366, 0.4838, 0.7881], 3071*da0073e9SAndroid Build Coastguard Worker [0.3718, 0.4945, 0.9511, 0.0864]]])) 3072*da0073e9SAndroid Build Coastguard Worker result = model(decoder_input, memory_input) 3073*da0073e9SAndroid Build Coastguard Worker ref_output = perm_fn(torch.tensor([[[2.42049104, 0.03443088, -0.60793706, -0.05436271], 3074*da0073e9SAndroid Build Coastguard Worker [2.42210631, 0.03546578, -0.60679895, -0.05357488]], 3075*da0073e9SAndroid Build Coastguard Worker [[2.41907674, 0.0336104, -0.60892977, -0.05490462], 3076*da0073e9SAndroid Build Coastguard Worker [2.42216881, 0.03586554, -0.6067524, -0.05289126]], 3077*da0073e9SAndroid Build Coastguard Worker [[2.42205716, 0.03488046, -0.60683681, -0.05460596], 3078*da0073e9SAndroid Build Coastguard Worker [2.42240309, 0.0354595, -0.60659063, -0.05378816]]])) 3079*da0073e9SAndroid Build Coastguard Worker torch.testing.assert_close(result, ref_output, rtol=1e-5, atol=0) 3080*da0073e9SAndroid Build Coastguard Worker 3081*da0073e9SAndroid Build Coastguard Worker @skipIfRocm(msg='Large numerical errors') 3082*da0073e9SAndroid Build Coastguard Worker def test_transformerdecoder(self): 3083*da0073e9SAndroid Build Coastguard Worker def get_a_test_layer(use_cuda, activation, batch_first=False): 3084*da0073e9SAndroid Build Coastguard Worker d_model = 4 3085*da0073e9SAndroid Build Coastguard Worker nhead = 2 3086*da0073e9SAndroid Build Coastguard Worker dim_feedforward = 16 3087*da0073e9SAndroid Build Coastguard Worker dropout = 0.0 3088*da0073e9SAndroid Build Coastguard Worker device = torch.device("cuda" if use_cuda else "cpu") 3089*da0073e9SAndroid Build Coastguard Worker 3090*da0073e9SAndroid Build Coastguard Worker layer = nn.TransformerDecoderLayer( 3091*da0073e9SAndroid Build Coastguard Worker d_model, 3092*da0073e9SAndroid Build Coastguard Worker nhead, 3093*da0073e9SAndroid Build Coastguard Worker dim_feedforward=dim_feedforward, 3094*da0073e9SAndroid Build Coastguard Worker dropout=dropout, 3095*da0073e9SAndroid Build Coastguard Worker activation=activation, 3096*da0073e9SAndroid Build Coastguard Worker batch_first=batch_first).to(device) 3097*da0073e9SAndroid Build Coastguard Worker 3098*da0073e9SAndroid Build Coastguard Worker with torch.no_grad(): 3099*da0073e9SAndroid Build Coastguard Worker # set constant weights of the model 3100*da0073e9SAndroid Build Coastguard Worker for idx, p in enumerate(layer.parameters()): 3101*da0073e9SAndroid Build Coastguard Worker x = p.data 3102*da0073e9SAndroid Build Coastguard Worker sz = x.view(-1).size(0) 3103*da0073e9SAndroid Build Coastguard Worker shape = x.shape 3104*da0073e9SAndroid Build Coastguard Worker x = torch.cos(torch.arange(0, sz).float().view(shape)) 3105*da0073e9SAndroid Build Coastguard Worker p.data.copy_(x) 3106*da0073e9SAndroid Build Coastguard Worker 3107*da0073e9SAndroid Build Coastguard Worker return layer 3108*da0073e9SAndroid Build Coastguard Worker 3109*da0073e9SAndroid Build Coastguard Worker # this is a deterministic test for TransformerDecoder 3110*da0073e9SAndroid Build Coastguard Worker for batch_first in (False, True): 3111*da0073e9SAndroid Build Coastguard Worker def perm_fn(x): 3112*da0073e9SAndroid Build Coastguard Worker return x.transpose(1, 0) if batch_first else x 3113*da0073e9SAndroid Build Coastguard Worker activation = F.relu 3114*da0073e9SAndroid Build Coastguard Worker use_cuda = torch.cuda.is_available() 3115*da0073e9SAndroid Build Coastguard Worker device = torch.device("cuda" if use_cuda else "cpu") 3116*da0073e9SAndroid Build Coastguard Worker 3117*da0073e9SAndroid Build Coastguard Worker decoder_layer = get_a_test_layer(use_cuda=use_cuda, activation=activation, 3118*da0073e9SAndroid Build Coastguard Worker batch_first=batch_first) 3119*da0073e9SAndroid Build Coastguard Worker 3120*da0073e9SAndroid Build Coastguard Worker model = nn.TransformerDecoder(decoder_layer, 1).to(device) 3121*da0073e9SAndroid Build Coastguard Worker 3122*da0073e9SAndroid Build Coastguard Worker # deterministic input 3123*da0073e9SAndroid Build Coastguard Worker decoder_input = torch.tensor([[[20., 30., 40., 50.]]]).to(device) 3124*da0073e9SAndroid Build Coastguard Worker memory_input = torch.tensor([[[60., 70., 80., 90.]]]).to(device) 3125*da0073e9SAndroid Build Coastguard Worker result = model(decoder_input, memory_input) 3126*da0073e9SAndroid Build Coastguard Worker ref_output = torch.tensor( 3127*da0073e9SAndroid Build Coastguard Worker [[[2.314351, 0.094805, -0.671322, 0.101977]]]).to(device) 3128*da0073e9SAndroid Build Coastguard Worker self.assertEqual(tuple(result.shape), tuple(ref_output.shape)) 3129*da0073e9SAndroid Build Coastguard Worker torch.testing.assert_close(result, ref_output, rtol=1e-7, atol=1e-3) 3130*da0073e9SAndroid Build Coastguard Worker 3131*da0073e9SAndroid Build Coastguard Worker # deterministic input 3132*da0073e9SAndroid Build Coastguard Worker decoder_input = perm_fn(torch.tensor([[[9., 10., 11., 12.]], 3133*da0073e9SAndroid Build Coastguard Worker [[11., 12., 13., 14.]]])).to(device) 3134*da0073e9SAndroid Build Coastguard Worker memory_input = perm_fn(torch.tensor([[[1., 2., 3., 4.]]])).to(device) 3135*da0073e9SAndroid Build Coastguard Worker result = model(decoder_input, memory_input) 3136*da0073e9SAndroid Build Coastguard Worker ref_output = perm_fn(torch.tensor([[[2.422245, 0.051716, -0.606338, -0.024756]], 3137*da0073e9SAndroid Build Coastguard Worker [[2.422245, 0.051716, -0.606338, -0.024756]]] 3138*da0073e9SAndroid Build Coastguard Worker )).to(device) 3139*da0073e9SAndroid Build Coastguard Worker self.assertEqual(tuple(result.shape), tuple(ref_output.shape)) 3140*da0073e9SAndroid Build Coastguard Worker torch.testing.assert_close(result, ref_output, rtol=1e-7, atol=1e-4) 3141*da0073e9SAndroid Build Coastguard Worker 3142*da0073e9SAndroid Build Coastguard Worker # deterministic input 3143*da0073e9SAndroid Build Coastguard Worker decoder_input = perm_fn(torch.tensor([[[1., 2., 3., 4.]], 3144*da0073e9SAndroid Build Coastguard Worker [[5., 6., 7., 8.]]])).to(device) 3145*da0073e9SAndroid Build Coastguard Worker memory_input = perm_fn(torch.tensor([[[9., 10., 11., 12.]], 3146*da0073e9SAndroid Build Coastguard Worker [[11., 12., 13., 14.]]])).to(device) 3147*da0073e9SAndroid Build Coastguard Worker result = model(decoder_input, memory_input) 3148*da0073e9SAndroid Build Coastguard Worker ref_output = perm_fn(torch.tensor([[[2.343536, 0.085561, -0.654954, 0.074991]], 3149*da0073e9SAndroid Build Coastguard Worker [[2.343536, 0.085561, -0.654954, 0.074991]]] 3150*da0073e9SAndroid Build Coastguard Worker )).to(device) 3151*da0073e9SAndroid Build Coastguard Worker self.assertEqual(tuple(result.shape), tuple(ref_output.shape)) 3152*da0073e9SAndroid Build Coastguard Worker torch.testing.assert_close(result, ref_output, rtol=1e-7, atol=1e-4) 3153*da0073e9SAndroid Build Coastguard Worker 3154*da0073e9SAndroid Build Coastguard Worker # deterministic input 3155*da0073e9SAndroid Build Coastguard Worker decoder_input = perm_fn(torch.tensor([[[0.4517, 0.6793, 0.5313, 0.0034], 3156*da0073e9SAndroid Build Coastguard Worker [0.2678, 0.3677, 0.4459, 0.7166]], 3157*da0073e9SAndroid Build Coastguard Worker [[0.8100, 0.3716, 0.4096, 0.1976], 3158*da0073e9SAndroid Build Coastguard Worker [0.6958, 0.8844, 0.6081, 0.8315]], 3159*da0073e9SAndroid Build Coastguard Worker [[0.0494, 0.9343, 0.5955, 0.3830], 3160*da0073e9SAndroid Build Coastguard Worker [0.5404, 0.3464, 0.9378, 0.6200]]] 3161*da0073e9SAndroid Build Coastguard Worker )).to(device) 3162*da0073e9SAndroid Build Coastguard Worker memory_input = perm_fn(torch.tensor([[[0.7462, 0.6653, 0.5679, 0.4891], 3163*da0073e9SAndroid Build Coastguard Worker [0.5387, 0.1655, 0.3565, 0.0471]], 3164*da0073e9SAndroid Build Coastguard Worker [[0.8335, 0.2799, 0.5031, 0.2947], 3165*da0073e9SAndroid Build Coastguard Worker [0.1402, 0.0318, 0.7636, 0.1346]], 3166*da0073e9SAndroid Build Coastguard Worker [[0.6333, 0.9344, 0.1376, 0.9938], 3167*da0073e9SAndroid Build Coastguard Worker [0.8924, 0.2872, 0.6692, 0.2944]], 3168*da0073e9SAndroid Build Coastguard Worker [[0.9897, 0.6915, 0.3154, 0.1733], 3169*da0073e9SAndroid Build Coastguard Worker [0.8645, 0.3513, 0.3064, 0.0767]], 3170*da0073e9SAndroid Build Coastguard Worker [[0.8117, 0.2366, 0.4838, 0.7881], 3171*da0073e9SAndroid Build Coastguard Worker [0.3718, 0.4945, 0.9511, 0.0864]]] 3172*da0073e9SAndroid Build Coastguard Worker )).to(device) 3173*da0073e9SAndroid Build Coastguard Worker result = model(decoder_input, memory_input) 3174*da0073e9SAndroid Build Coastguard Worker ref_output = perm_fn(torch.tensor([[[2.430065, 0.027862, -0.601136, -0.073096], 3175*da0073e9SAndroid Build Coastguard Worker [2.431935, 0.028907, -0.599809, -0.072488]], 3176*da0073e9SAndroid Build Coastguard Worker [[2.428457, 0.027053, -0.602275, -0.073462], 3177*da0073e9SAndroid Build Coastguard Worker [2.431970, 0.029387, -0.599789, -0.071621]], 3178*da0073e9SAndroid Build Coastguard Worker [[2.431934, 0.028196, -0.599802, -0.073809], 3179*da0073e9SAndroid Build Coastguard Worker [2.432306, 0.028858, -0.599542, -0.072846]]] 3180*da0073e9SAndroid Build Coastguard Worker )).to(device) 3181*da0073e9SAndroid Build Coastguard Worker self.assertEqual(tuple(result.shape), tuple(ref_output.shape)) 3182*da0073e9SAndroid Build Coastguard Worker torch.testing.assert_close(result, ref_output, rtol=1e-7, atol=1e-5) 3183*da0073e9SAndroid Build Coastguard Worker 3184*da0073e9SAndroid Build Coastguard Worker # key_padding_mask 3185*da0073e9SAndroid Build Coastguard Worker key_padding_mask = torch.zeros(2, 3).to(device) == 1 3186*da0073e9SAndroid Build Coastguard Worker result = model(decoder_input, memory_input, 3187*da0073e9SAndroid Build Coastguard Worker tgt_key_padding_mask=key_padding_mask) 3188*da0073e9SAndroid Build Coastguard Worker ref_output = perm_fn(torch.tensor([[[2.430065, 0.027862, -0.601136, -0.073096], 3189*da0073e9SAndroid Build Coastguard Worker [2.431935, 0.028907, -0.599809, -0.072488]], 3190*da0073e9SAndroid Build Coastguard Worker [[2.428457, 0.027053, -0.602275, -0.073462], 3191*da0073e9SAndroid Build Coastguard Worker [2.431970, 0.029387, -0.599789, -0.071621]], 3192*da0073e9SAndroid Build Coastguard Worker [[2.431934, 0.028196, -0.599802, -0.073809], 3193*da0073e9SAndroid Build Coastguard Worker [2.432306, 0.028858, -0.599542, -0.072846]]] 3194*da0073e9SAndroid Build Coastguard Worker )).to(device) 3195*da0073e9SAndroid Build Coastguard Worker self.assertEqual(tuple(result.shape), tuple(ref_output.shape)) 3196*da0073e9SAndroid Build Coastguard Worker torch.testing.assert_close(result, ref_output, rtol=1e-7, atol=1e-5) 3197*da0073e9SAndroid Build Coastguard Worker 3198*da0073e9SAndroid Build Coastguard Worker # key_padding_mask 3199*da0073e9SAndroid Build Coastguard Worker key_padding_mask[0, 2] = 1 3200*da0073e9SAndroid Build Coastguard Worker key_padding_mask[1, 1] = 1 3201*da0073e9SAndroid Build Coastguard Worker key_padding_mask[1, 2] = 1 3202*da0073e9SAndroid Build Coastguard Worker result = model(decoder_input, memory_input, 3203*da0073e9SAndroid Build Coastguard Worker tgt_key_padding_mask=key_padding_mask) 3204*da0073e9SAndroid Build Coastguard Worker ref_output = perm_fn(torch.tensor([[[2.430025, 0.027643, -0.601164, -0.073476], 3205*da0073e9SAndroid Build Coastguard Worker [2.4323, 0.029375, -0.599553, -0.071881]], 3206*da0073e9SAndroid Build Coastguard Worker [[2.428523, 0.026838, -0.602226, -0.07391], 3207*da0073e9SAndroid Build Coastguard Worker [2.432634, 0.029842, -0.599318, -0.071253]], 3208*da0073e9SAndroid Build Coastguard Worker [[2.432278, 0.028152, -0.599555, -0.074139], 3209*da0073e9SAndroid Build Coastguard Worker [2.432659, 0.029244, -0.599294, -0.072382]]] 3210*da0073e9SAndroid Build Coastguard Worker )).to(device) 3211*da0073e9SAndroid Build Coastguard Worker self.assertEqual(tuple(result.shape), tuple(ref_output.shape)) 3212*da0073e9SAndroid Build Coastguard Worker torch.testing.assert_close(result, ref_output, rtol=1e-7, atol=1e-5) 3213*da0073e9SAndroid Build Coastguard Worker 3214*da0073e9SAndroid Build Coastguard Worker # memory_key_padding_mask 3215*da0073e9SAndroid Build Coastguard Worker key_padding_mask = torch.zeros(2, 5).to(device) == 1 3216*da0073e9SAndroid Build Coastguard Worker result = model(decoder_input, memory_input, 3217*da0073e9SAndroid Build Coastguard Worker memory_key_padding_mask=key_padding_mask) 3218*da0073e9SAndroid Build Coastguard Worker ref_output = perm_fn(torch.tensor([[[2.430065, 0.027862, -0.601136, -0.073096], 3219*da0073e9SAndroid Build Coastguard Worker [2.431935, 0.028907, -0.599809, -0.072488]], 3220*da0073e9SAndroid Build Coastguard Worker [[2.428457, 0.027053, -0.602275, -0.073462], 3221*da0073e9SAndroid Build Coastguard Worker [2.431970, 0.029387, -0.599789, -0.071621]], 3222*da0073e9SAndroid Build Coastguard Worker [[2.431934, 0.028196, -0.599802, -0.073809], 3223*da0073e9SAndroid Build Coastguard Worker [2.432306, 0.028858, -0.599542, -0.072846]]] 3224*da0073e9SAndroid Build Coastguard Worker )).to(device) 3225*da0073e9SAndroid Build Coastguard Worker self.assertEqual(tuple(result.shape), tuple(ref_output.shape)) 3226*da0073e9SAndroid Build Coastguard Worker torch.testing.assert_close(result, ref_output, rtol=1e-7, atol=1e-5) 3227*da0073e9SAndroid Build Coastguard Worker 3228*da0073e9SAndroid Build Coastguard Worker # memory_key_padding_mask 3229*da0073e9SAndroid Build Coastguard Worker key_padding_mask[0, 4] = 1 3230*da0073e9SAndroid Build Coastguard Worker key_padding_mask[1, 3] = 1 3231*da0073e9SAndroid Build Coastguard Worker key_padding_mask[1, 4] = 1 3232*da0073e9SAndroid Build Coastguard Worker result = model(decoder_input, 3233*da0073e9SAndroid Build Coastguard Worker memory_input, 3234*da0073e9SAndroid Build Coastguard Worker memory_key_padding_mask=key_padding_mask) 3235*da0073e9SAndroid Build Coastguard Worker ref_output = perm_fn(torch.tensor([[[2.429757, 0.027358, -0.601351, -0.073816], 3236*da0073e9SAndroid Build Coastguard Worker [2.432692, 0.028583, -0.599263, -0.073634]], 3237*da0073e9SAndroid Build Coastguard Worker [[2.428247, 0.02662, -0.602419, -0.074123], 3238*da0073e9SAndroid Build Coastguard Worker [2.432657, 0.029055, -0.599293, -0.072732]], 3239*da0073e9SAndroid Build Coastguard Worker [[2.431515, 0.027687, -0.600096, -0.074459], 3240*da0073e9SAndroid Build Coastguard Worker [2.433075, 0.028543, -0.598987, -0.073985]]] 3241*da0073e9SAndroid Build Coastguard Worker )).to(device) 3242*da0073e9SAndroid Build Coastguard Worker self.assertEqual(tuple(result.shape), tuple(ref_output.shape)) 3243*da0073e9SAndroid Build Coastguard Worker torch.testing.assert_close(result, ref_output, rtol=1e-7, atol=1e-5) 3244*da0073e9SAndroid Build Coastguard Worker 3245*da0073e9SAndroid Build Coastguard Worker # multiple layers no norm 3246*da0073e9SAndroid Build Coastguard Worker model = nn.TransformerDecoder(decoder_layer, 2).to(device) 3247*da0073e9SAndroid Build Coastguard Worker 3248*da0073e9SAndroid Build Coastguard Worker # deterministic input 3249*da0073e9SAndroid Build Coastguard Worker decoder_input = torch.tensor([[[20., 30., 40., 50.]]]).to(device) 3250*da0073e9SAndroid Build Coastguard Worker memory_input = torch.tensor([[[60., 70., 80., 90.]]]).to(device) 3251*da0073e9SAndroid Build Coastguard Worker result = model(decoder_input, memory_input) 3252*da0073e9SAndroid Build Coastguard Worker ref_output = torch.tensor( 3253*da0073e9SAndroid Build Coastguard Worker [[[2.31316, 0.0950293, -0.671995, 0.102802]]]).to(device) 3254*da0073e9SAndroid Build Coastguard Worker self.assertEqual(tuple(result.shape), tuple(ref_output.shape)) 3255*da0073e9SAndroid Build Coastguard Worker torch.testing.assert_close(result, ref_output, rtol=1e-7, atol=1e-3) 3256*da0073e9SAndroid Build Coastguard Worker 3257*da0073e9SAndroid Build Coastguard Worker # multiple layers no norm 3258*da0073e9SAndroid Build Coastguard Worker model = nn.TransformerDecoder(decoder_layer, 6).to(device) 3259*da0073e9SAndroid Build Coastguard Worker 3260*da0073e9SAndroid Build Coastguard Worker # deterministic input 3261*da0073e9SAndroid Build Coastguard Worker decoder_input = perm_fn(torch.tensor([[[0.4517, 0.6793, 0.5313, 0.0034], 3262*da0073e9SAndroid Build Coastguard Worker [0.2678, 0.3677, 0.4459, 0.7166]], 3263*da0073e9SAndroid Build Coastguard Worker [[0.8100, 0.3716, 0.4096, 0.1976], 3264*da0073e9SAndroid Build Coastguard Worker [0.6958, 0.8844, 0.6081, 0.8315]], 3265*da0073e9SAndroid Build Coastguard Worker [[0.0494, 0.9343, 0.5955, 0.3830], 3266*da0073e9SAndroid Build Coastguard Worker [0.5404, 0.3464, 0.9378, 0.6200]]] 3267*da0073e9SAndroid Build Coastguard Worker )).to(device) 3268*da0073e9SAndroid Build Coastguard Worker memory_input = perm_fn(torch.tensor([[[0.7462, 0.6653, 0.5679, 0.4891], 3269*da0073e9SAndroid Build Coastguard Worker [0.5387, 0.1655, 0.3565, 0.0471]], 3270*da0073e9SAndroid Build Coastguard Worker [[0.8335, 0.2799, 0.5031, 0.2947], 3271*da0073e9SAndroid Build Coastguard Worker [0.1402, 0.0318, 0.7636, 0.1346]], 3272*da0073e9SAndroid Build Coastguard Worker [[0.6333, 0.9344, 0.1376, 0.9938], 3273*da0073e9SAndroid Build Coastguard Worker [0.8924, 0.2872, 0.6692, 0.2944]], 3274*da0073e9SAndroid Build Coastguard Worker [[0.9897, 0.6915, 0.3154, 0.1733], 3275*da0073e9SAndroid Build Coastguard Worker [0.8645, 0.3513, 0.3064, 0.0767]], 3276*da0073e9SAndroid Build Coastguard Worker [[0.8117, 0.2366, 0.4838, 0.7881], 3277*da0073e9SAndroid Build Coastguard Worker [0.3718, 0.4945, 0.9511, 0.0864]]] 3278*da0073e9SAndroid Build Coastguard Worker )).to(device) 3279*da0073e9SAndroid Build Coastguard Worker result = model(decoder_input, memory_input) 3280*da0073e9SAndroid Build Coastguard Worker ref_output = perm_fn(torch.tensor([[[2.42794, 0.026164, -0.60263, -0.0747591], 3281*da0073e9SAndroid Build Coastguard Worker [2.43113, 0.0279516, -0.600376, -0.0736896]], 3282*da0073e9SAndroid Build Coastguard Worker [[2.42794, 0.026164, -0.60263, -0.0747591], 3283*da0073e9SAndroid Build Coastguard Worker [2.43113, 0.0279516, -0.600376, -0.0736896]], 3284*da0073e9SAndroid Build Coastguard Worker [[2.42794, 0.026164, -0.60263, -0.0747591], 3285*da0073e9SAndroid Build Coastguard Worker [2.43113, 0.0279516, -0.600376, -0.0736896]]] 3286*da0073e9SAndroid Build Coastguard Worker )).to(device) 3287*da0073e9SAndroid Build Coastguard Worker self.assertEqual(tuple(result.shape), tuple(ref_output.shape)) 3288*da0073e9SAndroid Build Coastguard Worker torch.testing.assert_close(result, ref_output, rtol=1e-7, atol=1e-5) 3289*da0073e9SAndroid Build Coastguard Worker 3290*da0073e9SAndroid Build Coastguard Worker # multiple layers with norm 3291*da0073e9SAndroid Build Coastguard Worker # d_model = 4 3292*da0073e9SAndroid Build Coastguard Worker norm = nn.LayerNorm(4) 3293*da0073e9SAndroid Build Coastguard Worker model = nn.TransformerDecoder(decoder_layer, 2, norm=norm).to(device) 3294*da0073e9SAndroid Build Coastguard Worker 3295*da0073e9SAndroid Build Coastguard Worker # deterministic input 3296*da0073e9SAndroid Build Coastguard Worker decoder_input = torch.tensor([[[20., 30., 40., 50.]]]).to(device) 3297*da0073e9SAndroid Build Coastguard Worker memory_input = torch.tensor([[[60., 70., 80., 90.]]]).to(device) 3298*da0073e9SAndroid Build Coastguard Worker result = model(decoder_input, memory_input) 3299*da0073e9SAndroid Build Coastguard Worker ref_output = torch.tensor( 3300*da0073e9SAndroid Build Coastguard Worker [[[1.66166, -0.326986, -1.01466, -0.320017]]]).to(device) 3301*da0073e9SAndroid Build Coastguard Worker self.assertEqual(tuple(result.shape), tuple(ref_output.shape)) 3302*da0073e9SAndroid Build Coastguard Worker torch.testing.assert_close(result, ref_output, rtol=1e-7, atol=1e-3) 3303*da0073e9SAndroid Build Coastguard Worker 3304*da0073e9SAndroid Build Coastguard Worker # multiple layers with norm 3305*da0073e9SAndroid Build Coastguard Worker model = nn.TransformerDecoder(decoder_layer, 6, norm=norm).to(device) 3306*da0073e9SAndroid Build Coastguard Worker 3307*da0073e9SAndroid Build Coastguard Worker # deterministic input 3308*da0073e9SAndroid Build Coastguard Worker decoder_input = perm_fn(torch.tensor([[[0.4517, 0.6793, 0.5313, 0.0034], 3309*da0073e9SAndroid Build Coastguard Worker [0.2678, 0.3677, 0.4459, 0.7166]], 3310*da0073e9SAndroid Build Coastguard Worker [[0.8100, 0.3716, 0.4096, 0.1976], 3311*da0073e9SAndroid Build Coastguard Worker [0.6958, 0.8844, 0.6081, 0.8315]], 3312*da0073e9SAndroid Build Coastguard Worker [[0.0494, 0.9343, 0.5955, 0.3830], 3313*da0073e9SAndroid Build Coastguard Worker [0.5404, 0.3464, 0.9378, 0.6200]]] 3314*da0073e9SAndroid Build Coastguard Worker )).to(device) 3315*da0073e9SAndroid Build Coastguard Worker memory_input = perm_fn(torch.tensor([[[0.7462, 0.6653, 0.5679, 0.4891], 3316*da0073e9SAndroid Build Coastguard Worker [0.5387, 0.1655, 0.3565, 0.0471]], 3317*da0073e9SAndroid Build Coastguard Worker [[0.8335, 0.2799, 0.5031, 0.2947], 3318*da0073e9SAndroid Build Coastguard Worker [0.1402, 0.0318, 0.7636, 0.1346]], 3319*da0073e9SAndroid Build Coastguard Worker [[0.6333, 0.9344, 0.1376, 0.9938], 3320*da0073e9SAndroid Build Coastguard Worker [0.8924, 0.2872, 0.6692, 0.2944]], 3321*da0073e9SAndroid Build Coastguard Worker [[0.9897, 0.6915, 0.3154, 0.1733], 3322*da0073e9SAndroid Build Coastguard Worker [0.8645, 0.3513, 0.3064, 0.0767]], 3323*da0073e9SAndroid Build Coastguard Worker [[0.8117, 0.2366, 0.4838, 0.7881], 3324*da0073e9SAndroid Build Coastguard Worker [0.3718, 0.4945, 0.9511, 0.0864]]] 3325*da0073e9SAndroid Build Coastguard Worker )).to(device) 3326*da0073e9SAndroid Build Coastguard Worker result = model(decoder_input, memory_input) 3327*da0073e9SAndroid Build Coastguard Worker ref_output = perm_fn(torch.tensor([[[1.69559, -0.357291, -0.894741, -0.443553], 3328*da0073e9SAndroid Build Coastguard Worker [1.69571, -0.357363, -0.894154, -0.444196]], 3329*da0073e9SAndroid Build Coastguard Worker [[1.69559, -0.357291, -0.894741, -0.443553], 3330*da0073e9SAndroid Build Coastguard Worker [1.69571, -0.357363, -0.894154, -0.444196]], 3331*da0073e9SAndroid Build Coastguard Worker [[1.69559, -0.357291, -0.894741, -0.443553], 3332*da0073e9SAndroid Build Coastguard Worker [1.69571, -0.357363, -0.894154, -0.444196]]] 3333*da0073e9SAndroid Build Coastguard Worker )).to(device) 3334*da0073e9SAndroid Build Coastguard Worker self.assertEqual(tuple(result.shape), tuple(ref_output.shape)) 3335*da0073e9SAndroid Build Coastguard Worker torch.testing.assert_close(result, ref_output, rtol=1e-7, atol=1e-5) 3336*da0073e9SAndroid Build Coastguard Worker 3337*da0073e9SAndroid Build Coastguard Worker # gelu activation test cases 3338*da0073e9SAndroid Build Coastguard Worker activation = "gelu" 3339*da0073e9SAndroid Build Coastguard Worker use_cuda = torch.cuda.is_available() 3340*da0073e9SAndroid Build Coastguard Worker device = torch.device("cuda" if use_cuda else "cpu") 3341*da0073e9SAndroid Build Coastguard Worker 3342*da0073e9SAndroid Build Coastguard Worker decoder_layer = get_a_test_layer(use_cuda=use_cuda, activation=activation, 3343*da0073e9SAndroid Build Coastguard Worker batch_first=batch_first) 3344*da0073e9SAndroid Build Coastguard Worker 3345*da0073e9SAndroid Build Coastguard Worker model = nn.TransformerDecoder(decoder_layer, 1).to(device) 3346*da0073e9SAndroid Build Coastguard Worker 3347*da0073e9SAndroid Build Coastguard Worker # deterministic input 3348*da0073e9SAndroid Build Coastguard Worker decoder_input = torch.tensor([[[20., 30., 40., 50.]]]).to(device) 3349*da0073e9SAndroid Build Coastguard Worker memory_input = torch.tensor([[[60., 70., 80., 90.]]]).to(device) 3350*da0073e9SAndroid Build Coastguard Worker result = model(decoder_input, memory_input) 3351*da0073e9SAndroid Build Coastguard Worker ref_output = torch.tensor([[[2.306435, 0.095946, -0.675796, 0.10687]]]).to(device) 3352*da0073e9SAndroid Build Coastguard Worker self.assertEqual(tuple(result.shape), tuple(ref_output.shape)) 3353*da0073e9SAndroid Build Coastguard Worker torch.testing.assert_close(result, ref_output, rtol=1e-7, atol=1e-3) 3354*da0073e9SAndroid Build Coastguard Worker 3355*da0073e9SAndroid Build Coastguard Worker # deterministic input 3356*da0073e9SAndroid Build Coastguard Worker decoder_input = perm_fn(torch.tensor([[[9., 10., 11., 12.]], 3357*da0073e9SAndroid Build Coastguard Worker [[11., 12., 13., 14.]]])).to(device) 3358*da0073e9SAndroid Build Coastguard Worker memory_input = perm_fn(torch.tensor([[[1., 2., 3., 4.]]])).to(device) 3359*da0073e9SAndroid Build Coastguard Worker result = model(decoder_input, memory_input) 3360*da0073e9SAndroid Build Coastguard Worker ref_output = perm_fn(torch.tensor([[[2.415448, 0.054389, -0.610932, -0.0156613]], 3361*da0073e9SAndroid Build Coastguard Worker [[2.415448, 0.054389, -0.610932, -0.0156613]]])).to(device) 3362*da0073e9SAndroid Build Coastguard Worker self.assertEqual(tuple(result.shape), tuple(ref_output.shape)) 3363*da0073e9SAndroid Build Coastguard Worker torch.testing.assert_close(result, ref_output, rtol=1e-7, atol=1e-4) 3364*da0073e9SAndroid Build Coastguard Worker 3365*da0073e9SAndroid Build Coastguard Worker # deterministic input 3366*da0073e9SAndroid Build Coastguard Worker decoder_input = perm_fn(torch.tensor([[[1., 2., 3., 4.]], 3367*da0073e9SAndroid Build Coastguard Worker [[5., 6., 7., 8.]]])).to(device) 3368*da0073e9SAndroid Build Coastguard Worker memory_input = perm_fn(torch.tensor([[[9., 10., 11., 12.]], 3369*da0073e9SAndroid Build Coastguard Worker [[11., 12., 13., 14.]]])).to(device) 3370*da0073e9SAndroid Build Coastguard Worker result = model(decoder_input, memory_input) 3371*da0073e9SAndroid Build Coastguard Worker ref_output = perm_fn(torch.tensor([[[2.338531, 0.087709, -0.65776, 0.080646]], 3372*da0073e9SAndroid Build Coastguard Worker [[2.338531, 0.087709, -0.65776, 0.080646]]])).to(device) 3373*da0073e9SAndroid Build Coastguard Worker self.assertEqual(tuple(result.shape), tuple(ref_output.shape)) 3374*da0073e9SAndroid Build Coastguard Worker torch.testing.assert_close(result, ref_output, rtol=1e-7, atol=1e-4) 3375*da0073e9SAndroid Build Coastguard Worker 3376*da0073e9SAndroid Build Coastguard Worker # deterministic input 3377*da0073e9SAndroid Build Coastguard Worker decoder_input = perm_fn(torch.tensor([[[0.4517, 0.6793, 0.5313, 0.0034], 3378*da0073e9SAndroid Build Coastguard Worker [0.2678, 0.3677, 0.4459, 0.7166]], 3379*da0073e9SAndroid Build Coastguard Worker [[0.8100, 0.3716, 0.4096, 0.1976], 3380*da0073e9SAndroid Build Coastguard Worker [0.6958, 0.8844, 0.6081, 0.8315]], 3381*da0073e9SAndroid Build Coastguard Worker [[0.0494, 0.9343, 0.5955, 0.3830], 3382*da0073e9SAndroid Build Coastguard Worker [0.5404, 0.3464, 0.9378, 0.6200]]] 3383*da0073e9SAndroid Build Coastguard Worker )).to(device) 3384*da0073e9SAndroid Build Coastguard Worker memory_input = perm_fn(torch.tensor([[[0.7462, 0.6653, 0.5679, 0.4891], 3385*da0073e9SAndroid Build Coastguard Worker [0.5387, 0.1655, 0.3565, 0.0471]], 3386*da0073e9SAndroid Build Coastguard Worker [[0.8335, 0.2799, 0.5031, 0.2947], 3387*da0073e9SAndroid Build Coastguard Worker [0.1402, 0.0318, 0.7636, 0.1346]], 3388*da0073e9SAndroid Build Coastguard Worker [[0.6333, 0.9344, 0.1376, 0.9938], 3389*da0073e9SAndroid Build Coastguard Worker [0.8924, 0.2872, 0.6692, 0.2944]], 3390*da0073e9SAndroid Build Coastguard Worker [[0.9897, 0.6915, 0.3154, 0.1733], 3391*da0073e9SAndroid Build Coastguard Worker [0.8645, 0.3513, 0.3064, 0.0767]], 3392*da0073e9SAndroid Build Coastguard Worker [[0.8117, 0.2366, 0.4838, 0.7881], 3393*da0073e9SAndroid Build Coastguard Worker [0.3718, 0.4945, 0.9511, 0.0864]]] 3394*da0073e9SAndroid Build Coastguard Worker )).to(device) 3395*da0073e9SAndroid Build Coastguard Worker result = model(decoder_input, memory_input) 3396*da0073e9SAndroid Build Coastguard Worker ref_output = perm_fn(torch.tensor([[[2.42049104, 0.03443088, -0.60793706, -0.05436271], 3397*da0073e9SAndroid Build Coastguard Worker [2.42210631, 0.03546578, -0.60679895, -0.05357488]], 3398*da0073e9SAndroid Build Coastguard Worker [[2.41907674, 0.0336104, -0.60892977, -0.05490462], 3399*da0073e9SAndroid Build Coastguard Worker [2.42216881, 0.03586554, -0.6067524, -0.05289126]], 3400*da0073e9SAndroid Build Coastguard Worker [[2.42205716, 0.03488046, -0.60683681, -0.05460596], 3401*da0073e9SAndroid Build Coastguard Worker [2.42240309, 0.0354595, -0.60659063, -0.05378816]]] 3402*da0073e9SAndroid Build Coastguard Worker )).to(device) 3403*da0073e9SAndroid Build Coastguard Worker self.assertEqual(tuple(result.shape), tuple(ref_output.shape)) 3404*da0073e9SAndroid Build Coastguard Worker torch.testing.assert_close(result, ref_output, rtol=1e-7, atol=1e-5) 3405*da0073e9SAndroid Build Coastguard Worker 3406*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not (TEST_CUDNN and TEST_MULTIGPU), 'CUDNN or multi-gpu not available') 3407*da0073e9SAndroid Build Coastguard Worker def test_cudnn_rnn_dropout_states_device(self): 3408*da0073e9SAndroid Build Coastguard Worker rnn = nn.RNN(10, 20, num_layers=2, dropout=.5) 3409*da0073e9SAndroid Build Coastguard Worker device = 1 3410*da0073e9SAndroid Build Coastguard Worker input = torch.randn(5, 4, 10).cuda(device) 3411*da0073e9SAndroid Build Coastguard Worker rnn.cuda(device) 3412*da0073e9SAndroid Build Coastguard Worker hx = torch.randn(2, 4, 20).cuda(device) 3413*da0073e9SAndroid Build Coastguard Worker output = rnn(input, hx) 3414*da0073e9SAndroid Build Coastguard Worker 3415*da0073e9SAndroid Build Coastguard Worker def test_cudnn_forward_exception(self): 3416*da0073e9SAndroid Build Coastguard Worker rnns = [ 3417*da0073e9SAndroid Build Coastguard Worker (nn.LSTM(10, 20, batch_first=True), (torch.zeros(1, 2, 19), torch.zeros(1, 2, 19))), 3418*da0073e9SAndroid Build Coastguard Worker (nn.LSTM(10, 20, batch_first=True, proj_size=10), (torch.zeros(1, 2, 19), torch.zeros(1, 2, 19))), 3419*da0073e9SAndroid Build Coastguard Worker (nn.GRU(10, 20, batch_first=True), torch.zeros(1, 2, 19)), 3420*da0073e9SAndroid Build Coastguard Worker (nn.RNN(10, 20, batch_first=True), torch.zeros(1, 2, 19)), 3421*da0073e9SAndroid Build Coastguard Worker ] 3422*da0073e9SAndroid Build Coastguard Worker x_wrong = torch.randn(2, 3, 3) 3423*da0073e9SAndroid Build Coastguard Worker x_right = torch.randn(2, 3, 10) 3424*da0073e9SAndroid Build Coastguard Worker for rnn, hidden in rnns: 3425*da0073e9SAndroid Build Coastguard Worker self.assertRaisesRegex(RuntimeError, "Expected hidden.*size.*got", rnn, x_right, hidden) 3426*da0073e9SAndroid Build Coastguard Worker self.assertRaisesRegex(RuntimeError, re.escape("input.size(-1) must be equal to input_size"), rnn, x_wrong) 3427*da0073e9SAndroid Build Coastguard Worker 3428*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_CUDNN, 'CUDNN not available') 3429*da0073e9SAndroid Build Coastguard Worker @skipIfRocm 3430*da0073e9SAndroid Build Coastguard Worker def test_cudnn_weight_format(self): 3431*da0073e9SAndroid Build Coastguard Worker rnns = [ 3432*da0073e9SAndroid Build Coastguard Worker nn.LSTM(10, 20, batch_first=True), 3433*da0073e9SAndroid Build Coastguard Worker nn.LSTM(10, 20, batch_first=True, proj_size=10), 3434*da0073e9SAndroid Build Coastguard Worker nn.GRU(10, 20, batch_first=True), 3435*da0073e9SAndroid Build Coastguard Worker nn.RNN(10, 20, batch_first=True) 3436*da0073e9SAndroid Build Coastguard Worker ] 3437*da0073e9SAndroid Build Coastguard Worker first_warn = True 3438*da0073e9SAndroid Build Coastguard Worker for rnn in rnns: 3439*da0073e9SAndroid Build Coastguard Worker rnn.cuda() 3440*da0073e9SAndroid Build Coastguard Worker input = torch.randn(5, 4, 10, requires_grad=True, device="cuda") 3441*da0073e9SAndroid Build Coastguard Worker hx = torch.randn(1, 5, 20, requires_grad=True, device="cuda") 3442*da0073e9SAndroid Build Coastguard Worker all_vars = [input, hx] + list(rnn.parameters()) 3443*da0073e9SAndroid Build Coastguard Worker if isinstance(rnn, nn.LSTM): 3444*da0073e9SAndroid Build Coastguard Worker # LSTM with projections has different hx size 3445*da0073e9SAndroid Build Coastguard Worker if rnn.proj_size > 0: 3446*da0073e9SAndroid Build Coastguard Worker hx = torch.randn(1, 5, 10, requires_grad=True, device="cuda") 3447*da0073e9SAndroid Build Coastguard Worker all_vars[1] = hx 3448*da0073e9SAndroid Build Coastguard Worker cx = torch.randn(1, 5, 20, requires_grad=True, device="cuda") 3449*da0073e9SAndroid Build Coastguard Worker all_vars[2:2] = [cx] 3450*da0073e9SAndroid Build Coastguard Worker hx = (hx, cx) 3451*da0073e9SAndroid Build Coastguard Worker 3452*da0073e9SAndroid Build Coastguard Worker output = rnn(input, hx) 3453*da0073e9SAndroid Build Coastguard Worker output[0].sum().backward() 3454*da0073e9SAndroid Build Coastguard Worker grads = [v.grad.data.clone() for v in all_vars] 3455*da0073e9SAndroid Build Coastguard Worker for v in all_vars: 3456*da0073e9SAndroid Build Coastguard Worker v.grad.data.zero_() 3457*da0073e9SAndroid Build Coastguard Worker 3458*da0073e9SAndroid Build Coastguard Worker # Weights will no longer view onto the same chunk of memory 3459*da0073e9SAndroid Build Coastguard Worker weight = all_vars[4] 3460*da0073e9SAndroid Build Coastguard Worker weight_data = weight.data.clone() 3461*da0073e9SAndroid Build Coastguard Worker with torch.no_grad(): 3462*da0073e9SAndroid Build Coastguard Worker weight.set_(weight_data) 3463*da0073e9SAndroid Build Coastguard Worker 3464*da0073e9SAndroid Build Coastguard Worker for _ in range(2): 3465*da0073e9SAndroid Build Coastguard Worker with warnings.catch_warnings(record=True) as w: 3466*da0073e9SAndroid Build Coastguard Worker output_noncontig = rnn(input, hx) 3467*da0073e9SAndroid Build Coastguard Worker if first_warn: 3468*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(w), 1) 3469*da0073e9SAndroid Build Coastguard Worker self.assertIn('weights are not part of single contiguous chunk of memory', w[0].message.args[0]) 3470*da0073e9SAndroid Build Coastguard Worker first_warn = False 3471*da0073e9SAndroid Build Coastguard Worker warnings.resetwarnings() 3472*da0073e9SAndroid Build Coastguard Worker output_noncontig[0].sum().backward() 3473*da0073e9SAndroid Build Coastguard Worker grads_noncontig = [v.grad.data.clone() for v in all_vars] 3474*da0073e9SAndroid Build Coastguard Worker for v in all_vars: 3475*da0073e9SAndroid Build Coastguard Worker v.grad.data.zero_() 3476*da0073e9SAndroid Build Coastguard Worker self.assertEqual(output, output_noncontig) 3477*da0073e9SAndroid Build Coastguard Worker self.assertEqual(grads_noncontig, grads) 3478*da0073e9SAndroid Build Coastguard Worker 3479*da0073e9SAndroid Build Coastguard Worker # Make sure these still share storage 3480*da0073e9SAndroid Build Coastguard Worker weight_data[:] = 4 3481*da0073e9SAndroid Build Coastguard Worker self.assertEqual(weight_data, all_vars[4].data) 3482*da0073e9SAndroid Build Coastguard Worker 3483*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_CUDNN, 'CUDNN not available') 3484*da0073e9SAndroid Build Coastguard Worker def test_cudnn_weight_tying(self): 3485*da0073e9SAndroid Build Coastguard Worker rnns = [ 3486*da0073e9SAndroid Build Coastguard Worker nn.LSTM(10, 20, batch_first=True, bidirectional=True), 3487*da0073e9SAndroid Build Coastguard Worker nn.LSTM(10, 20, batch_first=True, bidirectional=True, proj_size=10), 3488*da0073e9SAndroid Build Coastguard Worker nn.GRU(10, 20, batch_first=True, bidirectional=True), 3489*da0073e9SAndroid Build Coastguard Worker nn.RNN(10, 20, batch_first=True, bidirectional=True) 3490*da0073e9SAndroid Build Coastguard Worker ] 3491*da0073e9SAndroid Build Coastguard Worker for rnn in rnns: 3492*da0073e9SAndroid Build Coastguard Worker rnn.bias_ih_l0_reverse = rnn.bias_ih_l0 3493*da0073e9SAndroid Build Coastguard Worker rnn.cuda() 3494*da0073e9SAndroid Build Coastguard Worker input = torch.randn(5, 4, 10, requires_grad=True, device="cuda") 3495*da0073e9SAndroid Build Coastguard Worker hx = torch.randn(2, 5, 20, requires_grad=True, device="cuda") 3496*da0073e9SAndroid Build Coastguard Worker all_vars = [input, hx] + list(rnn.parameters()) 3497*da0073e9SAndroid Build Coastguard Worker opt = torch.optim.SGD(rnn.parameters(), lr=0.1) 3498*da0073e9SAndroid Build Coastguard Worker opt.zero_grad() 3499*da0073e9SAndroid Build Coastguard Worker if isinstance(rnn, nn.LSTM): 3500*da0073e9SAndroid Build Coastguard Worker # LSTM with projections has different hx size 3501*da0073e9SAndroid Build Coastguard Worker if rnn.proj_size > 0: 3502*da0073e9SAndroid Build Coastguard Worker hx = torch.randn(2, 5, 10, requires_grad=True, device="cuda") 3503*da0073e9SAndroid Build Coastguard Worker all_vars[1] = hx 3504*da0073e9SAndroid Build Coastguard Worker cx = torch.randn(2, 5, 20, requires_grad=True, device="cuda") 3505*da0073e9SAndroid Build Coastguard Worker all_vars[2:2] = [cx] 3506*da0073e9SAndroid Build Coastguard Worker hx = (hx, cx) 3507*da0073e9SAndroid Build Coastguard Worker 3508*da0073e9SAndroid Build Coastguard Worker with warnings.catch_warnings(record=True) as w: 3509*da0073e9SAndroid Build Coastguard Worker output = rnn(input, hx) 3510*da0073e9SAndroid Build Coastguard Worker output[0].sum().backward() 3511*da0073e9SAndroid Build Coastguard Worker 3512*da0073e9SAndroid Build Coastguard Worker opt.step() 3513*da0073e9SAndroid Build Coastguard Worker with warnings.catch_warnings(record=True) as w: 3514*da0073e9SAndroid Build Coastguard Worker output_cuda = rnn(input, hx) 3515*da0073e9SAndroid Build Coastguard Worker rnn.cpu() 3516*da0073e9SAndroid Build Coastguard Worker hx = (hx[0].cpu(), hx[1].cpu()) if isinstance(rnn, nn.LSTM) else hx.cpu() 3517*da0073e9SAndroid Build Coastguard Worker output_cpu = rnn(input.cpu(), hx) 3518*da0073e9SAndroid Build Coastguard Worker self.assertEqual(output_cuda, output_cpu) 3519*da0073e9SAndroid Build Coastguard Worker 3520*da0073e9SAndroid Build Coastguard Worker 3521*da0073e9SAndroid Build Coastguard Worker def test_transformer_args_check(self): 3522*da0073e9SAndroid Build Coastguard Worker model_name = 'Transformer' 3523*da0073e9SAndroid Build Coastguard Worker d_model = 128 3524*da0073e9SAndroid Build Coastguard Worker nhead = 4 3525*da0073e9SAndroid Build Coastguard Worker num_encoder_layers = 2 3526*da0073e9SAndroid Build Coastguard Worker num_decoder_layers = 3 3527*da0073e9SAndroid Build Coastguard Worker dim_feedforward = 65 3528*da0073e9SAndroid Build Coastguard Worker dropout = 0.3 3529*da0073e9SAndroid Build Coastguard Worker bsz = 3 3530*da0073e9SAndroid Build Coastguard Worker seq_len = 35 3531*da0073e9SAndroid Build Coastguard Worker tgt_len = 15 3532*da0073e9SAndroid Build Coastguard Worker activations = [F.relu, F.gelu] 3533*da0073e9SAndroid Build Coastguard Worker 3534*da0073e9SAndroid Build Coastguard Worker wrong_bsz = 7 3535*da0073e9SAndroid Build Coastguard Worker wrong_d_model = 63 3536*da0073e9SAndroid Build Coastguard Worker wrong_nhead = 5 3537*da0073e9SAndroid Build Coastguard Worker wrong_activation = "abc" 3538*da0073e9SAndroid Build Coastguard Worker 3539*da0073e9SAndroid Build Coastguard Worker def test(encoder_input_shape, decoder_input_shape, 3540*da0073e9SAndroid Build Coastguard Worker src_mask_len=None, tgt_mask_len=None, memory_mask_size=None, 3541*da0073e9SAndroid Build Coastguard Worker src_key_padding_mask_size=None, tgt_key_padding_mask_size=None, 3542*da0073e9SAndroid Build Coastguard Worker memory_key_padding_mask_size=None, 3543*da0073e9SAndroid Build Coastguard Worker src_is_causal=False, tgt_is_causal=False, 3544*da0073e9SAndroid Build Coastguard Worker memory_is_causal=False): 3545*da0073e9SAndroid Build Coastguard Worker 3546*da0073e9SAndroid Build Coastguard Worker encoder_input = torch.randn(encoder_input_shape) 3547*da0073e9SAndroid Build Coastguard Worker decoder_input = torch.randn(decoder_input_shape) 3548*da0073e9SAndroid Build Coastguard Worker model = getattr(nn, model_name)(d_model, nhead, num_encoder_layers, 3549*da0073e9SAndroid Build Coastguard Worker num_decoder_layers, dim_feedforward, dropout) 3550*da0073e9SAndroid Build Coastguard Worker 3551*da0073e9SAndroid Build Coastguard Worker if src_mask_len is not None: 3552*da0073e9SAndroid Build Coastguard Worker src_mask = model.generate_square_subsequent_mask(src_mask_len) 3553*da0073e9SAndroid Build Coastguard Worker else: 3554*da0073e9SAndroid Build Coastguard Worker src_mask = None 3555*da0073e9SAndroid Build Coastguard Worker 3556*da0073e9SAndroid Build Coastguard Worker if tgt_mask_len is not None: 3557*da0073e9SAndroid Build Coastguard Worker tgt_mask = model.generate_square_subsequent_mask(tgt_mask_len) 3558*da0073e9SAndroid Build Coastguard Worker else: 3559*da0073e9SAndroid Build Coastguard Worker tgt_mask = None 3560*da0073e9SAndroid Build Coastguard Worker 3561*da0073e9SAndroid Build Coastguard Worker if memory_mask_size is not None: 3562*da0073e9SAndroid Build Coastguard Worker memory_task = torch.rand(memory_mask_size) 3563*da0073e9SAndroid Build Coastguard Worker else: 3564*da0073e9SAndroid Build Coastguard Worker memory_task = None 3565*da0073e9SAndroid Build Coastguard Worker 3566*da0073e9SAndroid Build Coastguard Worker if src_key_padding_mask_size is not None: 3567*da0073e9SAndroid Build Coastguard Worker src_key_padding_mask = torch.rand(src_key_padding_mask_size) >= 0.5 3568*da0073e9SAndroid Build Coastguard Worker else: 3569*da0073e9SAndroid Build Coastguard Worker src_key_padding_mask = None 3570*da0073e9SAndroid Build Coastguard Worker 3571*da0073e9SAndroid Build Coastguard Worker if tgt_key_padding_mask_size is not None: 3572*da0073e9SAndroid Build Coastguard Worker tgt_key_padding_mask = torch.rand(tgt_key_padding_mask_size) >= 0.5 3573*da0073e9SAndroid Build Coastguard Worker else: 3574*da0073e9SAndroid Build Coastguard Worker tgt_key_padding_mask = None 3575*da0073e9SAndroid Build Coastguard Worker 3576*da0073e9SAndroid Build Coastguard Worker if memory_key_padding_mask_size is not None: 3577*da0073e9SAndroid Build Coastguard Worker memory_key_padding_mask = torch.rand(memory_key_padding_mask_size) >= 0.5 3578*da0073e9SAndroid Build Coastguard Worker else: 3579*da0073e9SAndroid Build Coastguard Worker memory_key_padding_mask = None 3580*da0073e9SAndroid Build Coastguard Worker 3581*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(RuntimeError): 3582*da0073e9SAndroid Build Coastguard Worker model(encoder_input, decoder_input, 3583*da0073e9SAndroid Build Coastguard Worker src_mask=src_mask, 3584*da0073e9SAndroid Build Coastguard Worker tgt_mask=tgt_mask, 3585*da0073e9SAndroid Build Coastguard Worker memory_mask=memory_task, 3586*da0073e9SAndroid Build Coastguard Worker src_key_padding_mask=src_key_padding_mask, 3587*da0073e9SAndroid Build Coastguard Worker tgt_key_padding_mask=tgt_key_padding_mask, 3588*da0073e9SAndroid Build Coastguard Worker memory_key_padding_mask=memory_key_padding_mask, 3589*da0073e9SAndroid Build Coastguard Worker src_is_causal=src_is_causal, 3590*da0073e9SAndroid Build Coastguard Worker tgt_is_causal=tgt_is_causal, 3591*da0073e9SAndroid Build Coastguard Worker memory_is_causal=memory_is_causal) 3592*da0073e9SAndroid Build Coastguard Worker 3593*da0073e9SAndroid Build Coastguard Worker 3594*da0073e9SAndroid Build Coastguard Worker correct_encoder_input_shape = (seq_len, bsz, d_model) 3595*da0073e9SAndroid Build Coastguard Worker correct_decoder_input_shape = (tgt_len, bsz, d_model) 3596*da0073e9SAndroid Build Coastguard Worker 3597*da0073e9SAndroid Build Coastguard Worker def update_shape(shape, dim, new_dim_size): 3598*da0073e9SAndroid Build Coastguard Worker new_shape = list(shape) 3599*da0073e9SAndroid Build Coastguard Worker new_shape[dim] = new_dim_size 3600*da0073e9SAndroid Build Coastguard Worker return tuple(new_shape) 3601*da0073e9SAndroid Build Coastguard Worker 3602*da0073e9SAndroid Build Coastguard Worker # Incorrect encoder_input batch size 3603*da0073e9SAndroid Build Coastguard Worker encoder_input_shape = update_shape(correct_encoder_input_shape, 1, wrong_bsz) 3604*da0073e9SAndroid Build Coastguard Worker decoder_input_shape = correct_decoder_input_shape 3605*da0073e9SAndroid Build Coastguard Worker test(encoder_input_shape, decoder_input_shape) 3606*da0073e9SAndroid Build Coastguard Worker 3607*da0073e9SAndroid Build Coastguard Worker # Incorrect decoder_input batch size 3608*da0073e9SAndroid Build Coastguard Worker encoder_input_shape = correct_encoder_input_shape 3609*da0073e9SAndroid Build Coastguard Worker decoder_input_shape = update_shape(correct_decoder_input_shape, 1, wrong_bsz) 3610*da0073e9SAndroid Build Coastguard Worker test(encoder_input_shape, decoder_input_shape) 3611*da0073e9SAndroid Build Coastguard Worker 3612*da0073e9SAndroid Build Coastguard Worker # Incorrect encoder_input input size 3613*da0073e9SAndroid Build Coastguard Worker encoder_input_shape = update_shape(correct_encoder_input_shape, 2, wrong_d_model) 3614*da0073e9SAndroid Build Coastguard Worker decoder_input_shape = correct_decoder_input_shape 3615*da0073e9SAndroid Build Coastguard Worker test(encoder_input_shape, decoder_input_shape) 3616*da0073e9SAndroid Build Coastguard Worker 3617*da0073e9SAndroid Build Coastguard Worker # Incorrect decoder_input input size 3618*da0073e9SAndroid Build Coastguard Worker encoder_input_shape = correct_encoder_input_shape 3619*da0073e9SAndroid Build Coastguard Worker decoder_input_shape = update_shape(correct_decoder_input_shape, 2, wrong_d_model) 3620*da0073e9SAndroid Build Coastguard Worker test(encoder_input_shape, decoder_input_shape) 3621*da0073e9SAndroid Build Coastguard Worker 3622*da0073e9SAndroid Build Coastguard Worker # Incorrect nhead 3623*da0073e9SAndroid Build Coastguard Worker encoder_input_shape = correct_encoder_input_shape 3624*da0073e9SAndroid Build Coastguard Worker decoder_input_shape = correct_decoder_input_shape 3625*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(AssertionError): 3626*da0073e9SAndroid Build Coastguard Worker model = getattr(nn, model_name)(d_model, wrong_nhead, num_encoder_layers, 3627*da0073e9SAndroid Build Coastguard Worker num_decoder_layers, dim_feedforward, dropout) 3628*da0073e9SAndroid Build Coastguard Worker 3629*da0073e9SAndroid Build Coastguard Worker # Incorrect src_mask 3630*da0073e9SAndroid Build Coastguard Worker encoder_input_shape = correct_encoder_input_shape 3631*da0073e9SAndroid Build Coastguard Worker decoder_input_shape = correct_decoder_input_shape 3632*da0073e9SAndroid Build Coastguard Worker wrong_src_mask_size = seq_len + 1 3633*da0073e9SAndroid Build Coastguard Worker test(encoder_input_shape, decoder_input_shape, src_mask_len=wrong_src_mask_size) 3634*da0073e9SAndroid Build Coastguard Worker 3635*da0073e9SAndroid Build Coastguard Worker # Incorrect tgt_mask 3636*da0073e9SAndroid Build Coastguard Worker encoder_input_shape = correct_encoder_input_shape 3637*da0073e9SAndroid Build Coastguard Worker decoder_input_shape = correct_decoder_input_shape 3638*da0073e9SAndroid Build Coastguard Worker wrong_tgt_mask_size = tgt_len + 1 3639*da0073e9SAndroid Build Coastguard Worker test(encoder_input_shape, decoder_input_shape, tgt_mask_len=wrong_tgt_mask_size) 3640*da0073e9SAndroid Build Coastguard Worker 3641*da0073e9SAndroid Build Coastguard Worker # Incorrect memory_mask 3642*da0073e9SAndroid Build Coastguard Worker encoder_input_shape = correct_encoder_input_shape 3643*da0073e9SAndroid Build Coastguard Worker decoder_input_shape = correct_decoder_input_shape 3644*da0073e9SAndroid Build Coastguard Worker wrong_tgt_mask_size = tgt_len + 1 3645*da0073e9SAndroid Build Coastguard Worker test(encoder_input_shape, decoder_input_shape, 3646*da0073e9SAndroid Build Coastguard Worker memory_mask_size=(wrong_tgt_mask_size, wrong_src_mask_size)) 3647*da0073e9SAndroid Build Coastguard Worker 3648*da0073e9SAndroid Build Coastguard Worker # Incorrect src_key_padding_mask 3649*da0073e9SAndroid Build Coastguard Worker encoder_input_shape = correct_encoder_input_shape 3650*da0073e9SAndroid Build Coastguard Worker decoder_input_shape = correct_decoder_input_shape 3651*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(AssertionError): 3652*da0073e9SAndroid Build Coastguard Worker test(encoder_input_shape, decoder_input_shape, 3653*da0073e9SAndroid Build Coastguard Worker src_key_padding_mask_size=(wrong_bsz, wrong_src_mask_size)) 3654*da0073e9SAndroid Build Coastguard Worker 3655*da0073e9SAndroid Build Coastguard Worker # Incorrect tgt_key_padding_mask 3656*da0073e9SAndroid Build Coastguard Worker encoder_input_shape = correct_encoder_input_shape 3657*da0073e9SAndroid Build Coastguard Worker decoder_input_shape = correct_decoder_input_shape 3658*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(AssertionError): 3659*da0073e9SAndroid Build Coastguard Worker test(encoder_input_shape, decoder_input_shape, 3660*da0073e9SAndroid Build Coastguard Worker tgt_key_padding_mask_size=(wrong_bsz, wrong_tgt_mask_size)) 3661*da0073e9SAndroid Build Coastguard Worker 3662*da0073e9SAndroid Build Coastguard Worker # Incorrect memory_key_padding_mask 3663*da0073e9SAndroid Build Coastguard Worker encoder_input_shape = correct_encoder_input_shape 3664*da0073e9SAndroid Build Coastguard Worker decoder_input_shape = correct_decoder_input_shape 3665*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(AssertionError): 3666*da0073e9SAndroid Build Coastguard Worker test(encoder_input_shape, decoder_input_shape, 3667*da0073e9SAndroid Build Coastguard Worker memory_key_padding_mask_size=(wrong_bsz, wrong_src_mask_size)) 3668*da0073e9SAndroid Build Coastguard Worker 3669*da0073e9SAndroid Build Coastguard Worker # Correct activations 3670*da0073e9SAndroid Build Coastguard Worker for activation in activations: 3671*da0073e9SAndroid Build Coastguard Worker model = getattr(nn, model_name)(d_model, nhead, num_encoder_layers, num_decoder_layers, 3672*da0073e9SAndroid Build Coastguard Worker dim_feedforward, dropout, activation) 3673*da0073e9SAndroid Build Coastguard Worker # Incorrect activation 3674*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(RuntimeError): 3675*da0073e9SAndroid Build Coastguard Worker model = getattr(nn, model_name)(d_model, nhead, num_encoder_layers, num_decoder_layers, 3676*da0073e9SAndroid Build Coastguard Worker dim_feedforward, dropout, wrong_activation) 3677*da0073e9SAndroid Build Coastguard Worker 3678*da0073e9SAndroid Build Coastguard Worker 3679*da0073e9SAndroid Build Coastguard Worker def test_transformer_layer_args_check(self): 3680*da0073e9SAndroid Build Coastguard Worker model_names = ['TransformerEncoderLayer', 'TransformerDecoderLayer'] 3681*da0073e9SAndroid Build Coastguard Worker d_model = 128 3682*da0073e9SAndroid Build Coastguard Worker nhead = 4 3683*da0073e9SAndroid Build Coastguard Worker dim_feedforward = 65 3684*da0073e9SAndroid Build Coastguard Worker dropout = 0.3 3685*da0073e9SAndroid Build Coastguard Worker bsz = 3 3686*da0073e9SAndroid Build Coastguard Worker seq_len = 35 3687*da0073e9SAndroid Build Coastguard Worker tgt_len = 15 3688*da0073e9SAndroid Build Coastguard Worker activations = [F.relu, F.gelu] 3689*da0073e9SAndroid Build Coastguard Worker 3690*da0073e9SAndroid Build Coastguard Worker wrong_activation = "abc" 3691*da0073e9SAndroid Build Coastguard Worker 3692*da0073e9SAndroid Build Coastguard Worker encoder_input_shape = (seq_len, bsz, d_model) 3693*da0073e9SAndroid Build Coastguard Worker decoder_input_shape = (tgt_len, bsz, d_model) 3694*da0073e9SAndroid Build Coastguard Worker 3695*da0073e9SAndroid Build Coastguard Worker encoder_input = torch.randn(encoder_input_shape) 3696*da0073e9SAndroid Build Coastguard Worker decoder_input = torch.randn(decoder_input_shape) 3697*da0073e9SAndroid Build Coastguard Worker 3698*da0073e9SAndroid Build Coastguard Worker for model_name in model_names: 3699*da0073e9SAndroid Build Coastguard Worker for activation in activations: 3700*da0073e9SAndroid Build Coastguard Worker model = getattr(nn, model_name)(d_model, nhead, dim_feedforward, 3701*da0073e9SAndroid Build Coastguard Worker dropout, activation) 3702*da0073e9SAndroid Build Coastguard Worker # Incorrect activation 3703*da0073e9SAndroid Build Coastguard Worker for model_name in model_names: 3704*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(RuntimeError): 3705*da0073e9SAndroid Build Coastguard Worker model = getattr(nn, model_name)(d_model, nhead, dim_feedforward, 3706*da0073e9SAndroid Build Coastguard Worker dropout, wrong_activation) 3707*da0073e9SAndroid Build Coastguard Worker 3708*da0073e9SAndroid Build Coastguard Worker def test_rnn_args_check(self): 3709*da0073e9SAndroid Build Coastguard Worker input_size = 3 3710*da0073e9SAndroid Build Coastguard Worker hidden_size = 5 3711*da0073e9SAndroid Build Coastguard Worker num_layers = 2 3712*da0073e9SAndroid Build Coastguard Worker batch_size = 4 3713*da0073e9SAndroid Build Coastguard Worker seq_len = 6 3714*da0073e9SAndroid Build Coastguard Worker num_directions = 1 3715*da0073e9SAndroid Build Coastguard Worker bad_size = 7 # prime number so that no size can divide it. 3716*da0073e9SAndroid Build Coastguard Worker 3717*da0073e9SAndroid Build Coastguard Worker def test(input_shape, hidden_shape, mode): 3718*da0073e9SAndroid Build Coastguard Worker for input, hidden in get_inputs(input_shape, hidden_shape, mode): 3719*da0073e9SAndroid Build Coastguard Worker model = getattr(nn, mode)(input_size, hidden_size, num_layers) 3720*da0073e9SAndroid Build Coastguard Worker self.assertRaises(RuntimeError, lambda: model(input, hidden)) 3721*da0073e9SAndroid Build Coastguard Worker 3722*da0073e9SAndroid Build Coastguard Worker correct_input_shape = (seq_len, batch_size, input_size) 3723*da0073e9SAndroid Build Coastguard Worker correct_hidden_shape = (num_layers * num_directions, batch_size, hidden_size) 3724*da0073e9SAndroid Build Coastguard Worker 3725*da0073e9SAndroid Build Coastguard Worker def update_shape(shape, dim, new_dim_size): 3726*da0073e9SAndroid Build Coastguard Worker new_shape = list(shape) 3727*da0073e9SAndroid Build Coastguard Worker new_shape[dim] = new_dim_size 3728*da0073e9SAndroid Build Coastguard Worker return tuple(new_shape) 3729*da0073e9SAndroid Build Coastguard Worker 3730*da0073e9SAndroid Build Coastguard Worker def get_inputs(input_shape, hidden_shape, mode): 3731*da0073e9SAndroid Build Coastguard Worker '''returns list( tuple(input, hidden) ) 3732*da0073e9SAndroid Build Coastguard Worker where input, hidden are inputs to a model''' 3733*da0073e9SAndroid Build Coastguard Worker input = torch.randn(input_shape) 3734*da0073e9SAndroid Build Coastguard Worker hidden = torch.randn(hidden_shape) 3735*da0073e9SAndroid Build Coastguard Worker if mode != 'LSTM': 3736*da0073e9SAndroid Build Coastguard Worker return [(input, hidden)] 3737*da0073e9SAndroid Build Coastguard Worker if hidden_shape == correct_hidden_shape: 3738*da0073e9SAndroid Build Coastguard Worker return [(input, (hidden, hidden))] 3739*da0073e9SAndroid Build Coastguard Worker good_hidden = torch.randn(correct_hidden_shape) 3740*da0073e9SAndroid Build Coastguard Worker return [ 3741*da0073e9SAndroid Build Coastguard Worker (input, (hidden, good_hidden)), 3742*da0073e9SAndroid Build Coastguard Worker (input, (good_hidden, hidden)), 3743*da0073e9SAndroid Build Coastguard Worker ] 3744*da0073e9SAndroid Build Coastguard Worker 3745*da0073e9SAndroid Build Coastguard Worker rnn_modes = ['RNN', 'GRU', 'LSTM'] 3746*da0073e9SAndroid Build Coastguard Worker for mode in rnn_modes: 3747*da0073e9SAndroid Build Coastguard Worker # Incorrect input batch size 3748*da0073e9SAndroid Build Coastguard Worker input_shape = update_shape(correct_input_shape, 1, bad_size) 3749*da0073e9SAndroid Build Coastguard Worker hidden_shape = correct_hidden_shape 3750*da0073e9SAndroid Build Coastguard Worker test(input_shape, hidden_shape, mode) 3751*da0073e9SAndroid Build Coastguard Worker 3752*da0073e9SAndroid Build Coastguard Worker # Incorrect hidden batch size 3753*da0073e9SAndroid Build Coastguard Worker input_shape = correct_input_shape 3754*da0073e9SAndroid Build Coastguard Worker hidden_shape = update_shape(correct_hidden_shape, 1, bad_size) 3755*da0073e9SAndroid Build Coastguard Worker test(input_shape, hidden_shape, mode) 3756*da0073e9SAndroid Build Coastguard Worker 3757*da0073e9SAndroid Build Coastguard Worker # Incorrect input size 3758*da0073e9SAndroid Build Coastguard Worker input_shape = update_shape(correct_input_shape, 2, bad_size) 3759*da0073e9SAndroid Build Coastguard Worker hidden_shape = correct_hidden_shape 3760*da0073e9SAndroid Build Coastguard Worker test(input_shape, hidden_shape, mode) 3761*da0073e9SAndroid Build Coastguard Worker 3762*da0073e9SAndroid Build Coastguard Worker # Incorrect hidden size 3763*da0073e9SAndroid Build Coastguard Worker input_shape = correct_input_shape 3764*da0073e9SAndroid Build Coastguard Worker hidden_shape = update_shape(correct_hidden_shape, 2, bad_size) 3765*da0073e9SAndroid Build Coastguard Worker test(input_shape, hidden_shape, mode) 3766*da0073e9SAndroid Build Coastguard Worker 3767*da0073e9SAndroid Build Coastguard Worker # Incorrect hidden[0] 3768*da0073e9SAndroid Build Coastguard Worker input_shape = correct_input_shape 3769*da0073e9SAndroid Build Coastguard Worker hidden_shape = update_shape(correct_hidden_shape, 0, bad_size) 3770*da0073e9SAndroid Build Coastguard Worker test(input_shape, hidden_shape, mode) 3771*da0073e9SAndroid Build Coastguard Worker 3772*da0073e9SAndroid Build Coastguard Worker def test_projections_lstm_args_check(self): 3773*da0073e9SAndroid Build Coastguard Worker input_size = 3 3774*da0073e9SAndroid Build Coastguard Worker hidden_size = 5 3775*da0073e9SAndroid Build Coastguard Worker proj_size = 2 3776*da0073e9SAndroid Build Coastguard Worker num_layers = 2 3777*da0073e9SAndroid Build Coastguard Worker batch_size = 4 3778*da0073e9SAndroid Build Coastguard Worker seq_len = 6 3779*da0073e9SAndroid Build Coastguard Worker num_directions = 1 3780*da0073e9SAndroid Build Coastguard Worker bad_size = 7 # prime number so that no size can divide it. 3781*da0073e9SAndroid Build Coastguard Worker 3782*da0073e9SAndroid Build Coastguard Worker def test(input_shape, hidden_h_shape, hidden_c_shape): 3783*da0073e9SAndroid Build Coastguard Worker for input, hidden in get_inputs(input_shape, hidden_h_shape, hidden_c_shape): 3784*da0073e9SAndroid Build Coastguard Worker model = nn.LSTM(input_size, hidden_size, num_layers, proj_size=proj_size) 3785*da0073e9SAndroid Build Coastguard Worker self.assertRaises(RuntimeError, lambda: model(input, hidden)) 3786*da0073e9SAndroid Build Coastguard Worker 3787*da0073e9SAndroid Build Coastguard Worker correct_input_shape = (seq_len, batch_size, input_size) 3788*da0073e9SAndroid Build Coastguard Worker correct_hidden_h_shape = (num_layers * num_directions, batch_size, proj_size) 3789*da0073e9SAndroid Build Coastguard Worker correct_hidden_c_shape = (num_layers * num_directions, batch_size, hidden_size) 3790*da0073e9SAndroid Build Coastguard Worker 3791*da0073e9SAndroid Build Coastguard Worker def update_shape(shape, dim, new_dim_size): 3792*da0073e9SAndroid Build Coastguard Worker new_shape = list(shape) 3793*da0073e9SAndroid Build Coastguard Worker new_shape[dim] = new_dim_size 3794*da0073e9SAndroid Build Coastguard Worker return tuple(new_shape) 3795*da0073e9SAndroid Build Coastguard Worker 3796*da0073e9SAndroid Build Coastguard Worker def get_inputs(input_shape, hidden_h_shape, hidden_c_shape): 3797*da0073e9SAndroid Build Coastguard Worker '''returns list( tuple(input, hidden) ) 3798*da0073e9SAndroid Build Coastguard Worker where input, hidden are inputs to a model''' 3799*da0073e9SAndroid Build Coastguard Worker input = torch.randn(input_shape) 3800*da0073e9SAndroid Build Coastguard Worker hidden_h = torch.randn(hidden_h_shape) 3801*da0073e9SAndroid Build Coastguard Worker hidden_c = torch.randn(hidden_c_shape) 3802*da0073e9SAndroid Build Coastguard Worker return [(input, (hidden_h, hidden_c))] 3803*da0073e9SAndroid Build Coastguard Worker 3804*da0073e9SAndroid Build Coastguard Worker # Incorrect input batch size 3805*da0073e9SAndroid Build Coastguard Worker input_shape = update_shape(correct_input_shape, 1, bad_size) 3806*da0073e9SAndroid Build Coastguard Worker test(input_shape, correct_hidden_h_shape, correct_hidden_c_shape) 3807*da0073e9SAndroid Build Coastguard Worker 3808*da0073e9SAndroid Build Coastguard Worker # Incorrect hidden batch size 3809*da0073e9SAndroid Build Coastguard Worker input_shape = correct_input_shape 3810*da0073e9SAndroid Build Coastguard Worker hidden_h_shape = update_shape(correct_hidden_h_shape, 1, bad_size) 3811*da0073e9SAndroid Build Coastguard Worker hidden_c_shape = update_shape(correct_hidden_c_shape, 1, bad_size) 3812*da0073e9SAndroid Build Coastguard Worker test(input_shape, hidden_h_shape, hidden_c_shape) 3813*da0073e9SAndroid Build Coastguard Worker 3814*da0073e9SAndroid Build Coastguard Worker # Incorrect input size 3815*da0073e9SAndroid Build Coastguard Worker input_shape = update_shape(correct_input_shape, 2, bad_size) 3816*da0073e9SAndroid Build Coastguard Worker test(input_shape, correct_hidden_h_shape, correct_hidden_c_shape) 3817*da0073e9SAndroid Build Coastguard Worker 3818*da0073e9SAndroid Build Coastguard Worker # Incorrect hidden size 3819*da0073e9SAndroid Build Coastguard Worker input_shape = correct_input_shape 3820*da0073e9SAndroid Build Coastguard Worker hidden_h_shape = update_shape(correct_hidden_h_shape, 2, bad_size) 3821*da0073e9SAndroid Build Coastguard Worker hidden_c_shape = update_shape(correct_hidden_c_shape, 2, bad_size) 3822*da0073e9SAndroid Build Coastguard Worker test(input_shape, hidden_h_shape, hidden_c_shape) 3823*da0073e9SAndroid Build Coastguard Worker 3824*da0073e9SAndroid Build Coastguard Worker # Incorrect hidden[0] 3825*da0073e9SAndroid Build Coastguard Worker input_shape = correct_input_shape 3826*da0073e9SAndroid Build Coastguard Worker hidden_h_shape = update_shape(correct_hidden_h_shape, 0, bad_size) 3827*da0073e9SAndroid Build Coastguard Worker hidden_c_shape = update_shape(correct_hidden_c_shape, 0, bad_size) 3828*da0073e9SAndroid Build Coastguard Worker test(input_shape, hidden_h_shape, hidden_c_shape) 3829*da0073e9SAndroid Build Coastguard Worker 3830*da0073e9SAndroid Build Coastguard Worker # Incorrect proj size = hidden size 3831*da0073e9SAndroid Build Coastguard Worker input_shape = correct_input_shape 3832*da0073e9SAndroid Build Coastguard Worker hidden_h_shape = update_shape(correct_hidden_h_shape, 0, hidden_size) 3833*da0073e9SAndroid Build Coastguard Worker hidden_c_shape = correct_hidden_c_shape 3834*da0073e9SAndroid Build Coastguard Worker test(input_shape, hidden_h_shape, hidden_c_shape) 3835*da0073e9SAndroid Build Coastguard Worker 3836*da0073e9SAndroid Build Coastguard Worker # Incorrect proj size != hidden size 3837*da0073e9SAndroid Build Coastguard Worker input_shape = correct_input_shape 3838*da0073e9SAndroid Build Coastguard Worker hidden_h_shape = update_shape(correct_hidden_h_shape, 0, bad_size) 3839*da0073e9SAndroid Build Coastguard Worker hidden_c_shape = correct_hidden_c_shape 3840*da0073e9SAndroid Build Coastguard Worker test(input_shape, hidden_h_shape, hidden_c_shape) 3841*da0073e9SAndroid Build Coastguard Worker 3842*da0073e9SAndroid Build Coastguard Worker # Incorrect cell size != hidden size 3843*da0073e9SAndroid Build Coastguard Worker input_shape = correct_input_shape 3844*da0073e9SAndroid Build Coastguard Worker hidden_h_shape = correct_hidden_h_shape 3845*da0073e9SAndroid Build Coastguard Worker hidden_c_shape = update_shape(correct_hidden_c_shape, 0, bad_size) 3846*da0073e9SAndroid Build Coastguard Worker test(input_shape, hidden_h_shape, hidden_c_shape) 3847*da0073e9SAndroid Build Coastguard Worker 3848*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported") 3849*da0073e9SAndroid Build Coastguard Worker def test_rnn_check_device(self): 3850*da0073e9SAndroid Build Coastguard Worker import copy 3851*da0073e9SAndroid Build Coastguard Worker input_size = 3 3852*da0073e9SAndroid Build Coastguard Worker hidden_size = 5 3853*da0073e9SAndroid Build Coastguard Worker num_layers = 2 3854*da0073e9SAndroid Build Coastguard Worker batch_size = 4 3855*da0073e9SAndroid Build Coastguard Worker seq_len = 6 3856*da0073e9SAndroid Build Coastguard Worker num_directions = 1 3857*da0073e9SAndroid Build Coastguard Worker 3858*da0073e9SAndroid Build Coastguard Worker correct_input_shape = (seq_len, batch_size, input_size) 3859*da0073e9SAndroid Build Coastguard Worker correct_hidden_shape = (num_layers * num_directions, batch_size, hidden_size) 3860*da0073e9SAndroid Build Coastguard Worker rnn_modes = ['RNN', 'GRU', 'LSTM'] 3861*da0073e9SAndroid Build Coastguard Worker 3862*da0073e9SAndroid Build Coastguard Worker for mode in rnn_modes: 3863*da0073e9SAndroid Build Coastguard Worker model = getattr(nn, mode)(input_size, hidden_size, num_layers) 3864*da0073e9SAndroid Build Coastguard Worker model_cuda = copy.deepcopy(model).to('cuda:0') 3865*da0073e9SAndroid Build Coastguard Worker input = torch.randn(correct_input_shape) 3866*da0073e9SAndroid Build Coastguard Worker hidden = torch.randn(correct_hidden_shape) 3867*da0073e9SAndroid Build Coastguard Worker 3868*da0073e9SAndroid Build Coastguard Worker # input and weights are not at the same device 3869*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, 3870*da0073e9SAndroid Build Coastguard Worker "Input and parameter tensors are not at the same device"): 3871*da0073e9SAndroid Build Coastguard Worker model(input.to('cuda:0')) 3872*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, 3873*da0073e9SAndroid Build Coastguard Worker "Input and parameter tensors are not at the same device"): 3874*da0073e9SAndroid Build Coastguard Worker model_cuda(input) 3875*da0073e9SAndroid Build Coastguard Worker 3876*da0073e9SAndroid Build Coastguard Worker # input and hiddens are not at the same device 3877*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, 3878*da0073e9SAndroid Build Coastguard Worker r"Input and hidden tensors are not at the same device"): 3879*da0073e9SAndroid Build Coastguard Worker if mode == 'LSTM': 3880*da0073e9SAndroid Build Coastguard Worker model(input, (hidden.to('cuda:0'), hidden.to('cuda:0'))) 3881*da0073e9SAndroid Build Coastguard Worker else: 3882*da0073e9SAndroid Build Coastguard Worker model(input, (hidden.to('cuda:0'))) 3883*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, 3884*da0073e9SAndroid Build Coastguard Worker r"Input and hidden tensors are not at the same device"): 3885*da0073e9SAndroid Build Coastguard Worker if mode == 'LSTM': 3886*da0073e9SAndroid Build Coastguard Worker model_cuda(input.to('cuda:0'), (hidden, hidden)) 3887*da0073e9SAndroid Build Coastguard Worker else: 3888*da0073e9SAndroid Build Coastguard Worker model_cuda(input.to('cuda:0'), (hidden)) 3889*da0073e9SAndroid Build Coastguard Worker 3890*da0073e9SAndroid Build Coastguard Worker # hidden tensors are not at the same CUDA device 3891*da0073e9SAndroid Build Coastguard Worker if mode == 'LSTM': 3892*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, 3893*da0073e9SAndroid Build Coastguard Worker "Input and hidden tensors are not at the same device"): 3894*da0073e9SAndroid Build Coastguard Worker model(input.to('cuda:0'), (hidden.to('cuda:0'), hidden.to('cuda:1'))) 3895*da0073e9SAndroid Build Coastguard Worker 3896*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported") 3897*da0073e9SAndroid Build Coastguard Worker def test_projections_lstm_check_device(self): 3898*da0073e9SAndroid Build Coastguard Worker input_size = 3 3899*da0073e9SAndroid Build Coastguard Worker hidden_size = 5 3900*da0073e9SAndroid Build Coastguard Worker proj_size = 2 3901*da0073e9SAndroid Build Coastguard Worker num_layers = 2 3902*da0073e9SAndroid Build Coastguard Worker batch_size = 4 3903*da0073e9SAndroid Build Coastguard Worker seq_len = 6 3904*da0073e9SAndroid Build Coastguard Worker num_directions = 1 3905*da0073e9SAndroid Build Coastguard Worker 3906*da0073e9SAndroid Build Coastguard Worker correct_input_shape = (seq_len, batch_size, input_size) 3907*da0073e9SAndroid Build Coastguard Worker correct_hidden_h_shape = (num_layers * num_directions, batch_size, proj_size) 3908*da0073e9SAndroid Build Coastguard Worker correct_hidden_c_shape = (num_layers * num_directions, batch_size, hidden_size) 3909*da0073e9SAndroid Build Coastguard Worker 3910*da0073e9SAndroid Build Coastguard Worker model = nn.LSTM(input_size, hidden_size, num_layers, proj_size=proj_size) 3911*da0073e9SAndroid Build Coastguard Worker input = torch.randn(correct_input_shape) 3912*da0073e9SAndroid Build Coastguard Worker hidden_h = torch.randn(correct_hidden_h_shape) 3913*da0073e9SAndroid Build Coastguard Worker hidden_c = torch.randn(correct_hidden_c_shape) 3914*da0073e9SAndroid Build Coastguard Worker 3915*da0073e9SAndroid Build Coastguard Worker # input and weights are not at the same device 3916*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, 3917*da0073e9SAndroid Build Coastguard Worker "Input and parameter tensors are not at the same device"): 3918*da0073e9SAndroid Build Coastguard Worker model(input.to('cuda:0')) 3919*da0073e9SAndroid Build Coastguard Worker 3920*da0073e9SAndroid Build Coastguard Worker # input and hiddens are not at the same device 3921*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, 3922*da0073e9SAndroid Build Coastguard Worker r"Input and hidden tensors are not at the same device"): 3923*da0073e9SAndroid Build Coastguard Worker model(input, (hidden_h.to('cuda:0'), hidden_c.to('cuda:0'))) 3924*da0073e9SAndroid Build Coastguard Worker 3925*da0073e9SAndroid Build Coastguard Worker # hidden tensors are not at the same CUDA device 3926*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, 3927*da0073e9SAndroid Build Coastguard Worker "Input and hidden tensors are not at the same device"): 3928*da0073e9SAndroid Build Coastguard Worker model(input.to('cuda:0'), (hidden_h.to('cuda:0'), hidden_c.to('cuda:1'))) 3929*da0073e9SAndroid Build Coastguard Worker 3930*da0073e9SAndroid Build Coastguard Worker def test_rnn_initial_hidden_state(self): 3931*da0073e9SAndroid Build Coastguard Worker rnn_modes = ['RNN', 'GRU', 'LSTM'] 3932*da0073e9SAndroid Build Coastguard Worker for mode in rnn_modes: 3933*da0073e9SAndroid Build Coastguard Worker rnn = getattr(nn, mode)(30, 20, 2) 3934*da0073e9SAndroid Build Coastguard Worker input = torch.randn(10, 32, 30) 3935*da0073e9SAndroid Build Coastguard Worker hidden = torch.zeros(2, 32, 20) 3936*da0073e9SAndroid Build Coastguard Worker 3937*da0073e9SAndroid Build Coastguard Worker if mode == 'LSTM': 3938*da0073e9SAndroid Build Coastguard Worker hidden = (hidden, hidden) 3939*da0073e9SAndroid Build Coastguard Worker output1, hidden1 = rnn(input, hidden) 3940*da0073e9SAndroid Build Coastguard Worker output2, hidden2 = rnn(input) 3941*da0073e9SAndroid Build Coastguard Worker self.assertEqual(output1, output2) 3942*da0073e9SAndroid Build Coastguard Worker self.assertEqual(hidden1, hidden2) 3943*da0073e9SAndroid Build Coastguard Worker 3944*da0073e9SAndroid Build Coastguard Worker def test_projections_lstm_initial_hidden_state(self): 3945*da0073e9SAndroid Build Coastguard Worker for bidir in [False, True]: 3946*da0073e9SAndroid Build Coastguard Worker rnn = nn.LSTM(30, 20, 2, bidirectional=bidir, proj_size=10) 3947*da0073e9SAndroid Build Coastguard Worker num_dirs = 2 if bidir else 1 3948*da0073e9SAndroid Build Coastguard Worker input = torch.randn(10, 32, 30) 3949*da0073e9SAndroid Build Coastguard Worker hidden_h = torch.zeros(2 * num_dirs, 32, 10) 3950*da0073e9SAndroid Build Coastguard Worker hidden_c = torch.zeros(2 * num_dirs, 32, 20) 3951*da0073e9SAndroid Build Coastguard Worker hidden = (hidden_h, hidden_c) 3952*da0073e9SAndroid Build Coastguard Worker output1, hidden1 = rnn(input, hidden) 3953*da0073e9SAndroid Build Coastguard Worker output2, hidden2 = rnn(input) 3954*da0073e9SAndroid Build Coastguard Worker self.assertEqual(output1, output2) 3955*da0073e9SAndroid Build Coastguard Worker self.assertEqual(hidden1, hidden2) 3956*da0073e9SAndroid Build Coastguard Worker 3957*da0073e9SAndroid Build Coastguard Worker def test_projections_errors_on_gru_and_rnn(self): 3958*da0073e9SAndroid Build Coastguard Worker error_msg = "proj_size argument is only supported for LSTM, not RNN or GRU" 3959*da0073e9SAndroid Build Coastguard Worker for mode in ['RNN', 'GRU']: 3960*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(ValueError, error_msg): 3961*da0073e9SAndroid Build Coastguard Worker rnn = getattr(nn, mode)(30, 20, 2, proj_size=10) 3962*da0073e9SAndroid Build Coastguard Worker 3963*da0073e9SAndroid Build Coastguard Worker def _test_RNN_cpu_vs_cudnn(self, dropout, dtype=torch.double): 3964*da0073e9SAndroid Build Coastguard Worker 3965*da0073e9SAndroid Build Coastguard Worker def forward_backward(cuda, rnn, input_val, grad_output, weights_val, hx_val, grad_hy, 3966*da0073e9SAndroid Build Coastguard Worker cx_val=None, grad_cy=None): 3967*da0073e9SAndroid Build Coastguard Worker is_lstm = isinstance(rnn, nn.LSTM) 3968*da0073e9SAndroid Build Coastguard Worker 3969*da0073e9SAndroid Build Coastguard Worker for x_layer, y_layer in zip(rnn.all_weights, weights_val): 3970*da0073e9SAndroid Build Coastguard Worker for x, y in zip(x_layer, y_layer): 3971*da0073e9SAndroid Build Coastguard Worker x.data.copy_(y.data) 3972*da0073e9SAndroid Build Coastguard Worker 3973*da0073e9SAndroid Build Coastguard Worker if isinstance(input_val, rnn_utils.PackedSequence): 3974*da0073e9SAndroid Build Coastguard Worker input = rnn_utils.PackedSequence( 3975*da0073e9SAndroid Build Coastguard Worker input_val.data.data.requires_grad_(True), input_val.batch_sizes) 3976*da0073e9SAndroid Build Coastguard Worker input_var = input.data 3977*da0073e9SAndroid Build Coastguard Worker else: 3978*da0073e9SAndroid Build Coastguard Worker input = input_val.clone().requires_grad_(True) 3979*da0073e9SAndroid Build Coastguard Worker input_var = input 3980*da0073e9SAndroid Build Coastguard Worker if is_lstm: 3981*da0073e9SAndroid Build Coastguard Worker if cx_val is None: 3982*da0073e9SAndroid Build Coastguard Worker hx = (hx_val.clone().requires_grad_(True), 3983*da0073e9SAndroid Build Coastguard Worker hx_val.add(1).requires_grad_(True)) 3984*da0073e9SAndroid Build Coastguard Worker else: 3985*da0073e9SAndroid Build Coastguard Worker hx = (hx_val.clone().requires_grad_(True), 3986*da0073e9SAndroid Build Coastguard Worker cx_val.add(1).requires_grad_(True)) 3987*da0073e9SAndroid Build Coastguard Worker else: 3988*da0073e9SAndroid Build Coastguard Worker hx = hx_val.clone().requires_grad_(True) 3989*da0073e9SAndroid Build Coastguard Worker 3990*da0073e9SAndroid Build Coastguard Worker if cuda: 3991*da0073e9SAndroid Build Coastguard Worker rnn.cuda() 3992*da0073e9SAndroid Build Coastguard Worker input_var.data = input_var.data.cuda() 3993*da0073e9SAndroid Build Coastguard Worker if is_lstm: 3994*da0073e9SAndroid Build Coastguard Worker hx[0].data = hx[0].data.cuda() 3995*da0073e9SAndroid Build Coastguard Worker hx[1].data = hx[1].data.cuda() 3996*da0073e9SAndroid Build Coastguard Worker else: 3997*da0073e9SAndroid Build Coastguard Worker hx.data = hx.data.cuda() 3998*da0073e9SAndroid Build Coastguard Worker grad_hy = grad_hy.cuda() 3999*da0073e9SAndroid Build Coastguard Worker if grad_cy is not None: 4000*da0073e9SAndroid Build Coastguard Worker grad_cy = grad_cy.cuda() 4001*da0073e9SAndroid Build Coastguard Worker grad_output = grad_output.cuda() 4002*da0073e9SAndroid Build Coastguard Worker 4003*da0073e9SAndroid Build Coastguard Worker output, hy = rnn(input, hx) 4004*da0073e9SAndroid Build Coastguard Worker 4005*da0073e9SAndroid Build Coastguard Worker if isinstance(output, rnn_utils.PackedSequence): 4006*da0073e9SAndroid Build Coastguard Worker output = output.data 4007*da0073e9SAndroid Build Coastguard Worker 4008*da0073e9SAndroid Build Coastguard Worker if is_lstm: 4009*da0073e9SAndroid Build Coastguard Worker if grad_cy is None: 4010*da0073e9SAndroid Build Coastguard Worker torch.autograd.backward([output, hy[0], hy[1]], [grad_output, grad_hy, grad_hy + 1]) 4011*da0073e9SAndroid Build Coastguard Worker else: 4012*da0073e9SAndroid Build Coastguard Worker torch.autograd.backward([output, hy[0], hy[1]], [grad_output, grad_hy, grad_cy + 1]) 4013*da0073e9SAndroid Build Coastguard Worker else: 4014*da0073e9SAndroid Build Coastguard Worker torch.autograd.backward([output, hy], [grad_output, grad_hy]) 4015*da0073e9SAndroid Build Coastguard Worker 4016*da0073e9SAndroid Build Coastguard Worker return {'output': output.data, 4017*da0073e9SAndroid Build Coastguard Worker 'hy': hy[0].data if is_lstm else hy.data, 4018*da0073e9SAndroid Build Coastguard Worker 'weights': rnn.all_weights, 4019*da0073e9SAndroid Build Coastguard Worker 'grad_input': input_var.grad.data, 4020*da0073e9SAndroid Build Coastguard Worker 'grad_hx': hx[0].grad.data if is_lstm else hx.grad.data, 4021*da0073e9SAndroid Build Coastguard Worker 'cy': hy[1].data if is_lstm else None, 4022*da0073e9SAndroid Build Coastguard Worker 'grad_cx': hx[1].grad.data if is_lstm else None} 4023*da0073e9SAndroid Build Coastguard Worker 4024*da0073e9SAndroid Build Coastguard Worker input_size = 10 4025*da0073e9SAndroid Build Coastguard Worker hidden_size = 6 4026*da0073e9SAndroid Build Coastguard Worker proj_size = 3 4027*da0073e9SAndroid Build Coastguard Worker num_layers = 2 4028*da0073e9SAndroid Build Coastguard Worker seq_length = 7 4029*da0073e9SAndroid Build Coastguard Worker batch = 6 4030*da0073e9SAndroid Build Coastguard Worker 4031*da0073e9SAndroid Build Coastguard Worker def make_noncontig(tensor): 4032*da0073e9SAndroid Build Coastguard Worker ndim = tensor.dim() 4033*da0073e9SAndroid Build Coastguard Worker return torch.stack([tensor.clone().zero_(), tensor], ndim).select(ndim, 1) 4034*da0073e9SAndroid Build Coastguard Worker 4035*da0073e9SAndroid Build Coastguard Worker def compare_cpu_gpu(outputs_cpu, outputs_gpu): 4036*da0073e9SAndroid Build Coastguard Worker self.assertEqual(list(outputs_cpu.keys()), list(outputs_gpu.keys())) 4037*da0073e9SAndroid Build Coastguard Worker for key in outputs_cpu.keys(): 4038*da0073e9SAndroid Build Coastguard Worker if key != 'weights': 4039*da0073e9SAndroid Build Coastguard Worker self.assertEqual(outputs_cpu[key], outputs_gpu[key], atol=5e-5, rtol=0, msg=key) 4040*da0073e9SAndroid Build Coastguard Worker 4041*da0073e9SAndroid Build Coastguard Worker # check grad weights separately, as nested dict 4042*da0073e9SAndroid Build Coastguard Worker for cpu_layer_weight, gpu_layer_weight in zip(outputs_cpu['weights'], outputs_gpu['weights']): 4043*da0073e9SAndroid Build Coastguard Worker for (cpu_weight, gpu_weight) in zip(cpu_layer_weight, gpu_layer_weight): 4044*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cpu_weight.grad.data, gpu_weight.grad.data, atol=5e-5, rtol=0) 4045*da0073e9SAndroid Build Coastguard Worker 4046*da0073e9SAndroid Build Coastguard Worker for module in (nn.RNN, nn.LSTM, nn.GRU): 4047*da0073e9SAndroid Build Coastguard Worker for bias, bidirectional, batch_first, contig, variable_len, lens_as_tensor \ 4048*da0073e9SAndroid Build Coastguard Worker in product((True, False), repeat=6): 4049*da0073e9SAndroid Build Coastguard Worker 4050*da0073e9SAndroid Build Coastguard Worker num_directions = 2 if bidirectional else 1 4051*da0073e9SAndroid Build Coastguard Worker if batch_first: 4052*da0073e9SAndroid Build Coastguard Worker input_val = torch.randn(batch, seq_length, input_size, dtype=dtype) 4053*da0073e9SAndroid Build Coastguard Worker grad_output = torch.randn(batch, seq_length, hidden_size * num_directions, dtype=dtype) 4054*da0073e9SAndroid Build Coastguard Worker else: 4055*da0073e9SAndroid Build Coastguard Worker input_val = torch.randn(seq_length, batch, input_size, dtype=dtype) 4056*da0073e9SAndroid Build Coastguard Worker grad_output = torch.randn(seq_length, batch, hidden_size * num_directions, dtype=dtype) 4057*da0073e9SAndroid Build Coastguard Worker 4058*da0073e9SAndroid Build Coastguard Worker hx_val = torch.randn(num_layers * num_directions, batch, hidden_size, dtype=dtype) 4059*da0073e9SAndroid Build Coastguard Worker grad_hy = torch.randn(num_layers * num_directions, batch, hidden_size, dtype=dtype) 4060*da0073e9SAndroid Build Coastguard Worker 4061*da0073e9SAndroid Build Coastguard Worker if not contig: 4062*da0073e9SAndroid Build Coastguard Worker grad_output = make_noncontig(grad_output) 4063*da0073e9SAndroid Build Coastguard Worker grad_hy = make_noncontig(grad_hy) 4064*da0073e9SAndroid Build Coastguard Worker input_var = make_noncontig(input_val) 4065*da0073e9SAndroid Build Coastguard Worker hx_val = make_noncontig(hx_val) 4066*da0073e9SAndroid Build Coastguard Worker 4067*da0073e9SAndroid Build Coastguard Worker if variable_len: 4068*da0073e9SAndroid Build Coastguard Worker lengths = [7, 5, 5, 2, 1, 1] 4069*da0073e9SAndroid Build Coastguard Worker if lens_as_tensor: 4070*da0073e9SAndroid Build Coastguard Worker lengths = torch.tensor(lengths, dtype=torch.long) 4071*da0073e9SAndroid Build Coastguard Worker input_val = rnn_utils.pack_padded_sequence(input_val, lengths, batch_first=batch_first) 4072*da0073e9SAndroid Build Coastguard Worker grad_output = rnn_utils.pack_padded_sequence(grad_output, lengths, batch_first=batch_first).data 4073*da0073e9SAndroid Build Coastguard Worker 4074*da0073e9SAndroid Build Coastguard Worker rnn = module(input_size, 4075*da0073e9SAndroid Build Coastguard Worker hidden_size, 4076*da0073e9SAndroid Build Coastguard Worker num_layers, 4077*da0073e9SAndroid Build Coastguard Worker bias=bias, 4078*da0073e9SAndroid Build Coastguard Worker dropout=dropout, 4079*da0073e9SAndroid Build Coastguard Worker bidirectional=bidirectional, 4080*da0073e9SAndroid Build Coastguard Worker batch_first=batch_first).to(dtype) 4081*da0073e9SAndroid Build Coastguard Worker 4082*da0073e9SAndroid Build Coastguard Worker outputs_cpu = forward_backward( 4083*da0073e9SAndroid Build Coastguard Worker False, rnn, input_val, grad_output, rnn.all_weights, hx_val, grad_hy) 4084*da0073e9SAndroid Build Coastguard Worker 4085*da0073e9SAndroid Build Coastguard Worker rnn_gpu = module(input_size, 4086*da0073e9SAndroid Build Coastguard Worker hidden_size, 4087*da0073e9SAndroid Build Coastguard Worker num_layers, 4088*da0073e9SAndroid Build Coastguard Worker bias=bias, 4089*da0073e9SAndroid Build Coastguard Worker dropout=dropout, 4090*da0073e9SAndroid Build Coastguard Worker bidirectional=bidirectional, 4091*da0073e9SAndroid Build Coastguard Worker batch_first=batch_first).to(dtype) 4092*da0073e9SAndroid Build Coastguard Worker 4093*da0073e9SAndroid Build Coastguard Worker outputs_gpu = forward_backward( 4094*da0073e9SAndroid Build Coastguard Worker True, rnn_gpu, input_val, grad_output, rnn.all_weights, hx_val, grad_hy) 4095*da0073e9SAndroid Build Coastguard Worker 4096*da0073e9SAndroid Build Coastguard Worker compare_cpu_gpu(outputs_cpu, outputs_gpu) 4097*da0073e9SAndroid Build Coastguard Worker 4098*da0073e9SAndroid Build Coastguard Worker for nonlinearity in ('tanh', 'relu'): 4099*da0073e9SAndroid Build Coastguard Worker hx_val = torch.randn(num_layers, batch, hidden_size, dtype=dtype) 4100*da0073e9SAndroid Build Coastguard Worker input_val = torch.randn(seq_length, batch, input_size, dtype=dtype) 4101*da0073e9SAndroid Build Coastguard Worker grad_output = torch.randn( 4102*da0073e9SAndroid Build Coastguard Worker seq_length, batch, hidden_size * num_directions, dtype=dtype) 4103*da0073e9SAndroid Build Coastguard Worker grad_hy = torch.randn( 4104*da0073e9SAndroid Build Coastguard Worker num_layers * num_directions, batch, hidden_size, dtype=dtype) 4105*da0073e9SAndroid Build Coastguard Worker 4106*da0073e9SAndroid Build Coastguard Worker rnn = nn.RNN(input_size, hidden_size, num_layers, bias=bias, nonlinearity=nonlinearity).to(dtype) 4107*da0073e9SAndroid Build Coastguard Worker outputs_cpu = forward_backward(False, rnn, input_val, grad_output, rnn.all_weights, hx_val, grad_hy) 4108*da0073e9SAndroid Build Coastguard Worker 4109*da0073e9SAndroid Build Coastguard Worker rnn_gpu = nn.RNN(input_size, hidden_size, num_layers, bias=bias, nonlinearity=nonlinearity).to(dtype) 4110*da0073e9SAndroid Build Coastguard Worker outputs_gpu = forward_backward(True, rnn_gpu, input_val, grad_output, rnn.all_weights, hx_val, grad_hy) 4111*da0073e9SAndroid Build Coastguard Worker 4112*da0073e9SAndroid Build Coastguard Worker compare_cpu_gpu(outputs_cpu, outputs_gpu) 4113*da0073e9SAndroid Build Coastguard Worker 4114*da0073e9SAndroid Build Coastguard Worker # checking LSTM with projections 4115*da0073e9SAndroid Build Coastguard Worker for bias, bidirectional, batch_first, contig, variable_len, lens_as_tensor \ 4116*da0073e9SAndroid Build Coastguard Worker in product((True, False), repeat=6): 4117*da0073e9SAndroid Build Coastguard Worker num_directions = 2 if bidirectional else 1 4118*da0073e9SAndroid Build Coastguard Worker if batch_first: 4119*da0073e9SAndroid Build Coastguard Worker input_val = torch.randn(batch, seq_length, input_size, dtype=dtype) 4120*da0073e9SAndroid Build Coastguard Worker grad_output = torch.randn(batch, seq_length, proj_size * num_directions, dtype=dtype) 4121*da0073e9SAndroid Build Coastguard Worker else: 4122*da0073e9SAndroid Build Coastguard Worker input_val = torch.randn(seq_length, batch, input_size, dtype=dtype) 4123*da0073e9SAndroid Build Coastguard Worker grad_output = torch.randn(seq_length, batch, proj_size * num_directions, dtype=dtype) 4124*da0073e9SAndroid Build Coastguard Worker 4125*da0073e9SAndroid Build Coastguard Worker hx_val = torch.randn(num_layers * num_directions, batch, proj_size, dtype=dtype) 4126*da0073e9SAndroid Build Coastguard Worker cx_val = torch.randn(num_layers * num_directions, batch, hidden_size, dtype=dtype) 4127*da0073e9SAndroid Build Coastguard Worker grad_hy = torch.randn(num_layers * num_directions, batch, proj_size, dtype=dtype) 4128*da0073e9SAndroid Build Coastguard Worker grad_cy = torch.randn(num_layers * num_directions, batch, hidden_size, dtype=dtype) 4129*da0073e9SAndroid Build Coastguard Worker 4130*da0073e9SAndroid Build Coastguard Worker if not contig: 4131*da0073e9SAndroid Build Coastguard Worker grad_output = make_noncontig(grad_output) 4132*da0073e9SAndroid Build Coastguard Worker grad_hy = make_noncontig(grad_hy) 4133*da0073e9SAndroid Build Coastguard Worker grad_cy = make_noncontig(grad_cy) 4134*da0073e9SAndroid Build Coastguard Worker input_var = make_noncontig(input_val) 4135*da0073e9SAndroid Build Coastguard Worker hx_val = make_noncontig(hx_val) 4136*da0073e9SAndroid Build Coastguard Worker cx_val = make_noncontig(cx_val) 4137*da0073e9SAndroid Build Coastguard Worker 4138*da0073e9SAndroid Build Coastguard Worker if variable_len: 4139*da0073e9SAndroid Build Coastguard Worker lengths = [7, 5, 5, 2, 1, 1] 4140*da0073e9SAndroid Build Coastguard Worker if lens_as_tensor: 4141*da0073e9SAndroid Build Coastguard Worker lengths = torch.tensor(lengths, dtype=torch.long) 4142*da0073e9SAndroid Build Coastguard Worker input_val = rnn_utils.pack_padded_sequence(input_val, lengths, batch_first=batch_first) 4143*da0073e9SAndroid Build Coastguard Worker grad_output = rnn_utils.pack_padded_sequence(grad_output, lengths, batch_first=batch_first).data 4144*da0073e9SAndroid Build Coastguard Worker 4145*da0073e9SAndroid Build Coastguard Worker rnn = nn.LSTM(input_size, 4146*da0073e9SAndroid Build Coastguard Worker hidden_size, 4147*da0073e9SAndroid Build Coastguard Worker num_layers, 4148*da0073e9SAndroid Build Coastguard Worker bias=bias, 4149*da0073e9SAndroid Build Coastguard Worker dropout=dropout, 4150*da0073e9SAndroid Build Coastguard Worker bidirectional=bidirectional, 4151*da0073e9SAndroid Build Coastguard Worker batch_first=batch_first, 4152*da0073e9SAndroid Build Coastguard Worker proj_size=proj_size).to(dtype) 4153*da0073e9SAndroid Build Coastguard Worker 4154*da0073e9SAndroid Build Coastguard Worker outputs_cpu = forward_backward( 4155*da0073e9SAndroid Build Coastguard Worker False, rnn, input_val, grad_output, rnn.all_weights, 4156*da0073e9SAndroid Build Coastguard Worker hx_val, grad_hy, cx_val, grad_cy) 4157*da0073e9SAndroid Build Coastguard Worker 4158*da0073e9SAndroid Build Coastguard Worker rnn_gpu = nn.LSTM(input_size, 4159*da0073e9SAndroid Build Coastguard Worker hidden_size, 4160*da0073e9SAndroid Build Coastguard Worker num_layers, 4161*da0073e9SAndroid Build Coastguard Worker bias=bias, 4162*da0073e9SAndroid Build Coastguard Worker dropout=dropout, 4163*da0073e9SAndroid Build Coastguard Worker bidirectional=bidirectional, 4164*da0073e9SAndroid Build Coastguard Worker batch_first=batch_first, 4165*da0073e9SAndroid Build Coastguard Worker proj_size=proj_size).to(dtype) 4166*da0073e9SAndroid Build Coastguard Worker 4167*da0073e9SAndroid Build Coastguard Worker outputs_gpu = forward_backward( 4168*da0073e9SAndroid Build Coastguard Worker True, rnn_gpu, input_val, grad_output, rnn.all_weights, 4169*da0073e9SAndroid Build Coastguard Worker hx_val, grad_hy, cx_val, grad_cy) 4170*da0073e9SAndroid Build Coastguard Worker compare_cpu_gpu(outputs_cpu, outputs_gpu) 4171*da0073e9SAndroid Build Coastguard Worker 4172*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_CUDNN, "needs cudnn") 4173*da0073e9SAndroid Build Coastguard Worker def test_RNN_cpu_vs_cudnn_no_dropout(self): 4174*da0073e9SAndroid Build Coastguard Worker dtype = torch.double 4175*da0073e9SAndroid Build Coastguard Worker self._test_RNN_cpu_vs_cudnn(0, dtype) 4176*da0073e9SAndroid Build Coastguard Worker 4177*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_CUDNN, "needs cudnn") 4178*da0073e9SAndroid Build Coastguard Worker def test_RNN_cpu_vs_cudnn_with_dropout(self): 4179*da0073e9SAndroid Build Coastguard Worker # Because of dropout randomness, can only compare dropout=0 and dropout=1 4180*da0073e9SAndroid Build Coastguard Worker self._test_RNN_cpu_vs_cudnn(1) 4181*da0073e9SAndroid Build Coastguard Worker 4182*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_CUDNN, "needs cudnn") 4183*da0073e9SAndroid Build Coastguard Worker def test_RNN_cudnn_weight_norm(self): 4184*da0073e9SAndroid Build Coastguard Worker input_size = 10 4185*da0073e9SAndroid Build Coastguard Worker hidden_size = 6 4186*da0073e9SAndroid Build Coastguard Worker num_layers = 2 4187*da0073e9SAndroid Build Coastguard Worker seq_length = 7 4188*da0073e9SAndroid Build Coastguard Worker batch = 6 4189*da0073e9SAndroid Build Coastguard Worker 4190*da0073e9SAndroid Build Coastguard Worker # runs on CPU to acquire expected output 4191*da0073e9SAndroid Build Coastguard Worker def check_weight_norm(m, name): 4192*da0073e9SAndroid Build Coastguard Worker input = torch.randn(seq_length, batch, input_size) 4193*da0073e9SAndroid Build Coastguard Worker expected_output = m(input) 4194*da0073e9SAndroid Build Coastguard Worker 4195*da0073e9SAndroid Build Coastguard Worker # adds weight normalization 4196*da0073e9SAndroid Build Coastguard Worker m = torch.nn.utils.weight_norm(m, name=name) 4197*da0073e9SAndroid Build Coastguard Worker 4198*da0073e9SAndroid Build Coastguard Worker # moves to CUDA 4199*da0073e9SAndroid Build Coastguard Worker m = m.cuda() 4200*da0073e9SAndroid Build Coastguard Worker input = input.cuda() 4201*da0073e9SAndroid Build Coastguard Worker 4202*da0073e9SAndroid Build Coastguard Worker # otherwise, subsequent warnings will be hidden, and further tests rely on them 4203*da0073e9SAndroid Build Coastguard Worker warnings.simplefilter("always") 4204*da0073e9SAndroid Build Coastguard Worker self.assertEqual(m(input), expected_output) 4205*da0073e9SAndroid Build Coastguard Worker 4206*da0073e9SAndroid Build Coastguard Worker # remove weight norm 4207*da0073e9SAndroid Build Coastguard Worker m = torch.nn.utils.remove_weight_norm(m, name=name) 4208*da0073e9SAndroid Build Coastguard Worker self.assertEqual(m(input), expected_output) 4209*da0073e9SAndroid Build Coastguard Worker 4210*da0073e9SAndroid Build Coastguard Worker check_weight_norm(nn.LSTM(input_size, hidden_size, num_layers), 'weight_hh_l0') 4211*da0073e9SAndroid Build Coastguard Worker check_weight_norm(nn.LSTM(input_size, hidden_size, num_layers, proj_size=3), 'weight_hr_l0') 4212*da0073e9SAndroid Build Coastguard Worker 4213*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_CUDA, 'CUDA not available') 4214*da0073e9SAndroid Build Coastguard Worker def test_partial_flat_weights(self): 4215*da0073e9SAndroid Build Coastguard Worker input_size = 10 4216*da0073e9SAndroid Build Coastguard Worker hidden_size = 6 4217*da0073e9SAndroid Build Coastguard Worker num_layers = 2 4218*da0073e9SAndroid Build Coastguard Worker 4219*da0073e9SAndroid Build Coastguard Worker m = nn.LSTM(input_size, hidden_size, num_layers) 4220*da0073e9SAndroid Build Coastguard Worker inp = torch.randn(3, 2, 10) 4221*da0073e9SAndroid Build Coastguard Worker out_expected = m(inp) 4222*da0073e9SAndroid Build Coastguard Worker # deletes an attribute of original LSTM 4223*da0073e9SAndroid Build Coastguard Worker weight_orig = m.weight_hh_l0 4224*da0073e9SAndroid Build Coastguard Worker del m.weight_hh_l0 4225*da0073e9SAndroid Build Coastguard Worker self.assertFalse(hasattr(m, "weight_hh_l0")) 4226*da0073e9SAndroid Build Coastguard Worker # verifies that moving to CUDA with only some attributes defined 4227*da0073e9SAndroid Build Coastguard Worker # does not throw an error 4228*da0073e9SAndroid Build Coastguard Worker m.cuda() 4229*da0073e9SAndroid Build Coastguard Worker # recompute the weight and make sure that module can be used 4230*da0073e9SAndroid Build Coastguard Worker m.weight_hh_l0 = weight_orig.cuda() 4231*da0073e9SAndroid Build Coastguard Worker inp = inp.cuda() 4232*da0073e9SAndroid Build Coastguard Worker # otherwise, subsequent warnings will be hidden, and further tests rely on them 4233*da0073e9SAndroid Build Coastguard Worker warnings.simplefilter("always") 4234*da0073e9SAndroid Build Coastguard Worker self.assertEqual(m(inp)[0].cpu(), out_expected[0]) 4235*da0073e9SAndroid Build Coastguard Worker 4236*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_CUDNN, "needs cudnn") 4237*da0073e9SAndroid Build Coastguard Worker @set_default_dtype(torch.double) 4238*da0073e9SAndroid Build Coastguard Worker def test_RNN_dropout(self): 4239*da0073e9SAndroid Build Coastguard Worker # checking the assumption that cuDNN sticks dropout in between 4240*da0073e9SAndroid Build Coastguard Worker # RNN layers 4241*da0073e9SAndroid Build Coastguard Worker for p in (0, 0.276, 0.731, 1): 4242*da0073e9SAndroid Build Coastguard Worker for train in (True, False): 4243*da0073e9SAndroid Build Coastguard Worker for cuda in (True, False): 4244*da0073e9SAndroid Build Coastguard Worker rnn = nn.RNN(10, 1000, 2, bias=False, dropout=p, nonlinearity='relu') 4245*da0073e9SAndroid Build Coastguard Worker if cuda: 4246*da0073e9SAndroid Build Coastguard Worker rnn.cuda() 4247*da0073e9SAndroid Build Coastguard Worker 4248*da0073e9SAndroid Build Coastguard Worker if train: 4249*da0073e9SAndroid Build Coastguard Worker rnn.train() 4250*da0073e9SAndroid Build Coastguard Worker else: 4251*da0073e9SAndroid Build Coastguard Worker rnn.eval() 4252*da0073e9SAndroid Build Coastguard Worker rnn.weight_ih_l0.data.fill_(1) 4253*da0073e9SAndroid Build Coastguard Worker rnn.weight_hh_l0.data.fill_(1) 4254*da0073e9SAndroid Build Coastguard Worker rnn.weight_ih_l1.data.fill_(1) 4255*da0073e9SAndroid Build Coastguard Worker rnn.weight_hh_l1.data.fill_(1) 4256*da0073e9SAndroid Build Coastguard Worker input = torch.ones(1, 1, 10) 4257*da0073e9SAndroid Build Coastguard Worker hx = torch.zeros(2, 1, 1000) 4258*da0073e9SAndroid Build Coastguard Worker if cuda: 4259*da0073e9SAndroid Build Coastguard Worker input = input.cuda() 4260*da0073e9SAndroid Build Coastguard Worker hx = hx.cuda() 4261*da0073e9SAndroid Build Coastguard Worker 4262*da0073e9SAndroid Build Coastguard Worker output, hy = rnn(input, hx) 4263*da0073e9SAndroid Build Coastguard Worker self.assertEqual(output.data.min(), output.data.max()) 4264*da0073e9SAndroid Build Coastguard Worker output_val = output.data[0][0][0] 4265*da0073e9SAndroid Build Coastguard Worker if p == 0 or not train: 4266*da0073e9SAndroid Build Coastguard Worker self.assertEqual(output_val, 10000) 4267*da0073e9SAndroid Build Coastguard Worker elif p == 1: 4268*da0073e9SAndroid Build Coastguard Worker self.assertEqual(output_val, 0) 4269*da0073e9SAndroid Build Coastguard Worker else: 4270*da0073e9SAndroid Build Coastguard Worker self.assertGreater(output_val, 8000) 4271*da0073e9SAndroid Build Coastguard Worker self.assertLess(output_val, 12000) 4272*da0073e9SAndroid Build Coastguard Worker denorm_mod = (output_val * (1 - p)) % 10 4273*da0073e9SAndroid Build Coastguard Worker self.assertLess(min(denorm_mod, 10 - denorm_mod), 1e-2) 4274*da0073e9SAndroid Build Coastguard Worker 4275*da0073e9SAndroid Build Coastguard Worker self.assertEqual(hy[0].data.min(), hy[0].data.max()) 4276*da0073e9SAndroid Build Coastguard Worker self.assertEqual(hy[1].data.min(), hy[1].data.max()) 4277*da0073e9SAndroid Build Coastguard Worker self.assertEqual(hy.data[0][0][0], 10) 4278*da0073e9SAndroid Build Coastguard Worker self.assertEqual(hy.data[1][0][0], output_val) 4279*da0073e9SAndroid Build Coastguard Worker 4280*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_CUDNN, "needs cudnn") 4281*da0073e9SAndroid Build Coastguard Worker @set_default_dtype(torch.double) 4282*da0073e9SAndroid Build Coastguard Worker def test_error_RNN_seq_len_zero(self): 4283*da0073e9SAndroid Build Coastguard Worker # checking error message when RNN has seq_len = 0 4284*da0073e9SAndroid Build Coastguard Worker for module in (nn.RNN, nn.LSTM, nn.GRU): 4285*da0073e9SAndroid Build Coastguard Worker for bidirectional in [True, False]: 4286*da0073e9SAndroid Build Coastguard Worker for device in get_all_device_types(): 4287*da0073e9SAndroid Build Coastguard Worker input = torch.ones(0, 10, 5) 4288*da0073e9SAndroid Build Coastguard Worker rnn = module(5, 6, bidirectional=bidirectional) 4289*da0073e9SAndroid Build Coastguard Worker if device == 'cuda': 4290*da0073e9SAndroid Build Coastguard Worker rnn.cuda() 4291*da0073e9SAndroid Build Coastguard Worker input = input.cuda() 4292*da0073e9SAndroid Build Coastguard Worker 4293*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "Expected sequence length to be larger than 0 in RNN"): 4294*da0073e9SAndroid Build Coastguard Worker rnn(input) 4295*da0073e9SAndroid Build Coastguard Worker 4296*da0073e9SAndroid Build Coastguard Worker def test_RNN_input_size_zero(self): 4297*da0073e9SAndroid Build Coastguard Worker for module in (nn.RNN, nn.LSTM, nn.GRU): 4298*da0073e9SAndroid Build Coastguard Worker for device in get_all_device_types(): 4299*da0073e9SAndroid Build Coastguard Worker input = torch.zeros((5, 0, 3)) 4300*da0073e9SAndroid Build Coastguard Worker rnn = module(input_size=3, hidden_size=4) 4301*da0073e9SAndroid Build Coastguard Worker if device == 'cuda': 4302*da0073e9SAndroid Build Coastguard Worker rnn.cuda() 4303*da0073e9SAndroid Build Coastguard Worker input = input.cuda() 4304*da0073e9SAndroid Build Coastguard Worker outs = rnn(input) 4305*da0073e9SAndroid Build Coastguard Worker self.assertEqual(outs[0].shape, torch.Size([5, 0, 4])) 4306*da0073e9SAndroid Build Coastguard Worker # Check that backward does not cause a hard error 4307*da0073e9SAndroid Build Coastguard Worker outs[0].sum().backward() 4308*da0073e9SAndroid Build Coastguard Worker 4309*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_CUDNN, "needs cudnn") 4310*da0073e9SAndroid Build Coastguard Worker def test_RNN_dropout_state(self): 4311*da0073e9SAndroid Build Coastguard Worker for p in (0, 0.1234): 4312*da0073e9SAndroid Build Coastguard Worker for train in (True, False): 4313*da0073e9SAndroid Build Coastguard Worker for cuda in (True, False): 4314*da0073e9SAndroid Build Coastguard Worker rnn = nn.RNN(100, 100, 2, bias=False, dropout=p, nonlinearity='relu') 4315*da0073e9SAndroid Build Coastguard Worker if cuda: 4316*da0073e9SAndroid Build Coastguard Worker rnn.cuda() 4317*da0073e9SAndroid Build Coastguard Worker 4318*da0073e9SAndroid Build Coastguard Worker if train: 4319*da0073e9SAndroid Build Coastguard Worker rnn.train() 4320*da0073e9SAndroid Build Coastguard Worker else: 4321*da0073e9SAndroid Build Coastguard Worker rnn.eval() 4322*da0073e9SAndroid Build Coastguard Worker input = torch.rand(1, 1, 100) 4323*da0073e9SAndroid Build Coastguard Worker hx = torch.rand(2, 1, 100) 4324*da0073e9SAndroid Build Coastguard Worker if cuda: 4325*da0073e9SAndroid Build Coastguard Worker input = input.cuda() 4326*da0073e9SAndroid Build Coastguard Worker hx = hx.cuda() 4327*da0073e9SAndroid Build Coastguard Worker 4328*da0073e9SAndroid Build Coastguard Worker output1, hy1 = rnn(input, hx) 4329*da0073e9SAndroid Build Coastguard Worker output2, hy2 = rnn(input, hx) 4330*da0073e9SAndroid Build Coastguard Worker 4331*da0073e9SAndroid Build Coastguard Worker buf = io.BytesIO() 4332*da0073e9SAndroid Build Coastguard Worker rnn_pickle = torch.save(rnn, buf) 4333*da0073e9SAndroid Build Coastguard Worker buf.seek(0) 4334*da0073e9SAndroid Build Coastguard Worker # weights_only=False as this is legacy code that saves the model 4335*da0073e9SAndroid Build Coastguard Worker rnn2 = torch.load(buf, weights_only=False) 4336*da0073e9SAndroid Build Coastguard Worker rnn2.flatten_parameters() 4337*da0073e9SAndroid Build Coastguard Worker output3, hy3 = rnn2(input, hx) 4338*da0073e9SAndroid Build Coastguard Worker 4339*da0073e9SAndroid Build Coastguard Worker if p == 0 or not train: 4340*da0073e9SAndroid Build Coastguard Worker self.assertEqual(output1, output2) 4341*da0073e9SAndroid Build Coastguard Worker self.assertEqual(output1, output3) 4342*da0073e9SAndroid Build Coastguard Worker self.assertEqual(hy1, hy2) 4343*da0073e9SAndroid Build Coastguard Worker self.assertEqual(hy1, hy3) 4344*da0073e9SAndroid Build Coastguard Worker else: 4345*da0073e9SAndroid Build Coastguard Worker self.assertNotEqual(output1, output2) 4346*da0073e9SAndroid Build Coastguard Worker self.assertNotEqual(output1, output3) 4347*da0073e9SAndroid Build Coastguard Worker self.assertNotEqual(hy1, hy2) 4348*da0073e9SAndroid Build Coastguard Worker self.assertNotEqual(hy1, hy3) 4349*da0073e9SAndroid Build Coastguard Worker 4350*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_CUDNN, "needs cudnn") 4351*da0073e9SAndroid Build Coastguard Worker @set_default_dtype(torch.double) 4352*da0073e9SAndroid Build Coastguard Worker def test_RNN_change_dropout(self): 4353*da0073e9SAndroid Build Coastguard Worker for train, cuda in product((True, False), repeat=2): 4354*da0073e9SAndroid Build Coastguard Worker rnn = nn.RNN(100, 100, 2, dropout=0, nonlinearity='relu') 4355*da0073e9SAndroid Build Coastguard Worker input = torch.rand(3, 2, 100) 4356*da0073e9SAndroid Build Coastguard Worker if cuda: 4357*da0073e9SAndroid Build Coastguard Worker input.data = input.data.cuda() 4358*da0073e9SAndroid Build Coastguard Worker rnn.cuda() 4359*da0073e9SAndroid Build Coastguard Worker 4360*da0073e9SAndroid Build Coastguard Worker if train: 4361*da0073e9SAndroid Build Coastguard Worker rnn.train() 4362*da0073e9SAndroid Build Coastguard Worker else: 4363*da0073e9SAndroid Build Coastguard Worker rnn.eval() 4364*da0073e9SAndroid Build Coastguard Worker 4365*da0073e9SAndroid Build Coastguard Worker prev_output = None 4366*da0073e9SAndroid Build Coastguard Worker for p in (0, 0.5, 0, 0.7, 0.2, 1, 0.2, 0): 4367*da0073e9SAndroid Build Coastguard Worker rnn.dropout = p 4368*da0073e9SAndroid Build Coastguard Worker output1, hy1 = rnn(input) 4369*da0073e9SAndroid Build Coastguard Worker output2, hy2 = rnn(input) 4370*da0073e9SAndroid Build Coastguard Worker 4371*da0073e9SAndroid Build Coastguard Worker if p == 0 or p == 1 or not train: 4372*da0073e9SAndroid Build Coastguard Worker self.assertEqual(output1, output2) 4373*da0073e9SAndroid Build Coastguard Worker self.assertEqual(hy1, hy2) 4374*da0073e9SAndroid Build Coastguard Worker else: 4375*da0073e9SAndroid Build Coastguard Worker self.assertNotEqual(output1, output2) 4376*da0073e9SAndroid Build Coastguard Worker self.assertNotEqual(hy1, hy2) 4377*da0073e9SAndroid Build Coastguard Worker 4378*da0073e9SAndroid Build Coastguard Worker if prev_output is not None: 4379*da0073e9SAndroid Build Coastguard Worker if not train: 4380*da0073e9SAndroid Build Coastguard Worker self.assertEqual(output1.data, prev_output) 4381*da0073e9SAndroid Build Coastguard Worker self.assertEqual(output2.data, prev_output) 4382*da0073e9SAndroid Build Coastguard Worker else: 4383*da0073e9SAndroid Build Coastguard Worker self.assertNotEqual(output1.data, prev_output) 4384*da0073e9SAndroid Build Coastguard Worker self.assertNotEqual(output2.data, prev_output) 4385*da0073e9SAndroid Build Coastguard Worker prev_output = output1.data 4386*da0073e9SAndroid Build Coastguard Worker 4387*da0073e9SAndroid Build Coastguard Worker def test_inplace_thnn(self): 4388*da0073e9SAndroid Build Coastguard Worker modules = [nn.ReLU, nn.ELU, nn.SELU, nn.CELU, nn.RReLU] 4389*da0073e9SAndroid Build Coastguard Worker for mod in modules: 4390*da0073e9SAndroid Build Coastguard Worker r = mod(inplace=True) 4391*da0073e9SAndroid Build Coastguard Worker input = torch.randn(5, 5, requires_grad=True) 4392*da0073e9SAndroid Build Coastguard Worker output = r(input + 0) 4393*da0073e9SAndroid Build Coastguard Worker grad_output = torch.randn(5, 5) 4394*da0073e9SAndroid Build Coastguard Worker grad_output_clone = grad_output.clone() 4395*da0073e9SAndroid Build Coastguard Worker output.backward(grad_output) 4396*da0073e9SAndroid Build Coastguard Worker self.assertEqual(grad_output, grad_output_clone) 4397*da0073e9SAndroid Build Coastguard Worker 4398*da0073e9SAndroid Build Coastguard Worker 4399*da0073e9SAndroid Build Coastguard Worker def test_pixel_shuffle_unshuffle(self): 4400*da0073e9SAndroid Build Coastguard Worker def _test_pixel_shuffle_unshuffle_helper(num_input_dims, valid_channels_dim=True, 4401*da0073e9SAndroid Build Coastguard Worker upscale_factor=None): 4402*da0073e9SAndroid Build Coastguard Worker # Function to imperatively ensure pixels are shuffled to the correct locations. 4403*da0073e9SAndroid Build Coastguard Worker # Used to validate the batch operations in pixel_shuffle. 4404*da0073e9SAndroid Build Coastguard Worker def _verify_pixel_shuffle(input, output, upscale_factor): 4405*da0073e9SAndroid Build Coastguard Worker for c in range(output.size(-3)): 4406*da0073e9SAndroid Build Coastguard Worker for h in range(output.size(-2)): 4407*da0073e9SAndroid Build Coastguard Worker for w in range(output.size(-1)): 4408*da0073e9SAndroid Build Coastguard Worker height_idx = h // upscale_factor 4409*da0073e9SAndroid Build Coastguard Worker weight_idx = w // upscale_factor 4410*da0073e9SAndroid Build Coastguard Worker channel_idx = (upscale_factor * (h % upscale_factor)) + (w % upscale_factor) + \ 4411*da0073e9SAndroid Build Coastguard Worker (c * upscale_factor ** 2) 4412*da0073e9SAndroid Build Coastguard Worker self.assertEqual(output[..., c, h, w], input[..., channel_idx, height_idx, weight_idx]) 4413*da0073e9SAndroid Build Coastguard Worker 4414*da0073e9SAndroid Build Coastguard Worker upscale_factor = random.randint(2, 5) if upscale_factor is None else upscale_factor 4415*da0073e9SAndroid Build Coastguard Worker # If valid_channels_dim=False, add 1 to make channels dim indivisible by upscale_factor ** 2. 4416*da0073e9SAndroid Build Coastguard Worker channels = random.randint(1, 4) * upscale_factor ** 2 + (0 if valid_channels_dim else 1) 4417*da0073e9SAndroid Build Coastguard Worker height = random.randint(5, 10) 4418*da0073e9SAndroid Build Coastguard Worker width = random.randint(5, 10) 4419*da0073e9SAndroid Build Coastguard Worker 4420*da0073e9SAndroid Build Coastguard Worker if num_input_dims == 1: 4421*da0073e9SAndroid Build Coastguard Worker input = torch.rand(channels, requires_grad=True) 4422*da0073e9SAndroid Build Coastguard Worker elif num_input_dims == 2: 4423*da0073e9SAndroid Build Coastguard Worker input = torch.rand(height, width, requires_grad=True) 4424*da0073e9SAndroid Build Coastguard Worker else: 4425*da0073e9SAndroid Build Coastguard Worker batch_sizes = [random.randint(1, 3) for _ in range(num_input_dims - 3)] 4426*da0073e9SAndroid Build Coastguard Worker input = torch.rand(*batch_sizes, channels, height, width, requires_grad=True) 4427*da0073e9SAndroid Build Coastguard Worker ps = nn.PixelShuffle(upscale_factor) 4428*da0073e9SAndroid Build Coastguard Worker pus = nn.PixelUnshuffle(downscale_factor=upscale_factor) 4429*da0073e9SAndroid Build Coastguard Worker 4430*da0073e9SAndroid Build Coastguard Worker if num_input_dims >= 3 and valid_channels_dim and upscale_factor > 0: 4431*da0073e9SAndroid Build Coastguard Worker output = ps(input) 4432*da0073e9SAndroid Build Coastguard Worker _verify_pixel_shuffle(input, output, upscale_factor) 4433*da0073e9SAndroid Build Coastguard Worker output.backward(output.data) 4434*da0073e9SAndroid Build Coastguard Worker self.assertEqual(input.data, input.grad.data) 4435*da0073e9SAndroid Build Coastguard Worker 4436*da0073e9SAndroid Build Coastguard Worker # Ensure unshuffle properly inverts shuffle. 4437*da0073e9SAndroid Build Coastguard Worker unshuffle_output = pus(output) 4438*da0073e9SAndroid Build Coastguard Worker self.assertEqual(input, unshuffle_output) 4439*da0073e9SAndroid Build Coastguard Worker else: 4440*da0073e9SAndroid Build Coastguard Worker self.assertRaises(RuntimeError, lambda: ps(input)) 4441*da0073e9SAndroid Build Coastguard Worker 4442*da0073e9SAndroid Build Coastguard Worker def _test_pixel_unshuffle_error_case_helper(num_input_dims, valid_height_dim=True, valid_width_dim=True, 4443*da0073e9SAndroid Build Coastguard Worker downscale_factor=None): 4444*da0073e9SAndroid Build Coastguard Worker downscale_factor = random.randint(2, 5) if downscale_factor is None else downscale_factor 4445*da0073e9SAndroid Build Coastguard Worker channels = random.randint(1, 4) 4446*da0073e9SAndroid Build Coastguard Worker # If valid_height_dim=False, add 1 to make height dim indivisible by downscale_factor. 4447*da0073e9SAndroid Build Coastguard Worker height = random.randint(3, 5) * abs(downscale_factor) + (0 if valid_height_dim else 1) 4448*da0073e9SAndroid Build Coastguard Worker # If valid_width_dim=False, add 1 to make width dim indivisible by downscale_factor. 4449*da0073e9SAndroid Build Coastguard Worker width = random.randint(3, 5) * abs(downscale_factor) + (0 if valid_width_dim else 1) 4450*da0073e9SAndroid Build Coastguard Worker 4451*da0073e9SAndroid Build Coastguard Worker if num_input_dims == 1: 4452*da0073e9SAndroid Build Coastguard Worker input = torch.rand(channels, requires_grad=True) 4453*da0073e9SAndroid Build Coastguard Worker elif num_input_dims == 2: 4454*da0073e9SAndroid Build Coastguard Worker input = torch.rand(height, width, requires_grad=True) 4455*da0073e9SAndroid Build Coastguard Worker else: 4456*da0073e9SAndroid Build Coastguard Worker batch_sizes = [random.randint(1, 3) for _ in range(num_input_dims - 3)] 4457*da0073e9SAndroid Build Coastguard Worker input = torch.rand(*batch_sizes, channels, height, width, requires_grad=True) 4458*da0073e9SAndroid Build Coastguard Worker 4459*da0073e9SAndroid Build Coastguard Worker pus = nn.PixelUnshuffle(downscale_factor) 4460*da0073e9SAndroid Build Coastguard Worker self.assertRaises(RuntimeError, lambda: pus(input)) 4461*da0073e9SAndroid Build Coastguard Worker 4462*da0073e9SAndroid Build Coastguard Worker def _test_pixel_shuffle_unshuffle_for_input_dims(num_input_dims): 4463*da0073e9SAndroid Build Coastguard Worker # For 1D - 2D, this is an error case. 4464*da0073e9SAndroid Build Coastguard Worker # For 3D - 5D, this is a success case for pixel_shuffle + pixel_unshuffle. 4465*da0073e9SAndroid Build Coastguard Worker _test_pixel_shuffle_unshuffle_helper(num_input_dims=num_input_dims) 4466*da0073e9SAndroid Build Coastguard Worker 4467*da0073e9SAndroid Build Coastguard Worker # Error cases for pixel_shuffle. 4468*da0073e9SAndroid Build Coastguard Worker _test_pixel_shuffle_unshuffle_helper(num_input_dims=num_input_dims, valid_channels_dim=False) 4469*da0073e9SAndroid Build Coastguard Worker _test_pixel_shuffle_unshuffle_helper(num_input_dims=num_input_dims, upscale_factor=0) 4470*da0073e9SAndroid Build Coastguard Worker _test_pixel_shuffle_unshuffle_helper(num_input_dims=num_input_dims, upscale_factor=-2) 4471*da0073e9SAndroid Build Coastguard Worker 4472*da0073e9SAndroid Build Coastguard Worker # Error cases for pixel_unshuffle. 4473*da0073e9SAndroid Build Coastguard Worker _test_pixel_unshuffle_error_case_helper(num_input_dims=num_input_dims, valid_height_dim=False) 4474*da0073e9SAndroid Build Coastguard Worker _test_pixel_unshuffle_error_case_helper(num_input_dims=num_input_dims, valid_width_dim=False) 4475*da0073e9SAndroid Build Coastguard Worker _test_pixel_unshuffle_error_case_helper(num_input_dims=num_input_dims, downscale_factor=0) 4476*da0073e9SAndroid Build Coastguard Worker _test_pixel_unshuffle_error_case_helper(num_input_dims=num_input_dims, downscale_factor=-2) 4477*da0073e9SAndroid Build Coastguard Worker 4478*da0073e9SAndroid Build Coastguard Worker def test_pixel_shuffle_unshuffle_1D(): 4479*da0073e9SAndroid Build Coastguard Worker _test_pixel_shuffle_unshuffle_for_input_dims(num_input_dims=1) 4480*da0073e9SAndroid Build Coastguard Worker 4481*da0073e9SAndroid Build Coastguard Worker def test_pixel_shuffle_unshuffle_2D(): 4482*da0073e9SAndroid Build Coastguard Worker _test_pixel_shuffle_unshuffle_for_input_dims(num_input_dims=2) 4483*da0073e9SAndroid Build Coastguard Worker 4484*da0073e9SAndroid Build Coastguard Worker def test_pixel_shuffle_unshuffle_3D(): 4485*da0073e9SAndroid Build Coastguard Worker _test_pixel_shuffle_unshuffle_for_input_dims(num_input_dims=3) 4486*da0073e9SAndroid Build Coastguard Worker 4487*da0073e9SAndroid Build Coastguard Worker def test_pixel_shuffle_unshuffle_4D(): 4488*da0073e9SAndroid Build Coastguard Worker _test_pixel_shuffle_unshuffle_for_input_dims(num_input_dims=4) 4489*da0073e9SAndroid Build Coastguard Worker 4490*da0073e9SAndroid Build Coastguard Worker def test_pixel_shuffle_unshuffle_5D(): 4491*da0073e9SAndroid Build Coastguard Worker _test_pixel_shuffle_unshuffle_for_input_dims(num_input_dims=5) 4492*da0073e9SAndroid Build Coastguard Worker 4493*da0073e9SAndroid Build Coastguard Worker test_pixel_shuffle_unshuffle_1D() 4494*da0073e9SAndroid Build Coastguard Worker test_pixel_shuffle_unshuffle_2D() 4495*da0073e9SAndroid Build Coastguard Worker test_pixel_shuffle_unshuffle_3D() 4496*da0073e9SAndroid Build Coastguard Worker test_pixel_shuffle_unshuffle_4D() 4497*da0073e9SAndroid Build Coastguard Worker test_pixel_shuffle_unshuffle_5D() 4498*da0073e9SAndroid Build Coastguard Worker 4499*da0073e9SAndroid Build Coastguard Worker @set_default_dtype(torch.double) 4500*da0073e9SAndroid Build Coastguard Worker def test_pixel_shuffle_nhwc_cpu(self): 4501*da0073e9SAndroid Build Coastguard Worker input = torch.randn(3, 18, 4, 4, device='cpu') 4502*da0073e9SAndroid Build Coastguard Worker input = input.contiguous(memory_format=torch.channels_last).requires_grad_() 4503*da0073e9SAndroid Build Coastguard Worker grad = torch.randn(3, 18, 4, 4, device='cpu') 4504*da0073e9SAndroid Build Coastguard Worker ps = torch.nn.PixelShuffle(3) 4505*da0073e9SAndroid Build Coastguard Worker pus = torch.nn.PixelUnshuffle(3) 4506*da0073e9SAndroid Build Coastguard Worker 4507*da0073e9SAndroid Build Coastguard Worker ref_input = input.detach().clone().contiguous().requires_grad_(True) 4508*da0073e9SAndroid Build Coastguard Worker ref_grad = grad.detach().clone().contiguous() 4509*da0073e9SAndroid Build Coastguard Worker ref_ps = torch.nn.PixelShuffle(3) 4510*da0073e9SAndroid Build Coastguard Worker ref_pus = torch.nn.PixelUnshuffle(3) 4511*da0073e9SAndroid Build Coastguard Worker 4512*da0073e9SAndroid Build Coastguard Worker out = pus(ps(input)) 4513*da0073e9SAndroid Build Coastguard Worker out.backward(grad) 4514*da0073e9SAndroid Build Coastguard Worker ref_out = ref_pus(ref_ps(ref_input)) 4515*da0073e9SAndroid Build Coastguard Worker ref_out.backward(ref_grad) 4516*da0073e9SAndroid Build Coastguard Worker 4517*da0073e9SAndroid Build Coastguard Worker self.assertTrue(out.is_contiguous(memory_format=torch.channels_last)) 4518*da0073e9SAndroid Build Coastguard Worker self.assertTrue(ref_out.is_contiguous()) 4519*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out, ref_out) 4520*da0073e9SAndroid Build Coastguard Worker self.assertEqual(input.grad, ref_input.grad) 4521*da0073e9SAndroid Build Coastguard Worker 4522*da0073e9SAndroid Build Coastguard Worker # These tests should be OpInfo'd 4523*da0073e9SAndroid Build Coastguard Worker def test_elu_inplace_on_view(self): 4524*da0073e9SAndroid Build Coastguard Worker v = torch.tensor([1.0, -1.0, 1.0, -1.0], requires_grad=True, dtype=torch.double) 4525*da0073e9SAndroid Build Coastguard Worker 4526*da0073e9SAndroid Build Coastguard Worker def func(root): 4527*da0073e9SAndroid Build Coastguard Worker x = root.clone() 4528*da0073e9SAndroid Build Coastguard Worker view = x.narrow(0, 1, 2) 4529*da0073e9SAndroid Build Coastguard Worker res = F.elu(view, inplace=True) 4530*da0073e9SAndroid Build Coastguard Worker self.assertIs(res, view) 4531*da0073e9SAndroid Build Coastguard Worker return x 4532*da0073e9SAndroid Build Coastguard Worker 4533*da0073e9SAndroid Build Coastguard Worker gradcheck(func, [v]) 4534*da0073e9SAndroid Build Coastguard Worker gradgradcheck(func, [v]) 4535*da0073e9SAndroid Build Coastguard Worker 4536*da0073e9SAndroid Build Coastguard Worker def test_elu_inplace_gradgrad(self): 4537*da0073e9SAndroid Build Coastguard Worker v = torch.randn(8, requires_grad=True, dtype=torch.double) 4538*da0073e9SAndroid Build Coastguard Worker 4539*da0073e9SAndroid Build Coastguard Worker def func(root): 4540*da0073e9SAndroid Build Coastguard Worker x = root.clone() 4541*da0073e9SAndroid Build Coastguard Worker return F.elu(x, inplace=True) 4542*da0073e9SAndroid Build Coastguard Worker 4543*da0073e9SAndroid Build Coastguard Worker gradcheck(func, [v]) 4544*da0073e9SAndroid Build Coastguard Worker gradgradcheck(func, [v]) 4545*da0073e9SAndroid Build Coastguard Worker 4546*da0073e9SAndroid Build Coastguard Worker def test_relu_inplace_on_view(self): 4547*da0073e9SAndroid Build Coastguard Worker v = torch.tensor([1.0, -1.0, 1.0, -1.0], requires_grad=True, dtype=torch.double) 4548*da0073e9SAndroid Build Coastguard Worker 4549*da0073e9SAndroid Build Coastguard Worker def func(root): 4550*da0073e9SAndroid Build Coastguard Worker x = root.clone() 4551*da0073e9SAndroid Build Coastguard Worker view = x.narrow(0, 1, 2) 4552*da0073e9SAndroid Build Coastguard Worker res = F.relu(view, inplace=True) 4553*da0073e9SAndroid Build Coastguard Worker self.assertIs(res, view) 4554*da0073e9SAndroid Build Coastguard Worker return x 4555*da0073e9SAndroid Build Coastguard Worker 4556*da0073e9SAndroid Build Coastguard Worker gradcheck(func, [v]) 4557*da0073e9SAndroid Build Coastguard Worker gradgradcheck(func, [v]) 4558*da0073e9SAndroid Build Coastguard Worker 4559*da0073e9SAndroid Build Coastguard Worker def test_PReLU_backward_requires_grad_false(self): 4560*da0073e9SAndroid Build Coastguard Worker devices = ['cpu'] 4561*da0073e9SAndroid Build Coastguard Worker devices += ['cuda'] if TEST_CUDA else [] 4562*da0073e9SAndroid Build Coastguard Worker for d in devices: 4563*da0073e9SAndroid Build Coastguard Worker m = nn.PReLU().to(d) 4564*da0073e9SAndroid Build Coastguard Worker x = torch.randn(2, 3, 4, 5, device=d, requires_grad=False) 4565*da0073e9SAndroid Build Coastguard Worker y = m(x) 4566*da0073e9SAndroid Build Coastguard Worker y.mean().backward() 4567*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.grad, None) 4568*da0073e9SAndroid Build Coastguard Worker 4569*da0073e9SAndroid Build Coastguard Worker def test_bce_loss_always_nonnegative(self): 4570*da0073e9SAndroid Build Coastguard Worker target = torch.ones(5) 4571*da0073e9SAndroid Build Coastguard Worker input = torch.ones(5) 4572*da0073e9SAndroid Build Coastguard Worker self.assertEqual((nn.BCELoss()(input, target) < 0).sum(), 0) 4573*da0073e9SAndroid Build Coastguard Worker 4574*da0073e9SAndroid Build Coastguard Worker target = torch.zeros(5) 4575*da0073e9SAndroid Build Coastguard Worker input = torch.zeros(5) 4576*da0073e9SAndroid Build Coastguard Worker self.assertEqual((nn.BCELoss()(input, target) < 0).sum(), 0) 4577*da0073e9SAndroid Build Coastguard Worker 4578*da0073e9SAndroid Build Coastguard Worker def test_bce_with_logits_raises_if_target_and_input_are_different_size(self): 4579*da0073e9SAndroid Build Coastguard Worker target = torch.rand(5) 4580*da0073e9SAndroid Build Coastguard Worker input = torch.rand(5, 1) 4581*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(ValueError): 4582*da0073e9SAndroid Build Coastguard Worker nn.BCEWithLogitsLoss()(input, target) 4583*da0073e9SAndroid Build Coastguard Worker 4584*da0073e9SAndroid Build Coastguard Worker target = torch.rand(5, 1) 4585*da0073e9SAndroid Build Coastguard Worker input = torch.rand(5) 4586*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(ValueError): 4587*da0073e9SAndroid Build Coastguard Worker nn.BCEWithLogitsLoss()(input, target) 4588*da0073e9SAndroid Build Coastguard Worker 4589*da0073e9SAndroid Build Coastguard Worker def test_bce_with_logits_gives_same_result_as_sigmoid_and_bce_loss(self): 4590*da0073e9SAndroid Build Coastguard Worker sigmoid = nn.Sigmoid() 4591*da0073e9SAndroid Build Coastguard Worker 4592*da0073e9SAndroid Build Coastguard Worker target = torch.rand(64, 4) 4593*da0073e9SAndroid Build Coastguard Worker output = torch.rand(64, 4) - 0.5 4594*da0073e9SAndroid Build Coastguard Worker 4595*da0073e9SAndroid Build Coastguard Worker self.assertEqual(nn.BCEWithLogitsLoss()(output, target), nn.BCELoss()(sigmoid(output), target)) 4596*da0073e9SAndroid Build Coastguard Worker 4597*da0073e9SAndroid Build Coastguard Worker weight = torch.rand(4) 4598*da0073e9SAndroid Build Coastguard Worker self.assertEqual(nn.BCEWithLogitsLoss(weight)(output, target), nn.BCELoss(weight)(sigmoid(output), target)) 4599*da0073e9SAndroid Build Coastguard Worker 4600*da0073e9SAndroid Build Coastguard Worker target = torch.zeros(4, 1, dtype=torch.float) 4601*da0073e9SAndroid Build Coastguard Worker output = torch.empty(4, 1, dtype=torch.float).fill_(-100) 4602*da0073e9SAndroid Build Coastguard Worker 4603*da0073e9SAndroid Build Coastguard Worker self.assertEqual(nn.BCEWithLogitsLoss()(output, target), nn.BCELoss()(sigmoid(output), target)) 4604*da0073e9SAndroid Build Coastguard Worker 4605*da0073e9SAndroid Build Coastguard Worker self.assertEqual(nn.BCEWithLogitsLoss(reduction='none')(output, target), 4606*da0073e9SAndroid Build Coastguard Worker nn.BCELoss(reduction='none')(sigmoid(output), target)) 4607*da0073e9SAndroid Build Coastguard Worker 4608*da0073e9SAndroid Build Coastguard Worker weight = torch.rand(1, dtype=torch.float) 4609*da0073e9SAndroid Build Coastguard Worker self.assertEqual(nn.BCEWithLogitsLoss(weight)(output, target), nn.BCELoss(weight)(sigmoid(output), target)) 4610*da0073e9SAndroid Build Coastguard Worker 4611*da0073e9SAndroid Build Coastguard Worker def test_bce_loss_input_range(self): 4612*da0073e9SAndroid Build Coastguard Worker bceloss = nn.BCELoss() 4613*da0073e9SAndroid Build Coastguard Worker 4614*da0073e9SAndroid Build Coastguard Worker target = torch.rand(25, 25) 4615*da0073e9SAndroid Build Coastguard Worker output_valid = torch.rand(25, 25) 4616*da0073e9SAndroid Build Coastguard Worker output_too_negative = output_valid - 1.0 4617*da0073e9SAndroid Build Coastguard Worker output_too_positive = output_valid + 1.0 4618*da0073e9SAndroid Build Coastguard Worker 4619*da0073e9SAndroid Build Coastguard Worker loss_valid = bceloss(output_valid, target) 4620*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, 'between 0 and 1'): 4621*da0073e9SAndroid Build Coastguard Worker loss_too_negative = bceloss(output_too_negative, target) 4622*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, 'between 0 and 1'): 4623*da0073e9SAndroid Build Coastguard Worker loss_too_positive = bceloss(output_too_positive, target) 4624*da0073e9SAndroid Build Coastguard Worker 4625*da0073e9SAndroid Build Coastguard Worker def test_bce_loss_size_mismatch(self): 4626*da0073e9SAndroid Build Coastguard Worker bceloss = nn.BCELoss() 4627*da0073e9SAndroid Build Coastguard Worker a = torch.rand(25) 4628*da0073e9SAndroid Build Coastguard Worker b = torch.rand(25, 1) 4629*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(ValueError, r'Using a target size \('): 4630*da0073e9SAndroid Build Coastguard Worker bceloss(a, b) 4631*da0073e9SAndroid Build Coastguard Worker 4632*da0073e9SAndroid Build Coastguard Worker def test_bce_with_logits_gives_same_result_as_sigmoid_and_bce_loss_large_tensors_with_grad(self): 4633*da0073e9SAndroid Build Coastguard Worker x_size = 1024 4634*da0073e9SAndroid Build Coastguard Worker y_size = 256 4635*da0073e9SAndroid Build Coastguard Worker target = torch.rand(x_size, y_size) 4636*da0073e9SAndroid Build Coastguard Worker 4637*da0073e9SAndroid Build Coastguard Worker for reduction in ['none', 'mean', 'sum']: 4638*da0073e9SAndroid Build Coastguard Worker output_sig = torch.rand(x_size, y_size) - 0.5 4639*da0073e9SAndroid Build Coastguard Worker output_logits = output_sig.clone().detach() 4640*da0073e9SAndroid Build Coastguard Worker 4641*da0073e9SAndroid Build Coastguard Worker output_sig.requires_grad = True 4642*da0073e9SAndroid Build Coastguard Worker output_logits.requires_grad = True 4643*da0073e9SAndroid Build Coastguard Worker weight = torch.rand(y_size) 4644*da0073e9SAndroid Build Coastguard Worker 4645*da0073e9SAndroid Build Coastguard Worker loss_sig = nn.BCELoss(weight, reduction=reduction)( 4646*da0073e9SAndroid Build Coastguard Worker torch.sigmoid(output_sig), target 4647*da0073e9SAndroid Build Coastguard Worker ) 4648*da0073e9SAndroid Build Coastguard Worker loss_logits = nn.BCEWithLogitsLoss(weight, reduction=reduction)( 4649*da0073e9SAndroid Build Coastguard Worker output_logits, target 4650*da0073e9SAndroid Build Coastguard Worker ) 4651*da0073e9SAndroid Build Coastguard Worker 4652*da0073e9SAndroid Build Coastguard Worker self.assertEqual(loss_logits, loss_sig) 4653*da0073e9SAndroid Build Coastguard Worker 4654*da0073e9SAndroid Build Coastguard Worker if reduction == 'none': 4655*da0073e9SAndroid Build Coastguard Worker grad = torch.rand(x_size, y_size) 4656*da0073e9SAndroid Build Coastguard Worker loss_sig.backward(grad) 4657*da0073e9SAndroid Build Coastguard Worker loss_logits.backward(grad) 4658*da0073e9SAndroid Build Coastguard Worker else: 4659*da0073e9SAndroid Build Coastguard Worker loss_sig.backward() 4660*da0073e9SAndroid Build Coastguard Worker loss_logits.backward() 4661*da0073e9SAndroid Build Coastguard Worker 4662*da0073e9SAndroid Build Coastguard Worker self.assertEqual(output_sig.grad, output_logits.grad) 4663*da0073e9SAndroid Build Coastguard Worker 4664*da0073e9SAndroid Build Coastguard Worker def test_bce_with_logits_has_correct_forward_grad(self): 4665*da0073e9SAndroid Build Coastguard Worker output = torch.randn(3, 5, requires_grad=True, dtype=torch.double) 4666*da0073e9SAndroid Build Coastguard Worker target = torch.randn(3, 5, dtype=torch.double) 4667*da0073e9SAndroid Build Coastguard Worker for reduction in ('sum', 'mean', 'none'): 4668*da0073e9SAndroid Build Coastguard Worker gradcheck(lambda self, target: nn.BCEWithLogitsLoss(reduction=reduction)(self, target), 4669*da0073e9SAndroid Build Coastguard Worker (output, target), check_forward_ad=True) 4670*da0073e9SAndroid Build Coastguard Worker 4671*da0073e9SAndroid Build Coastguard Worker def test_bce_with_logits_has_correct_grad_at_zero(self): 4672*da0073e9SAndroid Build Coastguard Worker output = torch.zeros(3, 1, requires_grad=True) 4673*da0073e9SAndroid Build Coastguard Worker target = torch.zeros(3, 1) 4674*da0073e9SAndroid Build Coastguard Worker nn.BCEWithLogitsLoss(reduction='sum')(output, target).backward() 4675*da0073e9SAndroid Build Coastguard Worker expected_grad = torch.empty(3, 1).fill_(0.5) 4676*da0073e9SAndroid Build Coastguard Worker self.assertEqual(output.grad, expected_grad) 4677*da0073e9SAndroid Build Coastguard Worker 4678*da0073e9SAndroid Build Coastguard Worker def test_bce_with_logits_broadcasts_weights(self): 4679*da0073e9SAndroid Build Coastguard Worker target = torch.rand(16, 4) 4680*da0073e9SAndroid Build Coastguard Worker output = torch.rand(16, 4) - 0.5 4681*da0073e9SAndroid Build Coastguard Worker 4682*da0073e9SAndroid Build Coastguard Worker weight = torch.rand(4) 4683*da0073e9SAndroid Build Coastguard Worker out1 = nn.BCEWithLogitsLoss(weight)(output, target) 4684*da0073e9SAndroid Build Coastguard Worker 4685*da0073e9SAndroid Build Coastguard Worker weight = weight.expand(16, 4).contiguous() 4686*da0073e9SAndroid Build Coastguard Worker out2 = nn.BCEWithLogitsLoss(weight)(output, target) 4687*da0073e9SAndroid Build Coastguard Worker 4688*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out1, out2) 4689*da0073e9SAndroid Build Coastguard Worker 4690*da0073e9SAndroid Build Coastguard Worker weight = torch.rand(16, 1) 4691*da0073e9SAndroid Build Coastguard Worker out1 = nn.BCEWithLogitsLoss(weight)(output, target) 4692*da0073e9SAndroid Build Coastguard Worker 4693*da0073e9SAndroid Build Coastguard Worker weight = weight.expand(16, 4).contiguous() 4694*da0073e9SAndroid Build Coastguard Worker out2 = nn.BCEWithLogitsLoss(weight)(output, target) 4695*da0073e9SAndroid Build Coastguard Worker 4696*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out1, out2) 4697*da0073e9SAndroid Build Coastguard Worker 4698*da0073e9SAndroid Build Coastguard Worker def test_bce_with_logits_ones_in_pos_weights_are_the_same_as_none(self): 4699*da0073e9SAndroid Build Coastguard Worker target = torch.rand(64, 4) 4700*da0073e9SAndroid Build Coastguard Worker output = torch.rand(64, 4) - 0.5 4701*da0073e9SAndroid Build Coastguard Worker pos_weight = torch.ones(64, 4) 4702*da0073e9SAndroid Build Coastguard Worker 4703*da0073e9SAndroid Build Coastguard Worker self.assertEqual(nn.BCEWithLogitsLoss()(output, target), 4704*da0073e9SAndroid Build Coastguard Worker nn.BCEWithLogitsLoss(pos_weight=pos_weight)(output, target)) 4705*da0073e9SAndroid Build Coastguard Worker 4706*da0073e9SAndroid Build Coastguard Worker def test_bce_with_logits_broadcasts_pos_weights(self): 4707*da0073e9SAndroid Build Coastguard Worker target = torch.rand(64, 4) 4708*da0073e9SAndroid Build Coastguard Worker output = torch.rand(64, 4) - 0.5 4709*da0073e9SAndroid Build Coastguard Worker pos_weight = torch.rand(4) 4710*da0073e9SAndroid Build Coastguard Worker out1 = nn.BCEWithLogitsLoss(pos_weight=pos_weight)(output, target) 4711*da0073e9SAndroid Build Coastguard Worker 4712*da0073e9SAndroid Build Coastguard Worker pos_weight1 = pos_weight.expand(1, 4) 4713*da0073e9SAndroid Build Coastguard Worker out2 = nn.BCEWithLogitsLoss(pos_weight=pos_weight1)(output, target) 4714*da0073e9SAndroid Build Coastguard Worker 4715*da0073e9SAndroid Build Coastguard Worker pos_weight2 = pos_weight.expand(64, 4) 4716*da0073e9SAndroid Build Coastguard Worker out3 = nn.BCEWithLogitsLoss(pos_weight=pos_weight2)(output, target) 4717*da0073e9SAndroid Build Coastguard Worker 4718*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out1, out2) 4719*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out1, out3) 4720*da0073e9SAndroid Build Coastguard Worker 4721*da0073e9SAndroid Build Coastguard Worker def test_bce_with_logits_with_pos_weight_has_correct_grad_at_zero(self): 4722*da0073e9SAndroid Build Coastguard Worker output = torch.zeros(3, 1, requires_grad=True) 4723*da0073e9SAndroid Build Coastguard Worker target = torch.zeros(3, 1) 4724*da0073e9SAndroid Build Coastguard Worker pos_weight = torch.ones(3, 1) 4725*da0073e9SAndroid Build Coastguard Worker nn.BCEWithLogitsLoss(pos_weight=pos_weight, reduction='sum')(output, target).backward() 4726*da0073e9SAndroid Build Coastguard Worker expected_grad = torch.empty(3, 1).fill_(0.5) 4727*da0073e9SAndroid Build Coastguard Worker grad = output.grad 4728*da0073e9SAndroid Build Coastguard Worker self.assertEqual(grad, expected_grad) 4729*da0073e9SAndroid Build Coastguard Worker 4730*da0073e9SAndroid Build Coastguard Worker def test_bce_with_logits_stability(self): 4731*da0073e9SAndroid Build Coastguard Worker output = torch.tensor([0., -120.]) 4732*da0073e9SAndroid Build Coastguard Worker target = torch.tensor([0., 1.]) 4733*da0073e9SAndroid Build Coastguard Worker pos_weight = torch.tensor([1., 1.]) 4734*da0073e9SAndroid Build Coastguard Worker 4735*da0073e9SAndroid Build Coastguard Worker out1 = nn.BCEWithLogitsLoss()(output, target) 4736*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.isfinite(out1).all().item()) 4737*da0073e9SAndroid Build Coastguard Worker 4738*da0073e9SAndroid Build Coastguard Worker out2 = nn.BCEWithLogitsLoss(pos_weight=pos_weight)(output, target) 4739*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.isfinite(out2).all().item()) 4740*da0073e9SAndroid Build Coastguard Worker 4741*da0073e9SAndroid Build Coastguard Worker def test_bce_loss_broadcasts_weights(self): 4742*da0073e9SAndroid Build Coastguard Worker sigmoid = nn.Sigmoid() 4743*da0073e9SAndroid Build Coastguard Worker target = torch.rand(16, 4) 4744*da0073e9SAndroid Build Coastguard Worker output = torch.rand(16, 4) - 0.5 4745*da0073e9SAndroid Build Coastguard Worker 4746*da0073e9SAndroid Build Coastguard Worker weight = torch.rand(4) 4747*da0073e9SAndroid Build Coastguard Worker out1 = nn.BCELoss(weight)(sigmoid(output), target) 4748*da0073e9SAndroid Build Coastguard Worker 4749*da0073e9SAndroid Build Coastguard Worker weight = weight.expand(16, 4).contiguous() 4750*da0073e9SAndroid Build Coastguard Worker out2 = nn.BCELoss(weight)(sigmoid(output), target) 4751*da0073e9SAndroid Build Coastguard Worker 4752*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out1, out2) 4753*da0073e9SAndroid Build Coastguard Worker 4754*da0073e9SAndroid Build Coastguard Worker weight = torch.rand(16, 1) 4755*da0073e9SAndroid Build Coastguard Worker out1 = nn.BCELoss(weight)(sigmoid(output), target) 4756*da0073e9SAndroid Build Coastguard Worker 4757*da0073e9SAndroid Build Coastguard Worker weight = weight.expand(16, 4).contiguous() 4758*da0073e9SAndroid Build Coastguard Worker out2 = nn.BCELoss(weight)(sigmoid(output), target) 4759*da0073e9SAndroid Build Coastguard Worker 4760*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out1, out2) 4761*da0073e9SAndroid Build Coastguard Worker 4762*da0073e9SAndroid Build Coastguard Worker def test_hardtanh_inplace_gradgrad(self): 4763*da0073e9SAndroid Build Coastguard Worker v = torch.randn(8, requires_grad=True, dtype=torch.double) 4764*da0073e9SAndroid Build Coastguard Worker 4765*da0073e9SAndroid Build Coastguard Worker def func(root): 4766*da0073e9SAndroid Build Coastguard Worker x = root.clone() 4767*da0073e9SAndroid Build Coastguard Worker return F.hardtanh(x, inplace=True) 4768*da0073e9SAndroid Build Coastguard Worker 4769*da0073e9SAndroid Build Coastguard Worker gradcheck(func, [v]) 4770*da0073e9SAndroid Build Coastguard Worker gradgradcheck(func, [v]) 4771*da0073e9SAndroid Build Coastguard Worker 4772*da0073e9SAndroid Build Coastguard Worker # test hardtanh backward for large tensor 4773*da0073e9SAndroid Build Coastguard Worker def test_hardtanh_backward(self): 4774*da0073e9SAndroid Build Coastguard Worker x = torch.randn(128, 10000, requires_grad=True) 4775*da0073e9SAndroid Build Coastguard Worker grad = torch.randn(128, 10000) 4776*da0073e9SAndroid Build Coastguard Worker z = torch.zeros(128, 10000) 4777*da0073e9SAndroid Build Coastguard Worker y = F.hardtanh(x) 4778*da0073e9SAndroid Build Coastguard Worker y.backward(grad) 4779*da0073e9SAndroid Build Coastguard Worker # ref backward path for hardtanh 4780*da0073e9SAndroid Build Coastguard Worker mask = (x > -1) & (x < 1) 4781*da0073e9SAndroid Build Coastguard Worker x_grad_ref = torch.where(mask, grad, z) 4782*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.grad, x_grad_ref) 4783*da0073e9SAndroid Build Coastguard Worker 4784*da0073e9SAndroid Build Coastguard Worker def test_batchnorm_nhwc_cpu(self): 4785*da0073e9SAndroid Build Coastguard Worker def helper(self, mod, size, dtype, mixed_dtype=False, format=torch.channels_last, precision=None): 4786*da0073e9SAndroid Build Coastguard Worker channels = size[1] 4787*da0073e9SAndroid Build Coastguard Worker input = torch.randn(size, dtype=dtype, device='cpu', requires_grad=True) 4788*da0073e9SAndroid Build Coastguard Worker input = input.contiguous(memory_format=format).to(dtype) 4789*da0073e9SAndroid Build Coastguard Worker input.retain_grad() 4790*da0073e9SAndroid Build Coastguard Worker grad = torch.randn(size, dtype=dtype, device='cpu') 4791*da0073e9SAndroid Build Coastguard Worker grad = grad.contiguous(memory_format=format) 4792*da0073e9SAndroid Build Coastguard Worker bn = mod(channels).cpu().to(dtype) 4793*da0073e9SAndroid Build Coastguard Worker bn.weight.data.uniform_() 4794*da0073e9SAndroid Build Coastguard Worker bn.bias.data.uniform_() 4795*da0073e9SAndroid Build Coastguard Worker 4796*da0073e9SAndroid Build Coastguard Worker ref_input = input.detach().clone().contiguous().requires_grad_(True) 4797*da0073e9SAndroid Build Coastguard Worker ref_grad = grad.detach().clone().contiguous() 4798*da0073e9SAndroid Build Coastguard Worker ref_bn = mod(channels).cpu().to(dtype) 4799*da0073e9SAndroid Build Coastguard Worker ref_bn.load_state_dict(bn.state_dict()) 4800*da0073e9SAndroid Build Coastguard Worker 4801*da0073e9SAndroid Build Coastguard Worker if mixed_dtype: 4802*da0073e9SAndroid Build Coastguard Worker bn.float() 4803*da0073e9SAndroid Build Coastguard Worker ref_bn.float() 4804*da0073e9SAndroid Build Coastguard Worker 4805*da0073e9SAndroid Build Coastguard Worker out = bn(input) 4806*da0073e9SAndroid Build Coastguard Worker out.backward(grad) 4807*da0073e9SAndroid Build Coastguard Worker ref_out = ref_bn(ref_input) 4808*da0073e9SAndroid Build Coastguard Worker ref_out.backward(ref_grad) 4809*da0073e9SAndroid Build Coastguard Worker 4810*da0073e9SAndroid Build Coastguard Worker self.assertTrue(out.is_contiguous(memory_format=format)) 4811*da0073e9SAndroid Build Coastguard Worker self.assertTrue(ref_out.is_contiguous()) 4812*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out, ref_out) 4813*da0073e9SAndroid Build Coastguard Worker self.assertEqual(bn.weight.grad, ref_bn.weight.grad, atol=precision, rtol=precision) 4814*da0073e9SAndroid Build Coastguard Worker self.assertEqual(bn.bias.grad, ref_bn.bias.grad) 4815*da0073e9SAndroid Build Coastguard Worker self.assertEqual(input.grad, ref_input.grad) 4816*da0073e9SAndroid Build Coastguard Worker 4817*da0073e9SAndroid Build Coastguard Worker # test NC11 and N1HW; test mixed dtype 4818*da0073e9SAndroid Build Coastguard Worker for shape in [(4, 8, 10, 10), (4, 1, 9, 9), (4, 9, 1, 1)]: 4819*da0073e9SAndroid Build Coastguard Worker for dtype in [torch.float, torch.bfloat16, torch.float16]: 4820*da0073e9SAndroid Build Coastguard Worker for mixed_dtype in [False, True]: 4821*da0073e9SAndroid Build Coastguard Worker if dtype == torch.float: 4822*da0073e9SAndroid Build Coastguard Worker mixed_dtype = False 4823*da0073e9SAndroid Build Coastguard Worker helper(self, nn.BatchNorm2d, shape, dtype, mixed_dtype, torch.channels_last) 4824*da0073e9SAndroid Build Coastguard Worker 4825*da0073e9SAndroid Build Coastguard Worker precisons = {torch.float: 1e-4, torch.bfloat16: 1e-4, torch.float16: None} 4826*da0073e9SAndroid Build Coastguard Worker for shape in [(4, 8, 2, 10, 10), (4, 1, 2, 9, 9), (4, 9, 1, 1, 1)]: 4827*da0073e9SAndroid Build Coastguard Worker for dtype in [torch.float, torch.bfloat16, torch.float16]: 4828*da0073e9SAndroid Build Coastguard Worker for mixed_dtype in [False, True]: 4829*da0073e9SAndroid Build Coastguard Worker if dtype == torch.float: 4830*da0073e9SAndroid Build Coastguard Worker mixed_dtype = False 4831*da0073e9SAndroid Build Coastguard Worker helper(self, nn.BatchNorm3d, shape, dtype, mixed_dtype, torch.channels_last_3d, precisons[dtype]) 4832*da0073e9SAndroid Build Coastguard Worker 4833*da0073e9SAndroid Build Coastguard Worker @parametrize_test( 4834*da0073e9SAndroid Build Coastguard Worker 'bn_module', 4835*da0073e9SAndroid Build Coastguard Worker [ 4836*da0073e9SAndroid Build Coastguard Worker subtest(torch.nn.BatchNorm2d, name="BatchNorm2d"), 4837*da0073e9SAndroid Build Coastguard Worker subtest(torch.nn.SyncBatchNorm, name="SyncBatchNorm"), 4838*da0073e9SAndroid Build Coastguard Worker ], 4839*da0073e9SAndroid Build Coastguard Worker ) 4840*da0073e9SAndroid Build Coastguard Worker def test_batchnorm_non_contig_cpu(self, bn_module): 4841*da0073e9SAndroid Build Coastguard Worker def helper(self, dtype): 4842*da0073e9SAndroid Build Coastguard Worker input = torch.arange(6, dtype=torch.float).reshape(1, 3, 2, 1).cpu() 4843*da0073e9SAndroid Build Coastguard Worker input = input.permute(0, 2, 1, 3) 4844*da0073e9SAndroid Build Coastguard Worker 4845*da0073e9SAndroid Build Coastguard Worker bn = bn_module(2).cpu().float().eval() 4846*da0073e9SAndroid Build Coastguard Worker bn.weight.data.uniform_() 4847*da0073e9SAndroid Build Coastguard Worker bn.bias.data.uniform_() 4848*da0073e9SAndroid Build Coastguard Worker 4849*da0073e9SAndroid Build Coastguard Worker ref_input = input.detach().clone().contiguous() 4850*da0073e9SAndroid Build Coastguard Worker ref_bn = nn.BatchNorm2d(2).cpu().float().eval() 4851*da0073e9SAndroid Build Coastguard Worker ref_bn.load_state_dict(bn.state_dict()) 4852*da0073e9SAndroid Build Coastguard Worker 4853*da0073e9SAndroid Build Coastguard Worker out = bn(input) 4854*da0073e9SAndroid Build Coastguard Worker ref_out = ref_bn(ref_input) 4855*da0073e9SAndroid Build Coastguard Worker 4856*da0073e9SAndroid Build Coastguard Worker self.assertTrue(out.is_contiguous(memory_format=torch.channels_last)) 4857*da0073e9SAndroid Build Coastguard Worker self.assertTrue(ref_out.is_contiguous()) 4858*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out, ref_out) 4859*da0073e9SAndroid Build Coastguard Worker 4860*da0073e9SAndroid Build Coastguard Worker input_bf = torch.arange(24, dtype=dtype).reshape(1, 3, 2, 4) 4861*da0073e9SAndroid Build Coastguard Worker input_bf = input_bf.permute(0, 2, 1, 3) 4862*da0073e9SAndroid Build Coastguard Worker input_f = input_bf.float() 4863*da0073e9SAndroid Build Coastguard Worker bn_mix = bn_module(2).float().eval() 4864*da0073e9SAndroid Build Coastguard Worker ref_bn_f = deepcopy(bn_mix) 4865*da0073e9SAndroid Build Coastguard Worker out_bf = bn_mix(input_bf) 4866*da0073e9SAndroid Build Coastguard Worker ref_out_bf = ref_bn_f(input_f) 4867*da0073e9SAndroid Build Coastguard Worker self.assertEqual(ref_out_bf, out_bf.float(), atol=0.05, rtol=0.05) 4868*da0073e9SAndroid Build Coastguard Worker 4869*da0073e9SAndroid Build Coastguard Worker helper(self, torch.bfloat16) 4870*da0073e9SAndroid Build Coastguard Worker helper(self, torch.float16) 4871*da0073e9SAndroid Build Coastguard Worker 4872*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_CUDA, "CUDA unavailable") 4873*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_CUDNN, "needs cudnn") 4874*da0073e9SAndroid Build Coastguard Worker def test_batchnorm_cudnn_nhwc(self): 4875*da0073e9SAndroid Build Coastguard Worker def run_test(input, grad_output): 4876*da0073e9SAndroid Build Coastguard Worker c = input.size(1) 4877*da0073e9SAndroid Build Coastguard Worker mod = nn.BatchNorm2d(c).cuda().float() 4878*da0073e9SAndroid Build Coastguard Worker mod.weight.data.uniform_() 4879*da0073e9SAndroid Build Coastguard Worker mod.bias.data.uniform_() 4880*da0073e9SAndroid Build Coastguard Worker ref_input = input.detach().clone().contiguous().requires_grad_(True) 4881*da0073e9SAndroid Build Coastguard Worker ref_grad = grad.detach().clone().contiguous() 4882*da0073e9SAndroid Build Coastguard Worker ref_mod = nn.BatchNorm2d(c).cuda().float() 4883*da0073e9SAndroid Build Coastguard Worker ref_mod.load_state_dict(mod.state_dict()) 4884*da0073e9SAndroid Build Coastguard Worker out = mod(input) 4885*da0073e9SAndroid Build Coastguard Worker out.backward(grad_output) 4886*da0073e9SAndroid Build Coastguard Worker ref_out = ref_mod(ref_input) 4887*da0073e9SAndroid Build Coastguard Worker ref_out.backward(ref_grad) 4888*da0073e9SAndroid Build Coastguard Worker self.assertTrue(out.is_contiguous(memory_format=torch.channels_last)) 4889*da0073e9SAndroid Build Coastguard Worker self.assertTrue(ref_out.is_contiguous()) 4890*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out, ref_out) 4891*da0073e9SAndroid Build Coastguard Worker self.assertEqual(mod.weight.grad, ref_mod.weight.grad) 4892*da0073e9SAndroid Build Coastguard Worker self.assertEqual(mod.bias.grad, ref_mod.bias.grad) 4893*da0073e9SAndroid Build Coastguard Worker self.assertEqual(input.grad, ref_input.grad) 4894*da0073e9SAndroid Build Coastguard Worker 4895*da0073e9SAndroid Build Coastguard Worker input = torch.randint(1, 10, (4, 8, 2, 2), dtype=torch.float32, device="cuda") 4896*da0073e9SAndroid Build Coastguard Worker input = input.contiguous(memory_format=torch.channels_last).detach().requires_grad_() 4897*da0073e9SAndroid Build Coastguard Worker 4898*da0073e9SAndroid Build Coastguard Worker grad = torch.randint(1, 10, (4, 8, 2, 2), dtype=torch.float32, device="cuda") 4899*da0073e9SAndroid Build Coastguard Worker grad = grad.contiguous(memory_format=torch.channels_last) 4900*da0073e9SAndroid Build Coastguard Worker run_test(input, grad) 4901*da0073e9SAndroid Build Coastguard Worker # see #42588, grad is channels_last contiguous, but grad.suggest_memory_format (rightly) return "contiguous" 4902*da0073e9SAndroid Build Coastguard Worker # not channels_last 4903*da0073e9SAndroid Build Coastguard Worker input = torch.randint(1, 10, (2, 8, 8, 1), dtype=torch.float32, device="cuda") 4904*da0073e9SAndroid Build Coastguard Worker input = input.contiguous(memory_format=torch.channels_last).detach().requires_grad_() 4905*da0073e9SAndroid Build Coastguard Worker grad = torch.randint(1, 10, (2, 8, 8, 1), dtype=torch.float32, device="cuda") 4906*da0073e9SAndroid Build Coastguard Worker grad = grad.permute(0, 2, 1, 3) 4907*da0073e9SAndroid Build Coastguard Worker run_test(input, grad) 4908*da0073e9SAndroid Build Coastguard Worker 4909*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_CUDA, "CUDA unavailable") 4910*da0073e9SAndroid Build Coastguard Worker def test_batchnorm_cudnn_half(self): 4911*da0073e9SAndroid Build Coastguard Worker # THNN 4912*da0073e9SAndroid Build Coastguard Worker input = torch.randint(1, 10, (2, 3, 2, 2), dtype=torch.half, device="cuda", requires_grad=True) 4913*da0073e9SAndroid Build Coastguard Worker m = nn.BatchNorm2d(3).half().cuda() 4914*da0073e9SAndroid Build Coastguard Worker thnn_output = m(input) 4915*da0073e9SAndroid Build Coastguard Worker thnn_output.sum().backward() 4916*da0073e9SAndroid Build Coastguard Worker thnn_input_grad = input.grad.data.clone() 4917*da0073e9SAndroid Build Coastguard Worker self.assertEqualTypeString(thnn_output, input) 4918*da0073e9SAndroid Build Coastguard Worker # cuDNN 4919*da0073e9SAndroid Build Coastguard Worker if TEST_CUDNN: 4920*da0073e9SAndroid Build Coastguard Worker input.grad = None 4921*da0073e9SAndroid Build Coastguard Worker m = m.float() 4922*da0073e9SAndroid Build Coastguard Worker cudnn_output = m(input) 4923*da0073e9SAndroid Build Coastguard Worker cudnn_output.sum().backward() 4924*da0073e9SAndroid Build Coastguard Worker cudnn_input_grad = input.grad.data.clone() 4925*da0073e9SAndroid Build Coastguard Worker self.assertEqualTypeString(cudnn_output, input) 4926*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cudnn_output, thnn_output) 4927*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cudnn_input_grad, thnn_input_grad, atol=1e-3, rtol=0) 4928*da0073e9SAndroid Build Coastguard Worker 4929*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_CUDA, "CUDA unavailable") 4930*da0073e9SAndroid Build Coastguard Worker def test_batchnorm_nonaffine_cuda_half_input(self): 4931*da0073e9SAndroid Build Coastguard Worker input = torch.randn(16, 3, 24, 24, dtype=torch.half, device="cuda") 4932*da0073e9SAndroid Build Coastguard Worker m = nn.BatchNorm2d(3, affine=False).cuda().float() # keep running stats in FP32 4933*da0073e9SAndroid Build Coastguard Worker output = m(input) 4934*da0073e9SAndroid Build Coastguard Worker self.assertEqualTypeString(output, input) 4935*da0073e9SAndroid Build Coastguard Worker m.eval() 4936*da0073e9SAndroid Build Coastguard Worker output = m(input) 4937*da0073e9SAndroid Build Coastguard Worker self.assertEqualTypeString(output, input) 4938*da0073e9SAndroid Build Coastguard Worker 4939*da0073e9SAndroid Build Coastguard Worker def test_batchnorm_raises_error_if_less_than_one_value_per_channel(self): 4940*da0073e9SAndroid Build Coastguard Worker x = torch.rand(10)[None, :, None] 4941*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(ValueError): 4942*da0073e9SAndroid Build Coastguard Worker torch.nn.BatchNorm1d(10)(x) 4943*da0073e9SAndroid Build Coastguard Worker 4944*da0073e9SAndroid Build Coastguard Worker def test_batchnorm_raises_error_if_running_mean_is_not_same_size_as_input(self): 4945*da0073e9SAndroid Build Coastguard Worker input = torch.rand(2, 10) 4946*da0073e9SAndroid Build Coastguard Worker running_var = torch.rand(10) 4947*da0073e9SAndroid Build Coastguard Worker wrong_sizes = [9, 11] 4948*da0073e9SAndroid Build Coastguard Worker for size in wrong_sizes: 4949*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(RuntimeError): 4950*da0073e9SAndroid Build Coastguard Worker F.batch_norm(input, torch.rand(size), running_var) 4951*da0073e9SAndroid Build Coastguard Worker 4952*da0073e9SAndroid Build Coastguard Worker def test_batchnorm_raises_error_if_running_var_is_not_same_size_as_input(self): 4953*da0073e9SAndroid Build Coastguard Worker input = torch.rand(2, 10) 4954*da0073e9SAndroid Build Coastguard Worker running_mean = torch.rand(10) 4955*da0073e9SAndroid Build Coastguard Worker wrong_sizes = [9, 11] 4956*da0073e9SAndroid Build Coastguard Worker for size in wrong_sizes: 4957*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(RuntimeError): 4958*da0073e9SAndroid Build Coastguard Worker F.batch_norm(input, running_mean, torch.rand(size)) 4959*da0073e9SAndroid Build Coastguard Worker 4960*da0073e9SAndroid Build Coastguard Worker def test_batchnorm_raises_error_if_weight_is_not_same_size_as_input(self): 4961*da0073e9SAndroid Build Coastguard Worker input = torch.rand(2, 10) 4962*da0073e9SAndroid Build Coastguard Worker running_mean = torch.rand(10) 4963*da0073e9SAndroid Build Coastguard Worker running_var = torch.rand(10) 4964*da0073e9SAndroid Build Coastguard Worker wrong_sizes = [9, 11] 4965*da0073e9SAndroid Build Coastguard Worker for size in wrong_sizes: 4966*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(RuntimeError): 4967*da0073e9SAndroid Build Coastguard Worker F.batch_norm(input, running_mean, running_var, weight=Parameter(torch.rand(size))) 4968*da0073e9SAndroid Build Coastguard Worker 4969*da0073e9SAndroid Build Coastguard Worker def test_batchnorm_raises_error_if_bias_is_not_same_size_as_input(self): 4970*da0073e9SAndroid Build Coastguard Worker input = torch.rand(2, 10) 4971*da0073e9SAndroid Build Coastguard Worker running_mean = torch.rand(10) 4972*da0073e9SAndroid Build Coastguard Worker running_var = torch.rand(10) 4973*da0073e9SAndroid Build Coastguard Worker wrong_sizes = [9, 11] 4974*da0073e9SAndroid Build Coastguard Worker for size in wrong_sizes: 4975*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(RuntimeError): 4976*da0073e9SAndroid Build Coastguard Worker F.batch_norm(input, running_mean, running_var, bias=Parameter(torch.rand(size))) 4977*da0073e9SAndroid Build Coastguard Worker 4978*da0073e9SAndroid Build Coastguard Worker def test_batchnorm_raises_error_if_running_var_or_running_mean_have_forward_grad(self): 4979*da0073e9SAndroid Build Coastguard Worker args = ( 4980*da0073e9SAndroid Build Coastguard Worker torch.randn(3, 2, 5), # input 4981*da0073e9SAndroid Build Coastguard Worker torch.randn(2), # running_mean 4982*da0073e9SAndroid Build Coastguard Worker torch.randn(2), # running_var 4983*da0073e9SAndroid Build Coastguard Worker ) 4984*da0073e9SAndroid Build Coastguard Worker kwargs = {'training': False, 'momentum': -1.2} 4985*da0073e9SAndroid Build Coastguard Worker fn = partial(F.batch_norm, **kwargs) 4986*da0073e9SAndroid Build Coastguard Worker 4987*da0073e9SAndroid Build Coastguard Worker for dual_indices in ((0,), (1,), (1, 2), (0, 1), (0, 1, 2),): 4988*da0073e9SAndroid Build Coastguard Worker tangents = tuple(torch.rand_like(x) for x in args) 4989*da0073e9SAndroid Build Coastguard Worker 4990*da0073e9SAndroid Build Coastguard Worker with fwAD.dual_level(): 4991*da0073e9SAndroid Build Coastguard Worker duals = [fwAD.make_dual(primal, tangent) if i in dual_indices else primal 4992*da0073e9SAndroid Build Coastguard Worker for i, (primal, tangent) in enumerate(zip(args, tangents))] 4993*da0073e9SAndroid Build Coastguard Worker msg = "batch_norm is not differentiable wrt running_mean and running_var" 4994*da0073e9SAndroid Build Coastguard Worker # 0 needs to have forward grad because otherwise we won't even run batch_norm_jvp 4995*da0073e9SAndroid Build Coastguard Worker if (1 in dual_indices or 2 in dual_indices) and 0 in dual_indices: 4996*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, msg): 4997*da0073e9SAndroid Build Coastguard Worker fn(*duals) 4998*da0073e9SAndroid Build Coastguard Worker else: 4999*da0073e9SAndroid Build Coastguard Worker fn(*duals) 5000*da0073e9SAndroid Build Coastguard Worker 5001*da0073e9SAndroid Build Coastguard Worker def test_batchnorm_buffer_update_when_stats_are_not_tracked(self): 5002*da0073e9SAndroid Build Coastguard Worker input_size = (32, 4) 5003*da0073e9SAndroid Build Coastguard Worker # Instantiate BN with buffers that are not None 5004*da0073e9SAndroid Build Coastguard Worker bn = nn.BatchNorm1d(input_size[1], track_running_stats=True) 5005*da0073e9SAndroid Build Coastguard Worker # Use buffers for normalization but don't update them 5006*da0073e9SAndroid Build Coastguard Worker bn.track_running_stats = False 5007*da0073e9SAndroid Build Coastguard Worker # Store initial values 5008*da0073e9SAndroid Build Coastguard Worker num_batches = bn.num_batches_tracked.clone() 5009*da0073e9SAndroid Build Coastguard Worker running_mean = bn.running_mean.clone() 5010*da0073e9SAndroid Build Coastguard Worker running_var = bn.running_var.clone() 5011*da0073e9SAndroid Build Coastguard Worker # Forward random tensor 5012*da0073e9SAndroid Build Coastguard Worker _ = bn(torch.rand(input_size)) 5013*da0073e9SAndroid Build Coastguard Worker # Ensure none of the buffers has been updated 5014*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.equal(num_batches, bn.num_batches_tracked)) 5015*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.equal(running_mean, bn.running_mean)) 5016*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.equal(running_var, bn.running_var)) 5017*da0073e9SAndroid Build Coastguard Worker 5018*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") 5019*da0073e9SAndroid Build Coastguard Worker def test_batchnorm_nhwc_cuda(self): 5020*da0073e9SAndroid Build Coastguard Worker for dtype in (torch.half, torch.float): 5021*da0073e9SAndroid Build Coastguard Worker (N, C, H, W) = 2, 64, 50, 50 5022*da0073e9SAndroid Build Coastguard Worker model = torch.nn.BatchNorm2d(C, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) 5023*da0073e9SAndroid Build Coastguard Worker model = model.eval().cuda().to(dtype) 5024*da0073e9SAndroid Build Coastguard Worker inp1 = torch.randn(N, C, H, W, device=torch.device('cuda'), dtype=dtype) 5025*da0073e9SAndroid Build Coastguard Worker inp2 = inp1.contiguous(memory_format=torch.channels_last) 5026*da0073e9SAndroid Build Coastguard Worker out1 = model(inp1) 5027*da0073e9SAndroid Build Coastguard Worker out2 = model(inp2) 5028*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.equal(out1, out2)) 5029*da0073e9SAndroid Build Coastguard Worker 5030*da0073e9SAndroid Build Coastguard Worker def test_batchnorm_load_state_dict(self): 5031*da0073e9SAndroid Build Coastguard Worker bn = torch.nn.BatchNorm2d(3) 5032*da0073e9SAndroid Build Coastguard Worker self.assertEqual(bn.state_dict()["num_batches_tracked"], torch.tensor(0)) 5033*da0073e9SAndroid Build Coastguard Worker 5034*da0073e9SAndroid Build Coastguard Worker bn.num_batches_tracked = torch.tensor(10) 5035*da0073e9SAndroid Build Coastguard Worker self.assertEqual(bn.state_dict()["num_batches_tracked"], torch.tensor(10)) 5036*da0073e9SAndroid Build Coastguard Worker 5037*da0073e9SAndroid Build Coastguard Worker empty_dict = OrderedDict() 5038*da0073e9SAndroid Build Coastguard Worker bn.load_state_dict(empty_dict, strict=False) 5039*da0073e9SAndroid Build Coastguard Worker self.assertEqual(bn.state_dict()["num_batches_tracked"], torch.tensor(10)) 5040*da0073e9SAndroid Build Coastguard Worker 5041*da0073e9SAndroid Build Coastguard Worker # test that when `num_batches_tracked` is not in loaded state_dict, 5042*da0073e9SAndroid Build Coastguard Worker # meta num_batches_tracked is still replaced with singleton 0 tensor 5043*da0073e9SAndroid Build Coastguard Worker with torch.device('meta'): 5044*da0073e9SAndroid Build Coastguard Worker meta_bn = torch.nn.BatchNorm2d(3) 5045*da0073e9SAndroid Build Coastguard Worker self.assertTrue(meta_bn.num_batches_tracked.device == torch.device('meta')) 5046*da0073e9SAndroid Build Coastguard Worker meta_bn.load_state_dict(empty_dict, assign=True, strict=False) 5047*da0073e9SAndroid Build Coastguard Worker self.assertEqual(meta_bn.state_dict()["num_batches_tracked"], torch.tensor(0)) 5048*da0073e9SAndroid Build Coastguard Worker 5049*da0073e9SAndroid Build Coastguard Worker def test_batch_norm_update_stats(self): 5050*da0073e9SAndroid Build Coastguard Worker input = torch.rand(0, 1) 5051*da0073e9SAndroid Build Coastguard Worker running_mean = torch.rand(1) 5052*da0073e9SAndroid Build Coastguard Worker running_var = torch.rand(1) 5053*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, 5054*da0073e9SAndroid Build Coastguard Worker re.escape("input tensor must have at least one element, but got input_sizes = [0, 1]")): 5055*da0073e9SAndroid Build Coastguard Worker torch.batch_norm_update_stats(input=input, momentum=0.0, running_mean=running_mean, running_var=running_var) 5056*da0073e9SAndroid Build Coastguard Worker 5057*da0073e9SAndroid Build Coastguard Worker def test_pairwise_distance(self): 5058*da0073e9SAndroid Build Coastguard Worker input1 = torch.randn(4, 4, requires_grad=True, dtype=torch.double) 5059*da0073e9SAndroid Build Coastguard Worker input2 = torch.randn(4, 4, requires_grad=True, dtype=torch.double) 5060*da0073e9SAndroid Build Coastguard Worker self.assertTrue(gradcheck(lambda x, y: F.pairwise_distance(x, y), (input1, input2))) 5061*da0073e9SAndroid Build Coastguard Worker 5062*da0073e9SAndroid Build Coastguard Worker # TODO: Create an OpInfo for pdist 5063*da0073e9SAndroid Build Coastguard Worker def test_pdist(self): 5064*da0073e9SAndroid Build Coastguard Worker for device, trans in itertools.product(device_(), [False, True]): 5065*da0073e9SAndroid Build Coastguard Worker inp = torch.randn(4, 5, dtype=torch.double, device=device, requires_grad=True) 5066*da0073e9SAndroid Build Coastguard Worker if trans: 5067*da0073e9SAndroid Build Coastguard Worker inp = inp.transpose(0, 1) 5068*da0073e9SAndroid Build Coastguard Worker for p in [0, 1, 2, 0.5, 1.5, 2.5, float('inf')]: 5069*da0073e9SAndroid Build Coastguard Worker self.assertTrue(gradcheck(lambda x: F.pdist(x, p), (inp,))) 5070*da0073e9SAndroid Build Coastguard Worker 5071*da0073e9SAndroid Build Coastguard Worker def test_pdist_zeros(self): 5072*da0073e9SAndroid Build Coastguard Worker """Test that grad is still valid when dist is 0""" 5073*da0073e9SAndroid Build Coastguard Worker for device in device_(): 5074*da0073e9SAndroid Build Coastguard Worker inp = torch.randn(1, 3, dtype=torch.double, device=device, requires_grad=True).repeat([2, 1]) 5075*da0073e9SAndroid Build Coastguard Worker for p in [0, 1, 2, 0.5, 1.5, 2.5, float('inf')]: 5076*da0073e9SAndroid Build Coastguard Worker self.assertTrue(gradcheck(lambda x: F.pdist(x, p), (inp,))) 5077*da0073e9SAndroid Build Coastguard Worker 5078*da0073e9SAndroid Build Coastguard Worker def test_pdist_empty_row(self): 5079*da0073e9SAndroid Build Coastguard Worker for device in device_(): 5080*da0073e9SAndroid Build Coastguard Worker inp = torch.randn(1, 3, dtype=torch.double, device=device, requires_grad=True) 5081*da0073e9SAndroid Build Coastguard Worker self.assertTrue(gradcheck(F.pdist, (inp,))) 5082*da0073e9SAndroid Build Coastguard Worker 5083*da0073e9SAndroid Build Coastguard Worker def test_pdist_empty_col(self): 5084*da0073e9SAndroid Build Coastguard Worker for device in device_(): 5085*da0073e9SAndroid Build Coastguard Worker inp = torch.randn(4, 0, dtype=torch.double, device=device, requires_grad=True) 5086*da0073e9SAndroid Build Coastguard Worker self.assertTrue(gradcheck(F.pdist, (inp,))) 5087*da0073e9SAndroid Build Coastguard Worker 5088*da0073e9SAndroid Build Coastguard Worker @unittest.expectedFailure 5089*da0073e9SAndroid Build Coastguard Worker def test_pdist_cpu_gradgrad_unimplemented(self): 5090*da0073e9SAndroid Build Coastguard Worker inp = torch.randn(4, 5, requires_grad=True) 5091*da0073e9SAndroid Build Coastguard Worker gradgradcheck(F.pdist, (inp,)) 5092*da0073e9SAndroid Build Coastguard Worker 5093*da0073e9SAndroid Build Coastguard Worker @unittest.expectedFailure 5094*da0073e9SAndroid Build Coastguard Worker def test_pdist_cuda_gradgrad_unimplemented(self): 5095*da0073e9SAndroid Build Coastguard Worker inp = torch.randn(4, 5, device='cuda', requires_grad=True) 5096*da0073e9SAndroid Build Coastguard Worker gradgradcheck(F.pdist, (inp,)) 5097*da0073e9SAndroid Build Coastguard Worker 5098*da0073e9SAndroid Build Coastguard Worker # Merge into OpInfo? 5099*da0073e9SAndroid Build Coastguard Worker # test for backward in https://github.com/pytorch/pytorch/issues/15511 5100*da0073e9SAndroid Build Coastguard Worker def test_pdist_large(self): 5101*da0073e9SAndroid Build Coastguard Worker for device in device_(): 5102*da0073e9SAndroid Build Coastguard Worker def func(x): 5103*da0073e9SAndroid Build Coastguard Worker return torch.pdist(x, p=2) 5104*da0073e9SAndroid Build Coastguard Worker 5105*da0073e9SAndroid Build Coastguard Worker # shape[0] should be able to be (roughly) arbitrarily large, but the kernel 5106*da0073e9SAndroid Build Coastguard Worker # is currently limited to smaller sizes (see issue above); this is just testing 5107*da0073e9SAndroid Build Coastguard Worker # a floor. 5108*da0073e9SAndroid Build Coastguard Worker shape = (1000, 1) 5109*da0073e9SAndroid Build Coastguard Worker x = torch.randn(shape, device=device).requires_grad_() 5110*da0073e9SAndroid Build Coastguard Worker output = torch.pdist(x, p=2) 5111*da0073e9SAndroid Build Coastguard Worker # just run a single backward, as gradcheck/gradgradcheck is expensive here 5112*da0073e9SAndroid Build Coastguard Worker output.sum().backward() 5113*da0073e9SAndroid Build Coastguard Worker 5114*da0073e9SAndroid Build Coastguard Worker def test_cosine_embedding_loss_with_diff_type(self): 5115*da0073e9SAndroid Build Coastguard Worker for device in device_(): 5116*da0073e9SAndroid Build Coastguard Worker input1 = torch.tensor([[2, 3, 4], [6, 2, 4]], dtype=torch.double, device=device) 5117*da0073e9SAndroid Build Coastguard Worker input2 = torch.tensor([[2, 3, 5], [3, 2, 1]], dtype=torch.double, device=device) 5118*da0073e9SAndroid Build Coastguard Worker target = torch.tensor([1, -1], dtype=torch.int, device=device) 5119*da0073e9SAndroid Build Coastguard Worker expected = torch.nn.functional.cosine_embedding_loss(input1, input2, target) 5120*da0073e9SAndroid Build Coastguard Worker for dt1 in get_all_math_dtypes(device): 5121*da0073e9SAndroid Build Coastguard Worker for dt2 in get_all_math_dtypes(device): 5122*da0073e9SAndroid Build Coastguard Worker for dt3 in get_all_math_dtypes(device): 5123*da0073e9SAndroid Build Coastguard Worker # dt3 is used as dtype for target = [1, -1], so let's skip unsigned type 5124*da0073e9SAndroid Build Coastguard Worker if dt3 == torch.uint8: 5125*da0073e9SAndroid Build Coastguard Worker continue 5126*da0073e9SAndroid Build Coastguard Worker if dt1.is_complex or dt2.is_complex or dt3.is_complex: 5127*da0073e9SAndroid Build Coastguard Worker continue 5128*da0073e9SAndroid Build Coastguard Worker input1 = input1.to(dt1) 5129*da0073e9SAndroid Build Coastguard Worker input2 = input2.to(dt2) 5130*da0073e9SAndroid Build Coastguard Worker target = target.to(dt3) 5131*da0073e9SAndroid Build Coastguard Worker result = torch.nn.functional.cosine_embedding_loss(input1, input2, target) 5132*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result.item(), expected.item(), atol=0.001, rtol=0) 5133*da0073e9SAndroid Build Coastguard Worker 5134*da0073e9SAndroid Build Coastguard Worker def test_cosine_embedding_loss_error_on_diff_shapes(self): 5135*da0073e9SAndroid Build Coastguard Worker for device in device_(): 5136*da0073e9SAndroid Build Coastguard Worker input1 = torch.empty((0, 0), dtype=torch.double, device=device) 5137*da0073e9SAndroid Build Coastguard Worker input2 = torch.empty((0,), dtype=torch.double, device=device) 5138*da0073e9SAndroid Build Coastguard Worker target = torch.empty((0,), dtype=torch.int, device=device) 5139*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, ".*expects 2D.*"): 5140*da0073e9SAndroid Build Coastguard Worker torch.nn.functional.cosine_embedding_loss(input1, input2, target) 5141*da0073e9SAndroid Build Coastguard Worker 5142*da0073e9SAndroid Build Coastguard Worker def test_cosine_embedding_loss_error_on_nonexpandable_shapes(self): 5143*da0073e9SAndroid Build Coastguard Worker for device in device_(): 5144*da0073e9SAndroid Build Coastguard Worker input1 = torch.empty((1, 5), dtype=torch.double, device=device) 5145*da0073e9SAndroid Build Coastguard Worker input2 = torch.empty((1, 6), dtype=torch.double, device=device) 5146*da0073e9SAndroid Build Coastguard Worker target = torch.ones((1,), dtype=torch.int, device=device) 5147*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, ".*must match the size.*"): 5148*da0073e9SAndroid Build Coastguard Worker torch.nn.functional.cosine_embedding_loss(input1, input2, target) 5149*da0073e9SAndroid Build Coastguard Worker 5150*da0073e9SAndroid Build Coastguard Worker def test_kl_div_with_diff_type(self): 5151*da0073e9SAndroid Build Coastguard Worker for device in device_(): 5152*da0073e9SAndroid Build Coastguard Worker input = torch.tensor([[2, 3, 5], [3, 2, 1]], dtype=torch.double, device=device) 5153*da0073e9SAndroid Build Coastguard Worker target = torch.tensor([[1, 2, 3], [4, 5, 6]], dtype=torch.double, device=device) 5154*da0073e9SAndroid Build Coastguard Worker expected = torch.nn.functional.kl_div(input, target) 5155*da0073e9SAndroid Build Coastguard Worker real_dtypes = (torch.float32, torch.float64, torch.float16) 5156*da0073e9SAndroid Build Coastguard Worker for input_dtype, target_dtype in product(real_dtypes, repeat=2): 5157*da0073e9SAndroid Build Coastguard Worker if (torch.device(device).type == 'cpu' and target_dtype == torch.float16): 5158*da0073e9SAndroid Build Coastguard Worker continue 5159*da0073e9SAndroid Build Coastguard Worker input = input.to(input_dtype) 5160*da0073e9SAndroid Build Coastguard Worker target = target.to(target_dtype) 5161*da0073e9SAndroid Build Coastguard Worker result = torch.nn.functional.kl_div(input, target) 5162*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result.item(), expected.item(), atol=0.001, rtol=0) 5163*da0073e9SAndroid Build Coastguard Worker 5164*da0073e9SAndroid Build Coastguard Worker def test_kl_div_with_diff_type_log_target(self): 5165*da0073e9SAndroid Build Coastguard Worker for device in device_(): 5166*da0073e9SAndroid Build Coastguard Worker input = torch.tensor([[2, 3, 5], [3, 2, 1]], dtype=torch.double, device=device) 5167*da0073e9SAndroid Build Coastguard Worker target = torch.tensor([[1, 2, 3], [4, 5, 6]], dtype=torch.double, device=device).log() 5168*da0073e9SAndroid Build Coastguard Worker expected = torch.nn.functional.kl_div(input, target, log_target=True) 5169*da0073e9SAndroid Build Coastguard Worker real_dtypes = (torch.float32, torch.float64, torch.float16) 5170*da0073e9SAndroid Build Coastguard Worker for input_dtype, target_dtype in product(real_dtypes, repeat=2): 5171*da0073e9SAndroid Build Coastguard Worker if (torch.device(device).type == 'cpu' and target_dtype == torch.float16): 5172*da0073e9SAndroid Build Coastguard Worker continue 5173*da0073e9SAndroid Build Coastguard Worker input = input.to(input_dtype) 5174*da0073e9SAndroid Build Coastguard Worker target = target.to(target_dtype) 5175*da0073e9SAndroid Build Coastguard Worker result = torch.nn.functional.kl_div(input, target, log_target=True) 5176*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result.item(), expected.item(), atol=0.001, rtol=0) 5177*da0073e9SAndroid Build Coastguard Worker 5178*da0073e9SAndroid Build Coastguard Worker def test_kl_div_log_softmax_target(self): 5179*da0073e9SAndroid Build Coastguard Worker for device in device_(): 5180*da0073e9SAndroid Build Coastguard Worker a = torch.tensor([[1.0, 2, 3], [5.0, 5, 5]], device=device) 5181*da0073e9SAndroid Build Coastguard Worker b = torch.tensor([[1.0, 2, 3], [5.0, 5, 5]], device=device) 5182*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 5183*da0073e9SAndroid Build Coastguard Worker F.kl_div(F.log_softmax(a, 1), F.log_softmax(b, 1), reduction='none', log_target=True), 5184*da0073e9SAndroid Build Coastguard Worker torch.zeros_like(a) 5185*da0073e9SAndroid Build Coastguard Worker ) 5186*da0073e9SAndroid Build Coastguard Worker 5187*da0073e9SAndroid Build Coastguard Worker def test_cosine_embedding_loss_no_reduce(self): 5188*da0073e9SAndroid Build Coastguard Worker input1 = torch.randn(15, 10, requires_grad=True, dtype=torch.double) 5189*da0073e9SAndroid Build Coastguard Worker input2 = torch.randn(15, 10, requires_grad=True, dtype=torch.double) 5190*da0073e9SAndroid Build Coastguard Worker target = torch.randn(15, dtype=torch.double).sign() 5191*da0073e9SAndroid Build Coastguard Worker self.assertTrue(gradcheck(lambda x, y, z: F.cosine_embedding_loss( 5192*da0073e9SAndroid Build Coastguard Worker x, y, z, reduction='none'), (input1, input2, target))) 5193*da0073e9SAndroid Build Coastguard Worker self.assertEqual(F.cosine_embedding_loss(input1, input2, target, reduction='none'), 5194*da0073e9SAndroid Build Coastguard Worker loss_reference_fns['CosineEmbeddingLoss'](input1, input2, target, reduction='none')) 5195*da0073e9SAndroid Build Coastguard Worker 5196*da0073e9SAndroid Build Coastguard Worker def test_cosine_embedding_loss_margin_no_reduce(self): 5197*da0073e9SAndroid Build Coastguard Worker input1 = torch.randn(15, 10, requires_grad=True, dtype=torch.double) 5198*da0073e9SAndroid Build Coastguard Worker input2 = torch.randn(15, 10, requires_grad=True, dtype=torch.double) 5199*da0073e9SAndroid Build Coastguard Worker target = torch.randn(15, dtype=torch.double).sign() 5200*da0073e9SAndroid Build Coastguard Worker self.assertTrue(gradcheck(lambda x, y, z: F.cosine_embedding_loss( 5201*da0073e9SAndroid Build Coastguard Worker x, y, z, margin=0.5, reduction='none'), (input1, input2, target))) 5202*da0073e9SAndroid Build Coastguard Worker self.assertEqual(F.cosine_embedding_loss(input1, input2, target, margin=0.5, reduction='none'), 5203*da0073e9SAndroid Build Coastguard Worker loss_reference_fns['CosineEmbeddingLoss'](input1, input2, target, 5204*da0073e9SAndroid Build Coastguard Worker margin=0.5, reduction='none')) 5205*da0073e9SAndroid Build Coastguard Worker 5206*da0073e9SAndroid Build Coastguard Worker def test_cosine_embedding_loss_invalid_shape(self): 5207*da0073e9SAndroid Build Coastguard Worker input1 = torch.randn(15, 10) 5208*da0073e9SAndroid Build Coastguard Worker input2 = torch.randn(15, 10) 5209*da0073e9SAndroid Build Coastguard Worker target = torch.randn(15, 1).sign() 5210*da0073e9SAndroid Build Coastguard Worker 5211*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "1D target tensor expected"): 5212*da0073e9SAndroid Build Coastguard Worker F.cosine_embedding_loss(input1, input2, target) 5213*da0073e9SAndroid Build Coastguard Worker 5214*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "1D target tensor expects 2D input tensors"): 5215*da0073e9SAndroid Build Coastguard Worker F.cosine_embedding_loss(torch.randn(10), torch.randn(10), torch.randn(10)) 5216*da0073e9SAndroid Build Coastguard Worker 5217*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "0D target tensor expects 1D input tensors"): 5218*da0073e9SAndroid Build Coastguard Worker F.cosine_embedding_loss(torch.randn(2, 5), torch.randn(2, 5), torch.randn(())) 5219*da0073e9SAndroid Build Coastguard Worker 5220*da0073e9SAndroid Build Coastguard Worker def test_margin_ranking_loss_no_reduce(self): 5221*da0073e9SAndroid Build Coastguard Worker input1 = torch.randn(15, dtype=torch.double).mul_(10).requires_grad_() 5222*da0073e9SAndroid Build Coastguard Worker input2 = torch.randn(15, dtype=torch.double).mul_(10).requires_grad_() 5223*da0073e9SAndroid Build Coastguard Worker target = torch.randn(15, dtype=torch.double).sign() 5224*da0073e9SAndroid Build Coastguard Worker self.assertTrue(gradcheck(lambda x, y, z: F.margin_ranking_loss( 5225*da0073e9SAndroid Build Coastguard Worker x, y, z, reduction='none'), (input1, input2, target))) 5226*da0073e9SAndroid Build Coastguard Worker self.assertEqual(F.margin_ranking_loss(input1, input2, target, reduction='none'), 5227*da0073e9SAndroid Build Coastguard Worker loss_reference_fns['MarginRankingLoss'](input1, input2, target, reduction='none')) 5228*da0073e9SAndroid Build Coastguard Worker 5229*da0073e9SAndroid Build Coastguard Worker def test_margin_ranking_loss_margin_no_reduce(self): 5230*da0073e9SAndroid Build Coastguard Worker input1 = torch.randn(15, dtype=torch.double).mul_(10).requires_grad_() 5231*da0073e9SAndroid Build Coastguard Worker input2 = torch.randn(15, dtype=torch.double).mul_(10).requires_grad_() 5232*da0073e9SAndroid Build Coastguard Worker target = torch.randn(15, dtype=torch.double).sign() 5233*da0073e9SAndroid Build Coastguard Worker self.assertTrue(gradcheck(lambda x, y, z: F.margin_ranking_loss( 5234*da0073e9SAndroid Build Coastguard Worker x, y, z, margin=0.5, reduction='none'), (input1, input2, target))) 5235*da0073e9SAndroid Build Coastguard Worker self.assertEqual(F.margin_ranking_loss(input1, input2, target, margin=0.5, reduction='none'), 5236*da0073e9SAndroid Build Coastguard Worker loss_reference_fns['MarginRankingLoss'](input1, input2, target, margin=0.5, reduction='none')) 5237*da0073e9SAndroid Build Coastguard Worker 5238*da0073e9SAndroid Build Coastguard Worker def test_triplet_margin_loss(self): 5239*da0073e9SAndroid Build Coastguard Worker input1 = torch.randn(5, 10, requires_grad=True, dtype=torch.double) 5240*da0073e9SAndroid Build Coastguard Worker input2 = torch.randn(5, 10, requires_grad=True, dtype=torch.double) 5241*da0073e9SAndroid Build Coastguard Worker input3 = torch.randn(5, 10, requires_grad=True, dtype=torch.double) 5242*da0073e9SAndroid Build Coastguard Worker self.assertTrue(gradcheck(lambda x1, x2, x3: F.triplet_margin_loss( 5243*da0073e9SAndroid Build Coastguard Worker x1, x2, x3), (input1, input2, input3))) 5244*da0073e9SAndroid Build Coastguard Worker self.assertEqual(F.triplet_margin_loss(input1, input2, input3), 5245*da0073e9SAndroid Build Coastguard Worker loss_reference_fns['TripletMarginLoss'](input1, input2, input3)) 5246*da0073e9SAndroid Build Coastguard Worker 5247*da0073e9SAndroid Build Coastguard Worker def test_triplet_margin_loss_swap(self): 5248*da0073e9SAndroid Build Coastguard Worker input1 = torch.randn(5, 10, requires_grad=True, dtype=torch.double) 5249*da0073e9SAndroid Build Coastguard Worker input2 = torch.randn(5, 10, requires_grad=True, dtype=torch.double) 5250*da0073e9SAndroid Build Coastguard Worker input3 = torch.randn(5, 10, requires_grad=True, dtype=torch.double) 5251*da0073e9SAndroid Build Coastguard Worker self.assertTrue(gradcheck(lambda x1, x2, x3: F.triplet_margin_loss( 5252*da0073e9SAndroid Build Coastguard Worker x1, x2, x3, swap=True), (input1, input2, input3))) 5253*da0073e9SAndroid Build Coastguard Worker self.assertEqual(F.triplet_margin_loss(input1, input2, input3, swap=True), 5254*da0073e9SAndroid Build Coastguard Worker loss_reference_fns['TripletMarginLoss'](input1, input2, input3, swap=True)) 5255*da0073e9SAndroid Build Coastguard Worker 5256*da0073e9SAndroid Build Coastguard Worker def test_triplet_margin_loss_no_reduce(self): 5257*da0073e9SAndroid Build Coastguard Worker input1 = torch.randn(5, 10, requires_grad=True, dtype=torch.double) 5258*da0073e9SAndroid Build Coastguard Worker input2 = torch.randn(5, 10, requires_grad=True, dtype=torch.double) 5259*da0073e9SAndroid Build Coastguard Worker input3 = torch.randn(5, 10, requires_grad=True, dtype=torch.double) 5260*da0073e9SAndroid Build Coastguard Worker self.assertTrue(gradcheck(lambda x1, x2, x3: F.triplet_margin_loss( 5261*da0073e9SAndroid Build Coastguard Worker x1, x2, x3, reduction='none'), (input1, input2, input3))) 5262*da0073e9SAndroid Build Coastguard Worker self.assertEqual(F.triplet_margin_loss(input1, input2, input3, reduction='none'), 5263*da0073e9SAndroid Build Coastguard Worker loss_reference_fns['TripletMarginLoss'](input1, input2, input3, reduction='none')) 5264*da0073e9SAndroid Build Coastguard Worker 5265*da0073e9SAndroid Build Coastguard Worker def test_triplet_margin_loss_swap_no_reduce(self): 5266*da0073e9SAndroid Build Coastguard Worker input1 = torch.randn(5, 10, requires_grad=True, dtype=torch.double) 5267*da0073e9SAndroid Build Coastguard Worker input2 = torch.randn(5, 10, requires_grad=True, dtype=torch.double) 5268*da0073e9SAndroid Build Coastguard Worker input3 = torch.randn(5, 10, requires_grad=True, dtype=torch.double) 5269*da0073e9SAndroid Build Coastguard Worker self.assertTrue(gradcheck(lambda x1, x2, x3: F.triplet_margin_loss( 5270*da0073e9SAndroid Build Coastguard Worker x1, x2, x3, swap=True, reduction='none'), (input1, input2, input3))) 5271*da0073e9SAndroid Build Coastguard Worker self.assertEqual(F.triplet_margin_loss(input1, input2, input3, swap=True, reduction='none'), 5272*da0073e9SAndroid Build Coastguard Worker loss_reference_fns['TripletMarginLoss'](input1, input2, input3, swap=True, reduction='none')) 5273*da0073e9SAndroid Build Coastguard Worker 5274*da0073e9SAndroid Build Coastguard Worker def test_pointwise_loss_target_grad_none_reduction(self): 5275*da0073e9SAndroid Build Coastguard Worker i = torch.randn(5, 10) 5276*da0073e9SAndroid Build Coastguard Worker t = torch.randn(5, 10, requires_grad=True) 5277*da0073e9SAndroid Build Coastguard Worker self.assertEqual(F.mse_loss(i, t, reduction='none').size(), t.size()) 5278*da0073e9SAndroid Build Coastguard Worker self.assertEqual(F.l1_loss(i, t, reduction='none').size(), t.size()) 5279*da0073e9SAndroid Build Coastguard Worker 5280*da0073e9SAndroid Build Coastguard Worker def test_pointwise_loss_broadcast(self): 5281*da0073e9SAndroid Build Coastguard Worker losses = { 5282*da0073e9SAndroid Build Coastguard Worker 'mse_loss': lambda x, y, r: F.mse_loss(x, y, reduction=r), 5283*da0073e9SAndroid Build Coastguard Worker 'l1_loss': lambda x, y, r: F.l1_loss(x, y, reduction=r), 5284*da0073e9SAndroid Build Coastguard Worker 'smooth_l1_loss': lambda x, y, r: F.smooth_l1_loss(x, y, reduction=r), 5285*da0073e9SAndroid Build Coastguard Worker 'huber_loss': lambda x, y, r: F.huber_loss(x, y, reduction=r), 5286*da0073e9SAndroid Build Coastguard Worker } 5287*da0073e9SAndroid Build Coastguard Worker 5288*da0073e9SAndroid Build Coastguard Worker input = torch.randn(2, 1, requires_grad=True, dtype=torch.double) 5289*da0073e9SAndroid Build Coastguard Worker for fn in losses.values(): 5290*da0073e9SAndroid Build Coastguard Worker for requires_grad in [True, False]: 5291*da0073e9SAndroid Build Coastguard Worker # When target.requires_grad=True, its impl is in Python, while the other is in TH. 5292*da0073e9SAndroid Build Coastguard Worker target = torch.randn(2, 10, requires_grad=requires_grad, dtype=torch.double) 5293*da0073e9SAndroid Build Coastguard Worker for reduction in ['none', 'mean', 'sum']: 5294*da0073e9SAndroid Build Coastguard Worker l = fn(input, target, reduction) 5295*da0073e9SAndroid Build Coastguard Worker if reduction == 'none': 5296*da0073e9SAndroid Build Coastguard Worker self.assertEqual(l.size(), target.size()) 5297*da0073e9SAndroid Build Coastguard Worker self.assertTrue(gradcheck(fn, (input, target, reduction))) 5298*da0073e9SAndroid Build Coastguard Worker 5299*da0073e9SAndroid Build Coastguard Worker # https://github.com/pytorch/pytorch/issues/27692 reports 5300*da0073e9SAndroid Build Coastguard Worker # that l1_loss get a wrong result for big batch size 5301*da0073e9SAndroid Build Coastguard Worker def test_l1_loss_correct(self): 5302*da0073e9SAndroid Build Coastguard Worker for dtype in [torch.float, torch.cfloat]: 5303*da0073e9SAndroid Build Coastguard Worker for N in range(1, 50, 10): 5304*da0073e9SAndroid Build Coastguard Worker input = torch.rand(N, 3, 1024, 1024, dtype=dtype) 5305*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 5306*da0073e9SAndroid Build Coastguard Worker torch.nn.L1Loss()(input, torch.zeros_like(input)), 5307*da0073e9SAndroid Build Coastguard Worker input.abs().mean()) 5308*da0073e9SAndroid Build Coastguard Worker 5309*da0073e9SAndroid Build Coastguard Worker def test_smoothl1loss_intergral_target(self): 5310*da0073e9SAndroid Build Coastguard Worker def _input_grad(input, target, reduction): 5311*da0073e9SAndroid Build Coastguard Worker output = F.smooth_l1_loss(input, target, reduction=reduction, beta=0.5) 5312*da0073e9SAndroid Build Coastguard Worker output.sum().backward() 5313*da0073e9SAndroid Build Coastguard Worker return input.grad 5314*da0073e9SAndroid Build Coastguard Worker 5315*da0073e9SAndroid Build Coastguard Worker for device, dtype, reduction in product(device_(), 5316*da0073e9SAndroid Build Coastguard Worker integral_types(), 5317*da0073e9SAndroid Build Coastguard Worker ('none', 'sum', 'mean')): 5318*da0073e9SAndroid Build Coastguard Worker input = torch.randn(2, 2, device=device, requires_grad=True) 5319*da0073e9SAndroid Build Coastguard Worker target = torch.randint(0, 9, (2, 2), device=device, dtype=dtype) 5320*da0073e9SAndroid Build Coastguard Worker 5321*da0073e9SAndroid Build Coastguard Worker input_grad_with_float_target = _input_grad(input, target.float(), reduction) 5322*da0073e9SAndroid Build Coastguard Worker 5323*da0073e9SAndroid Build Coastguard Worker input_grad = _input_grad(input.detach().clone().requires_grad_(True), 5324*da0073e9SAndroid Build Coastguard Worker target, 5325*da0073e9SAndroid Build Coastguard Worker reduction) 5326*da0073e9SAndroid Build Coastguard Worker self.assertEqual(input_grad, input_grad_with_float_target) 5327*da0073e9SAndroid Build Coastguard Worker 5328*da0073e9SAndroid Build Coastguard Worker def test_smoothl1loss_negative_beta_not_supported(self): 5329*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(RuntimeError): 5330*da0073e9SAndroid Build Coastguard Worker F.smooth_l1_loss(torch.randn(2, 2), torch.randn(2, 2), beta=-1.0) 5331*da0073e9SAndroid Build Coastguard Worker 5332*da0073e9SAndroid Build Coastguard Worker def test_huber_loss_invalid_delta(self): 5333*da0073e9SAndroid Build Coastguard Worker def _test_huber_loss_delta_error_helper(delta): 5334*da0073e9SAndroid Build Coastguard Worker input, target = torch.randn(2, 2), torch.randn(2, 2) 5335*da0073e9SAndroid Build Coastguard Worker loss = torch.nn.HuberLoss(delta=delta) 5336*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(RuntimeError): 5337*da0073e9SAndroid Build Coastguard Worker loss(input, target) 5338*da0073e9SAndroid Build Coastguard Worker 5339*da0073e9SAndroid Build Coastguard Worker def test_huber_loss_negative_delta(): 5340*da0073e9SAndroid Build Coastguard Worker _test_huber_loss_delta_error_helper(delta=-0.5) 5341*da0073e9SAndroid Build Coastguard Worker 5342*da0073e9SAndroid Build Coastguard Worker def test_huber_loss_zero_delta(): 5343*da0073e9SAndroid Build Coastguard Worker _test_huber_loss_delta_error_helper(delta=0.0) 5344*da0073e9SAndroid Build Coastguard Worker 5345*da0073e9SAndroid Build Coastguard Worker test_huber_loss_negative_delta() 5346*da0073e9SAndroid Build Coastguard Worker test_huber_loss_zero_delta() 5347*da0073e9SAndroid Build Coastguard Worker 5348*da0073e9SAndroid Build Coastguard Worker @set_default_dtype(torch.double) 5349*da0073e9SAndroid Build Coastguard Worker def test_cosine_similarity(self): 5350*da0073e9SAndroid Build Coastguard Worker # Check cosine_similarity input/output shapes 5351*da0073e9SAndroid Build Coastguard Worker input_size = (1, 3, 2, 1) 5352*da0073e9SAndroid Build Coastguard Worker expected_size = (1, 2, 1) 5353*da0073e9SAndroid Build Coastguard Worker input1 = torch.randn(input_size, requires_grad=True) 5354*da0073e9SAndroid Build Coastguard Worker input2 = torch.randn(input_size, requires_grad=True) 5355*da0073e9SAndroid Build Coastguard Worker self.assertEqual(F.cosine_similarity(input1, input2, dim=1).size(), expected_size) 5356*da0073e9SAndroid Build Coastguard Worker 5357*da0073e9SAndroid Build Coastguard Worker # Check numerical precision, issue #18057 5358*da0073e9SAndroid Build Coastguard Worker vv1 = torch.tensor([float(i) for i in range(84)]).unsqueeze(0) 5359*da0073e9SAndroid Build Coastguard Worker vv2 = torch.tensor([float(i) for i in range(84)]).unsqueeze(0) 5360*da0073e9SAndroid Build Coastguard Worker out = F.cosine_similarity(vv1, vv2) 5361*da0073e9SAndroid Build Coastguard Worker self.assertLessEqual(out, 1.0) 5362*da0073e9SAndroid Build Coastguard Worker 5363*da0073e9SAndroid Build Coastguard Worker # Check dividing by 0. 5364*da0073e9SAndroid Build Coastguard Worker # previous behavior: <x,y>/max(eps, ||x|| * ||y||) 5365*da0073e9SAndroid Build Coastguard Worker # current: <x/max(eps, ||x||), y/max(eps,||y||)> 5366*da0073e9SAndroid Build Coastguard Worker # if f(x,y) is the cosine similarity, then 5367*da0073e9SAndroid Build Coastguard Worker # df/dx = y/(||x|| * ||y||) - (x * <x,y> * ||y||/||x||)/(||x|| * ||y||)^2 5368*da0073e9SAndroid Build Coastguard Worker # the tests below check division by zero in the backward formula when 5369*da0073e9SAndroid Build Coastguard Worker # x := input2 = 0, y := input1 != 0. 5370*da0073e9SAndroid Build Coastguard Worker # For these inputs the gradient wrt x simplifies to g(x,y) := y/(||x|| * ||y||) 5371*da0073e9SAndroid Build Coastguard Worker # Previous test checks g(x,y) == y/eps, 5372*da0073e9SAndroid Build Coastguard Worker # Current test checks g(x,y) == (y/||y||)/eps. 5373*da0073e9SAndroid Build Coastguard Worker input1 = torch.randn(10).requires_grad_() 5374*da0073e9SAndroid Build Coastguard Worker input2 = torch.zeros_like(input1).requires_grad_() 5375*da0073e9SAndroid Build Coastguard Worker torch.cosine_similarity(input1, input2, 0).sum().backward() 5376*da0073e9SAndroid Build Coastguard Worker self.assertEqual(input1.grad, torch.zeros_like(input1)) 5377*da0073e9SAndroid Build Coastguard Worker self.assertEqual(input2.grad, input1 / input1.norm() * 1e8) 5378*da0073e9SAndroid Build Coastguard Worker 5379*da0073e9SAndroid Build Coastguard Worker # Check type promotion, issue #61454 5380*da0073e9SAndroid Build Coastguard Worker input = torch.tensor(12.) 5381*da0073e9SAndroid Build Coastguard Worker out = F.cosine_similarity(input.to(torch.int8), input, dim=-1) 5382*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out, 1.) 5383*da0073e9SAndroid Build Coastguard Worker 5384*da0073e9SAndroid Build Coastguard Worker # Check broadcasting #109333 5385*da0073e9SAndroid Build Coastguard Worker a = torch.ones(2, 3, dtype=torch.float) 5386*da0073e9SAndroid Build Coastguard Worker b = torch.ones(1, 1, dtype=torch.float) 5387*da0073e9SAndroid Build Coastguard Worker out = F.cosine_similarity(a, b) 5388*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out, torch.ones(2, dtype=torch.float)) 5389*da0073e9SAndroid Build Coastguard Worker 5390*da0073e9SAndroid Build Coastguard Worker a = torch.ones(2, 3, dtype=torch.float) 5391*da0073e9SAndroid Build Coastguard Worker b = torch.ones(1, dtype=torch.float) 5392*da0073e9SAndroid Build Coastguard Worker out = F.cosine_similarity(a, b) 5393*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out, torch.ones(2, dtype=torch.float)) 5394*da0073e9SAndroid Build Coastguard Worker 5395*da0073e9SAndroid Build Coastguard Worker 5396*da0073e9SAndroid Build Coastguard Worker def test_grid_sample_error_checking(self): 5397*da0073e9SAndroid Build Coastguard Worker input = torch.empty(1, 1, 2, 2) 5398*da0073e9SAndroid Build Coastguard Worker grid = torch.empty(1, 1, 1, 2) 5399*da0073e9SAndroid Build Coastguard Worker 5400*da0073e9SAndroid Build Coastguard Worker # assert no error 5401*da0073e9SAndroid Build Coastguard Worker F.grid_sample(input, grid, align_corners=False) 5402*da0073e9SAndroid Build Coastguard Worker 5403*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(ValueError, "but got: 'garbage'"): 5404*da0073e9SAndroid Build Coastguard Worker F.grid_sample(input, grid, mode='garbage', align_corners=False) 5405*da0073e9SAndroid Build Coastguard Worker 5406*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(ValueError, "but got: 'garbage'"): 5407*da0073e9SAndroid Build Coastguard Worker F.grid_sample(input, grid, padding_mode='garbage', align_corners=False) 5408*da0073e9SAndroid Build Coastguard Worker 5409*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "expected grid to have size 1 in last dimension"): 5410*da0073e9SAndroid Build Coastguard Worker F.grid_sample(input[0], grid, align_corners=False) 5411*da0073e9SAndroid Build Coastguard Worker 5412*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "expected grid to have size 2 in last dimension"): 5413*da0073e9SAndroid Build Coastguard Worker F.grid_sample(input, torch.empty(1, 1, 1, 1, 3), align_corners=False) 5414*da0073e9SAndroid Build Coastguard Worker 5415*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "expected grid and input to have same batch size"): 5416*da0073e9SAndroid Build Coastguard Worker F.grid_sample(input, torch.empty(2, 1, 1, 2), align_corners=False) 5417*da0073e9SAndroid Build Coastguard Worker 5418*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "expected grid to have size 2 in last dimension"): 5419*da0073e9SAndroid Build Coastguard Worker F.grid_sample(input, torch.empty(1, 1, 1, 3), align_corners=False) 5420*da0073e9SAndroid Build Coastguard Worker 5421*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "expected input to have non-empty spatial dimensions"): 5422*da0073e9SAndroid Build Coastguard Worker F.grid_sample(torch.empty(1, 1, 0, 2), grid, align_corners=False) 5423*da0073e9SAndroid Build Coastguard Worker 5424*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "bicubic interpolation only supports 4D input"): 5425*da0073e9SAndroid Build Coastguard Worker F.grid_sample(torch.empty(1, 1, 2, 2, 2), torch.empty(1, 1, 1, 1, 3), mode='bicubic') 5426*da0073e9SAndroid Build Coastguard Worker 5427*da0073e9SAndroid Build Coastguard Worker if TEST_CUDA: 5428*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "Expected all tensors to be on the same device"): 5429*da0073e9SAndroid Build Coastguard Worker F.grid_sample(input.cuda(), grid, align_corners=False) 5430*da0073e9SAndroid Build Coastguard Worker 5431*da0073e9SAndroid Build Coastguard Worker def test_affine_grid_error_checking(self): 5432*da0073e9SAndroid Build Coastguard Worker # 2D affine 5433*da0073e9SAndroid Build Coastguard Worker theta = torch.empty(1, 2, 3, dtype=torch.double) 5434*da0073e9SAndroid Build Coastguard Worker size = torch.Size([1, 1, 2, 2]) 5435*da0073e9SAndroid Build Coastguard Worker 5436*da0073e9SAndroid Build Coastguard Worker # assert no error 5437*da0073e9SAndroid Build Coastguard Worker F.affine_grid(theta, size, align_corners=False) 5438*da0073e9SAndroid Build Coastguard Worker 5439*da0073e9SAndroid Build Coastguard Worker # check for warning for empty span along dimension 5440*da0073e9SAndroid Build Coastguard Worker with warnings.catch_warnings(record=True) as w: 5441*da0073e9SAndroid Build Coastguard Worker # Ensure warnings are being shown 5442*da0073e9SAndroid Build Coastguard Worker warnings.simplefilter("always") 5443*da0073e9SAndroid Build Coastguard Worker # Should not trigger warning 5444*da0073e9SAndroid Build Coastguard Worker F.affine_grid(theta, torch.Size([1, 1, 2, 1]), align_corners=False) 5445*da0073e9SAndroid Build Coastguard Worker # Check no warning occurs 5446*da0073e9SAndroid Build Coastguard Worker self.assertNotIn('See the documentation of affine_grid for details.', ' '.join(map(str, w))) 5447*da0073e9SAndroid Build Coastguard Worker # Should trigger warning 5448*da0073e9SAndroid Build Coastguard Worker F.affine_grid(theta, torch.Size([1, 1, 2, 1]), align_corners=True) 5449*da0073e9SAndroid Build Coastguard Worker # Check warning occurs 5450*da0073e9SAndroid Build Coastguard Worker self.assertIn('See the documentation of affine_grid for details.', ' '.join(map(str, w))) 5451*da0073e9SAndroid Build Coastguard Worker 5452*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(ValueError, "Expected theta to have floating point type"): 5453*da0073e9SAndroid Build Coastguard Worker F.affine_grid(theta.int(), size, align_corners=False) 5454*da0073e9SAndroid Build Coastguard Worker 5455*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(ValueError, "Expected a batch of 2D affine matrices of shape Nx2x3"): 5456*da0073e9SAndroid Build Coastguard Worker F.affine_grid(theta[0], size, align_corners=False) 5457*da0073e9SAndroid Build Coastguard Worker 5458*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(ValueError, "Expected a batch of 2D affine matrices of shape Nx2x3"): 5459*da0073e9SAndroid Build Coastguard Worker F.affine_grid(theta.unsqueeze(0), size, align_corners=False) 5460*da0073e9SAndroid Build Coastguard Worker 5461*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(ValueError, "Expected a batch of 2D affine matrices of shape Nx2x3"): 5462*da0073e9SAndroid Build Coastguard Worker F.affine_grid(theta.repeat(1, 2, 1), size, align_corners=False) 5463*da0073e9SAndroid Build Coastguard Worker 5464*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(ValueError, "Expected a batch of 2D affine matrices of shape Nx2x3"): 5465*da0073e9SAndroid Build Coastguard Worker F.affine_grid(theta.repeat(1, 1, 2), size, align_corners=False) 5466*da0073e9SAndroid Build Coastguard Worker 5467*da0073e9SAndroid Build Coastguard Worker # 3D affine 5468*da0073e9SAndroid Build Coastguard Worker theta = torch.empty(1, 3, 4, dtype=torch.double) 5469*da0073e9SAndroid Build Coastguard Worker size = torch.Size([1, 1, 2, 2, 2]) 5470*da0073e9SAndroid Build Coastguard Worker 5471*da0073e9SAndroid Build Coastguard Worker # assert no error 5472*da0073e9SAndroid Build Coastguard Worker F.affine_grid(theta, size, align_corners=False) 5473*da0073e9SAndroid Build Coastguard Worker 5474*da0073e9SAndroid Build Coastguard Worker # check for warning for empty span along dimension 5475*da0073e9SAndroid Build Coastguard Worker with warnings.catch_warnings(record=True) as w: 5476*da0073e9SAndroid Build Coastguard Worker # Ensure warnings are being shown 5477*da0073e9SAndroid Build Coastguard Worker warnings.simplefilter("always") 5478*da0073e9SAndroid Build Coastguard Worker # Should not trigger warning 5479*da0073e9SAndroid Build Coastguard Worker F.affine_grid(theta, torch.Size([1, 1, 3, 2, 1]), align_corners=False) 5480*da0073e9SAndroid Build Coastguard Worker # Check no warning occurs 5481*da0073e9SAndroid Build Coastguard Worker self.assertNotIn('See the documentation of affine_grid for details.', ' '.join(map(str, w))) 5482*da0073e9SAndroid Build Coastguard Worker # Should trigger warning 5483*da0073e9SAndroid Build Coastguard Worker F.affine_grid(theta, torch.Size([1, 1, 3, 2, 1]), align_corners=True) 5484*da0073e9SAndroid Build Coastguard Worker # Check warning occurs 5485*da0073e9SAndroid Build Coastguard Worker self.assertIn('See the documentation of affine_grid for details.', ' '.join(map(str, w))) 5486*da0073e9SAndroid Build Coastguard Worker 5487*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(ValueError, "Expected a batch of 3D affine matrices of shape Nx3x4"): 5488*da0073e9SAndroid Build Coastguard Worker F.affine_grid(theta[0], size, align_corners=False) 5489*da0073e9SAndroid Build Coastguard Worker 5490*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(ValueError, "Expected a batch of 3D affine matrices of shape Nx3x4"): 5491*da0073e9SAndroid Build Coastguard Worker F.affine_grid(theta.unsqueeze(0), size, align_corners=False) 5492*da0073e9SAndroid Build Coastguard Worker 5493*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(ValueError, "Expected a batch of 3D affine matrices of shape Nx3x4"): 5494*da0073e9SAndroid Build Coastguard Worker F.affine_grid(theta.repeat(1, 2, 1), size, align_corners=False) 5495*da0073e9SAndroid Build Coastguard Worker 5496*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(ValueError, "Expected a batch of 3D affine matrices of shape Nx3x4"): 5497*da0073e9SAndroid Build Coastguard Worker F.affine_grid(theta.repeat(1, 1, 2), size, align_corners=False) 5498*da0073e9SAndroid Build Coastguard Worker 5499*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(NotImplementedError, "affine_grid only supports 4D and 5D sizes"): 5500*da0073e9SAndroid Build Coastguard Worker F.affine_grid(theta, torch.Size([1, 2, 2]), align_corners=False) 5501*da0073e9SAndroid Build Coastguard Worker 5502*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(NotImplementedError, "affine_grid only supports 4D and 5D sizes"): 5503*da0073e9SAndroid Build Coastguard Worker F.affine_grid(theta, torch.Size([1, 1, 2, 2, 2, 2]), align_corners=False) 5504*da0073e9SAndroid Build Coastguard Worker 5505*da0073e9SAndroid Build Coastguard Worker @parametrize_test('device', ['cpu'] + (['cuda'] if TEST_CUDA else [])) 5506*da0073e9SAndroid Build Coastguard Worker @parametrize_test('nd', [2, 3]) 5507*da0073e9SAndroid Build Coastguard Worker def test_affine_grid_backward_cl_cf_consistency(self, device, nd): 5508*da0073e9SAndroid Build Coastguard Worker # Test based on reported issue: https://github.com/pytorch/pytorch/issues/124154 5509*da0073e9SAndroid Build Coastguard Worker 5510*da0073e9SAndroid Build Coastguard Worker theta = torch.rand([6, nd, nd + 1], requires_grad=True, device=device) 5511*da0073e9SAndroid Build Coastguard Worker size = [6, 3, 4, 5] if nd == 2 else [6, 3, 4, 5, 5] 5512*da0073e9SAndroid Build Coastguard Worker grid = torch.nn.functional.affine_grid(theta, size, align_corners=False) 5513*da0073e9SAndroid Build Coastguard Worker 5514*da0073e9SAndroid Build Coastguard Worker grad_tensor = torch.rand(grid.shape, device=device) 5515*da0073e9SAndroid Build Coastguard Worker 5516*da0073e9SAndroid Build Coastguard Worker memory_format_cl = torch.channels_last if nd == 2 else torch.channels_last_3d 5517*da0073e9SAndroid Build Coastguard Worker grad_tensor_cl = grad_tensor.contiguous(memory_format=memory_format_cl) 5518*da0073e9SAndroid Build Coastguard Worker 5519*da0073e9SAndroid Build Coastguard Worker assert theta.grad is None 5520*da0073e9SAndroid Build Coastguard Worker grid.backward(grad_tensor_cl) 5521*da0073e9SAndroid Build Coastguard Worker theta_grad_cl = theta.grad.clone().contiguous() 5522*da0073e9SAndroid Build Coastguard Worker 5523*da0073e9SAndroid Build Coastguard Worker theta.grad.zero_() 5524*da0073e9SAndroid Build Coastguard Worker grid.backward(grad_tensor) 5525*da0073e9SAndroid Build Coastguard Worker theta_grad_cf = theta.grad 5526*da0073e9SAndroid Build Coastguard Worker 5527*da0073e9SAndroid Build Coastguard Worker self.assertEqual(theta_grad_cf, theta_grad_cl) 5528*da0073e9SAndroid Build Coastguard Worker 5529*da0073e9SAndroid Build Coastguard Worker @set_default_dtype(torch.double) 5530*da0073e9SAndroid Build Coastguard Worker def test_grid_sample(self): 5531*da0073e9SAndroid Build Coastguard Worker # Backward pass of native C++ and CUDA kernels branch depending on whether input requires gradient, 5532*da0073e9SAndroid Build Coastguard Worker # so we test both cases. 5533*da0073e9SAndroid Build Coastguard Worker def test(N, C, H, W, mode, padding_mode, align_corners, input_requires_grad): 5534*da0073e9SAndroid Build Coastguard Worker def test_shape(N, C, IH, IW, H, W, mode, padding_mode, align_corners): 5535*da0073e9SAndroid Build Coastguard Worker for grid_dim_contig_order in [(0, 1, 2, 3), (0, 3, 1, 2), (3, 0, 1, 2), (0, 2, 1, 3)]: 5536*da0073e9SAndroid Build Coastguard Worker # grid_dim_contig_order specifies the dimension order that can 5537*da0073e9SAndroid Build Coastguard Worker # make grid to be contiguous. 5538*da0073e9SAndroid Build Coastguard Worker # i.e., grid.permute(grid_dim_contig_order) is contiguous. 5539*da0073e9SAndroid Build Coastguard Worker # e.g., with grid_dim_contig_order=[0, 3, 1, 2], grid should be 5540*da0073e9SAndroid Build Coastguard Worker # initialized with contiguous tensor of shape [N, 2, H, W] 5541*da0073e9SAndroid Build Coastguard Worker # and permuted to [N, H, W, 2] afterwards. 5542*da0073e9SAndroid Build Coastguard Worker grid_shape = [N, H, W, 2] 5543*da0073e9SAndroid Build Coastguard Worker grid_init_shape = [grid_shape[d] for d in grid_dim_contig_order] 5544*da0073e9SAndroid Build Coastguard Worker grid_fwd_permute = [None, None, None, None] 5545*da0073e9SAndroid Build Coastguard Worker for i, d in enumerate(grid_dim_contig_order): 5546*da0073e9SAndroid Build Coastguard Worker grid_fwd_permute[d] = i 5547*da0073e9SAndroid Build Coastguard Worker 5548*da0073e9SAndroid Build Coastguard Worker def get_grid(device='cpu', data=None): 5549*da0073e9SAndroid Build Coastguard Worker if data is not None: 5550*da0073e9SAndroid Build Coastguard Worker assert list(data.shape) == grid_shape 5551*da0073e9SAndroid Build Coastguard Worker data = data.permute(grid_dim_contig_order).to(device) 5552*da0073e9SAndroid Build Coastguard Worker else: 5553*da0073e9SAndroid Build Coastguard Worker data = torch.randn(grid_init_shape, device=device) 5554*da0073e9SAndroid Build Coastguard Worker grid = data.permute(grid_fwd_permute) 5555*da0073e9SAndroid Build Coastguard Worker assert grid.permute(grid_dim_contig_order).is_contiguous() 5556*da0073e9SAndroid Build Coastguard Worker return grid 5557*da0073e9SAndroid Build Coastguard Worker 5558*da0073e9SAndroid Build Coastguard Worker input_cpu = torch.randn(C, N, IH, IW).transpose(0, 1).requires_grad_(input_requires_grad) 5559*da0073e9SAndroid Build Coastguard Worker grid_cpu = get_grid().requires_grad_() 5560*da0073e9SAndroid Build Coastguard Worker out_cpu = F.grid_sample(input_cpu, grid_cpu, mode=mode, padding_mode=padding_mode, 5561*da0073e9SAndroid Build Coastguard Worker align_corners=align_corners) 5562*da0073e9SAndroid Build Coastguard Worker self.assertTrue(out_cpu.size() == torch.Size([N, C, H, W])) 5563*da0073e9SAndroid Build Coastguard Worker 5564*da0073e9SAndroid Build Coastguard Worker gradients = torch.randn_like(out_cpu) 5565*da0073e9SAndroid Build Coastguard Worker out_cpu.backward(gradients) 5566*da0073e9SAndroid Build Coastguard Worker 5567*da0073e9SAndroid Build Coastguard Worker 5568*da0073e9SAndroid Build Coastguard Worker # Compare against unvectorized CPU fallback 5569*da0073e9SAndroid Build Coastguard Worker 5570*da0073e9SAndroid Build Coastguard Worker # NOTE [ grid_sample CPU fallback ] 5571*da0073e9SAndroid Build Coastguard Worker # grid_sample uses AVX for 2d images, but that requires 32-bit indexing for 5572*da0073e9SAndroid Build Coastguard Worker # 32-bit floats. So we also have a fallback that is used only for float tensors 5573*da0073e9SAndroid Build Coastguard Worker # requiring 64-bit indexing. That requires too much memory to run on CI, so we 5574*da0073e9SAndroid Build Coastguard Worker # also export the fallback and test it here to ensure feature parity with 5575*da0073e9SAndroid Build Coastguard Worker # the vectorized version. 5576*da0073e9SAndroid Build Coastguard Worker input_fallback = input_cpu.float().detach_().requires_grad_() 5577*da0073e9SAndroid Build Coastguard Worker grid_fallback = grid_cpu.float().detach_().requires_grad_() 5578*da0073e9SAndroid Build Coastguard Worker out_fallback = torch._grid_sampler_2d_cpu_fallback( 5579*da0073e9SAndroid Build Coastguard Worker input_fallback, grid_fallback, 5580*da0073e9SAndroid Build Coastguard Worker F.GRID_SAMPLE_INTERPOLATION_MODES[mode], 5581*da0073e9SAndroid Build Coastguard Worker F.GRID_SAMPLE_PADDING_MODES[padding_mode], 5582*da0073e9SAndroid Build Coastguard Worker align_corners) 5583*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out_fallback, out_cpu.float(), atol=1e-5, rtol=5e-5) 5584*da0073e9SAndroid Build Coastguard Worker 5585*da0073e9SAndroid Build Coastguard Worker out_fallback.backward(gradients.float()) 5586*da0073e9SAndroid Build Coastguard Worker if input_requires_grad: 5587*da0073e9SAndroid Build Coastguard Worker self.assertEqual(input_fallback.grad, input_cpu.grad.float(), atol=1e-4, rtol=5e-5) 5588*da0073e9SAndroid Build Coastguard Worker self.assertEqual(grid_fallback.grad, grid_cpu.grad.float(), atol=1e-4, rtol=5e-5) 5589*da0073e9SAndroid Build Coastguard Worker 5590*da0073e9SAndroid Build Coastguard Worker if TEST_CUDA: 5591*da0073e9SAndroid Build Coastguard Worker input_cuda = input_cpu.detach().transpose(0, 1).cuda().transpose(0, 1).requires_grad_(input_requires_grad) 5592*da0073e9SAndroid Build Coastguard Worker grid_cuda = get_grid('cuda', grid_cpu.detach()).requires_grad_() 5593*da0073e9SAndroid Build Coastguard Worker out_cuda = F.grid_sample(input_cuda, grid_cuda, mode=mode, padding_mode=padding_mode, 5594*da0073e9SAndroid Build Coastguard Worker align_corners=align_corners) 5595*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out_cpu, out_cuda) 5596*da0073e9SAndroid Build Coastguard Worker 5597*da0073e9SAndroid Build Coastguard Worker out_cuda.backward(gradients.cuda()) 5598*da0073e9SAndroid Build Coastguard Worker if input_requires_grad: 5599*da0073e9SAndroid Build Coastguard Worker self.assertEqual(input_cpu.grad, input_cuda.grad) 5600*da0073e9SAndroid Build Coastguard Worker self.assertEqual(grid_cpu.grad, grid_cuda.grad, atol=5e-5, rtol=0) 5601*da0073e9SAndroid Build Coastguard Worker 5602*da0073e9SAndroid Build Coastguard Worker # check that zero-dimensional input strides don't error out 5603*da0073e9SAndroid Build Coastguard Worker base_input = torch.randn(N, C, 1, IW) 5604*da0073e9SAndroid Build Coastguard Worker input_cpu = base_input.expand_as(input_cuda).requires_grad_(input_requires_grad) 5605*da0073e9SAndroid Build Coastguard Worker out_cpu = F.grid_sample(input_cpu, grid_cpu, mode=mode, padding_mode=padding_mode, 5606*da0073e9SAndroid Build Coastguard Worker align_corners=align_corners) 5607*da0073e9SAndroid Build Coastguard Worker 5608*da0073e9SAndroid Build Coastguard Worker input_cuda = base_input.cuda().expand_as(input_cuda).requires_grad_(input_requires_grad) 5609*da0073e9SAndroid Build Coastguard Worker out_cuda = F.grid_sample(input_cuda, grid_cuda, mode=mode, padding_mode=padding_mode, 5610*da0073e9SAndroid Build Coastguard Worker align_corners=align_corners) 5611*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out_cpu, out_cuda) 5612*da0073e9SAndroid Build Coastguard Worker 5613*da0073e9SAndroid Build Coastguard Worker # test same size output 5614*da0073e9SAndroid Build Coastguard Worker test_shape(N, C, H, W, H, W, mode, padding_mode, align_corners) 5615*da0073e9SAndroid Build Coastguard Worker 5616*da0073e9SAndroid Build Coastguard Worker # test larger output 5617*da0073e9SAndroid Build Coastguard Worker N = random.randint(2, 8) 5618*da0073e9SAndroid Build Coastguard Worker C = random.randint(2, 8) 5619*da0073e9SAndroid Build Coastguard Worker IH = random.randint(2, 8) 5620*da0073e9SAndroid Build Coastguard Worker IW = random.randint(2, 8) 5621*da0073e9SAndroid Build Coastguard Worker H = random.randint(IH + 1, 12) 5622*da0073e9SAndroid Build Coastguard Worker W = random.randint(IW + 1, 12) 5623*da0073e9SAndroid Build Coastguard Worker test_shape(N, C, IH, IW, H, W, mode, padding_mode, align_corners) 5624*da0073e9SAndroid Build Coastguard Worker 5625*da0073e9SAndroid Build Coastguard Worker # test smaller output 5626*da0073e9SAndroid Build Coastguard Worker N = random.randint(2, 8) 5627*da0073e9SAndroid Build Coastguard Worker C = random.randint(2, 8) 5628*da0073e9SAndroid Build Coastguard Worker IH = random.randint(2, 8) 5629*da0073e9SAndroid Build Coastguard Worker IW = random.randint(2, 8) 5630*da0073e9SAndroid Build Coastguard Worker H = random.randint(2, IH) 5631*da0073e9SAndroid Build Coastguard Worker W = random.randint(2, IW) 5632*da0073e9SAndroid Build Coastguard Worker test_shape(N, C, IH, IW, H, W, mode, padding_mode, align_corners) 5633*da0073e9SAndroid Build Coastguard Worker 5634*da0073e9SAndroid Build Coastguard Worker # test 1x1 inpput 5635*da0073e9SAndroid Build Coastguard Worker N = random.randint(2, 8) 5636*da0073e9SAndroid Build Coastguard Worker C = random.randint(2, 8) 5637*da0073e9SAndroid Build Coastguard Worker IH = 1 5638*da0073e9SAndroid Build Coastguard Worker IW = 1 5639*da0073e9SAndroid Build Coastguard Worker H = random.randint(2, 5) 5640*da0073e9SAndroid Build Coastguard Worker W = random.randint(2, 5) 5641*da0073e9SAndroid Build Coastguard Worker test_shape(N, C, IH, IW, H, W, mode, padding_mode, align_corners) 5642*da0073e9SAndroid Build Coastguard Worker 5643*da0073e9SAndroid Build Coastguard Worker # testing empty grid 5644*da0073e9SAndroid Build Coastguard Worker N = random.randint(2, 8) 5645*da0073e9SAndroid Build Coastguard Worker C = random.randint(2, 8) 5646*da0073e9SAndroid Build Coastguard Worker IH = random.randint(2, 8) 5647*da0073e9SAndroid Build Coastguard Worker IW = random.randint(2, 8) 5648*da0073e9SAndroid Build Coastguard Worker W = random.randint(3, IW + 2) 5649*da0073e9SAndroid Build Coastguard Worker test_shape(N, C, IH, IW, 0, W, mode, padding_mode, align_corners) 5650*da0073e9SAndroid Build Coastguard Worker 5651*da0073e9SAndroid Build Coastguard Worker # testing empty channel 5652*da0073e9SAndroid Build Coastguard Worker N = random.randint(2, 8) 5653*da0073e9SAndroid Build Coastguard Worker IH = random.randint(2, 8) 5654*da0073e9SAndroid Build Coastguard Worker IW = random.randint(2, 8) 5655*da0073e9SAndroid Build Coastguard Worker H = random.randint(3, IH + 2) 5656*da0073e9SAndroid Build Coastguard Worker W = random.randint(3, IW + 2) 5657*da0073e9SAndroid Build Coastguard Worker test_shape(N, 0, IH, IW, H, W, mode, padding_mode, align_corners) 5658*da0073e9SAndroid Build Coastguard Worker 5659*da0073e9SAndroid Build Coastguard Worker # testing empty batch 5660*da0073e9SAndroid Build Coastguard Worker C = random.randint(2, 8) 5661*da0073e9SAndroid Build Coastguard Worker IH = random.randint(2, 8) 5662*da0073e9SAndroid Build Coastguard Worker IW = random.randint(2, 8) 5663*da0073e9SAndroid Build Coastguard Worker H = random.randint(3, IH + 2) 5664*da0073e9SAndroid Build Coastguard Worker W = random.randint(3, IW + 2) 5665*da0073e9SAndroid Build Coastguard Worker test_shape(0, C, IH, IW, H, W, mode, padding_mode, align_corners) 5666*da0073e9SAndroid Build Coastguard Worker 5667*da0073e9SAndroid Build Coastguard Worker for mode in ('bilinear', 'nearest', 'bicubic'): 5668*da0073e9SAndroid Build Coastguard Worker for padding_mode in ('zeros', 'border', 'reflection'): 5669*da0073e9SAndroid Build Coastguard Worker for align_corners in (True, False): 5670*da0073e9SAndroid Build Coastguard Worker # test known input on CPU 5671*da0073e9SAndroid Build Coastguard Worker input = torch.arange(1., 11).view(1, 1, 2, 5) 5672*da0073e9SAndroid Build Coastguard Worker grid = torch.tensor( 5673*da0073e9SAndroid Build Coastguard Worker [[[-0.9, -4.1], [0, 0.2000], [1, -1], [-0.333, 1e-6], [0.5, 1.0]], 5674*da0073e9SAndroid Build Coastguard Worker [[-1.0, -0.5], [0, 0.3333], [1, -1], [-0.200, 1e-6], [1.5, 0.5]]]).view(1, 2, 5, 2) 5675*da0073e9SAndroid Build Coastguard Worker if mode == 'bilinear': 5676*da0073e9SAndroid Build Coastguard Worker if padding_mode == 'zeros': 5677*da0073e9SAndroid Build Coastguard Worker if align_corners: 5678*da0073e9SAndroid Build Coastguard Worker groundtruth = torch.tensor( 5679*da0073e9SAndroid Build Coastguard Worker [[0.0000, 6.0000000000, 5.0000, 4.8340, 9.0000], 5680*da0073e9SAndroid Build Coastguard Worker [2.2500, 6.3332500450, 5.0000, 5.1000, 0.0000]]).view(1, 1, 2, 5) 5681*da0073e9SAndroid Build Coastguard Worker else: 5682*da0073e9SAndroid Build Coastguard Worker groundtruth = torch.tensor( 5683*da0073e9SAndroid Build Coastguard Worker [[0.0000, 6.5000000000, 1.2500, 4.6675000191, 4.6250], 5684*da0073e9SAndroid Build Coastguard Worker [0.5000, 7.1665000916, 1.2500, 5.0000000000, 0.0000]]).view(1, 1, 2, 5) 5685*da0073e9SAndroid Build Coastguard Worker elif padding_mode == 'border': 5686*da0073e9SAndroid Build Coastguard Worker if align_corners: 5687*da0073e9SAndroid Build Coastguard Worker groundtruth = torch.tensor( 5688*da0073e9SAndroid Build Coastguard Worker [[1.2000, 6.0000000000, 5.0000, 4.8340, 9.0000], 5689*da0073e9SAndroid Build Coastguard Worker [2.2500, 6.3332500450, 5.0000, 5.1000, 8.7500]]).view(1, 1, 2, 5) 5690*da0073e9SAndroid Build Coastguard Worker else: 5691*da0073e9SAndroid Build Coastguard Worker groundtruth = torch.tensor( 5692*da0073e9SAndroid Build Coastguard Worker [[1.0000, 6.5000000000, 5.0000, 4.6675000191, 9.2500], 5693*da0073e9SAndroid Build Coastguard Worker [1.0000, 7.1665000916, 5.0000, 5.0000000000, 10.0000]]).view(1, 1, 2, 5) 5694*da0073e9SAndroid Build Coastguard Worker elif padding_mode == 'reflection': 5695*da0073e9SAndroid Build Coastguard Worker if align_corners: 5696*da0073e9SAndroid Build Coastguard Worker groundtruth = torch.tensor( 5697*da0073e9SAndroid Build Coastguard Worker [[3.4500, 6.0000000000, 5.0000, 4.8340, 9.0000], 5698*da0073e9SAndroid Build Coastguard Worker [2.2500, 6.3332500450, 5.0000, 5.1000, 7.7500]]).view(1, 1, 2, 5) 5699*da0073e9SAndroid Build Coastguard Worker else: 5700*da0073e9SAndroid Build Coastguard Worker groundtruth = torch.tensor( 5701*da0073e9SAndroid Build Coastguard Worker [[3.0000004768, 6.5000000000, 5.0000, 4.6675000191, 9.2500], 5702*da0073e9SAndroid Build Coastguard Worker [1.0000000000, 7.1665000916, 5.0000, 5.0000000000, 9.2500]]).view(1, 1, 2, 5) 5703*da0073e9SAndroid Build Coastguard Worker else: 5704*da0073e9SAndroid Build Coastguard Worker raise AssertionError(f"missing groundtruth test for padding mode '{padding_mode}'") 5705*da0073e9SAndroid Build Coastguard Worker elif mode == 'nearest': 5706*da0073e9SAndroid Build Coastguard Worker if padding_mode == 'zeros': 5707*da0073e9SAndroid Build Coastguard Worker if align_corners: 5708*da0073e9SAndroid Build Coastguard Worker groundtruth = torch.tensor( 5709*da0073e9SAndroid Build Coastguard Worker [[0., 8., 5., 7., 9.], 5710*da0073e9SAndroid Build Coastguard Worker [1., 8., 5., 8., 0.]]).view(1, 1, 2, 5) 5711*da0073e9SAndroid Build Coastguard Worker else: 5712*da0073e9SAndroid Build Coastguard Worker groundtruth = torch.tensor( 5713*da0073e9SAndroid Build Coastguard Worker [[0., 8., 5., 7., 0.], 5714*da0073e9SAndroid Build Coastguard Worker [1., 8., 5., 8., 0.]]).view(1, 1, 2, 5) 5715*da0073e9SAndroid Build Coastguard Worker elif padding_mode == 'border': 5716*da0073e9SAndroid Build Coastguard Worker if align_corners: 5717*da0073e9SAndroid Build Coastguard Worker groundtruth = torch.tensor( 5718*da0073e9SAndroid Build Coastguard Worker [[1., 8., 5., 7., 9.], 5719*da0073e9SAndroid Build Coastguard Worker [1., 8., 5., 8., 10.]]).view(1, 1, 2, 5) 5720*da0073e9SAndroid Build Coastguard Worker else: 5721*da0073e9SAndroid Build Coastguard Worker groundtruth = torch.tensor( 5722*da0073e9SAndroid Build Coastguard Worker [[1., 8., 5., 7., 9.], 5723*da0073e9SAndroid Build Coastguard Worker [1., 8., 5., 8., 10.]]).view(1, 1, 2, 5) 5724*da0073e9SAndroid Build Coastguard Worker elif padding_mode == 'reflection': 5725*da0073e9SAndroid Build Coastguard Worker if align_corners: 5726*da0073e9SAndroid Build Coastguard Worker groundtruth = torch.tensor( 5727*da0073e9SAndroid Build Coastguard Worker [[1., 8., 5., 7., 9.], 5728*da0073e9SAndroid Build Coastguard Worker [1., 8., 5., 8., 9.]]).view(1, 1, 2, 5) 5729*da0073e9SAndroid Build Coastguard Worker else: 5730*da0073e9SAndroid Build Coastguard Worker groundtruth = torch.tensor( 5731*da0073e9SAndroid Build Coastguard Worker [[1., 8., 5., 7., 9.], 5732*da0073e9SAndroid Build Coastguard Worker [1., 8., 5., 8., 9.]]).view(1, 1, 2, 5) 5733*da0073e9SAndroid Build Coastguard Worker else: 5734*da0073e9SAndroid Build Coastguard Worker raise AssertionError(f"missing groundtruth test for padding mode '{padding_mode}'") 5735*da0073e9SAndroid Build Coastguard Worker elif mode == 'bicubic': 5736*da0073e9SAndroid Build Coastguard Worker if padding_mode == 'zeros': 5737*da0073e9SAndroid Build Coastguard Worker if align_corners: 5738*da0073e9SAndroid Build Coastguard Worker groundtruth = torch.tensor( 5739*da0073e9SAndroid Build Coastguard Worker [[-0.10424726, 7.1400003, 5.0000, 5.7842274, 9.0000], 5740*da0073e9SAndroid Build Coastguard Worker [2.4492188, 7.4814040, 5.0000, 6.0277520, 0.0000]]).view(1, 1, 2, 5) 5741*da0073e9SAndroid Build Coastguard Worker else: 5742*da0073e9SAndroid Build Coastguard Worker groundtruth = torch.tensor( 5743*da0073e9SAndroid Build Coastguard Worker [[0.00000, 7.6287503, 1.0625, 5.5977230, 5.3270264], 5744*da0073e9SAndroid Build Coastguard Worker [0.40625, 8.0288770, 1.0625, 5.9375067, -0.3515625]]).view(1, 1, 2, 5) 5745*da0073e9SAndroid Build Coastguard Worker elif padding_mode == 'border': 5746*da0073e9SAndroid Build Coastguard Worker if align_corners: 5747*da0073e9SAndroid Build Coastguard Worker groundtruth = torch.tensor( 5748*da0073e9SAndroid Build Coastguard Worker [[1.1520010, 6.0599990, 5.0000, 4.870930, 9.0000000], 5749*da0073e9SAndroid Build Coastguard Worker [2.1328125, 6.4258375, 5.0000, 5.076003, 8.8671875]]).view(1, 1, 2, 5) 5750*da0073e9SAndroid Build Coastguard Worker else: 5751*da0073e9SAndroid Build Coastguard Worker groundtruth = torch.tensor( 5752*da0073e9SAndroid Build Coastguard Worker [[0.894531, 6.6050020, 4.625, 4.7138715, 9.800781], 5753*da0073e9SAndroid Build Coastguard Worker [0.906250, 7.2822485, 4.625, 5.0000052, 10.00000]]).view(1, 1, 2, 5) 5754*da0073e9SAndroid Build Coastguard Worker elif padding_mode == 'reflection': 5755*da0073e9SAndroid Build Coastguard Worker if align_corners: 5756*da0073e9SAndroid Build Coastguard Worker groundtruth = torch.tensor( 5757*da0073e9SAndroid Build Coastguard Worker [[3.1822524, 6.239998, 5.0000, 4.8709273, 9.00000], 5758*da0073e9SAndroid Build Coastguard Worker [1.7812500, 6.703594, 5.0000, 5.0760007, 8.21875]]).view(1, 1, 2, 5) 5759*da0073e9SAndroid Build Coastguard Worker else: 5760*da0073e9SAndroid Build Coastguard Worker groundtruth = torch.tensor( 5761*da0073e9SAndroid Build Coastguard Worker [[2.7993753, 6.6050020, 4.25, 4.7138715, 10.269531], 5762*da0073e9SAndroid Build Coastguard Worker [0.8125000, 7.2822485, 4.25, 5.0000052, 9.332031]]).view(1, 1, 2, 5) 5763*da0073e9SAndroid Build Coastguard Worker else: 5764*da0073e9SAndroid Build Coastguard Worker raise AssertionError(f"missing groundtruth test for padding mode '{padding_mode}'") 5765*da0073e9SAndroid Build Coastguard Worker 5766*da0073e9SAndroid Build Coastguard Worker else: 5767*da0073e9SAndroid Build Coastguard Worker raise AssertionError(f"missing groundtruth test for interpolation mode '{mode}'") 5768*da0073e9SAndroid Build Coastguard Worker output = F.grid_sample(input, grid, mode=mode, padding_mode=padding_mode, 5769*da0073e9SAndroid Build Coastguard Worker align_corners=align_corners) 5770*da0073e9SAndroid Build Coastguard Worker self.assertEqual(output, groundtruth, atol=1e-5, rtol=0, 5771*da0073e9SAndroid Build Coastguard Worker msg=f"groundtruth comparison failed for mode={mode}, " 5772*da0073e9SAndroid Build Coastguard Worker f"padding_mode={padding_mode}") 5773*da0073e9SAndroid Build Coastguard Worker 5774*da0073e9SAndroid Build Coastguard Worker # See NOTE [ grid_sample CPU fallback ] 5775*da0073e9SAndroid Build Coastguard Worker output = torch._grid_sampler_2d_cpu_fallback( 5776*da0073e9SAndroid Build Coastguard Worker input.float(), grid.float(), 5777*da0073e9SAndroid Build Coastguard Worker F.GRID_SAMPLE_INTERPOLATION_MODES[mode], 5778*da0073e9SAndroid Build Coastguard Worker F.GRID_SAMPLE_PADDING_MODES[padding_mode], 5779*da0073e9SAndroid Build Coastguard Worker align_corners) 5780*da0073e9SAndroid Build Coastguard Worker self.assertEqual(output, groundtruth.float(), atol=1e-5, rtol=0) 5781*da0073e9SAndroid Build Coastguard Worker 5782*da0073e9SAndroid Build Coastguard Worker # explicit check for gradient edge cases 5783*da0073e9SAndroid Build Coastguard Worker input = torch.arange(0., 5).expand((1, 1, 5, 5)) 5784*da0073e9SAndroid Build Coastguard Worker grid = torch.tensor( 5785*da0073e9SAndroid Build Coastguard Worker [[[1.0, 1.0], [1.0, -1.0], [0.8, 0.8], [0.8, -0.8]], 5786*da0073e9SAndroid Build Coastguard Worker [[-1.0, -1.0], [-1.0, 1.0], [-0.8, -0.8], [-0.8, 0.8]]]).view(1, 2, 4, 2).requires_grad_() 5787*da0073e9SAndroid Build Coastguard Worker if mode == 'bilinear': 5788*da0073e9SAndroid Build Coastguard Worker if padding_mode == 'zeros': 5789*da0073e9SAndroid Build Coastguard Worker if align_corners: 5790*da0073e9SAndroid Build Coastguard Worker groundtruth = torch.tensor( 5791*da0073e9SAndroid Build Coastguard Worker [[[[-8., -8.], [-8., 0.], [2., 0.], [2., 0.]], 5792*da0073e9SAndroid Build Coastguard Worker [[2., 0.], [2., 0.], [2., 0.], [2., 0.]]]]).view(1, 2, 4, 2) 5793*da0073e9SAndroid Build Coastguard Worker else: 5794*da0073e9SAndroid Build Coastguard Worker groundtruth = torch.tensor( 5795*da0073e9SAndroid Build Coastguard Worker [[[[-5., -5.], [-5., 5.], [-10., -10.], [-10., 10.]], 5796*da0073e9SAndroid Build Coastguard Worker [[0., 0.], [0., 0.], [0., 0.], [0., 0.]]]]).view(1, 2, 4, 2) 5797*da0073e9SAndroid Build Coastguard Worker elif padding_mode == 'border': 5798*da0073e9SAndroid Build Coastguard Worker if align_corners: 5799*da0073e9SAndroid Build Coastguard Worker groundtruth = torch.tensor( 5800*da0073e9SAndroid Build Coastguard Worker [[[[-0., -0.], [-0., 0.], [2., 0.], [2., 0.]], 5801*da0073e9SAndroid Build Coastguard Worker [[0., 0.], [0., 0.], [2., 0.], [2., 0.]]]]).view(1, 2, 4, 2) 5802*da0073e9SAndroid Build Coastguard Worker else: 5803*da0073e9SAndroid Build Coastguard Worker groundtruth = torch.tensor( 5804*da0073e9SAndroid Build Coastguard Worker [[[[-0., -0.], [-0., 0.], [-0., -0.], [-0., 0.]], 5805*da0073e9SAndroid Build Coastguard Worker [[0., 0.], [0., 0.], [0., 0.], [0., 0.]]]]).view(1, 2, 4, 2) 5806*da0073e9SAndroid Build Coastguard Worker elif padding_mode == 'reflection': 5807*da0073e9SAndroid Build Coastguard Worker if align_corners: 5808*da0073e9SAndroid Build Coastguard Worker groundtruth = torch.tensor( 5809*da0073e9SAndroid Build Coastguard Worker [[[[-0., -0.], [-0., 0.], [2., 0.], [2., 0.]], 5810*da0073e9SAndroid Build Coastguard Worker [[0., 0.], [0., 0.], [2., 0.], [2., 0.]]]]).view(1, 2, 4, 2) 5811*da0073e9SAndroid Build Coastguard Worker else: 5812*da0073e9SAndroid Build Coastguard Worker groundtruth = torch.tensor( 5813*da0073e9SAndroid Build Coastguard Worker [[[[-0., -0.], [-0., 0.], [-0., -0.], [-0., 0.]], 5814*da0073e9SAndroid Build Coastguard Worker [[0., 0.], [0., 0.], [0., 0.], [0., 0.]]]]).view(1, 2, 4, 2) 5815*da0073e9SAndroid Build Coastguard Worker else: 5816*da0073e9SAndroid Build Coastguard Worker raise AssertionError(f"missing gradient groundtruth test for padding mode '{padding_mode}'") 5817*da0073e9SAndroid Build Coastguard Worker elif mode == 'nearest': 5818*da0073e9SAndroid Build Coastguard Worker groundtruth = torch.tensor( 5819*da0073e9SAndroid Build Coastguard Worker [[[[-0., -0.], [-0., 0.], [-0., -0.], [-0., 0.]], 5820*da0073e9SAndroid Build Coastguard Worker [[0., 0.], [0., 0.], [0., 0.], [0., 0.]]]]).view(1, 2, 4, 2) 5821*da0073e9SAndroid Build Coastguard Worker elif mode == 'bicubic': 5822*da0073e9SAndroid Build Coastguard Worker if padding_mode == 'zeros': 5823*da0073e9SAndroid Build Coastguard Worker if align_corners: 5824*da0073e9SAndroid Build Coastguard Worker groundtruth = torch.tensor( 5825*da0073e9SAndroid Build Coastguard Worker [[[[-4.5, -6.], [-4.5, 6.], [2.725679, 0.740878], [2.725679, -0.740878]], 5826*da0073e9SAndroid Build Coastguard Worker [[1.5, 0.], [1.5, 0.], [1.927921, -0.05688], [1.927921, 0.05688]]]]).view(1, 2, 4, 2) 5827*da0073e9SAndroid Build Coastguard Worker else: 5828*da0073e9SAndroid Build Coastguard Worker groundtruth = torch.tensor( 5829*da0073e9SAndroid Build Coastguard Worker [[[[-5.859375, -5.888672], [-5.859375, 5.888672], [-5.6250, -7.5000], [-5.6250, 7.5000]], 5830*da0073e9SAndroid Build Coastguard Worker [[-0.234375, -0.263672], [-0.234375, 0.263672], [1.8750, 0.], [1.8750, 0.]]]] 5831*da0073e9SAndroid Build Coastguard Worker ).view(1, 2, 4, 2) 5832*da0073e9SAndroid Build Coastguard Worker elif padding_mode == 'border': 5833*da0073e9SAndroid Build Coastguard Worker if align_corners: 5834*da0073e9SAndroid Build Coastguard Worker groundtruth = torch.tensor( 5835*da0073e9SAndroid Build Coastguard Worker [[[[1.5, 0.], [1.5, 0.], [1.74, 0.], [1.74, 0.]], 5836*da0073e9SAndroid Build Coastguard Worker [[1.5, 0.], [1.5, 0.], [1.74, 0.], [1.74, 0.]]]]).view(1, 2, 4, 2) 5837*da0073e9SAndroid Build Coastguard Worker else: 5838*da0073e9SAndroid Build Coastguard Worker groundtruth = torch.tensor( 5839*da0073e9SAndroid Build Coastguard Worker [[[[-0.46875, 0.], [-0.46875, 0.], [1.8750, 0.], [1.8750, 0.]], 5840*da0073e9SAndroid Build Coastguard Worker [[-0.46875, 0.], [-0.46875, 0.], [1.8750, 0.], [1.8750, 0.]]]]).view(1, 2, 4, 2) 5841*da0073e9SAndroid Build Coastguard Worker elif padding_mode == 'reflection': 5842*da0073e9SAndroid Build Coastguard Worker if align_corners: 5843*da0073e9SAndroid Build Coastguard Worker groundtruth = torch.tensor( 5844*da0073e9SAndroid Build Coastguard Worker [[[[0., 0.], [0., 0.], [1.92, 0.], [1.92, 0.]], 5845*da0073e9SAndroid Build Coastguard Worker [[0., 0.], [0., 0.], [1.92, 0.], [1.92, 0.]]]]).view(1, 2, 4, 2) 5846*da0073e9SAndroid Build Coastguard Worker else: 5847*da0073e9SAndroid Build Coastguard Worker groundtruth = torch.tensor( 5848*da0073e9SAndroid Build Coastguard Worker [[[[0., 0.], [0., 0.], [1.875, 0.], [1.875, 0.]], 5849*da0073e9SAndroid Build Coastguard Worker [[0., 0.], [0., 0.], [1.875, 0.], [1.875, 0.]]]]).view(1, 2, 4, 2) 5850*da0073e9SAndroid Build Coastguard Worker else: 5851*da0073e9SAndroid Build Coastguard Worker raise AssertionError(f"missing gradient groundtruth test for padding mode '{padding_mode}'") 5852*da0073e9SAndroid Build Coastguard Worker else: 5853*da0073e9SAndroid Build Coastguard Worker raise AssertionError(f"missing gradient groundtruth test for interpolation mode '{mode}'") 5854*da0073e9SAndroid Build Coastguard Worker for input_requires_grad in [False, True]: 5855*da0073e9SAndroid Build Coastguard Worker input = input.requires_grad_(input_requires_grad) 5856*da0073e9SAndroid Build Coastguard Worker F.grid_sample(input, grid, mode=mode, padding_mode=padding_mode, 5857*da0073e9SAndroid Build Coastguard Worker align_corners=align_corners).sum().backward() 5858*da0073e9SAndroid Build Coastguard Worker self.assertEqual(grid.grad, groundtruth, atol=1e-5, rtol=0, 5859*da0073e9SAndroid Build Coastguard Worker msg=f"gradient groundtruth comparison failed for mode={mode}, " 5860*da0073e9SAndroid Build Coastguard Worker f"padding_mode={padding_mode}, input_requires_grad={input_requires_grad}") 5861*da0073e9SAndroid Build Coastguard Worker grid.grad.zero_() 5862*da0073e9SAndroid Build Coastguard Worker 5863*da0073e9SAndroid Build Coastguard Worker # See NOTE [ grid_sample CPU fallback ] 5864*da0073e9SAndroid Build Coastguard Worker torch._grid_sampler_2d_cpu_fallback( 5865*da0073e9SAndroid Build Coastguard Worker input.float(), grid.float(), 5866*da0073e9SAndroid Build Coastguard Worker F.GRID_SAMPLE_INTERPOLATION_MODES[mode], 5867*da0073e9SAndroid Build Coastguard Worker F.GRID_SAMPLE_PADDING_MODES[padding_mode], 5868*da0073e9SAndroid Build Coastguard Worker align_corners).sum().backward() 5869*da0073e9SAndroid Build Coastguard Worker self.assertEqual(grid.grad, groundtruth, atol=1e-5, rtol=0) 5870*da0073e9SAndroid Build Coastguard Worker 5871*da0073e9SAndroid Build Coastguard Worker # do gradcheck 5872*da0073e9SAndroid Build Coastguard Worker N = random.randint(2, 8) 5873*da0073e9SAndroid Build Coastguard Worker C = random.randint(2, 6) 5874*da0073e9SAndroid Build Coastguard Worker H = random.randint(2, 8) 5875*da0073e9SAndroid Build Coastguard Worker W = random.randint(2, 8) 5876*da0073e9SAndroid Build Coastguard Worker input = torch.randn(N, C, H, W, requires_grad=True) 5877*da0073e9SAndroid Build Coastguard Worker grid = torch.randn(N, H, W, 2, requires_grad=True) 5878*da0073e9SAndroid Build Coastguard Worker 5879*da0073e9SAndroid Build Coastguard Worker for input_requires_grad in [False, True]: 5880*da0073e9SAndroid Build Coastguard Worker input.requires_grad_(input_requires_grad) 5881*da0073e9SAndroid Build Coastguard Worker self.assertTrue(gradcheck( 5882*da0073e9SAndroid Build Coastguard Worker lambda inp, grd: F.grid_sample(inp, grd, mode=mode, padding_mode=padding_mode, 5883*da0073e9SAndroid Build Coastguard Worker align_corners=align_corners), 5884*da0073e9SAndroid Build Coastguard Worker (input, grid))) 5885*da0073e9SAndroid Build Coastguard Worker test(N, C, H, W, mode, padding_mode, align_corners, input_requires_grad) 5886*da0073e9SAndroid Build Coastguard Worker if TEST_CUDNN: 5887*da0073e9SAndroid Build Coastguard Worker with cudnn.flags(enabled=False): 5888*da0073e9SAndroid Build Coastguard Worker test(N, C, H, W, mode, padding_mode, align_corners, input_requires_grad) 5889*da0073e9SAndroid Build Coastguard Worker 5890*da0073e9SAndroid Build Coastguard Worker @set_default_dtype(torch.double) 5891*da0073e9SAndroid Build Coastguard Worker def test_grid_sample_3d(self): 5892*da0073e9SAndroid Build Coastguard Worker # Backward pass of native C++ and CUDA kernels branch depending on whether input requires gradient, 5893*da0073e9SAndroid Build Coastguard Worker # so we test both cases. 5894*da0073e9SAndroid Build Coastguard Worker def test(N, C, D, H, W, mode, padding_mode, align_corners, input_requires_grad): 5895*da0073e9SAndroid Build Coastguard Worker def test_shape(N, C, ID, IH, IW, D, H, W, mode, padding_mode, align_corners): 5896*da0073e9SAndroid Build Coastguard Worker input_cpu = torch.randn(C, N, ID, IH, IW).transpose(0, 1).requires_grad_(input_requires_grad) 5897*da0073e9SAndroid Build Coastguard Worker grid_cpu = torch.randn(D, N, H, W, 3).transpose(0, 1).requires_grad_() 5898*da0073e9SAndroid Build Coastguard Worker out_cpu = F.grid_sample(input_cpu, grid_cpu, mode=mode, padding_mode=padding_mode, 5899*da0073e9SAndroid Build Coastguard Worker align_corners=align_corners) 5900*da0073e9SAndroid Build Coastguard Worker self.assertTrue(out_cpu.size() == torch.Size([N, C, D, H, W])) 5901*da0073e9SAndroid Build Coastguard Worker 5902*da0073e9SAndroid Build Coastguard Worker gradients = torch.randn_like(out_cpu) 5903*da0073e9SAndroid Build Coastguard Worker out_cpu.backward(gradients) 5904*da0073e9SAndroid Build Coastguard Worker 5905*da0073e9SAndroid Build Coastguard Worker if TEST_CUDA: 5906*da0073e9SAndroid Build Coastguard Worker input_cuda = input_cpu.detach().transpose(0, 1).cuda().transpose(0, 1).requires_grad_(input_requires_grad) 5907*da0073e9SAndroid Build Coastguard Worker grid_cuda = grid_cpu.detach().transpose(0, 1).cuda().transpose(0, 1).requires_grad_() 5908*da0073e9SAndroid Build Coastguard Worker out_cuda = F.grid_sample(input_cuda, grid_cuda, mode=mode, padding_mode=padding_mode, 5909*da0073e9SAndroid Build Coastguard Worker align_corners=align_corners) 5910*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out_cpu, out_cuda) 5911*da0073e9SAndroid Build Coastguard Worker 5912*da0073e9SAndroid Build Coastguard Worker out_cuda.backward(gradients.cuda()) 5913*da0073e9SAndroid Build Coastguard Worker if input_requires_grad: 5914*da0073e9SAndroid Build Coastguard Worker self.assertEqual(input_cpu.grad, input_cuda.grad) 5915*da0073e9SAndroid Build Coastguard Worker self.assertEqual(grid_cpu.grad, grid_cuda.grad, atol=5e-5, rtol=0) 5916*da0073e9SAndroid Build Coastguard Worker 5917*da0073e9SAndroid Build Coastguard Worker # check that zero-dimensional input strides don't error out 5918*da0073e9SAndroid Build Coastguard Worker base_input = torch.randn(N, C, 1, IH, IW) 5919*da0073e9SAndroid Build Coastguard Worker input_cpu = base_input.expand_as(input_cuda).requires_grad_(input_requires_grad) 5920*da0073e9SAndroid Build Coastguard Worker grid_cpu = torch.randn(N, D, H, W, 3, requires_grad=True) 5921*da0073e9SAndroid Build Coastguard Worker out_cpu = F.grid_sample(input_cpu, grid_cpu, mode=mode, padding_mode=padding_mode, 5922*da0073e9SAndroid Build Coastguard Worker align_corners=align_corners) 5923*da0073e9SAndroid Build Coastguard Worker 5924*da0073e9SAndroid Build Coastguard Worker input_cuda = base_input.cuda().expand_as(input_cuda).requires_grad_(input_requires_grad) 5925*da0073e9SAndroid Build Coastguard Worker grid_cuda = grid_cpu.detach().cuda().requires_grad_() 5926*da0073e9SAndroid Build Coastguard Worker out_cuda = F.grid_sample(input_cuda, grid_cuda, mode=mode, padding_mode=padding_mode, 5927*da0073e9SAndroid Build Coastguard Worker align_corners=align_corners) 5928*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out_cpu, out_cuda) 5929*da0073e9SAndroid Build Coastguard Worker 5930*da0073e9SAndroid Build Coastguard Worker # test same size output 5931*da0073e9SAndroid Build Coastguard Worker test_shape(N, C, D, H, W, D, H, W, mode, padding_mode, align_corners) 5932*da0073e9SAndroid Build Coastguard Worker 5933*da0073e9SAndroid Build Coastguard Worker # test larger output 5934*da0073e9SAndroid Build Coastguard Worker N = random.randint(2, 7) 5935*da0073e9SAndroid Build Coastguard Worker C = random.randint(2, 5) 5936*da0073e9SAndroid Build Coastguard Worker ID = random.randint(2, 7) 5937*da0073e9SAndroid Build Coastguard Worker IH = random.randint(2, 7) 5938*da0073e9SAndroid Build Coastguard Worker IW = random.randint(2, 7) 5939*da0073e9SAndroid Build Coastguard Worker D = random.randint(ID + 1, 10) 5940*da0073e9SAndroid Build Coastguard Worker H = random.randint(IH + 1, 10) 5941*da0073e9SAndroid Build Coastguard Worker W = random.randint(IW + 1, 10) 5942*da0073e9SAndroid Build Coastguard Worker test_shape(N, C, ID, IH, IW, D, H, W, mode, padding_mode, align_corners) 5943*da0073e9SAndroid Build Coastguard Worker 5944*da0073e9SAndroid Build Coastguard Worker # test smaller output 5945*da0073e9SAndroid Build Coastguard Worker N = random.randint(2, 7) 5946*da0073e9SAndroid Build Coastguard Worker C = random.randint(2, 5) 5947*da0073e9SAndroid Build Coastguard Worker ID = random.randint(2, 7) 5948*da0073e9SAndroid Build Coastguard Worker IH = random.randint(2, 7) 5949*da0073e9SAndroid Build Coastguard Worker IW = random.randint(2, 7) 5950*da0073e9SAndroid Build Coastguard Worker D = random.randint(2, ID) 5951*da0073e9SAndroid Build Coastguard Worker H = random.randint(2, IH) 5952*da0073e9SAndroid Build Coastguard Worker W = random.randint(2, IW) 5953*da0073e9SAndroid Build Coastguard Worker test_shape(N, C, ID, IH, IW, D, H, W, mode, padding_mode, align_corners) 5954*da0073e9SAndroid Build Coastguard Worker 5955*da0073e9SAndroid Build Coastguard Worker # test 1x1 inpput 5956*da0073e9SAndroid Build Coastguard Worker N = random.randint(2, 7) 5957*da0073e9SAndroid Build Coastguard Worker C = random.randint(2, 7) 5958*da0073e9SAndroid Build Coastguard Worker ID = 1 5959*da0073e9SAndroid Build Coastguard Worker IH = 1 5960*da0073e9SAndroid Build Coastguard Worker IW = 1 5961*da0073e9SAndroid Build Coastguard Worker H = random.randint(2, 5) 5962*da0073e9SAndroid Build Coastguard Worker W = random.randint(2, 5) 5963*da0073e9SAndroid Build Coastguard Worker test_shape(N, C, ID, IH, IW, D, H, W, mode, padding_mode, align_corners) 5964*da0073e9SAndroid Build Coastguard Worker 5965*da0073e9SAndroid Build Coastguard Worker # testing empty grid 5966*da0073e9SAndroid Build Coastguard Worker N = random.randint(2, 7) 5967*da0073e9SAndroid Build Coastguard Worker C = random.randint(2, 5) 5968*da0073e9SAndroid Build Coastguard Worker ID = random.randint(2, 7) 5969*da0073e9SAndroid Build Coastguard Worker IH = random.randint(2, 7) 5970*da0073e9SAndroid Build Coastguard Worker IW = random.randint(2, 7) 5971*da0073e9SAndroid Build Coastguard Worker D = random.randint(3, ID + 2) 5972*da0073e9SAndroid Build Coastguard Worker W = random.randint(3, IW + 2) 5973*da0073e9SAndroid Build Coastguard Worker test_shape(N, C, ID, IH, IW, D, 0, W, mode, padding_mode, align_corners) 5974*da0073e9SAndroid Build Coastguard Worker 5975*da0073e9SAndroid Build Coastguard Worker # testing empty channel 5976*da0073e9SAndroid Build Coastguard Worker N = random.randint(2, 7) 5977*da0073e9SAndroid Build Coastguard Worker ID = random.randint(2, 5) 5978*da0073e9SAndroid Build Coastguard Worker IH = random.randint(2, 7) 5979*da0073e9SAndroid Build Coastguard Worker IW = random.randint(2, 7) 5980*da0073e9SAndroid Build Coastguard Worker D = random.randint(3, ID + 2) 5981*da0073e9SAndroid Build Coastguard Worker H = random.randint(3, IH + 2) 5982*da0073e9SAndroid Build Coastguard Worker W = random.randint(3, IW + 2) 5983*da0073e9SAndroid Build Coastguard Worker test_shape(N, 0, ID, IH, IW, D, H, W, mode, padding_mode, align_corners) 5984*da0073e9SAndroid Build Coastguard Worker 5985*da0073e9SAndroid Build Coastguard Worker # testing empty batch 5986*da0073e9SAndroid Build Coastguard Worker C = random.randint(2, 5) 5987*da0073e9SAndroid Build Coastguard Worker ID = random.randint(2, 7) 5988*da0073e9SAndroid Build Coastguard Worker IH = random.randint(2, 7) 5989*da0073e9SAndroid Build Coastguard Worker IW = random.randint(2, 7) 5990*da0073e9SAndroid Build Coastguard Worker D = random.randint(3, ID + 2) 5991*da0073e9SAndroid Build Coastguard Worker H = random.randint(3, IH + 2) 5992*da0073e9SAndroid Build Coastguard Worker W = random.randint(3, IW + 2) 5993*da0073e9SAndroid Build Coastguard Worker test_shape(0, C, ID, IH, IW, D, H, W, mode, padding_mode, align_corners) 5994*da0073e9SAndroid Build Coastguard Worker 5995*da0073e9SAndroid Build Coastguard Worker for mode in ('bilinear', 'nearest'): 5996*da0073e9SAndroid Build Coastguard Worker for padding_mode in ('zeros', 'border', 'reflection'): 5997*da0073e9SAndroid Build Coastguard Worker for align_corners in (True, False): 5998*da0073e9SAndroid Build Coastguard Worker # do gradcheck 5999*da0073e9SAndroid Build Coastguard Worker N = random.randint(2, 5) 6000*da0073e9SAndroid Build Coastguard Worker C = random.randint(2, 4) 6001*da0073e9SAndroid Build Coastguard Worker D = random.randint(2, 5) 6002*da0073e9SAndroid Build Coastguard Worker H = random.randint(2, 5) 6003*da0073e9SAndroid Build Coastguard Worker W = random.randint(2, 5) 6004*da0073e9SAndroid Build Coastguard Worker input = torch.randn(N, C, D, H, W, requires_grad=True) 6005*da0073e9SAndroid Build Coastguard Worker grid = torch.randn(N, D, H, W, 3, requires_grad=True) 6006*da0073e9SAndroid Build Coastguard Worker self.assertTrue(gradcheck( 6007*da0073e9SAndroid Build Coastguard Worker lambda inp, grid: F.grid_sample(inp, grid, mode=mode, padding_mode=padding_mode, 6008*da0073e9SAndroid Build Coastguard Worker align_corners=align_corners), 6009*da0073e9SAndroid Build Coastguard Worker (input, grid))) 6010*da0073e9SAndroid Build Coastguard Worker input = input.requires_grad_(False) 6011*da0073e9SAndroid Build Coastguard Worker self.assertTrue(gradcheck( 6012*da0073e9SAndroid Build Coastguard Worker lambda grid: F.grid_sample(input, grid, mode=mode, padding_mode=padding_mode, 6013*da0073e9SAndroid Build Coastguard Worker align_corners=align_corners), 6014*da0073e9SAndroid Build Coastguard Worker (grid,))) 6015*da0073e9SAndroid Build Coastguard Worker 6016*da0073e9SAndroid Build Coastguard Worker for input_requires_grad in [False, True]: 6017*da0073e9SAndroid Build Coastguard Worker test(N, C, D, H, W, mode, padding_mode, align_corners, input_requires_grad) 6018*da0073e9SAndroid Build Coastguard Worker 6019*da0073e9SAndroid Build Coastguard Worker def test_grid_sample_nearest_neighbor_rounding_mode_consistency(self): 6020*da0073e9SAndroid Build Coastguard Worker 6021*da0073e9SAndroid Build Coastguard Worker device_list = ['cpu'] 6022*da0073e9SAndroid Build Coastguard Worker if TEST_CUDA: 6023*da0073e9SAndroid Build Coastguard Worker device_list.append('cuda') 6024*da0073e9SAndroid Build Coastguard Worker 6025*da0073e9SAndroid Build Coastguard Worker def normalize_indices(indices_unnormalized: torch.Tensor, dim_size: int, align_corners: bool): 6026*da0073e9SAndroid Build Coastguard Worker if align_corners: 6027*da0073e9SAndroid Build Coastguard Worker indices_normalized = 2 * indices_unnormalized / (dim_size - 1) - 1 6028*da0073e9SAndroid Build Coastguard Worker else: 6029*da0073e9SAndroid Build Coastguard Worker indices_normalized = (indices_unnormalized * 2 + 1) / dim_size - 1 6030*da0073e9SAndroid Build Coastguard Worker return indices_normalized 6031*da0073e9SAndroid Build Coastguard Worker 6032*da0073e9SAndroid Build Coastguard Worker test_dim_size = 10 6033*da0073e9SAndroid Build Coastguard Worker non_test_dim_size = 9 6034*da0073e9SAndroid Build Coastguard Worker step_size = 0.1 6035*da0073e9SAndroid Build Coastguard Worker 6036*da0073e9SAndroid Build Coastguard Worker batch_size = 1 6037*da0073e9SAndroid Build Coastguard Worker channel_size = 1 6038*da0073e9SAndroid Build Coastguard Worker 6039*da0073e9SAndroid Build Coastguard Worker mode = 'nearest' 6040*da0073e9SAndroid Build Coastguard Worker for device in device_list: 6041*da0073e9SAndroid Build Coastguard Worker for padding_mode in ('zeros', 'border', 'reflection'): 6042*da0073e9SAndroid Build Coastguard Worker for align_corners in (True, False): 6043*da0073e9SAndroid Build Coastguard Worker # Unnormalized inquiry indices 6044*da0073e9SAndroid Build Coastguard Worker inquiry_indices_unnormalized = torch.arange( 6045*da0073e9SAndroid Build Coastguard Worker 0, 6046*da0073e9SAndroid Build Coastguard Worker test_dim_size - 1 + step_size, step_size, 6047*da0073e9SAndroid Build Coastguard Worker dtype=torch.float32, 6048*da0073e9SAndroid Build Coastguard Worker device=device 6049*da0073e9SAndroid Build Coastguard Worker ) 6050*da0073e9SAndroid Build Coastguard Worker # Note that even though we are trying to create normalized indices 6051*da0073e9SAndroid Build Coastguard Worker # which results in x.0 and x.5 indices after unnormalization, 6052*da0073e9SAndroid Build Coastguard Worker # because of the numerical error, 6053*da0073e9SAndroid Build Coastguard Worker # the rounding direction might not always be expected as designed. 6054*da0073e9SAndroid Build Coastguard Worker # The best we could do is to ensure the rounding behaviors across 6055*da0073e9SAndroid Build Coastguard Worker # different implementations for different dimensions are 6056*da0073e9SAndroid Build Coastguard Worker # exactly the same. 6057*da0073e9SAndroid Build Coastguard Worker inquiry_indices = normalize_indices( 6058*da0073e9SAndroid Build Coastguard Worker indices_unnormalized=inquiry_indices_unnormalized, 6059*da0073e9SAndroid Build Coastguard Worker dim_size=test_dim_size, 6060*da0073e9SAndroid Build Coastguard Worker align_corners=align_corners 6061*da0073e9SAndroid Build Coastguard Worker ) 6062*da0073e9SAndroid Build Coastguard Worker num_inqueries = inquiry_indices.shape[0] 6063*da0073e9SAndroid Build Coastguard Worker inquiry_fixed_indices = torch.full((num_inqueries,), 0.5, dtype=torch.float32, device=device) 6064*da0073e9SAndroid Build Coastguard Worker array_data = torch.rand(test_dim_size, dtype=torch.float32, device=device) 6065*da0073e9SAndroid Build Coastguard Worker # 2D grid sample x-dim interpolation 6066*da0073e9SAndroid Build Coastguard Worker # The input_tensor_2d_x is of shape 6067*da0073e9SAndroid Build Coastguard Worker # [batch_size, channel_size, non_test_dim_size, test_dim_size] 6068*da0073e9SAndroid Build Coastguard Worker input_tensor_2d_x = array_data.reshape(1, test_dim_size).repeat( 6069*da0073e9SAndroid Build Coastguard Worker batch_size, 6070*da0073e9SAndroid Build Coastguard Worker channel_size, 6071*da0073e9SAndroid Build Coastguard Worker non_test_dim_size, 6072*da0073e9SAndroid Build Coastguard Worker 1 6073*da0073e9SAndroid Build Coastguard Worker ) 6074*da0073e9SAndroid Build Coastguard Worker # The grid_tensor_2d_x is of shape 6075*da0073e9SAndroid Build Coastguard Worker # [batch_size, 1, num_inqueries] 6076*da0073e9SAndroid Build Coastguard Worker grid_tensor_2d_x = torch.cat( 6077*da0073e9SAndroid Build Coastguard Worker tensors=( 6078*da0073e9SAndroid Build Coastguard Worker inquiry_indices.reshape(num_inqueries, 1), 6079*da0073e9SAndroid Build Coastguard Worker inquiry_fixed_indices.reshape(num_inqueries, 1), 6080*da0073e9SAndroid Build Coastguard Worker ), 6081*da0073e9SAndroid Build Coastguard Worker dim=1 6082*da0073e9SAndroid Build Coastguard Worker ).repeat(batch_size, 1, 1, 1) 6083*da0073e9SAndroid Build Coastguard Worker # The output_tensor_2d_x is of shape 6084*da0073e9SAndroid Build Coastguard Worker # [batch_size, channel_size, 1, num_inqueries] 6085*da0073e9SAndroid Build Coastguard Worker output_tensor_2d_x = F.grid_sample( 6086*da0073e9SAndroid Build Coastguard Worker input=input_tensor_2d_x, 6087*da0073e9SAndroid Build Coastguard Worker grid=grid_tensor_2d_x, 6088*da0073e9SAndroid Build Coastguard Worker mode=mode, 6089*da0073e9SAndroid Build Coastguard Worker padding_mode=padding_mode, 6090*da0073e9SAndroid Build Coastguard Worker align_corners=align_corners, 6091*da0073e9SAndroid Build Coastguard Worker ) 6092*da0073e9SAndroid Build Coastguard Worker # 2D grid sample y-dim interpolation 6093*da0073e9SAndroid Build Coastguard Worker # The input_tensor_2d_y is of shape 6094*da0073e9SAndroid Build Coastguard Worker # [batch_size, channel_size, test_dim_size, non_test_dim_size] 6095*da0073e9SAndroid Build Coastguard Worker input_tensor_2d_y = torch.transpose(input_tensor_2d_x, 3, 2) 6096*da0073e9SAndroid Build Coastguard Worker # The grid_tensor_2d_y is of shape 6097*da0073e9SAndroid Build Coastguard Worker # [batch_size, 1, num_inqueries] 6098*da0073e9SAndroid Build Coastguard Worker grid_tensor_2d_y = torch.index_select( 6099*da0073e9SAndroid Build Coastguard Worker grid_tensor_2d_x, 6100*da0073e9SAndroid Build Coastguard Worker -1, 6101*da0073e9SAndroid Build Coastguard Worker torch.tensor([1, 0], dtype=torch.int64, device=device) 6102*da0073e9SAndroid Build Coastguard Worker ) 6103*da0073e9SAndroid Build Coastguard Worker # The output_tensor_2d_y is of shape 6104*da0073e9SAndroid Build Coastguard Worker # [batch_size, channel_size, 1, num_inqueries] 6105*da0073e9SAndroid Build Coastguard Worker output_tensor_2d_y = F.grid_sample( 6106*da0073e9SAndroid Build Coastguard Worker input=input_tensor_2d_y, 6107*da0073e9SAndroid Build Coastguard Worker grid=grid_tensor_2d_y, 6108*da0073e9SAndroid Build Coastguard Worker mode=mode, 6109*da0073e9SAndroid Build Coastguard Worker padding_mode=padding_mode, 6110*da0073e9SAndroid Build Coastguard Worker align_corners=align_corners, 6111*da0073e9SAndroid Build Coastguard Worker ) 6112*da0073e9SAndroid Build Coastguard Worker self.assertEqual(output_tensor_2d_x[0, 0, 0, :], output_tensor_2d_y[0, 0, 0, :], atol=0, rtol=0) 6113*da0073e9SAndroid Build Coastguard Worker # 3D grid sample x-dim interpolation 6114*da0073e9SAndroid Build Coastguard Worker # The input_tensor_3d_x is of shape 6115*da0073e9SAndroid Build Coastguard Worker # [batch_size, channel_size, non_test_dim_size, non_test_dim_size, test_dim_size] 6116*da0073e9SAndroid Build Coastguard Worker input_tensor_3d_x = array_data.reshape(1, test_dim_size).repeat( 6117*da0073e9SAndroid Build Coastguard Worker batch_size, channel_size, non_test_dim_size, non_test_dim_size, 1) 6118*da0073e9SAndroid Build Coastguard Worker # The grid_tensor_3d_x is of shape 6119*da0073e9SAndroid Build Coastguard Worker # [batch_size, 1, 1, num_inqueries] 6120*da0073e9SAndroid Build Coastguard Worker grid_tensor_3d_x = torch.cat( 6121*da0073e9SAndroid Build Coastguard Worker tensors=( 6122*da0073e9SAndroid Build Coastguard Worker inquiry_indices.reshape(num_inqueries, 1), 6123*da0073e9SAndroid Build Coastguard Worker inquiry_fixed_indices.reshape(num_inqueries, 1), 6124*da0073e9SAndroid Build Coastguard Worker inquiry_fixed_indices.reshape(num_inqueries, 1), 6125*da0073e9SAndroid Build Coastguard Worker ), 6126*da0073e9SAndroid Build Coastguard Worker dim=1 6127*da0073e9SAndroid Build Coastguard Worker ).repeat(batch_size, 1, 1, 1, 1) 6128*da0073e9SAndroid Build Coastguard Worker # The output_tensor_3d_x is of shape 6129*da0073e9SAndroid Build Coastguard Worker # [batch_size, channel_size, 1, 1, num_inqueries] 6130*da0073e9SAndroid Build Coastguard Worker output_tensor_3d_x = F.grid_sample( 6131*da0073e9SAndroid Build Coastguard Worker input=input_tensor_3d_x, 6132*da0073e9SAndroid Build Coastguard Worker grid=grid_tensor_3d_x, 6133*da0073e9SAndroid Build Coastguard Worker mode=mode, 6134*da0073e9SAndroid Build Coastguard Worker padding_mode=padding_mode, 6135*da0073e9SAndroid Build Coastguard Worker align_corners=align_corners, 6136*da0073e9SAndroid Build Coastguard Worker ) 6137*da0073e9SAndroid Build Coastguard Worker self.assertEqual(output_tensor_2d_x[0, 0, 0, :], output_tensor_3d_x[0, 0, 0, 0, :], atol=0, rtol=0) 6138*da0073e9SAndroid Build Coastguard Worker # 3D grid sample y-dim interpolation 6139*da0073e9SAndroid Build Coastguard Worker # The input_tensor_3d_y is of shape 6140*da0073e9SAndroid Build Coastguard Worker # [batch_size, channel_size, non_test_dim_size, test_dim_size, non_test_dim_size] 6141*da0073e9SAndroid Build Coastguard Worker input_tensor_3d_y = torch.transpose(input_tensor_3d_x, 4, 3) 6142*da0073e9SAndroid Build Coastguard Worker # The grid_tensor_3d_y is of shape 6143*da0073e9SAndroid Build Coastguard Worker # [batch_size, 1, 1, num_inqueries] 6144*da0073e9SAndroid Build Coastguard Worker grid_tensor_3d_y = torch.index_select( 6145*da0073e9SAndroid Build Coastguard Worker grid_tensor_3d_x, 6146*da0073e9SAndroid Build Coastguard Worker -1, 6147*da0073e9SAndroid Build Coastguard Worker torch.tensor([1, 0, 2], dtype=torch.int64, device=device) 6148*da0073e9SAndroid Build Coastguard Worker ) 6149*da0073e9SAndroid Build Coastguard Worker # The output_tensor_3d_y is of shape 6150*da0073e9SAndroid Build Coastguard Worker # [batch_size, channel_size, 1, 1, num_inqueries] 6151*da0073e9SAndroid Build Coastguard Worker output_tensor_3d_y = F.grid_sample( 6152*da0073e9SAndroid Build Coastguard Worker input=input_tensor_3d_y, 6153*da0073e9SAndroid Build Coastguard Worker grid=grid_tensor_3d_y, 6154*da0073e9SAndroid Build Coastguard Worker mode=mode, 6155*da0073e9SAndroid Build Coastguard Worker padding_mode=padding_mode, 6156*da0073e9SAndroid Build Coastguard Worker align_corners=align_corners, 6157*da0073e9SAndroid Build Coastguard Worker ) 6158*da0073e9SAndroid Build Coastguard Worker self.assertEqual(output_tensor_2d_x[0, 0, 0, :], output_tensor_3d_y[0, 0, 0, 0, :], atol=0, rtol=0) 6159*da0073e9SAndroid Build Coastguard Worker # 3D grid sample z-dim interpolation 6160*da0073e9SAndroid Build Coastguard Worker # The input_tensor_3d_z is of shape 6161*da0073e9SAndroid Build Coastguard Worker # [batch_size, channel_size, non_test_dim_size, non_test_dim_size, test_dim_size] 6162*da0073e9SAndroid Build Coastguard Worker input_tensor_3d_z = torch.transpose(input_tensor_3d_x, 4, 2) 6163*da0073e9SAndroid Build Coastguard Worker # The grid_tensor_3d_z is of shape 6164*da0073e9SAndroid Build Coastguard Worker # [batch_size, 1, 1, num_inqueries] 6165*da0073e9SAndroid Build Coastguard Worker grid_tensor_3d_z = torch.index_select( 6166*da0073e9SAndroid Build Coastguard Worker grid_tensor_3d_x, 6167*da0073e9SAndroid Build Coastguard Worker -1, 6168*da0073e9SAndroid Build Coastguard Worker torch.tensor([1, 2, 0], dtype=torch.int64, device=device) 6169*da0073e9SAndroid Build Coastguard Worker ) 6170*da0073e9SAndroid Build Coastguard Worker # The output_tensor_3d_z is of shape 6171*da0073e9SAndroid Build Coastguard Worker # [batch_size, channel_size, 1, 1, num_inqueries] 6172*da0073e9SAndroid Build Coastguard Worker output_tensor_3d_z = F.grid_sample( 6173*da0073e9SAndroid Build Coastguard Worker input=input_tensor_3d_z, 6174*da0073e9SAndroid Build Coastguard Worker grid=grid_tensor_3d_z, 6175*da0073e9SAndroid Build Coastguard Worker mode=mode, 6176*da0073e9SAndroid Build Coastguard Worker padding_mode=padding_mode, 6177*da0073e9SAndroid Build Coastguard Worker align_corners=align_corners, 6178*da0073e9SAndroid Build Coastguard Worker ) 6179*da0073e9SAndroid Build Coastguard Worker self.assertEqual(output_tensor_2d_x[0, 0, 0, :], output_tensor_3d_z[0, 0, 0, 0, :], atol=0, rtol=0) 6180*da0073e9SAndroid Build Coastguard Worker 6181*da0073e9SAndroid Build Coastguard Worker @set_default_dtype(torch.double) 6182*da0073e9SAndroid Build Coastguard Worker def test_affine_grid(self): 6183*da0073e9SAndroid Build Coastguard Worker # test known input on CPU 6184*da0073e9SAndroid Build Coastguard Worker input = torch.arange(1., 7).view(1, 2, 3) 6185*da0073e9SAndroid Build Coastguard Worker output = F.affine_grid(input, torch.Size([1, 1, 2, 2]), align_corners=True) 6186*da0073e9SAndroid Build Coastguard Worker groundtruth = torch.tensor( 6187*da0073e9SAndroid Build Coastguard Worker [[[0., -3.], [2., 5.]], [[4., 7.], [6., 15.]]]).view(1, 2, 2, 2) 6188*da0073e9SAndroid Build Coastguard Worker self.assertEqual(output, groundtruth) 6189*da0073e9SAndroid Build Coastguard Worker output = F.affine_grid(input, torch.Size([1, 1, 2, 2]), align_corners=False) 6190*da0073e9SAndroid Build Coastguard Worker groundtruth = torch.tensor( 6191*da0073e9SAndroid Build Coastguard Worker [[[1.5, 1.5], [2.5, 5.5]], [[3.5, 6.5], [4.5, 10.5]]]).view(1, 2, 2, 2) 6192*da0073e9SAndroid Build Coastguard Worker self.assertEqual(output, groundtruth) 6193*da0073e9SAndroid Build Coastguard Worker 6194*da0073e9SAndroid Build Coastguard Worker for align_corners in (True, False): 6195*da0073e9SAndroid Build Coastguard Worker # do gradcheck 6196*da0073e9SAndroid Build Coastguard Worker N = random.randint(1, 8) 6197*da0073e9SAndroid Build Coastguard Worker C = random.randint(1, 8) 6198*da0073e9SAndroid Build Coastguard Worker H = random.randint(1, 8) 6199*da0073e9SAndroid Build Coastguard Worker W = random.randint(1, 8) 6200*da0073e9SAndroid Build Coastguard Worker sz = torch.Size([N, C, H, W]) 6201*da0073e9SAndroid Build Coastguard Worker inp = torch.randn(N, 2, 3, requires_grad=True) 6202*da0073e9SAndroid Build Coastguard Worker with warnings.catch_warnings(record=True): 6203*da0073e9SAndroid Build Coastguard Worker warnings.simplefilter("always") # python2 requires this so other tests can trigger 6204*da0073e9SAndroid Build Coastguard Worker self.assertTrue(gradcheck( 6205*da0073e9SAndroid Build Coastguard Worker lambda inp: F.affine_grid(inp, sz, align_corners=align_corners), 6206*da0073e9SAndroid Build Coastguard Worker (inp,))) 6207*da0073e9SAndroid Build Coastguard Worker 6208*da0073e9SAndroid Build Coastguard Worker # test CPU against CUDA 6209*da0073e9SAndroid Build Coastguard Worker if TEST_CUDA: 6210*da0073e9SAndroid Build Coastguard Worker N = random.randint(1, 8) 6211*da0073e9SAndroid Build Coastguard Worker C = random.randint(1, 8) 6212*da0073e9SAndroid Build Coastguard Worker H = random.randint(1, 8) 6213*da0073e9SAndroid Build Coastguard Worker W = random.randint(1, 8) 6214*da0073e9SAndroid Build Coastguard Worker sz = torch.Size([N, C, H, W]) 6215*da0073e9SAndroid Build Coastguard Worker for align_corners in (True, False): 6216*da0073e9SAndroid Build Coastguard Worker input_cpu = torch.randn(N, 2, 3, requires_grad=True) 6217*da0073e9SAndroid Build Coastguard Worker with warnings.catch_warnings(record=True): 6218*da0073e9SAndroid Build Coastguard Worker warnings.simplefilter("always") # python2 requires this so other tests can trigger 6219*da0073e9SAndroid Build Coastguard Worker out_cpu = F.affine_grid(input_cpu, sz, align_corners=align_corners) 6220*da0073e9SAndroid Build Coastguard Worker gradients = torch.randn(out_cpu.size()) 6221*da0073e9SAndroid Build Coastguard Worker out_cpu.backward(gradients) 6222*da0073e9SAndroid Build Coastguard Worker input_gpu = input_cpu.detach().cuda().requires_grad_() 6223*da0073e9SAndroid Build Coastguard Worker with warnings.catch_warnings(record=True): 6224*da0073e9SAndroid Build Coastguard Worker warnings.simplefilter("always") # python2 requires this so other tests can trigger 6225*da0073e9SAndroid Build Coastguard Worker out_cuda = F.affine_grid(input_gpu, sz, align_corners=align_corners) 6226*da0073e9SAndroid Build Coastguard Worker out_cuda.backward(gradients.cuda()) 6227*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out_cpu, out_cuda) 6228*da0073e9SAndroid Build Coastguard Worker self.assertEqual(input_cpu.grad, input_gpu.grad) 6229*da0073e9SAndroid Build Coastguard Worker 6230*da0073e9SAndroid Build Coastguard Worker @set_default_dtype(torch.double) 6231*da0073e9SAndroid Build Coastguard Worker def test_affine_grid_3d(self): 6232*da0073e9SAndroid Build Coastguard Worker # test known input on CPU 6233*da0073e9SAndroid Build Coastguard Worker input = torch.arange(1., 13).view(1, 3, 4) 6234*da0073e9SAndroid Build Coastguard Worker output = F.affine_grid(input, torch.Size([1, 1, 2, 2, 2]), align_corners=True) 6235*da0073e9SAndroid Build Coastguard Worker groundtruth = torch.tensor( 6236*da0073e9SAndroid Build Coastguard Worker [[[[[-2., -10., -18.], [0., 0., 0.]], [[2., 2., 2.], [4., 12., 20.]]], 6237*da0073e9SAndroid Build Coastguard Worker [[[4., 4., 4.], [6., 14., 22.]], [[8., 16., 24.], [10., 26., 42.]]]]]).view(1, 2, 2, 2, 3) 6238*da0073e9SAndroid Build Coastguard Worker self.assertEqual(output, groundtruth) 6239*da0073e9SAndroid Build Coastguard Worker output = F.affine_grid(input, torch.Size([1, 1, 2, 2, 2]), align_corners=False) 6240*da0073e9SAndroid Build Coastguard Worker groundtruth = torch.tensor( 6241*da0073e9SAndroid Build Coastguard Worker [[[[[1., -1., -3.], [2., 4., 6.]], [[3., 5., 7.], [4., 10., 16.]]], 6242*da0073e9SAndroid Build Coastguard Worker [[[4., 6., 8.], [5., 11., 17.]], [[6., 12., 18.], [7., 17., 27.]]]]]).view(1, 2, 2, 2, 3) 6243*da0073e9SAndroid Build Coastguard Worker self.assertEqual(output, groundtruth) 6244*da0073e9SAndroid Build Coastguard Worker 6245*da0073e9SAndroid Build Coastguard Worker for align_corners in (True, False): 6246*da0073e9SAndroid Build Coastguard Worker # do gradcheck 6247*da0073e9SAndroid Build Coastguard Worker N = random.randint(1, 8) 6248*da0073e9SAndroid Build Coastguard Worker C = random.randint(1, 8) 6249*da0073e9SAndroid Build Coastguard Worker D = random.randint(1, 8) 6250*da0073e9SAndroid Build Coastguard Worker H = random.randint(1, 8) 6251*da0073e9SAndroid Build Coastguard Worker W = random.randint(1, 8) 6252*da0073e9SAndroid Build Coastguard Worker sz = torch.Size([N, C, D, H, W]) 6253*da0073e9SAndroid Build Coastguard Worker inp = torch.randn(N, 3, 4, requires_grad=True) 6254*da0073e9SAndroid Build Coastguard Worker with warnings.catch_warnings(record=True): 6255*da0073e9SAndroid Build Coastguard Worker warnings.simplefilter("always") # python2 requires this so other tests can trigger 6256*da0073e9SAndroid Build Coastguard Worker self.assertTrue(gradcheck( 6257*da0073e9SAndroid Build Coastguard Worker lambda inp: F.affine_grid(inp, sz, align_corners=align_corners), 6258*da0073e9SAndroid Build Coastguard Worker (inp,))) 6259*da0073e9SAndroid Build Coastguard Worker 6260*da0073e9SAndroid Build Coastguard Worker # test CPU against CUDA 6261*da0073e9SAndroid Build Coastguard Worker if TEST_CUDA: 6262*da0073e9SAndroid Build Coastguard Worker N = random.randint(1, 8) 6263*da0073e9SAndroid Build Coastguard Worker C = random.randint(1, 8) 6264*da0073e9SAndroid Build Coastguard Worker D = random.randint(1, 8) 6265*da0073e9SAndroid Build Coastguard Worker H = random.randint(1, 8) 6266*da0073e9SAndroid Build Coastguard Worker W = random.randint(1, 8) 6267*da0073e9SAndroid Build Coastguard Worker sz = torch.Size([N, C, D, H, W]) 6268*da0073e9SAndroid Build Coastguard Worker for align_corners in (True, False): 6269*da0073e9SAndroid Build Coastguard Worker input_cpu = torch.randn(N, 3, 4, requires_grad=True) 6270*da0073e9SAndroid Build Coastguard Worker with warnings.catch_warnings(record=True): 6271*da0073e9SAndroid Build Coastguard Worker warnings.simplefilter("always") # python2 requires this so other tests can trigger 6272*da0073e9SAndroid Build Coastguard Worker out_cpu = F.affine_grid(input_cpu, sz, align_corners=align_corners) 6273*da0073e9SAndroid Build Coastguard Worker gradients = torch.randn(out_cpu.size()) 6274*da0073e9SAndroid Build Coastguard Worker out_cpu.backward(gradients) 6275*da0073e9SAndroid Build Coastguard Worker input_gpu = input_cpu.detach().cuda().requires_grad_() 6276*da0073e9SAndroid Build Coastguard Worker with warnings.catch_warnings(record=True): 6277*da0073e9SAndroid Build Coastguard Worker warnings.simplefilter("always") # python2 requires this so other tests can trigger 6278*da0073e9SAndroid Build Coastguard Worker out_cuda = F.affine_grid(input_gpu, sz, align_corners=align_corners) 6279*da0073e9SAndroid Build Coastguard Worker out_cuda.backward(gradients.cuda()) 6280*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out_cpu, out_cuda) 6281*da0073e9SAndroid Build Coastguard Worker self.assertEqual(input_cpu.grad, input_gpu.grad) 6282*da0073e9SAndroid Build Coastguard Worker 6283*da0073e9SAndroid Build Coastguard Worker def test_channel_shuffle_return_alias_of_self(self): 6284*da0073e9SAndroid Build Coastguard Worker # gh-76616: nn.ChannelShuffle will return alias of self with an empty input tensor 6285*da0073e9SAndroid Build Coastguard Worker groups = 3 6286*da0073e9SAndroid Build Coastguard Worker input_tensor = torch.rand([0, 9, 4, 4]) 6287*da0073e9SAndroid Build Coastguard Worker output = torch.nn.ChannelShuffle(groups)(input_tensor) 6288*da0073e9SAndroid Build Coastguard Worker torch.testing.assert_close(output, input_tensor) 6289*da0073e9SAndroid Build Coastguard Worker 6290*da0073e9SAndroid Build Coastguard Worker @skipIfTorchDynamo("TorchDynamo fails here for unknown reasons") 6291*da0073e9SAndroid Build Coastguard Worker def test_native_channel_shuffle_return_alias_of_self(self): 6292*da0073e9SAndroid Build Coastguard Worker groups = 3 6293*da0073e9SAndroid Build Coastguard Worker input_tensor = torch.rand([0, 9, 4, 4]) 6294*da0073e9SAndroid Build Coastguard Worker output = torch.native_channel_shuffle(input_tensor, groups) 6295*da0073e9SAndroid Build Coastguard Worker torch.testing.assert_close(output, input_tensor) 6296*da0073e9SAndroid Build Coastguard Worker 6297*da0073e9SAndroid Build Coastguard Worker @set_default_dtype(torch.double) 6298*da0073e9SAndroid Build Coastguard Worker def test_upsamplingLinear1d(self): 6299*da0073e9SAndroid Build Coastguard Worker for align_corners in [True, False]: 6300*da0073e9SAndroid Build Coastguard Worker for recompute_scale_factor in [True, False]: 6301*da0073e9SAndroid Build Coastguard Worker kwargs = dict( 6302*da0073e9SAndroid Build Coastguard Worker mode='linear', align_corners=align_corners, recompute_scale_factor=recompute_scale_factor 6303*da0073e9SAndroid Build Coastguard Worker ) 6304*da0073e9SAndroid Build Coastguard Worker # test float scale factor up & downsampling 6305*da0073e9SAndroid Build Coastguard Worker for scale_factor in [0.5, 1.5, 2]: 6306*da0073e9SAndroid Build Coastguard Worker m = nn.Upsample(scale_factor=scale_factor, **kwargs) 6307*da0073e9SAndroid Build Coastguard Worker in_t = torch.ones(1, 1, 2) 6308*da0073e9SAndroid Build Coastguard Worker out_size = int(math.floor(in_t.shape[-1] * scale_factor)) 6309*da0073e9SAndroid Build Coastguard Worker with warnings.catch_warnings(record=True) as w: 6310*da0073e9SAndroid Build Coastguard Worker out_t = m(in_t) 6311*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.ones(1, 1, out_size), out_t.data) 6312*da0073e9SAndroid Build Coastguard Worker 6313*da0073e9SAndroid Build Coastguard Worker input = torch.randn(1, 1, 2, requires_grad=True) 6314*da0073e9SAndroid Build Coastguard Worker if not recompute_scale_factor: 6315*da0073e9SAndroid Build Coastguard Worker gradcheck(lambda x: F.interpolate(x, out_size, **kwargs), (input,)) 6316*da0073e9SAndroid Build Coastguard Worker else: 6317*da0073e9SAndroid Build Coastguard Worker gradcheck(lambda x: F.interpolate(x, scale_factor=scale_factor, **kwargs), (input,)) 6318*da0073e9SAndroid Build Coastguard Worker 6319*da0073e9SAndroid Build Coastguard Worker def test_upsamplingLinear1d_spatial_invariance(self): 6320*da0073e9SAndroid Build Coastguard Worker m = nn.Upsample(scale_factor=3, mode='linear', align_corners=False) 6321*da0073e9SAndroid Build Coastguard Worker in_t_9 = torch.zeros(1, 1, 9) 6322*da0073e9SAndroid Build Coastguard Worker in_t_9[:, :, :4].normal_() 6323*da0073e9SAndroid Build Coastguard Worker with warnings.catch_warnings(record=True) as w: 6324*da0073e9SAndroid Build Coastguard Worker out_t_9 = m(in_t_9) 6325*da0073e9SAndroid Build Coastguard Worker out_t_5 = m(in_t_9[:, :, :5]) 6326*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out_t_9[:, :, :15], out_t_5) 6327*da0073e9SAndroid Build Coastguard Worker 6328*da0073e9SAndroid Build Coastguard Worker @set_default_dtype(torch.double) 6329*da0073e9SAndroid Build Coastguard Worker def test_upsampling_not_recompute_scale_factor(self): 6330*da0073e9SAndroid Build Coastguard Worker # test output against known input: result must match opencv 6331*da0073e9SAndroid Build Coastguard Worker in_t = torch.arange(8.).view(1, 2, 2, 2) 6332*da0073e9SAndroid Build Coastguard Worker expected_out_t = torch.tensor( 6333*da0073e9SAndroid Build Coastguard Worker [[[[-0.32725, -0.08843, 0.37933, 0.79744], 6334*da0073e9SAndroid Build Coastguard Worker [0.15039, 0.38921, 0.85697, 1.27508], 6335*da0073e9SAndroid Build Coastguard Worker [1.08591, 1.32473, 1.79249, 2.21060], 6336*da0073e9SAndroid Build Coastguard Worker [1.92213, 2.16095, 2.62871, 3.04682]], 6337*da0073e9SAndroid Build Coastguard Worker 6338*da0073e9SAndroid Build Coastguard Worker [[3.67275, 3.91157, 4.37933, 4.79744], 6339*da0073e9SAndroid Build Coastguard Worker [4.15039, 4.38921, 4.85697, 5.27508], 6340*da0073e9SAndroid Build Coastguard Worker [5.08591, 5.32473, 5.79249, 6.21060], 6341*da0073e9SAndroid Build Coastguard Worker [5.92213, 6.16095, 6.62871, 7.04682]]]]) 6342*da0073e9SAndroid Build Coastguard Worker if IS_PPC: 6343*da0073e9SAndroid Build Coastguard Worker # Both OpenCV and PyTorch give a slightly different result on PPC 6344*da0073e9SAndroid Build Coastguard Worker expected_out_t = torch.tensor( 6345*da0073e9SAndroid Build Coastguard Worker [[[[-0.32725, -0.08843, 0.37933, 0.79744], 6346*da0073e9SAndroid Build Coastguard Worker [0.15039, 0.38921, 0.85697, 1.27508], 6347*da0073e9SAndroid Build Coastguard Worker [1.08591, 1.32473, 1.79249, 2.21060], 6348*da0073e9SAndroid Build Coastguard Worker [1.92212, 2.16094, 2.62870, 3.04681]], 6349*da0073e9SAndroid Build Coastguard Worker 6350*da0073e9SAndroid Build Coastguard Worker [[3.67275, 3.91157, 4.37933, 4.79743], 6351*da0073e9SAndroid Build Coastguard Worker [4.15039, 4.38921, 4.85697, 5.27508], 6352*da0073e9SAndroid Build Coastguard Worker [5.08591, 5.32473, 5.79249, 6.21059], 6353*da0073e9SAndroid Build Coastguard Worker [5.92212, 6.16094, 6.62870, 7.04680]]]]) 6354*da0073e9SAndroid Build Coastguard Worker out_t = F.interpolate(in_t, scale_factor=2.3, mode='bicubic', align_corners=False, recompute_scale_factor=False) 6355*da0073e9SAndroid Build Coastguard Worker torch.set_printoptions(precision=5) 6356*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out_t, expected_out_t, atol=1e-4, rtol=0) 6357*da0073e9SAndroid Build Coastguard Worker 6358*da0073e9SAndroid Build Coastguard Worker device_list = ['cpu'] 6359*da0073e9SAndroid Build Coastguard Worker if TEST_CUDA: 6360*da0073e9SAndroid Build Coastguard Worker device_list.append('cuda') 6361*da0073e9SAndroid Build Coastguard Worker 6362*da0073e9SAndroid Build Coastguard Worker for align_corners in [True, False]: 6363*da0073e9SAndroid Build Coastguard Worker kwargs = dict(mode='bicubic', align_corners=align_corners) 6364*da0073e9SAndroid Build Coastguard Worker # test float scale factor up & downsampling 6365*da0073e9SAndroid Build Coastguard Worker for device in device_list: 6366*da0073e9SAndroid Build Coastguard Worker for scale_factor in [0.6, 1.6, 2.3]: 6367*da0073e9SAndroid Build Coastguard Worker in_t = torch.ones(2, 2, 2, 2).to(device) 6368*da0073e9SAndroid Build Coastguard Worker out_t = F.interpolate(in_t, scale_factor=scale_factor, **kwargs) 6369*da0073e9SAndroid Build Coastguard Worker out_size = int(math.floor(in_t.shape[-1] * scale_factor)) 6370*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.ones(2, 2, out_size, out_size), out_t.data, atol=1e-5, rtol=0) 6371*da0073e9SAndroid Build Coastguard Worker 6372*da0073e9SAndroid Build Coastguard Worker input = torch.randn(2, 2, 2, 2, requires_grad=True) 6373*da0073e9SAndroid Build Coastguard Worker gradcheck(lambda x: F.interpolate(x, out_size, **kwargs), [input]) 6374*da0073e9SAndroid Build Coastguard Worker 6375*da0073e9SAndroid Build Coastguard Worker def test_upsamplingBilinear2d_spatial_invariance(self): 6376*da0073e9SAndroid Build Coastguard Worker m = nn.Upsample(scale_factor=3, mode='bilinear', align_corners=False) 6377*da0073e9SAndroid Build Coastguard Worker in_t_9 = torch.zeros(1, 1, 9, 9) 6378*da0073e9SAndroid Build Coastguard Worker in_t_9[:, :, :4, :4].normal_() 6379*da0073e9SAndroid Build Coastguard Worker with warnings.catch_warnings(record=True) as w: 6380*da0073e9SAndroid Build Coastguard Worker out_t_9 = m(in_t_9) 6381*da0073e9SAndroid Build Coastguard Worker out_t_5 = m(in_t_9[:, :, :5, :5]) 6382*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out_t_9[:, :, :15, :15], out_t_5) 6383*da0073e9SAndroid Build Coastguard Worker 6384*da0073e9SAndroid Build Coastguard Worker def test_upsamplingTrilinear3d_spatial_invariance(self): 6385*da0073e9SAndroid Build Coastguard Worker m = nn.Upsample(scale_factor=3, mode='trilinear', align_corners=False) 6386*da0073e9SAndroid Build Coastguard Worker in_t_9 = torch.zeros(1, 1, 9, 9, 9) 6387*da0073e9SAndroid Build Coastguard Worker in_t_9[:, :, :4, :4, :4].normal_() 6388*da0073e9SAndroid Build Coastguard Worker with warnings.catch_warnings(record=True) as w: 6389*da0073e9SAndroid Build Coastguard Worker out_t_9 = m(in_t_9) 6390*da0073e9SAndroid Build Coastguard Worker out_t_5 = m(in_t_9[:, :, :5, :5, :5]) 6391*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out_t_9[:, :, :15, :15, :15], out_t_5) 6392*da0073e9SAndroid Build Coastguard Worker 6393*da0073e9SAndroid Build Coastguard Worker def test_upsampling_small_scale(self): 6394*da0073e9SAndroid Build Coastguard Worker m = torch.nn.Upsample(scale_factor=0.5, mode="bilinear") 6395*da0073e9SAndroid Build Coastguard Worker in_t = torch.arange(1, 5, dtype=torch.get_default_dtype()).reshape(1, 1, 2, 2) 6396*da0073e9SAndroid Build Coastguard Worker out_t = m(in_t) 6397*da0073e9SAndroid Build Coastguard Worker expected_out_t = torch.tensor([[[[2.5]]]]) 6398*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected_out_t, out_t) 6399*da0073e9SAndroid Build Coastguard Worker 6400*da0073e9SAndroid Build Coastguard Worker def test_upsampling_bfloat16(self, dtype=torch.bfloat16): 6401*da0073e9SAndroid Build Coastguard Worker def helper(size, scale_factor, mode, device, memory_format=torch.contiguous_format): 6402*da0073e9SAndroid Build Coastguard Worker input = torch.randn(size, device=device, dtype=dtype).to(memory_format=memory_format).detach().requires_grad_(True) 6403*da0073e9SAndroid Build Coastguard Worker inputf = input.to(torch.float32).to(memory_format=torch.contiguous_format).detach().requires_grad_(True) 6404*da0073e9SAndroid Build Coastguard Worker m = nn.Upsample(scale_factor=scale_factor, mode=mode) 6405*da0073e9SAndroid Build Coastguard Worker 6406*da0073e9SAndroid Build Coastguard Worker outf = m(inputf) 6407*da0073e9SAndroid Build Coastguard Worker out = m(input) 6408*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out.to(torch.float32), outf, atol=0.05, rtol=0) 6409*da0073e9SAndroid Build Coastguard Worker 6410*da0073e9SAndroid Build Coastguard Worker ginput = torch.randn(out.shape, device=device, dtype=dtype).to(memory_format=memory_format) 6411*da0073e9SAndroid Build Coastguard Worker ginputf = ginput.to(torch.float32).to(memory_format=torch.contiguous_format) 6412*da0073e9SAndroid Build Coastguard Worker out.backward(ginput) 6413*da0073e9SAndroid Build Coastguard Worker outf.backward(ginputf) 6414*da0073e9SAndroid Build Coastguard Worker self.assertEqual(input.grad.to(torch.float32), inputf.grad, atol=0.01, rtol=0.01) 6415*da0073e9SAndroid Build Coastguard Worker 6416*da0073e9SAndroid Build Coastguard Worker for device in ['cpu']: 6417*da0073e9SAndroid Build Coastguard Worker helper([3, 20, 11, 7], 2, 'nearest', device) 6418*da0073e9SAndroid Build Coastguard Worker helper([3, 20, 11, 7], 2, 'nearest', device, torch.channels_last) 6419*da0073e9SAndroid Build Coastguard Worker helper([3, 20, 11, 7, 3], 2, 'nearest', device) 6420*da0073e9SAndroid Build Coastguard Worker helper([3, 20, 30], 2, 'linear', device) 6421*da0073e9SAndroid Build Coastguard Worker helper([3, 20, 11, 7], 2, 'bilinear', device) 6422*da0073e9SAndroid Build Coastguard Worker helper([3, 20, 11, 7], 2, 'bilinear', device, torch.channels_last) 6423*da0073e9SAndroid Build Coastguard Worker helper([1, 3, 11, 7], 2, 'bicubic', device) 6424*da0073e9SAndroid Build Coastguard Worker helper([1, 3, 11, 7], 2, 'bicubic', device, torch.channels_last) 6425*da0073e9SAndroid Build Coastguard Worker helper([3, 20, 11, 7, 3], 2, 'trilinear', device) 6426*da0073e9SAndroid Build Coastguard Worker 6427*da0073e9SAndroid Build Coastguard Worker helper([3, 5, 5], 257., 'nearest', device) 6428*da0073e9SAndroid Build Coastguard Worker helper([3, 20, 11, 7], 20, 'nearest', device) 6429*da0073e9SAndroid Build Coastguard Worker helper([3, 20, 11, 7, 3], 20, 'nearest', device) 6430*da0073e9SAndroid Build Coastguard Worker helper([1, 2, 11, 7], 257, 'nearest', device, torch.channels_last) 6431*da0073e9SAndroid Build Coastguard Worker helper([1, 2, 2000, 2000], 1 / 377., 'nearest', device) 6432*da0073e9SAndroid Build Coastguard Worker helper([1, 2, 2000, 2000], 1 / 257., 'nearest', device, torch.channels_last) 6433*da0073e9SAndroid Build Coastguard Worker helper([3, 2, 11, 7, 3], 20, 'nearest', device, torch.channels_last_3d) 6434*da0073e9SAndroid Build Coastguard Worker helper([3, 5, 5], 10, 'linear', device) 6435*da0073e9SAndroid Build Coastguard Worker helper([3, 5, 5], 257, 'linear', device) 6436*da0073e9SAndroid Build Coastguard Worker helper([1, 2, 11, 7], 257, 'bilinear', device) 6437*da0073e9SAndroid Build Coastguard Worker helper([1, 2, 11, 7], 257, 'bilinear', device, torch.channels_last) 6438*da0073e9SAndroid Build Coastguard Worker helper([1, 3, 11, 7], 10, 'bicubic', device) 6439*da0073e9SAndroid Build Coastguard Worker helper([1, 3, 11, 7], 10, 'bicubic', device, torch.channels_last) 6440*da0073e9SAndroid Build Coastguard Worker helper([1, 1, 11, 7], 257, 'bicubic', device) 6441*da0073e9SAndroid Build Coastguard Worker helper([3, 2, 11, 7, 3], 20, 'trilinear', device) 6442*da0073e9SAndroid Build Coastguard Worker helper([3, 2, 11, 7, 3], 20, 'trilinear', device, torch.channels_last_3d) 6443*da0073e9SAndroid Build Coastguard Worker 6444*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_CUDA, "CUDA unavailable") 6445*da0073e9SAndroid Build Coastguard Worker def test_interpolate_illegal_memory_access(self): 6446*da0073e9SAndroid Build Coastguard Worker in_s = 45 6447*da0073e9SAndroid Build Coastguard Worker out_s = 14 6448*da0073e9SAndroid Build Coastguard Worker 6449*da0073e9SAndroid Build Coastguard Worker input = torch.ones((1, 1, in_s), device='cuda', requires_grad=True) 6450*da0073e9SAndroid Build Coastguard Worker # note we allocated grad_output to be larger so out of bound access 6451*da0073e9SAndroid Build Coastguard Worker # would be visible in grad_input 6452*da0073e9SAndroid Build Coastguard Worker grad = torch.ones((1, 1, out_s * 2), device='cuda', requires_grad=True) 6453*da0073e9SAndroid Build Coastguard Worker grad = grad[:, :, :out_s] 6454*da0073e9SAndroid Build Coastguard Worker 6455*da0073e9SAndroid Build Coastguard Worker input_ref = input.detach().cpu().requires_grad_() 6456*da0073e9SAndroid Build Coastguard Worker grad_ref = grad.cpu() 6457*da0073e9SAndroid Build Coastguard Worker 6458*da0073e9SAndroid Build Coastguard Worker out = F.interpolate(input, size=(out_s,), mode='nearest') 6459*da0073e9SAndroid Build Coastguard Worker out.backward(grad) 6460*da0073e9SAndroid Build Coastguard Worker 6461*da0073e9SAndroid Build Coastguard Worker out_ref = F.interpolate(input_ref, size=(out_s,), mode='nearest') 6462*da0073e9SAndroid Build Coastguard Worker out_ref.backward(grad_ref) 6463*da0073e9SAndroid Build Coastguard Worker 6464*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out_ref, out) 6465*da0073e9SAndroid Build Coastguard Worker self.assertEqual(input_ref.grad, input.grad) 6466*da0073e9SAndroid Build Coastguard Worker 6467*da0073e9SAndroid Build Coastguard Worker def test_interpolate_undefined_behavior_casting(self): 6468*da0073e9SAndroid Build Coastguard Worker x = torch.ones([1, 1, 16, 16]) 6469*da0073e9SAndroid Build Coastguard Worker self.assertRaises(RuntimeError, lambda: F.interpolate(x, scale_factor=-1e20, mode="bilinear")) 6470*da0073e9SAndroid Build Coastguard Worker self.assertRaises(RuntimeError, lambda: F.interpolate(x, scale_factor=1e20, mode="bilinear")) 6471*da0073e9SAndroid Build Coastguard Worker 6472*da0073e9SAndroid Build Coastguard Worker def test_interpolate_buffer_overflow(self): 6473*da0073e9SAndroid Build Coastguard Worker # Test buffer overflow issue due to inaccurate floating point 6474*da0073e9SAndroid Build Coastguard Worker # representation for integer values. See issue below for details. 6475*da0073e9SAndroid Build Coastguard Worker # https://github.com/pytorch/pytorch/issues/88939 6476*da0073e9SAndroid Build Coastguard Worker 6477*da0073e9SAndroid Build Coastguard Worker def helper(size, dtype, mode, device, is_channels_last): 6478*da0073e9SAndroid Build Coastguard Worker input = torch.ones(size, dtype=dtype, device=device) 6479*da0073e9SAndroid Build Coastguard Worker if is_channels_last: 6480*da0073e9SAndroid Build Coastguard Worker if len(size) == 3: 6481*da0073e9SAndroid Build Coastguard Worker input = input.transpose(1, 2).contiguous().transpose(1, 2) 6482*da0073e9SAndroid Build Coastguard Worker elif len(size) == 4: 6483*da0073e9SAndroid Build Coastguard Worker input = input.to(memory_format=torch.channels_last) 6484*da0073e9SAndroid Build Coastguard Worker else: 6485*da0073e9SAndroid Build Coastguard Worker input = input.to(memory_format=torch.channels_last_3d) 6486*da0073e9SAndroid Build Coastguard Worker output1 = F.interpolate(input, 2, mode=mode, align_corners=True) 6487*da0073e9SAndroid Build Coastguard Worker # reset the corner value and expect the output is changed as well 6488*da0073e9SAndroid Build Coastguard Worker # the output won't be changed on buffer overflow 6489*da0073e9SAndroid Build Coastguard Worker input[(-1,) * len(size)] = 0.5 6490*da0073e9SAndroid Build Coastguard Worker output2 = F.interpolate(input, 2, mode=mode, align_corners=True) 6491*da0073e9SAndroid Build Coastguard Worker self.assertNotEqual(output1, output2) 6492*da0073e9SAndroid Build Coastguard Worker 6493*da0073e9SAndroid Build Coastguard Worker size_dtype_list = [] 6494*da0073e9SAndroid Build Coastguard Worker # We set the size larger than the floating point exactly representable range 6495*da0073e9SAndroid Build Coastguard Worker # float: exact representable range (-2**24,2**24) 6496*da0073e9SAndroid Build Coastguard Worker size_dtype_list.append(([1, 10, 2**24 + 4], torch.float)) 6497*da0073e9SAndroid Build Coastguard Worker size_dtype_list.append(([1, 10, 2, 2**24 + 4], torch.float)) 6498*da0073e9SAndroid Build Coastguard Worker size_dtype_list.append(([1, 10, 2, 2, 2**24 + 4], torch.float)) 6499*da0073e9SAndroid Build Coastguard Worker # bfloat16: exact representable range (-2**8, 2**8) 6500*da0073e9SAndroid Build Coastguard Worker size_dtype_list.append(([1, 10, 2**8 + 4], torch.bfloat16)) 6501*da0073e9SAndroid Build Coastguard Worker size_dtype_list.append(([1, 10, 2, 2**8 + 4], torch.bfloat16)) 6502*da0073e9SAndroid Build Coastguard Worker size_dtype_list.append(([1, 10, 2, 2, 2**8 + 4], torch.bfloat16)) 6503*da0073e9SAndroid Build Coastguard Worker # half: exact representable range (-2**11, 2**11) 6504*da0073e9SAndroid Build Coastguard Worker size_dtype_list.append(([1, 10, 2**11 + 4], torch.half)) 6505*da0073e9SAndroid Build Coastguard Worker size_dtype_list.append(([1, 10, 2, 2**11 + 4], torch.half)) 6506*da0073e9SAndroid Build Coastguard Worker size_dtype_list.append(([1, 10, 2, 2, 2**11 + 4], torch.half)) 6507*da0073e9SAndroid Build Coastguard Worker 6508*da0073e9SAndroid Build Coastguard Worker # TODO: turn on cuda test after buffer overflow issue is fixed in cuda kernel 6509*da0073e9SAndroid Build Coastguard Worker # devices = ['cpu'] + (['cuda'] if torch.cuda.is_available() else []) 6510*da0073e9SAndroid Build Coastguard Worker devices = ['cpu'] 6511*da0073e9SAndroid Build Coastguard Worker 6512*da0073e9SAndroid Build Coastguard Worker for mode in ('linear', 'bilinear', 'bicubic', 'trilinear'): 6513*da0073e9SAndroid Build Coastguard Worker for size_dtype in size_dtype_list: 6514*da0073e9SAndroid Build Coastguard Worker size, dtype = size_dtype 6515*da0073e9SAndroid Build Coastguard Worker if ( 6516*da0073e9SAndroid Build Coastguard Worker mode == 'linear' and len(size) != 3 6517*da0073e9SAndroid Build Coastguard Worker or (mode == 'bilinear' and len(size) != 4) 6518*da0073e9SAndroid Build Coastguard Worker or (mode == 'bicubic' and len(size) != 4) 6519*da0073e9SAndroid Build Coastguard Worker or (mode == 'trilinear' and len(size) != 5) 6520*da0073e9SAndroid Build Coastguard Worker ): 6521*da0073e9SAndroid Build Coastguard Worker continue 6522*da0073e9SAndroid Build Coastguard Worker for device in devices: 6523*da0073e9SAndroid Build Coastguard Worker if ( 6524*da0073e9SAndroid Build Coastguard Worker device == 'cpu' and dtype == torch.half 6525*da0073e9SAndroid Build Coastguard Worker or (device == 'cuda' and dtype == torch.bfloat16) 6526*da0073e9SAndroid Build Coastguard Worker ): 6527*da0073e9SAndroid Build Coastguard Worker # no half precision support on cpu or bfloat16 on cuda yet 6528*da0073e9SAndroid Build Coastguard Worker continue 6529*da0073e9SAndroid Build Coastguard Worker for is_channels_last in (True, False): 6530*da0073e9SAndroid Build Coastguard Worker helper(size, dtype, mode, device, is_channels_last) 6531*da0073e9SAndroid Build Coastguard Worker 6532*da0073e9SAndroid Build Coastguard Worker 6533*da0073e9SAndroid Build Coastguard Worker @set_default_dtype(torch.double) 6534*da0073e9SAndroid Build Coastguard Worker def test_interpolate(self): 6535*da0073e9SAndroid Build Coastguard Worker def _test_interpolate_non_integer_size_warning(in_t, out_size, dim, **kwargs): 6536*da0073e9SAndroid Build Coastguard Worker test_sizes = [float(out_size), 6537*da0073e9SAndroid Build Coastguard Worker torch.tensor(out_size, dtype=torch.float)] 6538*da0073e9SAndroid Build Coastguard Worker for size in test_sizes: 6539*da0073e9SAndroid Build Coastguard Worker self.assertRaisesRegex(TypeError, 6540*da0073e9SAndroid Build Coastguard Worker "(expected size to be one of int or).*", 6541*da0073e9SAndroid Build Coastguard Worker F.interpolate, in_t, size=(size,) * dim, **kwargs) 6542*da0073e9SAndroid Build Coastguard Worker 6543*da0073e9SAndroid Build Coastguard Worker def _test_interpolate_helper(in_t, scale_factor, layer): 6544*da0073e9SAndroid Build Coastguard Worker out_size = int(math.floor(in_t.shape[-1] * scale_factor)) 6545*da0073e9SAndroid Build Coastguard Worker dim = len(in_t.shape) - 2 6546*da0073e9SAndroid Build Coastguard Worker out_shape = [1, 1] + [out_size] * dim 6547*da0073e9SAndroid Build Coastguard Worker with warnings.catch_warnings(record=True) as w: 6548*da0073e9SAndroid Build Coastguard Worker out_t = layer(in_t) 6549*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.ones(out_shape), out_t) 6550*da0073e9SAndroid Build Coastguard Worker 6551*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 6552*da0073e9SAndroid Build Coastguard Worker F.interpolate(in_t, (out_size,) * dim, **kwargs), 6553*da0073e9SAndroid Build Coastguard Worker F.interpolate(in_t, scale_factor=scale_factor, **kwargs)) 6554*da0073e9SAndroid Build Coastguard Worker gradcheck(lambda x: F.interpolate(x, out_size, **kwargs), [in_t], nondet_tol=GRADCHECK_NONDET_TOL) 6555*da0073e9SAndroid Build Coastguard Worker gradgradcheck(lambda x: F.interpolate(x, out_size, **kwargs), [in_t], nondet_tol=GRADCHECK_NONDET_TOL) 6556*da0073e9SAndroid Build Coastguard Worker _test_interpolate_non_integer_size_warning(in_t, out_size, dim, **kwargs) 6557*da0073e9SAndroid Build Coastguard Worker 6558*da0073e9SAndroid Build Coastguard Worker def _make_input(dim, device): 6559*da0073e9SAndroid Build Coastguard Worker size = [1, 1] 6560*da0073e9SAndroid Build Coastguard Worker size += [2] * dim 6561*da0073e9SAndroid Build Coastguard Worker return torch.ones(size, requires_grad=True, device=device) 6562*da0073e9SAndroid Build Coastguard Worker 6563*da0073e9SAndroid Build Coastguard Worker device_list = ['cpu'] 6564*da0073e9SAndroid Build Coastguard Worker if TEST_CUDA: 6565*da0073e9SAndroid Build Coastguard Worker device_list.append('cuda') 6566*da0073e9SAndroid Build Coastguard Worker 6567*da0073e9SAndroid Build Coastguard Worker for device in device_list: 6568*da0073e9SAndroid Build Coastguard Worker for scale_factor in [0.5, 1.5, 2]: 6569*da0073e9SAndroid Build Coastguard Worker for mode in ['nearest', 'area']: 6570*da0073e9SAndroid Build Coastguard Worker kwargs = dict(mode=mode) 6571*da0073e9SAndroid Build Coastguard Worker m = nn.Upsample(scale_factor=scale_factor, **kwargs).to(device) 6572*da0073e9SAndroid Build Coastguard Worker for input in [_make_input(1, device), _make_input(2, device), _make_input(3, device)]: 6573*da0073e9SAndroid Build Coastguard Worker _test_interpolate_helper(input, scale_factor, m) 6574*da0073e9SAndroid Build Coastguard Worker 6575*da0073e9SAndroid Build Coastguard Worker for align_corners in [True, False]: 6576*da0073e9SAndroid Build Coastguard Worker kwargs = dict(mode='linear', align_corners=align_corners) 6577*da0073e9SAndroid Build Coastguard Worker m = nn.Upsample(scale_factor=scale_factor, **kwargs).to(device) 6578*da0073e9SAndroid Build Coastguard Worker _test_interpolate_helper(_make_input(1, device), scale_factor, m) 6579*da0073e9SAndroid Build Coastguard Worker 6580*da0073e9SAndroid Build Coastguard Worker kwargs = dict(mode='bilinear', align_corners=align_corners) 6581*da0073e9SAndroid Build Coastguard Worker m = nn.Upsample(scale_factor=scale_factor, **kwargs).to(device) 6582*da0073e9SAndroid Build Coastguard Worker _test_interpolate_helper(_make_input(2, device), scale_factor, m) 6583*da0073e9SAndroid Build Coastguard Worker 6584*da0073e9SAndroid Build Coastguard Worker kwargs = dict(mode='bicubic', align_corners=align_corners) 6585*da0073e9SAndroid Build Coastguard Worker 6586*da0073e9SAndroid Build Coastguard Worker def m(t): 6587*da0073e9SAndroid Build Coastguard Worker return F.interpolate(t, scale_factor=scale_factor, **kwargs).to(device) 6588*da0073e9SAndroid Build Coastguard Worker _test_interpolate_helper(_make_input(2, device), scale_factor, m) 6589*da0073e9SAndroid Build Coastguard Worker 6590*da0073e9SAndroid Build Coastguard Worker kwargs = dict(mode='trilinear', align_corners=align_corners) 6591*da0073e9SAndroid Build Coastguard Worker m = nn.Upsample(scale_factor=scale_factor, **kwargs).to(device) 6592*da0073e9SAndroid Build Coastguard Worker _test_interpolate_helper(_make_input(3, device), scale_factor, m) 6593*da0073e9SAndroid Build Coastguard Worker 6594*da0073e9SAndroid Build Coastguard Worker def test_linear_broadcasting(self): 6595*da0073e9SAndroid Build Coastguard Worker m = nn.Linear(5, 8) 6596*da0073e9SAndroid Build Coastguard Worker inp = torch.randn(2, 3, 5) 6597*da0073e9SAndroid Build Coastguard Worker expected = m(inp.view(6, 5)).view(2, 3, 8) 6598*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected, m(inp)) 6599*da0073e9SAndroid Build Coastguard Worker 6600*da0073e9SAndroid Build Coastguard Worker def test_linear_raise_on_scalar_input(self): 6601*da0073e9SAndroid Build Coastguard Worker # This used to cause an int underflow issue when reshaping the input 6602*da0073e9SAndroid Build Coastguard Worker # see https://github.com/pytorch/pytorch/issues/119161 6603*da0073e9SAndroid Build Coastguard Worker m = nn.Linear(1, 1) 6604*da0073e9SAndroid Build Coastguard Worker inp = torch.ones(1).squeeze() 6605*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, ".*both arguments.*1D.*"): 6606*da0073e9SAndroid Build Coastguard Worker m(inp) 6607*da0073e9SAndroid Build Coastguard Worker 6608*da0073e9SAndroid Build Coastguard Worker @parametrize_test('device', ['cpu'] + (['cuda'] if TEST_CUDA else [])) 6609*da0073e9SAndroid Build Coastguard Worker @parametrize_test('bias', [ 6610*da0073e9SAndroid Build Coastguard Worker subtest(False, name='nobias'), subtest(True, name='bias')]) 6611*da0073e9SAndroid Build Coastguard Worker @parametrize_test('weight_layout', [ 6612*da0073e9SAndroid Build Coastguard Worker subtest(torch.strided, name='weightStrided'), 6613*da0073e9SAndroid Build Coastguard Worker subtest(torch.sparse_coo, name='weightCOO'), 6614*da0073e9SAndroid Build Coastguard Worker subtest(torch.sparse_csr, name='weightCSR'), 6615*da0073e9SAndroid Build Coastguard Worker subtest(torch.sparse_csc, name='weightCSC'), 6616*da0073e9SAndroid Build Coastguard Worker # TODO: addmm: computation on CPU is not implemented for Strided + Strided @ SparseBsr 6617*da0073e9SAndroid Build Coastguard Worker # subtest(torch.sparse_bsr, name='weightBSR'), 6618*da0073e9SAndroid Build Coastguard Worker # subtest(torch.sparse_bsc, name='weightBSC'), 6619*da0073e9SAndroid Build Coastguard Worker ]) 6620*da0073e9SAndroid Build Coastguard Worker def test_linear_autograd(self, device, bias, weight_layout): 6621*da0073e9SAndroid Build Coastguard Worker module = nn.Linear(4, 4, bias=bias, device=device) 6622*da0073e9SAndroid Build Coastguard Worker if weight_layout == torch.strided: 6623*da0073e9SAndroid Build Coastguard Worker pass 6624*da0073e9SAndroid Build Coastguard Worker elif weight_layout == torch.sparse_csr: 6625*da0073e9SAndroid Build Coastguard Worker module.weight = nn.Parameter(module.weight.to_sparse_csr()) 6626*da0073e9SAndroid Build Coastguard Worker elif weight_layout == torch.sparse_csc: 6627*da0073e9SAndroid Build Coastguard Worker module.weight = nn.Parameter(module.weight.to_sparse_csc()) 6628*da0073e9SAndroid Build Coastguard Worker elif weight_layout == torch.sparse_bsr: 6629*da0073e9SAndroid Build Coastguard Worker module.weight = nn.Parameter(module.weight.to_sparse_bsr((2, 2))) 6630*da0073e9SAndroid Build Coastguard Worker elif weight_layout == torch.sparse_bsc: 6631*da0073e9SAndroid Build Coastguard Worker module.weight = nn.Parameter(module.weight.to_sparse_bsc((2, 2))) 6632*da0073e9SAndroid Build Coastguard Worker elif weight_layout == torch.sparse_coo: 6633*da0073e9SAndroid Build Coastguard Worker module.weight = nn.Parameter(module.weight.to_sparse_coo()) 6634*da0073e9SAndroid Build Coastguard Worker else: 6635*da0073e9SAndroid Build Coastguard Worker raise AssertionError 6636*da0073e9SAndroid Build Coastguard Worker 6637*da0073e9SAndroid Build Coastguard Worker inp = torch.randn(4, requires_grad=True, device=device) 6638*da0073e9SAndroid Build Coastguard Worker res = module(inp) 6639*da0073e9SAndroid Build Coastguard Worker if bias: 6640*da0073e9SAndroid Build Coastguard Worker expected = (torch.einsum("i,ji->j", inp, module.weight.to_dense())) + module.bias 6641*da0073e9SAndroid Build Coastguard Worker else: 6642*da0073e9SAndroid Build Coastguard Worker expected = (torch.einsum("i,ji->j", inp, module.weight.to_dense())) 6643*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res, expected) 6644*da0073e9SAndroid Build Coastguard Worker 6645*da0073e9SAndroid Build Coastguard Worker grad_output = torch.randn(4, device=device) 6646*da0073e9SAndroid Build Coastguard Worker grads = torch.autograd.grad(res, [module.weight, inp], grad_output) 6647*da0073e9SAndroid Build Coastguard Worker grads_expected = torch.autograd.grad(expected, [module.weight, inp], grad_output) 6648*da0073e9SAndroid Build Coastguard Worker 6649*da0073e9SAndroid Build Coastguard Worker self.assertEqual(grads_expected[0].layout, weight_layout) 6650*da0073e9SAndroid Build Coastguard Worker 6651*da0073e9SAndroid Build Coastguard Worker for g, ge in zip(grads, grads_expected): 6652*da0073e9SAndroid Build Coastguard Worker self.assertEqual(g, ge) 6653*da0073e9SAndroid Build Coastguard Worker 6654*da0073e9SAndroid Build Coastguard Worker def test_bilinear(self): 6655*da0073e9SAndroid Build Coastguard Worker module = nn.Bilinear(10, 10, 8) 6656*da0073e9SAndroid Build Coastguard Worker input1 = torch.randn(4, 10, requires_grad=True) 6657*da0073e9SAndroid Build Coastguard Worker input2 = torch.randn(4, 10, requires_grad=True) 6658*da0073e9SAndroid Build Coastguard Worker grad_output = torch.randn(4, 8) 6659*da0073e9SAndroid Build Coastguard Worker res = module(input1, input2) 6660*da0073e9SAndroid Build Coastguard Worker expected = (torch.einsum("bi,kij,bj->bk", input1, module.weight, input2) + 6661*da0073e9SAndroid Build Coastguard Worker module.bias) 6662*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res, expected) 6663*da0073e9SAndroid Build Coastguard Worker grads = torch.autograd.grad(res, [module.weight, module.bias, input1, input2], grad_output) 6664*da0073e9SAndroid Build Coastguard Worker grads_expected = torch.autograd.grad(expected, [module.weight, module.bias, input1, input2], grad_output) 6665*da0073e9SAndroid Build Coastguard Worker for g, ge in zip(grads, grads_expected): 6666*da0073e9SAndroid Build Coastguard Worker self.assertEqual(g, ge) 6667*da0073e9SAndroid Build Coastguard Worker 6668*da0073e9SAndroid Build Coastguard Worker def test_bilinear_non_contiguous(self): 6669*da0073e9SAndroid Build Coastguard Worker module = nn.Bilinear(7, 7, 5) 6670*da0073e9SAndroid Build Coastguard Worker input1 = torch.randn(4, 7, 10, requires_grad=True) 6671*da0073e9SAndroid Build Coastguard Worker input2 = torch.randn(4, 7, 10, requires_grad=True) 6672*da0073e9SAndroid Build Coastguard Worker input1_tp = input1.transpose(1, 2) 6673*da0073e9SAndroid Build Coastguard Worker input2_tp = input2.transpose(1, 2) 6674*da0073e9SAndroid Build Coastguard Worker 6675*da0073e9SAndroid Build Coastguard Worker grad_output = torch.randn(4, 10, 5) 6676*da0073e9SAndroid Build Coastguard Worker 6677*da0073e9SAndroid Build Coastguard Worker def run(input1_tp, input2_tp): 6678*da0073e9SAndroid Build Coastguard Worker input1.grad = input2.grad = None 6679*da0073e9SAndroid Build Coastguard Worker output = module(input1_tp, input2_tp) 6680*da0073e9SAndroid Build Coastguard Worker output.backward(grad_output) 6681*da0073e9SAndroid Build Coastguard Worker 6682*da0073e9SAndroid Build Coastguard Worker return output.data, input1.grad.data, input2.grad.data 6683*da0073e9SAndroid Build Coastguard Worker 6684*da0073e9SAndroid Build Coastguard Worker out_nc, g1_nc, g2_nc = run(input1_tp, input2_tp) 6685*da0073e9SAndroid Build Coastguard Worker input1_tp = input1_tp.contiguous() 6686*da0073e9SAndroid Build Coastguard Worker input2_tp = input2_tp.contiguous() 6687*da0073e9SAndroid Build Coastguard Worker out, g1, g2 = run(input1_tp, input2_tp) 6688*da0073e9SAndroid Build Coastguard Worker 6689*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out, out_nc) 6690*da0073e9SAndroid Build Coastguard Worker self.assertEqual(g1, g1_nc) 6691*da0073e9SAndroid Build Coastguard Worker self.assertEqual(g2, g2_nc) 6692*da0073e9SAndroid Build Coastguard Worker 6693*da0073e9SAndroid Build Coastguard Worker def test_bilinear_no_bias(self): 6694*da0073e9SAndroid Build Coastguard Worker module = nn.Bilinear(10, 10, 8, dtype=torch.double) 6695*da0073e9SAndroid Build Coastguard Worker module_no_bias = nn.Bilinear(10, 10, 8, False, dtype=torch.double) 6696*da0073e9SAndroid Build Coastguard Worker 6697*da0073e9SAndroid Build Coastguard Worker module.bias.data.zero_() 6698*da0073e9SAndroid Build Coastguard Worker module.weight.data.copy_(module_no_bias.weight) 6699*da0073e9SAndroid Build Coastguard Worker 6700*da0073e9SAndroid Build Coastguard Worker input1 = torch.randn(4, 10, requires_grad=True, dtype=torch.double) 6701*da0073e9SAndroid Build Coastguard Worker input2 = torch.randn(4, 10, requires_grad=True, dtype=torch.double) 6702*da0073e9SAndroid Build Coastguard Worker grad_output = torch.randn(4, 8, dtype=torch.double) 6703*da0073e9SAndroid Build Coastguard Worker 6704*da0073e9SAndroid Build Coastguard Worker def run(net): 6705*da0073e9SAndroid Build Coastguard Worker input1.grad = input2.grad = None 6706*da0073e9SAndroid Build Coastguard Worker output = net(input1, input2) 6707*da0073e9SAndroid Build Coastguard Worker output.backward(grad_output) 6708*da0073e9SAndroid Build Coastguard Worker 6709*da0073e9SAndroid Build Coastguard Worker return output.data, input1.grad.data, input2.grad.data 6710*da0073e9SAndroid Build Coastguard Worker 6711*da0073e9SAndroid Build Coastguard Worker out, g1, g2 = run(module) 6712*da0073e9SAndroid Build Coastguard Worker out_nb, g1_nb, g2_nb = run(module_no_bias) 6713*da0073e9SAndroid Build Coastguard Worker 6714*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out, out_nb) 6715*da0073e9SAndroid Build Coastguard Worker self.assertEqual(g1, g1_nb) 6716*da0073e9SAndroid Build Coastguard Worker self.assertEqual(g2, g2_nb) 6717*da0073e9SAndroid Build Coastguard Worker 6718*da0073e9SAndroid Build Coastguard Worker _assertGradAndGradgradChecks(self, 6719*da0073e9SAndroid Build Coastguard Worker lambda x1, x2: F.bilinear(x1, x2, module_no_bias.weight, module_no_bias.bias), 6720*da0073e9SAndroid Build Coastguard Worker (input1, input2)) 6721*da0073e9SAndroid Build Coastguard Worker 6722*da0073e9SAndroid Build Coastguard Worker def test_bilinear_broadcasting(self): 6723*da0073e9SAndroid Build Coastguard Worker m = nn.Bilinear(5, 6, 8) 6724*da0073e9SAndroid Build Coastguard Worker input1 = torch.randn(2, 3, 5) 6725*da0073e9SAndroid Build Coastguard Worker input2 = torch.randn(2, 3, 6) 6726*da0073e9SAndroid Build Coastguard Worker expected = m(input1.view(6, 5), input2.view(6, 6)).view(2, 3, 8) 6727*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected, m(input1, input2)) 6728*da0073e9SAndroid Build Coastguard Worker 6729*da0073e9SAndroid Build Coastguard Worker def test_fold_invalid_arg(self): 6730*da0073e9SAndroid Build Coastguard Worker # input.size(1) not divisible by \prod(kernel_size) 6731*da0073e9SAndroid Build Coastguard Worker 6732*da0073e9SAndroid Build Coastguard Worker fold = nn.Fold(output_size=(4, 5), kernel_size=(2, 3)) 6733*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, r"be divisible by the product of kernel_size"): 6734*da0073e9SAndroid Build Coastguard Worker fold(torch.randn(1, 5, 9)) 6735*da0073e9SAndroid Build Coastguard Worker 6736*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, r"be divisible by the product of kernel_size"): 6737*da0073e9SAndroid Build Coastguard Worker fold(torch.randn(1, 19, 9)) 6738*da0073e9SAndroid Build Coastguard Worker 6739*da0073e9SAndroid Build Coastguard Worker # input.size(2) not matching the total number of sliding blocks 6740*da0073e9SAndroid Build Coastguard Worker 6741*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, r"match the calculated number of sliding blocks"): 6742*da0073e9SAndroid Build Coastguard Worker fold = nn.Fold(output_size=(4, 5), kernel_size=(2, 3)) 6743*da0073e9SAndroid Build Coastguard Worker fold(torch.randn(1, 6, 10)) 6744*da0073e9SAndroid Build Coastguard Worker 6745*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, r"match the calculated number of sliding blocks"): 6746*da0073e9SAndroid Build Coastguard Worker fold = nn.Fold(output_size=(4, 5), kernel_size=(2, 3), stride=(2, 2)) 6747*da0073e9SAndroid Build Coastguard Worker fold(torch.randn(1, 6, 5)) 6748*da0073e9SAndroid Build Coastguard Worker 6749*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, r"match the calculated number of sliding blocks"): 6750*da0073e9SAndroid Build Coastguard Worker fold = nn.Fold(output_size=(4, 5), kernel_size=(2, 3), stride=(2, 2), dilation=(1, 2), padding=(2, 0)) 6751*da0073e9SAndroid Build Coastguard Worker fold(torch.randn(1, 6, 5)) # should be 4 * 1 = 4 sliding blocks 6752*da0073e9SAndroid Build Coastguard Worker 6753*da0073e9SAndroid Build Coastguard Worker fold = nn.Fold(output_size=(4, 5), kernel_size=(2, 2), stride=1, dilation=8, padding=0) 6754*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, r"calculated shape of the array of sliding blocks as"): 6755*da0073e9SAndroid Build Coastguard Worker fold(torch.randn(1, 12, 12)) 6756*da0073e9SAndroid Build Coastguard Worker 6757*da0073e9SAndroid Build Coastguard Worker def test_unfold_invalid_arg(self): 6758*da0073e9SAndroid Build Coastguard Worker # input wrong dimension 6759*da0073e9SAndroid Build Coastguard Worker 6760*da0073e9SAndroid Build Coastguard Worker unfold = nn.Unfold(kernel_size=(2, 3)) 6761*da0073e9SAndroid Build Coastguard Worker 6762*da0073e9SAndroid Build Coastguard Worker # calculated output shape is too small 6763*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, r"its components must be at least one"): 6764*da0073e9SAndroid Build Coastguard Worker unfold = nn.Unfold(kernel_size=(2, 3)) 6765*da0073e9SAndroid Build Coastguard Worker unfold(torch.randn(1, 2, 2, 2)) 6766*da0073e9SAndroid Build Coastguard Worker 6767*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, r"its components must be at least one"): 6768*da0073e9SAndroid Build Coastguard Worker unfold = nn.Unfold(kernel_size=(5, 3), padding=(1, 1)) 6769*da0073e9SAndroid Build Coastguard Worker unfold(torch.randn(1, 2, 2, 3)) 6770*da0073e9SAndroid Build Coastguard Worker 6771*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, r"its components must be at least one"): 6772*da0073e9SAndroid Build Coastguard Worker unfold = nn.Unfold(kernel_size=(1, 3), padding=(1, 1), dilation=(1, 2)) 6773*da0073e9SAndroid Build Coastguard Worker unfold(torch.randn(1, 2, 2, 2)) 6774*da0073e9SAndroid Build Coastguard Worker 6775*da0073e9SAndroid Build Coastguard Worker def test_softmin(self): 6776*da0073e9SAndroid Build Coastguard Worker x = torch.randn(2, 16) 6777*da0073e9SAndroid Build Coastguard Worker self.assertEqual(F.softmin(x, 1), F.softmax(-x, 1)) 6778*da0073e9SAndroid Build Coastguard Worker self.assertEqual(F.softmin(x, 0), F.softmax(-x, 0)) 6779*da0073e9SAndroid Build Coastguard Worker 6780*da0073e9SAndroid Build Coastguard Worker def test_adaptive_log_softmax(self): 6781*da0073e9SAndroid Build Coastguard Worker # args validation 6782*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(ValueError): 6783*da0073e9SAndroid Build Coastguard Worker _ = nn.AdaptiveLogSoftmaxWithLoss(16, 20, [5, 15, 15], div_value=2.) 6784*da0073e9SAndroid Build Coastguard Worker 6785*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(ValueError): 6786*da0073e9SAndroid Build Coastguard Worker _ = nn.AdaptiveLogSoftmaxWithLoss(16, 20, [5, 15, 10], div_value=2.) 6787*da0073e9SAndroid Build Coastguard Worker 6788*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(ValueError): 6789*da0073e9SAndroid Build Coastguard Worker _ = nn.AdaptiveLogSoftmaxWithLoss(16, 20, [5, 10, 25], div_value=2.) 6790*da0073e9SAndroid Build Coastguard Worker 6791*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(ValueError, "cutoffs should be a sequence of unique,"): 6792*da0073e9SAndroid Build Coastguard Worker _ = nn.AdaptiveLogSoftmaxWithLoss(16, 20, [5, 10, 20], div_value=2.) 6793*da0073e9SAndroid Build Coastguard Worker 6794*da0073e9SAndroid Build Coastguard Worker # not raise 6795*da0073e9SAndroid Build Coastguard Worker _ = nn.AdaptiveLogSoftmaxWithLoss(16, 20, [5, 10, 19], div_value=2.) 6796*da0073e9SAndroid Build Coastguard Worker 6797*da0073e9SAndroid Build Coastguard Worker # input shapes 6798*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, r"Input and target should have the same size"): 6799*da0073e9SAndroid Build Coastguard Worker asfm = nn.AdaptiveLogSoftmaxWithLoss(16, 20, [5, 10, 15], div_value=2.) 6800*da0073e9SAndroid Build Coastguard Worker x = torch.randn(2, 16) 6801*da0073e9SAndroid Build Coastguard Worker y = torch.tensor([0, 5, 10]) 6802*da0073e9SAndroid Build Coastguard Worker asfm(x, y) 6803*da0073e9SAndroid Build Coastguard Worker 6804*da0073e9SAndroid Build Coastguard Worker # out-of-bound targets 6805*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, r"Target values should be in"): 6806*da0073e9SAndroid Build Coastguard Worker asfm = nn.AdaptiveLogSoftmaxWithLoss(16, 20, [5, 10, 15], div_value=2.) 6807*da0073e9SAndroid Build Coastguard Worker x = torch.randn(2, 16) 6808*da0073e9SAndroid Build Coastguard Worker y = torch.tensor([0, 20]) 6809*da0073e9SAndroid Build Coastguard Worker asfm(x, y) 6810*da0073e9SAndroid Build Coastguard Worker 6811*da0073e9SAndroid Build Coastguard Worker # cluster sizes 6812*da0073e9SAndroid Build Coastguard Worker asfm = nn.AdaptiveLogSoftmaxWithLoss(16, 20, [5, 10, 15], div_value=2.) 6813*da0073e9SAndroid Build Coastguard Worker x = torch.randn(2, 16) 6814*da0073e9SAndroid Build Coastguard Worker y = torch.tensor([0, 17]) 6815*da0073e9SAndroid Build Coastguard Worker 6816*da0073e9SAndroid Build Coastguard Worker self.assertEqual(asfm.head.weight.size(), (5 + 3, 16)) # 5 targets in head, 3 clusters, dimensionality 16 6817*da0073e9SAndroid Build Coastguard Worker self.assertEqual(asfm.tail[0][1].weight.size(), (5, 8)) # 5 targets in this cluster, dimensionality 8 6818*da0073e9SAndroid Build Coastguard Worker self.assertEqual(asfm.tail[1][1].weight.size(), (5, 4)) 6819*da0073e9SAndroid Build Coastguard Worker self.assertEqual(asfm.tail[2][1].weight.size(), (5, 2)) 6820*da0073e9SAndroid Build Coastguard Worker self.assertEqual(asfm(x, y).output.size(), (2, )) 6821*da0073e9SAndroid Build Coastguard Worker 6822*da0073e9SAndroid Build Coastguard Worker # test no_batch_dim support 6823*da0073e9SAndroid Build Coastguard Worker asfm = nn.AdaptiveLogSoftmaxWithLoss(16, 20, [5, 10, 15], div_value=2.) 6824*da0073e9SAndroid Build Coastguard Worker x = torch.randn(1, 16) 6825*da0073e9SAndroid Build Coastguard Worker y = torch.tensor([17]) 6826*da0073e9SAndroid Build Coastguard Worker x2 = x.squeeze(0) 6827*da0073e9SAndroid Build Coastguard Worker y2 = y.squeeze(0) 6828*da0073e9SAndroid Build Coastguard Worker self.assertEqual(asfm(x, y).output.squeeze(0), asfm(x2, y2).output) 6829*da0073e9SAndroid Build Coastguard Worker 6830*da0073e9SAndroid Build Coastguard Worker # log_probs actually returns log_proba 6831*da0073e9SAndroid Build Coastguard Worker asfm = nn.AdaptiveLogSoftmaxWithLoss(8, 4, [2], div_value=2.) 6832*da0073e9SAndroid Build Coastguard Worker x = torch.randn(4, 8) 6833*da0073e9SAndroid Build Coastguard Worker logprob_out = asfm.log_prob(x) 6834*da0073e9SAndroid Build Coastguard Worker 6835*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.exp(logprob_out).data.sum(1), torch.ones(4)) 6836*da0073e9SAndroid Build Coastguard Worker 6837*da0073e9SAndroid Build Coastguard Worker # forward returns the same thing as log_probs 6838*da0073e9SAndroid Build Coastguard Worker for v in [0, 1, 2, 3]: 6839*da0073e9SAndroid Build Coastguard Worker y = torch.full((4,), v, dtype=torch.long) 6840*da0073e9SAndroid Build Coastguard Worker out, loss = asfm(x, y) 6841*da0073e9SAndroid Build Coastguard Worker 6842*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out, logprob_out.gather(1, y.unsqueeze(1)).squeeze()) 6843*da0073e9SAndroid Build Coastguard Worker self.assertEqual(loss, F.nll_loss(logprob_out, y)) 6844*da0073e9SAndroid Build Coastguard Worker 6845*da0073e9SAndroid Build Coastguard Worker # predict 6846*da0073e9SAndroid Build Coastguard Worker x = torch.randn(64, 8).abs_() 6847*da0073e9SAndroid Build Coastguard Worker 6848*da0073e9SAndroid Build Coastguard Worker # argmax in shortlist 6849*da0073e9SAndroid Build Coastguard Worker asfm = nn.AdaptiveLogSoftmaxWithLoss(8, 10, [4, 8], div_value=2., head_bias=True) 6850*da0073e9SAndroid Build Coastguard Worker asfm.head.weight.data.abs_() 6851*da0073e9SAndroid Build Coastguard Worker asfm.head.bias.data.abs_() 6852*da0073e9SAndroid Build Coastguard Worker asfm.head.weight.data[asfm.shortlist_size:, :].zero_() 6853*da0073e9SAndroid Build Coastguard Worker 6854*da0073e9SAndroid Build Coastguard Worker out = asfm.predict(x) 6855*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out, asfm.log_prob(x).argmax(dim=1)) 6856*da0073e9SAndroid Build Coastguard Worker 6857*da0073e9SAndroid Build Coastguard Worker # argmax outside of shortlist 6858*da0073e9SAndroid Build Coastguard Worker asfm = nn.AdaptiveLogSoftmaxWithLoss(8, 10, [4, 8], div_value=2., head_bias=True) 6859*da0073e9SAndroid Build Coastguard Worker asfm.head.weight.data.abs_() 6860*da0073e9SAndroid Build Coastguard Worker asfm.head.bias.data.abs_() 6861*da0073e9SAndroid Build Coastguard Worker asfm.head.weight.data[:asfm.shortlist_size, :].zero_() 6862*da0073e9SAndroid Build Coastguard Worker 6863*da0073e9SAndroid Build Coastguard Worker out = asfm.predict(x) 6864*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out, asfm.log_prob(x).argmax(dim=1)) 6865*da0073e9SAndroid Build Coastguard Worker 6866*da0073e9SAndroid Build Coastguard Worker # half of the argmax in shortlist, half in clusters 6867*da0073e9SAndroid Build Coastguard Worker asfm = nn.AdaptiveLogSoftmaxWithLoss(8, 10, [4, 8], div_value=2., head_bias=True) 6868*da0073e9SAndroid Build Coastguard Worker asfm.head.weight.data.abs_() 6869*da0073e9SAndroid Build Coastguard Worker asfm.head.bias.data.abs_() 6870*da0073e9SAndroid Build Coastguard Worker 6871*da0073e9SAndroid Build Coastguard Worker x[:32, :asfm.shortlist_size].zero_() 6872*da0073e9SAndroid Build Coastguard Worker x[32:, asfm.shortlist_size:].zero_() 6873*da0073e9SAndroid Build Coastguard Worker 6874*da0073e9SAndroid Build Coastguard Worker asfm.head.weight.data[:asfm.shortlist_size, asfm.shortlist_size:].zero_() 6875*da0073e9SAndroid Build Coastguard Worker asfm.head.weight.data[asfm.shortlist_size:, :asfm.shortlist_size].zero_() 6876*da0073e9SAndroid Build Coastguard Worker 6877*da0073e9SAndroid Build Coastguard Worker out = asfm.predict(x) 6878*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out, asfm.log_prob(x).argmax(dim=1)) 6879*da0073e9SAndroid Build Coastguard Worker 6880*da0073e9SAndroid Build Coastguard Worker def test_cross_entropy_loss(self, dtype=torch.bfloat16): 6881*da0073e9SAndroid Build Coastguard Worker loss_cpu = nn.CrossEntropyLoss().cpu() 6882*da0073e9SAndroid Build Coastguard Worker inputf = torch.randn(15, 10, device="cpu", dtype=torch.float, requires_grad=True) 6883*da0073e9SAndroid Build Coastguard Worker input = inputf.to(dtype).detach().requires_grad_(True) 6884*da0073e9SAndroid Build Coastguard Worker target = torch.empty(15, dtype=torch.long).random_(10) 6885*da0073e9SAndroid Build Coastguard Worker 6886*da0073e9SAndroid Build Coastguard Worker outf = loss_cpu(inputf, target) 6887*da0073e9SAndroid Build Coastguard Worker out = loss_cpu(input, target) 6888*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out, outf.to(dtype=dtype), atol=1e-1, rtol=0) 6889*da0073e9SAndroid Build Coastguard Worker 6890*da0073e9SAndroid Build Coastguard Worker outf.backward() 6891*da0073e9SAndroid Build Coastguard Worker out.backward() 6892*da0073e9SAndroid Build Coastguard Worker self.assertEqual(input.grad, inputf.grad.to(dtype=dtype), atol=1e-1, rtol=0) 6893*da0073e9SAndroid Build Coastguard Worker 6894*da0073e9SAndroid Build Coastguard Worker def test_cross_entropy_loss_precision(self): 6895*da0073e9SAndroid Build Coastguard Worker # Regression test for #55657 6896*da0073e9SAndroid Build Coastguard Worker loss_cpu = nn.CrossEntropyLoss().cpu() 6897*da0073e9SAndroid Build Coastguard Worker inputf = torch.randn(128, 2, 768, 768, device="cpu", dtype=torch.float) 6898*da0073e9SAndroid Build Coastguard Worker inputd = inputf.double() 6899*da0073e9SAndroid Build Coastguard Worker target = torch.randint(2, (128, 768, 768), dtype=torch.long) 6900*da0073e9SAndroid Build Coastguard Worker 6901*da0073e9SAndroid Build Coastguard Worker outf = loss_cpu(inputf, target) 6902*da0073e9SAndroid Build Coastguard Worker outd = loss_cpu(inputd, target) 6903*da0073e9SAndroid Build Coastguard Worker self.assertEqual(outf, outd, exact_dtype=False) 6904*da0073e9SAndroid Build Coastguard Worker 6905*da0073e9SAndroid Build Coastguard Worker def test_cross_entropy_loss_zero_div(self): 6906*da0073e9SAndroid Build Coastguard Worker # Test for issue #73165 6907*da0073e9SAndroid Build Coastguard Worker input_1 = torch.rand([5, 0], dtype=torch.float32) 6908*da0073e9SAndroid Build Coastguard Worker input_2 = torch.rand([5, 0], dtype=torch.float32) 6909*da0073e9SAndroid Build Coastguard Worker torch.nn.CrossEntropyLoss()(input_1, input_2) 6910*da0073e9SAndroid Build Coastguard Worker 6911*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") 6912*da0073e9SAndroid Build Coastguard Worker def test_convert_sync_batchnorm(self): 6913*da0073e9SAndroid Build Coastguard Worker module = torch.nn.Sequential( 6914*da0073e9SAndroid Build Coastguard Worker torch.nn.BatchNorm1d(100), 6915*da0073e9SAndroid Build Coastguard Worker torch.nn.InstanceNorm1d(100) 6916*da0073e9SAndroid Build Coastguard Worker ).cuda() 6917*da0073e9SAndroid Build Coastguard Worker 6918*da0073e9SAndroid Build Coastguard Worker # necessary to have an anchor point for comparison, in case the 6919*da0073e9SAndroid Build Coastguard Worker # convert_sync_batchnorm updates in place 6920*da0073e9SAndroid Build Coastguard Worker comp_module = torch.nn.Sequential( 6921*da0073e9SAndroid Build Coastguard Worker torch.nn.BatchNorm1d(100), 6922*da0073e9SAndroid Build Coastguard Worker torch.nn.InstanceNorm1d(100) 6923*da0073e9SAndroid Build Coastguard Worker ).cuda() 6924*da0073e9SAndroid Build Coastguard Worker comp_module.load_state_dict(module.state_dict()) 6925*da0073e9SAndroid Build Coastguard Worker 6926*da0073e9SAndroid Build Coastguard Worker sync_bn_module = torch.nn.SyncBatchNorm.convert_sync_batchnorm(module) 6927*da0073e9SAndroid Build Coastguard Worker children = list(sync_bn_module.children()) 6928*da0073e9SAndroid Build Coastguard Worker self.assertEqual(children[0].__class__, torch.nn.SyncBatchNorm) 6929*da0073e9SAndroid Build Coastguard Worker self.assertEqual(children[1].__class__, torch.nn.InstanceNorm1d) 6930*da0073e9SAndroid Build Coastguard Worker 6931*da0073e9SAndroid Build Coastguard Worker for layer, converted_layer in zip(comp_module.children(), sync_bn_module.children()): 6932*da0073e9SAndroid Build Coastguard Worker for key in layer.state_dict().keys(): 6933*da0073e9SAndroid Build Coastguard Worker self.assertEqual(layer.state_dict()[key].device, converted_layer.state_dict()[key].device) 6934*da0073e9SAndroid Build Coastguard Worker self.assertEqual(layer.state_dict()[key], converted_layer.state_dict()[key]) 6935*da0073e9SAndroid Build Coastguard Worker 6936*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_CUDA, "CUDA not available") 6937*da0073e9SAndroid Build Coastguard Worker def test_sync_batchnorm_backward_elemt(self): 6938*da0073e9SAndroid Build Coastguard Worker device = 'cuda' 6939*da0073e9SAndroid Build Coastguard Worker saved_input = torch.rand(2, 3, 2, 1, device=device) 6940*da0073e9SAndroid Build Coastguard Worker grad_output = torch.rand(2, 3, 2, 1, device=device) 6941*da0073e9SAndroid Build Coastguard Worker mean = torch.rand(3, device=device) 6942*da0073e9SAndroid Build Coastguard Worker invstd = torch.rand(3, device=device) 6943*da0073e9SAndroid Build Coastguard Worker weight = torch.rand(3, device=device) 6944*da0073e9SAndroid Build Coastguard Worker sum_dy = torch.rand(3, device=device) 6945*da0073e9SAndroid Build Coastguard Worker sum_dy_xmu = torch.rand(3, device=device) 6946*da0073e9SAndroid Build Coastguard Worker count_tensor = torch.tensor([5, 5, 5], dtype=torch.int32, device=device) 6947*da0073e9SAndroid Build Coastguard Worker 6948*da0073e9SAndroid Build Coastguard Worker gI_contiguous = torch.batch_norm_backward_elemt( 6949*da0073e9SAndroid Build Coastguard Worker grad_output, 6950*da0073e9SAndroid Build Coastguard Worker saved_input, 6951*da0073e9SAndroid Build Coastguard Worker mean, 6952*da0073e9SAndroid Build Coastguard Worker invstd, 6953*da0073e9SAndroid Build Coastguard Worker weight, 6954*da0073e9SAndroid Build Coastguard Worker sum_dy, 6955*da0073e9SAndroid Build Coastguard Worker sum_dy_xmu, 6956*da0073e9SAndroid Build Coastguard Worker count_tensor 6957*da0073e9SAndroid Build Coastguard Worker ) 6958*da0073e9SAndroid Build Coastguard Worker 6959*da0073e9SAndroid Build Coastguard Worker # Test batch_norm_backward_elemt gives the same answer for all 6960*da0073e9SAndroid Build Coastguard Worker # combinations of contiguous as channels_last input 6961*da0073e9SAndroid Build Coastguard Worker for a, b in [ 6962*da0073e9SAndroid Build Coastguard Worker (torch.channels_last, torch.contiguous_format), 6963*da0073e9SAndroid Build Coastguard Worker (torch.contiguous_format, torch.channels_last), 6964*da0073e9SAndroid Build Coastguard Worker (torch.channels_last, torch.channels_last), 6965*da0073e9SAndroid Build Coastguard Worker ]: 6966*da0073e9SAndroid Build Coastguard Worker gI_actual = torch.batch_norm_backward_elemt( 6967*da0073e9SAndroid Build Coastguard Worker grad_output.contiguous(memory_format=a), 6968*da0073e9SAndroid Build Coastguard Worker saved_input.contiguous(memory_format=b), 6969*da0073e9SAndroid Build Coastguard Worker mean, 6970*da0073e9SAndroid Build Coastguard Worker invstd, 6971*da0073e9SAndroid Build Coastguard Worker weight, 6972*da0073e9SAndroid Build Coastguard Worker sum_dy, 6973*da0073e9SAndroid Build Coastguard Worker sum_dy_xmu, 6974*da0073e9SAndroid Build Coastguard Worker count_tensor 6975*da0073e9SAndroid Build Coastguard Worker ) 6976*da0073e9SAndroid Build Coastguard Worker self.assertEqual(gI_actual, gI_contiguous) 6977*da0073e9SAndroid Build Coastguard Worker 6978*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_CUDA, "CUDA not available") 6979*da0073e9SAndroid Build Coastguard Worker def test_sync_batchnorm_accuracy_cuda(self): 6980*da0073e9SAndroid Build Coastguard Worker # The target of this test is to test the functionality and accuracy of 6981*da0073e9SAndroid Build Coastguard Worker # those single-GPU cuda kernels used in SyncBatchNorm 6982*da0073e9SAndroid Build Coastguard Worker # They are: 6983*da0073e9SAndroid Build Coastguard Worker # fwd: torch.batch_norm_stats, torch.batch_norm_gather_stats_with_counts, torch.batch_norm_elemt 6984*da0073e9SAndroid Build Coastguard Worker # bwd: torch.batch_norm_backward_reduce, torch.batch_norm_backward_elemt 6985*da0073e9SAndroid Build Coastguard Worker 6986*da0073e9SAndroid Build Coastguard Worker def _batch_norm_stats(data, memory_format, mean_axes): 6987*da0073e9SAndroid Build Coastguard Worker mean1, _ = torch.batch_norm_stats(data, 1e-5) 6988*da0073e9SAndroid Build Coastguard Worker mean2, _ = torch.batch_norm_stats(data.to(memory_format=memory_format), 1e-5) 6989*da0073e9SAndroid Build Coastguard Worker mean_ref = torch.mean(data, mean_axes, keepdim=False) 6990*da0073e9SAndroid Build Coastguard Worker 6991*da0073e9SAndroid Build Coastguard Worker self.assertEqual(mean_ref, mean1) 6992*da0073e9SAndroid Build Coastguard Worker self.assertEqual(mean_ref, mean2) 6993*da0073e9SAndroid Build Coastguard Worker 6994*da0073e9SAndroid Build Coastguard Worker _batch_norm_stats(torch.randn(1, 96, 112, 112, dtype=torch.float, device='cuda'), torch.channels_last, (0, 2, 3)) 6995*da0073e9SAndroid Build Coastguard Worker _batch_norm_stats(torch.randn(1, 96, 112, 112, 112, dtype=torch.float, device='cuda'), torch.channels_last_3d, (0, 2, 3, 4)) 6996*da0073e9SAndroid Build Coastguard Worker 6997*da0073e9SAndroid Build Coastguard Worker def test_flatten(self): 6998*da0073e9SAndroid Build Coastguard Worker tensor_input = torch.randn(2, 1, 2, 3) 6999*da0073e9SAndroid Build Coastguard Worker 7000*da0073e9SAndroid Build Coastguard Worker # Flatten Tensor 7001*da0073e9SAndroid Build Coastguard Worker 7002*da0073e9SAndroid Build Coastguard Worker flatten = nn.Flatten(start_dim=1, end_dim=-1) 7003*da0073e9SAndroid Build Coastguard Worker tensor_output = flatten(tensor_input) 7004*da0073e9SAndroid Build Coastguard Worker self.assertEqual(tensor_output.size(), torch.Size([2, 6])) 7005*da0073e9SAndroid Build Coastguard Worker 7006*da0073e9SAndroid Build Coastguard Worker def test_unflatten(self): 7007*da0073e9SAndroid Build Coastguard Worker tensor_input = torch.randn(2, 50) 7008*da0073e9SAndroid Build Coastguard Worker 7009*da0073e9SAndroid Build Coastguard Worker # Unflatten Tensor (unflattened_size as a tuple of ints and list of ints) 7010*da0073e9SAndroid Build Coastguard Worker 7011*da0073e9SAndroid Build Coastguard Worker for us in ((2, 5, 5), [2, 5, 5]): 7012*da0073e9SAndroid Build Coastguard Worker unflatten = nn.Unflatten(dim=1, unflattened_size=us) 7013*da0073e9SAndroid Build Coastguard Worker tensor_output = unflatten(tensor_input) 7014*da0073e9SAndroid Build Coastguard Worker self.assertEqual(tensor_output.size(), torch.Size([2, 2, 5, 5])) 7015*da0073e9SAndroid Build Coastguard Worker 7016*da0073e9SAndroid Build Coastguard Worker # Unflatten NamedTensor 7017*da0073e9SAndroid Build Coastguard Worker 7018*da0073e9SAndroid Build Coastguard Worker unflatten = nn.Unflatten(dim='features', unflattened_size=(('C', 2), ('H', 5), ('W', 5))) 7019*da0073e9SAndroid Build Coastguard Worker named_tensor_input = tensor_input.refine_names('N', 'features') 7020*da0073e9SAndroid Build Coastguard Worker named_tensor_output = unflatten(named_tensor_input) 7021*da0073e9SAndroid Build Coastguard Worker self.assertEqual(named_tensor_output.size(), torch.Size([2, 2, 5, 5])) 7022*da0073e9SAndroid Build Coastguard Worker 7023*da0073e9SAndroid Build Coastguard Worker def test_unflatten_invalid_arg(self): 7024*da0073e9SAndroid Build Coastguard Worker # Wrong type for unflattened_size (tuple of floats) 7025*da0073e9SAndroid Build Coastguard Worker 7026*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 7027*da0073e9SAndroid Build Coastguard Worker TypeError, 7028*da0073e9SAndroid Build Coastguard Worker r"unflattened_size must be tuple of ints, but found element of type float at pos 2"): 7029*da0073e9SAndroid Build Coastguard Worker nn.Unflatten(dim=1, unflattened_size=(2, 5, 5.0)) 7030*da0073e9SAndroid Build Coastguard Worker 7031*da0073e9SAndroid Build Coastguard Worker # Wrong type for unflattened_size (list of lists and list of tuples) 7032*da0073e9SAndroid Build Coastguard Worker for us in ([['C', 2], ['W', 5], ['H', 5]], [('C', 2), ('W', 5), ('H', 5)]): 7033*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 7034*da0073e9SAndroid Build Coastguard Worker TypeError, 7035*da0073e9SAndroid Build Coastguard Worker r"unflattened_size must be a tuple of tuples, but found type list"): 7036*da0073e9SAndroid Build Coastguard Worker nn.Unflatten(dim='features', unflattened_size=us) 7037*da0073e9SAndroid Build Coastguard Worker 7038*da0073e9SAndroid Build Coastguard Worker # Wrong type for unflattened_size (tuple of lists) 7039*da0073e9SAndroid Build Coastguard Worker 7040*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 7041*da0073e9SAndroid Build Coastguard Worker TypeError, 7042*da0073e9SAndroid Build Coastguard Worker r"unflattened_size must be tuple of tuples, but found element of type list at pos 0"): 7043*da0073e9SAndroid Build Coastguard Worker nn.Unflatten(dim='features', unflattened_size=(['C', 2], ['W', 5], ['H', 5])) 7044*da0073e9SAndroid Build Coastguard Worker 7045*da0073e9SAndroid Build Coastguard Worker # Wrong type for unflattened_size (tuple of dicts) 7046*da0073e9SAndroid Build Coastguard Worker 7047*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 7048*da0073e9SAndroid Build Coastguard Worker TypeError, 7049*da0073e9SAndroid Build Coastguard Worker r"unflattened_size must be tuple of tuples, but found element of type dict at pos 0"): 7050*da0073e9SAndroid Build Coastguard Worker nn.Unflatten(dim='features', unflattened_size=({'C': 2}, {'W': 5}, {'H': 5})) 7051*da0073e9SAndroid Build Coastguard Worker 7052*da0073e9SAndroid Build Coastguard Worker def test_layer_norm_grads_with_create_graph_flag(self): 7053*da0073e9SAndroid Build Coastguard Worker atol = 1e-5 7054*da0073e9SAndroid Build Coastguard Worker rtol = 1e-3 7055*da0073e9SAndroid Build Coastguard Worker 7056*da0073e9SAndroid Build Coastguard Worker x = torch.randn((4, 4, 16), requires_grad=True) 7057*da0073e9SAndroid Build Coastguard Worker layer_norm = nn.LayerNorm((16,), 1e-5, True) 7058*da0073e9SAndroid Build Coastguard Worker with torch.no_grad(): 7059*da0073e9SAndroid Build Coastguard Worker layer_norm.weight = torch.nn.Parameter(0.1 * torch.ones_like(layer_norm.weight)) 7060*da0073e9SAndroid Build Coastguard Worker 7061*da0073e9SAndroid Build Coastguard Worker grads1 = torch.autograd.grad(layer_norm(x).sum(), x, create_graph=False)[0] 7062*da0073e9SAndroid Build Coastguard Worker grads2 = torch.autograd.grad(layer_norm(x).sum(), x, create_graph=True)[0] 7063*da0073e9SAndroid Build Coastguard Worker 7064*da0073e9SAndroid Build Coastguard Worker self.assertEqual(grads1, grads2, rtol=rtol, atol=atol) 7065*da0073e9SAndroid Build Coastguard Worker 7066*da0073e9SAndroid Build Coastguard Worker if TEST_CUDA: 7067*da0073e9SAndroid Build Coastguard Worker x = x.to('cuda') 7068*da0073e9SAndroid Build Coastguard Worker layer_norm = layer_norm.to('cuda') 7069*da0073e9SAndroid Build Coastguard Worker 7070*da0073e9SAndroid Build Coastguard Worker grads1 = torch.autograd.grad(layer_norm(x).sum(), x, create_graph=False)[0] 7071*da0073e9SAndroid Build Coastguard Worker grads2 = torch.autograd.grad(layer_norm(x).sum(), x, create_graph=True)[0] 7072*da0073e9SAndroid Build Coastguard Worker 7073*da0073e9SAndroid Build Coastguard Worker self.assertEqual(grads1, grads2, rtol=rtol, atol=atol) 7074*da0073e9SAndroid Build Coastguard Worker 7075*da0073e9SAndroid Build Coastguard Worker def test_layer_norm_eps(self): 7076*da0073e9SAndroid Build Coastguard Worker # test for https://github.com/pytorch/pytorch/issues/108072 7077*da0073e9SAndroid Build Coastguard Worker x = torch.Tensor([[[2.0, 2.0], [14.0, 14.0]], [[2.0, 2.0], [14.0, 14.0]]]) 7078*da0073e9SAndroid Build Coastguard Worker ln = torch.nn.LayerNorm(2, eps=1e-6, elementwise_affine=False) 7079*da0073e9SAndroid Build Coastguard Worker self.assertEqual(ln.forward(x), torch.zeros_like(x)) 7080*da0073e9SAndroid Build Coastguard Worker 7081*da0073e9SAndroid Build Coastguard Worker def test_padding_list(self): 7082*da0073e9SAndroid Build Coastguard Worker # Padding can be a list, or tuple (regression test for gh-54452) 7083*da0073e9SAndroid Build Coastguard Worker x = torch.randn(4, 8, 32, 32) 7084*da0073e9SAndroid Build Coastguard Worker net = torch.nn.ConvTranspose2d(8, 16, kernel_size=3, padding=[3, 3]) 7085*da0073e9SAndroid Build Coastguard Worker y = net(x) 7086*da0073e9SAndroid Build Coastguard Worker 7087*da0073e9SAndroid Build Coastguard Worker net = torch.nn.ConvTranspose2d(8, 16, kernel_size=3, padding=(3, 3)) 7088*da0073e9SAndroid Build Coastguard Worker y = net(x) 7089*da0073e9SAndroid Build Coastguard Worker 7090*da0073e9SAndroid Build Coastguard Worker def test_fractional_max_pool2d_invalid_output_ratio(self): 7091*da0073e9SAndroid Build Coastguard Worker arg_1 = [2, 1] 7092*da0073e9SAndroid Build Coastguard Worker arg_2 = [0.5, 0.5, 0.6] 7093*da0073e9SAndroid Build Coastguard Worker arg_class = torch.nn.FractionalMaxPool2d(kernel_size=arg_1, output_ratio=arg_2,) 7094*da0073e9SAndroid Build Coastguard Worker arg_3_0_tensor = torch.rand([20, 16, 50, 32], dtype=torch.float32) 7095*da0073e9SAndroid Build Coastguard Worker arg_3_0 = arg_3_0_tensor.clone() 7096*da0073e9SAndroid Build Coastguard Worker arg_3 = [arg_3_0,] 7097*da0073e9SAndroid Build Coastguard Worker 7098*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(ValueError, 7099*da0073e9SAndroid Build Coastguard Worker "fractional_max_pool2d requires output_ratio to either be a single Int or tuple of Ints."): 7100*da0073e9SAndroid Build Coastguard Worker res = arg_class(*arg_3) 7101*da0073e9SAndroid Build Coastguard Worker 7102*da0073e9SAndroid Build Coastguard Worker def test_max_pool1d_invalid_output_size(self): 7103*da0073e9SAndroid Build Coastguard Worker arg_1 = 3 7104*da0073e9SAndroid Build Coastguard Worker arg_2 = 255 7105*da0073e9SAndroid Build Coastguard Worker arg_3 = False 7106*da0073e9SAndroid Build Coastguard Worker arg_class = torch.nn.MaxPool1d(kernel_size=arg_1, stride=arg_2, return_indices=arg_3) 7107*da0073e9SAndroid Build Coastguard Worker arg_4_0 = torch.as_tensor([[0.3204]]) 7108*da0073e9SAndroid Build Coastguard Worker arg_4 = [arg_4_0,] 7109*da0073e9SAndroid Build Coastguard Worker 7110*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(RuntimeError): 7111*da0073e9SAndroid Build Coastguard Worker res = arg_class(*arg_4) 7112*da0073e9SAndroid Build Coastguard Worker 7113*da0073e9SAndroid Build Coastguard Worker def test_pickle_module_no_weights_only_warning(self): 7114*da0073e9SAndroid Build Coastguard Worker with warnings.catch_warnings(record=True) as w: 7115*da0073e9SAndroid Build Coastguard Worker pickle.loads(pickle.dumps(torch.nn.Linear(10, 10))) 7116*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(w), 0) 7117*da0073e9SAndroid Build Coastguard Worker 7118*da0073e9SAndroid Build Coastguard Workerclass TestFusionEval(TestCase): 7119*da0073e9SAndroid Build Coastguard Worker @set_default_dtype(torch.double) 7120*da0073e9SAndroid Build Coastguard Worker @given(X=hu.tensor(shapes=((5, 3, 5, 5),), dtype=np.double), 7121*da0073e9SAndroid Build Coastguard Worker running_mean=hu.tensor(shapes=(6,), dtype=np.double), 7122*da0073e9SAndroid Build Coastguard Worker running_var=hu.tensor(shapes=(6,), dtype=np.double)) 7123*da0073e9SAndroid Build Coastguard Worker def test_fuse_module_eval_numerics(self, X, running_mean, running_var): 7124*da0073e9SAndroid Build Coastguard Worker inputs, _ = X 7125*da0073e9SAndroid Build Coastguard Worker 7126*da0073e9SAndroid Build Coastguard Worker iC, oC = inputs.shape[1], len(running_mean[0]) 7127*da0073e9SAndroid Build Coastguard Worker inputs = torch.from_numpy(inputs) 7128*da0073e9SAndroid Build Coastguard Worker kernel_size = (3, 3) 7129*da0073e9SAndroid Build Coastguard Worker 7130*da0073e9SAndroid Build Coastguard Worker conv_ref = torch.nn.Conv2d(iC, oC, bias=True, kernel_size=kernel_size) 7131*da0073e9SAndroid Build Coastguard Worker bn_ref = torch.nn.BatchNorm2d(oC) 7132*da0073e9SAndroid Build Coastguard Worker bn_ref.running_mean = torch.from_numpy(running_mean[0]) 7133*da0073e9SAndroid Build Coastguard Worker bn_ref.running_var = torch.from_numpy(running_var[0]) 7134*da0073e9SAndroid Build Coastguard Worker 7135*da0073e9SAndroid Build Coastguard Worker conv_ref.eval() 7136*da0073e9SAndroid Build Coastguard Worker bn_ref.eval() 7137*da0073e9SAndroid Build Coastguard Worker 7138*da0073e9SAndroid Build Coastguard Worker Y_ref = bn_ref(conv_ref(inputs)) 7139*da0073e9SAndroid Build Coastguard Worker conv_bn_fused = torch.nn.utils.fusion.fuse_conv_bn_eval(conv_ref, 7140*da0073e9SAndroid Build Coastguard Worker bn_ref) 7141*da0073e9SAndroid Build Coastguard Worker Y_hat = conv_bn_fused(inputs) 7142*da0073e9SAndroid Build Coastguard Worker 7143*da0073e9SAndroid Build Coastguard Worker self.assertEqual(Y_ref, Y_hat, msg="Conv+BN fusion results are off") 7144*da0073e9SAndroid Build Coastguard Worker 7145*da0073e9SAndroid Build Coastguard Worker na_bn_ref = torch.nn.BatchNorm2d(oC, affine=False) 7146*da0073e9SAndroid Build Coastguard Worker na_bn_ref.running_mean = torch.from_numpy(running_mean[0]) 7147*da0073e9SAndroid Build Coastguard Worker na_bn_ref.running_var = torch.from_numpy(running_var[0]) 7148*da0073e9SAndroid Build Coastguard Worker na_bn_ref.eval() 7149*da0073e9SAndroid Build Coastguard Worker 7150*da0073e9SAndroid Build Coastguard Worker Y_ref = na_bn_ref(conv_ref(inputs)) 7151*da0073e9SAndroid Build Coastguard Worker conv_na_bn_fused = torch.nn.utils.fusion.fuse_conv_bn_eval(conv_ref, 7152*da0073e9SAndroid Build Coastguard Worker na_bn_ref) 7153*da0073e9SAndroid Build Coastguard Worker Y_hat = conv_na_bn_fused(inputs) 7154*da0073e9SAndroid Build Coastguard Worker 7155*da0073e9SAndroid Build Coastguard Worker self.assertEqual(Y_ref, Y_hat, msg="Conv+BN(non-affine) fusion results are off") 7156*da0073e9SAndroid Build Coastguard Worker 7157*da0073e9SAndroid Build Coastguard Worker 7158*da0073e9SAndroid Build Coastguard Workerclass TestConstantPadNd(TestCase): 7159*da0073e9SAndroid Build Coastguard Worker def test_constant_pad_nd(self): 7160*da0073e9SAndroid Build Coastguard Worker a = torch.tensor([[1, 2], [3, 4]]) 7161*da0073e9SAndroid Build Coastguard Worker res = torch.constant_pad_nd(a, [1, 2, 1, 0], 9) 7162*da0073e9SAndroid Build Coastguard Worker expected = torch.tensor([ 7163*da0073e9SAndroid Build Coastguard Worker [9, 9, 9, 9, 9], 7164*da0073e9SAndroid Build Coastguard Worker [9, 1, 2, 9, 9], 7165*da0073e9SAndroid Build Coastguard Worker [9, 3, 4, 9, 9] 7166*da0073e9SAndroid Build Coastguard Worker ]) 7167*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res, expected) 7168*da0073e9SAndroid Build Coastguard Worker 7169*da0073e9SAndroid Build Coastguard Worker def test_preserves_memory_format(self): 7170*da0073e9SAndroid Build Coastguard Worker nchw_tensor = torch.rand((1, 2, 5, 3)) 7171*da0073e9SAndroid Build Coastguard Worker nchw_padded = torch.constant_pad_nd(nchw_tensor, [1, 2], 0.5) 7172*da0073e9SAndroid Build Coastguard Worker self.assertTrue(nchw_padded.is_contiguous(memory_format=torch.contiguous_format)) 7173*da0073e9SAndroid Build Coastguard Worker 7174*da0073e9SAndroid Build Coastguard Worker nhwc_tensor = nchw_tensor.contiguous(memory_format=torch.channels_last) 7175*da0073e9SAndroid Build Coastguard Worker nhwc_padded = torch.constant_pad_nd(nhwc_tensor, [1, 2], 0.5) 7176*da0073e9SAndroid Build Coastguard Worker self.assertTrue(nhwc_padded.is_contiguous(memory_format=torch.channels_last)) 7177*da0073e9SAndroid Build Coastguard Worker 7178*da0073e9SAndroid Build Coastguard Worker 7179*da0073e9SAndroid Build Coastguard Workerclass TestAddRelu(TestCase): 7180*da0073e9SAndroid Build Coastguard Worker def test_add_relu(self): 7181*da0073e9SAndroid Build Coastguard Worker a = torch.rand((7, 11)) 7182*da0073e9SAndroid Build Coastguard Worker b = torch.rand((7, 11)) 7183*da0073e9SAndroid Build Coastguard Worker a = a.float() 7184*da0073e9SAndroid Build Coastguard Worker b = b.float() 7185*da0073e9SAndroid Build Coastguard Worker a = a * -10 7186*da0073e9SAndroid Build Coastguard Worker a = a + 5 7187*da0073e9SAndroid Build Coastguard Worker add_res = a + b 7188*da0073e9SAndroid Build Coastguard Worker relu_res = torch.relu(add_res) 7189*da0073e9SAndroid Build Coastguard Worker add_relu_res = torch._VF._add_relu(a, b) 7190*da0073e9SAndroid Build Coastguard Worker 7191*da0073e9SAndroid Build Coastguard Worker self.assertEqual(add_relu_res, relu_res) 7192*da0073e9SAndroid Build Coastguard Worker 7193*da0073e9SAndroid Build Coastguard Worker def test_add_relu_broadcasting(self): 7194*da0073e9SAndroid Build Coastguard Worker a = torch.rand((1, 32)) 7195*da0073e9SAndroid Build Coastguard Worker b = 1 7196*da0073e9SAndroid Build Coastguard Worker b_scalar = torch.ones(1, 32) 7197*da0073e9SAndroid Build Coastguard Worker res = torch._VF._add_relu(a, b) 7198*da0073e9SAndroid Build Coastguard Worker broadcasted_res = torch._VF._add_relu(a, b_scalar) 7199*da0073e9SAndroid Build Coastguard Worker 7200*da0073e9SAndroid Build Coastguard Worker self.assertEqual(broadcasted_res, res) 7201*da0073e9SAndroid Build Coastguard Worker 7202*da0073e9SAndroid Build Coastguard Worker 7203*da0073e9SAndroid Build Coastguard Workerdef add_test(test, decorator=None): 7204*da0073e9SAndroid Build Coastguard Worker def add(test_name, fn): 7205*da0073e9SAndroid Build Coastguard Worker if hasattr(TestNN, test_name): 7206*da0073e9SAndroid Build Coastguard Worker raise RuntimeError('Found two tests with the same name: ' + test_name) 7207*da0073e9SAndroid Build Coastguard Worker if decorator is not None: 7208*da0073e9SAndroid Build Coastguard Worker fn = decorator(fn) 7209*da0073e9SAndroid Build Coastguard Worker setattr(TestNN, test_name, fn) 7210*da0073e9SAndroid Build Coastguard Worker 7211*da0073e9SAndroid Build Coastguard Worker test_name = test.get_name() 7212*da0073e9SAndroid Build Coastguard Worker if not hasattr(test, 'test_cpu') or test.test_cpu: 7213*da0073e9SAndroid Build Coastguard Worker add(test_name, lambda self, test=test: test(self)) 7214*da0073e9SAndroid Build Coastguard Worker cuda_test_name = test_name + '_cuda' 7215*da0073e9SAndroid Build Coastguard Worker # With dtype enable, it's good enough to test against three floating types 7216*da0073e9SAndroid Build Coastguard Worker kwargs = {} 7217*da0073e9SAndroid Build Coastguard Worker if 'extra_args' in get_function_arglist(test.test_cuda): 7218*da0073e9SAndroid Build Coastguard Worker kwargs['extra_args'] = test.extra_args 7219*da0073e9SAndroid Build Coastguard Worker 7220*da0073e9SAndroid Build Coastguard Worker if 'dtype' in get_function_arglist(test.test_cuda): 7221*da0073e9SAndroid Build Coastguard Worker if tf32_is_not_fp32() and test.with_tf32: 7222*da0073e9SAndroid Build Coastguard Worker 7223*da0073e9SAndroid Build Coastguard Worker def with_tf32_off(self, test=test, kwargs=kwargs): 7224*da0073e9SAndroid Build Coastguard Worker with tf32_off(): 7225*da0073e9SAndroid Build Coastguard Worker test.test_cuda(self, dtype=torch.float, **kwargs) 7226*da0073e9SAndroid Build Coastguard Worker 7227*da0073e9SAndroid Build Coastguard Worker add(cuda_test_name + '_fp32', with_tf32_off) 7228*da0073e9SAndroid Build Coastguard Worker 7229*da0073e9SAndroid Build Coastguard Worker def with_tf32_on(self, test=test, kwargs=kwargs): 7230*da0073e9SAndroid Build Coastguard Worker with tf32_on(self, test.tf32_precision): 7231*da0073e9SAndroid Build Coastguard Worker test.test_cuda(self, dtype=torch.float, **kwargs) 7232*da0073e9SAndroid Build Coastguard Worker 7233*da0073e9SAndroid Build Coastguard Worker add(cuda_test_name + '_tf32', with_tf32_on) 7234*da0073e9SAndroid Build Coastguard Worker else: 7235*da0073e9SAndroid Build Coastguard Worker add(cuda_test_name + '_float', lambda self, 7236*da0073e9SAndroid Build Coastguard Worker test=test, kwargs=kwargs: test.test_cuda(self, dtype=torch.float, **kwargs)) 7237*da0073e9SAndroid Build Coastguard Worker add(cuda_test_name + '_double', lambda self, 7238*da0073e9SAndroid Build Coastguard Worker test=test, kwargs=kwargs: test.test_cuda(self, dtype=torch.double, **kwargs)) 7239*da0073e9SAndroid Build Coastguard Worker 7240*da0073e9SAndroid Build Coastguard Worker def test_half(self, test=test, kwargs=kwargs): 7241*da0073e9SAndroid Build Coastguard Worker test.test_cuda(self, dtype=torch.half, **kwargs) 7242*da0073e9SAndroid Build Coastguard Worker if getattr(test, 'check_half', True): 7243*da0073e9SAndroid Build Coastguard Worker add(cuda_test_name + '_half', test_half) 7244*da0073e9SAndroid Build Coastguard Worker 7245*da0073e9SAndroid Build Coastguard Worker def test_bfloat16(self, test=test, kwargs=kwargs): 7246*da0073e9SAndroid Build Coastguard Worker test.test_cuda(self, dtype=torch.bfloat16, **kwargs) 7247*da0073e9SAndroid Build Coastguard Worker if getattr(test, 'check_bfloat16', True): 7248*da0073e9SAndroid Build Coastguard Worker add(cuda_test_name + '_bfloat16', test_bfloat16) 7249*da0073e9SAndroid Build Coastguard Worker 7250*da0073e9SAndroid Build Coastguard Worker def test_cfloat(self, test=test, kwargs=kwargs): 7251*da0073e9SAndroid Build Coastguard Worker test.test_cuda(self, dtype=torch.cfloat, **kwargs) 7252*da0073e9SAndroid Build Coastguard Worker 7253*da0073e9SAndroid Build Coastguard Worker def test_cdouble(self, test=test, kwargs=kwargs): 7254*da0073e9SAndroid Build Coastguard Worker test.test_cuda(self, dtype=torch.cdouble, **kwargs) 7255*da0073e9SAndroid Build Coastguard Worker if getattr(test, 'check_complex', False): 7256*da0073e9SAndroid Build Coastguard Worker add(cuda_test_name + '_cfloat', test_cfloat) 7257*da0073e9SAndroid Build Coastguard Worker add(cuda_test_name + '_cdouble', test_cdouble) 7258*da0073e9SAndroid Build Coastguard Worker 7259*da0073e9SAndroid Build Coastguard Worker else: 7260*da0073e9SAndroid Build Coastguard Worker def with_tf32_off(self, test=test, kwargs=kwargs): 7261*da0073e9SAndroid Build Coastguard Worker with tf32_off(): 7262*da0073e9SAndroid Build Coastguard Worker test.test_cuda(self, **kwargs) 7263*da0073e9SAndroid Build Coastguard Worker 7264*da0073e9SAndroid Build Coastguard Worker if tf32_is_not_fp32() and test.with_tf32: 7265*da0073e9SAndroid Build Coastguard Worker add(cuda_test_name + '_fp32', with_tf32_off) 7266*da0073e9SAndroid Build Coastguard Worker 7267*da0073e9SAndroid Build Coastguard Worker def with_tf32_on(self, test=test, kwargs=kwargs): 7268*da0073e9SAndroid Build Coastguard Worker with tf32_on(self, test.tf32_precision): 7269*da0073e9SAndroid Build Coastguard Worker test.test_cuda(self, **kwargs) 7270*da0073e9SAndroid Build Coastguard Worker 7271*da0073e9SAndroid Build Coastguard Worker add(cuda_test_name + '_tf32', with_tf32_on) 7272*da0073e9SAndroid Build Coastguard Worker else: 7273*da0073e9SAndroid Build Coastguard Worker add(cuda_test_name, with_tf32_off) 7274*da0073e9SAndroid Build Coastguard Worker 7275*da0073e9SAndroid Build Coastguard Workerfor test_params in module_tests + new_module_tests: 7276*da0073e9SAndroid Build Coastguard Worker # TODO: CUDA is not implemented yet 7277*da0073e9SAndroid Build Coastguard Worker if 'constructor' not in test_params: 7278*da0073e9SAndroid Build Coastguard Worker name = test_params.pop('module_name') 7279*da0073e9SAndroid Build Coastguard Worker test_params['constructor'] = getattr(nn, name) 7280*da0073e9SAndroid Build Coastguard Worker decorator = test_params.pop('decorator', None) 7281*da0073e9SAndroid Build Coastguard Worker test = NewModuleTest(**test_params) 7282*da0073e9SAndroid Build Coastguard Worker add_test(test, decorator) 7283*da0073e9SAndroid Build Coastguard Worker if 'check_eval' in test_params: 7284*da0073e9SAndroid Build Coastguard Worker # create a new test that is identical but that sets module.training to False 7285*da0073e9SAndroid Build Coastguard Worker desc = test_params.get('desc', None) 7286*da0073e9SAndroid Build Coastguard Worker test_params['desc'] = 'eval' if desc is None else desc + '_eval' 7287*da0073e9SAndroid Build Coastguard Worker 7288*da0073e9SAndroid Build Coastguard Worker def gen_eval_constructor(constructor): 7289*da0073e9SAndroid Build Coastguard Worker def eval_constructor(*args, **kwargs): 7290*da0073e9SAndroid Build Coastguard Worker cons = constructor(*args, **kwargs) 7291*da0073e9SAndroid Build Coastguard Worker cons.training = False 7292*da0073e9SAndroid Build Coastguard Worker return cons 7293*da0073e9SAndroid Build Coastguard Worker eval_constructor.__name__ = constructor.__name__ 7294*da0073e9SAndroid Build Coastguard Worker return eval_constructor 7295*da0073e9SAndroid Build Coastguard Worker 7296*da0073e9SAndroid Build Coastguard Worker test_params['constructor'] = gen_eval_constructor(test_params['constructor']) 7297*da0073e9SAndroid Build Coastguard Worker test = NewModuleTest(**test_params) 7298*da0073e9SAndroid Build Coastguard Worker add_test(test, decorator) 7299*da0073e9SAndroid Build Coastguard Worker if 'check_with_long_tensor' in test_params: 7300*da0073e9SAndroid Build Coastguard Worker fullname = test_params.get('fullname', None) 7301*da0073e9SAndroid Build Coastguard Worker if fullname: 7302*da0073e9SAndroid Build Coastguard Worker test_params['fullname'] = fullname + '_with_long_tensor' 7303*da0073e9SAndroid Build Coastguard Worker else: 7304*da0073e9SAndroid Build Coastguard Worker desc = test_params.get('desc', None) 7305*da0073e9SAndroid Build Coastguard Worker test_params['desc'] = 'with_long_tensor' if desc is None else desc + '_with_long_tensor' 7306*da0073e9SAndroid Build Coastguard Worker 7307*da0073e9SAndroid Build Coastguard Worker def double_equivalent_of_long_tensor(size): 7308*da0073e9SAndroid Build Coastguard Worker return torch.randint(-1000, 1000, size=size).double() 7309*da0073e9SAndroid Build Coastguard Worker 7310*da0073e9SAndroid Build Coastguard Worker def apply_to_cons(t): 7311*da0073e9SAndroid Build Coastguard Worker if t.is_floating_point(): 7312*da0073e9SAndroid Build Coastguard Worker if isinstance(t, Parameter): 7313*da0073e9SAndroid Build Coastguard Worker return Parameter(double_equivalent_of_long_tensor(t.size())) 7314*da0073e9SAndroid Build Coastguard Worker elif isinstance(t, torch.Tensor): 7315*da0073e9SAndroid Build Coastguard Worker return double_equivalent_of_long_tensor(t.size()) 7316*da0073e9SAndroid Build Coastguard Worker else: 7317*da0073e9SAndroid Build Coastguard Worker return t 7318*da0073e9SAndroid Build Coastguard Worker 7319*da0073e9SAndroid Build Coastguard Worker def gen_long_tensor_constructor(constructor): 7320*da0073e9SAndroid Build Coastguard Worker def long_tensor_constructor(*args, **kwargs): 7321*da0073e9SAndroid Build Coastguard Worker cons = constructor(*args, **kwargs) 7322*da0073e9SAndroid Build Coastguard Worker cons._apply(apply_to_cons) 7323*da0073e9SAndroid Build Coastguard Worker return cons 7324*da0073e9SAndroid Build Coastguard Worker long_tensor_constructor.__name__ = constructor.__name__ 7325*da0073e9SAndroid Build Coastguard Worker return long_tensor_constructor 7326*da0073e9SAndroid Build Coastguard Worker 7327*da0073e9SAndroid Build Coastguard Worker def gen_long_tensor_input(input_size): 7328*da0073e9SAndroid Build Coastguard Worker def input_func(): 7329*da0073e9SAndroid Build Coastguard Worker return double_equivalent_of_long_tensor(input_size) 7330*da0073e9SAndroid Build Coastguard Worker return input_func 7331*da0073e9SAndroid Build Coastguard Worker 7332*da0073e9SAndroid Build Coastguard Worker def reference_fn(i, p, m): 7333*da0073e9SAndroid Build Coastguard Worker # For bad reasons this would create LongTensors that requires gradients 7334*da0073e9SAndroid Build Coastguard Worker # Remove requires_grad to avoid this 7335*da0073e9SAndroid Build Coastguard Worker for p in m.parameters(): 7336*da0073e9SAndroid Build Coastguard Worker p.requires_grad_(False) 7337*da0073e9SAndroid Build Coastguard Worker m._apply(lambda t: t.long()) 7338*da0073e9SAndroid Build Coastguard Worker input = i.long() 7339*da0073e9SAndroid Build Coastguard Worker out = m.forward(input) 7340*da0073e9SAndroid Build Coastguard Worker return out 7341*da0073e9SAndroid Build Coastguard Worker 7342*da0073e9SAndroid Build Coastguard Worker test_params['constructor'] = gen_long_tensor_constructor(test_params['constructor']) 7343*da0073e9SAndroid Build Coastguard Worker test_params['input_fn'] = gen_long_tensor_input(test_params['input_size']) 7344*da0073e9SAndroid Build Coastguard Worker test_params['reference_fn'] = reference_fn 7345*da0073e9SAndroid Build Coastguard Worker test_params['check_forward_only'] = True 7346*da0073e9SAndroid Build Coastguard Worker # Currently we don't support conv2d/conv3d for LongTensor in CUDA 7347*da0073e9SAndroid Build Coastguard Worker test_params['test_cuda'] = False 7348*da0073e9SAndroid Build Coastguard Worker test = NewModuleTest(**test_params) 7349*da0073e9SAndroid Build Coastguard Worker 7350*da0073e9SAndroid Build Coastguard Worker add_test(test, decorator) 7351*da0073e9SAndroid Build Coastguard Worker 7352*da0073e9SAndroid Build Coastguard Workerfor test_params in criterion_tests: 7353*da0073e9SAndroid Build Coastguard Worker if 'constructor' not in test_params: 7354*da0073e9SAndroid Build Coastguard Worker name = test_params.pop('module_name') 7355*da0073e9SAndroid Build Coastguard Worker test_params['constructor'] = getattr(nn, name) 7356*da0073e9SAndroid Build Coastguard Worker test = CriterionTest(**test_params) 7357*da0073e9SAndroid Build Coastguard Worker decorator = test_params.pop('decorator', None) 7358*da0073e9SAndroid Build Coastguard Worker add_test(test, decorator) 7359*da0073e9SAndroid Build Coastguard Worker if 'check_sum_reduction' in test_params: 7360*da0073e9SAndroid Build Coastguard Worker desc = test_params.get('desc', None) 7361*da0073e9SAndroid Build Coastguard Worker test_params['desc'] = 'sum_reduction' if desc is None else desc + '_sum_reduction' 7362*da0073e9SAndroid Build Coastguard Worker 7363*da0073e9SAndroid Build Coastguard Worker def gen_sum_reduction_constructor(constructor): 7364*da0073e9SAndroid Build Coastguard Worker def sum_reduction_constructor(*args, **kwargs): 7365*da0073e9SAndroid Build Coastguard Worker cons = constructor(*args, reduction='sum', **kwargs) 7366*da0073e9SAndroid Build Coastguard Worker return cons 7367*da0073e9SAndroid Build Coastguard Worker sum_reduction_constructor.__name__ = constructor.__name__ 7368*da0073e9SAndroid Build Coastguard Worker return sum_reduction_constructor 7369*da0073e9SAndroid Build Coastguard Worker 7370*da0073e9SAndroid Build Coastguard Worker test_params['constructor'] = gen_sum_reduction_constructor(test_params['constructor']) 7371*da0073e9SAndroid Build Coastguard Worker test = CriterionTest(**test_params) 7372*da0073e9SAndroid Build Coastguard Worker add_test(test, decorator) 7373*da0073e9SAndroid Build Coastguard Worker 7374*da0073e9SAndroid Build Coastguard Worker 7375*da0073e9SAndroid Build Coastguard Workerclass UnpoolingNet(nn.Module): 7376*da0073e9SAndroid Build Coastguard Worker def __init__(self, pool, unpool): 7377*da0073e9SAndroid Build Coastguard Worker super().__init__() 7378*da0073e9SAndroid Build Coastguard Worker self.pool = pool 7379*da0073e9SAndroid Build Coastguard Worker self.unpool = unpool 7380*da0073e9SAndroid Build Coastguard Worker 7381*da0073e9SAndroid Build Coastguard Worker def forward(self, input): 7382*da0073e9SAndroid Build Coastguard Worker return self.unpool(*self.pool(input)) 7383*da0073e9SAndroid Build Coastguard Worker 7384*da0073e9SAndroid Build Coastguard Worker 7385*da0073e9SAndroid Build Coastguard Workeradd_test(NewModuleTest( 7386*da0073e9SAndroid Build Coastguard Worker constructor=lambda: UnpoolingNet( 7387*da0073e9SAndroid Build Coastguard Worker nn.MaxPool1d(2, return_indices=True), 7388*da0073e9SAndroid Build Coastguard Worker nn.MaxUnpool1d(2)), 7389*da0073e9SAndroid Build Coastguard Worker input_size=(1, 1, 4), 7390*da0073e9SAndroid Build Coastguard Worker fullname='MaxUnpool1d_net', 7391*da0073e9SAndroid Build Coastguard Worker default_dtype=torch.double,)) 7392*da0073e9SAndroid Build Coastguard Workeradd_test(NewModuleTest( 7393*da0073e9SAndroid Build Coastguard Worker constructor=lambda: UnpoolingNet( 7394*da0073e9SAndroid Build Coastguard Worker nn.MaxPool2d(2, return_indices=True), 7395*da0073e9SAndroid Build Coastguard Worker nn.MaxUnpool2d(2)), 7396*da0073e9SAndroid Build Coastguard Worker input_size=(1, 1, 2, 4), 7397*da0073e9SAndroid Build Coastguard Worker fullname='MaxUnpool2d_net', 7398*da0073e9SAndroid Build Coastguard Worker default_dtype=torch.double,)) 7399*da0073e9SAndroid Build Coastguard Workeradd_test(NewModuleTest( 7400*da0073e9SAndroid Build Coastguard Worker constructor=lambda: UnpoolingNet( 7401*da0073e9SAndroid Build Coastguard Worker nn.MaxPool3d(2, return_indices=True), 7402*da0073e9SAndroid Build Coastguard Worker nn.MaxUnpool3d(2)), 7403*da0073e9SAndroid Build Coastguard Worker input_size=(1, 1, 2, 4, 6), 7404*da0073e9SAndroid Build Coastguard Worker fullname='MaxUnpool3d_net', 7405*da0073e9SAndroid Build Coastguard Worker check_gradgrad=False, 7406*da0073e9SAndroid Build Coastguard Worker default_dtype=torch.double,)) 7407*da0073e9SAndroid Build Coastguard Worker 7408*da0073e9SAndroid Build Coastguard Workeradd_test(NewModuleTest( 7409*da0073e9SAndroid Build Coastguard Worker constructor=lambda: UnpoolingNet( 7410*da0073e9SAndroid Build Coastguard Worker nn.MaxPool1d(2, return_indices=True), 7411*da0073e9SAndroid Build Coastguard Worker nn.MaxUnpool1d(2)), 7412*da0073e9SAndroid Build Coastguard Worker input_size=(1, 4), 7413*da0073e9SAndroid Build Coastguard Worker reference_fn=single_batch_reference_fn, 7414*da0073e9SAndroid Build Coastguard Worker fullname='MaxUnpool1d_net_no_batch_dim', 7415*da0073e9SAndroid Build Coastguard Worker default_dtype=torch.double,)) 7416*da0073e9SAndroid Build Coastguard Workeradd_test(NewModuleTest( 7417*da0073e9SAndroid Build Coastguard Worker constructor=lambda: UnpoolingNet( 7418*da0073e9SAndroid Build Coastguard Worker nn.MaxPool2d(2, return_indices=True), 7419*da0073e9SAndroid Build Coastguard Worker nn.MaxUnpool2d(2)), 7420*da0073e9SAndroid Build Coastguard Worker input_size=(1, 2, 4), 7421*da0073e9SAndroid Build Coastguard Worker reference_fn=single_batch_reference_fn, 7422*da0073e9SAndroid Build Coastguard Worker fullname='MaxUnpool2d_net_no_batch_dim', 7423*da0073e9SAndroid Build Coastguard Worker default_dtype=torch.double,)) 7424*da0073e9SAndroid Build Coastguard Worker 7425*da0073e9SAndroid Build Coastguard Workeradd_test(NewModuleTest( 7426*da0073e9SAndroid Build Coastguard Worker constructor=lambda: UnpoolingNet( 7427*da0073e9SAndroid Build Coastguard Worker nn.MaxPool3d(2, return_indices=True), 7428*da0073e9SAndroid Build Coastguard Worker nn.MaxUnpool3d(2)), 7429*da0073e9SAndroid Build Coastguard Worker input_size=(1, 2, 4, 6), 7430*da0073e9SAndroid Build Coastguard Worker reference_fn=single_batch_reference_fn, 7431*da0073e9SAndroid Build Coastguard Worker fullname='MaxUnpool3d_net_no_batch_dim', 7432*da0073e9SAndroid Build Coastguard Worker check_gradgrad=False, 7433*da0073e9SAndroid Build Coastguard Worker default_dtype=torch.double,)) 7434*da0073e9SAndroid Build Coastguard Worker 7435*da0073e9SAndroid Build Coastguard Workerclass _AdaptiveLogSoftmaxWithLoss(nn.AdaptiveLogSoftmaxWithLoss): 7436*da0073e9SAndroid Build Coastguard Worker def __call__(self, input): 7437*da0073e9SAndroid Build Coastguard Worker t = torch.tensor([0, 1, 4, 8]).to(input.device) 7438*da0073e9SAndroid Build Coastguard Worker return nn.AdaptiveLogSoftmaxWithLoss.__call__(self, input, t).output 7439*da0073e9SAndroid Build Coastguard Worker 7440*da0073e9SAndroid Build Coastguard Workeradd_test(NewModuleTest( 7441*da0073e9SAndroid Build Coastguard Worker constructor=lambda: _AdaptiveLogSoftmaxWithLoss(16, 10, [2, 6]), 7442*da0073e9SAndroid Build Coastguard Worker input_size=(4, 16), 7443*da0073e9SAndroid Build Coastguard Worker fullname='AdaptiveLogSoftmax', 7444*da0073e9SAndroid Build Coastguard Worker with_tf32=True, 7445*da0073e9SAndroid Build Coastguard Worker tf32_precision=0.005, 7446*da0073e9SAndroid Build Coastguard Worker default_dtype=torch.double)) 7447*da0073e9SAndroid Build Coastguard Worker 7448*da0073e9SAndroid Build Coastguard Worker 7449*da0073e9SAndroid Build Coastguard Worker# The following are helpers for TestNN.test_affine_* 7450*da0073e9SAndroid Build Coastguard Workerif torch.cuda.is_available(): 7451*da0073e9SAndroid Build Coastguard Worker def device_(): 7452*da0073e9SAndroid Build Coastguard Worker return ['cpu', 'cuda'] 7453*da0073e9SAndroid Build Coastguard Workerelse: 7454*da0073e9SAndroid Build Coastguard Worker def device_(): 7455*da0073e9SAndroid Build Coastguard Worker return ['cpu'] 7456*da0073e9SAndroid Build Coastguard Worker 7457*da0073e9SAndroid Build Coastguard Worker 7458*da0073e9SAndroid Build Coastguard Workerdef angle_rad_(): 7459*da0073e9SAndroid Build Coastguard Worker return [r * math.pi * 2 for r in [0.0, 0.5, 0.25, 0.125, random.random()]] 7460*da0073e9SAndroid Build Coastguard Worker 7461*da0073e9SAndroid Build Coastguard Worker 7462*da0073e9SAndroid Build Coastguard Workerdef axis_vector_(): 7463*da0073e9SAndroid Build Coastguard Worker t = (random.random(), random.random(), random.random()) 7464*da0073e9SAndroid Build Coastguard Worker l = sum(x ** 2 for x in t) ** 0.5 7465*da0073e9SAndroid Build Coastguard Worker 7466*da0073e9SAndroid Build Coastguard Worker return [(1.0, 0.0, 0.0), (0.0, 1.0, 0.0), (0.0, 0.0, 1.0), tuple(x / l for x in t)] 7467*da0073e9SAndroid Build Coastguard Worker 7468*da0073e9SAndroid Build Coastguard Worker 7469*da0073e9SAndroid Build Coastguard Workerdef input_size2d_(): 7470*da0073e9SAndroid Build Coastguard Worker return [[1, 1, 3, 5], [1, 1, 3, 3], [1, 1, 4, 4], [1, 1, 3, 4]] 7471*da0073e9SAndroid Build Coastguard Worker 7472*da0073e9SAndroid Build Coastguard Worker 7473*da0073e9SAndroid Build Coastguard Workerdef output_size2d_(): 7474*da0073e9SAndroid Build Coastguard Worker return [[1, 1, 5, 3], [1, 1, 3, 5], [1, 1, 4, 3], [1, 1, 5, 5], [1, 1, 6, 6]] 7475*da0073e9SAndroid Build Coastguard Worker 7476*da0073e9SAndroid Build Coastguard Worker 7477*da0073e9SAndroid Build Coastguard Workerdef input_size2dsq_(): 7478*da0073e9SAndroid Build Coastguard Worker return [[1, 1, 2, 2], [1, 1, 3, 3], [1, 1, 4, 4], [1, 1, 6, 6]] 7479*da0073e9SAndroid Build Coastguard Worker 7480*da0073e9SAndroid Build Coastguard Worker 7481*da0073e9SAndroid Build Coastguard Workerdef output_size2dsq_(): 7482*da0073e9SAndroid Build Coastguard Worker return [[1, 1, 2, 2], [1, 1, 3, 3], [1, 1, 4, 4], [1, 1, 5, 5], [1, 1, 6, 6]] 7483*da0073e9SAndroid Build Coastguard Worker 7484*da0073e9SAndroid Build Coastguard Worker 7485*da0073e9SAndroid Build Coastguard Workerdef input_size3d_(): 7486*da0073e9SAndroid Build Coastguard Worker return [[1, 1, 2, 2, 2], [1, 1, 2, 3, 4], [1, 1, 3, 3, 3], [1, 1, 4, 4, 4], [1, 1, 3, 4, 5]] 7487*da0073e9SAndroid Build Coastguard Worker 7488*da0073e9SAndroid Build Coastguard Worker 7489*da0073e9SAndroid Build Coastguard Workerdef input_size3dsq_(): 7490*da0073e9SAndroid Build Coastguard Worker return [[1, 1, 2, 2, 2], [1, 1, 3, 3, 3], [1, 1, 4, 4, 4], [1, 1, 6, 6, 6]] 7491*da0073e9SAndroid Build Coastguard Worker 7492*da0073e9SAndroid Build Coastguard Worker 7493*da0073e9SAndroid Build Coastguard Workerdef output_size3dsq_(): 7494*da0073e9SAndroid Build Coastguard Worker return [[1, 1, 2, 2, 2], [1, 1, 3, 3, 3], [1, 1, 4, 4, 4], [1, 1, 5, 5, 5], [1, 1, 6, 6, 6]] 7495*da0073e9SAndroid Build Coastguard Worker 7496*da0073e9SAndroid Build Coastguard Worker 7497*da0073e9SAndroid Build Coastguard Workerdef output_size3d_(): 7498*da0073e9SAndroid Build Coastguard Worker return [[1, 1, 2, 2, 2], [1, 1, 3, 3, 3], [1, 1, 3, 4, 5], [1, 1, 4, 3, 2], [1, 1, 5, 5, 5], [1, 1, 6, 6, 6]] 7499*da0073e9SAndroid Build Coastguard Worker 7500*da0073e9SAndroid Build Coastguard Worker 7501*da0073e9SAndroid Build Coastguard Workerdef _buildEquivalentAffineTransforms2d(device, input_size, output_size, angle_rad): 7502*da0073e9SAndroid Build Coastguard Worker input_center = [(x - 1) / 2.0 for x in input_size] 7503*da0073e9SAndroid Build Coastguard Worker output_center = [(x - 1) / 2.0 for x in output_size] 7504*da0073e9SAndroid Build Coastguard Worker 7505*da0073e9SAndroid Build Coastguard Worker s = math.sin(angle_rad) 7506*da0073e9SAndroid Build Coastguard Worker c = math.cos(angle_rad) 7507*da0073e9SAndroid Build Coastguard Worker 7508*da0073e9SAndroid Build Coastguard Worker intrans_ary = np.array([ 7509*da0073e9SAndroid Build Coastguard Worker [1, 0, input_center[2]], 7510*da0073e9SAndroid Build Coastguard Worker [0, 1, input_center[3]], 7511*da0073e9SAndroid Build Coastguard Worker [0, 0, 1], 7512*da0073e9SAndroid Build Coastguard Worker ], dtype=np.float64) 7513*da0073e9SAndroid Build Coastguard Worker 7514*da0073e9SAndroid Build Coastguard Worker inscale_ary = np.array([ 7515*da0073e9SAndroid Build Coastguard Worker [input_center[2], 0, 0], 7516*da0073e9SAndroid Build Coastguard Worker [0, input_center[3], 0], 7517*da0073e9SAndroid Build Coastguard Worker [0, 0, 1], 7518*da0073e9SAndroid Build Coastguard Worker ], dtype=np.float64) 7519*da0073e9SAndroid Build Coastguard Worker 7520*da0073e9SAndroid Build Coastguard Worker rotation_ary = np.array([ 7521*da0073e9SAndroid Build Coastguard Worker [c, -s, 0], 7522*da0073e9SAndroid Build Coastguard Worker [s, c, 0], 7523*da0073e9SAndroid Build Coastguard Worker [0, 0, 1], 7524*da0073e9SAndroid Build Coastguard Worker ], dtype=np.float64) 7525*da0073e9SAndroid Build Coastguard Worker 7526*da0073e9SAndroid Build Coastguard Worker outscale_ary = np.array([ 7527*da0073e9SAndroid Build Coastguard Worker [1.0 / output_center[2], 0, 0], 7528*da0073e9SAndroid Build Coastguard Worker [0, 1.0 / output_center[3], 0], 7529*da0073e9SAndroid Build Coastguard Worker [0, 0, 1], 7530*da0073e9SAndroid Build Coastguard Worker ], dtype=np.float64) 7531*da0073e9SAndroid Build Coastguard Worker 7532*da0073e9SAndroid Build Coastguard Worker outtrans_ary = np.array([ 7533*da0073e9SAndroid Build Coastguard Worker [1, 0, -output_center[2]], 7534*da0073e9SAndroid Build Coastguard Worker [0, 1, -output_center[3]], 7535*da0073e9SAndroid Build Coastguard Worker [0, 0, 1], 7536*da0073e9SAndroid Build Coastguard Worker ], dtype=np.float64) 7537*da0073e9SAndroid Build Coastguard Worker 7538*da0073e9SAndroid Build Coastguard Worker reorder_ary = np.array([ 7539*da0073e9SAndroid Build Coastguard Worker [0, 1, 0], 7540*da0073e9SAndroid Build Coastguard Worker [1, 0, 0], 7541*da0073e9SAndroid Build Coastguard Worker [0, 0, 1], 7542*da0073e9SAndroid Build Coastguard Worker ], dtype=np.float64) 7543*da0073e9SAndroid Build Coastguard Worker 7544*da0073e9SAndroid Build Coastguard Worker transform_ary = np.dot(np.dot(np.dot(np.dot( 7545*da0073e9SAndroid Build Coastguard Worker intrans_ary, 7546*da0073e9SAndroid Build Coastguard Worker inscale_ary), 7547*da0073e9SAndroid Build Coastguard Worker rotation_ary.T), 7548*da0073e9SAndroid Build Coastguard Worker outscale_ary), 7549*da0073e9SAndroid Build Coastguard Worker outtrans_ary) 7550*da0073e9SAndroid Build Coastguard Worker grid_ary = np.dot(np.dot(np.dot(reorder_ary, rotation_ary.T), outscale_ary), outtrans_ary) 7551*da0073e9SAndroid Build Coastguard Worker 7552*da0073e9SAndroid Build Coastguard Worker transform_tensor = torch.from_numpy(rotation_ary).to(device, torch.float32) 7553*da0073e9SAndroid Build Coastguard Worker transform_tensor = transform_tensor[:2].unsqueeze(0) 7554*da0073e9SAndroid Build Coastguard Worker 7555*da0073e9SAndroid Build Coastguard Worker return transform_tensor, transform_ary, grid_ary 7556*da0073e9SAndroid Build Coastguard Worker 7557*da0073e9SAndroid Build Coastguard Worker 7558*da0073e9SAndroid Build Coastguard Workerdef _buildEquivalentAffineTransforms3d(device, input_size, output_size, angle_rad, axis_vector): 7559*da0073e9SAndroid Build Coastguard Worker input_center = [(x - 1) / 2.0 for x in input_size] 7560*da0073e9SAndroid Build Coastguard Worker output_center = [(x - 1) / 2.0 for x in output_size] 7561*da0073e9SAndroid Build Coastguard Worker 7562*da0073e9SAndroid Build Coastguard Worker s = math.sin(angle_rad) 7563*da0073e9SAndroid Build Coastguard Worker c = math.cos(angle_rad) 7564*da0073e9SAndroid Build Coastguard Worker c1 = 1 - c 7565*da0073e9SAndroid Build Coastguard Worker 7566*da0073e9SAndroid Build Coastguard Worker intrans_ary = np.array([ 7567*da0073e9SAndroid Build Coastguard Worker [1, 0, 0, input_center[2]], 7568*da0073e9SAndroid Build Coastguard Worker [0, 1, 0, input_center[3]], 7569*da0073e9SAndroid Build Coastguard Worker [0, 0, 1, input_center[4]], 7570*da0073e9SAndroid Build Coastguard Worker [0, 0, 0, 1], 7571*da0073e9SAndroid Build Coastguard Worker ], dtype=np.float64) 7572*da0073e9SAndroid Build Coastguard Worker 7573*da0073e9SAndroid Build Coastguard Worker inscale_ary = np.array([ 7574*da0073e9SAndroid Build Coastguard Worker [input_center[2], 0, 0, 0], 7575*da0073e9SAndroid Build Coastguard Worker [0, input_center[3], 0, 0], 7576*da0073e9SAndroid Build Coastguard Worker [0, 0, input_center[4], 0], 7577*da0073e9SAndroid Build Coastguard Worker [0, 0, 0, 1], 7578*da0073e9SAndroid Build Coastguard Worker ], dtype=np.float64) 7579*da0073e9SAndroid Build Coastguard Worker 7580*da0073e9SAndroid Build Coastguard Worker l, m, n = axis_vector 7581*da0073e9SAndroid Build Coastguard Worker scipyRotation_ary = np.array([ 7582*da0073e9SAndroid Build Coastguard Worker [l * l * c1 + c, m * l * c1 - n * s, n * l * c1 + m * s, 0], 7583*da0073e9SAndroid Build Coastguard Worker [l * m * c1 + n * s, m * m * c1 + c, n * m * c1 - l * s, 0], 7584*da0073e9SAndroid Build Coastguard Worker [l * n * c1 - m * s, m * n * c1 + l * s, n * n * c1 + c, 0], 7585*da0073e9SAndroid Build Coastguard Worker [0, 0, 0, 1], 7586*da0073e9SAndroid Build Coastguard Worker ], dtype=np.float64) 7587*da0073e9SAndroid Build Coastguard Worker 7588*da0073e9SAndroid Build Coastguard Worker z, y, x = axis_vector 7589*da0073e9SAndroid Build Coastguard Worker torchRotation_ary = np.array([ 7590*da0073e9SAndroid Build Coastguard Worker [x * x * c1 + c, y * x * c1 - z * s, z * x * c1 + y * s, 0], 7591*da0073e9SAndroid Build Coastguard Worker [x * y * c1 + z * s, y * y * c1 + c, z * y * c1 - x * s, 0], 7592*da0073e9SAndroid Build Coastguard Worker [x * z * c1 - y * s, y * z * c1 + x * s, z * z * c1 + c, 0], 7593*da0073e9SAndroid Build Coastguard Worker [0, 0, 0, 1], 7594*da0073e9SAndroid Build Coastguard Worker ], dtype=np.float64) 7595*da0073e9SAndroid Build Coastguard Worker 7596*da0073e9SAndroid Build Coastguard Worker outscale_ary = np.array([ 7597*da0073e9SAndroid Build Coastguard Worker [1.0 / output_center[2], 0, 0, 0], 7598*da0073e9SAndroid Build Coastguard Worker [0, 1.0 / output_center[3], 0, 0], 7599*da0073e9SAndroid Build Coastguard Worker [0, 0, 1.0 / output_center[4], 0], 7600*da0073e9SAndroid Build Coastguard Worker [0, 0, 0, 1], 7601*da0073e9SAndroid Build Coastguard Worker ], dtype=np.float64) 7602*da0073e9SAndroid Build Coastguard Worker 7603*da0073e9SAndroid Build Coastguard Worker outtrans_ary = np.array([ 7604*da0073e9SAndroid Build Coastguard Worker [1, 0, 0, -output_center[2]], 7605*da0073e9SAndroid Build Coastguard Worker [0, 1, 0, -output_center[3]], 7606*da0073e9SAndroid Build Coastguard Worker [0, 0, 1, -output_center[4]], 7607*da0073e9SAndroid Build Coastguard Worker [0, 0, 0, 1], 7608*da0073e9SAndroid Build Coastguard Worker ], dtype=np.float64) 7609*da0073e9SAndroid Build Coastguard Worker 7610*da0073e9SAndroid Build Coastguard Worker reorder_ary = np.array([ 7611*da0073e9SAndroid Build Coastguard Worker [0, 0, 1, 0], 7612*da0073e9SAndroid Build Coastguard Worker [0, 1, 0, 0], 7613*da0073e9SAndroid Build Coastguard Worker [1, 0, 0, 0], 7614*da0073e9SAndroid Build Coastguard Worker [0, 0, 0, 1], 7615*da0073e9SAndroid Build Coastguard Worker ], dtype=np.float64) 7616*da0073e9SAndroid Build Coastguard Worker 7617*da0073e9SAndroid Build Coastguard Worker transform_ary = np.dot(np.dot(np.dot(np.dot( 7618*da0073e9SAndroid Build Coastguard Worker intrans_ary, 7619*da0073e9SAndroid Build Coastguard Worker inscale_ary), 7620*da0073e9SAndroid Build Coastguard Worker np.linalg.inv(scipyRotation_ary)), 7621*da0073e9SAndroid Build Coastguard Worker outscale_ary), 7622*da0073e9SAndroid Build Coastguard Worker outtrans_ary) 7623*da0073e9SAndroid Build Coastguard Worker grid_ary = np.dot(np.dot(np.dot(reorder_ary, np.linalg.inv(scipyRotation_ary)), outscale_ary), outtrans_ary) 7624*da0073e9SAndroid Build Coastguard Worker 7625*da0073e9SAndroid Build Coastguard Worker transform_tensor = torch.from_numpy(torchRotation_ary).to(device, torch.float32) 7626*da0073e9SAndroid Build Coastguard Worker transform_tensor = transform_tensor[:3].unsqueeze(0) 7627*da0073e9SAndroid Build Coastguard Worker 7628*da0073e9SAndroid Build Coastguard Worker return transform_tensor, transform_ary, grid_ary 7629*da0073e9SAndroid Build Coastguard Worker# end TestNN.test_affine_* helpers 7630*da0073e9SAndroid Build Coastguard Worker 7631*da0073e9SAndroid Build Coastguard Worker 7632*da0073e9SAndroid Build Coastguard Workerclass TestNNDeviceType(NNTestCase): 7633*da0073e9SAndroid Build Coastguard Worker def _test_InstanceNorm_general(self, cls, input, device, dtype=torch.float): 7634*da0073e9SAndroid Build Coastguard Worker # default case track_running_stats=False 7635*da0073e9SAndroid Build Coastguard Worker b, c = input.size(0), input.size(1) 7636*da0073e9SAndroid Build Coastguard Worker input_var = input.to(device=device, dtype=dtype).requires_grad_() 7637*da0073e9SAndroid Build Coastguard Worker 7638*da0073e9SAndroid Build Coastguard Worker IN = cls(c, eps=0).to(device, dtype) 7639*da0073e9SAndroid Build Coastguard Worker 7640*da0073e9SAndroid Build Coastguard Worker output = IN(input_var) 7641*da0073e9SAndroid Build Coastguard Worker out_reshaped = output.view(b * c, -1) 7642*da0073e9SAndroid Build Coastguard Worker 7643*da0073e9SAndroid Build Coastguard Worker mean = out_reshaped.mean(1) 7644*da0073e9SAndroid Build Coastguard Worker var = out_reshaped.var(1, unbiased=False) 7645*da0073e9SAndroid Build Coastguard Worker 7646*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.abs(mean.data).mean(), 0, atol=1e-5, rtol=0) 7647*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.abs(var.data).mean(), 1, atol=1e-5, rtol=0) 7648*da0073e9SAndroid Build Coastguard Worker 7649*da0073e9SAndroid Build Coastguard Worker # check that eval mode doesn't change behavior 7650*da0073e9SAndroid Build Coastguard Worker grad_out = torch.randn_like(output) 7651*da0073e9SAndroid Build Coastguard Worker res1 = output.data.clone() 7652*da0073e9SAndroid Build Coastguard Worker output.backward(grad_out) 7653*da0073e9SAndroid Build Coastguard Worker grad1 = input_var.grad.data.clone() 7654*da0073e9SAndroid Build Coastguard Worker 7655*da0073e9SAndroid Build Coastguard Worker IN.eval() 7656*da0073e9SAndroid Build Coastguard Worker output = IN(input_var) 7657*da0073e9SAndroid Build Coastguard Worker input_var.grad = None 7658*da0073e9SAndroid Build Coastguard Worker output.backward(grad_out) 7659*da0073e9SAndroid Build Coastguard Worker res2 = output.data 7660*da0073e9SAndroid Build Coastguard Worker grad2 = input_var.grad.data 7661*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res1, res2) 7662*da0073e9SAndroid Build Coastguard Worker self.assertEqual(grad1, grad2) 7663*da0073e9SAndroid Build Coastguard Worker 7664*da0073e9SAndroid Build Coastguard Worker # If track_running_stats=True and momentum=1, running_mean/var should be 7665*da0073e9SAndroid Build Coastguard Worker # equal to mean/var of the input (with unbias correction) 7666*da0073e9SAndroid Build Coastguard Worker IN = cls(c, momentum=1, eps=0, track_running_stats=True).to(device, dtype) 7667*da0073e9SAndroid Build Coastguard Worker 7668*da0073e9SAndroid Build Coastguard Worker output = IN(input_var) 7669*da0073e9SAndroid Build Coastguard Worker 7670*da0073e9SAndroid Build Coastguard Worker input_reshaped = input_var.transpose(1, 0).reshape(c, -1) 7671*da0073e9SAndroid Build Coastguard Worker mean = input_reshaped.mean(1) 7672*da0073e9SAndroid Build Coastguard Worker 7673*da0073e9SAndroid Build Coastguard Worker input_reshaped = input_var.transpose(1, 0).reshape(c, b, -1) 7674*da0073e9SAndroid Build Coastguard Worker var = input_reshaped.var(2, unbiased=True)[:, :] 7675*da0073e9SAndroid Build Coastguard Worker 7676*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.abs(mean.data - IN.running_mean).mean(), 0, atol=1e-5, rtol=0) 7677*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.abs(var.data.mean(1) - IN.running_var).mean(), 0, atol=1e-5, rtol=0) 7678*da0073e9SAndroid Build Coastguard Worker 7679*da0073e9SAndroid Build Coastguard Worker # in eval mode, adding X * std to a channel in input should make the 7680*da0073e9SAndroid Build Coastguard Worker # corresponding channel in output have mean X 7681*da0073e9SAndroid Build Coastguard Worker IN.eval() 7682*da0073e9SAndroid Build Coastguard Worker delta = IN.running_var.sqrt() * torch.arange(c, device=device, dtype=dtype) 7683*da0073e9SAndroid Build Coastguard Worker delta = delta.view(-1, *[1 for _ in range(2, input.dim())]) 7684*da0073e9SAndroid Build Coastguard Worker output = IN(input_var + delta) 7685*da0073e9SAndroid Build Coastguard Worker self.assertEqual(output.transpose(0, 1).reshape(c, -1).mean(1), torch.arange(c, dtype=dtype)) 7686*da0073e9SAndroid Build Coastguard Worker 7687*da0073e9SAndroid Build Coastguard Worker def _test_InstanceNorm_cuda_half(self, cls, input, device): 7688*da0073e9SAndroid Build Coastguard Worker # THNN 7689*da0073e9SAndroid Build Coastguard Worker input = input.to(device=device, dtype=torch.half).random_(1, 10).requires_grad_(True) 7690*da0073e9SAndroid Build Coastguard Worker m = cls(input.size(1), affine=True, track_running_stats=True).to(device, torch.half) 7691*da0073e9SAndroid Build Coastguard Worker thnn_output = m(input) 7692*da0073e9SAndroid Build Coastguard Worker thnn_output.sum().backward() 7693*da0073e9SAndroid Build Coastguard Worker thnn_input_grad = input.grad.data.clone() 7694*da0073e9SAndroid Build Coastguard Worker self.assertEqualTypeString(thnn_output, input) 7695*da0073e9SAndroid Build Coastguard Worker # cuDNN 7696*da0073e9SAndroid Build Coastguard Worker if TEST_CUDNN: 7697*da0073e9SAndroid Build Coastguard Worker input.grad = None 7698*da0073e9SAndroid Build Coastguard Worker m = m.float() 7699*da0073e9SAndroid Build Coastguard Worker cudnn_output = m(input) 7700*da0073e9SAndroid Build Coastguard Worker cudnn_output.sum().backward() 7701*da0073e9SAndroid Build Coastguard Worker cudnn_input_grad = input.grad.data.clone() 7702*da0073e9SAndroid Build Coastguard Worker self.assertEqualTypeString(cudnn_output, input) 7703*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cudnn_output, thnn_output, atol=1e-4, rtol=0) 7704*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cudnn_input_grad, thnn_input_grad, atol=1e-3, rtol=0) 7705*da0073e9SAndroid Build Coastguard Worker 7706*da0073e9SAndroid Build Coastguard Worker def _test_LayerNorm_general(self, device, dtype=torch.float): 7707*da0073e9SAndroid Build Coastguard Worker for i in range(2, 6): 7708*da0073e9SAndroid Build Coastguard Worker shape = torch.randint(3, 6, (i,), dtype=torch.long).tolist() 7709*da0073e9SAndroid Build Coastguard Worker x = torch.empty(*shape, device=device, dtype=dtype).uniform_(0, 10) 7710*da0073e9SAndroid Build Coastguard Worker normalized_ndim = random.randint(1, i - 1) # inclusive 7711*da0073e9SAndroid Build Coastguard Worker normalized_shape = shape[-normalized_ndim:] 7712*da0073e9SAndroid Build Coastguard Worker unnormalized_shape = shape[:-normalized_ndim] 7713*da0073e9SAndroid Build Coastguard Worker 7714*da0073e9SAndroid Build Coastguard Worker # test that LN normalizes to mean 0 and stddev 1 7715*da0073e9SAndroid Build Coastguard Worker ln = nn.LayerNorm(normalized_shape, eps=0).to(device, dtype) 7716*da0073e9SAndroid Build Coastguard Worker ln.weight.data.fill_(1) 7717*da0073e9SAndroid Build Coastguard Worker ln.bias.data.fill_(0) 7718*da0073e9SAndroid Build Coastguard Worker output = ln(x) 7719*da0073e9SAndroid Build Coastguard Worker out_reshaped = output.view(*(unnormalized_shape + [-1])) 7720*da0073e9SAndroid Build Coastguard Worker mean = out_reshaped.mean(-1) 7721*da0073e9SAndroid Build Coastguard Worker var = out_reshaped.var(-1, unbiased=False) 7722*da0073e9SAndroid Build Coastguard Worker 7723*da0073e9SAndroid Build Coastguard Worker delta = 1e-1 if (dtype == torch.bfloat16 or dtype == torch.half) else 1e-5 7724*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.abs(mean.data).mean(), 0, atol=delta, rtol=0) 7725*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.abs(var.data).mean(), 1, atol=delta, rtol=0) 7726*da0073e9SAndroid Build Coastguard Worker 7727*da0073e9SAndroid Build Coastguard Worker # test that LN applies weight and bias correctly 7728*da0073e9SAndroid Build Coastguard Worker scale, bias = torch.empty(2).uniform_(0.2, 2).tolist() 7729*da0073e9SAndroid Build Coastguard Worker ln.weight.data.fill_(scale) 7730*da0073e9SAndroid Build Coastguard Worker ln.bias.data.fill_(bias) 7731*da0073e9SAndroid Build Coastguard Worker output = ln(x) 7732*da0073e9SAndroid Build Coastguard Worker out_reshaped = output.view(*(unnormalized_shape + [-1])) 7733*da0073e9SAndroid Build Coastguard Worker mean = out_reshaped.mean(-1) 7734*da0073e9SAndroid Build Coastguard Worker var = out_reshaped.var(-1, unbiased=False) 7735*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.abs(mean.data).mean(), bias, atol=delta, rtol=0) 7736*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.abs(var.data).mean(), scale ** 2, atol=delta, rtol=0) 7737*da0073e9SAndroid Build Coastguard Worker 7738*da0073e9SAndroid Build Coastguard Worker bad_norm_shape_input_shape = { 7739*da0073e9SAndroid Build Coastguard Worker (): (), 7740*da0073e9SAndroid Build Coastguard Worker (2, 3): (3,), 7741*da0073e9SAndroid Build Coastguard Worker (2,): (1, 2, 3), 7742*da0073e9SAndroid Build Coastguard Worker (10,): (2, 3), 7743*da0073e9SAndroid Build Coastguard Worker 10: (2, 3), 7744*da0073e9SAndroid Build Coastguard Worker } 7745*da0073e9SAndroid Build Coastguard Worker for norm_shape, input_shape in bad_norm_shape_input_shape.items(): 7746*da0073e9SAndroid Build Coastguard Worker ln = nn.LayerNorm(norm_shape) 7747*da0073e9SAndroid Build Coastguard Worker input = torch.empty(input_shape, device=device, dtype=dtype).uniform_(0, 10) 7748*da0073e9SAndroid Build Coastguard Worker self.assertRaises(RuntimeError, lambda: ln(input)) 7749*da0073e9SAndroid Build Coastguard Worker 7750*da0073e9SAndroid Build Coastguard Worker def _test_LayerNorm_cuda_half(self, device): 7751*da0073e9SAndroid Build Coastguard Worker input = torch.empty(2, 3, 3, 2, device=device, dtype=torch.half).random_(1, 10).requires_grad_(True) 7752*da0073e9SAndroid Build Coastguard Worker m = nn.LayerNorm([3, 2]).to(device, torch.half) 7753*da0073e9SAndroid Build Coastguard Worker output = m(input) 7754*da0073e9SAndroid Build Coastguard Worker output.sum().backward() 7755*da0073e9SAndroid Build Coastguard Worker self.assertEqualTypeString(output, input) 7756*da0073e9SAndroid Build Coastguard Worker 7757*da0073e9SAndroid Build Coastguard Worker def _test_LayerNorm_cpu_mixed_dtype(self, device, dtype): 7758*da0073e9SAndroid Build Coastguard Worker for elementwise_affine in [True, False]: 7759*da0073e9SAndroid Build Coastguard Worker # layer norm input shape is normalized to m x n, cpu vectorized on n, 7760*da0073e9SAndroid Build Coastguard Worker # so make sure n exceeds vector length 7761*da0073e9SAndroid Build Coastguard Worker input = torch.empty(2, 3, 11, 3, device=device, dtype=dtype).random_(1, 10) 7762*da0073e9SAndroid Build Coastguard Worker m = nn.LayerNorm([11, 3], elementwise_affine=elementwise_affine).to(device, dtype) 7763*da0073e9SAndroid Build Coastguard Worker 7764*da0073e9SAndroid Build Coastguard Worker # fp32 7765*da0073e9SAndroid Build Coastguard Worker m_fp32 = deepcopy(m).to(device, torch.float) 7766*da0073e9SAndroid Build Coastguard Worker x_fp32 = input.clone().detach().float().requires_grad_() 7767*da0073e9SAndroid Build Coastguard Worker out_fp32 = m_fp32(x_fp32) 7768*da0073e9SAndroid Build Coastguard Worker out_fp32.sum().backward() 7769*da0073e9SAndroid Build Coastguard Worker 7770*da0073e9SAndroid Build Coastguard Worker # bf16/half 7771*da0073e9SAndroid Build Coastguard Worker m_bf16 = deepcopy(m) 7772*da0073e9SAndroid Build Coastguard Worker x_bf16 = input.clone().detach().requires_grad_() 7773*da0073e9SAndroid Build Coastguard Worker out_bf16 = m_bf16(x_bf16) 7774*da0073e9SAndroid Build Coastguard Worker out_bf16.sum().backward() 7775*da0073e9SAndroid Build Coastguard Worker 7776*da0073e9SAndroid Build Coastguard Worker # bf16/half mixed type 7777*da0073e9SAndroid Build Coastguard Worker m_mix = deepcopy(m).to(device, torch.float) 7778*da0073e9SAndroid Build Coastguard Worker x_mix = input.clone().detach().requires_grad_() 7779*da0073e9SAndroid Build Coastguard Worker out_mix = m_mix(x_mix) 7780*da0073e9SAndroid Build Coastguard Worker out_mix.sum().backward() 7781*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out_fp32.to(dtype=dtype), out_bf16) 7782*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out_fp32.to(dtype=dtype), out_mix) 7783*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x_fp32.grad.to(dtype=dtype), x_bf16.grad, atol=1e-1, rtol=1e-1) 7784*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x_fp32.grad.to(dtype=dtype), x_mix.grad, atol=1e-1, rtol=1e-1) 7785*da0073e9SAndroid Build Coastguard Worker 7786*da0073e9SAndroid Build Coastguard Worker def _test_GroupNorm_general(self, device, dtype=torch.float): 7787*da0073e9SAndroid Build Coastguard Worker good_shape_g = { 7788*da0073e9SAndroid Build Coastguard Worker (1, 2, 3, 4): 2, 7789*da0073e9SAndroid Build Coastguard Worker (2, 3, 10): 3, 7790*da0073e9SAndroid Build Coastguard Worker (3, 1, 1, 1, 2): 1, 7791*da0073e9SAndroid Build Coastguard Worker (2, 6, 4, 2, 2): 3, 7792*da0073e9SAndroid Build Coastguard Worker (1, 256, 1, 1): 32, 7793*da0073e9SAndroid Build Coastguard Worker } 7794*da0073e9SAndroid Build Coastguard Worker for shape_g, grad in product(good_shape_g.items(), [True, False]): 7795*da0073e9SAndroid Build Coastguard Worker shape, g = shape_g 7796*da0073e9SAndroid Build Coastguard Worker x = torch.empty(*shape, device=device, dtype=dtype).uniform_(0, 10) 7797*da0073e9SAndroid Build Coastguard Worker x.requires_grad_(grad) 7798*da0073e9SAndroid Build Coastguard Worker b = shape[0] 7799*da0073e9SAndroid Build Coastguard Worker c = shape[1] 7800*da0073e9SAndroid Build Coastguard Worker 7801*da0073e9SAndroid Build Coastguard Worker # test that GN normalizes to mean 0 and stddev 1 7802*da0073e9SAndroid Build Coastguard Worker gn = nn.GroupNorm(g, c, eps=0).to(device, dtype) 7803*da0073e9SAndroid Build Coastguard Worker gn.weight.data.fill_(1) 7804*da0073e9SAndroid Build Coastguard Worker gn.bias.data.fill_(0) 7805*da0073e9SAndroid Build Coastguard Worker output = gn(x) 7806*da0073e9SAndroid Build Coastguard Worker out_reshaped = output.view(b, g, -1) 7807*da0073e9SAndroid Build Coastguard Worker mean = out_reshaped.mean(-1) 7808*da0073e9SAndroid Build Coastguard Worker var = out_reshaped.var(-1, unbiased=False) 7809*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.abs(mean).mean(), 0, atol=1e-5, rtol=0) 7810*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.abs(var).mean(), 1, atol=1e-5, rtol=0) 7811*da0073e9SAndroid Build Coastguard Worker 7812*da0073e9SAndroid Build Coastguard Worker output.backward(torch.randn_like(output)) 7813*da0073e9SAndroid Build Coastguard Worker if output.is_cuda: 7814*da0073e9SAndroid Build Coastguard Worker torch.cuda.synchronize() 7815*da0073e9SAndroid Build Coastguard Worker 7816*da0073e9SAndroid Build Coastguard Worker # test that GN applies weight and bias correctly 7817*da0073e9SAndroid Build Coastguard Worker scale = torch.empty(c, device=device, dtype=dtype).uniform_(0.2, 2) 7818*da0073e9SAndroid Build Coastguard Worker bias = torch.empty(c, device=device, dtype=dtype).uniform_(0.2, 2) 7819*da0073e9SAndroid Build Coastguard Worker gn.weight.data.copy_(scale) 7820*da0073e9SAndroid Build Coastguard Worker gn.bias.data.copy_(bias) 7821*da0073e9SAndroid Build Coastguard Worker output = gn(x) 7822*da0073e9SAndroid Build Coastguard Worker out_reshaped = output.view(b, c, -1) 7823*da0073e9SAndroid Build Coastguard Worker out_normed = (out_reshaped - bias.view(c, 1)) / scale.view(c, 1) 7824*da0073e9SAndroid Build Coastguard Worker out_normed_reshaped = out_normed.view(b, g, -1) 7825*da0073e9SAndroid Build Coastguard Worker mean = out_normed_reshaped.mean(-1) 7826*da0073e9SAndroid Build Coastguard Worker var = out_normed_reshaped.var(-1, unbiased=False) 7827*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.abs(mean).mean(), 0, atol=1e-5, rtol=0) 7828*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.abs(var).mean(), 1, atol=1e-5, rtol=0) 7829*da0073e9SAndroid Build Coastguard Worker 7830*da0073e9SAndroid Build Coastguard Worker bad_shape_g = { 7831*da0073e9SAndroid Build Coastguard Worker (1, 2, 3, 4): 3, 7832*da0073e9SAndroid Build Coastguard Worker (2, 3, 10): 2, 7833*da0073e9SAndroid Build Coastguard Worker (3, 1, 1, 1, 2): 10, 7834*da0073e9SAndroid Build Coastguard Worker (2, 6, 4, 2, 2): 4, 7835*da0073e9SAndroid Build Coastguard Worker } 7836*da0073e9SAndroid Build Coastguard Worker for shape, g in bad_shape_g.items(): 7837*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(ValueError): 7838*da0073e9SAndroid Build Coastguard Worker gn = nn.GroupNorm(g, shape[1]) 7839*da0073e9SAndroid Build Coastguard Worker 7840*da0073e9SAndroid Build Coastguard Worker def _test_GroupNorm_cuda_half(self): 7841*da0073e9SAndroid Build Coastguard Worker input = torch.zeros(2, 4, 3, 2, requires_grad=True).cuda().half().random_(1, 10) 7842*da0073e9SAndroid Build Coastguard Worker m = nn.GroupNorm(2, 4).to("cuda", torch.half) 7843*da0073e9SAndroid Build Coastguard Worker output = m(input) 7844*da0073e9SAndroid Build Coastguard Worker output.sum().backward() 7845*da0073e9SAndroid Build Coastguard Worker self.assertEqualTypeString(output, input) 7846*da0073e9SAndroid Build Coastguard Worker 7847*da0073e9SAndroid Build Coastguard Worker def _test_GroupNorm_cpu_mixed_dtype(self): 7848*da0073e9SAndroid Build Coastguard Worker def helper(self, size, groups, memory_format, dtype): 7849*da0073e9SAndroid Build Coastguard Worker channels = size[1] 7850*da0073e9SAndroid Build Coastguard Worker input = torch.randn(size).cpu().to(dtype=dtype) 7851*da0073e9SAndroid Build Coastguard Worker input_bf1 = input.contiguous(memory_format=memory_format).detach().requires_grad_(True) 7852*da0073e9SAndroid Build Coastguard Worker input_bf2 = input_bf1.clone().detach().requires_grad_(True) 7853*da0073e9SAndroid Build Coastguard Worker input_f = input_bf1.float().detach().requires_grad_(True) 7854*da0073e9SAndroid Build Coastguard Worker m_bf = nn.GroupNorm(groups, channels).cpu().to(dtype=dtype) 7855*da0073e9SAndroid Build Coastguard Worker m_f = deepcopy(m_bf).float() 7856*da0073e9SAndroid Build Coastguard Worker m_f2 = deepcopy(m_f) 7857*da0073e9SAndroid Build Coastguard Worker # bfloat16 input and bfloat16 parameters 7858*da0073e9SAndroid Build Coastguard Worker out = m_bf(input_bf1) 7859*da0073e9SAndroid Build Coastguard Worker # bfloat16 input and float parameters 7860*da0073e9SAndroid Build Coastguard Worker out2 = m_f(input_bf2) 7861*da0073e9SAndroid Build Coastguard Worker # float input and float parameters 7862*da0073e9SAndroid Build Coastguard Worker out3 = m_f2(input_f) 7863*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out, out2, atol=5e-3, rtol=5e-3) 7864*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out2.float(), out3, atol=5e-3, rtol=5e-3) 7865*da0073e9SAndroid Build Coastguard Worker grad_out = torch.randn(out2.shape).cpu().to(dtype=dtype) 7866*da0073e9SAndroid Build Coastguard Worker grad_out_bf1 = grad_out.contiguous(memory_format=memory_format).detach().requires_grad_(True) 7867*da0073e9SAndroid Build Coastguard Worker grad_out_bf2 = grad_out_bf1.clone().detach().requires_grad_(True) 7868*da0073e9SAndroid Build Coastguard Worker grad_out_f = grad_out_bf2.clone().float().detach().requires_grad_(True) 7869*da0073e9SAndroid Build Coastguard Worker # bfloat16/half input grad and float parameters 7870*da0073e9SAndroid Build Coastguard Worker out2.backward(grad_out_bf2, retain_graph=True) 7871*da0073e9SAndroid Build Coastguard Worker # float input grad and float parameters 7872*da0073e9SAndroid Build Coastguard Worker out3.backward(grad_out_f, retain_graph=True) 7873*da0073e9SAndroid Build Coastguard Worker # bfloat16/half input grad and bfloat16/half parameters 7874*da0073e9SAndroid Build Coastguard Worker out.backward(grad_out_bf1, retain_graph=True) 7875*da0073e9SAndroid Build Coastguard Worker # Need higher tolerances atol=1e-4 and rtol=1e-4 on macos 7876*da0073e9SAndroid Build Coastguard Worker self.assertEqual(m_f.weight.grad, m_f2.weight.grad, atol=1e-4, rtol=1e-4) 7877*da0073e9SAndroid Build Coastguard Worker self.assertEqual(m_f.bias.grad, m_f2.bias.grad, atol=1e-5, rtol=1e-5) 7878*da0073e9SAndroid Build Coastguard Worker self.assertEqual(input_bf2.grad.float(), input_f.grad, atol=5e-5, rtol=5e-3) 7879*da0073e9SAndroid Build Coastguard Worker # Full bf16/half has lower precision compared with mixed bf16/half and fp32. 7880*da0073e9SAndroid Build Coastguard Worker # Use Amp to keep module parameters in acc dtype, i.e. float, for better numerical stability 7881*da0073e9SAndroid Build Coastguard Worker atol = None 7882*da0073e9SAndroid Build Coastguard Worker rtol = None 7883*da0073e9SAndroid Build Coastguard Worker if dtype == torch.bfloat16: 7884*da0073e9SAndroid Build Coastguard Worker atol = 1e-2 7885*da0073e9SAndroid Build Coastguard Worker rtol = 1.2e-1 7886*da0073e9SAndroid Build Coastguard Worker else: 7887*da0073e9SAndroid Build Coastguard Worker assert dtype == torch.half 7888*da0073e9SAndroid Build Coastguard Worker atol = 5e-3 7889*da0073e9SAndroid Build Coastguard Worker rtol = 1.5e-2 7890*da0073e9SAndroid Build Coastguard Worker self.assertEqual(m_bf.weight.grad, m_f.weight.grad.to(dtype=dtype), atol=atol, rtol=rtol) 7891*da0073e9SAndroid Build Coastguard Worker self.assertEqual(m_bf.bias.grad, m_f.bias.grad.to(dtype=dtype), atol=atol, rtol=rtol) 7892*da0073e9SAndroid Build Coastguard Worker self.assertEqual(input_bf1.grad, input_bf2.grad, atol=atol, rtol=rtol) 7893*da0073e9SAndroid Build Coastguard Worker 7894*da0073e9SAndroid Build Coastguard Worker cl_formats = {4: torch.channels_last, 5: torch.channels_last_3d} 7895*da0073e9SAndroid Build Coastguard Worker for dtype in [torch.bfloat16, torch.half]: 7896*da0073e9SAndroid Build Coastguard Worker for shape, g in [((1, 8, 4, 3), 2), ((1, 8, 3, 4), 4), 7897*da0073e9SAndroid Build Coastguard Worker ((4, 40, 40, 40), 2), ((4, 8, 40, 40), 4), 7898*da0073e9SAndroid Build Coastguard Worker ((1, 8, 40, 40), 4), ((1, 8, 40, 40), 2), 7899*da0073e9SAndroid Build Coastguard Worker ((1, 8, 50, 50), 2), ((1, 8, 50, 50), 4), 7900*da0073e9SAndroid Build Coastguard Worker ((1, 40, 50, 50), 2), ((1, 9, 3, 4, 5), 3), 7901*da0073e9SAndroid Build Coastguard Worker ((1, 60, 10, 10, 10), 3), ((1, 9, 10, 50, 50), 3), 7902*da0073e9SAndroid Build Coastguard Worker ((1, 60, 10, 50, 50), 3), ((1, 8, 65, 55), 2), 7903*da0073e9SAndroid Build Coastguard Worker ((1, 3, 65, 55), 1), ((1, 3, 20, 20), 1)]: 7904*da0073e9SAndroid Build Coastguard Worker for is_cl in [False, True]: 7905*da0073e9SAndroid Build Coastguard Worker format = cl_formats[len(shape)] if is_cl else torch.contiguous_format 7906*da0073e9SAndroid Build Coastguard Worker helper(self, shape, g, format, dtype) 7907*da0073e9SAndroid Build Coastguard Worker 7908*da0073e9SAndroid Build Coastguard Worker def _test_module_empty_inputs(self, module, inputs): 7909*da0073e9SAndroid Build Coastguard Worker for _inp in inputs: 7910*da0073e9SAndroid Build Coastguard Worker _inp.requires_grad_(True) 7911*da0073e9SAndroid Build Coastguard Worker out = module(*inputs) 7912*da0073e9SAndroid Build Coastguard Worker gO = torch.rand_like(out) 7913*da0073e9SAndroid Build Coastguard Worker out.backward(gO) 7914*da0073e9SAndroid Build Coastguard Worker 7915*da0073e9SAndroid Build Coastguard Worker for p in module.parameters(): 7916*da0073e9SAndroid Build Coastguard Worker if p.requires_grad: 7917*da0073e9SAndroid Build Coastguard Worker self.assertEqual(p.grad, torch.zeros_like(p.grad)) 7918*da0073e9SAndroid Build Coastguard Worker 7919*da0073e9SAndroid Build Coastguard Worker for _inp in inputs: 7920*da0073e9SAndroid Build Coastguard Worker self.assertEqual(_inp.grad, torch.zeros_like(_inp)) 7921*da0073e9SAndroid Build Coastguard Worker 7922*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf((not TEST_NUMPY) or (not TEST_SCIPY) or (scipy.__version__ < '1.0.0'), 7923*da0073e9SAndroid Build Coastguard Worker "Scipy v1.0 and/or numpy not found") 7924*da0073e9SAndroid Build Coastguard Worker @expectedFailureMPS # Unsupported Border padding mode https://github.com/pytorch/pytorch/issues/125098 7925*da0073e9SAndroid Build Coastguard Worker @tf32_on_and_off() 7926*da0073e9SAndroid Build Coastguard Worker @bf32_on_and_off() 7927*da0073e9SAndroid Build Coastguard Worker def test_affine_2d_rotate0(self, device): 7928*da0073e9SAndroid Build Coastguard Worker # scipy before 1.0.0 do not support homogeneous coordinate 7929*da0073e9SAndroid Build Coastguard Worker # scipy.ndimage.affine_transform, so we need to skip. 7930*da0073e9SAndroid Build Coastguard Worker input_size = [1, 1, 3, 3] 7931*da0073e9SAndroid Build Coastguard Worker input_ary = np.array(np.random.random(input_size), dtype=np.float32) 7932*da0073e9SAndroid Build Coastguard Worker output_size = [1, 1, 5, 5] 7933*da0073e9SAndroid Build Coastguard Worker angle_rad = 0. 7934*da0073e9SAndroid Build Coastguard Worker 7935*da0073e9SAndroid Build Coastguard Worker transform_tensor, transform_ary, offset = \ 7936*da0073e9SAndroid Build Coastguard Worker _buildEquivalentAffineTransforms2d(device, input_size, output_size, angle_rad) 7937*da0073e9SAndroid Build Coastguard Worker 7938*da0073e9SAndroid Build Coastguard Worker scipy_ary = torch.from_numpy(scipy.ndimage.affine_transform( 7939*da0073e9SAndroid Build Coastguard Worker input_ary[0, 0], 7940*da0073e9SAndroid Build Coastguard Worker transform_ary, 7941*da0073e9SAndroid Build Coastguard Worker offset=offset, 7942*da0073e9SAndroid Build Coastguard Worker output_shape=output_size[2:], 7943*da0073e9SAndroid Build Coastguard Worker order=1, 7944*da0073e9SAndroid Build Coastguard Worker mode='nearest', 7945*da0073e9SAndroid Build Coastguard Worker prefilter=False)) 7946*da0073e9SAndroid Build Coastguard Worker 7947*da0073e9SAndroid Build Coastguard Worker affine_tensor = torch.nn.functional.affine_grid( 7948*da0073e9SAndroid Build Coastguard Worker transform_tensor, 7949*da0073e9SAndroid Build Coastguard Worker torch.Size(output_size), 7950*da0073e9SAndroid Build Coastguard Worker align_corners=True 7951*da0073e9SAndroid Build Coastguard Worker ) 7952*da0073e9SAndroid Build Coastguard Worker 7953*da0073e9SAndroid Build Coastguard Worker gridsample_ary = torch.nn.functional.grid_sample( 7954*da0073e9SAndroid Build Coastguard Worker torch.tensor(input_ary, device=device).to(device), 7955*da0073e9SAndroid Build Coastguard Worker affine_tensor, 7956*da0073e9SAndroid Build Coastguard Worker padding_mode='border', 7957*da0073e9SAndroid Build Coastguard Worker align_corners=True 7958*da0073e9SAndroid Build Coastguard Worker ).to('cpu') 7959*da0073e9SAndroid Build Coastguard Worker 7960*da0073e9SAndroid Build Coastguard Worker self.assertEqual(scipy_ary.mean(), gridsample_ary.mean()) 7961*da0073e9SAndroid Build Coastguard Worker self.assertEqual(scipy_ary, gridsample_ary.reshape_as(scipy_ary)) 7962*da0073e9SAndroid Build Coastguard Worker 7963*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf((not TEST_NUMPY) or (not TEST_SCIPY) or (scipy.__version__ < '1.0.0'), 7964*da0073e9SAndroid Build Coastguard Worker "Scipy v1.0 and/or numpy not found") 7965*da0073e9SAndroid Build Coastguard Worker @expectedFailureMPS # Unsupported Border padding mode https://github.com/pytorch/pytorch/issues/125098 7966*da0073e9SAndroid Build Coastguard Worker @tf32_on_and_off(0.001) 7967*da0073e9SAndroid Build Coastguard Worker @bf32_on_and_off(0.001) 7968*da0073e9SAndroid Build Coastguard Worker def test_affine_2d_rotate90(self, device): 7969*da0073e9SAndroid Build Coastguard Worker # scipy before 1.0.0 do not support homogeneous coordinate 7970*da0073e9SAndroid Build Coastguard Worker # scipy.ndimage.affine_transform, so we need to skip. 7971*da0073e9SAndroid Build Coastguard Worker for input_size2dsq, output_size2dsq in \ 7972*da0073e9SAndroid Build Coastguard Worker itertools.product(input_size2dsq_(), output_size2dsq_()): 7973*da0073e9SAndroid Build Coastguard Worker input_size = input_size2dsq 7974*da0073e9SAndroid Build Coastguard Worker input_ary = np.array(np.random.random(input_size), dtype=np.float32) 7975*da0073e9SAndroid Build Coastguard Worker output_size = output_size2dsq 7976*da0073e9SAndroid Build Coastguard Worker angle_rad = 0.25 * math.pi * 2 7977*da0073e9SAndroid Build Coastguard Worker 7978*da0073e9SAndroid Build Coastguard Worker transform_tensor, transform_ary, offset = \ 7979*da0073e9SAndroid Build Coastguard Worker _buildEquivalentAffineTransforms2d(device, input_size, output_size, angle_rad) 7980*da0073e9SAndroid Build Coastguard Worker 7981*da0073e9SAndroid Build Coastguard Worker scipy_ary = torch.from_numpy(scipy.ndimage.affine_transform( 7982*da0073e9SAndroid Build Coastguard Worker input_ary[0, 0], 7983*da0073e9SAndroid Build Coastguard Worker transform_ary, 7984*da0073e9SAndroid Build Coastguard Worker offset=offset, 7985*da0073e9SAndroid Build Coastguard Worker output_shape=output_size[2:], 7986*da0073e9SAndroid Build Coastguard Worker order=1, 7987*da0073e9SAndroid Build Coastguard Worker mode='nearest', 7988*da0073e9SAndroid Build Coastguard Worker prefilter=True)) 7989*da0073e9SAndroid Build Coastguard Worker 7990*da0073e9SAndroid Build Coastguard Worker if input_size2dsq == output_size2dsq: 7991*da0073e9SAndroid Build Coastguard Worker self.assertEqual(scipy_ary.mean(), input_ary.mean()) 7992*da0073e9SAndroid Build Coastguard Worker self.assertEqual(scipy_ary[0, 0], input_ary[0, 0, 0, -1]) 7993*da0073e9SAndroid Build Coastguard Worker self.assertEqual(scipy_ary[0, -1], input_ary[0, 0, -1, -1]) 7994*da0073e9SAndroid Build Coastguard Worker self.assertEqual(scipy_ary[-1, -1], input_ary[0, 0, -1, 0]) 7995*da0073e9SAndroid Build Coastguard Worker self.assertEqual(scipy_ary[-1, 0], input_ary[0, 0, 0, 0]) 7996*da0073e9SAndroid Build Coastguard Worker 7997*da0073e9SAndroid Build Coastguard Worker affine_tensor = torch.nn.functional.affine_grid( 7998*da0073e9SAndroid Build Coastguard Worker transform_tensor, 7999*da0073e9SAndroid Build Coastguard Worker torch.Size(output_size), 8000*da0073e9SAndroid Build Coastguard Worker align_corners=True 8001*da0073e9SAndroid Build Coastguard Worker ) 8002*da0073e9SAndroid Build Coastguard Worker 8003*da0073e9SAndroid Build Coastguard Worker gridsample_ary = torch.nn.functional.grid_sample( 8004*da0073e9SAndroid Build Coastguard Worker torch.tensor(input_ary, device=device).to(device), 8005*da0073e9SAndroid Build Coastguard Worker affine_tensor, 8006*da0073e9SAndroid Build Coastguard Worker padding_mode='border', 8007*da0073e9SAndroid Build Coastguard Worker align_corners=True 8008*da0073e9SAndroid Build Coastguard Worker ).to('cpu') 8009*da0073e9SAndroid Build Coastguard Worker 8010*da0073e9SAndroid Build Coastguard Worker self.assertEqual(scipy_ary.mean(), gridsample_ary.mean()) 8011*da0073e9SAndroid Build Coastguard Worker self.assertEqual(scipy_ary, gridsample_ary.reshape_as(scipy_ary)) 8012*da0073e9SAndroid Build Coastguard Worker 8013*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf((not TEST_NUMPY) or (not TEST_SCIPY) or (scipy.__version__ < '1.0.0'), 8014*da0073e9SAndroid Build Coastguard Worker "Scipy v1.0 and/or numpy not found") 8015*da0073e9SAndroid Build Coastguard Worker @expectedFailureMPS # Unsupported Border padding mode https://github.com/pytorch/pytorch/issues/125098 8016*da0073e9SAndroid Build Coastguard Worker @tf32_on_and_off(0.005) 8017*da0073e9SAndroid Build Coastguard Worker @bf32_on_and_off(0.005) 8018*da0073e9SAndroid Build Coastguard Worker def test_affine_2d_rotate45(self, device): 8019*da0073e9SAndroid Build Coastguard Worker # scipy before 1.0.0 do not support homogeneous coordinate 8020*da0073e9SAndroid Build Coastguard Worker # scipy.ndimage.affine_transform, so we need to skip. 8021*da0073e9SAndroid Build Coastguard Worker input_size = [1, 1, 3, 3] 8022*da0073e9SAndroid Build Coastguard Worker input_ary = np.array(np.zeros(input_size), dtype=np.float32) 8023*da0073e9SAndroid Build Coastguard Worker input_ary[0, 0, 0, :] = 0.5 8024*da0073e9SAndroid Build Coastguard Worker input_ary[0, 0, 2, 2] = 1.0 8025*da0073e9SAndroid Build Coastguard Worker output_size = [1, 1, 3, 3] 8026*da0073e9SAndroid Build Coastguard Worker angle_rad = 0.125 * math.pi * 2 8027*da0073e9SAndroid Build Coastguard Worker 8028*da0073e9SAndroid Build Coastguard Worker transform_tensor, transform_ary, offset = \ 8029*da0073e9SAndroid Build Coastguard Worker _buildEquivalentAffineTransforms2d(device, input_size, output_size, angle_rad) 8030*da0073e9SAndroid Build Coastguard Worker 8031*da0073e9SAndroid Build Coastguard Worker scipy_ary = torch.from_numpy(scipy.ndimage.affine_transform( 8032*da0073e9SAndroid Build Coastguard Worker input_ary[0, 0], 8033*da0073e9SAndroid Build Coastguard Worker transform_ary, 8034*da0073e9SAndroid Build Coastguard Worker offset=offset, 8035*da0073e9SAndroid Build Coastguard Worker output_shape=output_size[2:], 8036*da0073e9SAndroid Build Coastguard Worker order=1, 8037*da0073e9SAndroid Build Coastguard Worker mode='nearest', 8038*da0073e9SAndroid Build Coastguard Worker prefilter=False)) 8039*da0073e9SAndroid Build Coastguard Worker 8040*da0073e9SAndroid Build Coastguard Worker affine_tensor = torch.nn.functional.affine_grid( 8041*da0073e9SAndroid Build Coastguard Worker transform_tensor, 8042*da0073e9SAndroid Build Coastguard Worker torch.Size(output_size), 8043*da0073e9SAndroid Build Coastguard Worker align_corners=True 8044*da0073e9SAndroid Build Coastguard Worker ) 8045*da0073e9SAndroid Build Coastguard Worker 8046*da0073e9SAndroid Build Coastguard Worker gridsample_ary = torch.nn.functional.grid_sample( 8047*da0073e9SAndroid Build Coastguard Worker torch.tensor(input_ary, device=device).to(device), 8048*da0073e9SAndroid Build Coastguard Worker affine_tensor, 8049*da0073e9SAndroid Build Coastguard Worker padding_mode='border', 8050*da0073e9SAndroid Build Coastguard Worker align_corners=True 8051*da0073e9SAndroid Build Coastguard Worker ).to('cpu') 8052*da0073e9SAndroid Build Coastguard Worker 8053*da0073e9SAndroid Build Coastguard Worker self.assertEqual(scipy_ary, gridsample_ary.reshape_as(scipy_ary)) 8054*da0073e9SAndroid Build Coastguard Worker 8055*da0073e9SAndroid Build Coastguard Worker @onlyCUDA 8056*da0073e9SAndroid Build Coastguard Worker @largeTensorTest("60GB", "cpu") 8057*da0073e9SAndroid Build Coastguard Worker @largeTensorTest("16GB", "cuda") 8058*da0073e9SAndroid Build Coastguard Worker def test_avg_pool_large_tensor(self, device): 8059*da0073e9SAndroid Build Coastguard Worker # test for https://github.com/pytorch/pytorch/issues/113833 8060*da0073e9SAndroid Build Coastguard Worker a = torch.randn(128, 256, 256, 256, dtype=torch.half, device=device, requires_grad=True) 8061*da0073e9SAndroid Build Coastguard Worker a_cpu = a.detach().cpu().float() 8062*da0073e9SAndroid Build Coastguard Worker m = torch.nn.AvgPool2d(2) 8063*da0073e9SAndroid Build Coastguard Worker o = m(a) 8064*da0073e9SAndroid Build Coastguard Worker a_cpu.requires_grad = True 8065*da0073e9SAndroid Build Coastguard Worker o.sum().backward() 8066*da0073e9SAndroid Build Coastguard Worker o_cpu = m(a_cpu) 8067*da0073e9SAndroid Build Coastguard Worker o_cpu.sum().backward() 8068*da0073e9SAndroid Build Coastguard Worker # workaround for memory usage overhead of assertEqual 8069*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.allclose(a.grad.cpu(), a_cpu.grad.half())) 8070*da0073e9SAndroid Build Coastguard Worker 8071*da0073e9SAndroid Build Coastguard Worker @onlyCUDA 8072*da0073e9SAndroid Build Coastguard Worker @largeTensorTest("48GB", "cpu") 8073*da0073e9SAndroid Build Coastguard Worker @largeTensorTest("48GB", "cuda") 8074*da0073e9SAndroid Build Coastguard Worker def test_avg_pool_large_tensor2(self, device): 8075*da0073e9SAndroid Build Coastguard Worker # test for https://github.com/pytorch/pytorch/issues/129785 8076*da0073e9SAndroid Build Coastguard Worker out_size = [2048, 64, 104, 79] 8077*da0073e9SAndroid Build Coastguard Worker size = [2048, 64, 209, 159] 8078*da0073e9SAndroid Build Coastguard Worker inp = torch.randn(size, device=device, requires_grad=True, dtype=torch.float) 8079*da0073e9SAndroid Build Coastguard Worker inp_cpu = inp.detach().cpu() 8080*da0073e9SAndroid Build Coastguard Worker m = torch.nn.AvgPool2d([2, 2], [2, 2], [0, 0], False, True, None) 8081*da0073e9SAndroid Build Coastguard Worker o = m(inp) 8082*da0073e9SAndroid Build Coastguard Worker inp_cpu.requires_grad = True 8083*da0073e9SAndroid Build Coastguard Worker o.sum().backward() 8084*da0073e9SAndroid Build Coastguard Worker o_cpu = m(inp_cpu) 8085*da0073e9SAndroid Build Coastguard Worker o_cpu.sum().backward() 8086*da0073e9SAndroid Build Coastguard Worker self.assertEqual(o.shape, out_size) 8087*da0073e9SAndroid Build Coastguard Worker self.assertEqual(o_cpu.shape, out_size) 8088*da0073e9SAndroid Build Coastguard Worker # reduce memory usage 8089*da0073e9SAndroid Build Coastguard Worker self.assertEqual(inp.grad.sum(), inp_cpu.grad.sum()) 8090*da0073e9SAndroid Build Coastguard Worker 8091*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf((not TEST_NUMPY) or (not TEST_SCIPY) or (scipy.__version__ < '1.0.0'), 8092*da0073e9SAndroid Build Coastguard Worker "Scipy v1.0 and/or numpy not found") 8093*da0073e9SAndroid Build Coastguard Worker @expectedFailureMPS # Unsupported Border padding mode https://github.com/pytorch/pytorch/issues/125098 8094*da0073e9SAndroid Build Coastguard Worker @tf32_on_and_off(0.005) 8095*da0073e9SAndroid Build Coastguard Worker @bf32_on_and_off(0.005) 8096*da0073e9SAndroid Build Coastguard Worker def test_affine_2d_rotateRandom(self, device): 8097*da0073e9SAndroid Build Coastguard Worker # scipy before 1.0.0 do not support homogeneous coordinate 8098*da0073e9SAndroid Build Coastguard Worker # scipy.ndimage.affine_transform, so we need to skip. 8099*da0073e9SAndroid Build Coastguard Worker for angle_rad, input_size2d, output_size2d in \ 8100*da0073e9SAndroid Build Coastguard Worker itertools.product(angle_rad_(), input_size2d_(), output_size2d_()): 8101*da0073e9SAndroid Build Coastguard Worker 8102*da0073e9SAndroid Build Coastguard Worker input_size = input_size2d 8103*da0073e9SAndroid Build Coastguard Worker input_ary = np.array(np.random.random(input_size), dtype=np.float32).round(3) 8104*da0073e9SAndroid Build Coastguard Worker output_size = output_size2d 8105*da0073e9SAndroid Build Coastguard Worker 8106*da0073e9SAndroid Build Coastguard Worker input_ary[0, 0, 0, 0] = 2 8107*da0073e9SAndroid Build Coastguard Worker input_ary[0, 0, 0, -1] = 4 8108*da0073e9SAndroid Build Coastguard Worker input_ary[0, 0, -1, 0] = 6 8109*da0073e9SAndroid Build Coastguard Worker input_ary[0, 0, -1, -1] = 8 8110*da0073e9SAndroid Build Coastguard Worker 8111*da0073e9SAndroid Build Coastguard Worker transform_tensor, transform_ary, grid_ary = \ 8112*da0073e9SAndroid Build Coastguard Worker _buildEquivalentAffineTransforms2d(device, input_size, output_size, angle_rad) 8113*da0073e9SAndroid Build Coastguard Worker 8114*da0073e9SAndroid Build Coastguard Worker scipy_ary = torch.from_numpy(scipy.ndimage.affine_transform( 8115*da0073e9SAndroid Build Coastguard Worker input_ary[0, 0], 8116*da0073e9SAndroid Build Coastguard Worker transform_ary, 8117*da0073e9SAndroid Build Coastguard Worker output_shape=output_size[2:], 8118*da0073e9SAndroid Build Coastguard Worker order=1, 8119*da0073e9SAndroid Build Coastguard Worker mode='nearest', 8120*da0073e9SAndroid Build Coastguard Worker prefilter=False)) 8121*da0073e9SAndroid Build Coastguard Worker 8122*da0073e9SAndroid Build Coastguard Worker affine_tensor = torch.nn.functional.affine_grid( 8123*da0073e9SAndroid Build Coastguard Worker transform_tensor, 8124*da0073e9SAndroid Build Coastguard Worker torch.Size(output_size), 8125*da0073e9SAndroid Build Coastguard Worker align_corners=True 8126*da0073e9SAndroid Build Coastguard Worker ) 8127*da0073e9SAndroid Build Coastguard Worker 8128*da0073e9SAndroid Build Coastguard Worker gridsample_ary = torch.nn.functional.grid_sample( 8129*da0073e9SAndroid Build Coastguard Worker torch.tensor(input_ary, device=device).to(device), 8130*da0073e9SAndroid Build Coastguard Worker affine_tensor, 8131*da0073e9SAndroid Build Coastguard Worker padding_mode='border', 8132*da0073e9SAndroid Build Coastguard Worker align_corners=True 8133*da0073e9SAndroid Build Coastguard Worker ).to('cpu') 8134*da0073e9SAndroid Build Coastguard Worker 8135*da0073e9SAndroid Build Coastguard Worker affine_tensor = affine_tensor.to('cpu') 8136*da0073e9SAndroid Build Coastguard Worker 8137*da0073e9SAndroid Build Coastguard Worker for r in range(affine_tensor.size(1)): 8138*da0073e9SAndroid Build Coastguard Worker for c in range(affine_tensor.size(2)): 8139*da0073e9SAndroid Build Coastguard Worker grid_out = np.dot(grid_ary, [r, c, 1]) 8140*da0073e9SAndroid Build Coastguard Worker self.assertEqual(affine_tensor[0, r, c], grid_out[:2], exact_dtype=False) 8141*da0073e9SAndroid Build Coastguard Worker 8142*da0073e9SAndroid Build Coastguard Worker self.assertEqual(scipy_ary, gridsample_ary.reshape_as(scipy_ary)) 8143*da0073e9SAndroid Build Coastguard Worker 8144*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf((not TEST_NUMPY) or (not TEST_SCIPY) or (scipy.__version__ < '1.0.0'), 8145*da0073e9SAndroid Build Coastguard Worker "Scipy v1.0 and/or numpy not found") 8146*da0073e9SAndroid Build Coastguard Worker @expectedFailureMPS # aten::grid_sampler_3d not implemented https://github.com/pytorch/pytorch/issues/77764 8147*da0073e9SAndroid Build Coastguard Worker @tf32_on_and_off(0.005) 8148*da0073e9SAndroid Build Coastguard Worker @bf32_on_and_off(0.005) 8149*da0073e9SAndroid Build Coastguard Worker def test_affine_3d_rotateRandom(self, device): 8150*da0073e9SAndroid Build Coastguard Worker # scipy before 1.0.0 do not support homogeneous coordinate 8151*da0073e9SAndroid Build Coastguard Worker # scipy.ndimage.affine_transform, so we need to skip. 8152*da0073e9SAndroid Build Coastguard Worker for angle_rad, axis_vector, input_size3d, output_size3d in \ 8153*da0073e9SAndroid Build Coastguard Worker itertools.product(angle_rad_(), axis_vector_(), input_size3d_(), output_size3d_()): 8154*da0073e9SAndroid Build Coastguard Worker input_size = input_size3d 8155*da0073e9SAndroid Build Coastguard Worker input_ary = np.array(np.random.random(input_size), dtype=np.float32) 8156*da0073e9SAndroid Build Coastguard Worker output_size = output_size3d 8157*da0073e9SAndroid Build Coastguard Worker 8158*da0073e9SAndroid Build Coastguard Worker input_ary[0, 0, 0, 0, 0] = 2 8159*da0073e9SAndroid Build Coastguard Worker input_ary[0, 0, 0, 0, -1] = 3 8160*da0073e9SAndroid Build Coastguard Worker input_ary[0, 0, 0, -1, 0] = 4 8161*da0073e9SAndroid Build Coastguard Worker input_ary[0, 0, 0, -1, -1] = 5 8162*da0073e9SAndroid Build Coastguard Worker input_ary[0, 0, -1, 0, 0] = 6 8163*da0073e9SAndroid Build Coastguard Worker input_ary[0, 0, -1, 0, -1] = 7 8164*da0073e9SAndroid Build Coastguard Worker input_ary[0, 0, -1, -1, 0] = 8 8165*da0073e9SAndroid Build Coastguard Worker input_ary[0, 0, -1, -1, -1] = 9 8166*da0073e9SAndroid Build Coastguard Worker 8167*da0073e9SAndroid Build Coastguard Worker transform_tensor, transform_ary, grid_ary = \ 8168*da0073e9SAndroid Build Coastguard Worker _buildEquivalentAffineTransforms3d(device, input_size, output_size, angle_rad, axis_vector) 8169*da0073e9SAndroid Build Coastguard Worker 8170*da0073e9SAndroid Build Coastguard Worker scipy_ary = torch.from_numpy(scipy.ndimage.affine_transform( 8171*da0073e9SAndroid Build Coastguard Worker input_ary[0, 0], 8172*da0073e9SAndroid Build Coastguard Worker transform_ary, 8173*da0073e9SAndroid Build Coastguard Worker output_shape=output_size[2:], 8174*da0073e9SAndroid Build Coastguard Worker order=1, 8175*da0073e9SAndroid Build Coastguard Worker mode='nearest', 8176*da0073e9SAndroid Build Coastguard Worker prefilter=False)) 8177*da0073e9SAndroid Build Coastguard Worker 8178*da0073e9SAndroid Build Coastguard Worker affine_tensor = torch.nn.functional.affine_grid( 8179*da0073e9SAndroid Build Coastguard Worker transform_tensor, 8180*da0073e9SAndroid Build Coastguard Worker torch.Size(output_size), 8181*da0073e9SAndroid Build Coastguard Worker align_corners=True 8182*da0073e9SAndroid Build Coastguard Worker ) 8183*da0073e9SAndroid Build Coastguard Worker 8184*da0073e9SAndroid Build Coastguard Worker gridsample_ary = torch.nn.functional.grid_sample( 8185*da0073e9SAndroid Build Coastguard Worker torch.tensor(input_ary, device=device).to(device), 8186*da0073e9SAndroid Build Coastguard Worker affine_tensor, 8187*da0073e9SAndroid Build Coastguard Worker padding_mode='border', 8188*da0073e9SAndroid Build Coastguard Worker align_corners=True 8189*da0073e9SAndroid Build Coastguard Worker ).to('cpu') 8190*da0073e9SAndroid Build Coastguard Worker 8191*da0073e9SAndroid Build Coastguard Worker affine_tensor = affine_tensor.to('cpu') 8192*da0073e9SAndroid Build Coastguard Worker 8193*da0073e9SAndroid Build Coastguard Worker for i in range(affine_tensor.size(1)): 8194*da0073e9SAndroid Build Coastguard Worker for r in range(affine_tensor.size(2)): 8195*da0073e9SAndroid Build Coastguard Worker for c in range(affine_tensor.size(3)): 8196*da0073e9SAndroid Build Coastguard Worker grid_out = np.dot(grid_ary, [i, r, c, 1]) 8197*da0073e9SAndroid Build Coastguard Worker self.assertEqual(affine_tensor[0, i, r, c], grid_out[:3], exact_dtype=False) 8198*da0073e9SAndroid Build Coastguard Worker 8199*da0073e9SAndroid Build Coastguard Worker self.assertEqual(scipy_ary, gridsample_ary.reshape_as(scipy_ary)) 8200*da0073e9SAndroid Build Coastguard Worker 8201*da0073e9SAndroid Build Coastguard Worker 8202*da0073e9SAndroid Build Coastguard Worker @onlyCUDA 8203*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float, torch.half) 8204*da0073e9SAndroid Build Coastguard Worker def test_batchnorm_large_batch(self, device, dtype): 8205*da0073e9SAndroid Build Coastguard Worker bn = nn.BatchNorm2d(1).to(device, dtype) 8206*da0073e9SAndroid Build Coastguard Worker data = torch.rand(880801, 1, 1, 1, device=device, dtype=dtype) 8207*da0073e9SAndroid Build Coastguard Worker out = bn(data).sum().backward() 8208*da0073e9SAndroid Build Coastguard Worker 8209*da0073e9SAndroid Build Coastguard Worker @dtypesIfCUDA(torch.float, torch.double, torch.half, torch.complex128) 8210*da0073e9SAndroid Build Coastguard Worker @dtypesIfMPS(torch.float, torch.half, torch.complex64) 8211*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float, torch.double, torch.bfloat16, torch.complex128) 8212*da0073e9SAndroid Build Coastguard Worker def test_conv_empty_input(self, device, dtype): 8213*da0073e9SAndroid Build Coastguard Worker def help(input, conv, memory_format): 8214*da0073e9SAndroid Build Coastguard Worker ref_out = conv(input) 8215*da0073e9SAndroid Build Coastguard Worker conv_cl = conv.to(memory_format=memory_format) 8216*da0073e9SAndroid Build Coastguard Worker out_cl = conv_cl(input) 8217*da0073e9SAndroid Build Coastguard Worker self.assertEqual(ref_out, out_cl) 8218*da0073e9SAndroid Build Coastguard Worker input_cl = input.to(memory_format=memory_format) 8219*da0073e9SAndroid Build Coastguard Worker out_cl2 = conv(input_cl) 8220*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out_cl, out_cl2) 8221*da0073e9SAndroid Build Coastguard Worker out_cl3 = conv_cl(input_cl) 8222*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out_cl, out_cl3) 8223*da0073e9SAndroid Build Coastguard Worker 8224*da0073e9SAndroid Build Coastguard Worker # channels_last case 8225*da0073e9SAndroid Build Coastguard Worker input2d = torch.randn((0, 4, 20, 20)).to(device=device, dtype=dtype) 8226*da0073e9SAndroid Build Coastguard Worker conv2d = torch.nn.Conv2d(4, 4, 3, 1).to(device=device, dtype=dtype) 8227*da0073e9SAndroid Build Coastguard Worker help(input2d, conv2d, torch.channels_last) 8228*da0073e9SAndroid Build Coastguard Worker # channels_last_3d case 8229*da0073e9SAndroid Build Coastguard Worker input3d = torch.randn((0, 4, 20, 20, 20)).to(device=device, dtype=dtype) 8230*da0073e9SAndroid Build Coastguard Worker conv3d = torch.nn.Conv3d(4, 4, 3, 1).to(device=device, dtype=dtype) 8231*da0073e9SAndroid Build Coastguard Worker help(input3d, conv3d, torch.channels_last_3d) 8232*da0073e9SAndroid Build Coastguard Worker # non-contiguous case 8233*da0073e9SAndroid Build Coastguard Worker weight = torch.rand(4, 8, 3, 3)[:, ::2, :, :].to(device=device, dtype=dtype) 8234*da0073e9SAndroid Build Coastguard Worker bias = torch.rand(4).to(device=device, dtype=dtype) 8235*da0073e9SAndroid Build Coastguard Worker out = F.conv2d(input2d, weight, bias, (1, 1), 0, (1, 1), 1) 8236*da0073e9SAndroid Build Coastguard Worker weight = weight.contiguous() 8237*da0073e9SAndroid Build Coastguard Worker out_ref = F.conv2d(input2d, weight, bias, (1, 1), 0, (1, 1), 1) 8238*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out_ref, out) 8239*da0073e9SAndroid Build Coastguard Worker # sigfpe reported in https://github.com/pytorch/pytorch/issues/94125 8240*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(RuntimeError): 8241*da0073e9SAndroid Build Coastguard Worker inp = torch.empty([1, 1, 1, 0], dtype=dtype, device=device) 8242*da0073e9SAndroid Build Coastguard Worker weight = torch.empty([1, 0, 1], dtype=dtype, device=device) 8243*da0073e9SAndroid Build Coastguard Worker torch._C._nn.slow_conv3d(inp, weight, 1) 8244*da0073e9SAndroid Build Coastguard Worker 8245*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, re.escape("2D kernel_size expected")): 8246*da0073e9SAndroid Build Coastguard Worker torch._C._nn.thnn_conv2d(torch.rand([1, 1, 1, 1]), kernel_size=[], padding=[1, 1], stride=[1, 1], 8247*da0073e9SAndroid Build Coastguard Worker weight=torch.rand([1, 1])) 8248*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, re.escape("2D stride expected")): 8249*da0073e9SAndroid Build Coastguard Worker torch._C._nn.thnn_conv2d(torch.rand([1, 1, 1, 1]), kernel_size=[1, 1], padding=[1, 1], stride=[], 8250*da0073e9SAndroid Build Coastguard Worker weight=torch.rand([1, 1])) 8251*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, re.escape("2D padding expected")): 8252*da0073e9SAndroid Build Coastguard Worker torch._C._nn.thnn_conv2d(torch.rand([1, 1, 1, 1]), kernel_size=[1, 1], padding=[], stride=[1, 1], 8253*da0073e9SAndroid Build Coastguard Worker weight=torch.rand([1, 1])) 8254*da0073e9SAndroid Build Coastguard Worker 8255*da0073e9SAndroid Build Coastguard Worker def test_InstanceNorm1d_general(self, device): 8256*da0073e9SAndroid Build Coastguard Worker b = random.randint(3, 5) 8257*da0073e9SAndroid Build Coastguard Worker c = random.randint(3, 5) 8258*da0073e9SAndroid Build Coastguard Worker d = random.randint(8, 10) 8259*da0073e9SAndroid Build Coastguard Worker 8260*da0073e9SAndroid Build Coastguard Worker input = torch.rand(b, c, d) 8261*da0073e9SAndroid Build Coastguard Worker self._test_InstanceNorm_general(nn.InstanceNorm1d, input, device) 8262*da0073e9SAndroid Build Coastguard Worker 8263*da0073e9SAndroid Build Coastguard Worker if self.device_type == 'cuda': 8264*da0073e9SAndroid Build Coastguard Worker self._test_InstanceNorm_cuda_half(nn.InstanceNorm1d, input, device) 8265*da0073e9SAndroid Build Coastguard Worker 8266*da0073e9SAndroid Build Coastguard Worker def test_InstanceNorm2d_general(self, device): 8267*da0073e9SAndroid Build Coastguard Worker b = random.randint(3, 5) 8268*da0073e9SAndroid Build Coastguard Worker c = random.randint(3, 5) 8269*da0073e9SAndroid Build Coastguard Worker w = random.randint(3, 6) 8270*da0073e9SAndroid Build Coastguard Worker h = random.randint(6, 8) 8271*da0073e9SAndroid Build Coastguard Worker 8272*da0073e9SAndroid Build Coastguard Worker input = torch.rand(b, c, h, w) 8273*da0073e9SAndroid Build Coastguard Worker self._test_InstanceNorm_general(nn.InstanceNorm2d, input, device) 8274*da0073e9SAndroid Build Coastguard Worker 8275*da0073e9SAndroid Build Coastguard Worker if self.device_type == 'cuda': 8276*da0073e9SAndroid Build Coastguard Worker self._test_InstanceNorm_cuda_half(nn.InstanceNorm2d, input, device) 8277*da0073e9SAndroid Build Coastguard Worker 8278*da0073e9SAndroid Build Coastguard Worker def test_InstanceNorm3d_general(self, device): 8279*da0073e9SAndroid Build Coastguard Worker b = random.randint(3, 5) 8280*da0073e9SAndroid Build Coastguard Worker c = random.randint(3, 5) 8281*da0073e9SAndroid Build Coastguard Worker w = random.randint(2, 5) 8282*da0073e9SAndroid Build Coastguard Worker h = random.randint(2, 5) 8283*da0073e9SAndroid Build Coastguard Worker d = random.randint(2, 5) 8284*da0073e9SAndroid Build Coastguard Worker 8285*da0073e9SAndroid Build Coastguard Worker input = torch.rand(b, c, h, w, d) 8286*da0073e9SAndroid Build Coastguard Worker self._test_InstanceNorm_general(nn.InstanceNorm3d, input, device) 8287*da0073e9SAndroid Build Coastguard Worker 8288*da0073e9SAndroid Build Coastguard Worker if self.device_type == 'cuda': 8289*da0073e9SAndroid Build Coastguard Worker self._test_InstanceNorm_cuda_half(nn.InstanceNorm3d, input, device) 8290*da0073e9SAndroid Build Coastguard Worker 8291*da0073e9SAndroid Build Coastguard Worker @parametrize_test("instance_norm_cls", [nn.InstanceNorm1d, nn.InstanceNorm2d, nn.InstanceNorm3d], name_fn=lambda c: c.__name__) 8292*da0073e9SAndroid Build Coastguard Worker @parametrize_test("no_batch_dim", [True, False]) 8293*da0073e9SAndroid Build Coastguard Worker @parametrize_test("affine", [True, False]) 8294*da0073e9SAndroid Build Coastguard Worker def test_instancenorm_raises_error_if_input_channels_is_not_num_features(self, device, instance_norm_cls, no_batch_dim, affine): 8295*da0073e9SAndroid Build Coastguard Worker inst_norm = instance_norm_cls(4, affine=affine) 8296*da0073e9SAndroid Build Coastguard Worker size = [2] * inst_norm._get_no_batch_dim() 8297*da0073e9SAndroid Build Coastguard Worker if not no_batch_dim: 8298*da0073e9SAndroid Build Coastguard Worker size = [3] + size 8299*da0073e9SAndroid Build Coastguard Worker t = torch.randn(size) 8300*da0073e9SAndroid Build Coastguard Worker if affine: 8301*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(ValueError, "expected input's size at dim="): 8302*da0073e9SAndroid Build Coastguard Worker inst_norm(t) 8303*da0073e9SAndroid Build Coastguard Worker else: 8304*da0073e9SAndroid Build Coastguard Worker with warnings.catch_warnings(record=True) as w: 8305*da0073e9SAndroid Build Coastguard Worker inst_norm(t) 8306*da0073e9SAndroid Build Coastguard Worker self.assertIn("which is not used because affine=False", str(w[0].message)) 8307*da0073e9SAndroid Build Coastguard Worker 8308*da0073e9SAndroid Build Coastguard Worker def test_instancenorm_raises_error_if_less_than_one_value_per_channel(self, device): 8309*da0073e9SAndroid Build Coastguard Worker x = torch.rand(10)[None, :, None] 8310*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(ValueError): 8311*da0073e9SAndroid Build Coastguard Worker torch.nn.InstanceNorm1d(10)(x).to(device) 8312*da0073e9SAndroid Build Coastguard Worker 8313*da0073e9SAndroid Build Coastguard Worker def test_instancenorm_raises_error_for_single_spatial_element_during_training(self, device): 8314*da0073e9SAndroid Build Coastguard Worker BATCH_SIZE = 10 8315*da0073e9SAndroid Build Coastguard Worker NUM_CHANNELS = 3 8316*da0073e9SAndroid Build Coastguard Worker norms = [torch.nn.InstanceNorm1d, torch.nn.InstanceNorm2d, torch.nn.InstanceNorm3d] 8317*da0073e9SAndroid Build Coastguard Worker for i, norm in enumerate(norms): 8318*da0073e9SAndroid Build Coastguard Worker m = norm(NUM_CHANNELS, track_running_stats=True) 8319*da0073e9SAndroid Build Coastguard Worker m.to(device) 8320*da0073e9SAndroid Build Coastguard Worker 8321*da0073e9SAndroid Build Coastguard Worker # Create an appropriately-sized input with a single spatial element. 8322*da0073e9SAndroid Build Coastguard Worker input = torch.randn(BATCH_SIZE, NUM_CHANNELS, *[1 for _ in range(i + 1)], 8323*da0073e9SAndroid Build Coastguard Worker device=device) 8324*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(ValueError): 8325*da0073e9SAndroid Build Coastguard Worker m(input) 8326*da0073e9SAndroid Build Coastguard Worker 8327*da0073e9SAndroid Build Coastguard Worker # Single spatial element should be fine in eval. 8328*da0073e9SAndroid Build Coastguard Worker m.eval() 8329*da0073e9SAndroid Build Coastguard Worker m(input) 8330*da0073e9SAndroid Build Coastguard Worker 8331*da0073e9SAndroid Build Coastguard Worker def test_LayerNorm_general(self, device): 8332*da0073e9SAndroid Build Coastguard Worker self._test_LayerNorm_general(device) 8333*da0073e9SAndroid Build Coastguard Worker 8334*da0073e9SAndroid Build Coastguard Worker if self.device_type == 'cuda' or self.device_type == 'cpu': 8335*da0073e9SAndroid Build Coastguard Worker for dtype in [torch.half, torch.bfloat16]: 8336*da0073e9SAndroid Build Coastguard Worker self._test_LayerNorm_general(device, dtype=dtype) 8337*da0073e9SAndroid Build Coastguard Worker 8338*da0073e9SAndroid Build Coastguard Worker if self.device_type == 'cuda': 8339*da0073e9SAndroid Build Coastguard Worker self._test_LayerNorm_cuda_half(device) 8340*da0073e9SAndroid Build Coastguard Worker 8341*da0073e9SAndroid Build Coastguard Worker if self.device_type == 'cpu': 8342*da0073e9SAndroid Build Coastguard Worker for dtype in [torch.half, torch.bfloat16]: 8343*da0073e9SAndroid Build Coastguard Worker self._test_LayerNorm_cpu_mixed_dtype(device, dtype=dtype) 8344*da0073e9SAndroid Build Coastguard Worker 8345*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 8346*da0073e9SAndroid Build Coastguard Worker def test_LayerNorm_numeric(self, device): 8347*da0073e9SAndroid Build Coastguard Worker def layer_norm_ref(X, gamma, beta, normalized_shape, eps): 8348*da0073e9SAndroid Build Coastguard Worker feature_size = np.prod(normalized_shape) 8349*da0073e9SAndroid Build Coastguard Worker X_view = X.view(-1, feature_size) 8350*da0073e9SAndroid Build Coastguard Worker mean = X_view.mean(dim=-1, keepdim=True) 8351*da0073e9SAndroid Build Coastguard Worker var = X_view.var(dim=-1, unbiased=False, keepdim=True) 8352*da0073e9SAndroid Build Coastguard Worker Y = (X_view - mean) / torch.sqrt(var + eps) 8353*da0073e9SAndroid Build Coastguard Worker Y = Y * gamma.view(-1) + beta.view(-1) 8354*da0073e9SAndroid Build Coastguard Worker return Y.view(*X.size()) 8355*da0073e9SAndroid Build Coastguard Worker 8356*da0073e9SAndroid Build Coastguard Worker normalized_shape = [256, 256, 144] 8357*da0073e9SAndroid Build Coastguard Worker layer_norm = nn.LayerNorm(normalized_shape).float().to(device) 8358*da0073e9SAndroid Build Coastguard Worker X = torch.rand(2, *normalized_shape, dtype=torch.float32, 8359*da0073e9SAndroid Build Coastguard Worker device=device) 8360*da0073e9SAndroid Build Coastguard Worker 8361*da0073e9SAndroid Build Coastguard Worker Y = layer_norm(X) 8362*da0073e9SAndroid Build Coastguard Worker Y_ref = layer_norm_ref(X, layer_norm.weight.data, layer_norm.bias.data, 8363*da0073e9SAndroid Build Coastguard Worker normalized_shape, layer_norm.eps) 8364*da0073e9SAndroid Build Coastguard Worker self.assertEqual(Y, Y_ref, rtol=0, atol=1e-5) 8365*da0073e9SAndroid Build Coastguard Worker 8366*da0073e9SAndroid Build Coastguard Worker if self.device_type == 'cuda': 8367*da0073e9SAndroid Build Coastguard Worker layer_norm.cpu() 8368*da0073e9SAndroid Build Coastguard Worker Y_cpu = layer_norm(X.cpu()) 8369*da0073e9SAndroid Build Coastguard Worker self.assertEqual(Y_cpu, Y, rtol=0, atol=1e-5) 8370*da0073e9SAndroid Build Coastguard Worker 8371*da0073e9SAndroid Build Coastguard Worker @onlyCPU 8372*da0073e9SAndroid Build Coastguard Worker def test_glu_bfloat16(self, device): 8373*da0073e9SAndroid Build Coastguard Worker def test_dtype(fn, input, dtype): 8374*da0073e9SAndroid Build Coastguard Worker input = input.detach().clone().to(dtype=dtype).requires_grad_(True) 8375*da0073e9SAndroid Build Coastguard Worker input2 = input.detach().clone().float().requires_grad_(True) 8376*da0073e9SAndroid Build Coastguard Worker out = fn(input) 8377*da0073e9SAndroid Build Coastguard Worker out.sum().backward() 8378*da0073e9SAndroid Build Coastguard Worker out2 = fn(input2) 8379*da0073e9SAndroid Build Coastguard Worker out2.sum().backward() 8380*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out.dtype, dtype) 8381*da0073e9SAndroid Build Coastguard Worker self.assertEqual(input.grad.dtype, dtype) 8382*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out, out2, exact_dtype=False) 8383*da0073e9SAndroid Build Coastguard Worker self.assertEqual(input.grad, input2.grad, atol=1e-2, rtol=0, exact_dtype=False) 8384*da0073e9SAndroid Build Coastguard Worker 8385*da0073e9SAndroid Build Coastguard Worker def func(device): 8386*da0073e9SAndroid Build Coastguard Worker return torch.nn.GLU(dim=-1).to(device) 8387*da0073e9SAndroid Build Coastguard Worker 8388*da0073e9SAndroid Build Coastguard Worker shapes = [[1, 3, 1, 6], [1, 3, 1, 128], [1, 3, 256, 256]] 8389*da0073e9SAndroid Build Coastguard Worker for shape in shapes: 8390*da0073e9SAndroid Build Coastguard Worker x = torch.randn(shape, device=device) 8391*da0073e9SAndroid Build Coastguard Worker test_dtype(func(device), x, torch.bfloat16) 8392*da0073e9SAndroid Build Coastguard Worker 8393*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 8394*da0073e9SAndroid Build Coastguard Worker def test_GroupNorm_general(self, device): 8395*da0073e9SAndroid Build Coastguard Worker self._test_GroupNorm_general(device) 8396*da0073e9SAndroid Build Coastguard Worker 8397*da0073e9SAndroid Build Coastguard Worker if self.device_type == 'cuda': 8398*da0073e9SAndroid Build Coastguard Worker self._test_GroupNorm_cuda_half() 8399*da0073e9SAndroid Build Coastguard Worker 8400*da0073e9SAndroid Build Coastguard Worker if self.device_type == 'cpu': 8401*da0073e9SAndroid Build Coastguard Worker self._test_GroupNorm_cpu_mixed_dtype() 8402*da0073e9SAndroid Build Coastguard Worker 8403*da0073e9SAndroid Build Coastguard Worker def test_GroupNorm_raises_error_if_one_value_per_group(self, device): 8404*da0073e9SAndroid Build Coastguard Worker x = torch.rand(10)[None, :, None] 8405*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(ValueError): 8406*da0073e9SAndroid Build Coastguard Worker torch.nn.GroupNorm(10, 10)(x).to(device) 8407*da0073e9SAndroid Build Coastguard Worker 8408*da0073e9SAndroid Build Coastguard Worker def test_GroupNorm_empty(self, device): 8409*da0073e9SAndroid Build Coastguard Worker mod = torch.nn.GroupNorm(2, 4).to(device) 8410*da0073e9SAndroid Build Coastguard Worker inp = torch.randn(0, 4, 2, 2, device=device) 8411*da0073e9SAndroid Build Coastguard Worker _test_module_empty_input(self, mod, inp) 8412*da0073e9SAndroid Build Coastguard Worker if self.device_type == 'cuda' and self.has_cudnn(): 8413*da0073e9SAndroid Build Coastguard Worker with torch.backends.cudnn.flags(enabled=False): 8414*da0073e9SAndroid Build Coastguard Worker _test_module_empty_input(self, mod, inp) 8415*da0073e9SAndroid Build Coastguard Worker 8416*da0073e9SAndroid Build Coastguard Worker @onlyCPU 8417*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float, torch.double, torch.bfloat16, torch.half) 8418*da0073e9SAndroid Build Coastguard Worker def test_groupnorm_nhwc(self, device, dtype): 8419*da0073e9SAndroid Build Coastguard Worker def helper(self, size, groups, memory_format, is_mixed): 8420*da0073e9SAndroid Build Coastguard Worker channels = size[1] 8421*da0073e9SAndroid Build Coastguard Worker input = torch.randn(size, dtype=dtype, device=device, requires_grad=True) 8422*da0073e9SAndroid Build Coastguard Worker input = input.contiguous(memory_format=memory_format) 8423*da0073e9SAndroid Build Coastguard Worker input.retain_grad() 8424*da0073e9SAndroid Build Coastguard Worker grad = torch.randn(size, dtype=dtype, device=device) 8425*da0073e9SAndroid Build Coastguard Worker grad = grad.contiguous(memory_format=memory_format) 8426*da0073e9SAndroid Build Coastguard Worker if dtype == torch.bfloat16 and is_mixed: 8427*da0073e9SAndroid Build Coastguard Worker gn = nn.GroupNorm(groups, channels).to(device).to(torch.float) 8428*da0073e9SAndroid Build Coastguard Worker else: 8429*da0073e9SAndroid Build Coastguard Worker gn = nn.GroupNorm(groups, channels).to(device).to(dtype) 8430*da0073e9SAndroid Build Coastguard Worker gn.weight.data.uniform_() 8431*da0073e9SAndroid Build Coastguard Worker gn.bias.data.uniform_() 8432*da0073e9SAndroid Build Coastguard Worker 8433*da0073e9SAndroid Build Coastguard Worker ref_input = input.detach().clone().contiguous(memory_format=torch.contiguous_format).requires_grad_(True) 8434*da0073e9SAndroid Build Coastguard Worker ref_grad = grad.detach().clone().contiguous(memory_format=torch.contiguous_format) 8435*da0073e9SAndroid Build Coastguard Worker if dtype == torch.bfloat16 and is_mixed: 8436*da0073e9SAndroid Build Coastguard Worker ref_gn = nn.GroupNorm(groups, channels).to(device).to(torch.float) 8437*da0073e9SAndroid Build Coastguard Worker else: 8438*da0073e9SAndroid Build Coastguard Worker ref_gn = nn.GroupNorm(groups, channels).to(device).to(dtype) 8439*da0073e9SAndroid Build Coastguard Worker ref_gn.load_state_dict(gn.state_dict()) 8440*da0073e9SAndroid Build Coastguard Worker out = gn(input) 8441*da0073e9SAndroid Build Coastguard Worker out.backward(grad) 8442*da0073e9SAndroid Build Coastguard Worker ref_out = ref_gn(ref_input) 8443*da0073e9SAndroid Build Coastguard Worker ref_out.backward(ref_grad) 8444*da0073e9SAndroid Build Coastguard Worker 8445*da0073e9SAndroid Build Coastguard Worker self.assertTrue(out.is_contiguous(memory_format=memory_format)) 8446*da0073e9SAndroid Build Coastguard Worker self.assertTrue(ref_out.is_contiguous(memory_format=torch.contiguous_format)) 8447*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out, ref_out) 8448*da0073e9SAndroid Build Coastguard Worker # parameters in bfloat16/Half is not recommended 8449*da0073e9SAndroid Build Coastguard Worker atol = 5e-4 8450*da0073e9SAndroid Build Coastguard Worker rtol = 8e-3 8451*da0073e9SAndroid Build Coastguard Worker 8452*da0073e9SAndroid Build Coastguard Worker self.assertEqual(gn.weight.grad, ref_gn.weight.grad, atol=atol, rtol=rtol) 8453*da0073e9SAndroid Build Coastguard Worker self.assertEqual(gn.bias.grad, ref_gn.bias.grad, atol=atol, rtol=rtol) 8454*da0073e9SAndroid Build Coastguard Worker self.assertEqual(input.grad, ref_input.grad, atol=atol, rtol=rtol) 8455*da0073e9SAndroid Build Coastguard Worker 8456*da0073e9SAndroid Build Coastguard Worker for is_mixed in [True, False]: 8457*da0073e9SAndroid Build Coastguard Worker helper(self, (4, 8, 10, 10), 4, torch.channels_last, is_mixed) 8458*da0073e9SAndroid Build Coastguard Worker helper(self, (2, 30, 9, 9), 3, torch.channels_last, is_mixed) 8459*da0073e9SAndroid Build Coastguard Worker helper(self, (4, 8, 40, 40), 4, torch.channels_last, is_mixed) 8460*da0073e9SAndroid Build Coastguard Worker helper(self, (4, 40, 40, 40), 2, torch.channels_last, is_mixed) 8461*da0073e9SAndroid Build Coastguard Worker helper(self, (2, 30, 50, 50), 3, torch.channels_last, is_mixed) 8462*da0073e9SAndroid Build Coastguard Worker helper(self, (2, 60, 50, 50), 3, torch.channels_last, is_mixed) 8463*da0073e9SAndroid Build Coastguard Worker helper(self, (2, 9, 7, 11, 15), 3, torch.channels_last_3d, is_mixed) 8464*da0073e9SAndroid Build Coastguard Worker helper(self, (2, 9, 7, 200, 15), 3, torch.channels_last_3d, is_mixed) 8465*da0073e9SAndroid Build Coastguard Worker helper(self, (2, 60, 7, 200, 15), 3, torch.channels_last_3d, is_mixed) 8466*da0073e9SAndroid Build Coastguard Worker 8467*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 8468*da0073e9SAndroid Build Coastguard Worker def test_GroupNorm_memory_format(self, device): 8469*da0073e9SAndroid Build Coastguard Worker # Tests for regression reported in https://github.com/pytorch/pytorch/issues/92166 8470*da0073e9SAndroid Build Coastguard Worker 8471*da0073e9SAndroid Build Coastguard Worker def helper(input_format, grad_format, B=2, C=4, W=4, H=4): 8472*da0073e9SAndroid Build Coastguard Worker import copy 8473*da0073e9SAndroid Build Coastguard Worker net_orig = torch.nn.GroupNorm(B, C).to(device=device) 8474*da0073e9SAndroid Build Coastguard Worker net = copy.deepcopy(net_orig) 8475*da0073e9SAndroid Build Coastguard Worker x_orig = torch.rand(B, C, W, H, device=device, requires_grad=True) 8476*da0073e9SAndroid Build Coastguard Worker grad_orig = torch.rand(B, C, W, H, device=device) 8477*da0073e9SAndroid Build Coastguard Worker x = x_orig.clone().detach().to(memory_format=input_format).requires_grad_(True) 8478*da0073e9SAndroid Build Coastguard Worker grad = grad_orig.detach().to(memory_format=grad_format) 8479*da0073e9SAndroid Build Coastguard Worker 8480*da0073e9SAndroid Build Coastguard Worker y = net(x) 8481*da0073e9SAndroid Build Coastguard Worker y.backward(grad) 8482*da0073e9SAndroid Build Coastguard Worker 8483*da0073e9SAndroid Build Coastguard Worker y_orig = net_orig(x_orig) 8484*da0073e9SAndroid Build Coastguard Worker y_orig.backward(grad_orig) 8485*da0073e9SAndroid Build Coastguard Worker 8486*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y, y_orig) 8487*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.grad, x_orig.grad) 8488*da0073e9SAndroid Build Coastguard Worker 8489*da0073e9SAndroid Build Coastguard Worker for input_format in [torch.contiguous_format, torch.channels_last]: 8490*da0073e9SAndroid Build Coastguard Worker for grad_format in [torch.contiguous_format, torch.channels_last]: 8491*da0073e9SAndroid Build Coastguard Worker helper(input_format, grad_format) 8492*da0073e9SAndroid Build Coastguard Worker 8493*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 8494*da0073e9SAndroid Build Coastguard Worker def test_GroupNorm_numeric(self, device): 8495*da0073e9SAndroid Build Coastguard Worker def group_norm_ref(X, gamma, beta, groups, channels, eps): 8496*da0073e9SAndroid Build Coastguard Worker batch_size = X.size()[0] 8497*da0073e9SAndroid Build Coastguard Worker X_view = X.view(batch_size, groups, -1) 8498*da0073e9SAndroid Build Coastguard Worker mean = X_view.mean(dim=-1, keepdim=True) 8499*da0073e9SAndroid Build Coastguard Worker var = X_view.var(dim=-1, unbiased=False, keepdim=True) 8500*da0073e9SAndroid Build Coastguard Worker Y = ((X_view - mean) / torch.sqrt(var + eps)).view( 8501*da0073e9SAndroid Build Coastguard Worker batch_size, channels, -1) 8502*da0073e9SAndroid Build Coastguard Worker Y = Y * gamma.view(channels, 1) + beta.view(channels, 1) 8503*da0073e9SAndroid Build Coastguard Worker return Y.view(*X.size()) 8504*da0073e9SAndroid Build Coastguard Worker 8505*da0073e9SAndroid Build Coastguard Worker batch_size = 1 8506*da0073e9SAndroid Build Coastguard Worker groups = 2 8507*da0073e9SAndroid Build Coastguard Worker channels = 8 8508*da0073e9SAndroid Build Coastguard Worker group_norm = nn.GroupNorm(groups, channels).float().to(device) 8509*da0073e9SAndroid Build Coastguard Worker X = torch.rand(batch_size, channels, 256, 256, 72, 8510*da0073e9SAndroid Build Coastguard Worker dtype=torch.float32, device=device) 8511*da0073e9SAndroid Build Coastguard Worker 8512*da0073e9SAndroid Build Coastguard Worker Y = group_norm(X) 8513*da0073e9SAndroid Build Coastguard Worker Y_ref = group_norm_ref( 8514*da0073e9SAndroid Build Coastguard Worker X, group_norm.weight.data, group_norm.bias.data, groups, 8515*da0073e9SAndroid Build Coastguard Worker channels, group_norm.eps) 8516*da0073e9SAndroid Build Coastguard Worker self.assertEqual(Y, Y_ref, rtol=0, atol=1e-5) 8517*da0073e9SAndroid Build Coastguard Worker 8518*da0073e9SAndroid Build Coastguard Worker if self.device_type == 'cuda': 8519*da0073e9SAndroid Build Coastguard Worker group_norm.cpu() 8520*da0073e9SAndroid Build Coastguard Worker Y_cpu = group_norm(X.cpu()) 8521*da0073e9SAndroid Build Coastguard Worker self.assertEqual(Y_cpu, Y, rtol=0, atol=1e-5) 8522*da0073e9SAndroid Build Coastguard Worker 8523*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 8524*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float64, torch.complex128) 8525*da0073e9SAndroid Build Coastguard Worker def test_pad(self, device, dtype): 8526*da0073e9SAndroid Build Coastguard Worker # Assert assertion errors are raised for invalid circular padding values 8527*da0073e9SAndroid Build Coastguard Worker inputs = torch.randn(1, 1, 4, device=device, dtype=dtype, requires_grad=True) 8528*da0073e9SAndroid Build Coastguard Worker # Should raise error when trying to wrap around more than once 8529*da0073e9SAndroid Build Coastguard Worker self.assertRaises(RuntimeError, lambda: F.pad(inputs, (5, 4), mode='circular')) 8530*da0073e9SAndroid Build Coastguard Worker self.assertRaises(RuntimeError, lambda: F.pad(inputs, (3, 6), mode='circular')) 8531*da0073e9SAndroid Build Coastguard Worker # Should raise error when negative padding results in negative output shape 8532*da0073e9SAndroid Build Coastguard Worker self.assertRaises(RuntimeError, lambda: F.pad(inputs, (-3, -2), mode='circular')) 8533*da0073e9SAndroid Build Coastguard Worker 8534*da0073e9SAndroid Build Coastguard Worker # assert that relfection padding errors when pad >= input size 8535*da0073e9SAndroid Build Coastguard Worker expected_err_msg = r"Padding size should be less than the corresponding input dimension" 8536*da0073e9SAndroid Build Coastguard Worker inputs = torch.randn(1, 1, 2, 3, device=device, dtype=dtype) 8537*da0073e9SAndroid Build Coastguard Worker self.assertRaisesRegex(RuntimeError, expected_err_msg, 8538*da0073e9SAndroid Build Coastguard Worker lambda: F.pad(inputs, (1, 1, 3, 0), mode='reflect')) 8539*da0073e9SAndroid Build Coastguard Worker inputs = torch.randn(1, 1, 2, device=device, dtype=dtype) 8540*da0073e9SAndroid Build Coastguard Worker self.assertRaisesRegex(RuntimeError, expected_err_msg, 8541*da0073e9SAndroid Build Coastguard Worker lambda: F.pad(inputs, (2, 1), mode='reflect')) 8542*da0073e9SAndroid Build Coastguard Worker 8543*da0073e9SAndroid Build Coastguard Worker inputs = torch.rand(1, 3, 4, 4, device=device, dtype=dtype) 8544*da0073e9SAndroid Build Coastguard Worker # assert that pad doesn't return a view into the input tensor 8545*da0073e9SAndroid Build Coastguard Worker for mode in 'constant', 'reflect', 'replicate', 'circular': 8546*da0073e9SAndroid Build Coastguard Worker out = F.pad(inputs, (0, 0, 0, 0), mode=mode) 8547*da0073e9SAndroid Build Coastguard Worker out.fill_(4) 8548*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.all(torch.abs(inputs) < 2)) 8549*da0073e9SAndroid Build Coastguard Worker 8550*da0073e9SAndroid Build Coastguard Worker out = F.pad(inputs, (0, 0, -1, -1), mode=mode) 8551*da0073e9SAndroid Build Coastguard Worker out.fill_(4) 8552*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.all(torch.abs(inputs) < 2)) 8553*da0073e9SAndroid Build Coastguard Worker 8554*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 8555*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float64, torch.complex128) 8556*da0073e9SAndroid Build Coastguard Worker def test_ReplicationPad_empty(self, device, dtype): 8557*da0073e9SAndroid Build Coastguard Worker for mod, inp in [ 8558*da0073e9SAndroid Build Coastguard Worker (torch.nn.ReplicationPad1d(3), torch.randn(0, 3, 10, device=device, dtype=dtype)), 8559*da0073e9SAndroid Build Coastguard Worker (torch.nn.ReplicationPad2d(3), torch.randn(0, 3, 10, 10, device=device, dtype=dtype)), 8560*da0073e9SAndroid Build Coastguard Worker (torch.nn.ReplicationPad3d(3), torch.randn(0, 3, 10, 10, 10, device=device, dtype=dtype))]: 8561*da0073e9SAndroid Build Coastguard Worker _test_module_empty_input(self, mod, inp, check_size=False) 8562*da0073e9SAndroid Build Coastguard Worker 8563*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, 'Expected 2D or 3D'): 8564*da0073e9SAndroid Build Coastguard Worker mod = torch.nn.ReplicationPad1d(2) 8565*da0073e9SAndroid Build Coastguard Worker inp = torch.randn(3, 0, 10, device=device, dtype=dtype) 8566*da0073e9SAndroid Build Coastguard Worker mod(inp) 8567*da0073e9SAndroid Build Coastguard Worker 8568*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, 'Expected 3D or 4D'): 8569*da0073e9SAndroid Build Coastguard Worker mod = torch.nn.ReplicationPad2d((2, 2, 2, 2)) 8570*da0073e9SAndroid Build Coastguard Worker inp = torch.randn(43, 0, 10, 10, device=device, dtype=dtype) 8571*da0073e9SAndroid Build Coastguard Worker mod(inp) 8572*da0073e9SAndroid Build Coastguard Worker 8573*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, 'Expected 4D or 5D'): 8574*da0073e9SAndroid Build Coastguard Worker mod = torch.nn.ReplicationPad3d((2, 2, 2, 2, 2, 2)) 8575*da0073e9SAndroid Build Coastguard Worker inp = torch.randn(3, 0, 10, 10, 10, device=device, dtype=dtype) 8576*da0073e9SAndroid Build Coastguard Worker mod(inp) 8577*da0073e9SAndroid Build Coastguard Worker 8578*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, 'padding size is expected to be 2'): 8579*da0073e9SAndroid Build Coastguard Worker torch._C._nn.replication_pad1d(torch.randn([2]), padding=[]) 8580*da0073e9SAndroid Build Coastguard Worker 8581*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, 'padding size is expected to be 4'): 8582*da0073e9SAndroid Build Coastguard Worker torch._C._nn.replication_pad2d(torch.randn([2]), padding=[]) 8583*da0073e9SAndroid Build Coastguard Worker 8584*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, 'padding size is expected to be 6'): 8585*da0073e9SAndroid Build Coastguard Worker torch._C._nn.replication_pad3d(torch.randn([2]), padding=[]) 8586*da0073e9SAndroid Build Coastguard Worker 8587*da0073e9SAndroid Build Coastguard Worker @expectedFailureMPS # TODO(hvaara): Investigate as possible bug. 8588*da0073e9SAndroid Build Coastguard Worker def test_ReplicationPad1d_large(self, device): 8589*da0073e9SAndroid Build Coastguard Worker shapes = ([2, 65736, 4], [65736, 2, 4]) 8590*da0073e9SAndroid Build Coastguard Worker pl, pr = 3, 4 8591*da0073e9SAndroid Build Coastguard Worker for shape in shapes: 8592*da0073e9SAndroid Build Coastguard Worker x = torch.randn(shape, device=device, requires_grad=True) 8593*da0073e9SAndroid Build Coastguard Worker model = torch.nn.ReplicationPad1d((pl, pr)) 8594*da0073e9SAndroid Build Coastguard Worker 8595*da0073e9SAndroid Build Coastguard Worker # forward 8596*da0073e9SAndroid Build Coastguard Worker out = model(x) 8597*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out[:, :, pl : -pr], x) 8598*da0073e9SAndroid Build Coastguard Worker 8599*da0073e9SAndroid Build Coastguard Worker left_padding = out[:, :, : pl] 8600*da0073e9SAndroid Build Coastguard Worker self.assertEqual(left_padding, x[:, :, :1].expand_as(left_padding)) 8601*da0073e9SAndroid Build Coastguard Worker right_padding = out[:, :, -pr :] 8602*da0073e9SAndroid Build Coastguard Worker self.assertEqual(right_padding, x[:, :, -1:].expand_as(right_padding)) 8603*da0073e9SAndroid Build Coastguard Worker 8604*da0073e9SAndroid Build Coastguard Worker # backward 8605*da0073e9SAndroid Build Coastguard Worker g = torch.randn_like(out) 8606*da0073e9SAndroid Build Coastguard Worker out.backward(g) 8607*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.grad[:, :, 1 : -1], g[:, :, pl + 1 : -pr - 1]) 8608*da0073e9SAndroid Build Coastguard Worker 8609*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.grad[:, :, 0], g[:, :, : pl + 1].sum(-1)) 8610*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.grad[:, :, -1], g[:, :, -pr - 1:].sum(-1)) 8611*da0073e9SAndroid Build Coastguard Worker 8612*da0073e9SAndroid Build Coastguard Worker @expectedFailureMPS # TODO(hvaara): Investigate as possible bug. 8613*da0073e9SAndroid Build Coastguard Worker def test_ReplicationPad2d_large(self, device): 8614*da0073e9SAndroid Build Coastguard Worker shapes = ([2, 65736, 4, 4], [65736, 2, 4, 4]) 8615*da0073e9SAndroid Build Coastguard Worker pl, pr, pt, pb = 3, 4, 5, 6 8616*da0073e9SAndroid Build Coastguard Worker for shape in shapes: 8617*da0073e9SAndroid Build Coastguard Worker x = torch.randn(shape, device=device, requires_grad=True) 8618*da0073e9SAndroid Build Coastguard Worker model = torch.nn.ReplicationPad2d((pl, pr, pt, pb)) 8619*da0073e9SAndroid Build Coastguard Worker 8620*da0073e9SAndroid Build Coastguard Worker # forward center, edge 8621*da0073e9SAndroid Build Coastguard Worker out = model(x) 8622*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out[:, :, pt : -pb, pl : -pr], x) 8623*da0073e9SAndroid Build Coastguard Worker 8624*da0073e9SAndroid Build Coastguard Worker left_padding = out[:, :, pt : -pb, : pl] 8625*da0073e9SAndroid Build Coastguard Worker self.assertEqual(left_padding, x[:, :, :, :1].expand_as(left_padding)) 8626*da0073e9SAndroid Build Coastguard Worker right_padding = out[:, :, pt : -pb, -pr :] 8627*da0073e9SAndroid Build Coastguard Worker self.assertEqual(right_padding, x[:, :, :, -1:].expand_as(right_padding)) 8628*da0073e9SAndroid Build Coastguard Worker top_padding = out[:, :, : pt, pl : -pr] 8629*da0073e9SAndroid Build Coastguard Worker self.assertEqual(top_padding, x[:, :, :1, :].expand_as(top_padding)) 8630*da0073e9SAndroid Build Coastguard Worker bottom_padding = out[:, :, -pb : , pl : -pr] 8631*da0073e9SAndroid Build Coastguard Worker self.assertEqual(bottom_padding, x[:, :, -1:, :].expand_as(bottom_padding)) 8632*da0073e9SAndroid Build Coastguard Worker 8633*da0073e9SAndroid Build Coastguard Worker # forward corner 8634*da0073e9SAndroid Build Coastguard Worker tl_padding = out[:, :, : pt + 1, : pl + 1] 8635*da0073e9SAndroid Build Coastguard Worker self.assertEqual(tl_padding, x[:, :, :1, :1].expand_as(tl_padding)) 8636*da0073e9SAndroid Build Coastguard Worker tr_padding = out[:, :, : pt + 1, -pr - 1:] 8637*da0073e9SAndroid Build Coastguard Worker self.assertEqual(tr_padding, x[:, :, :1, -1:].expand_as(tr_padding)) 8638*da0073e9SAndroid Build Coastguard Worker bl_padding = out[:, :, -pb - 1:, : pl + 1] 8639*da0073e9SAndroid Build Coastguard Worker self.assertEqual(bl_padding, x[:, :, -1:, :1].expand_as(bl_padding)) 8640*da0073e9SAndroid Build Coastguard Worker br_padding = out[:, :, -pb - 1:, -pr - 1:] 8641*da0073e9SAndroid Build Coastguard Worker self.assertEqual(br_padding, x[:, :, -1:, -1:].expand_as(br_padding)) 8642*da0073e9SAndroid Build Coastguard Worker 8643*da0073e9SAndroid Build Coastguard Worker # backward center, edge 8644*da0073e9SAndroid Build Coastguard Worker g = torch.randn_like(out) 8645*da0073e9SAndroid Build Coastguard Worker out.backward(g) 8646*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.grad[:, :, 1:-1, 1:-1], g[:, :, pt + 1 : -pb - 1, pl + 1 : -pr - 1]) 8647*da0073e9SAndroid Build Coastguard Worker 8648*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.grad[:, :, 1:-1, 0], g[:, :, pt + 1 : -pb - 1, : pl + 1].sum(-1)) 8649*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.grad[:, :, 1:-1, -1], g[:, :, pt + 1 : -pb - 1, -pr - 1 :].sum(-1)) 8650*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.grad[:, :, 0, 1:-1], g[:, :, : pt + 1, pl + 1 : -pr - 1].sum(-2)) 8651*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.grad[:, :, -1, 1:-1], g[:, :, -pb - 1 :, pl + 1 : -pr - 1].sum(-2)) 8652*da0073e9SAndroid Build Coastguard Worker 8653*da0073e9SAndroid Build Coastguard Worker # backward corner 8654*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.grad[:, :, 0, 0], g[:, :, : pt + 1, : pl + 1].sum((-2, -1))) 8655*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.grad[:, :, 0, -1], g[:, :, : pt + 1, -pr - 1 :].sum((-2, -1))) 8656*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.grad[:, :, -1, 0], g[:, :, -pb - 1 :, : pl + 1].sum((-2, -1))) 8657*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.grad[:, :, -1, -1], g[:, :, -pb - 1 :, -pr - 1 :].sum((-2, -1))) 8658*da0073e9SAndroid Build Coastguard Worker 8659*da0073e9SAndroid Build Coastguard Worker @largeTensorTest("6GB") 8660*da0073e9SAndroid Build Coastguard Worker def test_ReplicationPad3d_large(self, device): 8661*da0073e9SAndroid Build Coastguard Worker shapes = ([1, 65736, 2, 2, 2], [65736, 1, 2, 2, 2]) 8662*da0073e9SAndroid Build Coastguard Worker pl, pr, pt, pbt, pf, pbk = 3, 4, 5, 6, 7, 8 8663*da0073e9SAndroid Build Coastguard Worker 8664*da0073e9SAndroid Build Coastguard Worker for shape in shapes: 8665*da0073e9SAndroid Build Coastguard Worker x = torch.randn(shape, device=device, requires_grad=True) 8666*da0073e9SAndroid Build Coastguard Worker model = torch.nn.ReplicationPad3d((pl, pr, pt, pbt, pf, pbk)) 8667*da0073e9SAndroid Build Coastguard Worker 8668*da0073e9SAndroid Build Coastguard Worker # forward center 8669*da0073e9SAndroid Build Coastguard Worker out = model(x) 8670*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out[:, :, pf : -pbk, pt : -pbt, pl : -pr], x) 8671*da0073e9SAndroid Build Coastguard Worker 8672*da0073e9SAndroid Build Coastguard Worker # backward center 8673*da0073e9SAndroid Build Coastguard Worker g = torch.randn_like(out) 8674*da0073e9SAndroid Build Coastguard Worker out.backward(g) 8675*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.grad[:, :, 1:-1, 1:-1, 1:-1], g[:, :, pf + 1 : -pbk - 1, pt + 1 : -pbt - 1, pl + 1 : -pr - 1]) 8676*da0073e9SAndroid Build Coastguard Worker 8677*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 8678*da0073e9SAndroid Build Coastguard Worker def test_Bilinear_empty(self, device): 8679*da0073e9SAndroid Build Coastguard Worker mod = torch.nn.Bilinear(20, 30, 40).to(device) 8680*da0073e9SAndroid Build Coastguard Worker inp1 = torch.randn(0, 10, 20, requires_grad=True, device=device) 8681*da0073e9SAndroid Build Coastguard Worker inp2 = torch.randn(0, 10, 30, requires_grad=True, device=device) 8682*da0073e9SAndroid Build Coastguard Worker 8683*da0073e9SAndroid Build Coastguard Worker output = mod(inp1, inp2) 8684*da0073e9SAndroid Build Coastguard Worker output.sum().backward() 8685*da0073e9SAndroid Build Coastguard Worker 8686*da0073e9SAndroid Build Coastguard Worker self.assertEqual(inp1, torch.zeros_like(inp1)) 8687*da0073e9SAndroid Build Coastguard Worker self.assertEqual(inp2, torch.zeros_like(inp2)) 8688*da0073e9SAndroid Build Coastguard Worker 8689*da0073e9SAndroid Build Coastguard Worker self.assertEqual(inp1.grad, torch.zeros_like(inp1)) 8690*da0073e9SAndroid Build Coastguard Worker self.assertEqual(inp2.grad, torch.zeros_like(inp2)) 8691*da0073e9SAndroid Build Coastguard Worker 8692*da0073e9SAndroid Build Coastguard Worker @expectedFailureMeta # RuntimeError: cannot reshape tensor of 0 elements into shape [1, 0, -1] 8693*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 8694*da0073e9SAndroid Build Coastguard Worker def test_TransformerEncoderLayer_empty(self, device): 8695*da0073e9SAndroid Build Coastguard Worker for training in (True, False): 8696*da0073e9SAndroid Build Coastguard Worker for batch_first, input_shape in [(True, (0, 10, 512)), 8697*da0073e9SAndroid Build Coastguard Worker (False, (10, 0, 512))]: 8698*da0073e9SAndroid Build Coastguard Worker input = torch.rand(*input_shape, device=device, dtype=torch.double) 8699*da0073e9SAndroid Build Coastguard Worker encoder_layer = nn.TransformerEncoderLayer( 8700*da0073e9SAndroid Build Coastguard Worker d_model=512, nhead=8, batch_first=batch_first, dtype=torch.double).to(device) 8701*da0073e9SAndroid Build Coastguard Worker if not training: 8702*da0073e9SAndroid Build Coastguard Worker encoder_layer = encoder_layer.eval() 8703*da0073e9SAndroid Build Coastguard Worker with torch.no_grad(): 8704*da0073e9SAndroid Build Coastguard Worker _test_module_empty_input(self, encoder_layer, input, check_size=False, inference=True) 8705*da0073e9SAndroid Build Coastguard Worker if batch_first and not TEST_WITH_CROSSREF: 8706*da0073e9SAndroid Build Coastguard Worker with torch.no_grad(): 8707*da0073e9SAndroid Build Coastguard Worker # A NestedTensor with no tensors inside it doesn't have dim 3 (or dim 8708*da0073e9SAndroid Build Coastguard Worker # 2, for that matter) so it can't hit the fast path, nor can we give a 8709*da0073e9SAndroid Build Coastguard Worker # result. 8710*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 8711*da0073e9SAndroid Build Coastguard Worker AssertionError, 'MultiheadAttention does not support NestedTensor outside'): 8712*da0073e9SAndroid Build Coastguard Worker nt = torch.nested.nested_tensor([], device=device) 8713*da0073e9SAndroid Build Coastguard Worker _test_module_empty_input(self, encoder_layer, nt, check_size=False, inference=True) 8714*da0073e9SAndroid Build Coastguard Worker 8715*da0073e9SAndroid Build Coastguard Worker nt = torch.nested.nested_tensor([torch.rand(0, 512, device=device, dtype=torch.double)], device=device) 8716*da0073e9SAndroid Build Coastguard Worker _test_module_empty_input(self, encoder_layer, nt, check_size=False, inference=True) 8717*da0073e9SAndroid Build Coastguard Worker else: 8718*da0073e9SAndroid Build Coastguard Worker _test_module_empty_input(self, encoder_layer, input, check_size=False) 8719*da0073e9SAndroid Build Coastguard Worker 8720*da0073e9SAndroid Build Coastguard Worker @expectedFailureMeta # RuntimeError: cannot reshape tensor of 0 elements into shape [1, 0, -1] 8721*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 8722*da0073e9SAndroid Build Coastguard Worker def test_TransformerEncoder_empty(self, device): 8723*da0073e9SAndroid Build Coastguard Worker for batch_first, input_shape in [(True, (0, 10, 512)), 8724*da0073e9SAndroid Build Coastguard Worker (False, (10, 0, 512))]: 8725*da0073e9SAndroid Build Coastguard Worker input = torch.rand(*input_shape, device=device, dtype=torch.double) 8726*da0073e9SAndroid Build Coastguard Worker encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8, batch_first=batch_first, dtype=torch.double).to(device) 8727*da0073e9SAndroid Build Coastguard Worker transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=6).to(device) 8728*da0073e9SAndroid Build Coastguard Worker _test_module_empty_input(self, transformer_encoder, input, check_size=False) 8729*da0073e9SAndroid Build Coastguard Worker 8730*da0073e9SAndroid Build Coastguard Worker @expectedFailureMeta # RuntimeError: cannot reshape tensor of 0 elements into shape [1, 0, -1] 8731*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 8732*da0073e9SAndroid Build Coastguard Worker def test_TransformerDecoderLayer_empty(self, device): 8733*da0073e9SAndroid Build Coastguard Worker for batch_first, memory_shape, tgt_shape in [(True, (0, 10, 512), (0, 20, 512)), 8734*da0073e9SAndroid Build Coastguard Worker (False, (10, 0, 512), (20, 0, 512))]: 8735*da0073e9SAndroid Build Coastguard Worker memory = torch.rand(*memory_shape, device=device, dtype=torch.double) 8736*da0073e9SAndroid Build Coastguard Worker tgt = torch.rand(*tgt_shape, requires_grad=True, device=device, dtype=torch.double) 8737*da0073e9SAndroid Build Coastguard Worker decoder_layer = nn.TransformerDecoderLayer(d_model=512, nhead=8, batch_first=batch_first, dtype=torch.double).to(device) 8738*da0073e9SAndroid Build Coastguard Worker self._test_module_empty_inputs(decoder_layer, [tgt, memory]) 8739*da0073e9SAndroid Build Coastguard Worker 8740*da0073e9SAndroid Build Coastguard Worker @expectedFailureMeta # RuntimeError: cannot reshape tensor of 0 elements into shape [1, 0, -1] 8741*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 8742*da0073e9SAndroid Build Coastguard Worker def test_TransformerDecoder_empty(self, device): 8743*da0073e9SAndroid Build Coastguard Worker for batch_first, memory_shape, tgt_shape in [(True, (0, 10, 512), (0, 20, 512)), 8744*da0073e9SAndroid Build Coastguard Worker (False, (10, 0, 512), (20, 0, 512))]: 8745*da0073e9SAndroid Build Coastguard Worker memory = torch.rand(*memory_shape, device=device, dtype=torch.double) 8746*da0073e9SAndroid Build Coastguard Worker tgt = torch.rand(*tgt_shape, requires_grad=True, device=device, dtype=torch.double) 8747*da0073e9SAndroid Build Coastguard Worker decoder_layer = nn.TransformerDecoderLayer(d_model=512, nhead=8, batch_first=batch_first, dtype=torch.double).to(device) 8748*da0073e9SAndroid Build Coastguard Worker transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers=6).to(device) 8749*da0073e9SAndroid Build Coastguard Worker self._test_module_empty_inputs(transformer_decoder, [tgt, memory]) 8750*da0073e9SAndroid Build Coastguard Worker 8751*da0073e9SAndroid Build Coastguard Worker @expectedFailureMeta # RuntimeError: cannot reshape tensor of 0 elements into shape [1, 0, -1] 8752*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 8753*da0073e9SAndroid Build Coastguard Worker def test_Transformer_empty(self, device): 8754*da0073e9SAndroid Build Coastguard Worker for batch_first, src_shape, tgt_shape in [(True, (10, 0, 512), (20, 0, 512))]: 8755*da0073e9SAndroid Build Coastguard Worker transformer_model = nn.Transformer(nhead=16, num_encoder_layers=12, dtype=torch.double).to(device) 8756*da0073e9SAndroid Build Coastguard Worker src = torch.rand(*src_shape, requires_grad=True, device=device, dtype=torch.double) 8757*da0073e9SAndroid Build Coastguard Worker tgt = torch.rand(*tgt_shape, requires_grad=True, device=device, dtype=torch.double) 8758*da0073e9SAndroid Build Coastguard Worker self._test_module_empty_inputs(transformer_model, [src, tgt]) 8759*da0073e9SAndroid Build Coastguard Worker 8760*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 8761*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float32, torch.complex64) 8762*da0073e9SAndroid Build Coastguard Worker def test_ReflectionPad_empty(self, device, dtype): 8763*da0073e9SAndroid Build Coastguard Worker for mod, inp in [ 8764*da0073e9SAndroid Build Coastguard Worker (torch.nn.ReflectionPad1d(2), torch.randn(0, 3, 10, device=device, dtype=dtype)), 8765*da0073e9SAndroid Build Coastguard Worker (torch.nn.ReflectionPad2d(2), torch.randn(0, 3, 10, 10, device=device, dtype=dtype)), 8766*da0073e9SAndroid Build Coastguard Worker (torch.nn.ReflectionPad3d(3), torch.randn(0, 3, 10, 10, 10, device=device, dtype=dtype))]: 8767*da0073e9SAndroid Build Coastguard Worker _test_module_empty_input(self, mod, inp, check_size=False) 8768*da0073e9SAndroid Build Coastguard Worker 8769*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, '2D or 3D'): 8770*da0073e9SAndroid Build Coastguard Worker mod = torch.nn.ReflectionPad1d(2) 8771*da0073e9SAndroid Build Coastguard Worker inp = torch.randn(3, 0, 10, device=device, dtype=dtype) 8772*da0073e9SAndroid Build Coastguard Worker mod(inp) 8773*da0073e9SAndroid Build Coastguard Worker 8774*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, '3D or 4D'): 8775*da0073e9SAndroid Build Coastguard Worker mod = torch.nn.ReflectionPad2d(2) 8776*da0073e9SAndroid Build Coastguard Worker inp = torch.randn(3, 0, 10, 10, device=device, dtype=dtype) 8777*da0073e9SAndroid Build Coastguard Worker mod(inp) 8778*da0073e9SAndroid Build Coastguard Worker 8779*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, '4D or 5D'): 8780*da0073e9SAndroid Build Coastguard Worker mod = torch.nn.ReflectionPad3d(3) 8781*da0073e9SAndroid Build Coastguard Worker inp = torch.randn(3, 0, 10, 10, 10, device=device, dtype=dtype) 8782*da0073e9SAndroid Build Coastguard Worker mod(inp) 8783*da0073e9SAndroid Build Coastguard Worker 8784*da0073e9SAndroid Build Coastguard Worker @onlyCUDA # Test if CPU and GPU results match 8785*da0073e9SAndroid Build Coastguard Worker def test_ReflectionPad2d_large(self, device): 8786*da0073e9SAndroid Build Coastguard Worker shapes = ([2, 65736, 6, 6], [65736, 2, 6, 6]) 8787*da0073e9SAndroid Build Coastguard Worker pad = (1, 2, 3, 4) 8788*da0073e9SAndroid Build Coastguard Worker for shape in shapes: 8789*da0073e9SAndroid Build Coastguard Worker x = torch.randn(shape, device=device, requires_grad=True) 8790*da0073e9SAndroid Build Coastguard Worker ref_x = x.detach().cpu().requires_grad_() 8791*da0073e9SAndroid Build Coastguard Worker 8792*da0073e9SAndroid Build Coastguard Worker out = F.pad(x, pad, mode='reflect') 8793*da0073e9SAndroid Build Coastguard Worker ref_out = F.pad(ref_x, pad, mode='reflect') 8794*da0073e9SAndroid Build Coastguard Worker 8795*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out, ref_out) 8796*da0073e9SAndroid Build Coastguard Worker 8797*da0073e9SAndroid Build Coastguard Worker g = torch.randn_like(out) 8798*da0073e9SAndroid Build Coastguard Worker ref_g = g.cpu() 8799*da0073e9SAndroid Build Coastguard Worker 8800*da0073e9SAndroid Build Coastguard Worker out.backward(g) 8801*da0073e9SAndroid Build Coastguard Worker ref_out.backward(ref_g) 8802*da0073e9SAndroid Build Coastguard Worker 8803*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.grad, ref_x.grad) 8804*da0073e9SAndroid Build Coastguard Worker 8805*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 8806*da0073e9SAndroid Build Coastguard Worker def test_LocalResponseNorm_empty(self, device): 8807*da0073e9SAndroid Build Coastguard Worker mod = torch.nn.LocalResponseNorm(2).to(device) 8808*da0073e9SAndroid Build Coastguard Worker inp = torch.ones(0, 5, 24, 24, device=device) 8809*da0073e9SAndroid Build Coastguard Worker _test_module_empty_input(self, mod, inp, check_size=False) 8810*da0073e9SAndroid Build Coastguard Worker 8811*da0073e9SAndroid Build Coastguard Worker @onlyCUDA # Test if CPU and GPU results match 8812*da0073e9SAndroid Build Coastguard Worker def test_ReflectionPad3d_large(self, device): 8813*da0073e9SAndroid Build Coastguard Worker shapes = ([2, 1000, 7, 7, 7], [1000, 2, 7, 7, 7]) 8814*da0073e9SAndroid Build Coastguard Worker pad = (1, 2, 3, 4, 5, 6) 8815*da0073e9SAndroid Build Coastguard Worker for shape in shapes: 8816*da0073e9SAndroid Build Coastguard Worker x = torch.randn(shape, device=device, requires_grad=True) 8817*da0073e9SAndroid Build Coastguard Worker ref_x = x.detach().cpu().requires_grad_() 8818*da0073e9SAndroid Build Coastguard Worker 8819*da0073e9SAndroid Build Coastguard Worker out = F.pad(x, pad, mode='reflect') 8820*da0073e9SAndroid Build Coastguard Worker ref_out = F.pad(ref_x, pad, mode='reflect') 8821*da0073e9SAndroid Build Coastguard Worker 8822*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out, ref_out) 8823*da0073e9SAndroid Build Coastguard Worker 8824*da0073e9SAndroid Build Coastguard Worker g = torch.randn_like(out) 8825*da0073e9SAndroid Build Coastguard Worker ref_g = g.cpu() 8826*da0073e9SAndroid Build Coastguard Worker 8827*da0073e9SAndroid Build Coastguard Worker out.backward(g) 8828*da0073e9SAndroid Build Coastguard Worker ref_out.backward(ref_g) 8829*da0073e9SAndroid Build Coastguard Worker 8830*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.grad, ref_x.grad) 8831*da0073e9SAndroid Build Coastguard Worker 8832*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 8833*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float, torch.double) 8834*da0073e9SAndroid Build Coastguard Worker def test_MarginLoss_empty(self, device, dtype): 8835*da0073e9SAndroid Build Coastguard Worker for mod, x, y in [ 8836*da0073e9SAndroid Build Coastguard Worker (torch.nn.MultiMarginLoss().to(device), 8837*da0073e9SAndroid Build Coastguard Worker torch.randn(0, 10, requires_grad=True, device=device, dtype=dtype), 8838*da0073e9SAndroid Build Coastguard Worker torch.ones(0, device=device).type(torch.long)), 8839*da0073e9SAndroid Build Coastguard Worker (torch.nn.MultiLabelMarginLoss().to(device), 8840*da0073e9SAndroid Build Coastguard Worker torch.randn(0, 10, requires_grad=True, device=device, dtype=dtype), 8841*da0073e9SAndroid Build Coastguard Worker torch.ones(0, 10, device=device).type(torch.long))]: 8842*da0073e9SAndroid Build Coastguard Worker 8843*da0073e9SAndroid Build Coastguard Worker out = mod(x, y) 8844*da0073e9SAndroid Build Coastguard Worker out.sum().backward() 8845*da0073e9SAndroid Build Coastguard Worker 8846*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x, torch.zeros_like(x)) 8847*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.grad, torch.zeros_like(x)) 8848*da0073e9SAndroid Build Coastguard Worker 8849*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, 'Expected'): 8850*da0073e9SAndroid Build Coastguard Worker x = torch.randn(0, requires_grad=True, device=device, dtype=dtype) 8851*da0073e9SAndroid Build Coastguard Worker y = torch.ones(10, device=device).type(torch.long) 8852*da0073e9SAndroid Build Coastguard Worker mod(x, y) 8853*da0073e9SAndroid Build Coastguard Worker 8854*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, 'Expected'): 8855*da0073e9SAndroid Build Coastguard Worker x = torch.randn(10, 0, requires_grad=True, device=device, dtype=dtype) 8856*da0073e9SAndroid Build Coastguard Worker y = torch.ones(10, 0, device=device).type(torch.long) 8857*da0073e9SAndroid Build Coastguard Worker mod(x, y) 8858*da0073e9SAndroid Build Coastguard Worker 8859*da0073e9SAndroid Build Coastguard Worker @onlyCUDA 8860*da0073e9SAndroid Build Coastguard Worker def test_MarginLoss_warnings(self, device): 8861*da0073e9SAndroid Build Coastguard Worker model = torch.nn.Linear(128, 22, device=device) 8862*da0073e9SAndroid Build Coastguard Worker loss = torch.nn.MultiMarginLoss() 8863*da0073e9SAndroid Build Coastguard Worker x = torch.rand((56, 128), device=device) 8864*da0073e9SAndroid Build Coastguard Worker targets = torch.randint(22, (56,), device=device) 8865*da0073e9SAndroid Build Coastguard Worker f = io.StringIO() 8866*da0073e9SAndroid Build Coastguard Worker with contextlib.redirect_stderr(f): 8867*da0073e9SAndroid Build Coastguard Worker out = model(x) 8868*da0073e9SAndroid Build Coastguard Worker l = loss(out, targets) 8869*da0073e9SAndroid Build Coastguard Worker l.backward() 8870*da0073e9SAndroid Build Coastguard Worker self.assertTrue(len(f.getvalue()) == 0) 8871*da0073e9SAndroid Build Coastguard Worker 8872*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 8873*da0073e9SAndroid Build Coastguard Worker def test_Unfold_empty(self, device): 8874*da0073e9SAndroid Build Coastguard Worker inp = torch.randn(0, 3, 3, 4, device=device) 8875*da0073e9SAndroid Build Coastguard Worker unfold = torch.nn.Unfold(kernel_size=(2, 3)).to(device) 8876*da0073e9SAndroid Build Coastguard Worker _test_module_empty_input(self, unfold, inp, check_size=False) 8877*da0073e9SAndroid Build Coastguard Worker 8878*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, 'Expected 3D or 4D'): 8879*da0073e9SAndroid Build Coastguard Worker inp = torch.randn(3, 0, 3, 4, device=device) 8880*da0073e9SAndroid Build Coastguard Worker unfold = torch.nn.Unfold(kernel_size=(2, 3)).to(device) 8881*da0073e9SAndroid Build Coastguard Worker unfold(inp) 8882*da0073e9SAndroid Build Coastguard Worker 8883*da0073e9SAndroid Build Coastguard Worker @onlyCUDA 8884*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float, torch.double) 8885*da0073e9SAndroid Build Coastguard Worker @tf32_on_and_off(0.005) 8886*da0073e9SAndroid Build Coastguard Worker def test_rnn_fused(self, device, dtype): 8887*da0073e9SAndroid Build Coastguard Worker 8888*da0073e9SAndroid Build Coastguard Worker def copy_rnn(rnn1, rnn2): 8889*da0073e9SAndroid Build Coastguard Worker for x_layer, y_layer in zip(rnn1.all_weights, rnn2.all_weights): 8890*da0073e9SAndroid Build Coastguard Worker for x, y in zip(x_layer, y_layer): 8891*da0073e9SAndroid Build Coastguard Worker x.data.copy_(y.data) 8892*da0073e9SAndroid Build Coastguard Worker 8893*da0073e9SAndroid Build Coastguard Worker def check_rnn_grads(rnn1, rnn2): 8894*da0073e9SAndroid Build Coastguard Worker for x_layer, y_layer in zip(rnn1.all_weights, rnn2.all_weights): 8895*da0073e9SAndroid Build Coastguard Worker for x, y in zip(x_layer, y_layer): 8896*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.grad, y.grad, atol=5e-5, rtol=0) 8897*da0073e9SAndroid Build Coastguard Worker 8898*da0073e9SAndroid Build Coastguard Worker input_size = 10 8899*da0073e9SAndroid Build Coastguard Worker hidden_size = 6 8900*da0073e9SAndroid Build Coastguard Worker num_layers = 2 8901*da0073e9SAndroid Build Coastguard Worker seq_length = 7 8902*da0073e9SAndroid Build Coastguard Worker batch = 6 8903*da0073e9SAndroid Build Coastguard Worker input_val = torch.randn(seq_length, batch, input_size, dtype=dtype) 8904*da0073e9SAndroid Build Coastguard Worker grad_output = torch.randn(seq_length, batch, hidden_size, dtype=dtype) 8905*da0073e9SAndroid Build Coastguard Worker hx_val = torch.randn(num_layers, batch, hidden_size, dtype=dtype) 8906*da0073e9SAndroid Build Coastguard Worker grad_hy = torch.randn(num_layers, batch, hidden_size, dtype=dtype) 8907*da0073e9SAndroid Build Coastguard Worker with torch.backends.cudnn.flags(enabled=False, allow_tf32=None): 8908*da0073e9SAndroid Build Coastguard Worker for module in (nn.GRU, nn.LSTM): 8909*da0073e9SAndroid Build Coastguard Worker for bias in (True, False): 8910*da0073e9SAndroid Build Coastguard Worker rnn = module(input_size, hidden_size, num_layers, bias=bias).to(dtype) 8911*da0073e9SAndroid Build Coastguard Worker rnn_device = module(input_size, hidden_size, num_layers, bias=bias).to(device, dtype) 8912*da0073e9SAndroid Build Coastguard Worker copy_rnn(rnn, rnn_device) 8913*da0073e9SAndroid Build Coastguard Worker 8914*da0073e9SAndroid Build Coastguard Worker is_lstm = isinstance(rnn, nn.LSTM) 8915*da0073e9SAndroid Build Coastguard Worker if is_lstm: 8916*da0073e9SAndroid Build Coastguard Worker hx = (hx_val.clone().requires_grad_(True), 8917*da0073e9SAndroid Build Coastguard Worker hx_val.clone().add(1).requires_grad_(True)) 8918*da0073e9SAndroid Build Coastguard Worker hx_device = (hx_val.clone().to(device).requires_grad_(True), 8919*da0073e9SAndroid Build Coastguard Worker hx_val.clone().to(device).add(1).requires_grad_(True)) 8920*da0073e9SAndroid Build Coastguard Worker else: 8921*da0073e9SAndroid Build Coastguard Worker hx = hx_val.clone().requires_grad_(True) 8922*da0073e9SAndroid Build Coastguard Worker hx_device = hx_val.clone().to(device).requires_grad_(True) 8923*da0073e9SAndroid Build Coastguard Worker 8924*da0073e9SAndroid Build Coastguard Worker inp = input_val.clone().requires_grad_(True) 8925*da0073e9SAndroid Build Coastguard Worker inp_cu = input_val.clone().to(device).requires_grad_(True) 8926*da0073e9SAndroid Build Coastguard Worker output1, hy1 = rnn(inp, hx) 8927*da0073e9SAndroid Build Coastguard Worker output2, hy2 = rnn_device(inp_cu, hx_device) 8928*da0073e9SAndroid Build Coastguard Worker if is_lstm: 8929*da0073e9SAndroid Build Coastguard Worker torch.autograd.backward( 8930*da0073e9SAndroid Build Coastguard Worker [output1, hy1[0], hy1[1]], [grad_output, grad_hy, grad_hy + 1] 8931*da0073e9SAndroid Build Coastguard Worker ) 8932*da0073e9SAndroid Build Coastguard Worker torch.autograd.backward( 8933*da0073e9SAndroid Build Coastguard Worker [output2, hy2[0], hy2[1]], 8934*da0073e9SAndroid Build Coastguard Worker [grad_output.to(device), grad_hy.to(device), (grad_hy + 1).to(device)] 8935*da0073e9SAndroid Build Coastguard Worker ) 8936*da0073e9SAndroid Build Coastguard Worker else: 8937*da0073e9SAndroid Build Coastguard Worker torch.autograd.backward([output1, hy1], [grad_output, grad_hy]) 8938*da0073e9SAndroid Build Coastguard Worker torch.autograd.backward([output2, hy2], [grad_output.to(device), grad_hy.to(device)]) 8939*da0073e9SAndroid Build Coastguard Worker 8940*da0073e9SAndroid Build Coastguard Worker self.assertEqual(output1, output2) 8941*da0073e9SAndroid Build Coastguard Worker self.assertEqual(hy1, hy2) 8942*da0073e9SAndroid Build Coastguard Worker 8943*da0073e9SAndroid Build Coastguard Worker check_rnn_grads(rnn, rnn_device) 8944*da0073e9SAndroid Build Coastguard Worker self.assertEqual(inp.grad, inp_cu.grad) 8945*da0073e9SAndroid Build Coastguard Worker if is_lstm: 8946*da0073e9SAndroid Build Coastguard Worker self.assertEqual(hx[0].grad, hx_device[0].grad) 8947*da0073e9SAndroid Build Coastguard Worker self.assertEqual(hx[1].grad, hx_device[1].grad) 8948*da0073e9SAndroid Build Coastguard Worker else: 8949*da0073e9SAndroid Build Coastguard Worker self.assertEqual(hx.grad, hx_device.grad) 8950*da0073e9SAndroid Build Coastguard Worker 8951*da0073e9SAndroid Build Coastguard Worker @dtypesIfMPS(torch.float) 8952*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.double) 8953*da0073e9SAndroid Build Coastguard Worker def test_BatchNorm_empty(self, device, dtype): 8954*da0073e9SAndroid Build Coastguard Worker mod = torch.nn.BatchNorm2d(3).to(device) 8955*da0073e9SAndroid Build Coastguard Worker inp = torch.randn(0, 3, 2, 2, device=device, dtype=dtype) 8956*da0073e9SAndroid Build Coastguard Worker _test_module_empty_input(self, mod, inp) 8957*da0073e9SAndroid Build Coastguard Worker if self.device_type == 'cuda' and self.has_cudnn(): 8958*da0073e9SAndroid Build Coastguard Worker with torch.backends.cudnn.flags(enabled=False): 8959*da0073e9SAndroid Build Coastguard Worker _test_module_empty_input(self, mod, inp) 8960*da0073e9SAndroid Build Coastguard Worker 8961*da0073e9SAndroid Build Coastguard Worker self.assertEqual(mod.running_mean, torch.tensor([0., 0, 0], device=device)) 8962*da0073e9SAndroid Build Coastguard Worker self.assertEqual(mod.running_var, torch.tensor([1., 1, 1], device=device)) 8963*da0073e9SAndroid Build Coastguard Worker self.assertEqual(mod.weight.grad, torch.tensor([0., 0, 0], device=device)) 8964*da0073e9SAndroid Build Coastguard Worker self.assertEqual(mod.bias.grad, torch.tensor([0., 0, 0], device=device)) 8965*da0073e9SAndroid Build Coastguard Worker 8966*da0073e9SAndroid Build Coastguard Worker @onlyCUDA 8967*da0073e9SAndroid Build Coastguard Worker @largeTensorTest('16GB') 8968*da0073e9SAndroid Build Coastguard Worker def test_prelu_backward_32bit_indexing(self, device): 8969*da0073e9SAndroid Build Coastguard Worker m = torch.nn.PReLU().cuda().half() 8970*da0073e9SAndroid Build Coastguard Worker input_ = torch.ones((1024, 1024, 1024, 2), dtype=torch.half, device=device) 8971*da0073e9SAndroid Build Coastguard Worker output = m(input_) 8972*da0073e9SAndroid Build Coastguard Worker output.backward(input_) 8973*da0073e9SAndroid Build Coastguard Worker 8974*da0073e9SAndroid Build Coastguard Worker def test_linear_empty(self, device): 8975*da0073e9SAndroid Build Coastguard Worker mod = torch.nn.Linear(7, 7).to(device) 8976*da0073e9SAndroid Build Coastguard Worker inp = torch.randn(0, 7, device=device) 8977*da0073e9SAndroid Build Coastguard Worker _test_module_empty_input(self, mod, inp) 8978*da0073e9SAndroid Build Coastguard Worker 8979*da0073e9SAndroid Build Coastguard Worker def test_one_hot(self, device): 8980*da0073e9SAndroid Build Coastguard Worker # cuda throws device assert for invalid data 8981*da0073e9SAndroid Build Coastguard Worker # xla ignores out of bound indices 8982*da0073e9SAndroid Build Coastguard Worker if self.device_type not in ('cuda', 'mps', 'xla'): 8983*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(RuntimeError): 8984*da0073e9SAndroid Build Coastguard Worker torch.nn.functional.one_hot(torch.tensor([3, 4, -1, 0], device=device), -1) 8985*da0073e9SAndroid Build Coastguard Worker 8986*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(RuntimeError): 8987*da0073e9SAndroid Build Coastguard Worker torch.nn.functional.one_hot(torch.tensor([3, 4, 1, 0], device=device), 3) 8988*da0073e9SAndroid Build Coastguard Worker 8989*da0073e9SAndroid Build Coastguard Worker t = torch.nn.functional.one_hot(torch.tensor([3, 4, 1, 0], device=device)) 8990*da0073e9SAndroid Build Coastguard Worker expected = torch.tensor([[0, 0, 0, 1, 0], 8991*da0073e9SAndroid Build Coastguard Worker [0, 0, 0, 0, 1], 8992*da0073e9SAndroid Build Coastguard Worker [0, 1, 0, 0, 0], 8993*da0073e9SAndroid Build Coastguard Worker [1, 0, 0, 0, 0]], device=device) 8994*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t, expected) 8995*da0073e9SAndroid Build Coastguard Worker 8996*da0073e9SAndroid Build Coastguard Worker t = torch.nn.functional.one_hot(torch.tensor([3, 4, 1, 0], device=device), -1) 8997*da0073e9SAndroid Build Coastguard Worker expected = torch.tensor([[0, 0, 0, 1, 0], 8998*da0073e9SAndroid Build Coastguard Worker [0, 0, 0, 0, 1], 8999*da0073e9SAndroid Build Coastguard Worker [0, 1, 0, 0, 0], 9000*da0073e9SAndroid Build Coastguard Worker [1, 0, 0, 0, 0]], device=device) 9001*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t, expected) 9002*da0073e9SAndroid Build Coastguard Worker 9003*da0073e9SAndroid Build Coastguard Worker t = torch.nn.functional.one_hot(torch.tensor([3, 4, 1, 0], device=device), 6) 9004*da0073e9SAndroid Build Coastguard Worker expected = torch.tensor([[0, 0, 0, 1, 0, 0], 9005*da0073e9SAndroid Build Coastguard Worker [0, 0, 0, 0, 1, 0], 9006*da0073e9SAndroid Build Coastguard Worker [0, 1, 0, 0, 0, 0], 9007*da0073e9SAndroid Build Coastguard Worker [1, 0, 0, 0, 0, 0]], device=device) 9008*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t, expected) 9009*da0073e9SAndroid Build Coastguard Worker 9010*da0073e9SAndroid Build Coastguard Worker t = torch.nn.functional.one_hot(torch.tensor([[3, 4], [1, 0]], device=device)) 9011*da0073e9SAndroid Build Coastguard Worker expected = torch.tensor([[[0, 0, 0, 1, 0], 9012*da0073e9SAndroid Build Coastguard Worker [0, 0, 0, 0, 1]], 9013*da0073e9SAndroid Build Coastguard Worker [[0, 1, 0, 0, 0], 9014*da0073e9SAndroid Build Coastguard Worker [1, 0, 0, 0, 0]]], device=device) 9015*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t, expected) 9016*da0073e9SAndroid Build Coastguard Worker 9017*da0073e9SAndroid Build Coastguard Worker t = torch.nn.functional.one_hot(torch.tensor(4, device=device)) 9018*da0073e9SAndroid Build Coastguard Worker expected = torch.tensor([0, 0, 0, 0, 1], device=device) 9019*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t, expected) 9020*da0073e9SAndroid Build Coastguard Worker 9021*da0073e9SAndroid Build Coastguard Worker t = torch.nn.functional.one_hot(torch.empty([4, 0], dtype=torch.long, device=device), 100) 9022*da0073e9SAndroid Build Coastguard Worker expected = torch.empty([4, 0, 100], dtype=torch.long) 9023*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t, expected) 9024*da0073e9SAndroid Build Coastguard Worker 9025*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(RuntimeError): 9026*da0073e9SAndroid Build Coastguard Worker torch.nn.functional.one_hot(torch.empty([4, 0], dtype=torch.long, device=device)) 9027*da0073e9SAndroid Build Coastguard Worker 9028*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(RuntimeError): 9029*da0073e9SAndroid Build Coastguard Worker torch.nn.functional.one_hot(torch.tensor([3, 4, 1, 0], device=device), -2) 9030*da0073e9SAndroid Build Coastguard Worker 9031*da0073e9SAndroid Build Coastguard Worker @expectedFailureMPS # NotImplementedError: aten::rrelu_with_noise https://github.com/pytorch/pytorch/issues/77764 9032*da0073e9SAndroid Build Coastguard Worker def test_nn_empty(self, device): 9033*da0073e9SAndroid Build Coastguard Worker # One off tests to ensure scalars from nn.yaml are properly applied 9034*da0073e9SAndroid Build Coastguard Worker def verify_scalars(input, output): 9035*da0073e9SAndroid Build Coastguard Worker self.assertEqual(input.shape, output.shape) 9036*da0073e9SAndroid Build Coastguard Worker self.assertEqual(0, output.numel()) 9037*da0073e9SAndroid Build Coastguard Worker 9038*da0073e9SAndroid Build Coastguard Worker for input_shape in [(0), (0, 2)]: 9039*da0073e9SAndroid Build Coastguard Worker for module in [torch.nn.ELU, torch.nn.Hardtanh, torch.nn.LeakyReLU, torch.nn.LogSigmoid, 9040*da0073e9SAndroid Build Coastguard Worker torch.nn.RReLU, torch.nn.Softshrink, torch.nn.Softplus, torch.nn.Sigmoid, 9041*da0073e9SAndroid Build Coastguard Worker torch.nn.Tanh]: 9042*da0073e9SAndroid Build Coastguard Worker input = torch.randn(input_shape, device=device, requires_grad=True) 9043*da0073e9SAndroid Build Coastguard Worker m = module() 9044*da0073e9SAndroid Build Coastguard Worker output = m(input) 9045*da0073e9SAndroid Build Coastguard Worker verify_scalars(input, output) 9046*da0073e9SAndroid Build Coastguard Worker 9047*da0073e9SAndroid Build Coastguard Worker @expectedFailureMPS # NotImplementedError: aten::rrelu_with_noise https://github.com/pytorch/pytorch/issues/77764 9048*da0073e9SAndroid Build Coastguard Worker def test_nn_scalars(self, device): 9049*da0073e9SAndroid Build Coastguard Worker # One off tests to ensure scalars from nn.yaml are properly applied 9050*da0073e9SAndroid Build Coastguard Worker def verify_scalars(input, output): 9051*da0073e9SAndroid Build Coastguard Worker if input.dim() == 0: 9052*da0073e9SAndroid Build Coastguard Worker self.assertEqual((), output.shape) 9053*da0073e9SAndroid Build Coastguard Worker else: 9054*da0073e9SAndroid Build Coastguard Worker self.assertNotEqual((), output.shape) 9055*da0073e9SAndroid Build Coastguard Worker output.sum().backward() 9056*da0073e9SAndroid Build Coastguard Worker self.assertEqual(input.shape, input.grad.shape) 9057*da0073e9SAndroid Build Coastguard Worker 9058*da0073e9SAndroid Build Coastguard Worker for input_shape in [(5, 6), ()]: 9059*da0073e9SAndroid Build Coastguard Worker for module in [torch.nn.ELU, torch.nn.Hardtanh, torch.nn.LeakyReLU, torch.nn.LogSigmoid, 9060*da0073e9SAndroid Build Coastguard Worker torch.nn.RReLU, torch.nn.Softshrink, torch.nn.Softplus, torch.nn.Sigmoid, 9061*da0073e9SAndroid Build Coastguard Worker torch.nn.Tanh]: 9062*da0073e9SAndroid Build Coastguard Worker input = torch.randn(input_shape, device=device, requires_grad=True) 9063*da0073e9SAndroid Build Coastguard Worker m = module() 9064*da0073e9SAndroid Build Coastguard Worker output = m(input) 9065*da0073e9SAndroid Build Coastguard Worker verify_scalars(input, output) 9066*da0073e9SAndroid Build Coastguard Worker 9067*da0073e9SAndroid Build Coastguard Worker def test_nn_scalars_reductions(self, device): 9068*da0073e9SAndroid Build Coastguard Worker # One off tests to ensure scalars from nn.yaml are properly applied 9069*da0073e9SAndroid Build Coastguard Worker def verify_reduction_scalars(input, reduction, output): 9070*da0073e9SAndroid Build Coastguard Worker if reduction != 'none' or input.dim() == 0: 9071*da0073e9SAndroid Build Coastguard Worker self.assertEqual((), output.shape) 9072*da0073e9SAndroid Build Coastguard Worker else: 9073*da0073e9SAndroid Build Coastguard Worker self.assertNotEqual((), output.shape) 9074*da0073e9SAndroid Build Coastguard Worker output.sum().backward() 9075*da0073e9SAndroid Build Coastguard Worker self.assertEqual(input.shape, input.grad.shape) 9076*da0073e9SAndroid Build Coastguard Worker 9077*da0073e9SAndroid Build Coastguard Worker for input_shape in [(5, 6), ()]: 9078*da0073e9SAndroid Build Coastguard Worker for reduction in ['none', 'mean', 'sum']: 9079*da0073e9SAndroid Build Coastguard Worker for module in [torch.nn.BCELoss, torch.nn.L1Loss, torch.nn.MSELoss, 9080*da0073e9SAndroid Build Coastguard Worker torch.nn.SmoothL1Loss, torch.nn.SoftMarginLoss]: 9081*da0073e9SAndroid Build Coastguard Worker input = torch.randn(input_shape, device=device, requires_grad=True) 9082*da0073e9SAndroid Build Coastguard Worker target = torch.empty(input_shape, device=device).random_(2) 9083*da0073e9SAndroid Build Coastguard Worker sigmoid = nn.Sigmoid() 9084*da0073e9SAndroid Build Coastguard Worker 9085*da0073e9SAndroid Build Coastguard Worker input = torch.randn(input_shape, device=device, requires_grad=True) 9086*da0073e9SAndroid Build Coastguard Worker m = module(reduction=reduction) 9087*da0073e9SAndroid Build Coastguard Worker output = m(sigmoid(input), target) 9088*da0073e9SAndroid Build Coastguard Worker verify_reduction_scalars(input, reduction, output) 9089*da0073e9SAndroid Build Coastguard Worker 9090*da0073e9SAndroid Build Coastguard Worker # verify that bogus reduction strings are errors 9091*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 9092*da0073e9SAndroid Build Coastguard Worker def test_invalid_reduction_strings(self, device): 9093*da0073e9SAndroid Build Coastguard Worker input = torch.randn(3, 5, requires_grad=True, device=device) 9094*da0073e9SAndroid Build Coastguard Worker cinput = torch.randn(3, 5, requires_grad=True, device=device, dtype=torch.cfloat) 9095*da0073e9SAndroid Build Coastguard Worker target = torch.tensor([1, 0, 4], device=device) 9096*da0073e9SAndroid Build Coastguard Worker var = torch.ones(size=input.size(), requires_grad=True, device=device) 9097*da0073e9SAndroid Build Coastguard Worker 9098*da0073e9SAndroid Build Coastguard Worker for reduction in ['none', 'invalid']: 9099*da0073e9SAndroid Build Coastguard Worker def v(fn): 9100*da0073e9SAndroid Build Coastguard Worker if reduction == 'invalid': 9101*da0073e9SAndroid Build Coastguard Worker self.assertRaises(ValueError, lambda: fn()) 9102*da0073e9SAndroid Build Coastguard Worker else: 9103*da0073e9SAndroid Build Coastguard Worker fn() 9104*da0073e9SAndroid Build Coastguard Worker 9105*da0073e9SAndroid Build Coastguard Worker v(lambda: F.nll_loss(input, target, reduction=reduction)) 9106*da0073e9SAndroid Build Coastguard Worker v(lambda: F.cross_entropy(input, target, reduction=reduction)) 9107*da0073e9SAndroid Build Coastguard Worker 9108*da0073e9SAndroid Build Coastguard Worker v(lambda: F.kl_div(input, input, reduction=reduction)) 9109*da0073e9SAndroid Build Coastguard Worker v(lambda: F.huber_loss(input, input, reduction=reduction)) 9110*da0073e9SAndroid Build Coastguard Worker v(lambda: F.smooth_l1_loss(input, input, reduction=reduction)) 9111*da0073e9SAndroid Build Coastguard Worker v(lambda: F.l1_loss(input, input, reduction=reduction)) 9112*da0073e9SAndroid Build Coastguard Worker v(lambda: F.l1_loss(cinput, cinput, reduction=reduction)) 9113*da0073e9SAndroid Build Coastguard Worker v(lambda: F.mse_loss(input, input, reduction=reduction)) 9114*da0073e9SAndroid Build Coastguard Worker v(lambda: F.hinge_embedding_loss(input, input, reduction=reduction)) 9115*da0073e9SAndroid Build Coastguard Worker v(lambda: F.poisson_nll_loss(input, input, reduction=reduction)) 9116*da0073e9SAndroid Build Coastguard Worker v(lambda: F.gaussian_nll_loss(input, input, var, reduction=reduction)) 9117*da0073e9SAndroid Build Coastguard Worker v(lambda: F.binary_cross_entropy(torch.sigmoid(input), input.gt(0).to(torch.get_default_dtype()), reduction=reduction)) 9118*da0073e9SAndroid Build Coastguard Worker v(lambda: F.binary_cross_entropy_with_logits(input, input, reduction=reduction)) 9119*da0073e9SAndroid Build Coastguard Worker 9120*da0073e9SAndroid Build Coastguard Worker zeros = torch.zeros_like(input).to(torch.int64) 9121*da0073e9SAndroid Build Coastguard Worker v(lambda: F.multilabel_soft_margin_loss(input, zeros, reduction=reduction)) 9122*da0073e9SAndroid Build Coastguard Worker 9123*da0073e9SAndroid Build Coastguard Worker v(lambda: F.triplet_margin_loss(input, input, input, reduction=reduction)) 9124*da0073e9SAndroid Build Coastguard Worker v(lambda: F.triplet_margin_with_distance_loss(input, input, input, reduction=reduction)) 9125*da0073e9SAndroid Build Coastguard Worker v(lambda: F.margin_ranking_loss(input, input, input.sign(), reduction=reduction)) 9126*da0073e9SAndroid Build Coastguard Worker v(lambda: F.cosine_embedding_loss(input, input, input[:, 0].sign(), reduction=reduction)) 9127*da0073e9SAndroid Build Coastguard Worker 9128*da0073e9SAndroid Build Coastguard Worker log_probs = torch.randn(50, 16, 20, requires_grad=True, device=device).log_softmax(2) 9129*da0073e9SAndroid Build Coastguard Worker targets = torch.randint(1, 20, (16, 30), dtype=torch.long, device=device) 9130*da0073e9SAndroid Build Coastguard Worker input_lengths = torch.full((16,), 50, dtype=torch.long, device=device) 9131*da0073e9SAndroid Build Coastguard Worker target_lengths = torch.randint(10, 30, (16,), dtype=torch.long, device=device) 9132*da0073e9SAndroid Build Coastguard Worker v(lambda: F.ctc_loss(log_probs, targets, input_lengths, target_lengths, reduction=reduction)) 9133*da0073e9SAndroid Build Coastguard Worker 9134*da0073e9SAndroid Build Coastguard Worker # FIXME: should we allow derivatives on these? 9135*da0073e9SAndroid Build Coastguard Worker v(lambda: F.soft_margin_loss(input, input.sign().detach(), reduction=reduction)) 9136*da0073e9SAndroid Build Coastguard Worker 9137*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 9138*da0073e9SAndroid Build Coastguard Worker def test_smooth_l1_loss_vs_huber_loss(self, device): 9139*da0073e9SAndroid Build Coastguard Worker def _make_test_tensor(shape, contiguous=True): 9140*da0073e9SAndroid Build Coastguard Worker if contiguous: 9141*da0073e9SAndroid Build Coastguard Worker test_tensor = torch.randn(shape, device=device) 9142*da0073e9SAndroid Build Coastguard Worker else: 9143*da0073e9SAndroid Build Coastguard Worker # Select every other element in the innermost dimension to 9144*da0073e9SAndroid Build Coastguard Worker # make it non-contiguous. 9145*da0073e9SAndroid Build Coastguard Worker doubled_shape = list(shape) 9146*da0073e9SAndroid Build Coastguard Worker doubled_shape[-1] *= 2 9147*da0073e9SAndroid Build Coastguard Worker test_tensor = torch.randn(doubled_shape, device=device) 9148*da0073e9SAndroid Build Coastguard Worker test_tensor = test_tensor[..., ::2] 9149*da0073e9SAndroid Build Coastguard Worker return test_tensor 9150*da0073e9SAndroid Build Coastguard Worker 9151*da0073e9SAndroid Build Coastguard Worker def _test_smooth_l1_loss_vs_huber_loss_helper(input, target, beta, require_equal): 9152*da0073e9SAndroid Build Coastguard Worker for reduction in ['mean', 'sum', 'none']: 9153*da0073e9SAndroid Build Coastguard Worker smooth_l1 = torch.nn.SmoothL1Loss(beta=beta, reduction=reduction) 9154*da0073e9SAndroid Build Coastguard Worker # beta hyper-parameter is called delta for Huber 9155*da0073e9SAndroid Build Coastguard Worker huber = torch.nn.HuberLoss(delta=beta, reduction=reduction) 9156*da0073e9SAndroid Build Coastguard Worker smooth_l1_loss = smooth_l1(input, target) 9157*da0073e9SAndroid Build Coastguard Worker huber_loss = huber(input, target) 9158*da0073e9SAndroid Build Coastguard Worker 9159*da0073e9SAndroid Build Coastguard Worker if require_equal: 9160*da0073e9SAndroid Build Coastguard Worker self.assertEqual(smooth_l1_loss, huber_loss) 9161*da0073e9SAndroid Build Coastguard Worker else: 9162*da0073e9SAndroid Build Coastguard Worker # Huber loss should be larger than smooth L1 loss by a factor of beta. 9163*da0073e9SAndroid Build Coastguard Worker self.assertEqual(smooth_l1_loss * beta, huber_loss) 9164*da0073e9SAndroid Build Coastguard Worker 9165*da0073e9SAndroid Build Coastguard Worker def _test_smooth_l1_loss_vs_huber_loss_multi_input_helper(beta, require_equal): 9166*da0073e9SAndroid Build Coastguard Worker # Test the non-vectorized case. 9167*da0073e9SAndroid Build Coastguard Worker shape = (2, 2) 9168*da0073e9SAndroid Build Coastguard Worker _test_smooth_l1_loss_vs_huber_loss_helper(input=_make_test_tensor(shape), 9169*da0073e9SAndroid Build Coastguard Worker target=_make_test_tensor(shape), 9170*da0073e9SAndroid Build Coastguard Worker beta=beta, 9171*da0073e9SAndroid Build Coastguard Worker require_equal=require_equal) 9172*da0073e9SAndroid Build Coastguard Worker 9173*da0073e9SAndroid Build Coastguard Worker # Test the vectorized case (innermost dim > 32). 9174*da0073e9SAndroid Build Coastguard Worker shape = (64, 64) 9175*da0073e9SAndroid Build Coastguard Worker _test_smooth_l1_loss_vs_huber_loss_helper(input=_make_test_tensor(shape), 9176*da0073e9SAndroid Build Coastguard Worker target=_make_test_tensor(shape), 9177*da0073e9SAndroid Build Coastguard Worker beta=beta, 9178*da0073e9SAndroid Build Coastguard Worker require_equal=require_equal) 9179*da0073e9SAndroid Build Coastguard Worker 9180*da0073e9SAndroid Build Coastguard Worker # Test the non-contiguous case. 9181*da0073e9SAndroid Build Coastguard Worker _test_smooth_l1_loss_vs_huber_loss_helper(input=_make_test_tensor(shape, contiguous=False), 9182*da0073e9SAndroid Build Coastguard Worker target=_make_test_tensor(shape, contiguous=False), 9183*da0073e9SAndroid Build Coastguard Worker beta=beta, 9184*da0073e9SAndroid Build Coastguard Worker require_equal=require_equal) 9185*da0073e9SAndroid Build Coastguard Worker 9186*da0073e9SAndroid Build Coastguard Worker def test_equal_when_beta_is_one(): 9187*da0073e9SAndroid Build Coastguard Worker _test_smooth_l1_loss_vs_huber_loss_multi_input_helper(beta=1.0, require_equal=True) 9188*da0073e9SAndroid Build Coastguard Worker 9189*da0073e9SAndroid Build Coastguard Worker def test_unequal_when_beta_is_less_than_one(): 9190*da0073e9SAndroid Build Coastguard Worker _test_smooth_l1_loss_vs_huber_loss_multi_input_helper(beta=0.5, require_equal=False) 9191*da0073e9SAndroid Build Coastguard Worker 9192*da0073e9SAndroid Build Coastguard Worker def test_unequal_when_beta_is_greater_than_one(): 9193*da0073e9SAndroid Build Coastguard Worker _test_smooth_l1_loss_vs_huber_loss_multi_input_helper(beta=1.5, require_equal=False) 9194*da0073e9SAndroid Build Coastguard Worker 9195*da0073e9SAndroid Build Coastguard Worker test_equal_when_beta_is_one() 9196*da0073e9SAndroid Build Coastguard Worker test_unequal_when_beta_is_less_than_one() 9197*da0073e9SAndroid Build Coastguard Worker test_unequal_when_beta_is_greater_than_one() 9198*da0073e9SAndroid Build Coastguard Worker 9199*da0073e9SAndroid Build Coastguard Worker @onlyCPU 9200*da0073e9SAndroid Build Coastguard Worker def test_smooth_l1_loss_bfloat16(self, device): 9201*da0073e9SAndroid Build Coastguard Worker def test_dtype(fn, input, target, dtype): 9202*da0073e9SAndroid Build Coastguard Worker input = input.detach().clone().to(dtype=dtype).requires_grad_(True) 9203*da0073e9SAndroid Build Coastguard Worker input2 = input.detach().clone().float().requires_grad_(True) 9204*da0073e9SAndroid Build Coastguard Worker target = target.detach().clone().to(dtype=dtype) 9205*da0073e9SAndroid Build Coastguard Worker target2 = target.detach().clone().float() 9206*da0073e9SAndroid Build Coastguard Worker out = fn(input, target) 9207*da0073e9SAndroid Build Coastguard Worker out.sum().backward() 9208*da0073e9SAndroid Build Coastguard Worker out2 = fn(input2, target2) 9209*da0073e9SAndroid Build Coastguard Worker out2.sum().backward() 9210*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out.dtype, dtype) 9211*da0073e9SAndroid Build Coastguard Worker self.assertEqual(input.grad.dtype, dtype) 9212*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out, out2, exact_dtype=False) 9213*da0073e9SAndroid Build Coastguard Worker self.assertEqual(input.grad, input2.grad, exact_dtype=False) 9214*da0073e9SAndroid Build Coastguard Worker 9215*da0073e9SAndroid Build Coastguard Worker def func(device): 9216*da0073e9SAndroid Build Coastguard Worker return nn.SmoothL1Loss().to(device=device) 9217*da0073e9SAndroid Build Coastguard Worker 9218*da0073e9SAndroid Build Coastguard Worker shapes = [[1, 3, 1, 6], [1, 3, 1, 128], [1, 3, 128, 128]] 9219*da0073e9SAndroid Build Coastguard Worker for shape in shapes: 9220*da0073e9SAndroid Build Coastguard Worker x = torch.randn(shape, device=device, requires_grad=True) 9221*da0073e9SAndroid Build Coastguard Worker t = torch.randn(shape, device=device) 9222*da0073e9SAndroid Build Coastguard Worker test_dtype(func(device), x, t, torch.bfloat16) 9223*da0073e9SAndroid Build Coastguard Worker 9224*da0073e9SAndroid Build Coastguard Worker # We don't want to make propagating NaN a hard requirement on ops, but for 9225*da0073e9SAndroid Build Coastguard Worker # these easy ones, we should make them do so. 9226*da0073e9SAndroid Build Coastguard Worker # MPS: NotImplementedError: aten::rrelu_with_noise_ https://github.com/pytorch/pytorch/issues/77764 9227*da0073e9SAndroid Build Coastguard Worker # MPS: NotImplementedError: aten::hardshrink.out https://github.com/pytorch/pytorch/issues/77764 9228*da0073e9SAndroid Build Coastguard Worker @expectedFailureMPS 9229*da0073e9SAndroid Build Coastguard Worker def test_nonlinearity_propagate_nan(self, device): 9230*da0073e9SAndroid Build Coastguard Worker def test(nonlinearity, *args, **kwargs): 9231*da0073e9SAndroid Build Coastguard Worker x = torch.tensor([nan], device=device) 9232*da0073e9SAndroid Build Coastguard Worker fn = getattr(F, nonlinearity) 9233*da0073e9SAndroid Build Coastguard Worker try: 9234*da0073e9SAndroid Build Coastguard Worker self.assertTrue(math.isnan(fn(x, *args, **kwargs).item())) 9235*da0073e9SAndroid Build Coastguard Worker except Exception as e: 9236*da0073e9SAndroid Build Coastguard Worker if 'not implemented' not in str(e): 9237*da0073e9SAndroid Build Coastguard Worker raise 9238*da0073e9SAndroid Build Coastguard Worker 9239*da0073e9SAndroid Build Coastguard Worker test('relu') 9240*da0073e9SAndroid Build Coastguard Worker test('relu', inplace=True) 9241*da0073e9SAndroid Build Coastguard Worker test('relu6') 9242*da0073e9SAndroid Build Coastguard Worker test('elu') 9243*da0073e9SAndroid Build Coastguard Worker test('selu') 9244*da0073e9SAndroid Build Coastguard Worker test('celu') 9245*da0073e9SAndroid Build Coastguard Worker test('rrelu') 9246*da0073e9SAndroid Build Coastguard Worker test('rrelu', inplace=True) 9247*da0073e9SAndroid Build Coastguard Worker test('hardtanh') 9248*da0073e9SAndroid Build Coastguard Worker test('tanh') 9249*da0073e9SAndroid Build Coastguard Worker test('sigmoid') 9250*da0073e9SAndroid Build Coastguard Worker test('logsigmoid') 9251*da0073e9SAndroid Build Coastguard Worker test('hardshrink') 9252*da0073e9SAndroid Build Coastguard Worker test('tanhshrink') 9253*da0073e9SAndroid Build Coastguard Worker test('softsign') 9254*da0073e9SAndroid Build Coastguard Worker test('softmin', 0) 9255*da0073e9SAndroid Build Coastguard Worker test('softmax', 0) 9256*da0073e9SAndroid Build Coastguard Worker test('log_softmax', 0) 9257*da0073e9SAndroid Build Coastguard Worker test('leaky_relu', 0.2) 9258*da0073e9SAndroid Build Coastguard Worker test('threshold', 3, 2) 9259*da0073e9SAndroid Build Coastguard Worker test('threshold', 3, 2, inplace=True) 9260*da0073e9SAndroid Build Coastguard Worker 9261*da0073e9SAndroid Build Coastguard Worker @expectedFailureMPS # TypeError: float64 the MPS framework doesn't support float64 9262*da0073e9SAndroid Build Coastguard Worker @parametrize_test("mode", ["nearest-exact", "nearest"]) 9263*da0073e9SAndroid Build Coastguard Worker def test_upsamplingNearest1d(self, device, mode): 9264*da0073e9SAndroid Build Coastguard Worker # Forward AD does not support XLA because XLA tensors don't have storage 9265*da0073e9SAndroid Build Coastguard Worker check_forward_ad = torch.device(device).type != 'xla' 9266*da0073e9SAndroid Build Coastguard Worker 9267*da0073e9SAndroid Build Coastguard Worker m = nn.Upsample(size=4, mode=mode) 9268*da0073e9SAndroid Build Coastguard Worker in_t = torch.ones(1, 1, 2, device=device, dtype=torch.double) 9269*da0073e9SAndroid Build Coastguard Worker in_uint8_t = torch.ones(1, 1, 2, dtype=torch.uint8, device=device) 9270*da0073e9SAndroid Build Coastguard Worker with warnings.catch_warnings(record=True) as w: 9271*da0073e9SAndroid Build Coastguard Worker out_t = m(in_t) 9272*da0073e9SAndroid Build Coastguard Worker out_uint8_t = m(in_uint8_t) 9273*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.ones(1, 1, 4, device=device, dtype=torch.double), out_t.data) 9274*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.ones(1, 1, 4, dtype=torch.uint8, device=device), out_uint8_t.data) 9275*da0073e9SAndroid Build Coastguard Worker 9276*da0073e9SAndroid Build Coastguard Worker # Checks upsampling 9277*da0073e9SAndroid Build Coastguard Worker input = torch.randn(1, 1, 2, requires_grad=True, device=device, dtype=torch.double) 9278*da0073e9SAndroid Build Coastguard Worker gradcheck(lambda x: F.interpolate(x, 4, mode=mode), [input], check_forward_ad=check_forward_ad) 9279*da0073e9SAndroid Build Coastguard Worker gradgradcheck(lambda x: F.interpolate(x, 4, mode=mode), [input], check_fwd_over_rev=check_forward_ad) 9280*da0073e9SAndroid Build Coastguard Worker 9281*da0073e9SAndroid Build Coastguard Worker # Checks downsampling 9282*da0073e9SAndroid Build Coastguard Worker input = torch.randn(1, 1, 20, requires_grad=True, device=device, dtype=torch.double) 9283*da0073e9SAndroid Build Coastguard Worker gradcheck(lambda x: F.interpolate(x, 11, mode=mode), [input], check_forward_ad=check_forward_ad) 9284*da0073e9SAndroid Build Coastguard Worker gradgradcheck(lambda x: F.interpolate(x, 4, mode=mode), [input], check_fwd_over_rev=check_forward_ad) 9285*da0073e9SAndroid Build Coastguard Worker 9286*da0073e9SAndroid Build Coastguard Worker # consistency CUDA/CPU check 9287*da0073e9SAndroid Build Coastguard Worker if torch.device(device).type == 'cuda': 9288*da0073e9SAndroid Build Coastguard Worker input_cuda = torch.randn(1, 1, 20, device=device, dtype=torch.double) 9289*da0073e9SAndroid Build Coastguard Worker input_cpu = input_cuda.cpu() 9290*da0073e9SAndroid Build Coastguard Worker output_cuda = F.interpolate(input_cuda, 4, mode=mode) 9291*da0073e9SAndroid Build Coastguard Worker output_cpu = F.interpolate(input_cpu, 4, mode=mode) 9292*da0073e9SAndroid Build Coastguard Worker self.assertEqual(output_cuda.cpu(), output_cpu) 9293*da0073e9SAndroid Build Coastguard Worker 9294*da0073e9SAndroid Build Coastguard Worker output_cuda = F.interpolate(input_cuda, 24, mode=mode) 9295*da0073e9SAndroid Build Coastguard Worker output_cpu = F.interpolate(input_cpu, 24, mode=mode) 9296*da0073e9SAndroid Build Coastguard Worker self.assertEqual(output_cuda.cpu(), output_cpu) 9297*da0073e9SAndroid Build Coastguard Worker 9298*da0073e9SAndroid Build Coastguard Worker @parametrize_test("isize, osize", [(20, 11), (10, 15)]) 9299*da0073e9SAndroid Build Coastguard Worker def test_upsamplingNearest1d_correctness(self, device, isize, osize): 9300*da0073e9SAndroid Build Coastguard Worker # Here we check if output matches OpenCV's INTER_NEAREST-like result 9301*da0073e9SAndroid Build Coastguard Worker in_t = torch.arange(isize, dtype=torch.float, device=device).unsqueeze(0).unsqueeze(0) 9302*da0073e9SAndroid Build Coastguard Worker out_t = F.interpolate( 9303*da0073e9SAndroid Build Coastguard Worker in_t, size=(osize, ), recompute_scale_factor=False, mode="nearest" 9304*da0073e9SAndroid Build Coastguard Worker ) 9305*da0073e9SAndroid Build Coastguard Worker # compute expected output as OpenCV 9306*da0073e9SAndroid Build Coastguard Worker expected_out = torch.zeros(osize, dtype=torch.float).unsqueeze(0).unsqueeze(0) 9307*da0073e9SAndroid Build Coastguard Worker scale = 1.0 * isize / osize 9308*da0073e9SAndroid Build Coastguard Worker for o in range(osize): 9309*da0073e9SAndroid Build Coastguard Worker i_f32 = o * scale 9310*da0073e9SAndroid Build Coastguard Worker i = int(i_f32) 9311*da0073e9SAndroid Build Coastguard Worker expected_out[0, 0, o] = in_t[0, 0, i] 9312*da0073e9SAndroid Build Coastguard Worker expected_out = expected_out.to(device=device) 9313*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out_t, expected_out) 9314*da0073e9SAndroid Build Coastguard Worker 9315*da0073e9SAndroid Build Coastguard Worker def test_upsamplingNearestExact1d_rescale(self, device): 9316*da0073e9SAndroid Build Coastguard Worker # Checks https://github.com/pytorch/pytorch/issues/62237 9317*da0073e9SAndroid Build Coastguard Worker isize = 20 9318*da0073e9SAndroid Build Coastguard Worker in_t = torch.arange(isize, dtype=torch.float, device=device).unsqueeze(0).unsqueeze(0) 9319*da0073e9SAndroid Build Coastguard Worker # for s in [1.00001, 0.99999]: # 0.9999 case is broken 9320*da0073e9SAndroid Build Coastguard Worker # See issue: https://github.com/pytorch/pytorch/issues/62396 9321*da0073e9SAndroid Build Coastguard Worker for s in [1.00001, ]: 9322*da0073e9SAndroid Build Coastguard Worker out_t = F.interpolate( 9323*da0073e9SAndroid Build Coastguard Worker in_t, scale_factor=s, recompute_scale_factor=False, mode="nearest-exact" 9324*da0073e9SAndroid Build Coastguard Worker ) 9325*da0073e9SAndroid Build Coastguard Worker expected_out = in_t 9326*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out_t, expected_out, msg=f"scale: {s}") 9327*da0073e9SAndroid Build Coastguard Worker 9328*da0073e9SAndroid Build Coastguard Worker # checks data duplication if output_size == 2 * input_size 9329*da0073e9SAndroid Build Coastguard Worker # for s in [2.00001, 1.99999]: # 1.99999 case is broken 9330*da0073e9SAndroid Build Coastguard Worker # See issue: https://github.com/pytorch/pytorch/issues/62396 9331*da0073e9SAndroid Build Coastguard Worker for s in [2.00001, ]: 9332*da0073e9SAndroid Build Coastguard Worker out_t = F.interpolate( 9333*da0073e9SAndroid Build Coastguard Worker in_t, scale_factor=s, recompute_scale_factor=False, mode="nearest-exact" 9334*da0073e9SAndroid Build Coastguard Worker ) 9335*da0073e9SAndroid Build Coastguard Worker # input is [[[0, 1, 2, 3, ..., 9]]] 9336*da0073e9SAndroid Build Coastguard Worker # expected out is [[[0, 0, 1, 1, 2, 2, ..., 9, 9]]] 9337*da0073e9SAndroid Build Coastguard Worker expected_out = in_t.repeat_interleave(2, dim=-1) 9338*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out_t, expected_out) 9339*da0073e9SAndroid Build Coastguard Worker 9340*da0073e9SAndroid Build Coastguard Worker @skipIfMps # Partially passes https://github.com/pytorch/pytorch/issues/134430 9341*da0073e9SAndroid Build Coastguard Worker @parametrize_test("isize, osize", [(20, 11), (10, 15)]) 9342*da0073e9SAndroid Build Coastguard Worker def test_upsamplingNearestExact1d_correctness(self, device, isize, osize): 9343*da0073e9SAndroid Build Coastguard Worker # Here we check if output matches Scikit-Image/Scipy-like result 9344*da0073e9SAndroid Build Coastguard Worker # Checks https://github.com/pytorch/pytorch/issues/34808 9345*da0073e9SAndroid Build Coastguard Worker in_t = torch.arange(isize, dtype=torch.float, device=device).unsqueeze(0).unsqueeze(0) 9346*da0073e9SAndroid Build Coastguard Worker out_t = F.interpolate( 9347*da0073e9SAndroid Build Coastguard Worker in_t, size=(osize, ), recompute_scale_factor=False, mode="nearest-exact" 9348*da0073e9SAndroid Build Coastguard Worker ) 9349*da0073e9SAndroid Build Coastguard Worker # compute expected output as scikit-image/scipy 9350*da0073e9SAndroid Build Coastguard Worker expected_out = torch.zeros(osize, dtype=torch.float).unsqueeze(0).unsqueeze(0) 9351*da0073e9SAndroid Build Coastguard Worker scale = 1.0 * isize / osize 9352*da0073e9SAndroid Build Coastguard Worker for o in range(osize): 9353*da0073e9SAndroid Build Coastguard Worker i_f32 = (o + 0.5) * scale 9354*da0073e9SAndroid Build Coastguard Worker i = int(i_f32) 9355*da0073e9SAndroid Build Coastguard Worker expected_out[0, 0, o] = in_t[0, 0, i] 9356*da0073e9SAndroid Build Coastguard Worker expected_out = expected_out.to(device=device) 9357*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out_t, expected_out) 9358*da0073e9SAndroid Build Coastguard Worker 9359*da0073e9SAndroid Build Coastguard Worker @expectedFailureMPS # TypeError: the MPS framework doesn't support float64 9360*da0073e9SAndroid Build Coastguard Worker @parametrize_test("memory_format", [torch.contiguous_format, torch.channels_last]) 9361*da0073e9SAndroid Build Coastguard Worker @parametrize_test("mode", ["nearest", "nearest-exact"]) 9362*da0073e9SAndroid Build Coastguard Worker def test_upsamplingNearest2d(self, device, memory_format, mode): 9363*da0073e9SAndroid Build Coastguard Worker # Forward AD does not support XLA because XLA tensors don't have storage 9364*da0073e9SAndroid Build Coastguard Worker check_forward_ad = torch.device(device).type != 'xla' 9365*da0073e9SAndroid Build Coastguard Worker 9366*da0073e9SAndroid Build Coastguard Worker in_t = torch.ones(1, 2, 2, 2, device=device, dtype=torch.double).contiguous(memory_format=memory_format) 9367*da0073e9SAndroid Build Coastguard Worker in_uint8_t = torch.ones(1, 2, 2, 2, dtype=torch.uint8, device=device).contiguous(memory_format=memory_format) 9368*da0073e9SAndroid Build Coastguard Worker with warnings.catch_warnings(record=True) as w: 9369*da0073e9SAndroid Build Coastguard Worker out_t = F.interpolate(in_t, size=4, mode=mode) 9370*da0073e9SAndroid Build Coastguard Worker out_uint8_t = F.interpolate(in_uint8_t, size=4, mode=mode) 9371*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(w), 0) 9372*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.ones(1, 2, 4, 4, device=device, dtype=torch.double), out_t) 9373*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.ones(1, 2, 4, 4, dtype=torch.uint8, device=device), out_uint8_t) 9374*da0073e9SAndroid Build Coastguard Worker # Assert that memory format is carried through to the output 9375*da0073e9SAndroid Build Coastguard Worker self.assertTrue(out_t.is_contiguous(memory_format=memory_format)) 9376*da0073e9SAndroid Build Coastguard Worker 9377*da0073e9SAndroid Build Coastguard Worker # test forward when input's height is not same as width 9378*da0073e9SAndroid Build Coastguard Worker in_t = torch.ones(1, 2, 2, 1, device=device, dtype=torch.double).contiguous(memory_format=memory_format).requires_grad_() 9379*da0073e9SAndroid Build Coastguard Worker out_t = F.interpolate(in_t, size=(4, 2), mode=mode) 9380*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.ones(1, 2, 4, 2, device=device, dtype=torch.double), out_t) 9381*da0073e9SAndroid Build Coastguard Worker self.assertTrue(out_t.is_contiguous(memory_format=memory_format)) 9382*da0073e9SAndroid Build Coastguard Worker 9383*da0073e9SAndroid Build Coastguard Worker out_t.backward(torch.randn_like(out_t)) 9384*da0073e9SAndroid Build Coastguard Worker self.assertTrue(in_t.grad.is_contiguous(memory_format=memory_format)) 9385*da0073e9SAndroid Build Coastguard Worker 9386*da0073e9SAndroid Build Coastguard Worker # test backward when input's height is not same as width 9387*da0073e9SAndroid Build Coastguard Worker input = torch.ones( 9388*da0073e9SAndroid Build Coastguard Worker 1, 2, 2, 1, requires_grad=True, device=device, 9389*da0073e9SAndroid Build Coastguard Worker dtype=torch.double).contiguous(memory_format=memory_format) 9390*da0073e9SAndroid Build Coastguard Worker gradcheck(lambda x: F.interpolate(x, size=(4, 2), mode=mode), [input], check_forward_ad=check_forward_ad) 9391*da0073e9SAndroid Build Coastguard Worker gradgradcheck(lambda x: F.interpolate(x, size=(4, 2), mode=mode), [input], check_fwd_over_rev=check_forward_ad) 9392*da0073e9SAndroid Build Coastguard Worker 9393*da0073e9SAndroid Build Coastguard Worker input = torch.randn( 9394*da0073e9SAndroid Build Coastguard Worker 1, 2, 2, 2, requires_grad=True, device=device, 9395*da0073e9SAndroid Build Coastguard Worker dtype=torch.double).contiguous(memory_format=memory_format) 9396*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 9397*da0073e9SAndroid Build Coastguard Worker F.interpolate(input, 4, mode=mode), 9398*da0073e9SAndroid Build Coastguard Worker F.interpolate(input, scale_factor=2, mode=mode)) 9399*da0073e9SAndroid Build Coastguard Worker gradcheck(lambda x: F.interpolate(x, 4, mode=mode), [input], check_forward_ad=check_forward_ad) 9400*da0073e9SAndroid Build Coastguard Worker gradgradcheck(lambda x: F.interpolate(x, 4, mode=mode), [input], check_fwd_over_rev=check_forward_ad) 9401*da0073e9SAndroid Build Coastguard Worker 9402*da0073e9SAndroid Build Coastguard Worker # Assert that cpu and cuda handle channels_last memory format in the same way 9403*da0073e9SAndroid Build Coastguard Worker # https://github.com/pytorch/pytorch/issues/54590 9404*da0073e9SAndroid Build Coastguard Worker if torch.device(device).type == 'cuda': 9405*da0073e9SAndroid Build Coastguard Worker for shapes, scale_factor in product([ 9406*da0073e9SAndroid Build Coastguard Worker (2, 2, 3, 4), (2, 3, 4, 5), (3, 1, 2, 2), (1, 5, 3, 2) 9407*da0073e9SAndroid Build Coastguard Worker ], [0.5, 1.5, 2]): 9408*da0073e9SAndroid Build Coastguard Worker a_cuda = torch.randn( 9409*da0073e9SAndroid Build Coastguard Worker *shapes, device=device, 9410*da0073e9SAndroid Build Coastguard Worker dtype=torch.double).contiguous(memory_format=memory_format).requires_grad_() 9411*da0073e9SAndroid Build Coastguard Worker a_cpu = a_cuda.detach().cpu().requires_grad_() 9412*da0073e9SAndroid Build Coastguard Worker 9413*da0073e9SAndroid Build Coastguard Worker out_cuda = F.interpolate(a_cuda, scale_factor=scale_factor, mode=mode) 9414*da0073e9SAndroid Build Coastguard Worker out_cpu = F.interpolate(a_cpu, scale_factor=scale_factor, mode=mode) 9415*da0073e9SAndroid Build Coastguard Worker 9416*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out_cpu.cuda(), out_cuda) 9417*da0073e9SAndroid Build Coastguard Worker 9418*da0073e9SAndroid Build Coastguard Worker g_cuda = torch.randn_like(out_cuda) 9419*da0073e9SAndroid Build Coastguard Worker g_cpu = g_cuda.cpu() 9420*da0073e9SAndroid Build Coastguard Worker 9421*da0073e9SAndroid Build Coastguard Worker out_cuda.backward(g_cuda) 9422*da0073e9SAndroid Build Coastguard Worker out_cpu.backward(g_cpu) 9423*da0073e9SAndroid Build Coastguard Worker 9424*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a_cuda.grad, a_cpu.grad) 9425*da0073e9SAndroid Build Coastguard Worker 9426*da0073e9SAndroid Build Coastguard Worker @parametrize_test("memory_format", [torch.contiguous_format, torch.channels_last]) 9427*da0073e9SAndroid Build Coastguard Worker @parametrize_test("isize, osize", [(20, 11), (10, 15)]) 9428*da0073e9SAndroid Build Coastguard Worker def test_upsamplingNearest2d_correctness(self, device, memory_format, isize, osize): 9429*da0073e9SAndroid Build Coastguard Worker # Here we check if output matches OpenCV's INTER_NEAREST-like result 9430*da0073e9SAndroid Build Coastguard Worker in_t = torch.arange(isize * isize, dtype=torch.float, device=device).reshape(1, 1, isize, isize) 9431*da0073e9SAndroid Build Coastguard Worker in_t = in_t.contiguous(memory_format=memory_format) 9432*da0073e9SAndroid Build Coastguard Worker out_t = F.interpolate( 9433*da0073e9SAndroid Build Coastguard Worker in_t, size=(osize, osize), recompute_scale_factor=False, mode="nearest" 9434*da0073e9SAndroid Build Coastguard Worker ) 9435*da0073e9SAndroid Build Coastguard Worker # compute expected output as OpenCV 9436*da0073e9SAndroid Build Coastguard Worker expected_out = torch.zeros(1, 1, osize, osize, dtype=torch.float) 9437*da0073e9SAndroid Build Coastguard Worker scale = 1.0 * isize / osize 9438*da0073e9SAndroid Build Coastguard Worker for o1 in range(osize): 9439*da0073e9SAndroid Build Coastguard Worker i1_f32 = o1 * scale 9440*da0073e9SAndroid Build Coastguard Worker i1 = int(i1_f32) 9441*da0073e9SAndroid Build Coastguard Worker for o2 in range(osize): 9442*da0073e9SAndroid Build Coastguard Worker i2_f32 = o2 * scale 9443*da0073e9SAndroid Build Coastguard Worker i2 = int(i2_f32) 9444*da0073e9SAndroid Build Coastguard Worker expected_out[0, 0, o1, o2] = in_t[0, 0, i1, i2] 9445*da0073e9SAndroid Build Coastguard Worker expected_out = expected_out.to(device=device) 9446*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out_t, expected_out) 9447*da0073e9SAndroid Build Coastguard Worker 9448*da0073e9SAndroid Build Coastguard Worker @skipIfMps # Partially passes https://github.com/pytorch/pytorch/issues/134430 9449*da0073e9SAndroid Build Coastguard Worker @parametrize_test("memory_format", [torch.contiguous_format, torch.channels_last]) 9450*da0073e9SAndroid Build Coastguard Worker @parametrize_test("isize, osize", [(20, 11), (10, 15)]) 9451*da0073e9SAndroid Build Coastguard Worker def test_upsamplingNearestExact2d_correctness(self, device, memory_format, isize, osize): 9452*da0073e9SAndroid Build Coastguard Worker # Here we check if output matches Scikit-Image/Scipy-like result 9453*da0073e9SAndroid Build Coastguard Worker # Checks https://github.com/pytorch/pytorch/issues/34808 9454*da0073e9SAndroid Build Coastguard Worker in_t = torch.arange(isize * isize, dtype=torch.float, device=device).reshape(1, 1, isize, isize) 9455*da0073e9SAndroid Build Coastguard Worker in_t = in_t.contiguous(memory_format=memory_format) 9456*da0073e9SAndroid Build Coastguard Worker out_t = F.interpolate( 9457*da0073e9SAndroid Build Coastguard Worker in_t, size=(osize, osize), recompute_scale_factor=False, mode="nearest-exact" 9458*da0073e9SAndroid Build Coastguard Worker ) 9459*da0073e9SAndroid Build Coastguard Worker # compute expected output as Scikit-Image/Scipy 9460*da0073e9SAndroid Build Coastguard Worker expected_out = torch.zeros(1, 1, osize, osize, dtype=torch.float) 9461*da0073e9SAndroid Build Coastguard Worker scale = 1.0 * isize / osize 9462*da0073e9SAndroid Build Coastguard Worker for o1 in range(osize): 9463*da0073e9SAndroid Build Coastguard Worker i1_f32 = (o1 + 0.5) * scale 9464*da0073e9SAndroid Build Coastguard Worker i1 = int(i1_f32) 9465*da0073e9SAndroid Build Coastguard Worker for o2 in range(osize): 9466*da0073e9SAndroid Build Coastguard Worker i2_f32 = (o2 + 0.5) * scale 9467*da0073e9SAndroid Build Coastguard Worker i2 = int(i2_f32) 9468*da0073e9SAndroid Build Coastguard Worker expected_out[0, 0, o1, o2] = in_t[0, 0, i1, i2] 9469*da0073e9SAndroid Build Coastguard Worker expected_out = expected_out.to(device=device) 9470*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out_t, expected_out) 9471*da0073e9SAndroid Build Coastguard Worker 9472*da0073e9SAndroid Build Coastguard Worker @expectedFailureMPS # TypeError: the MPS framework doesn't support float64 9473*da0073e9SAndroid Build Coastguard Worker @parametrize_test("memory_format", [torch.contiguous_format, torch.channels_last_3d]) 9474*da0073e9SAndroid Build Coastguard Worker @parametrize_test("mode", ["nearest", "nearest-exact"]) 9475*da0073e9SAndroid Build Coastguard Worker def test_upsamplingNearest3d(self, device, memory_format, mode): 9476*da0073e9SAndroid Build Coastguard Worker # Forward AD does not support XLA because XLA tensors don't have storage 9477*da0073e9SAndroid Build Coastguard Worker check_forward_ad = torch.device(device).type != 'xla' 9478*da0073e9SAndroid Build Coastguard Worker 9479*da0073e9SAndroid Build Coastguard Worker m = nn.Upsample(size=4, mode=mode) 9480*da0073e9SAndroid Build Coastguard Worker in_t = torch.ones(1, 2, 2, 2, 2, device=device, dtype=torch.double).contiguous(memory_format=memory_format).requires_grad_() 9481*da0073e9SAndroid Build Coastguard Worker in_uint8_t = torch.ones( 9482*da0073e9SAndroid Build Coastguard Worker 1, 2, 2, 2, 2, dtype=torch.uint8, device=device 9483*da0073e9SAndroid Build Coastguard Worker ).contiguous(memory_format=memory_format) 9484*da0073e9SAndroid Build Coastguard Worker with warnings.catch_warnings(record=True) as w: 9485*da0073e9SAndroid Build Coastguard Worker out_t = m(in_t) 9486*da0073e9SAndroid Build Coastguard Worker out_uint8_t = m(in_uint8_t) 9487*da0073e9SAndroid Build Coastguard Worker expected_output = torch.ones(1, 2, 4, 4, 4, device=device, dtype=torch.double) 9488*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected_output, out_t) 9489*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected_output.to(torch.uint8), out_uint8_t) 9490*da0073e9SAndroid Build Coastguard Worker # Assert that memory format is carried through to the output 9491*da0073e9SAndroid Build Coastguard Worker self.assertTrue(out_t.is_contiguous(memory_format=memory_format)) 9492*da0073e9SAndroid Build Coastguard Worker out_t.backward(torch.randn_like(out_t)) 9493*da0073e9SAndroid Build Coastguard Worker self.assertTrue(in_t.grad.is_contiguous(memory_format=memory_format)) 9494*da0073e9SAndroid Build Coastguard Worker 9495*da0073e9SAndroid Build Coastguard Worker input = torch.randn( 9496*da0073e9SAndroid Build Coastguard Worker 1, 2, 2, 2, 2, requires_grad=True, device=device, dtype=torch.double 9497*da0073e9SAndroid Build Coastguard Worker ).contiguous(memory_format=memory_format) 9498*da0073e9SAndroid Build Coastguard Worker gradcheck(lambda x: F.interpolate(x, 4, mode=mode), [input], check_forward_ad=check_forward_ad) 9499*da0073e9SAndroid Build Coastguard Worker gradgradcheck(lambda x: F.interpolate(x, 4, mode=mode), [input], check_fwd_over_rev=check_forward_ad) 9500*da0073e9SAndroid Build Coastguard Worker 9501*da0073e9SAndroid Build Coastguard Worker # Assert that cpu and cuda handle channels_last memory format in the same way 9502*da0073e9SAndroid Build Coastguard Worker # https://github.com/pytorch/pytorch/issues/54590 9503*da0073e9SAndroid Build Coastguard Worker if torch.device(device).type == 'cuda': 9504*da0073e9SAndroid Build Coastguard Worker a = torch.ones( 9505*da0073e9SAndroid Build Coastguard Worker 2, 2, 2, 3, 4, device=device, requires_grad=True, dtype=torch.double 9506*da0073e9SAndroid Build Coastguard Worker ).contiguous(memory_format=torch.channels_last_3d) 9507*da0073e9SAndroid Build Coastguard Worker # make the data asymmetric; ensure that cuda/cpu handle channels_last appropriately. 9508*da0073e9SAndroid Build Coastguard Worker a[1][1][1][2][2] = a[1][1][1][2][3] = 0 9509*da0073e9SAndroid Build Coastguard Worker 9510*da0073e9SAndroid Build Coastguard Worker out_cuda = torch.nn.functional.interpolate(a, scale_factor=2, mode=mode) 9511*da0073e9SAndroid Build Coastguard Worker out_cpu = torch.nn.functional.interpolate(a.to('cpu'), scale_factor=2, mode=mode) 9512*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out_cpu, out_cuda.to('cpu')) 9513*da0073e9SAndroid Build Coastguard Worker 9514*da0073e9SAndroid Build Coastguard Worker gradcheck(lambda x: F.interpolate(x, 4, mode=mode), [a], check_forward_ad=check_forward_ad) 9515*da0073e9SAndroid Build Coastguard Worker gradgradcheck(lambda x: F.interpolate(x, 4, mode=mode), [a], check_fwd_over_rev=check_forward_ad) 9516*da0073e9SAndroid Build Coastguard Worker 9517*da0073e9SAndroid Build Coastguard Worker gradcheck(lambda x: F.interpolate(x, 4, mode=mode), [a.to('cuda')], check_forward_ad=check_forward_ad) 9518*da0073e9SAndroid Build Coastguard Worker gradgradcheck(lambda x: F.interpolate(x, 4, mode=mode), [a.to('cuda')], check_fwd_over_rev=check_forward_ad) 9519*da0073e9SAndroid Build Coastguard Worker 9520*da0073e9SAndroid Build Coastguard Worker @parametrize_test("memory_format", [torch.contiguous_format, torch.channels_last_3d]) 9521*da0073e9SAndroid Build Coastguard Worker @parametrize_test("isize, osize", [(20, 11), (10, 15)]) 9522*da0073e9SAndroid Build Coastguard Worker def test_upsamplingNearest3d_correctness(self, device, memory_format, isize, osize): 9523*da0073e9SAndroid Build Coastguard Worker # Here we check if output matches OpenCV's INTER_NEAREST-like result 9524*da0073e9SAndroid Build Coastguard Worker in_t = torch.arange(isize * isize * isize, dtype=torch.float, device=device) 9525*da0073e9SAndroid Build Coastguard Worker in_t = in_t.reshape(1, 1, isize, isize, isize) 9526*da0073e9SAndroid Build Coastguard Worker in_t = in_t.contiguous(memory_format=memory_format) 9527*da0073e9SAndroid Build Coastguard Worker out_t = F.interpolate( 9528*da0073e9SAndroid Build Coastguard Worker in_t, size=(osize, osize, osize), recompute_scale_factor=False, mode="nearest" 9529*da0073e9SAndroid Build Coastguard Worker ) 9530*da0073e9SAndroid Build Coastguard Worker # compute expected output as OpenCV 9531*da0073e9SAndroid Build Coastguard Worker expected_out = torch.zeros(1, 1, osize, osize, osize, dtype=torch.float) 9532*da0073e9SAndroid Build Coastguard Worker scale = 1.0 * isize / osize 9533*da0073e9SAndroid Build Coastguard Worker for o1 in range(osize): 9534*da0073e9SAndroid Build Coastguard Worker i1_f32 = o1 * scale 9535*da0073e9SAndroid Build Coastguard Worker i1 = int(i1_f32) 9536*da0073e9SAndroid Build Coastguard Worker for o2 in range(osize): 9537*da0073e9SAndroid Build Coastguard Worker i2_f32 = o2 * scale 9538*da0073e9SAndroid Build Coastguard Worker i2 = int(i2_f32) 9539*da0073e9SAndroid Build Coastguard Worker for o3 in range(osize): 9540*da0073e9SAndroid Build Coastguard Worker i3_f32 = o3 * scale 9541*da0073e9SAndroid Build Coastguard Worker i3 = int(i3_f32) 9542*da0073e9SAndroid Build Coastguard Worker expected_out[0, 0, o1, o2, o3] = in_t[0, 0, i1, i2, i3] 9543*da0073e9SAndroid Build Coastguard Worker expected_out = expected_out.to(device=device) 9544*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out_t, expected_out) 9545*da0073e9SAndroid Build Coastguard Worker 9546*da0073e9SAndroid Build Coastguard Worker @expectedFailureMPS # NotImplementedError: aten::_upsample_nearest_exact3d.out https://github.com/pytorch/pytorch/issues/77764 9547*da0073e9SAndroid Build Coastguard Worker @parametrize_test("memory_format", [torch.contiguous_format, torch.channels_last_3d]) 9548*da0073e9SAndroid Build Coastguard Worker @parametrize_test("isize, osize", [(20, 11), (10, 15)]) 9549*da0073e9SAndroid Build Coastguard Worker def test_upsamplingNearestExact3d_correctness(self, device, memory_format, isize, osize): 9550*da0073e9SAndroid Build Coastguard Worker # Here we check if output matches Scikit-Image/Scipy-like result 9551*da0073e9SAndroid Build Coastguard Worker # Checks https://github.com/pytorch/pytorch/issues/34808 9552*da0073e9SAndroid Build Coastguard Worker in_t = torch.arange(isize * isize * isize, dtype=torch.float, device=device) 9553*da0073e9SAndroid Build Coastguard Worker in_t = in_t.reshape(1, 1, isize, isize, isize) 9554*da0073e9SAndroid Build Coastguard Worker in_t = in_t.contiguous(memory_format=memory_format) 9555*da0073e9SAndroid Build Coastguard Worker out_t = F.interpolate( 9556*da0073e9SAndroid Build Coastguard Worker in_t, size=(osize, osize, osize), recompute_scale_factor=False, mode="nearest-exact" 9557*da0073e9SAndroid Build Coastguard Worker ) 9558*da0073e9SAndroid Build Coastguard Worker # compute expected output as Scikit-Image/Scipy 9559*da0073e9SAndroid Build Coastguard Worker expected_out = torch.zeros(1, 1, osize, osize, osize, dtype=torch.float) 9560*da0073e9SAndroid Build Coastguard Worker scale = 1.0 * isize / osize 9561*da0073e9SAndroid Build Coastguard Worker for o1 in range(osize): 9562*da0073e9SAndroid Build Coastguard Worker i1_f32 = (o1 + 0.5) * scale 9563*da0073e9SAndroid Build Coastguard Worker i1 = int(i1_f32) 9564*da0073e9SAndroid Build Coastguard Worker for o2 in range(osize): 9565*da0073e9SAndroid Build Coastguard Worker i2_f32 = (o2 + 0.5) * scale 9566*da0073e9SAndroid Build Coastguard Worker i2 = int(i2_f32) 9567*da0073e9SAndroid Build Coastguard Worker for o3 in range(osize): 9568*da0073e9SAndroid Build Coastguard Worker i3_f32 = (o3 + 0.5) * scale 9569*da0073e9SAndroid Build Coastguard Worker i3 = int(i3_f32) 9570*da0073e9SAndroid Build Coastguard Worker expected_out[0, 0, o1, o2, o3] = in_t[0, 0, i1, i2, i3] 9571*da0073e9SAndroid Build Coastguard Worker expected_out = expected_out.to(device=device) 9572*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out_t, expected_out) 9573*da0073e9SAndroid Build Coastguard Worker 9574*da0073e9SAndroid Build Coastguard Worker @parametrize_test("antialias", [True, False]) 9575*da0073e9SAndroid Build Coastguard Worker @parametrize_test("align_corners", [True, False]) 9576*da0073e9SAndroid Build Coastguard Worker @parametrize_test("mode", ["bilinear", "bicubic"]) 9577*da0073e9SAndroid Build Coastguard Worker @parametrize_test("memory_format", [torch.contiguous_format, torch.channels_last]) 9578*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 9579*da0073e9SAndroid Build Coastguard Worker def test_upsamplingBiMode2d(self, device, antialias, align_corners, mode, memory_format): 9580*da0073e9SAndroid Build Coastguard Worker # Forward AD does not support XLA because XLA tensors don't have storage 9581*da0073e9SAndroid Build Coastguard Worker check_forward_ad = torch.device(device).type != 'xla' 9582*da0073e9SAndroid Build Coastguard Worker 9583*da0073e9SAndroid Build Coastguard Worker kwargs = dict(mode=mode, align_corners=align_corners, antialias=antialias) 9584*da0073e9SAndroid Build Coastguard Worker # test float scale factor up & downsampling 9585*da0073e9SAndroid Build Coastguard Worker for scale_factor in [0.5, 1.5, 2]: 9586*da0073e9SAndroid Build Coastguard Worker in_t = torch.ones( 9587*da0073e9SAndroid Build Coastguard Worker 2, 3, 8, 8, device=device, 9588*da0073e9SAndroid Build Coastguard Worker dtype=torch.double).contiguous(memory_format=memory_format).requires_grad_() 9589*da0073e9SAndroid Build Coastguard Worker out_size = int(math.floor(in_t.shape[-1] * scale_factor)) 9590*da0073e9SAndroid Build Coastguard Worker with warnings.catch_warnings(record=True) as w: 9591*da0073e9SAndroid Build Coastguard Worker out_t = F.interpolate(in_t, scale_factor=scale_factor, **kwargs) 9592*da0073e9SAndroid Build Coastguard Worker expected_out = torch.ones(2, 3, out_size, out_size, device=device, dtype=torch.double) 9593*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected_out, out_t) 9594*da0073e9SAndroid Build Coastguard Worker # Assert that memory format is carried through to the output 9595*da0073e9SAndroid Build Coastguard Worker self.assertTrue(out_t.is_contiguous(memory_format=memory_format)) 9596*da0073e9SAndroid Build Coastguard Worker out_t.backward(torch.randn_like(out_t)) 9597*da0073e9SAndroid Build Coastguard Worker self.assertTrue(in_t.grad.is_contiguous(memory_format=memory_format)) 9598*da0073e9SAndroid Build Coastguard Worker 9599*da0073e9SAndroid Build Coastguard Worker if torch.device(device).type == 'cuda': 9600*da0073e9SAndroid Build Coastguard Worker # Bilinear backward is nondeterministic because of atomicAdd usage 9601*da0073e9SAndroid Build Coastguard Worker nondet_tol = 1e-5 9602*da0073e9SAndroid Build Coastguard Worker else: 9603*da0073e9SAndroid Build Coastguard Worker nondet_tol = 0.0 9604*da0073e9SAndroid Build Coastguard Worker 9605*da0073e9SAndroid Build Coastguard Worker input = torch.randn( 9606*da0073e9SAndroid Build Coastguard Worker 2, 3, 8, 8, device=device, 9607*da0073e9SAndroid Build Coastguard Worker dtype=torch.double).contiguous(memory_format=memory_format).requires_grad_() 9608*da0073e9SAndroid Build Coastguard Worker gradcheck( 9609*da0073e9SAndroid Build Coastguard Worker lambda x: F.interpolate(x, out_size, **kwargs), 9610*da0073e9SAndroid Build Coastguard Worker [input], 9611*da0073e9SAndroid Build Coastguard Worker check_forward_ad=check_forward_ad, nondet_tol=nondet_tol 9612*da0073e9SAndroid Build Coastguard Worker ) 9613*da0073e9SAndroid Build Coastguard Worker gradgradcheck( 9614*da0073e9SAndroid Build Coastguard Worker lambda x: F.interpolate(x, out_size, **kwargs), 9615*da0073e9SAndroid Build Coastguard Worker [input], 9616*da0073e9SAndroid Build Coastguard Worker check_fwd_over_rev=check_forward_ad, nondet_tol=nondet_tol 9617*da0073e9SAndroid Build Coastguard Worker ) 9618*da0073e9SAndroid Build Coastguard Worker 9619*da0073e9SAndroid Build Coastguard Worker # Assert that cpu and cuda give same results 9620*da0073e9SAndroid Build Coastguard Worker if torch.device(device).type == 'cuda': 9621*da0073e9SAndroid Build Coastguard Worker for shapes in [ 9622*da0073e9SAndroid Build Coastguard Worker (2, 2, 3, 4), (2, 3, 4, 5), (3, 1, 2, 2), (1, 5, 3, 2) 9623*da0073e9SAndroid Build Coastguard Worker ]: 9624*da0073e9SAndroid Build Coastguard Worker a_cuda = torch.randn( 9625*da0073e9SAndroid Build Coastguard Worker *shapes, device=device, dtype=torch.double 9626*da0073e9SAndroid Build Coastguard Worker ).contiguous(memory_format=memory_format).requires_grad_() 9627*da0073e9SAndroid Build Coastguard Worker a_cpu = a_cuda.detach().cpu().requires_grad_() 9628*da0073e9SAndroid Build Coastguard Worker 9629*da0073e9SAndroid Build Coastguard Worker with warnings.catch_warnings(record=True): 9630*da0073e9SAndroid Build Coastguard Worker out_cuda = F.interpolate(a_cuda, scale_factor=scale_factor, **kwargs) 9631*da0073e9SAndroid Build Coastguard Worker out_cpu = F.interpolate(a_cpu, scale_factor=scale_factor, **kwargs) 9632*da0073e9SAndroid Build Coastguard Worker 9633*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out_cpu, out_cuda.cpu()) 9634*da0073e9SAndroid Build Coastguard Worker 9635*da0073e9SAndroid Build Coastguard Worker g_cuda = torch.randn_like(out_cuda) 9636*da0073e9SAndroid Build Coastguard Worker g_cpu = g_cuda.cpu() 9637*da0073e9SAndroid Build Coastguard Worker 9638*da0073e9SAndroid Build Coastguard Worker out_cuda.backward(g_cuda) 9639*da0073e9SAndroid Build Coastguard Worker out_cpu.backward(g_cpu) 9640*da0073e9SAndroid Build Coastguard Worker 9641*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a_cuda.grad, a_cpu.grad) 9642*da0073e9SAndroid Build Coastguard Worker 9643*da0073e9SAndroid Build Coastguard Worker @parametrize_test("antialias", [True, False]) 9644*da0073e9SAndroid Build Coastguard Worker @parametrize_test("num_channels", [3, 5]) 9645*da0073e9SAndroid Build Coastguard Worker @parametrize_test("mode", ["nearest", "nearest-exact", "bilinear", "bicubic"]) 9646*da0073e9SAndroid Build Coastguard Worker @parametrize_test("dtype", integral_types() + floating_types()) 9647*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 9648*da0073e9SAndroid Build Coastguard Worker def test_upsamplingBiMode2d_nonsupported_dtypes(self, device, antialias, num_channels, mode, dtype): 9649*da0073e9SAndroid Build Coastguard Worker x = torch.ones(1, num_channels, 32, 32, dtype=dtype, device=device) 9650*da0073e9SAndroid Build Coastguard Worker 9651*da0073e9SAndroid Build Coastguard Worker should_raise_runtime_error = True 9652*da0073e9SAndroid Build Coastguard Worker 9653*da0073e9SAndroid Build Coastguard Worker if "nearest" in mode: 9654*da0073e9SAndroid Build Coastguard Worker if antialias: 9655*da0073e9SAndroid Build Coastguard Worker raise SkipTest("Nearest mode does not have antialiasing") 9656*da0073e9SAndroid Build Coastguard Worker if dtype in (torch.uint8, ) + floating_types(): 9657*da0073e9SAndroid Build Coastguard Worker should_raise_runtime_error = False 9658*da0073e9SAndroid Build Coastguard Worker 9659*da0073e9SAndroid Build Coastguard Worker elif mode in ("bilinear", "bicubic"): 9660*da0073e9SAndroid Build Coastguard Worker if dtype in floating_types() or (device == "cpu" and dtype == torch.uint8): 9661*da0073e9SAndroid Build Coastguard Worker should_raise_runtime_error = False 9662*da0073e9SAndroid Build Coastguard Worker 9663*da0073e9SAndroid Build Coastguard Worker if should_raise_runtime_error: 9664*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "not implemented for"): 9665*da0073e9SAndroid Build Coastguard Worker F.interpolate(x, (12, 12), mode=mode, antialias=antialias) 9666*da0073e9SAndroid Build Coastguard Worker else: 9667*da0073e9SAndroid Build Coastguard Worker _ = F.interpolate(x, (12, 12), mode=mode, antialias=antialias) 9668*da0073e9SAndroid Build Coastguard Worker 9669*da0073e9SAndroid Build Coastguard Worker @expectedFailureMPS # NotImplementedError: aten::_upsample_bilinear2d_aa.out https://github.com/pytorch/pytorch/issues/77764 9670*da0073e9SAndroid Build Coastguard Worker @parametrize_test("memory_format", [torch.contiguous_format, torch.channels_last]) 9671*da0073e9SAndroid Build Coastguard Worker def test_upsamplingBilinear2d_aa_correctness(self, device, memory_format): 9672*da0073e9SAndroid Build Coastguard Worker # NOTE: We expand the batch dim such that `b*c` is above the maximum 9673*da0073e9SAndroid Build Coastguard Worker # size of CUDA grid z-dimension (2**16) 9674*da0073e9SAndroid Build Coastguard Worker shape = [23000, 3, 8, 8] 9675*da0073e9SAndroid Build Coastguard Worker t_in = torch.arange(3 * 8 * 8, dtype=torch.float, device=device).reshape(1, *shape[1:]) 9676*da0073e9SAndroid Build Coastguard Worker t_in = t_in.expand(shape) 9677*da0073e9SAndroid Build Coastguard Worker t_in = t_in.contiguous(memory_format=memory_format) 9678*da0073e9SAndroid Build Coastguard Worker # This expected result is obtain using PIL.Image.resize 9679*da0073e9SAndroid Build Coastguard Worker # for c in range(3): 9680*da0073e9SAndroid Build Coastguard Worker # a_in = t_in.numpy()[0, c, ...] 9681*da0073e9SAndroid Build Coastguard Worker # pil_in = Image.fromarray(a_in) 9682*da0073e9SAndroid Build Coastguard Worker # pil_out = pil_in.resize((2, 2), resample=Image.LINEAR) 9683*da0073e9SAndroid Build Coastguard Worker expected_out = torch.tensor([ 9684*da0073e9SAndroid Build Coastguard Worker 17.035713, 20.25, 42.75, 45.964287, 81.03572, 84.25, 9685*da0073e9SAndroid Build Coastguard Worker 106.75, 109.96428, 145.0357, 148.25, 170.75, 173.9643 9686*da0073e9SAndroid Build Coastguard Worker ], device=device, dtype=t_in.dtype).reshape(1, 3, 2, 2) 9687*da0073e9SAndroid Build Coastguard Worker t_out = F.interpolate(t_in, size=(2, 2), mode="bilinear", align_corners=False, antialias=True) 9688*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected_out.expand([*shape[:2], 2, 2]), t_out) 9689*da0073e9SAndroid Build Coastguard Worker 9690*da0073e9SAndroid Build Coastguard Worker # Partially passes. NotImplementedError: aten::upsample_bicubic2d.out https://github.com/pytorch/pytorch/issues/77764 9691*da0073e9SAndroid Build Coastguard Worker @skipIfMps 9692*da0073e9SAndroid Build Coastguard Worker @parametrize_test("memory_format", [torch.contiguous_format, torch.channels_last]) 9693*da0073e9SAndroid Build Coastguard Worker @parametrize_test("mode", ["bilinear", "bicubic"]) 9694*da0073e9SAndroid Build Coastguard Worker @parametrize_test("antialias", [True, False]) 9695*da0073e9SAndroid Build Coastguard Worker @parametrize_test("align_corners", [True, False]) 9696*da0073e9SAndroid Build Coastguard Worker @parametrize_test("num_channels", [3, 5]) 9697*da0073e9SAndroid Build Coastguard Worker @parametrize_test("output_size", [32, 600]) 9698*da0073e9SAndroid Build Coastguard Worker @parametrize_test("check_as_unsqueezed_3d_tensor", [True, False]) 9699*da0073e9SAndroid Build Coastguard Worker @parametrize_test("non_contig", [False, "sliced", "restrided"]) 9700*da0073e9SAndroid Build Coastguard Worker @parametrize_test("batch_size", [1, 5]) 9701*da0073e9SAndroid Build Coastguard Worker def test_upsamplingBiMode2d_consistency( 9702*da0073e9SAndroid Build Coastguard Worker self, 9703*da0073e9SAndroid Build Coastguard Worker device, 9704*da0073e9SAndroid Build Coastguard Worker memory_format, 9705*da0073e9SAndroid Build Coastguard Worker mode, 9706*da0073e9SAndroid Build Coastguard Worker antialias, 9707*da0073e9SAndroid Build Coastguard Worker align_corners, 9708*da0073e9SAndroid Build Coastguard Worker num_channels, 9709*da0073e9SAndroid Build Coastguard Worker output_size, 9710*da0073e9SAndroid Build Coastguard Worker check_as_unsqueezed_3d_tensor, 9711*da0073e9SAndroid Build Coastguard Worker non_contig, 9712*da0073e9SAndroid Build Coastguard Worker batch_size, 9713*da0073e9SAndroid Build Coastguard Worker ): 9714*da0073e9SAndroid Build Coastguard Worker # Check output value consistency between resized_input_uint8 and resized input_float 9715*da0073e9SAndroid Build Coastguard Worker if torch.device(device).type == "cuda": 9716*da0073e9SAndroid Build Coastguard Worker raise SkipTest("CUDA implementation is not yet supporting uint8") 9717*da0073e9SAndroid Build Coastguard Worker 9718*da0073e9SAndroid Build Coastguard Worker torch.manual_seed(0) 9719*da0073e9SAndroid Build Coastguard Worker 9720*da0073e9SAndroid Build Coastguard Worker # - input range is set to [30, 220] for bicubic mode, because the bicubic kernel may create 9721*da0073e9SAndroid Build Coastguard Worker # [intermediate] values outside of the [0, 255] range, which need 9722*da0073e9SAndroid Build Coastguard Worker # to be clipped in uint8 path, but not in float path. This isn't 9723*da0073e9SAndroid Build Coastguard Worker # an issue with bilinear kernel. 9724*da0073e9SAndroid Build Coastguard Worker input_range = (30, 220) if mode == "bicubic" else (0, 256) 9725*da0073e9SAndroid Build Coastguard Worker input_ui8 = torch.randint(*input_range, size=(batch_size, num_channels, 400, 400), dtype=torch.uint8, device=device) 9726*da0073e9SAndroid Build Coastguard Worker input_ui8 = input_ui8.contiguous(memory_format=memory_format) 9727*da0073e9SAndroid Build Coastguard Worker 9728*da0073e9SAndroid Build Coastguard Worker if non_contig == "sliced": 9729*da0073e9SAndroid Build Coastguard Worker input_ui8 = input_ui8[:, :, 10:-10, 10:-10] 9730*da0073e9SAndroid Build Coastguard Worker elif non_contig == "restrided": 9731*da0073e9SAndroid Build Coastguard Worker input_ui8 = input_ui8[:, :, ::2, ::2] 9732*da0073e9SAndroid Build Coastguard Worker 9733*da0073e9SAndroid Build Coastguard Worker if batch_size == 1 and check_as_unsqueezed_3d_tensor: 9734*da0073e9SAndroid Build Coastguard Worker input_ui8 = input_ui8[0, ...] 9735*da0073e9SAndroid Build Coastguard Worker input_ui8 = input_ui8[None, ...] 9736*da0073e9SAndroid Build Coastguard Worker 9737*da0073e9SAndroid Build Coastguard Worker input_f32 = input_ui8.float() 9738*da0073e9SAndroid Build Coastguard Worker 9739*da0073e9SAndroid Build Coastguard Worker output_f32 = F.interpolate( 9740*da0073e9SAndroid Build Coastguard Worker input_f32, size=(output_size, output_size), mode=mode, align_corners=align_corners, antialias=antialias 9741*da0073e9SAndroid Build Coastguard Worker ).round().clip(0, 255) 9742*da0073e9SAndroid Build Coastguard Worker output_ui8 = F.interpolate( 9743*da0073e9SAndroid Build Coastguard Worker input_ui8, size=(output_size, output_size), mode=mode, align_corners=align_corners, antialias=antialias 9744*da0073e9SAndroid Build Coastguard Worker ) 9745*da0073e9SAndroid Build Coastguard Worker 9746*da0073e9SAndroid Build Coastguard Worker if non_contig is False: 9747*da0073e9SAndroid Build Coastguard Worker self.assertTrue(input_ui8.is_contiguous(memory_format=memory_format)) 9748*da0073e9SAndroid Build Coastguard Worker 9749*da0073e9SAndroid Build Coastguard Worker # FIXME if-clause shows the current behaviour which is definitely unexpected. 9750*da0073e9SAndroid Build Coastguard Worker # Ideally we want to fix it such that both the ui8 and f32 outputs are also channels_last 9751*da0073e9SAndroid Build Coastguard Worker # See for more details: https://github.com/pytorch/pytorch/pull/100373 9752*da0073e9SAndroid Build Coastguard Worker if batch_size == 1 and check_as_unsqueezed_3d_tensor and memory_format == torch.channels_last: 9753*da0073e9SAndroid Build Coastguard Worker self.assertTrue(output_ui8.is_contiguous()) 9754*da0073e9SAndroid Build Coastguard Worker self.assertTrue(output_f32.is_contiguous()) 9755*da0073e9SAndroid Build Coastguard Worker else: 9756*da0073e9SAndroid Build Coastguard Worker self.assertTrue(output_ui8.is_contiguous(memory_format=memory_format)) 9757*da0073e9SAndroid Build Coastguard Worker self.assertTrue(output_f32.is_contiguous(memory_format=memory_format)) 9758*da0073e9SAndroid Build Coastguard Worker 9759*da0073e9SAndroid Build Coastguard Worker if mode == "bilinear": 9760*da0073e9SAndroid Build Coastguard Worker torch.testing.assert_close(output_f32, output_ui8.float(), rtol=0, atol=1) 9761*da0073e9SAndroid Build Coastguard Worker else: 9762*da0073e9SAndroid Build Coastguard Worker diff = (output_f32 - output_ui8.float()).abs() 9763*da0073e9SAndroid Build Coastguard Worker self.assertLess(diff.max(), 15) 9764*da0073e9SAndroid Build Coastguard Worker 9765*da0073e9SAndroid Build Coastguard Worker threshold = 2 9766*da0073e9SAndroid Build Coastguard Worker percent = 3 9767*da0073e9SAndroid Build Coastguard Worker self.assertLess((diff > threshold).float().mean(), percent / 100) 9768*da0073e9SAndroid Build Coastguard Worker 9769*da0073e9SAndroid Build Coastguard Worker threshold = 5 9770*da0073e9SAndroid Build Coastguard Worker percent = 1 9771*da0073e9SAndroid Build Coastguard Worker self.assertLess((diff > threshold).float().mean(), percent / 100) 9772*da0073e9SAndroid Build Coastguard Worker 9773*da0073e9SAndroid Build Coastguard Worker self.assertLess(diff.mean(), 0.4) 9774*da0073e9SAndroid Build Coastguard Worker 9775*da0073e9SAndroid Build Coastguard Worker @parametrize_test("memory_format", [torch.contiguous_format, torch.channels_last]) 9776*da0073e9SAndroid Build Coastguard Worker @parametrize_test("align_corners", [True, False]) 9777*da0073e9SAndroid Build Coastguard Worker @parametrize_test("input_size, output_size", [(399, 437), (403, 377)]) 9778*da0073e9SAndroid Build Coastguard Worker def test_upsamplingBiLinear2d_consistency_interp_size_bug(self, device, memory_format, align_corners, input_size, output_size): 9779*da0073e9SAndroid Build Coastguard Worker # Non-regression test for https://github.com/pytorch/pytorch/pull/101403 9780*da0073e9SAndroid Build Coastguard Worker 9781*da0073e9SAndroid Build Coastguard Worker if torch.device(device).type == "cuda": 9782*da0073e9SAndroid Build Coastguard Worker raise SkipTest("CUDA implementation is not yet supporting uint8") 9783*da0073e9SAndroid Build Coastguard Worker 9784*da0073e9SAndroid Build Coastguard Worker mode = "bilinear" 9785*da0073e9SAndroid Build Coastguard Worker input_ui8 = torch.randint(0, 256, size=(1, 3, input_size, input_size), dtype=torch.uint8, device=device) 9786*da0073e9SAndroid Build Coastguard Worker input_ui8 = input_ui8.contiguous(memory_format=memory_format) 9787*da0073e9SAndroid Build Coastguard Worker input_f32 = input_ui8.float() 9788*da0073e9SAndroid Build Coastguard Worker 9789*da0073e9SAndroid Build Coastguard Worker output_f32 = F.interpolate( 9790*da0073e9SAndroid Build Coastguard Worker input_f32, size=(output_size, output_size), mode=mode, align_corners=align_corners, antialias=False 9791*da0073e9SAndroid Build Coastguard Worker ).round().to(torch.uint8) 9792*da0073e9SAndroid Build Coastguard Worker output_ui8 = F.interpolate( 9793*da0073e9SAndroid Build Coastguard Worker input_ui8, size=(output_size, output_size), mode=mode, align_corners=align_corners, antialias=False 9794*da0073e9SAndroid Build Coastguard Worker ) 9795*da0073e9SAndroid Build Coastguard Worker torch.testing.assert_close(output_f32, output_ui8, atol=1, rtol=0) 9796*da0073e9SAndroid Build Coastguard Worker 9797*da0073e9SAndroid Build Coastguard Worker @expectedFailureMPS # NotImplementedError: aten::upsample_bicubic2d.out https://github.com/pytorch/pytorch/issues/77764 9798*da0073e9SAndroid Build Coastguard Worker def test_upsamplingBicubic2d_correctness(self, device): 9799*da0073e9SAndroid Build Coastguard Worker # test output against known input: align_corners=False result must match opencv 9800*da0073e9SAndroid Build Coastguard Worker in_t = torch.arange(8., device=device).view(1, 2, 2, 2) 9801*da0073e9SAndroid Build Coastguard Worker expected_out_t = torch.tensor( 9802*da0073e9SAndroid Build Coastguard Worker [[[[-0.31641, 0.01562, 0.56250, 0.89453], 9803*da0073e9SAndroid Build Coastguard Worker [0.34766, 0.67969, 1.22656, 1.55859], 9804*da0073e9SAndroid Build Coastguard Worker [1.44141, 1.77344, 2.32031, 2.65234], 9805*da0073e9SAndroid Build Coastguard Worker [2.10547, 2.43750, 2.98438, 3.31641]], 9806*da0073e9SAndroid Build Coastguard Worker 9807*da0073e9SAndroid Build Coastguard Worker [[3.68359, 4.01562, 4.56250, 4.89453], 9808*da0073e9SAndroid Build Coastguard Worker [4.34766, 4.67969, 5.22656, 5.55859], 9809*da0073e9SAndroid Build Coastguard Worker [5.44141, 5.77344, 6.32031, 6.65234], 9810*da0073e9SAndroid Build Coastguard Worker [6.10547, 6.43750, 6.98438, 7.31641]]]], device=device) 9811*da0073e9SAndroid Build Coastguard Worker out_t = F.interpolate(in_t, scale_factor=2, mode='bicubic', align_corners=False) 9812*da0073e9SAndroid Build Coastguard Worker torch.set_printoptions(precision=5) 9813*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out_t, expected_out_t, atol=1e-5, rtol=0) 9814*da0073e9SAndroid Build Coastguard Worker 9815*da0073e9SAndroid Build Coastguard Worker @expectedFailureMPS # NotImplementedError: aten::_upsample_bicubic2d_aa.out https://github.com/pytorch/pytorch/issues/77764 9816*da0073e9SAndroid Build Coastguard Worker @parametrize_test("memory_format", [torch.contiguous_format, torch.channels_last]) 9817*da0073e9SAndroid Build Coastguard Worker def test_upsamplingBicubic2d_aa_correctness(self, device, memory_format): 9818*da0073e9SAndroid Build Coastguard Worker t_in = torch.arange(3 * 8 * 8, dtype=torch.float, device=device).reshape(1, 3, 8, 8) 9819*da0073e9SAndroid Build Coastguard Worker t_in = t_in.contiguous(memory_format=memory_format) 9820*da0073e9SAndroid Build Coastguard Worker # This expected result is obtain using PIL.Image.resize 9821*da0073e9SAndroid Build Coastguard Worker # for c in range(3): 9822*da0073e9SAndroid Build Coastguard Worker # a_in = t_in.numpy()[0, c, ...] 9823*da0073e9SAndroid Build Coastguard Worker # pil_in = Image.fromarray(a_in) 9824*da0073e9SAndroid Build Coastguard Worker # pil_out = pil_in.resize((2, 2), resample=Image.BICUBIC) 9825*da0073e9SAndroid Build Coastguard Worker expected_out = torch.tensor([ 9826*da0073e9SAndroid Build Coastguard Worker 15.1205635, 18.760439, 44.23956, 47.879436, 79.12056, 82.76044, 9827*da0073e9SAndroid Build Coastguard Worker 108.23956, 111.87944, 143.12057, 146.76044, 172.23956, 175.87943 9828*da0073e9SAndroid Build Coastguard Worker ], device=device, dtype=t_in.dtype).reshape(1, 3, 2, 2) 9829*da0073e9SAndroid Build Coastguard Worker t_out = F.interpolate(t_in, size=(2, 2), mode="bicubic", align_corners=False, antialias=True) 9830*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected_out, t_out) 9831*da0073e9SAndroid Build Coastguard Worker 9832*da0073e9SAndroid Build Coastguard Worker @expectedFailureMPS # NotImplementedError: aten::upsample_trilinear3d.out https://github.com/pytorch/pytorch/issues/77764 9833*da0073e9SAndroid Build Coastguard Worker @parametrize_test("align_corners", [True, False]) 9834*da0073e9SAndroid Build Coastguard Worker @parametrize_test("memory_format", [torch.contiguous_format, torch.channels_last_3d]) 9835*da0073e9SAndroid Build Coastguard Worker def test_upsamplingTrilinear3d(self, device, align_corners, memory_format): 9836*da0073e9SAndroid Build Coastguard Worker kwargs = dict(mode='trilinear', align_corners=align_corners) 9837*da0073e9SAndroid Build Coastguard Worker 9838*da0073e9SAndroid Build Coastguard Worker # test float scale factor up & downsampling 9839*da0073e9SAndroid Build Coastguard Worker for scale_factor in [0.5, 1.5, 2]: 9840*da0073e9SAndroid Build Coastguard Worker m = nn.Upsample(scale_factor=scale_factor, **kwargs) 9841*da0073e9SAndroid Build Coastguard Worker in_t = torch.ones(1, 2, 4, 4, 4, device=device, dtype=torch.double) 9842*da0073e9SAndroid Build Coastguard Worker in_t = in_t.contiguous(memory_format=memory_format).requires_grad_() 9843*da0073e9SAndroid Build Coastguard Worker out_size = int(math.floor(in_t.shape[-1] * scale_factor)) 9844*da0073e9SAndroid Build Coastguard Worker with warnings.catch_warnings(record=True) as w: 9845*da0073e9SAndroid Build Coastguard Worker out_t = m(in_t) 9846*da0073e9SAndroid Build Coastguard Worker expected_out = torch.ones(1, 2, out_size, out_size, out_size, device=device, dtype=torch.double) 9847*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected_out, out_t) 9848*da0073e9SAndroid Build Coastguard Worker # Assert that memory format is carried through to the output 9849*da0073e9SAndroid Build Coastguard Worker self.assertTrue(out_t.is_contiguous(memory_format=memory_format)) 9850*da0073e9SAndroid Build Coastguard Worker 9851*da0073e9SAndroid Build Coastguard Worker grad_out = torch.randn_like(out_t).contiguous(memory_format=memory_format) 9852*da0073e9SAndroid Build Coastguard Worker in_t.grad = None 9853*da0073e9SAndroid Build Coastguard Worker out_t.backward(grad_out) 9854*da0073e9SAndroid Build Coastguard Worker grad_in = in_t.grad 9855*da0073e9SAndroid Build Coastguard Worker self.assertTrue(grad_in.is_contiguous(memory_format=memory_format)) 9856*da0073e9SAndroid Build Coastguard Worker 9857*da0073e9SAndroid Build Coastguard Worker if memory_format == torch.channels_last_3d: 9858*da0073e9SAndroid Build Coastguard Worker # check if grad inputs CF and CL match 9859*da0073e9SAndroid Build Coastguard Worker in_t.grad = None 9860*da0073e9SAndroid Build Coastguard Worker out_t.backward(grad_out.contiguous()) 9861*da0073e9SAndroid Build Coastguard Worker self.assertEqual(in_t.grad, grad_in) 9862*da0073e9SAndroid Build Coastguard Worker 9863*da0073e9SAndroid Build Coastguard Worker input = torch.randn(1, 2, 4, 4, 4, requires_grad=True, dtype=torch.double) 9864*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 9865*da0073e9SAndroid Build Coastguard Worker F.interpolate(input, (out_size, out_size, out_size), **kwargs), 9866*da0073e9SAndroid Build Coastguard Worker F.interpolate(input, scale_factor=scale_factor, **kwargs)) 9867*da0073e9SAndroid Build Coastguard Worker gradcheck(lambda x: F.interpolate(x, out_size, **kwargs), [input]) 9868*da0073e9SAndroid Build Coastguard Worker gradgradcheck(lambda x: F.interpolate(x, out_size, **kwargs), [input]) 9869*da0073e9SAndroid Build Coastguard Worker 9870*da0073e9SAndroid Build Coastguard Worker @onlyCUDA 9871*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.half) 9872*da0073e9SAndroid Build Coastguard Worker @largeTensorTest('40GB') 9873*da0073e9SAndroid Build Coastguard Worker def test_upsampling_64bit_indexing_channels_last(self, device, dtype): 9874*da0073e9SAndroid Build Coastguard Worker x = torch.rand((32, 64, 512, 512), dtype=dtype, device=device) 9875*da0073e9SAndroid Build Coastguard Worker out = torch.nn.functional.interpolate(x.to(memory_format=torch.channels_last), scale_factor=2, mode='nearest') 9876*da0073e9SAndroid Build Coastguard Worker out_ref = torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest') 9877*da0073e9SAndroid Build Coastguard Worker del x 9878*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.allclose(out, out_ref)) 9879*da0073e9SAndroid Build Coastguard Worker 9880*da0073e9SAndroid Build Coastguard Worker @onlyCUDA 9881*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.half) 9882*da0073e9SAndroid Build Coastguard Worker @largeTensorTest('40GB') 9883*da0073e9SAndroid Build Coastguard Worker def test_replicatepad_64bit_indexing(self, device, dtype): 9884*da0073e9SAndroid Build Coastguard Worker conv = torch.nn.Conv1d(128, 128, 3, 1, 1, padding_mode="replicate", device=device, dtype=dtype) 9885*da0073e9SAndroid Build Coastguard Worker x = torch.randn(size=(256 * 448 * 2, 128, 96), dtype=dtype, device=device) 9886*da0073e9SAndroid Build Coastguard Worker y = conv(x) 9887*da0073e9SAndroid Build Coastguard Worker torch.mean(y).backward() 9888*da0073e9SAndroid Build Coastguard Worker 9889*da0073e9SAndroid Build Coastguard Worker @onlyCUDA 9890*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.half) 9891*da0073e9SAndroid Build Coastguard Worker @largeTensorTest('40GB') 9892*da0073e9SAndroid Build Coastguard Worker def test_upsamplingnearest2d_backward_64bit_indexing(self, device, dtype): 9893*da0073e9SAndroid Build Coastguard Worker x = torch.randn(size=(36, 128, 512, 512), device=device, dtype=dtype).requires_grad_() 9894*da0073e9SAndroid Build Coastguard Worker y = F.interpolate(x, scale_factor=2, mode="nearest") 9895*da0073e9SAndroid Build Coastguard Worker y.backward(torch.randn_like(y)) 9896*da0073e9SAndroid Build Coastguard Worker 9897*da0073e9SAndroid Build Coastguard Worker def _slow_masked_softmax(self, input, mask): 9898*da0073e9SAndroid Build Coastguard Worker exp = torch.exp(input) 9899*da0073e9SAndroid Build Coastguard Worker exp = exp * mask 9900*da0073e9SAndroid Build Coastguard Worker s = exp.sum(dim=3, keepdim=True).expand(exp.size()) 9901*da0073e9SAndroid Build Coastguard Worker return exp / s 9902*da0073e9SAndroid Build Coastguard Worker 9903*da0073e9SAndroid Build Coastguard Worker def test_masked_softmax_mask_types(self, device): 9904*da0073e9SAndroid Build Coastguard Worker # Test that mask type 0 (LxL attention mask), mask type 1 (BxL padding mask), 9905*da0073e9SAndroid Build Coastguard Worker # and mask type 2 (generic BxHxLxL mask) are processed correctly on the 9906*da0073e9SAndroid Build Coastguard Worker # fast path and the results match explicit slow calculation. 9907*da0073e9SAndroid Build Coastguard Worker sizes = [(1, 1, 32), (3, 16, 310), (12, 4, 1024), (4, 2, 1200)] 9908*da0073e9SAndroid Build Coastguard Worker 9909*da0073e9SAndroid Build Coastguard Worker for (B, num_heads, L) in sizes: 9910*da0073e9SAndroid Build Coastguard Worker 9911*da0073e9SAndroid Build Coastguard Worker # mask_type == 0 => attention mask of shape LxL 9912*da0073e9SAndroid Build Coastguard Worker src_mask_orig = torch.randint(0, 2, (L, L)).bool() 9913*da0073e9SAndroid Build Coastguard Worker src_mask = src_mask_orig.reshape(1, 1, L, L).expand(B, num_heads, L, L).bool() 9914*da0073e9SAndroid Build Coastguard Worker 9915*da0073e9SAndroid Build Coastguard Worker # mask_type == 1 => padding mask of shape BxL 9916*da0073e9SAndroid Build Coastguard Worker src_key_padding_mask_orig = torch.randint(0, 2, (B, L)).bool() 9917*da0073e9SAndroid Build Coastguard Worker src_key_padding_mask = src_key_padding_mask_orig.reshape(B, 1, 1, L).expand(B, num_heads, L, L).bool() 9918*da0073e9SAndroid Build Coastguard Worker 9919*da0073e9SAndroid Build Coastguard Worker # mask_type == 2 => shape BxHxLxL 9920*da0073e9SAndroid Build Coastguard Worker generic_mask = torch.randint(0, 2, (B, num_heads, L, L)).bool() 9921*da0073e9SAndroid Build Coastguard Worker masks = [(src_mask_orig, src_mask, 0), 9922*da0073e9SAndroid Build Coastguard Worker (src_key_padding_mask_orig, src_key_padding_mask, 1), 9923*da0073e9SAndroid Build Coastguard Worker (generic_mask, generic_mask, 2) 9924*da0073e9SAndroid Build Coastguard Worker ] 9925*da0073e9SAndroid Build Coastguard Worker for dim in [0, 3]: 9926*da0073e9SAndroid Build Coastguard Worker for mask_orig, mask, mask_type in masks: 9927*da0073e9SAndroid Build Coastguard Worker if (self.device_type == "cuda") and (num_heads % 2) and (mask_type == 1): 9928*da0073e9SAndroid Build Coastguard Worker # CUDA path doesn't support padding mask when the number of heads is odd 9929*da0073e9SAndroid Build Coastguard Worker continue 9930*da0073e9SAndroid Build Coastguard Worker input = torch.randn((B, num_heads, L, L)) 9931*da0073e9SAndroid Build Coastguard Worker if (self.device_type == "cuda"): 9932*da0073e9SAndroid Build Coastguard Worker input = input.cuda() 9933*da0073e9SAndroid Build Coastguard Worker mask = mask.cuda() 9934*da0073e9SAndroid Build Coastguard Worker mask_orig = mask_orig.cuda() 9935*da0073e9SAndroid Build Coastguard Worker native_res = torch._masked_softmax(input, mask_orig, dim, mask_type) 9936*da0073e9SAndroid Build Coastguard Worker mask = ~mask 9937*da0073e9SAndroid Build Coastguard Worker 9938*da0073e9SAndroid Build Coastguard Worker def slow_masked_softmax(input, mask): 9939*da0073e9SAndroid Build Coastguard Worker exp = torch.exp(input) 9940*da0073e9SAndroid Build Coastguard Worker exp = exp * mask 9941*da0073e9SAndroid Build Coastguard Worker s = exp.sum(dim=dim, keepdim=True).expand(exp.size()) 9942*da0073e9SAndroid Build Coastguard Worker return exp / s 9943*da0073e9SAndroid Build Coastguard Worker 9944*da0073e9SAndroid Build Coastguard Worker pt_res = slow_masked_softmax(input, mask) 9945*da0073e9SAndroid Build Coastguard Worker pt_res = torch.nan_to_num(pt_res) 9946*da0073e9SAndroid Build Coastguard Worker 9947*da0073e9SAndroid Build Coastguard Worker mask_not = mask.logical_not() 9948*da0073e9SAndroid Build Coastguard Worker # In result, should only fill the entirely masked out rows since those are non-deterministic (*may* be 0) 9949*da0073e9SAndroid Build Coastguard Worker # Converts rows with all True's to False 9950*da0073e9SAndroid Build Coastguard Worker mask_out = mask_not.all(dim, keepdim=True).expand(mask_not.shape) 9951*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 9952*da0073e9SAndroid Build Coastguard Worker pt_res.masked_fill(mask_out, 0), 9953*da0073e9SAndroid Build Coastguard Worker native_res.masked_fill(mask_out, 0), 9954*da0073e9SAndroid Build Coastguard Worker exact_dtype=True 9955*da0073e9SAndroid Build Coastguard Worker ) 9956*da0073e9SAndroid Build Coastguard Worker 9957*da0073e9SAndroid Build Coastguard Worker @onlyCUDA 9958*da0073e9SAndroid Build Coastguard Worker @gcIfJetson 9959*da0073e9SAndroid Build Coastguard Worker def test_masked_softmax_devices_parity(self): 9960*da0073e9SAndroid Build Coastguard Worker # Test that softmax with mask type 0 (LxL attention mask), mask type 1 (BxL padding mask), 9961*da0073e9SAndroid Build Coastguard Worker # and mask type 2 (BxHxLxL generic mask) gives the same result on CPU and on CUDA. 9962*da0073e9SAndroid Build Coastguard Worker 9963*da0073e9SAndroid Build Coastguard Worker sizes = [(1, 1, 32), (3, 16, 310), (12, 4, 1024), (4, 2, 1200)] 9964*da0073e9SAndroid Build Coastguard Worker for (B, num_heads, L) in sizes: 9965*da0073e9SAndroid Build Coastguard Worker # mask_type == 0 => attention mask of shape LxL 9966*da0073e9SAndroid Build Coastguard Worker src_mask = torch.randint(0, 2, (L, L)).bool() 9967*da0073e9SAndroid Build Coastguard Worker # mask_type == 1 => padding mask of shape BxL 9968*da0073e9SAndroid Build Coastguard Worker src_key_padding_mask = torch.randint(0, 2, (B, L)).bool() 9969*da0073e9SAndroid Build Coastguard Worker # mask_type == 2 => generic mask of shape BxHxLxL 9970*da0073e9SAndroid Build Coastguard Worker generic_mask = torch.randint(0, 2, (B, num_heads, L, L)).bool() 9971*da0073e9SAndroid Build Coastguard Worker masks = [(src_mask, 0), (src_key_padding_mask, 1), (generic_mask, 2)] 9972*da0073e9SAndroid Build Coastguard Worker input = torch.randn((B, num_heads, L, L)) 9973*da0073e9SAndroid Build Coastguard Worker for dim in [0, 3]: 9974*da0073e9SAndroid Build Coastguard Worker for mask, mask_type in masks: 9975*da0073e9SAndroid Build Coastguard Worker if (num_heads % 2) and (mask_type == 1): 9976*da0073e9SAndroid Build Coastguard Worker # CUDA path doesn't support padding mask when the number of heads is odd 9977*da0073e9SAndroid Build Coastguard Worker continue 9978*da0073e9SAndroid Build Coastguard Worker 9979*da0073e9SAndroid Build Coastguard Worker def softmax_on_device(mask, input, device): 9980*da0073e9SAndroid Build Coastguard Worker # Compute softmax on a given device 9981*da0073e9SAndroid Build Coastguard Worker input_device = input.to(device) 9982*da0073e9SAndroid Build Coastguard Worker mask_device = mask.to(device) 9983*da0073e9SAndroid Build Coastguard Worker softmax_res = torch._masked_softmax(input_device, mask_device, dim, mask_type) 9984*da0073e9SAndroid Build Coastguard Worker if mask_type == 0: 9985*da0073e9SAndroid Build Coastguard Worker mask_expanded = mask_device.reshape(1, 1, L, L).expand(B, num_heads, L, L).bool() 9986*da0073e9SAndroid Build Coastguard Worker elif mask_type == 1: 9987*da0073e9SAndroid Build Coastguard Worker mask_expanded = mask_device.reshape(B, 1, 1, L).expand(B, num_heads, L, L).bool() 9988*da0073e9SAndroid Build Coastguard Worker else: 9989*da0073e9SAndroid Build Coastguard Worker mask_expanded = mask_device 9990*da0073e9SAndroid Build Coastguard Worker # In result, should only fill the entirely masked out rows since those are non-deterministic (*may* be 0) 9991*da0073e9SAndroid Build Coastguard Worker # Fill rows with all True's with 0 9992*da0073e9SAndroid Build Coastguard Worker mask_out = mask_expanded.all(dim, keepdim=True).expand(mask_expanded.shape) 9993*da0073e9SAndroid Build Coastguard Worker softmax_res = softmax_res.masked_fill(mask_out, 0) 9994*da0073e9SAndroid Build Coastguard Worker return softmax_res 9995*da0073e9SAndroid Build Coastguard Worker 9996*da0073e9SAndroid Build Coastguard Worker cpu_res = softmax_on_device(mask, input, "cpu") 9997*da0073e9SAndroid Build Coastguard Worker cuda_res = softmax_on_device(mask, input, "cuda") 9998*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cpu_res, cuda_res, exact_dtype=True) 9999*da0073e9SAndroid Build Coastguard Worker 10000*da0073e9SAndroid Build Coastguard Worker def test_masked_softmax(self, device): 10001*da0073e9SAndroid Build Coastguard Worker sizes = [(1, 1, 32), (3, 16, 310), (12, 4, 1024), (4, 2, 1200)] 10002*da0073e9SAndroid Build Coastguard Worker for (B, num_heads, L) in sizes: 10003*da0073e9SAndroid Build Coastguard Worker for dim in [0, 3]: 10004*da0073e9SAndroid Build Coastguard Worker input = torch.randn((B, num_heads, L, L)) 10005*da0073e9SAndroid Build Coastguard Worker mask = torch.randint(0, 2, (B, L)) 10006*da0073e9SAndroid Build Coastguard Worker mask = mask.reshape(B, 1, 1, L).expand(B, num_heads, L, L).bool() 10007*da0073e9SAndroid Build Coastguard Worker mask_type = 1 # BxL => src_key_padding_mask 10008*da0073e9SAndroid Build Coastguard Worker if (self.device_type == "cuda"): 10009*da0073e9SAndroid Build Coastguard Worker input = input.cuda() 10010*da0073e9SAndroid Build Coastguard Worker mask = mask.cuda() 10011*da0073e9SAndroid Build Coastguard Worker native_res = torch._masked_softmax(input, mask, dim, mask_type) 10012*da0073e9SAndroid Build Coastguard Worker mask = ~mask 10013*da0073e9SAndroid Build Coastguard Worker 10014*da0073e9SAndroid Build Coastguard Worker def slow_masked_softmax(input, mask): 10015*da0073e9SAndroid Build Coastguard Worker exp = torch.exp(input) 10016*da0073e9SAndroid Build Coastguard Worker exp = exp * mask 10017*da0073e9SAndroid Build Coastguard Worker s = exp.sum(dim=dim, keepdim=True).expand(exp.size()) 10018*da0073e9SAndroid Build Coastguard Worker return exp / s 10019*da0073e9SAndroid Build Coastguard Worker 10020*da0073e9SAndroid Build Coastguard Worker pt_res = slow_masked_softmax(input, mask) 10021*da0073e9SAndroid Build Coastguard Worker pt_res = torch.nan_to_num(pt_res) 10022*da0073e9SAndroid Build Coastguard Worker 10023*da0073e9SAndroid Build Coastguard Worker mask_not = mask.logical_not() 10024*da0073e9SAndroid Build Coastguard Worker # In result, should only fill the entirely masked out rows since those are non-deterministic (*may* be 0) 10025*da0073e9SAndroid Build Coastguard Worker # Converts rows with all True's to False 10026*da0073e9SAndroid Build Coastguard Worker mask_out = mask_not.all(dim, keepdim=True).expand(mask_not.shape) 10027*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 10028*da0073e9SAndroid Build Coastguard Worker pt_res.masked_fill(mask_out, 0), 10029*da0073e9SAndroid Build Coastguard Worker native_res.masked_fill(mask_out, 0), 10030*da0073e9SAndroid Build Coastguard Worker exact_dtype=True 10031*da0073e9SAndroid Build Coastguard Worker ) 10032*da0073e9SAndroid Build Coastguard Worker 10033*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.bfloat16, torch.half) 10034*da0073e9SAndroid Build Coastguard Worker @precisionOverride({torch.bfloat16: 2e-2, torch.half: 3e-3}) 10035*da0073e9SAndroid Build Coastguard Worker def test_masked_softmax_lowp(self, dtype): 10036*da0073e9SAndroid Build Coastguard Worker sizes = [(1, 1, 32), (3, 16, 310), (12, 4, 1024), (4, 2, 1200)] 10037*da0073e9SAndroid Build Coastguard Worker for (B, num_heads, L) in sizes: 10038*da0073e9SAndroid Build Coastguard Worker for dim in [0, 3]: 10039*da0073e9SAndroid Build Coastguard Worker input_lowp = torch.randn((B, num_heads, L, L), dtype=dtype).requires_grad_() 10040*da0073e9SAndroid Build Coastguard Worker input_ref = input_lowp.float().detach().requires_grad_() 10041*da0073e9SAndroid Build Coastguard Worker mask = torch.randint(0, 2, (B, L)) 10042*da0073e9SAndroid Build Coastguard Worker mask = mask.reshape(B, 1, 1, L).expand(B, num_heads, L, L).bool() 10043*da0073e9SAndroid Build Coastguard Worker 10044*da0073e9SAndroid Build Coastguard Worker for mask_type in [1, 2]: 10045*da0073e9SAndroid Build Coastguard Worker res_ref = torch._masked_softmax(input_ref, mask, dim, mask_type) 10046*da0073e9SAndroid Build Coastguard Worker res = torch._masked_softmax(input_lowp, mask, dim, mask_type) 10047*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res_ref.to(dtype), res) 10048*da0073e9SAndroid Build Coastguard Worker 10049*da0073e9SAndroid Build Coastguard Worker grad_lowp = torch.randn_like(res_ref).to(dtype=dtype) 10050*da0073e9SAndroid Build Coastguard Worker grad_ref = grad_lowp.float() 10051*da0073e9SAndroid Build Coastguard Worker 10052*da0073e9SAndroid Build Coastguard Worker res_ref.backward(grad_ref) 10053*da0073e9SAndroid Build Coastguard Worker res.backward(grad_lowp) 10054*da0073e9SAndroid Build Coastguard Worker self.assertEqual(input_ref.grad.to(dtype), input_lowp.grad) 10055*da0073e9SAndroid Build Coastguard Worker 10056*da0073e9SAndroid Build Coastguard Worker def _test_masked_softmax_helper(self, input, dim, mask, mask_type): 10057*da0073e9SAndroid Build Coastguard Worker input_ref = input.detach().clone().requires_grad_() 10058*da0073e9SAndroid Build Coastguard Worker result = torch._masked_softmax(input, mask, dim, mask_type) 10059*da0073e9SAndroid Build Coastguard Worker 10060*da0073e9SAndroid Build Coastguard Worker expected = torch._softmax(input_ref.masked_fill(mask, float('-inf')), dim, False) 10061*da0073e9SAndroid Build Coastguard Worker grad = torch.randn_like(expected).to(dtype=expected.dtype) 10062*da0073e9SAndroid Build Coastguard Worker 10063*da0073e9SAndroid Build Coastguard Worker result.backward(grad) 10064*da0073e9SAndroid Build Coastguard Worker expected.backward(grad) 10065*da0073e9SAndroid Build Coastguard Worker 10066*da0073e9SAndroid Build Coastguard Worker # Make sure the optional argument works as well 10067*da0073e9SAndroid Build Coastguard Worker if dim == input.dim() - 1: 10068*da0073e9SAndroid Build Coastguard Worker input_ref_default = input.detach().clone().requires_grad_() 10069*da0073e9SAndroid Build Coastguard Worker result_default = torch._masked_softmax(input_ref_default, mask, None, mask_type) 10070*da0073e9SAndroid Build Coastguard Worker result_default.backward(grad) 10071*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result, result_default) 10072*da0073e9SAndroid Build Coastguard Worker self.assertEqual(input.grad, input_ref_default.grad) 10073*da0073e9SAndroid Build Coastguard Worker 10074*da0073e9SAndroid Build Coastguard Worker # In result, should only fill the entirely masked out rows since those are non-deterministic (*may* be 0) 10075*da0073e9SAndroid Build Coastguard Worker # Converts rows with all True's to False 10076*da0073e9SAndroid Build Coastguard Worker mask_out = mask.all(dim, keepdim=True).expand(mask.shape) 10077*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result.masked_fill(mask_out, 0), expected.masked_fill(mask_out, 0)) 10078*da0073e9SAndroid Build Coastguard Worker 10079*da0073e9SAndroid Build Coastguard Worker self.assertEqual(input.grad, torch.nan_to_num(input_ref.grad)) 10080*da0073e9SAndroid Build Coastguard Worker self.assertEqual(input.grad, input.grad.masked_fill(mask, 0.0)) 10081*da0073e9SAndroid Build Coastguard Worker 10082*da0073e9SAndroid Build Coastguard Worker def test_masked_softmax_grad(self, device): 10083*da0073e9SAndroid Build Coastguard Worker shapes = [(1, 1, 32), (3, 16, 310), (12, 4, 1024), (4, 2, 1200)] 10084*da0073e9SAndroid Build Coastguard Worker for shape in shapes: 10085*da0073e9SAndroid Build Coastguard Worker dims = [0, len(shape) - 1] if len(shape) > 0 else [0] 10086*da0073e9SAndroid Build Coastguard Worker for dim in dims: 10087*da0073e9SAndroid Build Coastguard Worker for mask_type in [1, 2]: # 1 = BxL => src_key_padding_mask 10088*da0073e9SAndroid Build Coastguard Worker input = torch.randn(shape, requires_grad=True) 10089*da0073e9SAndroid Build Coastguard Worker mask = torch.randint(0, 2, shape).bool() 10090*da0073e9SAndroid Build Coastguard Worker if (self.device_type == "cuda"): 10091*da0073e9SAndroid Build Coastguard Worker input = input.cuda().detach().requires_grad_() 10092*da0073e9SAndroid Build Coastguard Worker mask = mask.cuda() 10093*da0073e9SAndroid Build Coastguard Worker self._test_masked_softmax_helper(input, dim, mask, mask_type) 10094*da0073e9SAndroid Build Coastguard Worker 10095*da0073e9SAndroid Build Coastguard Worker # In this test, the forward pass is expected to produce nan's because when dim=0, we only have unspecified values 10096*da0073e9SAndroid Build Coastguard Worker def test_masked_softmax_forward_with_nans(self, device): 10097*da0073e9SAndroid Build Coastguard Worker dim = 0 10098*da0073e9SAndroid Build Coastguard Worker shapes = [(4, 5), (50, 100), (1500, 1200)] 10099*da0073e9SAndroid Build Coastguard Worker for (x, y) in shapes: 10100*da0073e9SAndroid Build Coastguard Worker for mask_type in [1, 2]: # 1 = BxL => src_key_padding_mask 10101*da0073e9SAndroid Build Coastguard Worker input = torch.randn((x, y), requires_grad=True) 10102*da0073e9SAndroid Build Coastguard Worker mask = torch.tensor([i % 2 for i in range(y)]).expand((x, y)).bool() 10103*da0073e9SAndroid Build Coastguard Worker if (self.device_type == "cuda"): 10104*da0073e9SAndroid Build Coastguard Worker input = input.cuda().detach().requires_grad_() 10105*da0073e9SAndroid Build Coastguard Worker mask = mask.cuda() 10106*da0073e9SAndroid Build Coastguard Worker self._test_masked_softmax_helper(input, dim, mask, mask_type) 10107*da0073e9SAndroid Build Coastguard Worker 10108*da0073e9SAndroid Build Coastguard Worker @onlyCUDA 10109*da0073e9SAndroid Build Coastguard Worker def test_masked_softmax_transformer_layout(self, device): 10110*da0073e9SAndroid Build Coastguard Worker B = 211 10111*da0073e9SAndroid Build Coastguard Worker num_heads = 16 10112*da0073e9SAndroid Build Coastguard Worker L = 42 10113*da0073e9SAndroid Build Coastguard Worker input = torch.randn((B, num_heads, L, L)) 10114*da0073e9SAndroid Build Coastguard Worker dim = input.dim() - 1 10115*da0073e9SAndroid Build Coastguard Worker mask = torch.randint(0, 2, (B, L)) 10116*da0073e9SAndroid Build Coastguard Worker mask_type = 1 # BxL => src_key_padding_mask 10117*da0073e9SAndroid Build Coastguard Worker if (self.device_type == "cuda"): 10118*da0073e9SAndroid Build Coastguard Worker input = input.cuda() 10119*da0073e9SAndroid Build Coastguard Worker mask = mask.cuda() 10120*da0073e9SAndroid Build Coastguard Worker mask = mask.bool() 10121*da0073e9SAndroid Build Coastguard Worker native_res = torch._masked_softmax(input, mask, dim, mask_type) 10122*da0073e9SAndroid Build Coastguard Worker mask = mask.reshape(B, 1, 1, L).expand(B, num_heads, L, L) 10123*da0073e9SAndroid Build Coastguard Worker mask = ~mask 10124*da0073e9SAndroid Build Coastguard Worker mask = mask.float() 10125*da0073e9SAndroid Build Coastguard Worker 10126*da0073e9SAndroid Build Coastguard Worker pt_res = self._slow_masked_softmax(input, mask) 10127*da0073e9SAndroid Build Coastguard Worker self.assertEqual(pt_res, native_res, exact_dtype=True) 10128*da0073e9SAndroid Build Coastguard Worker 10129*da0073e9SAndroid Build Coastguard Worker @onlyCUDA 10130*da0073e9SAndroid Build Coastguard Worker def test_masked_softmax_TxT_layout(self, device): 10131*da0073e9SAndroid Build Coastguard Worker B = 211 10132*da0073e9SAndroid Build Coastguard Worker num_heads = 16 10133*da0073e9SAndroid Build Coastguard Worker L = 42 10134*da0073e9SAndroid Build Coastguard Worker input = torch.randn((B, num_heads, L, L)) 10135*da0073e9SAndroid Build Coastguard Worker dim = input.dim() - 1 10136*da0073e9SAndroid Build Coastguard Worker mask = torch.randint(0, 2, (L, L)) 10137*da0073e9SAndroid Build Coastguard Worker mask_type = 0 # LxL => src_mask 10138*da0073e9SAndroid Build Coastguard Worker if (self.device_type == "cuda"): 10139*da0073e9SAndroid Build Coastguard Worker input = input.cuda() 10140*da0073e9SAndroid Build Coastguard Worker mask = mask.cuda() 10141*da0073e9SAndroid Build Coastguard Worker mask = mask.bool() 10142*da0073e9SAndroid Build Coastguard Worker native_res = torch._masked_softmax(input, mask, dim, mask_type) 10143*da0073e9SAndroid Build Coastguard Worker mask = mask.expand(B, num_heads, L, L) 10144*da0073e9SAndroid Build Coastguard Worker mask = ~mask 10145*da0073e9SAndroid Build Coastguard Worker mask = mask.float() 10146*da0073e9SAndroid Build Coastguard Worker 10147*da0073e9SAndroid Build Coastguard Worker pt_res = self._slow_masked_softmax(input, mask) 10148*da0073e9SAndroid Build Coastguard Worker self.assertEqual(pt_res, native_res, exact_dtype=True) 10149*da0073e9SAndroid Build Coastguard Worker 10150*da0073e9SAndroid Build Coastguard Worker @onlyCPU 10151*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.bfloat16, torch.half) 10152*da0073e9SAndroid Build Coastguard Worker def test_log_softmax_cpu(self, device, dtype): 10153*da0073e9SAndroid Build Coastguard Worker for dim in [0, 1]: 10154*da0073e9SAndroid Build Coastguard Worker inputf = torch.rand(200, 200, device=device, dtype=torch.float, requires_grad=True) 10155*da0073e9SAndroid Build Coastguard Worker input = inputf.to(dtype).detach().requires_grad_(True) 10156*da0073e9SAndroid Build Coastguard Worker outf = F.log_softmax(inputf, dim=dim) 10157*da0073e9SAndroid Build Coastguard Worker out = F.log_softmax(input, dim=dim) 10158*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out, outf.to(dtype=dtype), atol=0.1, rtol=0) 10159*da0073e9SAndroid Build Coastguard Worker 10160*da0073e9SAndroid Build Coastguard Worker out.sum().backward() 10161*da0073e9SAndroid Build Coastguard Worker outf.sum().backward() 10162*da0073e9SAndroid Build Coastguard Worker self.assertEqual(input.grad, inputf.grad.to(dtype), atol=0.1, rtol=0) 10163*da0073e9SAndroid Build Coastguard Worker 10164*da0073e9SAndroid Build Coastguard Worker @onlyCPU 10165*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.bfloat16, torch.half) 10166*da0073e9SAndroid Build Coastguard Worker def test_softmax_cpu(self, device, dtype): 10167*da0073e9SAndroid Build Coastguard Worker for dim in [0, 1]: 10168*da0073e9SAndroid Build Coastguard Worker inputf = torch.rand(200, 200, device=device, dtype=torch.float, requires_grad=True) 10169*da0073e9SAndroid Build Coastguard Worker input = inputf.to(dtype).detach().requires_grad_(True) 10170*da0073e9SAndroid Build Coastguard Worker outf = F.softmax(inputf, dim=dim) 10171*da0073e9SAndroid Build Coastguard Worker out = F.softmax(input, dim=dim) 10172*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out, outf.to(dtype), atol=1e-3, rtol=0) 10173*da0073e9SAndroid Build Coastguard Worker 10174*da0073e9SAndroid Build Coastguard Worker out.sum().backward() 10175*da0073e9SAndroid Build Coastguard Worker outf.sum().backward() 10176*da0073e9SAndroid Build Coastguard Worker self.assertEqual(input.grad, inputf.grad.to(dtype), atol=1e-3, rtol=0) 10177*da0073e9SAndroid Build Coastguard Worker 10178*da0073e9SAndroid Build Coastguard Worker @dtypesIfCUDA(torch.half, torch.float) 10179*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float) 10180*da0073e9SAndroid Build Coastguard Worker def test_softmax_results(self, device, dtype): 10181*da0073e9SAndroid Build Coastguard Worker # Non-even sizes and non-zero shifts test fallback paths in vectorized kernel 10182*da0073e9SAndroid Build Coastguard Worker # Note: dim1 > 1024 is needed to exercise the vectorized (non-persistent) path, (16, 30576) is BERT-esque 10183*da0073e9SAndroid Build Coastguard Worker sizes = [(0, 10), (32, 20), (10, 0), (31, 20), (32, 21), (31, 23), (32, 1536), (31, 2048), (33, 2049), (16, 30576)] 10184*da0073e9SAndroid Build Coastguard Worker shifts = [(0, 0), (1, 0), (0, 1), (1, 1)] 10185*da0073e9SAndroid Build Coastguard Worker for fn in [F.softmax, F.log_softmax]: 10186*da0073e9SAndroid Build Coastguard Worker for size in sizes: 10187*da0073e9SAndroid Build Coastguard Worker for shift in shifts: 10188*da0073e9SAndroid Build Coastguard Worker input = torch.rand(size, device=device, dtype=dtype) 10189*da0073e9SAndroid Build Coastguard Worker # Note: With the largest tests we can hit upper limit of fp16 when we 10190*da0073e9SAndroid Build Coastguard Worker # sum, so scale the input down to stay in a nicer range. 10191*da0073e9SAndroid Build Coastguard Worker if dtype == torch.float16: 10192*da0073e9SAndroid Build Coastguard Worker input = input / 100. 10193*da0073e9SAndroid Build Coastguard Worker input = input[shift[0]:, shift[1]:] 10194*da0073e9SAndroid Build Coastguard Worker # Note; Don't want to bprop back through slice op 10195*da0073e9SAndroid Build Coastguard Worker input = input.detach().requires_grad_(True) 10196*da0073e9SAndroid Build Coastguard Worker ref_input = input.clone().cpu().detach().requires_grad_(True) 10197*da0073e9SAndroid Build Coastguard Worker for dim in [0, 1]: 10198*da0073e9SAndroid Build Coastguard Worker ref_output = fn(ref_input, dtype=torch.float, dim=dim) 10199*da0073e9SAndroid Build Coastguard Worker output = fn(input, dtype=torch.float, dim=dim) 10200*da0073e9SAndroid Build Coastguard Worker grad_output = torch.rand(size, device=device, dtype=dtype) 10201*da0073e9SAndroid Build Coastguard Worker grad_output = grad_output[shift[0]:, shift[1]:] 10202*da0073e9SAndroid Build Coastguard Worker ref_grad_output = grad_output.clone().cpu().detach() 10203*da0073e9SAndroid Build Coastguard Worker grad_input, = torch.autograd.grad(output, input, grad_outputs=(grad_output), create_graph=True) 10204*da0073e9SAndroid Build Coastguard Worker ref_grad_input, = torch.autograd.grad(ref_output, ref_input, 10205*da0073e9SAndroid Build Coastguard Worker grad_outputs=(ref_grad_output), create_graph=True) 10206*da0073e9SAndroid Build Coastguard Worker grad_input.sum().backward() 10207*da0073e9SAndroid Build Coastguard Worker ref_grad_input.sum().backward() 10208*da0073e9SAndroid Build Coastguard Worker 10209*da0073e9SAndroid Build Coastguard Worker self.assertEqual(output, ref_output) 10210*da0073e9SAndroid Build Coastguard Worker self.assertEqual(grad_input, ref_grad_input) 10211*da0073e9SAndroid Build Coastguard Worker self.assertEqual(input.grad, ref_input.grad) 10212*da0073e9SAndroid Build Coastguard Worker 10213*da0073e9SAndroid Build Coastguard Worker @onlyCUDA 10214*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float, torch.half) 10215*da0073e9SAndroid Build Coastguard Worker @largeTensorTest("20GB") 10216*da0073e9SAndroid Build Coastguard Worker @largeTensorTest("64GB", "cpu") 10217*da0073e9SAndroid Build Coastguard Worker def test_warp_softmax_64bit_indexing(self, device, dtype): 10218*da0073e9SAndroid Build Coastguard Worker def run_test(*shape): 10219*da0073e9SAndroid Build Coastguard Worker x = torch.randn(shape, device="cuda", dtype=torch.float16, requires_grad=True) 10220*da0073e9SAndroid Build Coastguard Worker y = F.log_softmax(x, dim=-1, dtype=dtype) 10221*da0073e9SAndroid Build Coastguard Worker y.backward(y) 10222*da0073e9SAndroid Build Coastguard Worker with torch.no_grad(): 10223*da0073e9SAndroid Build Coastguard Worker xx = x.cpu().requires_grad_() 10224*da0073e9SAndroid Build Coastguard Worker yy = F.log_softmax(xx.float(), dim=-1).to(dtype) 10225*da0073e9SAndroid Build Coastguard Worker yy.backward(yy) 10226*da0073e9SAndroid Build Coastguard Worker # workaround to reduce memory usage vs. self.assertEqual, see #84944 10227*da0073e9SAndroid Build Coastguard Worker rtol, atol = torch.testing._comparison.get_tolerances(dtype, rtol=None, atol=None) 10228*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.allclose(y.cpu(), yy, rtol=rtol, atol=atol)) 10229*da0073e9SAndroid Build Coastguard Worker # x is half 10230*da0073e9SAndroid Build Coastguard Worker rtol, _ = torch.testing._comparison.get_tolerances(torch.half, rtol=None, atol=None) 10231*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.allclose(x.grad.cpu(), xx.grad, rtol=rtol, atol=1e-3)) 10232*da0073e9SAndroid Build Coastguard Worker 10233*da0073e9SAndroid Build Coastguard Worker run_test(1100000000, 2) # Illegal memory access https://github.com/pytorch/pytorch/issues/52715 10234*da0073e9SAndroid Build Coastguard Worker run_test(2200000000, 1) # invalid configuration argument https://github.com/pytorch/pytorch/issues/52716 10235*da0073e9SAndroid Build Coastguard Worker 10236*da0073e9SAndroid Build Coastguard Worker @onlyCUDA 10237*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.half) 10238*da0073e9SAndroid Build Coastguard Worker @largeTensorTest("20GB") 10239*da0073e9SAndroid Build Coastguard Worker @largeTensorTest("2GB", "cpu") 10240*da0073e9SAndroid Build Coastguard Worker @precisionOverride({torch.half: 0.001}) 10241*da0073e9SAndroid Build Coastguard Worker def test_softmax_64bit_indexing(self, device, dtype): 10242*da0073e9SAndroid Build Coastguard Worker def run_test(*shape): 10243*da0073e9SAndroid Build Coastguard Worker x = torch.ones(shape, device=device, dtype=dtype, requires_grad=True) 10244*da0073e9SAndroid Build Coastguard Worker y = F.log_softmax(x, dim=-1, dtype=dtype) 10245*da0073e9SAndroid Build Coastguard Worker y.backward(y) 10246*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y[0], y[-1]) 10247*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.grad[0], x.grad[-1]) 10248*da0073e9SAndroid Build Coastguard Worker 10249*da0073e9SAndroid Build Coastguard Worker run_test(1024 * 256 + 1, 8192) # https://github.com/pytorch/pytorch/issues/84144 10250*da0073e9SAndroid Build Coastguard Worker 10251*da0073e9SAndroid Build Coastguard Worker 10252*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float) 10253*da0073e9SAndroid Build Coastguard Worker @dtypesIfCUDA(torch.float, torch.half) 10254*da0073e9SAndroid Build Coastguard Worker def test_log_softmax_big(self, device, dtype): 10255*da0073e9SAndroid Build Coastguard Worker def _test_helper(shape): 10256*da0073e9SAndroid Build Coastguard Worker # generate a tensor with big numbers that are exactly representable in dtype 10257*da0073e9SAndroid Build Coastguard Worker # and are at a constant offset from tensor with small numbers 10258*da0073e9SAndroid Build Coastguard Worker # the logsoftmax of a small and big tensors should be equal 10259*da0073e9SAndroid Build Coastguard Worker x_small = torch.randint(100, shape, dtype=dtype, device=device) 10260*da0073e9SAndroid Build Coastguard Worker offset = 1.5e3 if dtype == torch.half else 1e7 10261*da0073e9SAndroid Build Coastguard Worker x_big = x_small + offset 10262*da0073e9SAndroid Build Coastguard Worker self.assertEqual(F.log_softmax(x_small, -1), F.log_softmax(x_big, -1)) 10263*da0073e9SAndroid Build Coastguard Worker _test_helper((16, 4)) 10264*da0073e9SAndroid Build Coastguard Worker if self.device_type == 'cuda': 10265*da0073e9SAndroid Build Coastguard Worker # test non-persistent softmax kernel 10266*da0073e9SAndroid Build Coastguard Worker _test_helper((4, 1536)) 10267*da0073e9SAndroid Build Coastguard Worker 10268*da0073e9SAndroid Build Coastguard Worker def test_save_lstm_compatibility(self, device): 10269*da0073e9SAndroid Build Coastguard Worker # Test that saving an LSTM in PyTorch 1.7 and older can still be 10270*da0073e9SAndroid Build Coastguard Worker # loaded in newer versions of PyTorch. 10271*da0073e9SAndroid Build Coastguard Worker model = nn.LSTM(2, 3) 10272*da0073e9SAndroid Build Coastguard Worker x = torch.randn(32, 5, 2) 10273*da0073e9SAndroid Build Coastguard Worker expected = model(x) 10274*da0073e9SAndroid Build Coastguard Worker 10275*da0073e9SAndroid Build Coastguard Worker # Get a state dict for PyTorch 1.7 LSTM. Before PyTorch 1.8, proj_size 10276*da0073e9SAndroid Build Coastguard Worker # didn't exist. 10277*da0073e9SAndroid Build Coastguard Worker assert model.proj_size == 0 10278*da0073e9SAndroid Build Coastguard Worker state_dict = model.__dict__ 10279*da0073e9SAndroid Build Coastguard Worker del state_dict['proj_size'] 10280*da0073e9SAndroid Build Coastguard Worker 10281*da0073e9SAndroid Build Coastguard Worker # load a model 10282*da0073e9SAndroid Build Coastguard Worker loaded_model = nn.LSTM(2, 3) 10283*da0073e9SAndroid Build Coastguard Worker loaded_model.__setstate__(state_dict) 10284*da0073e9SAndroid Build Coastguard Worker result = loaded_model(x) 10285*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result, expected) 10286*da0073e9SAndroid Build Coastguard Worker 10287*da0073e9SAndroid Build Coastguard Worker @onlyCUDA 10288*da0073e9SAndroid Build Coastguard Worker @tf32_on_and_off(0.005) 10289*da0073e9SAndroid Build Coastguard Worker def test_grid_sample_large(self, device): 10290*da0073e9SAndroid Build Coastguard Worker def issue_35202(): 10291*da0073e9SAndroid Build Coastguard Worker input_tensor = torch.rand(1, 1, 480, 640, dtype=torch.float, device=device, requires_grad=True) 10292*da0073e9SAndroid Build Coastguard Worker coords = torch.tensor([[-10059144, 67680944], [67680944, 67680944]], dtype=torch.float, device=device) 10293*da0073e9SAndroid Build Coastguard Worker coords = coords.unsqueeze(0).unsqueeze(0).repeat(1, 1, 1, 1) 10294*da0073e9SAndroid Build Coastguard Worker result = torch.nn.functional.grid_sample(input_tensor, coords) 10295*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result, torch.tensor([[[[0., 0.]]]], dtype=torch.float, device=device)) 10296*da0073e9SAndroid Build Coastguard Worker result.backward(torch.ones_like(result)) 10297*da0073e9SAndroid Build Coastguard Worker torch.cuda.synchronize() 10298*da0073e9SAndroid Build Coastguard Worker issue_35202() 10299*da0073e9SAndroid Build Coastguard Worker 10300*da0073e9SAndroid Build Coastguard Worker def issue_24823_1(dtype): 10301*da0073e9SAndroid Build Coastguard Worker image = torch.arange(27, 0, -1, dtype=dtype, device=device).view(1, 1, 3, 3, 3) 10302*da0073e9SAndroid Build Coastguard Worker image.requires_grad_() 10303*da0073e9SAndroid Build Coastguard Worker grid = torch.nn.functional.affine_grid( 10304*da0073e9SAndroid Build Coastguard Worker torch.tensor([[[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0]]], dtype=dtype, device=device), 10305*da0073e9SAndroid Build Coastguard Worker (1, 1, 3, 3, 3)) 10306*da0073e9SAndroid Build Coastguard Worker grid[:, 1, 1, 1, 0] = float('inf') 10307*da0073e9SAndroid Build Coastguard Worker result = torch.nn.functional.grid_sample(image, grid, padding_mode='zeros') 10308*da0073e9SAndroid Build Coastguard Worker tol_override = {'atol': 0.005, 'rtol': 0} if dtype == torch.half else {} 10309*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result, torch.tensor([[[[[27., 26., 25.], [24., 23., 22.], [21., 20., 19.]], 10310*da0073e9SAndroid Build Coastguard Worker [[18., 17., 16.], [15., 0., 13.], [12., 11., 10.]], 10311*da0073e9SAndroid Build Coastguard Worker [[9., 8., 7.], [6., 5., 4.], [3., 2., 1.]]]]], 10312*da0073e9SAndroid Build Coastguard Worker device=device, dtype=dtype), **tol_override) 10313*da0073e9SAndroid Build Coastguard Worker result.backward(torch.ones_like(result)) 10314*da0073e9SAndroid Build Coastguard Worker expected_grad = torch.ones_like(image) 10315*da0073e9SAndroid Build Coastguard Worker expected_grad[0, 0, 1, 1, 1] = 0 10316*da0073e9SAndroid Build Coastguard Worker self.assertEqual(image.grad, expected_grad, atol=0.005, rtol=0) 10317*da0073e9SAndroid Build Coastguard Worker issue_24823_1(torch.half) 10318*da0073e9SAndroid Build Coastguard Worker issue_24823_1(torch.float) 10319*da0073e9SAndroid Build Coastguard Worker issue_24823_1(torch.double) 10320*da0073e9SAndroid Build Coastguard Worker 10321*da0073e9SAndroid Build Coastguard Worker def issue_24823_2(): 10322*da0073e9SAndroid Build Coastguard Worker param = torch.tensor([[[-1.0e+20, 0.0, 0.0], [0.0, -1.0e+20, 0.0]]], dtype=torch.float, device=device) 10323*da0073e9SAndroid Build Coastguard Worker img = torch.zeros((1, 1, 4, 4), dtype=torch.float, device=device, requires_grad=True) 10324*da0073e9SAndroid Build Coastguard Worker grid = torch.nn.functional.affine_grid(param, img.size()) 10325*da0073e9SAndroid Build Coastguard Worker result = torch.nn.functional.grid_sample(img, grid) 10326*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result, torch.zeros(1, 1, 4, 4, device=device, dtype=torch.float)) 10327*da0073e9SAndroid Build Coastguard Worker result.backward(torch.ones_like(result)) 10328*da0073e9SAndroid Build Coastguard Worker torch.cuda.synchronize() 10329*da0073e9SAndroid Build Coastguard Worker issue_24823_2() 10330*da0073e9SAndroid Build Coastguard Worker 10331*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float, torch.double) 10332*da0073e9SAndroid Build Coastguard Worker @largeTensorTest(lambda self, device, dtype: 10333*da0073e9SAndroid Build Coastguard Worker # Compute sum of the large tensor sizes: 10334*da0073e9SAndroid Build Coastguard Worker # (im.numel() + small_image.numel() + small_image.grad.numel() + 10335*da0073e9SAndroid Build Coastguard Worker # large_view.grad.numel()) * sizeof(dtype) 10336*da0073e9SAndroid Build Coastguard Worker 32769 * (65536 + 3 * 65536 / 128) * 10337*da0073e9SAndroid Build Coastguard Worker torch.tensor([], dtype=dtype).element_size()) 10338*da0073e9SAndroid Build Coastguard Worker def test_grid_sample_large_index_2d(self, device, dtype): 10339*da0073e9SAndroid Build Coastguard Worker # Test 64-bit indexing with grid_sample (gh-41656) 10340*da0073e9SAndroid Build Coastguard Worker # Try accessing the corners, there should be no segfault 10341*da0073e9SAndroid Build Coastguard Worker coords = torch.tensor([[[-1., -1.], 10342*da0073e9SAndroid Build Coastguard Worker [+1., -1.]], 10343*da0073e9SAndroid Build Coastguard Worker 10344*da0073e9SAndroid Build Coastguard Worker [[-1., +1.], 10345*da0073e9SAndroid Build Coastguard Worker [+1., +1.]]], device=device, dtype=dtype) 10346*da0073e9SAndroid Build Coastguard Worker coords = coords.expand(1, 2, 2, 2) 10347*da0073e9SAndroid Build Coastguard Worker im = torch.zeros([1, 1, 32769, 65536], device=device, dtype=dtype) 10348*da0073e9SAndroid Build Coastguard Worker 10349*da0073e9SAndroid Build Coastguard Worker # Compare sampling with large strides to the same op on a contiguous tensor 10350*da0073e9SAndroid Build Coastguard Worker coords = torch.rand(1, 4, 4, 2, device=device, dtype=dtype) 10351*da0073e9SAndroid Build Coastguard Worker large_view = im[..., 127::128] 10352*da0073e9SAndroid Build Coastguard Worker small_image = torch.rand_like(large_view) 10353*da0073e9SAndroid Build Coastguard Worker large_view[...] = small_image 10354*da0073e9SAndroid Build Coastguard Worker large_view.requires_grad, small_image.requires_grad = True, True 10355*da0073e9SAndroid Build Coastguard Worker self.assertTrue( 10356*da0073e9SAndroid Build Coastguard Worker sum(i * s for i, s in zip(large_view.size(), large_view.stride())) >= 2 ** 31, 10357*da0073e9SAndroid Build Coastguard Worker msg="View must use 64-bit indexing") 10358*da0073e9SAndroid Build Coastguard Worker for mode, padding_mode, align_corners in itertools.product( 10359*da0073e9SAndroid Build Coastguard Worker ('nearest', 'bilinear', 'bicubic'), ('zeros', 'border', 'reflection'), (True, False)): 10360*da0073e9SAndroid Build Coastguard Worker a = F.grid_sample( 10361*da0073e9SAndroid Build Coastguard Worker small_image, coords, mode=mode, 10362*da0073e9SAndroid Build Coastguard Worker padding_mode=padding_mode, align_corners=align_corners) 10363*da0073e9SAndroid Build Coastguard Worker a.sum().backward() 10364*da0073e9SAndroid Build Coastguard Worker 10365*da0073e9SAndroid Build Coastguard Worker b = F.grid_sample( 10366*da0073e9SAndroid Build Coastguard Worker large_view, coords, mode=mode, 10367*da0073e9SAndroid Build Coastguard Worker padding_mode=padding_mode, align_corners=align_corners) 10368*da0073e9SAndroid Build Coastguard Worker b.sum().backward() 10369*da0073e9SAndroid Build Coastguard Worker 10370*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a, b) 10371*da0073e9SAndroid Build Coastguard Worker self.assertEqual(small_image.grad, large_view.grad) 10372*da0073e9SAndroid Build Coastguard Worker 10373*da0073e9SAndroid Build Coastguard Worker small_image.grad.zero_() 10374*da0073e9SAndroid Build Coastguard Worker large_view.grad.zero_() 10375*da0073e9SAndroid Build Coastguard Worker 10376*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float, torch.double) 10377*da0073e9SAndroid Build Coastguard Worker @largeTensorTest(lambda self, device, dtype: 10378*da0073e9SAndroid Build Coastguard Worker # Compute sum of the large tensor sizes: 10379*da0073e9SAndroid Build Coastguard Worker # (im.numel() + small_image.numel() + small_image.grad.numel() + 10380*da0073e9SAndroid Build Coastguard Worker # large_view.grad.numel()) * sizeof(dtype) 10381*da0073e9SAndroid Build Coastguard Worker 2 * 32769 * (32768 + 3 * 32768 / 128) * 10382*da0073e9SAndroid Build Coastguard Worker torch.tensor([], dtype=dtype).element_size()) 10383*da0073e9SAndroid Build Coastguard Worker def test_grid_sample_large_index_3d(self, device, dtype): 10384*da0073e9SAndroid Build Coastguard Worker # Test 64-bit indexing with grid_sample (gh-41656) 10385*da0073e9SAndroid Build Coastguard Worker # Try accessing the corners, there should be no segfault 10386*da0073e9SAndroid Build Coastguard Worker coords = torch.full((1, 2, 2, 2, 3), 1., device=device, dtype=dtype) 10387*da0073e9SAndroid Build Coastguard Worker im = torch.zeros([1, 1, 2, 32769, 32768], device=device, dtype=dtype) 10388*da0073e9SAndroid Build Coastguard Worker 10389*da0073e9SAndroid Build Coastguard Worker result = F.grid_sample(im, coords, align_corners=False) 10390*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result, torch.zeros((1, 1, 2, 2, 2), device=device, dtype=dtype)) 10391*da0073e9SAndroid Build Coastguard Worker 10392*da0073e9SAndroid Build Coastguard Worker # Compare sampling with large strides to the same op on a contiguous tensor 10393*da0073e9SAndroid Build Coastguard Worker coords = torch.rand(1, 1, 4, 4, 3, device=device, dtype=dtype) 10394*da0073e9SAndroid Build Coastguard Worker large_view = im[..., 127::128] 10395*da0073e9SAndroid Build Coastguard Worker small_image = torch.rand_like(large_view) 10396*da0073e9SAndroid Build Coastguard Worker large_view[...] = small_image 10397*da0073e9SAndroid Build Coastguard Worker small_image.requires_grad, large_view.requires_grad = True, True 10398*da0073e9SAndroid Build Coastguard Worker self.assertTrue( 10399*da0073e9SAndroid Build Coastguard Worker sum(i * s for i, s in zip(large_view.size(), large_view.stride())) >= 2 ** 31, 10400*da0073e9SAndroid Build Coastguard Worker msg="View must use 64-bit indexing") 10401*da0073e9SAndroid Build Coastguard Worker for mode, padding_mode, align_corners in itertools.product( 10402*da0073e9SAndroid Build Coastguard Worker ('nearest', 'bilinear'), ('zeros', 'border', 'reflection'), (True, False)): 10403*da0073e9SAndroid Build Coastguard Worker a = F.grid_sample( 10404*da0073e9SAndroid Build Coastguard Worker small_image, coords, mode=mode, 10405*da0073e9SAndroid Build Coastguard Worker padding_mode=padding_mode, align_corners=align_corners) 10406*da0073e9SAndroid Build Coastguard Worker a.sum().backward() 10407*da0073e9SAndroid Build Coastguard Worker 10408*da0073e9SAndroid Build Coastguard Worker b = F.grid_sample( 10409*da0073e9SAndroid Build Coastguard Worker large_view, coords, mode=mode, 10410*da0073e9SAndroid Build Coastguard Worker padding_mode=padding_mode, align_corners=align_corners) 10411*da0073e9SAndroid Build Coastguard Worker b.sum().backward() 10412*da0073e9SAndroid Build Coastguard Worker 10413*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a, b) 10414*da0073e9SAndroid Build Coastguard Worker self.assertEqual(small_image.grad, large_view.grad) 10415*da0073e9SAndroid Build Coastguard Worker 10416*da0073e9SAndroid Build Coastguard Worker small_image.grad.zero_() 10417*da0073e9SAndroid Build Coastguard Worker large_view.grad.zero_() 10418*da0073e9SAndroid Build Coastguard Worker 10419*da0073e9SAndroid Build Coastguard Worker @onlyCUDA 10420*da0073e9SAndroid Build Coastguard Worker def test_grid_sample_half_precision(self): 10421*da0073e9SAndroid Build Coastguard Worker def helper(shape_in, shape_out, align_corners): 10422*da0073e9SAndroid Build Coastguard Worker for mode in ('bilinear', 'nearest', 'bicubic'): 10423*da0073e9SAndroid Build Coastguard Worker if len(shape_in) != 4 and mode == 'bicubic': 10424*da0073e9SAndroid Build Coastguard Worker continue 10425*da0073e9SAndroid Build Coastguard Worker data = torch.randn(shape_in, device='cuda', dtype=torch.half) 10426*da0073e9SAndroid Build Coastguard Worker grid = torch.rand(shape_out, device='cuda', dtype=torch.half) * 2.0 - 1.0 10427*da0073e9SAndroid Build Coastguard Worker 10428*da0073e9SAndroid Build Coastguard Worker out_half = F.grid_sample(data, grid, mode=mode, padding_mode='zeros', align_corners=align_corners) 10429*da0073e9SAndroid Build Coastguard Worker out_double = F.grid_sample(data.double(), grid.double(), mode=mode, padding_mode='zeros', 10430*da0073e9SAndroid Build Coastguard Worker align_corners=align_corners) 10431*da0073e9SAndroid Build Coastguard Worker 10432*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out_half, out_double.half(), msg=f"grid_sample with mode = {mode} doesn't match") 10433*da0073e9SAndroid Build Coastguard Worker 10434*da0073e9SAndroid Build Coastguard Worker helper((32, 64, 16, 16), (32, 8, 8, 2), True) 10435*da0073e9SAndroid Build Coastguard Worker helper((32, 64, 16, 16, 16), (32, 8, 8, 8, 3), True) 10436*da0073e9SAndroid Build Coastguard Worker helper((32, 64, 16, 16), (32, 8, 8, 2), False) 10437*da0073e9SAndroid Build Coastguard Worker helper((32, 64, 16, 16, 16), (32, 8, 8, 8, 3), False) 10438*da0073e9SAndroid Build Coastguard Worker 10439*da0073e9SAndroid Build Coastguard Worker @onlyCUDA 10440*da0073e9SAndroid Build Coastguard Worker def test_grid_sample_bfloat16_precision(self): 10441*da0073e9SAndroid Build Coastguard Worker def helper(shape_in, shape_out, align_corners): 10442*da0073e9SAndroid Build Coastguard Worker for mode in ('bilinear', 'nearest', 'bicubic'): 10443*da0073e9SAndroid Build Coastguard Worker if len(shape_in) != 4 and mode == 'bicubic': 10444*da0073e9SAndroid Build Coastguard Worker continue 10445*da0073e9SAndroid Build Coastguard Worker data = torch.randn(shape_in, device='cuda', dtype=torch.bfloat16) 10446*da0073e9SAndroid Build Coastguard Worker grid = torch.rand(shape_out, device='cuda', dtype=torch.bfloat16) * 2.0 - 1.0 10447*da0073e9SAndroid Build Coastguard Worker 10448*da0073e9SAndroid Build Coastguard Worker out_half = F.grid_sample(data, grid, mode=mode, padding_mode='zeros', align_corners=align_corners) 10449*da0073e9SAndroid Build Coastguard Worker out_double = F.grid_sample(data.double(), grid.double(), mode=mode, padding_mode='zeros', 10450*da0073e9SAndroid Build Coastguard Worker align_corners=align_corners) 10451*da0073e9SAndroid Build Coastguard Worker 10452*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out_half, out_double.bfloat16(), msg=f"grid_sample with mode = {mode} doesn't match") 10453*da0073e9SAndroid Build Coastguard Worker 10454*da0073e9SAndroid Build Coastguard Worker helper((32, 64, 16, 16), (32, 8, 8, 2), True) 10455*da0073e9SAndroid Build Coastguard Worker helper((32, 64, 16, 16, 16), (32, 8, 8, 8, 3), True) 10456*da0073e9SAndroid Build Coastguard Worker helper((32, 64, 16, 16), (32, 8, 8, 2), False) 10457*da0073e9SAndroid Build Coastguard Worker helper((32, 64, 16, 16, 16), (32, 8, 8, 8, 3), False) 10458*da0073e9SAndroid Build Coastguard Worker 10459*da0073e9SAndroid Build Coastguard Worker def _test_gumbel_softmax_st_shapes(self, device, dtype, shape, dim, count_expected): 10460*da0073e9SAndroid Build Coastguard Worker logits = torch.randn(shape, dtype=torch.float, device=device) 10461*da0073e9SAndroid Build Coastguard Worker logits = logits.to(dtype) 10462*da0073e9SAndroid Build Coastguard Worker 10463*da0073e9SAndroid Build Coastguard Worker y_draw = F.gumbel_softmax(logits, hard=True, dim=dim) 10464*da0073e9SAndroid Build Coastguard Worker 10465*da0073e9SAndroid Build Coastguard Worker # All values positive 10466*da0073e9SAndroid Build Coastguard Worker self.assertGreaterEqual(y_draw.min(), 0) 10467*da0073e9SAndroid Build Coastguard Worker # Shape unchanged 10468*da0073e9SAndroid Build Coastguard Worker self.assertTrue(y_draw.shape == logits.shape) 10469*da0073e9SAndroid Build Coastguard Worker # One choice per draw 10470*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y_draw.sum(), count_expected, atol=torch.finfo(y_draw.dtype).eps, rtol=0) 10471*da0073e9SAndroid Build Coastguard Worker 10472*da0073e9SAndroid Build Coastguard Worker def _test_gumbel_softmax_straight_through(self, device, dtype): 10473*da0073e9SAndroid Build Coastguard Worker num_draws = 100 10474*da0073e9SAndroid Build Coastguard Worker 10475*da0073e9SAndroid Build Coastguard Worker logits = torch.tensor([[0.2, 0.8, 0.1]], device=device) 10476*da0073e9SAndroid Build Coastguard Worker logits = logits.reshape([1, 3]) 10477*da0073e9SAndroid Build Coastguard Worker logits = logits.to(dtype).requires_grad_() 10478*da0073e9SAndroid Build Coastguard Worker probs = logits.softmax(dim=-1) 10479*da0073e9SAndroid Build Coastguard Worker 10480*da0073e9SAndroid Build Coastguard Worker counts = torch.zeros_like(logits) 10481*da0073e9SAndroid Build Coastguard Worker for _ in range(num_draws): 10482*da0073e9SAndroid Build Coastguard Worker y_draw = F.gumbel_softmax(logits, hard=True) 10483*da0073e9SAndroid Build Coastguard Worker counts = counts + y_draw 10484*da0073e9SAndroid Build Coastguard Worker 10485*da0073e9SAndroid Build Coastguard Worker # All values positive 10486*da0073e9SAndroid Build Coastguard Worker self.assertGreaterEqual(y_draw.min(), 0) 10487*da0073e9SAndroid Build Coastguard Worker # Each experiment should result in 1 draw. 10488*da0073e9SAndroid Build Coastguard Worker self.assertEqual(counts.sum(), num_draws, atol=torch.finfo(counts.dtype).eps, rtol=0) 10489*da0073e9SAndroid Build Coastguard Worker 10490*da0073e9SAndroid Build Coastguard Worker # check results is asymptotically as expected. 10491*da0073e9SAndroid Build Coastguard Worker expected = probs * num_draws 10492*da0073e9SAndroid Build Coastguard Worker # ~z is approximately N(0,1) for unbiased count 10493*da0073e9SAndroid Build Coastguard Worker z = (counts - expected) / (expected * (1 - probs)).sqrt() 10494*da0073e9SAndroid Build Coastguard Worker # A (lazy) approximate 99% two-sided test: 10495*da0073e9SAndroid Build Coastguard Worker # occurs with prob alpha~>=0.01 if unbiased 10496*da0073e9SAndroid Build Coastguard Worker self.assertLess(z.abs().max().item(), 2.58) 10497*da0073e9SAndroid Build Coastguard Worker 10498*da0073e9SAndroid Build Coastguard Worker def _test_gumbel_softmax_grad(self, device, dtype): 10499*da0073e9SAndroid Build Coastguard Worker # "hard" and "not hard" should propagate same gradient. 10500*da0073e9SAndroid Build Coastguard Worker logits_soft = torch.zeros(10, 10, dtype=dtype, device=device, requires_grad=True) 10501*da0073e9SAndroid Build Coastguard Worker logits_hard = torch.zeros(10, 10, dtype=dtype, device=device, requires_grad=True) 10502*da0073e9SAndroid Build Coastguard Worker 10503*da0073e9SAndroid Build Coastguard Worker seed = torch.random.get_rng_state() 10504*da0073e9SAndroid Build Coastguard Worker y_soft = F.gumbel_softmax(logits_soft, hard=False) 10505*da0073e9SAndroid Build Coastguard Worker torch.random.set_rng_state(seed) 10506*da0073e9SAndroid Build Coastguard Worker y_hard = F.gumbel_softmax(logits_hard, hard=True) 10507*da0073e9SAndroid Build Coastguard Worker 10508*da0073e9SAndroid Build Coastguard Worker y_soft.sum().backward() 10509*da0073e9SAndroid Build Coastguard Worker y_hard.sum().backward() 10510*da0073e9SAndroid Build Coastguard Worker 10511*da0073e9SAndroid Build Coastguard Worker # 2eps = 1x addition + 1x subtraction. 10512*da0073e9SAndroid Build Coastguard Worker tol = 2 * torch.finfo(dtype).eps 10513*da0073e9SAndroid Build Coastguard Worker self.assertEqual(logits_soft.grad, logits_hard.grad, atol=tol, rtol=0) 10514*da0073e9SAndroid Build Coastguard Worker 10515*da0073e9SAndroid Build Coastguard Worker @dtypesIfCUDA(torch.half, torch.float, torch.double) 10516*da0073e9SAndroid Build Coastguard Worker @dtypesIfMPS(torch.float) 10517*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float, torch.double) 10518*da0073e9SAndroid Build Coastguard Worker def test_gumbel_softmax(self, device, dtype): 10519*da0073e9SAndroid Build Coastguard Worker self._test_gumbel_softmax_st_shapes(device, dtype, shape=[5], dim=0, count_expected=1) 10520*da0073e9SAndroid Build Coastguard Worker self._test_gumbel_softmax_st_shapes(device, dtype, shape=[5], dim=-1, count_expected=1) 10521*da0073e9SAndroid Build Coastguard Worker self._test_gumbel_softmax_st_shapes(device, dtype, shape=[5, 4], dim=1, count_expected=5) 10522*da0073e9SAndroid Build Coastguard Worker self._test_gumbel_softmax_st_shapes(device, dtype, shape=[5, 4, 3], dim=1, count_expected=5 * 3) 10523*da0073e9SAndroid Build Coastguard Worker self._test_gumbel_softmax_st_shapes(device, dtype, shape=[5, 4, 3], dim=-1, count_expected=5 * 4) 10524*da0073e9SAndroid Build Coastguard Worker self._test_gumbel_softmax_straight_through(device, dtype) 10525*da0073e9SAndroid Build Coastguard Worker self._test_gumbel_softmax_grad(device, dtype) 10526*da0073e9SAndroid Build Coastguard Worker 10527*da0073e9SAndroid Build Coastguard Worker def _test_rnn_retain_variables(self, device, dtype): 10528*da0073e9SAndroid Build Coastguard Worker rnns = [nn.LSTM(10, 20, num_layers=2).to(device, dtype), 10529*da0073e9SAndroid Build Coastguard Worker nn.GRU(10, 20, num_layers=2).to(device, dtype), 10530*da0073e9SAndroid Build Coastguard Worker nn.RNN(10, 20, num_layers=2).to(device, dtype)] 10531*da0073e9SAndroid Build Coastguard Worker for rnn in rnns: 10532*da0073e9SAndroid Build Coastguard Worker input = torch.randn(5, 6, 10, device=device, dtype=dtype, requires_grad=True) 10533*da0073e9SAndroid Build Coastguard Worker output = rnn(input) 10534*da0073e9SAndroid Build Coastguard Worker output[0].sum().backward(retain_graph=True) 10535*da0073e9SAndroid Build Coastguard Worker grads = [input.grad.data.clone()] + [p.grad.data.clone() for p in rnn.parameters()] 10536*da0073e9SAndroid Build Coastguard Worker for _ in range(4): 10537*da0073e9SAndroid Build Coastguard Worker rnn.zero_grad() 10538*da0073e9SAndroid Build Coastguard Worker input.grad.data.zero_() 10539*da0073e9SAndroid Build Coastguard Worker output[0].sum().backward(retain_graph=True) 10540*da0073e9SAndroid Build Coastguard Worker grads2 = [input.grad.data] + [p.grad.data for p in rnn.parameters()] 10541*da0073e9SAndroid Build Coastguard Worker self.assertEqual(grads, grads2) 10542*da0073e9SAndroid Build Coastguard Worker 10543*da0073e9SAndroid Build Coastguard Worker @dtypesIfCUDA(torch.half, torch.float, torch.double) 10544*da0073e9SAndroid Build Coastguard Worker @dtypesIfMPS(torch.half, torch.float) 10545*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.double) 10546*da0073e9SAndroid Build Coastguard Worker def test_rnn_retain_variables(self, device, dtype): 10547*da0073e9SAndroid Build Coastguard Worker self._test_rnn_retain_variables(device, dtype) 10548*da0073e9SAndroid Build Coastguard Worker 10549*da0073e9SAndroid Build Coastguard Worker if self.device_type == 'cuda' and self.has_cudnn(): 10550*da0073e9SAndroid Build Coastguard Worker with torch.backends.cudnn.flags(enabled=False): 10551*da0073e9SAndroid Build Coastguard Worker self._test_rnn_retain_variables(device, dtype) 10552*da0073e9SAndroid Build Coastguard Worker 10553*da0073e9SAndroid Build Coastguard Worker @onlyCUDA 10554*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.double) 10555*da0073e9SAndroid Build Coastguard Worker def test_lstmcell_backward_only_one_output_grad(self, device, dtype): 10556*da0073e9SAndroid Build Coastguard Worker # checks that undefined gradients doen't hamper the backward 10557*da0073e9SAndroid Build Coastguard Worker # see #11872 10558*da0073e9SAndroid Build Coastguard Worker l = torch.nn.LSTMCell(2, 3).to(device).to(dtype=dtype) 10559*da0073e9SAndroid Build Coastguard Worker s = torch.randn(1, 2, device=device, dtype=dtype, requires_grad=True) 10560*da0073e9SAndroid Build Coastguard Worker for i in range(2): 10561*da0073e9SAndroid Build Coastguard Worker out = l(s)[i] 10562*da0073e9SAndroid Build Coastguard Worker out.sum().backward() 10563*da0073e9SAndroid Build Coastguard Worker self.assertFalse(s.grad is None or s.grad.abs().sum().item() == 0) 10564*da0073e9SAndroid Build Coastguard Worker 10565*da0073e9SAndroid Build Coastguard Worker def _test_rnn_mod(self, mod, inp): 10566*da0073e9SAndroid Build Coastguard Worker def flatten_out(mod, inp): 10567*da0073e9SAndroid Build Coastguard Worker out = mod(inp) 10568*da0073e9SAndroid Build Coastguard Worker return tuple([t if isinstance(t, torch.Tensor) else tt for t in out for tt in t]) 10569*da0073e9SAndroid Build Coastguard Worker gradcheckfunc = partial(flatten_out, mod) 10570*da0073e9SAndroid Build Coastguard Worker with torch.backends.cudnn.flags(enabled=False): 10571*da0073e9SAndroid Build Coastguard Worker gradcheck(gradcheckfunc, inp, check_batched_grad=False) 10572*da0073e9SAndroid Build Coastguard Worker gradgradcheck(gradcheckfunc, inp, check_batched_grad=False) 10573*da0073e9SAndroid Build Coastguard Worker 10574*da0073e9SAndroid Build Coastguard Worker if inp.is_cuda and not TEST_WITH_ROCM: 10575*da0073e9SAndroid Build Coastguard Worker # Assert that we have good error message around unsupported CuDNN double backward 10576*da0073e9SAndroid Build Coastguard Worker # NB: we trigger double backward using .backward() instead of autograd.grad due to 10577*da0073e9SAndroid Build Coastguard Worker # https://github.com/pytorch/pytorch/issues/37874 10578*da0073e9SAndroid Build Coastguard Worker with torch.backends.cudnn.flags(enabled=True): 10579*da0073e9SAndroid Build Coastguard Worker result = gradcheckfunc(inp) 10580*da0073e9SAndroid Build Coastguard Worker result[0].sum().backward(create_graph=True) 10581*da0073e9SAndroid Build Coastguard Worker grad0 = next(mod.parameters()).grad 10582*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, 10583*da0073e9SAndroid Build Coastguard Worker "please disable the CuDNN backend temporarily"): 10584*da0073e9SAndroid Build Coastguard Worker grad0.sum().backward() 10585*da0073e9SAndroid Build Coastguard Worker 10586*da0073e9SAndroid Build Coastguard Worker # Here we avoid the backward(create_graph=True) memory leak 10587*da0073e9SAndroid Build Coastguard Worker # described in https://github.com/pytorch/pytorch/issues/7343 10588*da0073e9SAndroid Build Coastguard Worker for param in mod.parameters(): 10589*da0073e9SAndroid Build Coastguard Worker param.grad = None 10590*da0073e9SAndroid Build Coastguard Worker inp.grad = None 10591*da0073e9SAndroid Build Coastguard Worker 10592*da0073e9SAndroid Build Coastguard Worker # Merge into OpInfo? 10593*da0073e9SAndroid Build Coastguard Worker @skipMeta # LSTM cell reuses output which was resized 10594*da0073e9SAndroid Build Coastguard Worker @expectedFailureMPS # TypeError: the MPS framework doesn't support float64 10595*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.double) 10596*da0073e9SAndroid Build Coastguard Worker def test_LSTM_grad_and_gradgrad(self, device, dtype): 10597*da0073e9SAndroid Build Coastguard Worker hsize = 4 10598*da0073e9SAndroid Build Coastguard Worker inp = torch.rand(1, 3, hsize, device=device, dtype=dtype, requires_grad=True) 10599*da0073e9SAndroid Build Coastguard Worker for bias in [True, False]: 10600*da0073e9SAndroid Build Coastguard Worker mod = torch.nn.LSTM(hsize, hsize, bias=bias).to(device).to(dtype) 10601*da0073e9SAndroid Build Coastguard Worker self._test_rnn_mod(mod, inp) 10602*da0073e9SAndroid Build Coastguard Worker 10603*da0073e9SAndroid Build Coastguard Worker @skipMeta # GRU cell reuses output which was resized 10604*da0073e9SAndroid Build Coastguard Worker @expectedFailureMPS # TypeError: the MPS framework doesn't support float64 10605*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.double) 10606*da0073e9SAndroid Build Coastguard Worker def test_GRU_grad_and_gradgrad(self, device, dtype): 10607*da0073e9SAndroid Build Coastguard Worker hsize = 4 10608*da0073e9SAndroid Build Coastguard Worker inp = torch.rand(1, 3, hsize, device=device, dtype=dtype, requires_grad=True) 10609*da0073e9SAndroid Build Coastguard Worker for bias in [True, False]: 10610*da0073e9SAndroid Build Coastguard Worker mod = torch.nn.GRU(hsize, hsize, bias=bias).to(device).to(dtype) 10611*da0073e9SAndroid Build Coastguard Worker self._test_rnn_mod(mod, inp) 10612*da0073e9SAndroid Build Coastguard Worker 10613*da0073e9SAndroid Build Coastguard Worker @skipMeta 10614*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float32, torch.bfloat16) 10615*da0073e9SAndroid Build Coastguard Worker @onlyCPU 10616*da0073e9SAndroid Build Coastguard Worker def test_LSTM_differentiable_backward_using_oneDNN(self, dtype): 10617*da0073e9SAndroid Build Coastguard Worker batch = 10 10618*da0073e9SAndroid Build Coastguard Worker seq_len = 12 10619*da0073e9SAndroid Build Coastguard Worker input = 3 10620*da0073e9SAndroid Build Coastguard Worker Net = nn.LSTM(input, 3, 20, batch_first=True) 10621*da0073e9SAndroid Build Coastguard Worker import copy 10622*da0073e9SAndroid Build Coastguard Worker Net_clone = copy.deepcopy(Net) 10623*da0073e9SAndroid Build Coastguard Worker x = torch.rand(batch, seq_len, input) 10624*da0073e9SAndroid Build Coastguard Worker x1 = x.clone().requires_grad_(True) 10625*da0073e9SAndroid Build Coastguard Worker x2 = x.clone().requires_grad_(True) 10626*da0073e9SAndroid Build Coastguard Worker 10627*da0073e9SAndroid Build Coastguard Worker torch._C._set_mkldnn_enabled(False) 10628*da0073e9SAndroid Build Coastguard Worker out1, _ = Net(x1) 10629*da0073e9SAndroid Build Coastguard Worker der_out1 = torch.autograd.grad(out1, x1, 10630*da0073e9SAndroid Build Coastguard Worker grad_outputs=torch.ones_like(out1), 10631*da0073e9SAndroid Build Coastguard Worker retain_graph=True, 10632*da0073e9SAndroid Build Coastguard Worker create_graph=True)[0] 10633*da0073e9SAndroid Build Coastguard Worker loss1 = der_out1.sum() 10634*da0073e9SAndroid Build Coastguard Worker loss1.backward(retain_graph=True) 10635*da0073e9SAndroid Build Coastguard Worker 10636*da0073e9SAndroid Build Coastguard Worker torch._C._set_mkldnn_enabled(True) 10637*da0073e9SAndroid Build Coastguard Worker out2, _ = Net(x2) 10638*da0073e9SAndroid Build Coastguard Worker der_out2 = torch.autograd.grad(out2, x2, 10639*da0073e9SAndroid Build Coastguard Worker grad_outputs=torch.ones_like(out2), 10640*da0073e9SAndroid Build Coastguard Worker retain_graph=True, 10641*da0073e9SAndroid Build Coastguard Worker create_graph=True)[0] 10642*da0073e9SAndroid Build Coastguard Worker loss2 = der_out2.sum() 10643*da0073e9SAndroid Build Coastguard Worker loss2.backward(retain_graph=True) 10644*da0073e9SAndroid Build Coastguard Worker assert torch.allclose(der_out1, der_out2) 10645*da0073e9SAndroid Build Coastguard Worker assert torch.allclose(x1.grad, x2.grad) 10646*da0073e9SAndroid Build Coastguard Worker 10647*da0073e9SAndroid Build Coastguard Worker @onlyCUDA 10648*da0073e9SAndroid Build Coastguard Worker def test_upsamplingNearest1d_launch_config(self, device): 10649*da0073e9SAndroid Build Coastguard Worker m = nn.Upsample(scale_factor=2) 10650*da0073e9SAndroid Build Coastguard Worker inp = torch.rand(2**25, 1, 1, device=device) 10651*da0073e9SAndroid Build Coastguard Worker out = m(inp) 10652*da0073e9SAndroid Build Coastguard Worker inp_ref = inp.cpu() 10653*da0073e9SAndroid Build Coastguard Worker out_ref = m(inp_ref) 10654*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out_ref, out) 10655*da0073e9SAndroid Build Coastguard Worker 10656*da0073e9SAndroid Build Coastguard Worker @onlyCUDA 10657*da0073e9SAndroid Build Coastguard Worker def test_upsamplingNearest2d_launch_config(self, device): 10658*da0073e9SAndroid Build Coastguard Worker m = nn.Upsample(scale_factor=2) 10659*da0073e9SAndroid Build Coastguard Worker inp = torch.rand(2**25, 1, 1, 1, device=device) 10660*da0073e9SAndroid Build Coastguard Worker out = m(inp) 10661*da0073e9SAndroid Build Coastguard Worker inp_ref = inp.cpu() 10662*da0073e9SAndroid Build Coastguard Worker out_ref = m(inp_ref) 10663*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out_ref, out) 10664*da0073e9SAndroid Build Coastguard Worker 10665*da0073e9SAndroid Build Coastguard Worker @onlyCUDA 10666*da0073e9SAndroid Build Coastguard Worker @gcIfJetson 10667*da0073e9SAndroid Build Coastguard Worker def test_upsamplingNearest3d_launch_config(self, device): 10668*da0073e9SAndroid Build Coastguard Worker m = nn.Upsample(scale_factor=2) 10669*da0073e9SAndroid Build Coastguard Worker inp = torch.rand(2**25, 1, 1, 1, 1, device=device) 10670*da0073e9SAndroid Build Coastguard Worker out = m(inp) 10671*da0073e9SAndroid Build Coastguard Worker inp_ref = inp.cpu() 10672*da0073e9SAndroid Build Coastguard Worker out_ref = m(inp_ref) 10673*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out_ref, out) 10674*da0073e9SAndroid Build Coastguard Worker 10675*da0073e9SAndroid Build Coastguard Worker @unittest.expectedFailure 10676*da0073e9SAndroid Build Coastguard Worker @skipIfRocm 10677*da0073e9SAndroid Build Coastguard Worker @onlyCUDA 10678*da0073e9SAndroid Build Coastguard Worker def test_upsamplingNearest2d_launch_fail(self, device): 10679*da0073e9SAndroid Build Coastguard Worker m = nn.Upsample(scale_factor=2) 10680*da0073e9SAndroid Build Coastguard Worker # launch grid_y == 2**16 (larger than maximum y-dimension limit 65535) 10681*da0073e9SAndroid Build Coastguard Worker inp = torch.rand(1, 1, 2**15, 2**8, device=device) 10682*da0073e9SAndroid Build Coastguard Worker out = m(inp) 10683*da0073e9SAndroid Build Coastguard Worker 10684*da0073e9SAndroid Build Coastguard Worker @onlyCUDA 10685*da0073e9SAndroid Build Coastguard Worker @skipCUDAIfNotRocm 10686*da0073e9SAndroid Build Coastguard Worker def test_upsamplingNearest2d_launch_rocm(self, device): 10687*da0073e9SAndroid Build Coastguard Worker # test_upsamplingNearest2d_launch_fail should run OK on ROCm 10688*da0073e9SAndroid Build Coastguard Worker m = nn.Upsample(scale_factor=2) 10689*da0073e9SAndroid Build Coastguard Worker inp = torch.rand(1, 1, 2**15, 2**8, device=device) 10690*da0073e9SAndroid Build Coastguard Worker out = m(inp) 10691*da0073e9SAndroid Build Coastguard Worker 10692*da0073e9SAndroid Build Coastguard Worker @onlyCUDA 10693*da0073e9SAndroid Build Coastguard Worker @skipCUDAIfCudnnVersionLessThan(7600) 10694*da0073e9SAndroid Build Coastguard Worker def test_CTCLoss_cudnn(self, device): 10695*da0073e9SAndroid Build Coastguard Worker def _helper(zero_infinity): 10696*da0073e9SAndroid Build Coastguard Worker target_lengths = [30, 25, 20] 10697*da0073e9SAndroid Build Coastguard Worker input_lengths = [50, 50, 50] 10698*da0073e9SAndroid Build Coastguard Worker targets = torch.randint(1, 15, (sum(target_lengths),), dtype=torch.int) 10699*da0073e9SAndroid Build Coastguard Worker log_probs = torch.randn(50, 3, 15, dtype=torch.float, device=device).log_softmax(2).requires_grad_() 10700*da0073e9SAndroid Build Coastguard Worker 10701*da0073e9SAndroid Build Coastguard Worker log_probs_ref = log_probs.detach().clone().requires_grad_() 10702*da0073e9SAndroid Build Coastguard Worker 10703*da0073e9SAndroid Build Coastguard Worker with torch.backends.cudnn.flags(enabled=True): 10704*da0073e9SAndroid Build Coastguard Worker res = torch.nn.functional.ctc_loss(log_probs, targets, input_lengths, target_lengths, zero_infinity=zero_infinity) 10705*da0073e9SAndroid Build Coastguard Worker res.backward() 10706*da0073e9SAndroid Build Coastguard Worker 10707*da0073e9SAndroid Build Coastguard Worker expected = ctcloss_reference(log_probs, targets.cuda(), input_lengths, target_lengths).float() 10708*da0073e9SAndroid Build Coastguard Worker 10709*da0073e9SAndroid Build Coastguard Worker with torch.backends.cudnn.flags(enabled=False): 10710*da0073e9SAndroid Build Coastguard Worker res2 = torch.nn.functional.ctc_loss(log_probs_ref, targets.cuda().long(), input_lengths, target_lengths, 10711*da0073e9SAndroid Build Coastguard Worker zero_infinity=zero_infinity) 10712*da0073e9SAndroid Build Coastguard Worker res2.backward() 10713*da0073e9SAndroid Build Coastguard Worker 10714*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res, expected) 10715*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res2, res) 10716*da0073e9SAndroid Build Coastguard Worker self.assertEqual(log_probs.grad, log_probs_ref.grad) 10717*da0073e9SAndroid Build Coastguard Worker 10718*da0073e9SAndroid Build Coastguard Worker _helper(zero_infinity=True) 10719*da0073e9SAndroid Build Coastguard Worker _helper(zero_infinity=False) 10720*da0073e9SAndroid Build Coastguard Worker 10721*da0073e9SAndroid Build Coastguard Worker def _CTCLoss_gen_losses(self, device, input_length, vocab_size, target_length, reduction, use_module_form): 10722*da0073e9SAndroid Build Coastguard Worker batch_size = 1 10723*da0073e9SAndroid Build Coastguard Worker log_probs = torch.randn(input_length, batch_size, vocab_size, dtype=torch.float, device=device) \ 10724*da0073e9SAndroid Build Coastguard Worker .log_softmax(2).requires_grad_() 10725*da0073e9SAndroid Build Coastguard Worker targets = torch.randint(low=1, high=vocab_size - 1, size=(batch_size, target_length), 10726*da0073e9SAndroid Build Coastguard Worker dtype=torch.int, device=device) 10727*da0073e9SAndroid Build Coastguard Worker input_lengths = batch_size * [input_length] 10728*da0073e9SAndroid Build Coastguard Worker target_lengths = batch_size * [target_length] 10729*da0073e9SAndroid Build Coastguard Worker 10730*da0073e9SAndroid Build Coastguard Worker log_probs_no_bd = log_probs.squeeze(1).detach().clone().requires_grad_() 10731*da0073e9SAndroid Build Coastguard Worker targets_no_bd = targets.squeeze(0).detach().clone() 10732*da0073e9SAndroid Build Coastguard Worker input_lengths_no_bd = torch.tensor(input_length) 10733*da0073e9SAndroid Build Coastguard Worker target_lengths_no_bd = torch.tensor(target_length) 10734*da0073e9SAndroid Build Coastguard Worker 10735*da0073e9SAndroid Build Coastguard Worker # currently only length 2 and 1 right now, but left flexible for additional potential cases 10736*da0073e9SAndroid Build Coastguard Worker log_probs_refs = [log_probs.detach().clone().requires_grad_() for _ in range(2)] 10737*da0073e9SAndroid Build Coastguard Worker log_probs_no_bd_refs = [log_probs_no_bd.detach().clone().requires_grad_() for _ in range(1)] 10738*da0073e9SAndroid Build Coastguard Worker 10739*da0073e9SAndroid Build Coastguard Worker losses = [] 10740*da0073e9SAndroid Build Coastguard Worker losses_no_bd = [] 10741*da0073e9SAndroid Build Coastguard Worker 10742*da0073e9SAndroid Build Coastguard Worker has_cuda = torch.cuda.is_available() 10743*da0073e9SAndroid Build Coastguard Worker has_cudnn = has_cuda and 'cuda' in device and self.has_cudnn() 10744*da0073e9SAndroid Build Coastguard Worker # cudnn requires a cpu target 10745*da0073e9SAndroid Build Coastguard Worker if has_cuda and has_cudnn: 10746*da0073e9SAndroid Build Coastguard Worker targets = targets.cpu() 10747*da0073e9SAndroid Build Coastguard Worker targets_no_bd = targets_no_bd.cpu() 10748*da0073e9SAndroid Build Coastguard Worker 10749*da0073e9SAndroid Build Coastguard Worker ctc_loss = ( 10750*da0073e9SAndroid Build Coastguard Worker nn.CTCLoss(reduction=reduction, zero_infinity=True) 10751*da0073e9SAndroid Build Coastguard Worker if use_module_form 10752*da0073e9SAndroid Build Coastguard Worker else partial(torch.nn.functional.ctc_loss, reduction=reduction, zero_infinity=True) 10753*da0073e9SAndroid Build Coastguard Worker ) 10754*da0073e9SAndroid Build Coastguard Worker 10755*da0073e9SAndroid Build Coastguard Worker with torch.backends.cudnn.flags(enabled=has_cudnn): 10756*da0073e9SAndroid Build Coastguard Worker # batched case. log_probs.shape = (T, N, C), targets = (N, S), input_lengths/target_lengths = (N,) 10757*da0073e9SAndroid Build Coastguard Worker losses.append(ctc_loss(log_probs_refs[0], targets, input_lengths, target_lengths)) 10758*da0073e9SAndroid Build Coastguard Worker # batched case. input.shape = (T, N, C), targets = (S,), input_lengths/target_lengths = (N,) 10759*da0073e9SAndroid Build Coastguard Worker losses.append(ctc_loss(log_probs_refs[1], targets_no_bd, input_lengths, target_lengths)) 10760*da0073e9SAndroid Build Coastguard Worker # unbatched case. input.shape = (T, C), targets = (S,), input_lengths/target_lengths = (N,) 10761*da0073e9SAndroid Build Coastguard Worker losses_no_bd.append(ctc_loss(log_probs_no_bd_refs[0], targets_no_bd, 10762*da0073e9SAndroid Build Coastguard Worker input_lengths_no_bd, target_lengths_no_bd)) 10763*da0073e9SAndroid Build Coastguard Worker 10764*da0073e9SAndroid Build Coastguard Worker for loss in losses + losses_no_bd: 10765*da0073e9SAndroid Build Coastguard Worker loss.backward() 10766*da0073e9SAndroid Build Coastguard Worker 10767*da0073e9SAndroid Build Coastguard Worker return losses, losses_no_bd, log_probs_refs, log_probs_no_bd_refs 10768*da0073e9SAndroid Build Coastguard Worker 10769*da0073e9SAndroid Build Coastguard Worker def _assertEqual_list(self, expected, list_to_compare, atol=None, rtol=None): 10770*da0073e9SAndroid Build Coastguard Worker for ele in list_to_compare: 10771*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected, ele, atol=atol, rtol=rtol) 10772*da0073e9SAndroid Build Coastguard Worker 10773*da0073e9SAndroid Build Coastguard Worker @expectedFailureMPS # NotImplementedError: aten::_ctc_loss https://github.com/pytorch/pytorch/issues/77764 10774*da0073e9SAndroid Build Coastguard Worker @parametrize_test("reduction", ['none', 'mean', 'sum']) 10775*da0073e9SAndroid Build Coastguard Worker @parametrize_test("use_module_form", [True, False]) 10776*da0073e9SAndroid Build Coastguard Worker def test_CTCLoss_no_batch_dim(self, device, reduction, use_module_form): 10777*da0073e9SAndroid Build Coastguard Worker input_length = 40 10778*da0073e9SAndroid Build Coastguard Worker vocab_size = 3 10779*da0073e9SAndroid Build Coastguard Worker target_length = 12 10780*da0073e9SAndroid Build Coastguard Worker 10781*da0073e9SAndroid Build Coastguard Worker args = self._CTCLoss_gen_losses(device, input_length, vocab_size, target_length, reduction, use_module_form) 10782*da0073e9SAndroid Build Coastguard Worker losses, losses_no_bd, log_probs_refs, log_probs_no_bd_refs = args 10783*da0073e9SAndroid Build Coastguard Worker 10784*da0073e9SAndroid Build Coastguard Worker # test output values 10785*da0073e9SAndroid Build Coastguard Worker self._assertEqual_list(losses[0], losses[1:], atol=1e-4, rtol=0) 10786*da0073e9SAndroid Build Coastguard Worker self._assertEqual_list(losses[0].squeeze(0), losses_no_bd, atol=1e-4, rtol=0) 10787*da0073e9SAndroid Build Coastguard Worker 10788*da0073e9SAndroid Build Coastguard Worker # test gradient values 10789*da0073e9SAndroid Build Coastguard Worker self._assertEqual_list(log_probs_refs[0].grad, [t.grad for t in log_probs_refs[1:]], atol=1e-4, rtol=0) 10790*da0073e9SAndroid Build Coastguard Worker self._assertEqual_list( 10791*da0073e9SAndroid Build Coastguard Worker log_probs_refs[0].grad.squeeze(1), 10792*da0073e9SAndroid Build Coastguard Worker [t.grad for t in log_probs_no_bd_refs], 10793*da0073e9SAndroid Build Coastguard Worker atol=1e-4, 10794*da0073e9SAndroid Build Coastguard Worker rtol=0, 10795*da0073e9SAndroid Build Coastguard Worker ) 10796*da0073e9SAndroid Build Coastguard Worker 10797*da0073e9SAndroid Build Coastguard Worker # checking the output's shape 10798*da0073e9SAndroid Build Coastguard Worker # batch dim case should be (N,). no batch dim case should be () 10799*da0073e9SAndroid Build Coastguard Worker self._assertEqual_list((1,) if reduction == 'none' else (), [loss.shape for loss in losses]) 10800*da0073e9SAndroid Build Coastguard Worker self._assertEqual_list((), [loss.shape for loss in losses_no_bd]) 10801*da0073e9SAndroid Build Coastguard Worker 10802*da0073e9SAndroid Build Coastguard Worker # checking the gradient's shape 10803*da0073e9SAndroid Build Coastguard Worker # batch dim case should have shape (T, N, C). no batch dim case should have shape (T, C) 10804*da0073e9SAndroid Build Coastguard Worker self._assertEqual_list((input_length, 1, vocab_size), [t.grad.shape for t in log_probs_refs]) 10805*da0073e9SAndroid Build Coastguard Worker self._assertEqual_list((input_length, vocab_size), [t.grad.shape for t in log_probs_no_bd_refs]) 10806*da0073e9SAndroid Build Coastguard Worker 10807*da0073e9SAndroid Build Coastguard Worker def _ordered_sequence(self, device, dtype): 10808*da0073e9SAndroid Build Coastguard Worker """Create ordered list of random sequences""" 10809*da0073e9SAndroid Build Coastguard Worker seqs = [torch.empty(random.randint(1, 6), device=device, dtype=dtype) 10810*da0073e9SAndroid Build Coastguard Worker for _ in range(5)] 10811*da0073e9SAndroid Build Coastguard Worker seqs = [s.random_(-128, 128) for s in seqs] 10812*da0073e9SAndroid Build Coastguard Worker ordered = sorted(seqs, key=len, reverse=True) 10813*da0073e9SAndroid Build Coastguard Worker return ordered 10814*da0073e9SAndroid Build Coastguard Worker 10815*da0073e9SAndroid Build Coastguard Worker def _padded_sequence(self, device, dtype): 10816*da0073e9SAndroid Build Coastguard Worker """Create Tensor of random padded sequences""" 10817*da0073e9SAndroid Build Coastguard Worker ordered = self._ordered_sequence(device, dtype) 10818*da0073e9SAndroid Build Coastguard Worker lengths = [len(i) for i in ordered] 10819*da0073e9SAndroid Build Coastguard Worker padded_tensor = rnn_utils.pad_sequence(ordered) 10820*da0073e9SAndroid Build Coastguard Worker return padded_tensor, lengths 10821*da0073e9SAndroid Build Coastguard Worker 10822*da0073e9SAndroid Build Coastguard Worker @onlyCUDA 10823*da0073e9SAndroid Build Coastguard Worker def test_device_mask(self, device): 10824*da0073e9SAndroid Build Coastguard Worker for enforce_sorted in [True, False]: 10825*da0073e9SAndroid Build Coastguard Worker padded, lengths = self._padded_sequence('cpu', torch.float) 10826*da0073e9SAndroid Build Coastguard Worker packed = rnn_utils.pack_padded_sequence( 10827*da0073e9SAndroid Build Coastguard Worker padded, lengths, enforce_sorted=enforce_sorted) 10828*da0073e9SAndroid Build Coastguard Worker self.assertFalse(packed.is_cuda) 10829*da0073e9SAndroid Build Coastguard Worker packed = packed.to(device) 10830*da0073e9SAndroid Build Coastguard Worker self.assertTrue(packed.is_cuda) 10831*da0073e9SAndroid Build Coastguard Worker unpacked, _ = rnn_utils.pad_packed_sequence(packed) 10832*da0073e9SAndroid Build Coastguard Worker self.assertTrue(unpacked.is_cuda) 10833*da0073e9SAndroid Build Coastguard Worker self.assertEqual(unpacked.dtype, torch.float) 10834*da0073e9SAndroid Build Coastguard Worker 10835*da0073e9SAndroid Build Coastguard Worker @onlyCUDA 10836*da0073e9SAndroid Build Coastguard Worker def test_overwrite_module_params_on_conversion_cpu_device(self, device): 10837*da0073e9SAndroid Build Coastguard Worker # Test that under the current default settings 10838*da0073e9SAndroid Build Coastguard Worker # (`torch.__future__.get_overwrite_module_params_on_conversion() == False`), 10839*da0073e9SAndroid Build Coastguard Worker # a view to a module's parameters is not pointing to the same storage as 10840*da0073e9SAndroid Build Coastguard Worker # its base variable after converting the module to a different device. 10841*da0073e9SAndroid Build Coastguard Worker m = nn.Linear(20, 10) 10842*da0073e9SAndroid Build Coastguard Worker mw = m.weight[:] 10843*da0073e9SAndroid Build Coastguard Worker m.to(device) 10844*da0073e9SAndroid Build Coastguard Worker with torch.no_grad(): 10845*da0073e9SAndroid Build Coastguard Worker # Without using `torch.no_grad()`, this will leak CUDA memory. 10846*da0073e9SAndroid Build Coastguard Worker # (Issue is filed at https://github.com/pytorch/pytorch/issues/21875) 10847*da0073e9SAndroid Build Coastguard Worker mw[0][0] = 5 10848*da0073e9SAndroid Build Coastguard Worker self.assertTrue(mw[0][0].device.type == "cpu") 10849*da0073e9SAndroid Build Coastguard Worker self.assertTrue(mw._base[0][0].device.type == "cuda") 10850*da0073e9SAndroid Build Coastguard Worker 10851*da0073e9SAndroid Build Coastguard Worker try: 10852*da0073e9SAndroid Build Coastguard Worker torch.__future__.set_overwrite_module_params_on_conversion(True) 10853*da0073e9SAndroid Build Coastguard Worker 10854*da0073e9SAndroid Build Coastguard Worker # Test that if `torch.__future__.get_overwrite_module_params_on_conversion() == True`, 10855*da0073e9SAndroid Build Coastguard Worker # a view to a module's parameters is still pointing to the same storage as 10856*da0073e9SAndroid Build Coastguard Worker # its base variable after converting the module to a different device. 10857*da0073e9SAndroid Build Coastguard Worker m = nn.Linear(20, 10) 10858*da0073e9SAndroid Build Coastguard Worker mw = m.weight[:] 10859*da0073e9SAndroid Build Coastguard Worker m.to(device) 10860*da0073e9SAndroid Build Coastguard Worker with torch.no_grad(): 10861*da0073e9SAndroid Build Coastguard Worker mw[0][0] = 5 10862*da0073e9SAndroid Build Coastguard Worker self.assertTrue(mw[0][0] == mw._base[0][0]) 10863*da0073e9SAndroid Build Coastguard Worker 10864*da0073e9SAndroid Build Coastguard Worker # Test that if `torch.__future__.get_overwrite_module_params_on_conversion() == True`, 10865*da0073e9SAndroid Build Coastguard Worker # `cpu_module.to("cuda")` doesn't preserve previous references to 10866*da0073e9SAndroid Build Coastguard Worker # `cpu_module`'s parameters or gradients. 10867*da0073e9SAndroid Build Coastguard Worker m = nn.Linear(20, 10) 10868*da0073e9SAndroid Build Coastguard Worker m.weight.grad = torch.randn(10, 20) 10869*da0073e9SAndroid Build Coastguard Worker weight_ref = m.weight 10870*da0073e9SAndroid Build Coastguard Worker weight_grad_ref = m.weight.grad 10871*da0073e9SAndroid Build Coastguard Worker m.to(device) 10872*da0073e9SAndroid Build Coastguard Worker self.assertNotEqual(weight_ref.device, m.weight.device) 10873*da0073e9SAndroid Build Coastguard Worker self.assertNotEqual(weight_grad_ref.device, m.weight.grad.device) 10874*da0073e9SAndroid Build Coastguard Worker finally: 10875*da0073e9SAndroid Build Coastguard Worker torch.__future__.set_overwrite_module_params_on_conversion(False) 10876*da0073e9SAndroid Build Coastguard Worker 10877*da0073e9SAndroid Build Coastguard Worker @onlyCUDA 10878*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.half, torch.float) 10879*da0073e9SAndroid Build Coastguard Worker def test_softmax(self, device, dtype): 10880*da0073e9SAndroid Build Coastguard Worker input = torch.rand(32, 100, device=device, dtype=dtype, requires_grad=True) 10881*da0073e9SAndroid Build Coastguard Worker inputf = input.to(torch.float).detach().requires_grad_(True) 10882*da0073e9SAndroid Build Coastguard Worker out = F.softmax(input, dim=-1, dtype=torch.float) 10883*da0073e9SAndroid Build Coastguard Worker outf = F.softmax(inputf, dim=-1) 10884*da0073e9SAndroid Build Coastguard Worker # should be bitwise equal 10885*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out, outf, atol=0, rtol=0) 10886*da0073e9SAndroid Build Coastguard Worker gO = torch.empty_like(outf).uniform_() 10887*da0073e9SAndroid Build Coastguard Worker out.backward(gO) 10888*da0073e9SAndroid Build Coastguard Worker outf.backward(gO) 10889*da0073e9SAndroid Build Coastguard Worker # should be bitwise equal 10890*da0073e9SAndroid Build Coastguard Worker self.assertEqual(input.grad, inputf.grad.to(dtype), atol=0, rtol=0) 10891*da0073e9SAndroid Build Coastguard Worker 10892*da0073e9SAndroid Build Coastguard Worker def _test_batchnorm_grad(self, device, dtype=torch.double): 10893*da0073e9SAndroid Build Coastguard Worker bs, n_feat, size_feat = 4, 5, 6 10894*da0073e9SAndroid Build Coastguard Worker input = torch.arange(bs * n_feat * size_feat, device=device, 10895*da0073e9SAndroid Build Coastguard Worker requires_grad=True, dtype=dtype).view(bs, n_feat, size_feat) 10896*da0073e9SAndroid Build Coastguard Worker weight = torch.arange(1, n_feat + 1, device=device, requires_grad=True, dtype=dtype) 10897*da0073e9SAndroid Build Coastguard Worker bias = torch.arange(n_feat, device=device, requires_grad=True, dtype=dtype) 10898*da0073e9SAndroid Build Coastguard Worker running_mean = 1 - torch.arange(n_feat, device=device, dtype=dtype) 10899*da0073e9SAndroid Build Coastguard Worker running_var = 2 * torch.arange(n_feat, device=device, dtype=dtype) 10900*da0073e9SAndroid Build Coastguard Worker for training in [False, True]: 10901*da0073e9SAndroid Build Coastguard Worker _assertGradAndGradgradChecks(self, F.batch_norm, (input, running_mean, running_var, weight, bias, 10902*da0073e9SAndroid Build Coastguard Worker training, 0.1, 0.0001)) 10903*da0073e9SAndroid Build Coastguard Worker 10904*da0073e9SAndroid Build Coastguard Worker @expectedFailureMPS # TypeError: the MPS framework doesn't support float64 10905*da0073e9SAndroid Build Coastguard Worker def test_batchnorm_grad(self, device): 10906*da0073e9SAndroid Build Coastguard Worker self._test_batchnorm_grad(device) 10907*da0073e9SAndroid Build Coastguard Worker 10908*da0073e9SAndroid Build Coastguard Worker if self.device_type == 'cuda' and self.has_cudnn(): 10909*da0073e9SAndroid Build Coastguard Worker with torch.backends.cudnn.flags(enabled=False): 10910*da0073e9SAndroid Build Coastguard Worker self._test_batchnorm_grad(device) 10911*da0073e9SAndroid Build Coastguard Worker 10912*da0073e9SAndroid Build Coastguard Worker @onlyCUDA 10913*da0073e9SAndroid Build Coastguard Worker def test_layernorm_half_precision(self): 10914*da0073e9SAndroid Build Coastguard Worker width = 128 10915*da0073e9SAndroid Build Coastguard Worker input = torch.rand(1, 5, width, device="cuda", dtype=torch.half) * 0.1 10916*da0073e9SAndroid Build Coastguard Worker normalized_shape = (width,) 10917*da0073e9SAndroid Build Coastguard Worker weight = torch.ones(width, device="cuda", dtype=torch.half) 10918*da0073e9SAndroid Build Coastguard Worker bias = torch.zeros(width, device="cuda", dtype=torch.half) 10919*da0073e9SAndroid Build Coastguard Worker eps = 1e-5 10920*da0073e9SAndroid Build Coastguard Worker 10921*da0073e9SAndroid Build Coastguard Worker output_fp16 = torch.layer_norm(input, normalized_shape, weight, bias, eps) 10922*da0073e9SAndroid Build Coastguard Worker output_fp32 = torch.layer_norm(input.float(), normalized_shape, weight.float(), bias.float(), eps).half() 10923*da0073e9SAndroid Build Coastguard Worker self.assertEqual(output_fp16, output_fp32, atol=0, rtol=0) 10924*da0073e9SAndroid Build Coastguard Worker 10925*da0073e9SAndroid Build Coastguard Worker @onlyCUDA 10926*da0073e9SAndroid Build Coastguard Worker def test_layernorm_weight_bias(self): 10927*da0073e9SAndroid Build Coastguard Worker width = 128 10928*da0073e9SAndroid Build Coastguard Worker input = torch.rand(1, 5, width, device="cuda", dtype=torch.float32) * 0.1 10929*da0073e9SAndroid Build Coastguard Worker normalized_shape = (width,) 10930*da0073e9SAndroid Build Coastguard Worker data = torch.randn(width, device="cuda", dtype=torch.float32) 10931*da0073e9SAndroid Build Coastguard Worker weight = torch.ones(width, device="cuda", dtype=torch.float32) 10932*da0073e9SAndroid Build Coastguard Worker bias = torch.zeros(width, device="cuda", dtype=torch.float32) 10933*da0073e9SAndroid Build Coastguard Worker eps = 1e-5 10934*da0073e9SAndroid Build Coastguard Worker 10935*da0073e9SAndroid Build Coastguard Worker out_none_weight = torch.layer_norm(input, normalized_shape, None, data, eps) 10936*da0073e9SAndroid Build Coastguard Worker out_one_weight = torch.layer_norm(input, normalized_shape, weight, data, eps) 10937*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out_none_weight, out_one_weight) 10938*da0073e9SAndroid Build Coastguard Worker 10939*da0073e9SAndroid Build Coastguard Worker out_none_bias = torch.layer_norm(input, normalized_shape, data, None, eps) 10940*da0073e9SAndroid Build Coastguard Worker out_zero_bias = torch.layer_norm(input, normalized_shape, data, bias, eps) 10941*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out_none_bias, out_zero_bias) 10942*da0073e9SAndroid Build Coastguard Worker 10943*da0073e9SAndroid Build Coastguard Worker @expectedFailureMPS # TypeError: the MPS framework doesn't support float64 10944*da0073e9SAndroid Build Coastguard Worker def test_hardsigmoid_grad(self, device): 10945*da0073e9SAndroid Build Coastguard Worker inputs = (torch.randn(4, 16, 16, device=device, dtype=torch.double) - 0.5) * 10 10946*da0073e9SAndroid Build Coastguard Worker inputs.requires_grad = True 10947*da0073e9SAndroid Build Coastguard Worker self.assertTrue(gradcheck(F.hardsigmoid, (inputs,))) 10948*da0073e9SAndroid Build Coastguard Worker 10949*da0073e9SAndroid Build Coastguard Worker # currently fails on XLA 10950*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 10951*da0073e9SAndroid Build Coastguard Worker def test_hardswish_grad(self, device): 10952*da0073e9SAndroid Build Coastguard Worker inputs = (torch.randn(4, 16, 16, device=device, dtype=torch.double) - 0.5) * 10 10953*da0073e9SAndroid Build Coastguard Worker inputs.requires_grad = True 10954*da0073e9SAndroid Build Coastguard Worker self.assertTrue(gradcheck(F.hardswish, (inputs,))) 10955*da0073e9SAndroid Build Coastguard Worker 10956*da0073e9SAndroid Build Coastguard Worker 10957*da0073e9SAndroid Build Coastguard Worker def _test_batchnorm_eval(self, ndim, device, dtype, module_dtype=None): 10958*da0073e9SAndroid Build Coastguard Worker module_dtype = module_dtype or dtype 10959*da0073e9SAndroid Build Coastguard Worker module = nn.BatchNorm1d(3).to(device, module_dtype) 10960*da0073e9SAndroid Build Coastguard Worker module.eval() 10961*da0073e9SAndroid Build Coastguard Worker 10962*da0073e9SAndroid Build Coastguard Worker data = torch.rand([3] * ndim, device=device, dtype=dtype, requires_grad=True) 10963*da0073e9SAndroid Build Coastguard Worker grad = torch.rand([3] * ndim, device=device, dtype=dtype) 10964*da0073e9SAndroid Build Coastguard Worker 10965*da0073e9SAndroid Build Coastguard Worker # 1st pass 10966*da0073e9SAndroid Build Coastguard Worker res1 = module(data) 10967*da0073e9SAndroid Build Coastguard Worker res1.backward(grad) 10968*da0073e9SAndroid Build Coastguard Worker grad1 = data.grad.clone() 10969*da0073e9SAndroid Build Coastguard Worker 10970*da0073e9SAndroid Build Coastguard Worker # 2nd pass 10971*da0073e9SAndroid Build Coastguard Worker if data.grad is not None: 10972*da0073e9SAndroid Build Coastguard Worker data.grad.data.zero_() 10973*da0073e9SAndroid Build Coastguard Worker 10974*da0073e9SAndroid Build Coastguard Worker res2 = module(data) 10975*da0073e9SAndroid Build Coastguard Worker res2.backward(grad) 10976*da0073e9SAndroid Build Coastguard Worker grad2 = data.grad.clone() 10977*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res1, res2) 10978*da0073e9SAndroid Build Coastguard Worker self.assertEqual(grad1, grad2) 10979*da0073e9SAndroid Build Coastguard Worker 10980*da0073e9SAndroid Build Coastguard Worker # track_running_stats=False 10981*da0073e9SAndroid Build Coastguard Worker module = nn.BatchNorm1d(3, track_running_stats=False).to(device, module_dtype) 10982*da0073e9SAndroid Build Coastguard Worker 10983*da0073e9SAndroid Build Coastguard Worker data = torch.rand(4, 3, device=device, dtype=dtype, requires_grad=True) 10984*da0073e9SAndroid Build Coastguard Worker grad = torch.rand(4, 3, device=device, dtype=dtype) 10985*da0073e9SAndroid Build Coastguard Worker 10986*da0073e9SAndroid Build Coastguard Worker # 1st pass 10987*da0073e9SAndroid Build Coastguard Worker res1 = module(data) 10988*da0073e9SAndroid Build Coastguard Worker res1.backward(grad) 10989*da0073e9SAndroid Build Coastguard Worker grad1 = data.grad.clone() 10990*da0073e9SAndroid Build Coastguard Worker 10991*da0073e9SAndroid Build Coastguard Worker # set eval 10992*da0073e9SAndroid Build Coastguard Worker module.eval() 10993*da0073e9SAndroid Build Coastguard Worker 10994*da0073e9SAndroid Build Coastguard Worker # 2nd pass 10995*da0073e9SAndroid Build Coastguard Worker if data.grad is not None: 10996*da0073e9SAndroid Build Coastguard Worker data.grad.data.zero_() 10997*da0073e9SAndroid Build Coastguard Worker 10998*da0073e9SAndroid Build Coastguard Worker res2 = module(data) 10999*da0073e9SAndroid Build Coastguard Worker res2.backward(grad) 11000*da0073e9SAndroid Build Coastguard Worker grad2 = data.grad.clone() 11001*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res1, res2) 11002*da0073e9SAndroid Build Coastguard Worker self.assertEqual(grad1, grad2) 11003*da0073e9SAndroid Build Coastguard Worker 11004*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float) 11005*da0073e9SAndroid Build Coastguard Worker @dtypesIfCUDA(torch.float, torch.bfloat16) 11006*da0073e9SAndroid Build Coastguard Worker def test_batchnorm_eval(self, device, dtype): 11007*da0073e9SAndroid Build Coastguard Worker self._test_batchnorm_eval(2, device, dtype) 11008*da0073e9SAndroid Build Coastguard Worker self._test_batchnorm_eval(3, device, dtype) 11009*da0073e9SAndroid Build Coastguard Worker 11010*da0073e9SAndroid Build Coastguard Worker if self.device_type == 'cuda' and self.has_cudnn(): 11011*da0073e9SAndroid Build Coastguard Worker with torch.backends.cudnn.flags(enabled=False): 11012*da0073e9SAndroid Build Coastguard Worker self._test_batchnorm_eval(2, device, dtype) 11013*da0073e9SAndroid Build Coastguard Worker self._test_batchnorm_eval(3, device, dtype) 11014*da0073e9SAndroid Build Coastguard Worker 11015*da0073e9SAndroid Build Coastguard Worker @onlyCUDA 11016*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.bfloat16, torch.half) 11017*da0073e9SAndroid Build Coastguard Worker def test_batchnorm_eval_mixed(self, device, dtype): 11018*da0073e9SAndroid Build Coastguard Worker # Test bfloat16 input with float module 11019*da0073e9SAndroid Build Coastguard Worker self._test_batchnorm_eval(2, device, dtype, torch.float) 11020*da0073e9SAndroid Build Coastguard Worker self._test_batchnorm_eval(3, device, dtype, torch.float) 11021*da0073e9SAndroid Build Coastguard Worker 11022*da0073e9SAndroid Build Coastguard Worker if self.device_type == 'cuda' and self.has_cudnn(): 11023*da0073e9SAndroid Build Coastguard Worker with torch.backends.cudnn.flags(enabled=False): 11024*da0073e9SAndroid Build Coastguard Worker self._test_batchnorm_eval(2, device, dtype, torch.float) 11025*da0073e9SAndroid Build Coastguard Worker self._test_batchnorm_eval(3, device, dtype, torch.float) 11026*da0073e9SAndroid Build Coastguard Worker 11027*da0073e9SAndroid Build Coastguard Worker def _test_batchnorm_affine(self, ndim, device, dtype, module_dtype=None): 11028*da0073e9SAndroid Build Coastguard Worker # Compare affine against no-op weights and bias 11029*da0073e9SAndroid Build Coastguard Worker module_dtype = module_dtype or dtype 11030*da0073e9SAndroid Build Coastguard Worker module = nn.BatchNorm1d(3, affine=False).to(device, module_dtype) 11031*da0073e9SAndroid Build Coastguard Worker module_affine = nn.BatchNorm1d(3, affine=True).to(device, module_dtype) 11032*da0073e9SAndroid Build Coastguard Worker with torch.no_grad(): 11033*da0073e9SAndroid Build Coastguard Worker module_affine.weight.fill_(1.0) 11034*da0073e9SAndroid Build Coastguard Worker module_affine.bias.zero_() 11035*da0073e9SAndroid Build Coastguard Worker 11036*da0073e9SAndroid Build Coastguard Worker data = torch.rand([3] * ndim, device=device, dtype=dtype, requires_grad=True) 11037*da0073e9SAndroid Build Coastguard Worker grad = torch.ones_like(data, requires_grad=False) 11038*da0073e9SAndroid Build Coastguard Worker 11039*da0073e9SAndroid Build Coastguard Worker # With weights all ones and bias all zeros 11040*da0073e9SAndroid Build Coastguard Worker res1 = module_affine(data) 11041*da0073e9SAndroid Build Coastguard Worker res1.backward(grad) 11042*da0073e9SAndroid Build Coastguard Worker grad1 = data.grad.clone() 11043*da0073e9SAndroid Build Coastguard Worker data.grad.zero_() 11044*da0073e9SAndroid Build Coastguard Worker 11045*da0073e9SAndroid Build Coastguard Worker # Without any weights or bias 11046*da0073e9SAndroid Build Coastguard Worker res2 = module(data) 11047*da0073e9SAndroid Build Coastguard Worker res2.backward(grad) 11048*da0073e9SAndroid Build Coastguard Worker grad2 = data.grad 11049*da0073e9SAndroid Build Coastguard Worker 11050*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res1, res2) 11051*da0073e9SAndroid Build Coastguard Worker self.assertEqual(grad1, grad2) 11052*da0073e9SAndroid Build Coastguard Worker 11053*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float) 11054*da0073e9SAndroid Build Coastguard Worker @dtypesIfCUDA(torch.float, torch.bfloat16) 11055*da0073e9SAndroid Build Coastguard Worker def test_batchnorm_affine(self, device, dtype): 11056*da0073e9SAndroid Build Coastguard Worker self._test_batchnorm_affine(2, device, dtype) 11057*da0073e9SAndroid Build Coastguard Worker self._test_batchnorm_affine(3, device, dtype) 11058*da0073e9SAndroid Build Coastguard Worker 11059*da0073e9SAndroid Build Coastguard Worker if self.device_type == 'cuda' and self.has_cudnn(): 11060*da0073e9SAndroid Build Coastguard Worker with torch.backends.cudnn.flags(enabled=False): 11061*da0073e9SAndroid Build Coastguard Worker self._test_batchnorm_affine(2, device, dtype) 11062*da0073e9SAndroid Build Coastguard Worker self._test_batchnorm_affine(3, device, dtype) 11063*da0073e9SAndroid Build Coastguard Worker 11064*da0073e9SAndroid Build Coastguard Worker @onlyCUDA 11065*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.bfloat16, torch.half) 11066*da0073e9SAndroid Build Coastguard Worker def test_batchnorm_affine_mixed(self, device, dtype): 11067*da0073e9SAndroid Build Coastguard Worker cudnn_enabled = [False] 11068*da0073e9SAndroid Build Coastguard Worker if self.device_type == 'cuda' and self.has_cudnn(): 11069*da0073e9SAndroid Build Coastguard Worker # TODO: Test fails with cudnn, see gh-62034 11070*da0073e9SAndroid Build Coastguard Worker # cudnn_enabled = [False, True] 11071*da0073e9SAndroid Build Coastguard Worker pass 11072*da0073e9SAndroid Build Coastguard Worker 11073*da0073e9SAndroid Build Coastguard Worker # Test bfloat16 input with float module 11074*da0073e9SAndroid Build Coastguard Worker for enabled in cudnn_enabled: 11075*da0073e9SAndroid Build Coastguard Worker with torch.backends.cudnn.flags(enabled=enabled): 11076*da0073e9SAndroid Build Coastguard Worker self._test_batchnorm_affine(2, device, dtype, torch.float) 11077*da0073e9SAndroid Build Coastguard Worker self._test_batchnorm_affine(3, device, dtype, torch.float) 11078*da0073e9SAndroid Build Coastguard Worker 11079*da0073e9SAndroid Build Coastguard Worker def _test_batchnorm_simple_average(self, device, dtype, module_dtype=None): 11080*da0073e9SAndroid Build Coastguard Worker module_dtype = module_dtype or dtype 11081*da0073e9SAndroid Build Coastguard Worker module = nn.BatchNorm1d(3, momentum=None).to(dtype=module_dtype, device=device) 11082*da0073e9SAndroid Build Coastguard Worker zeros = torch.zeros(3, dtype=module_dtype, device=device) 11083*da0073e9SAndroid Build Coastguard Worker ones = torch.ones(3, dtype=module_dtype, device=device) 11084*da0073e9SAndroid Build Coastguard Worker self.assertEqual(module.running_mean, zeros) 11085*da0073e9SAndroid Build Coastguard Worker self.assertEqual(module.running_var, ones) 11086*da0073e9SAndroid Build Coastguard Worker 11087*da0073e9SAndroid Build Coastguard Worker data1 = torch.rand(4, 3, dtype=dtype, device=device) 11088*da0073e9SAndroid Build Coastguard Worker data2 = torch.rand(4, 3, dtype=dtype, device=device) 11089*da0073e9SAndroid Build Coastguard Worker 11090*da0073e9SAndroid Build Coastguard Worker # 1st pass 11091*da0073e9SAndroid Build Coastguard Worker res1 = module(data1) 11092*da0073e9SAndroid Build Coastguard Worker running_mean1 = module.running_mean.clone() 11093*da0073e9SAndroid Build Coastguard Worker running_var1 = module.running_var.clone() 11094*da0073e9SAndroid Build Coastguard Worker self.assertNotEqual(running_mean1, zeros) 11095*da0073e9SAndroid Build Coastguard Worker self.assertNotEqual(running_var1, ones) 11096*da0073e9SAndroid Build Coastguard Worker 11097*da0073e9SAndroid Build Coastguard Worker # reset stats 11098*da0073e9SAndroid Build Coastguard Worker module.reset_running_stats() 11099*da0073e9SAndroid Build Coastguard Worker self.assertEqual(module.running_mean, zeros) 11100*da0073e9SAndroid Build Coastguard Worker self.assertEqual(module.running_var, ones) 11101*da0073e9SAndroid Build Coastguard Worker 11102*da0073e9SAndroid Build Coastguard Worker # 2nd pass 11103*da0073e9SAndroid Build Coastguard Worker res2 = module(data2) 11104*da0073e9SAndroid Build Coastguard Worker running_mean2 = module.running_mean.clone() 11105*da0073e9SAndroid Build Coastguard Worker running_var2 = module.running_var.clone() 11106*da0073e9SAndroid Build Coastguard Worker self.assertNotEqual(running_mean2, zeros) 11107*da0073e9SAndroid Build Coastguard Worker self.assertNotEqual(running_var2, ones) 11108*da0073e9SAndroid Build Coastguard Worker 11109*da0073e9SAndroid Build Coastguard Worker # reset stats 11110*da0073e9SAndroid Build Coastguard Worker module.reset_running_stats() 11111*da0073e9SAndroid Build Coastguard Worker self.assertEqual(module.running_mean, zeros) 11112*da0073e9SAndroid Build Coastguard Worker self.assertEqual(module.running_var, ones) 11113*da0073e9SAndroid Build Coastguard Worker 11114*da0073e9SAndroid Build Coastguard Worker # 3rd (combined) pass 11115*da0073e9SAndroid Build Coastguard Worker res3 = module(data1) 11116*da0073e9SAndroid Build Coastguard Worker res4 = module(data2) 11117*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res3, res1) 11118*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res4, res2) 11119*da0073e9SAndroid Build Coastguard Worker self.assertEqual(module.running_mean, (running_mean1 + running_mean2) / 2) 11120*da0073e9SAndroid Build Coastguard Worker self.assertEqual(module.running_var, (running_var1 + running_var2) / 2) 11121*da0073e9SAndroid Build Coastguard Worker 11122*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float) 11123*da0073e9SAndroid Build Coastguard Worker @dtypesIfCUDA(torch.float, torch.bfloat16) 11124*da0073e9SAndroid Build Coastguard Worker def test_batchnorm_simple_average(self, device, dtype): 11125*da0073e9SAndroid Build Coastguard Worker self._test_batchnorm_simple_average(device, dtype) 11126*da0073e9SAndroid Build Coastguard Worker 11127*da0073e9SAndroid Build Coastguard Worker if self.device_type == 'cuda' and self.has_cudnn(): 11128*da0073e9SAndroid Build Coastguard Worker with torch.backends.cudnn.flags(enabled=False): 11129*da0073e9SAndroid Build Coastguard Worker self._test_batchnorm_simple_average(device, dtype) 11130*da0073e9SAndroid Build Coastguard Worker 11131*da0073e9SAndroid Build Coastguard Worker @onlyCUDA 11132*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.bfloat16, torch.half) 11133*da0073e9SAndroid Build Coastguard Worker def test_batchnorm_simple_average_mixed(self, device, dtype): 11134*da0073e9SAndroid Build Coastguard Worker self._test_batchnorm_simple_average(device, dtype, torch.float) 11135*da0073e9SAndroid Build Coastguard Worker 11136*da0073e9SAndroid Build Coastguard Worker if self.device_type == 'cuda' and self.has_cudnn(): 11137*da0073e9SAndroid Build Coastguard Worker with torch.backends.cudnn.flags(enabled=False): 11138*da0073e9SAndroid Build Coastguard Worker self._test_batchnorm_simple_average(device, dtype, torch.float) 11139*da0073e9SAndroid Build Coastguard Worker 11140*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 11141*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float, torch.double) 11142*da0073e9SAndroid Build Coastguard Worker def test_grid_sample_nan_inf(self, device, dtype): 11143*da0073e9SAndroid Build Coastguard Worker input = torch.zeros([1, 1, 3, 3], device=device, dtype=dtype) 11144*da0073e9SAndroid Build Coastguard Worker grid = torch.tensor([[[[nan, 0], [0, inf]]]], device=device, dtype=dtype) 11145*da0073e9SAndroid Build Coastguard Worker for padding_mode in ('reflection', 'border', 'zeros'): 11146*da0073e9SAndroid Build Coastguard Worker sample = torch.nn.functional.grid_sample(input=input, grid=grid, mode='nearest', 11147*da0073e9SAndroid Build Coastguard Worker padding_mode=padding_mode, align_corners=False) 11148*da0073e9SAndroid Build Coastguard Worker self.assertEqual(sample, torch.zeros([1, 1, 1, 2], device=device, dtype=dtype)) 11149*da0073e9SAndroid Build Coastguard Worker 11150*da0073e9SAndroid Build Coastguard Worker @expectedFailureMPS # NotImplementedError aten::_ctc_loss https://github.com/pytorch/pytorch/issues/77764 11151*da0073e9SAndroid Build Coastguard Worker def test_CTCLoss_empty_target(self, device): 11152*da0073e9SAndroid Build Coastguard Worker target_lengths = [0, 0, 0] 11153*da0073e9SAndroid Build Coastguard Worker input_lengths = [50, 50, 50] 11154*da0073e9SAndroid Build Coastguard Worker targets = torch.randint(1, 15, (0,), dtype=torch.long, device=device) 11155*da0073e9SAndroid Build Coastguard Worker log_probs = torch.randn(50, 3, 15, dtype=torch.double, device=device).log_softmax(2) 11156*da0073e9SAndroid Build Coastguard Worker loss = torch.nn.functional.ctc_loss(log_probs, targets, input_lengths, target_lengths, reduction='none') 11157*da0073e9SAndroid Build Coastguard Worker self.assertTrue((loss >= 0).all().item()) 11158*da0073e9SAndroid Build Coastguard Worker self.assertEqual(-log_probs.sum(0)[:, 0], loss) 11159*da0073e9SAndroid Build Coastguard Worker 11160*da0073e9SAndroid Build Coastguard Worker target_lengths = [0, 9, 0] 11161*da0073e9SAndroid Build Coastguard Worker input_lengths = [50, 50, 50] 11162*da0073e9SAndroid Build Coastguard Worker targets = torch.randint(1, 15, (9,), dtype=torch.long, device=device) 11163*da0073e9SAndroid Build Coastguard Worker log_probs = torch.randn(50, 3, 15, dtype=torch.double, device=device).log_softmax(2) 11164*da0073e9SAndroid Build Coastguard Worker loss = torch.nn.functional.ctc_loss(log_probs, targets, input_lengths, target_lengths, reduction='none') 11165*da0073e9SAndroid Build Coastguard Worker self.assertTrue((loss >= 0).all().item()) 11166*da0073e9SAndroid Build Coastguard Worker self.assertEqual(-log_probs.sum(0)[[0, 2], 0], loss[[0, 2]]) 11167*da0073e9SAndroid Build Coastguard Worker 11168*da0073e9SAndroid Build Coastguard Worker # Merge into OpInfo? 11169*da0073e9SAndroid Build Coastguard Worker @skipCUDAIf(True, """Test is flaky on Linux and Windows, typical error message: 11170*da0073e9SAndroid Build Coastguard Worker https://github.com/pytorch/pytorch/issues/34870""") 11171*da0073e9SAndroid Build Coastguard Worker @expectedFailureMPS # NotImplementedError aten::_ctc_loss https://github.com/pytorch/pytorch/issues/77764 11172*da0073e9SAndroid Build Coastguard Worker def test_ctc_loss(self, device): 11173*da0073e9SAndroid Build Coastguard Worker batch_size = 64 11174*da0073e9SAndroid Build Coastguard Worker num_labels = 101 11175*da0073e9SAndroid Build Coastguard Worker target_length = 15 11176*da0073e9SAndroid Build Coastguard Worker gradcheck_input_size = 10 11177*da0073e9SAndroid Build Coastguard Worker 11178*da0073e9SAndroid Build Coastguard Worker ZERO_NONE = 0 11179*da0073e9SAndroid Build Coastguard Worker ZERO_SOME = 1 11180*da0073e9SAndroid Build Coastguard Worker ZERO_ALL = 2 11181*da0073e9SAndroid Build Coastguard Worker 11182*da0073e9SAndroid Build Coastguard Worker # input_length, vary_lengths, zero_lengths 11183*da0073e9SAndroid Build Coastguard Worker tests = [(150, False, ZERO_NONE), 11184*da0073e9SAndroid Build Coastguard Worker (150, True, ZERO_NONE), 11185*da0073e9SAndroid Build Coastguard Worker (50, True, ZERO_SOME), 11186*da0073e9SAndroid Build Coastguard Worker (50, True, ZERO_ALL)] 11187*da0073e9SAndroid Build Coastguard Worker 11188*da0073e9SAndroid Build Coastguard Worker if 'cuda' in device: 11189*da0073e9SAndroid Build Coastguard Worker tests += [(50, False, ZERO_NONE), 11190*da0073e9SAndroid Build Coastguard Worker (50, True, ZERO_NONE), 11191*da0073e9SAndroid Build Coastguard Worker (150, True, ZERO_SOME), 11192*da0073e9SAndroid Build Coastguard Worker (150, True, ZERO_ALL)] 11193*da0073e9SAndroid Build Coastguard Worker 11194*da0073e9SAndroid Build Coastguard Worker for input_length, vary_lengths, zero_mode in tests: 11195*da0073e9SAndroid Build Coastguard Worker targets = torch.randint(1, num_labels, (batch_size, target_length), 11196*da0073e9SAndroid Build Coastguard Worker device=device, dtype=torch.long) 11197*da0073e9SAndroid Build Coastguard Worker x = torch.randn(gradcheck_input_size, dtype=torch.double, device=device, requires_grad=True) 11198*da0073e9SAndroid Build Coastguard Worker tile_factors = torch.randn(input_length * batch_size * num_labels // gradcheck_input_size + 1, 11199*da0073e9SAndroid Build Coastguard Worker device=device) 11200*da0073e9SAndroid Build Coastguard Worker input_lengths = [(torch.randint(input_length // 2, input_length + 1, ()).item() 11201*da0073e9SAndroid Build Coastguard Worker if vary_lengths or i == 0 else input_length) for i in range(batch_size)] 11202*da0073e9SAndroid Build Coastguard Worker if zero_mode == ZERO_ALL: 11203*da0073e9SAndroid Build Coastguard Worker target_lengths = [0 for _ in range(batch_size)] 11204*da0073e9SAndroid Build Coastguard Worker else: 11205*da0073e9SAndroid Build Coastguard Worker target_lengths = [(torch.randint(target_length // 2, target_length + 1, ()).item() 11206*da0073e9SAndroid Build Coastguard Worker if vary_lengths else target_length) for _ in range(batch_size)] 11207*da0073e9SAndroid Build Coastguard Worker if zero_mode == ZERO_SOME: 11208*da0073e9SAndroid Build Coastguard Worker idxes = torch.randint(0, batch_size, (10,)) 11209*da0073e9SAndroid Build Coastguard Worker for i in idxes: 11210*da0073e9SAndroid Build Coastguard Worker target_lengths[i] = 0 11211*da0073e9SAndroid Build Coastguard Worker 11212*da0073e9SAndroid Build Coastguard Worker def ctc_after_softmax(x): 11213*da0073e9SAndroid Build Coastguard Worker x_full = ((x[:, None] * tile_factors[None, :]).view(-1)[:input_length * batch_size * num_labels] 11214*da0073e9SAndroid Build Coastguard Worker .view(input_length, batch_size, num_labels)) 11215*da0073e9SAndroid Build Coastguard Worker log_probs = torch.log_softmax(x_full, 2) 11216*da0073e9SAndroid Build Coastguard Worker return torch.nn.functional.ctc_loss(log_probs, targets, input_lengths, target_lengths) 11217*da0073e9SAndroid Build Coastguard Worker 11218*da0073e9SAndroid Build Coastguard Worker gradcheck(ctc_after_softmax, [x]) 11219*da0073e9SAndroid Build Coastguard Worker 11220*da0073e9SAndroid Build Coastguard Worker @onlyCUDA 11221*da0073e9SAndroid Build Coastguard Worker @skipCUDAIfRocm(msg="skipped Cudnn test on ROCm") 11222*da0073e9SAndroid Build Coastguard Worker @skipCUDAIfCudnnVersionLessThan(7600) 11223*da0073e9SAndroid Build Coastguard Worker def test_ctc_loss_cudnn(self, device): 11224*da0073e9SAndroid Build Coastguard Worker batch_size = 16 11225*da0073e9SAndroid Build Coastguard Worker input_length = 30 11226*da0073e9SAndroid Build Coastguard Worker num_labels = 101 11227*da0073e9SAndroid Build Coastguard Worker target_length = 15 11228*da0073e9SAndroid Build Coastguard Worker targets = torch.randint(1, num_labels, (batch_size * target_length,), 11229*da0073e9SAndroid Build Coastguard Worker device='cuda', dtype=torch.long) 11230*da0073e9SAndroid Build Coastguard Worker log_probs = torch.log_softmax(torch.randn(input_length, batch_size, num_labels, device='cuda', dtype=torch.float), 2) 11231*da0073e9SAndroid Build Coastguard Worker log_probs.requires_grad_() 11232*da0073e9SAndroid Build Coastguard Worker 11233*da0073e9SAndroid Build Coastguard Worker input_lengths = batch_size * [input_length] 11234*da0073e9SAndroid Build Coastguard Worker target_lengths = batch_size * [target_length] 11235*da0073e9SAndroid Build Coastguard Worker grad_out = torch.randn(batch_size, device='cuda', dtype=torch.float) 11236*da0073e9SAndroid Build Coastguard Worker with torch.backends.cudnn.flags(enabled=False): 11237*da0073e9SAndroid Build Coastguard Worker loss_native = torch.nn.functional.ctc_loss(log_probs, targets, input_lengths, target_lengths, reduction='none') 11238*da0073e9SAndroid Build Coastguard Worker grad_native, = torch.autograd.grad(loss_native, log_probs, grad_out) 11239*da0073e9SAndroid Build Coastguard Worker loss_cudnn = torch.nn.functional.ctc_loss(log_probs, targets.to('cpu', torch.int32), 11240*da0073e9SAndroid Build Coastguard Worker input_lengths, target_lengths, reduction='none') 11241*da0073e9SAndroid Build Coastguard Worker self.assertTrue("Cudnn" in str(loss_cudnn.grad_fn)) 11242*da0073e9SAndroid Build Coastguard Worker grad_cudnn, = torch.autograd.grad(loss_cudnn, log_probs, grad_out) 11243*da0073e9SAndroid Build Coastguard Worker self.assertEqual(grad_cudnn, grad_native, atol=1e-4, rtol=0) 11244*da0073e9SAndroid Build Coastguard Worker 11245*da0073e9SAndroid Build Coastguard Worker @onlyCUDA 11246*da0073e9SAndroid Build Coastguard Worker @skipCUDAIfRocm(msg="skipped Cudnn test on ROCm") 11247*da0073e9SAndroid Build Coastguard Worker @skipCUDAIfCudnnVersionLessThan(8000) 11248*da0073e9SAndroid Build Coastguard Worker def test_ctc_loss_cudnn_tensor(self, device): 11249*da0073e9SAndroid Build Coastguard Worker batch_size = 16 11250*da0073e9SAndroid Build Coastguard Worker input_length = 30 11251*da0073e9SAndroid Build Coastguard Worker num_labels = 101 11252*da0073e9SAndroid Build Coastguard Worker target_length = 15 11253*da0073e9SAndroid Build Coastguard Worker targets = torch.randint(1, num_labels, (batch_size * target_length,), 11254*da0073e9SAndroid Build Coastguard Worker device='cuda', dtype=torch.long) 11255*da0073e9SAndroid Build Coastguard Worker log_probs = torch.log_softmax(torch.randn(input_length, batch_size, num_labels, device='cuda', dtype=torch.float), 2) 11256*da0073e9SAndroid Build Coastguard Worker log_probs.requires_grad_() 11257*da0073e9SAndroid Build Coastguard Worker 11258*da0073e9SAndroid Build Coastguard Worker input_lengths = batch_size * [input_length] 11259*da0073e9SAndroid Build Coastguard Worker input_lengths = torch.linspace(start=15, end=input_length, steps=batch_size, dtype=torch.long, device='cuda') 11260*da0073e9SAndroid Build Coastguard Worker target_lengths = torch.tensor(batch_size * [target_length], dtype=torch.long, device='cuda') 11261*da0073e9SAndroid Build Coastguard Worker grad_out = torch.randn(batch_size, device='cuda', dtype=torch.float) 11262*da0073e9SAndroid Build Coastguard Worker with torch.backends.cudnn.flags(enabled=False): 11263*da0073e9SAndroid Build Coastguard Worker loss_native = torch.nn.functional.ctc_loss(log_probs, targets, input_lengths, target_lengths, reduction='none') 11264*da0073e9SAndroid Build Coastguard Worker grad_native, = torch.autograd.grad(loss_native, log_probs, grad_out) 11265*da0073e9SAndroid Build Coastguard Worker loss_cudnn = torch.nn.functional.ctc_loss(log_probs, 11266*da0073e9SAndroid Build Coastguard Worker targets.to('cuda', torch.int32), 11267*da0073e9SAndroid Build Coastguard Worker input_lengths.to('cuda', torch.int32), 11268*da0073e9SAndroid Build Coastguard Worker target_lengths.to('cuda', torch.int32), 11269*da0073e9SAndroid Build Coastguard Worker reduction='none') 11270*da0073e9SAndroid Build Coastguard Worker self.assertTrue("Cudnn" in str(loss_cudnn.grad_fn)) 11271*da0073e9SAndroid Build Coastguard Worker grad_cudnn, = torch.autograd.grad(loss_cudnn, log_probs, grad_out) 11272*da0073e9SAndroid Build Coastguard Worker self.assertEqual(grad_cudnn, grad_native, atol=1e-4, rtol=0) 11273*da0073e9SAndroid Build Coastguard Worker 11274*da0073e9SAndroid Build Coastguard Worker @expectedFailureMPS # RuntimeError: LSTM with projections is not currently supported with MPS. 11275*da0073e9SAndroid Build Coastguard Worker @dtypesIfCUDA(torch.half, torch.float, torch.double) 11276*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float) 11277*da0073e9SAndroid Build Coastguard Worker @tf32_on_and_off(0.005) 11278*da0073e9SAndroid Build Coastguard Worker @skipIfTorchDynamo("TorchDynamo fails here for unknown reasons") 11279*da0073e9SAndroid Build Coastguard Worker def test_variable_sequence(self, device, dtype): 11280*da0073e9SAndroid Build Coastguard Worker def pad(var, length): 11281*da0073e9SAndroid Build Coastguard Worker if var.size(0) == length: 11282*da0073e9SAndroid Build Coastguard Worker return var 11283*da0073e9SAndroid Build Coastguard Worker return torch.cat([var, var.new_zeros(length - var.size(0), *var.size()[1:])]) 11284*da0073e9SAndroid Build Coastguard Worker 11285*da0073e9SAndroid Build Coastguard Worker def maybe_index_tuple(maybe_tuple_of_tensors, index): 11286*da0073e9SAndroid Build Coastguard Worker if maybe_tuple_of_tensors is None: 11287*da0073e9SAndroid Build Coastguard Worker return None 11288*da0073e9SAndroid Build Coastguard Worker return tuple(maybe_tuple_of_tensors[j][:, index:index + 1, :].contiguous() 11289*da0073e9SAndroid Build Coastguard Worker for j in range(2)) 11290*da0073e9SAndroid Build Coastguard Worker 11291*da0073e9SAndroid Build Coastguard Worker def check_lengths(lengths, enforce_sorted, use_default_hiddens, proj_size): 11292*da0073e9SAndroid Build Coastguard Worker input_size = 3 11293*da0073e9SAndroid Build Coastguard Worker hidden_size = 4 11294*da0073e9SAndroid Build Coastguard Worker num_layers = 2 11295*da0073e9SAndroid Build Coastguard Worker bidirectional = True 11296*da0073e9SAndroid Build Coastguard Worker 11297*da0073e9SAndroid Build Coastguard Worker max_length = max(lengths) 11298*da0073e9SAndroid Build Coastguard Worker x_leaf = torch.randn(max_length, len(lengths), input_size, device=device, 11299*da0073e9SAndroid Build Coastguard Worker dtype=dtype, requires_grad=True) 11300*da0073e9SAndroid Build Coastguard Worker num_directions = 2 if bidirectional else 1 11301*da0073e9SAndroid Build Coastguard Worker lstm = nn.LSTM(input_size, hidden_size, bidirectional=bidirectional, 11302*da0073e9SAndroid Build Coastguard Worker num_layers=num_layers, proj_size=proj_size).to(device, dtype) 11303*da0073e9SAndroid Build Coastguard Worker lstm2 = deepcopy(lstm).to(device, dtype) 11304*da0073e9SAndroid Build Coastguard Worker x = x_leaf 11305*da0073e9SAndroid Build Coastguard Worker 11306*da0073e9SAndroid Build Coastguard Worker hidden0 = None 11307*da0073e9SAndroid Build Coastguard Worker if not use_default_hiddens: 11308*da0073e9SAndroid Build Coastguard Worker real_hidden_size = hidden_size if proj_size == 0 else proj_size 11309*da0073e9SAndroid Build Coastguard Worker hidden0 = (torch.randn(num_directions * num_layers, len(lengths), real_hidden_size, 11310*da0073e9SAndroid Build Coastguard Worker device=device, dtype=dtype), 11311*da0073e9SAndroid Build Coastguard Worker torch.randn(num_directions * num_layers, len(lengths), hidden_size, 11312*da0073e9SAndroid Build Coastguard Worker device=device, dtype=dtype)) 11313*da0073e9SAndroid Build Coastguard Worker 11314*da0073e9SAndroid Build Coastguard Worker # Compute sequences separately 11315*da0073e9SAndroid Build Coastguard Worker seq_outs = [] 11316*da0073e9SAndroid Build Coastguard Worker seq_hiddens = [] 11317*da0073e9SAndroid Build Coastguard Worker for i, l in enumerate(lengths): 11318*da0073e9SAndroid Build Coastguard Worker hidden_i = maybe_index_tuple(hidden0, i) 11319*da0073e9SAndroid Build Coastguard Worker out, hid = lstm2(x[:l, i:i + 1], hidden_i) 11320*da0073e9SAndroid Build Coastguard Worker out_pad = pad(out, max_length) 11321*da0073e9SAndroid Build Coastguard Worker seq_outs.append(out_pad) 11322*da0073e9SAndroid Build Coastguard Worker seq_hiddens.append(hid) 11323*da0073e9SAndroid Build Coastguard Worker seq_out = torch.cat(seq_outs, 1) 11324*da0073e9SAndroid Build Coastguard Worker seq_hidden = tuple(torch.cat(hids, 1) for hids in zip(*seq_hiddens)) 11325*da0073e9SAndroid Build Coastguard Worker 11326*da0073e9SAndroid Build Coastguard Worker # Use packed format 11327*da0073e9SAndroid Build Coastguard Worker packed = rnn_utils.pack_padded_sequence(x, lengths, enforce_sorted=enforce_sorted) 11328*da0073e9SAndroid Build Coastguard Worker packed_out, packed_hidden = lstm(packed, hidden0) 11329*da0073e9SAndroid Build Coastguard Worker unpacked, unpacked_len = rnn_utils.pad_packed_sequence(packed_out) 11330*da0073e9SAndroid Build Coastguard Worker 11331*da0073e9SAndroid Build Coastguard Worker # Check forward 11332*da0073e9SAndroid Build Coastguard Worker prec = dtype2prec_DONTUSE[dtype] 11333*da0073e9SAndroid Build Coastguard Worker self.assertEqual(packed_hidden, seq_hidden, atol=prec, rtol=0) 11334*da0073e9SAndroid Build Coastguard Worker self.assertEqual(unpacked, seq_out, atol=prec, rtol=0) 11335*da0073e9SAndroid Build Coastguard Worker self.assertEqual(unpacked_len, lengths, atol=prec, rtol=0) 11336*da0073e9SAndroid Build Coastguard Worker 11337*da0073e9SAndroid Build Coastguard Worker # Check backward 11338*da0073e9SAndroid Build Coastguard Worker seq_out.sum().backward() 11339*da0073e9SAndroid Build Coastguard Worker grad_x = x_leaf.grad.data.clone() 11340*da0073e9SAndroid Build Coastguard Worker x_leaf.grad.data.zero_() 11341*da0073e9SAndroid Build Coastguard Worker unpacked.sum().backward() 11342*da0073e9SAndroid Build Coastguard Worker 11343*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x_leaf.grad, grad_x, atol=dtype2prec_DONTUSE[dtype], rtol=0) 11344*da0073e9SAndroid Build Coastguard Worker for p1, p2 in zip(lstm.parameters(), lstm2.parameters()): 11345*da0073e9SAndroid Build Coastguard Worker prec = dtype2prec_DONTUSE[dtype] 11346*da0073e9SAndroid Build Coastguard Worker if dtype == torch.float16: 11347*da0073e9SAndroid Build Coastguard Worker prec = 4e-2 11348*da0073e9SAndroid Build Coastguard Worker self.assertEqual(p1.grad, p2.grad, atol=prec, rtol=0) 11349*da0073e9SAndroid Build Coastguard Worker 11350*da0073e9SAndroid Build Coastguard Worker tests = [ 11351*da0073e9SAndroid Build Coastguard Worker # enforce_sorted, lengths 11352*da0073e9SAndroid Build Coastguard Worker [True, [5]], 11353*da0073e9SAndroid Build Coastguard Worker [False, [5]], 11354*da0073e9SAndroid Build Coastguard Worker [True, [10, 10, 6, 2, 2, 1, 1]], 11355*da0073e9SAndroid Build Coastguard Worker [False, [10, 10, 6, 2, 2, 1, 1]], 11356*da0073e9SAndroid Build Coastguard Worker [False, [2, 1, 3, 2, 10, 5, 3]], 11357*da0073e9SAndroid Build Coastguard Worker ] 11358*da0073e9SAndroid Build Coastguard Worker 11359*da0073e9SAndroid Build Coastguard Worker for enforce_sorted, seq_lens, in tests: 11360*da0073e9SAndroid Build Coastguard Worker for use_default_hiddens in (True, False): 11361*da0073e9SAndroid Build Coastguard Worker for proj_size in [0, 2]: 11362*da0073e9SAndroid Build Coastguard Worker check_lengths(seq_lens, enforce_sorted, use_default_hiddens, proj_size) 11363*da0073e9SAndroid Build Coastguard Worker 11364*da0073e9SAndroid Build Coastguard Worker def _test_batchnorm_update_stats(self, device, dtype=torch.float): 11365*da0073e9SAndroid Build Coastguard Worker module = nn.BatchNorm1d(3).to(device, dtype) 11366*da0073e9SAndroid Build Coastguard Worker 11367*da0073e9SAndroid Build Coastguard Worker data = torch.rand(4, 3, device=device, dtype=dtype) 11368*da0073e9SAndroid Build Coastguard Worker 11369*da0073e9SAndroid Build Coastguard Worker # training pass 11370*da0073e9SAndroid Build Coastguard Worker old_running_mean = module.running_mean.clone() 11371*da0073e9SAndroid Build Coastguard Worker old_running_var = module.running_var.clone() 11372*da0073e9SAndroid Build Coastguard Worker old_num_batches_tracked = module.num_batches_tracked.clone() 11373*da0073e9SAndroid Build Coastguard Worker module(data) 11374*da0073e9SAndroid Build Coastguard Worker self.assertNotEqual(old_running_mean, module.running_mean) 11375*da0073e9SAndroid Build Coastguard Worker self.assertNotEqual(old_running_var, module.running_var) 11376*da0073e9SAndroid Build Coastguard Worker self.assertEqual(old_num_batches_tracked + 1, module.num_batches_tracked) 11377*da0073e9SAndroid Build Coastguard Worker 11378*da0073e9SAndroid Build Coastguard Worker # eval pass 11379*da0073e9SAndroid Build Coastguard Worker module.eval() 11380*da0073e9SAndroid Build Coastguard Worker old_running_mean = module.running_mean.clone() 11381*da0073e9SAndroid Build Coastguard Worker old_running_var = module.running_var.clone() 11382*da0073e9SAndroid Build Coastguard Worker old_num_batches_tracked = module.num_batches_tracked.clone() 11383*da0073e9SAndroid Build Coastguard Worker module(data) 11384*da0073e9SAndroid Build Coastguard Worker self.assertEqual(old_running_mean, module.running_mean) 11385*da0073e9SAndroid Build Coastguard Worker self.assertEqual(old_running_var, module.running_var) 11386*da0073e9SAndroid Build Coastguard Worker self.assertEqual(old_num_batches_tracked, module.num_batches_tracked) 11387*da0073e9SAndroid Build Coastguard Worker 11388*da0073e9SAndroid Build Coastguard Worker def test_batchnorm_update_stats(self, device): 11389*da0073e9SAndroid Build Coastguard Worker self._test_batchnorm_update_stats(device) 11390*da0073e9SAndroid Build Coastguard Worker 11391*da0073e9SAndroid Build Coastguard Worker if self.device_type == 'cuda' and self.has_cudnn(): 11392*da0073e9SAndroid Build Coastguard Worker with torch.backends.cudnn.flags(enabled=False): 11393*da0073e9SAndroid Build Coastguard Worker self._test_batchnorm_update_stats(device) 11394*da0073e9SAndroid Build Coastguard Worker 11395*da0073e9SAndroid Build Coastguard Worker @onlyCPU 11396*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.bfloat16, torch.float16) 11397*da0073e9SAndroid Build Coastguard Worker def test_activations_bfloat16_half_cpu(self, device, dtype): 11398*da0073e9SAndroid Build Coastguard Worker def test_helper(fn, device, inp_dims, prec=None): 11399*da0073e9SAndroid Build Coastguard Worker torch.manual_seed(37) 11400*da0073e9SAndroid Build Coastguard Worker # bfloat16/half compute 11401*da0073e9SAndroid Build Coastguard Worker fn = fn.to(dtype=dtype) 11402*da0073e9SAndroid Build Coastguard Worker input = torch.randn(inp_dims, dtype=dtype, device=device, requires_grad=True) 11403*da0073e9SAndroid Build Coastguard Worker out = fn(input) 11404*da0073e9SAndroid Build Coastguard Worker grad_input = torch.randn_like(out, dtype=dtype, device=device) 11405*da0073e9SAndroid Build Coastguard Worker out.backward(grad_input) 11406*da0073e9SAndroid Build Coastguard Worker 11407*da0073e9SAndroid Build Coastguard Worker # fp32 compute 11408*da0073e9SAndroid Build Coastguard Worker input2 = input.detach().clone().float().requires_grad_(True) 11409*da0073e9SAndroid Build Coastguard Worker out2 = fn.float()(input2) 11410*da0073e9SAndroid Build Coastguard Worker grad_input2 = grad_input.detach().clone().float() 11411*da0073e9SAndroid Build Coastguard Worker out2.backward(grad_input2) 11412*da0073e9SAndroid Build Coastguard Worker 11413*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out.dtype, dtype) 11414*da0073e9SAndroid Build Coastguard Worker self.assertEqual(input.grad.dtype, dtype) 11415*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out, out2.to(dtype=dtype), atol=prec, rtol=prec) 11416*da0073e9SAndroid Build Coastguard Worker self.assertEqual(input.grad.data, input2.grad.data.to(dtype=dtype), atol=prec, rtol=prec) 11417*da0073e9SAndroid Build Coastguard Worker 11418*da0073e9SAndroid Build Coastguard Worker shapes = [[1, 3, 1, 6], [1, 3, 1, 128], [1, 3, 256, 256]] 11419*da0073e9SAndroid Build Coastguard Worker for shape in shapes: 11420*da0073e9SAndroid Build Coastguard Worker test_helper(torch.nn.LogSigmoid(), device, shape) 11421*da0073e9SAndroid Build Coastguard Worker test_helper(torch.nn.Hardsigmoid(), device, shape) 11422*da0073e9SAndroid Build Coastguard Worker test_helper(torch.nn.Hardshrink(), device, shape) 11423*da0073e9SAndroid Build Coastguard Worker test_helper(torch.nn.Softshrink(), device, shape) 11424*da0073e9SAndroid Build Coastguard Worker test_helper(torch.nn.Hardswish(), device, shape) 11425*da0073e9SAndroid Build Coastguard Worker test_helper(torch.nn.Softplus(), device, shape) 11426*da0073e9SAndroid Build Coastguard Worker test_helper(torch.nn.SiLU(), device, shape) 11427*da0073e9SAndroid Build Coastguard Worker test_helper(torch.nn.Hardtanh(), device, shape) 11428*da0073e9SAndroid Build Coastguard Worker test_helper(torch.nn.Mish(), device, shape) 11429*da0073e9SAndroid Build Coastguard Worker test_helper(torch.nn.ELU(), device, shape) 11430*da0073e9SAndroid Build Coastguard Worker test_helper(torch.nn.PReLU(), device, shape) 11431*da0073e9SAndroid Build Coastguard Worker test_helper(torch.nn.GLU(), device, shape, prec=1e-2) 11432*da0073e9SAndroid Build Coastguard Worker test_helper(torch.nn.Threshold(0.1, 20), device, shape) 11433*da0073e9SAndroid Build Coastguard Worker test_helper(torch.nn.GELU(), device, shape) 11434*da0073e9SAndroid Build Coastguard Worker test_helper(torch.nn.Hardtanh(), device, shape) 11435*da0073e9SAndroid Build Coastguard Worker test_helper(torch.nn.LeakyReLU(), device, shape) 11436*da0073e9SAndroid Build Coastguard Worker 11437*da0073e9SAndroid Build Coastguard Worker @onlyCUDA 11438*da0073e9SAndroid Build Coastguard Worker def test_activations_bfloat16(self, device): 11439*da0073e9SAndroid Build Coastguard Worker _test_bfloat16_ops(self, torch.nn.ReLU(), device, inp_dims=(5), prec=1e-2) 11440*da0073e9SAndroid Build Coastguard Worker _test_bfloat16_ops(self, torch.nn.Threshold(0.1, 20), device, inp_dims=(5), prec=1e-2) 11441*da0073e9SAndroid Build Coastguard Worker _test_bfloat16_ops(self, torch.nn.ELU(), device, inp_dims=(5), prec=1e-2) 11442*da0073e9SAndroid Build Coastguard Worker _test_bfloat16_ops(self, torch.nn.Softplus(), device, inp_dims=(5), prec=1e-2) 11443*da0073e9SAndroid Build Coastguard Worker _test_bfloat16_ops(self, torch.nn.Hardshrink(), device, inp_dims=(5), prec=1e-2) 11444*da0073e9SAndroid Build Coastguard Worker _test_bfloat16_ops(self, torch.nn.Softshrink(), device, inp_dims=(5), prec=1e-2) 11445*da0073e9SAndroid Build Coastguard Worker _test_bfloat16_ops(self, torch.nn.LeakyReLU(), device, inp_dims=(5), prec=1e-2) 11446*da0073e9SAndroid Build Coastguard Worker 11447*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 11448*da0073e9SAndroid Build Coastguard Worker def test_softmax_bfloat16(self, device): 11449*da0073e9SAndroid Build Coastguard Worker for dim in [0, 1, 2, 3]: 11450*da0073e9SAndroid Build Coastguard Worker _test_bfloat16_ops(self, torch.nn.Softmax(dim=dim), device, inp_dims=(16, 33, 15, 16), prec=1e-2) 11451*da0073e9SAndroid Build Coastguard Worker # test softmax with large input value which casues exp() to overflow 11452*da0073e9SAndroid Build Coastguard Worker _test_bfloat16_ops(self, torch.nn.Softmax(dim=dim), device, inp_dims=(16, 33, 15, 16), prec=0.05, scale_factor=1000.0) 11453*da0073e9SAndroid Build Coastguard Worker 11454*da0073e9SAndroid Build Coastguard Worker def test_nll_loss_mismatched_batch(self, device): 11455*da0073e9SAndroid Build Coastguard Worker x = torch.randn((10, 3), requires_grad=True, device=device) 11456*da0073e9SAndroid Build Coastguard Worker # t should have size (10,) 11457*da0073e9SAndroid Build Coastguard Worker t = torch.zeros((3,), dtype=torch.int64, device=device) 11458*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(ValueError, 'Expected.*batch_size'): 11459*da0073e9SAndroid Build Coastguard Worker F.nll_loss(x, t) 11460*da0073e9SAndroid Build Coastguard Worker 11461*da0073e9SAndroid Build Coastguard Worker def test_nll_loss_out_of_bounds_ignore_index(self, device): 11462*da0073e9SAndroid Build Coastguard Worker x = torch.randn(6, 3, requires_grad=True, device=device) 11463*da0073e9SAndroid Build Coastguard Worker t = torch.tensor([0, 1, 255, 0, 1, 2], dtype=torch.int64, device=device) 11464*da0073e9SAndroid Build Coastguard Worker for reduction in ['mean', 'none']: 11465*da0073e9SAndroid Build Coastguard Worker F.nll_loss(x, t, ignore_index=255, reduction=reduction).sum().backward() 11466*da0073e9SAndroid Build Coastguard Worker 11467*da0073e9SAndroid Build Coastguard Worker def test_nll_loss_invalid_target_dim(self, device): 11468*da0073e9SAndroid Build Coastguard Worker x = torch.randn((10, 3), device=device) 11469*da0073e9SAndroid Build Coastguard Worker t = torch.zeros((10, 2), dtype=torch.int64, device=device) 11470*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "1D target tensor expected"): 11471*da0073e9SAndroid Build Coastguard Worker F.nll_loss(x, t) 11472*da0073e9SAndroid Build Coastguard Worker 11473*da0073e9SAndroid Build Coastguard Worker def test_nll_loss_invalid_weights(self, device): 11474*da0073e9SAndroid Build Coastguard Worker x = torch.randn((10, 3), device=device) 11475*da0073e9SAndroid Build Coastguard Worker t = torch.empty(10, dtype=torch.int64, device=device).random_(0, 3) 11476*da0073e9SAndroid Build Coastguard Worker invalid_weights = [ 11477*da0073e9SAndroid Build Coastguard Worker torch.randn(4, device=device), 11478*da0073e9SAndroid Build Coastguard Worker torch.randn(1, 3, device=device), 11479*da0073e9SAndroid Build Coastguard Worker ] 11480*da0073e9SAndroid Build Coastguard Worker msg = "weight tensor should be defined either for all 3 classes or no classes" 11481*da0073e9SAndroid Build Coastguard Worker for weight in invalid_weights: 11482*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, msg): 11483*da0073e9SAndroid Build Coastguard Worker F.nll_loss(x, t, weight=weight) 11484*da0073e9SAndroid Build Coastguard Worker 11485*da0073e9SAndroid Build Coastguard Worker # Ref: https://github.com/pytorch/pytorch/issue/85005 11486*da0073e9SAndroid Build Coastguard Worker @onlyCUDA 11487*da0073e9SAndroid Build Coastguard Worker @largeTensorTest("120GB", "cpu") 11488*da0073e9SAndroid Build Coastguard Worker @largeTensorTest("45GB", "cuda") 11489*da0073e9SAndroid Build Coastguard Worker @parametrize_test("reduction", ("none", "mean", "sum")) 11490*da0073e9SAndroid Build Coastguard Worker def test_nll_loss_large_tensor(self, device, reduction): 11491*da0073e9SAndroid Build Coastguard Worker shape = [int(2 ** 16), int(2 ** 16) + 1] 11492*da0073e9SAndroid Build Coastguard Worker 11493*da0073e9SAndroid Build Coastguard Worker input = torch.randn(shape, device=device, dtype=torch.float32, requires_grad=True) 11494*da0073e9SAndroid Build Coastguard Worker labels = torch.randint(shape[0], (shape[0],), dtype=torch.long, device=device) 11495*da0073e9SAndroid Build Coastguard Worker 11496*da0073e9SAndroid Build Coastguard Worker out = F.nll_loss(input, labels, reduction=reduction) 11497*da0073e9SAndroid Build Coastguard Worker 11498*da0073e9SAndroid Build Coastguard Worker with torch.no_grad(): 11499*da0073e9SAndroid Build Coastguard Worker input_cpu = input.cpu().float().requires_grad_() 11500*da0073e9SAndroid Build Coastguard Worker labels_cpu = labels.cpu() 11501*da0073e9SAndroid Build Coastguard Worker out_cpu = F.nll_loss(input_cpu, labels_cpu, reduction=reduction) 11502*da0073e9SAndroid Build Coastguard Worker # workaround to reduce memory usage vs. self.assertEqual, see #84944 11503*da0073e9SAndroid Build Coastguard Worker rtol, atol = torch.testing._comparison.get_tolerances(torch.float32, rtol=None, atol=None) 11504*da0073e9SAndroid Build Coastguard Worker if reduction == "sum": 11505*da0073e9SAndroid Build Coastguard Worker orig_rtol, orig_atol = rtol, atol 11506*da0073e9SAndroid Build Coastguard Worker rtol, atol = 7 * rtol, 3 * atol 11507*da0073e9SAndroid Build Coastguard Worker with torch.no_grad(): 11508*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.allclose(out.cpu(), out_cpu, rtol=rtol, atol=atol)) 11509*da0073e9SAndroid Build Coastguard Worker if reduction == "sum": 11510*da0073e9SAndroid Build Coastguard Worker rtol, atol = orig_rtol, orig_atol 11511*da0073e9SAndroid Build Coastguard Worker 11512*da0073e9SAndroid Build Coastguard Worker if reduction != "none": 11513*da0073e9SAndroid Build Coastguard Worker out.backward() 11514*da0073e9SAndroid Build Coastguard Worker out_cpu.backward() 11515*da0073e9SAndroid Build Coastguard Worker with torch.no_grad(): 11516*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.allclose(input.grad.cpu(), input_cpu.grad, rtol=rtol, atol=atol)) 11517*da0073e9SAndroid Build Coastguard Worker 11518*da0073e9SAndroid Build Coastguard Worker # Ref: https://github.com/pytorch/pytorch/issue/108345 11519*da0073e9SAndroid Build Coastguard Worker @onlyCUDA 11520*da0073e9SAndroid Build Coastguard Worker @largeTensorTest("20GB", "cpu") 11521*da0073e9SAndroid Build Coastguard Worker @largeTensorTest("20GB", "cuda") 11522*da0073e9SAndroid Build Coastguard Worker @parametrize_test("reduction", ("none", "mean", "sum")) 11523*da0073e9SAndroid Build Coastguard Worker def test_cross_entropy_64bit(self, device, reduction): 11524*da0073e9SAndroid Build Coastguard Worker labels = torch.zeros(190, 50, dtype=torch.long, device=device) 11525*da0073e9SAndroid Build Coastguard Worker logits = torch.ones(190, 229000, 50, dtype=torch.float, device=device) 11526*da0073e9SAndroid Build Coastguard Worker loss = torch.nn.functional.cross_entropy(logits, labels) 11527*da0073e9SAndroid Build Coastguard Worker loss_cpu = torch.nn.functional.cross_entropy(logits.cpu(), labels.cpu()) 11528*da0073e9SAndroid Build Coastguard Worker print(logits.numel(), labels.numel(), loss.numel()) 11529*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.allclose(loss_cpu, loss.cpu(), rtol=1e-4, atol=1e-4)) 11530*da0073e9SAndroid Build Coastguard Worker 11531*da0073e9SAndroid Build Coastguard Worker def _nll_loss_helper(self, input_size, reduction, expected, device): 11532*da0073e9SAndroid Build Coastguard Worker input = torch.rand(input_size, requires_grad=True, device=device) 11533*da0073e9SAndroid Build Coastguard Worker num_channels = input_size[1] 11534*da0073e9SAndroid Build Coastguard Worker target_size = (input_size[0], ) + tuple(input_size[2:]) 11535*da0073e9SAndroid Build Coastguard Worker target = torch.randint(num_channels, target_size, device=device) 11536*da0073e9SAndroid Build Coastguard Worker 11537*da0073e9SAndroid Build Coastguard Worker output = F.nll_loss(input, target, reduction=reduction) 11538*da0073e9SAndroid Build Coastguard Worker self.assertEqual(output, expected, exact_dtype=False) 11539*da0073e9SAndroid Build Coastguard Worker 11540*da0073e9SAndroid Build Coastguard Worker output.sum().backward() 11541*da0073e9SAndroid Build Coastguard Worker self.assertEqual(input.grad.size(), input.size()) 11542*da0073e9SAndroid Build Coastguard Worker 11543*da0073e9SAndroid Build Coastguard Worker def test_nll_loss_empty_tensor_reduction_none(self, device): 11544*da0073e9SAndroid Build Coastguard Worker self._nll_loss_helper([0, 3], "none", torch.empty([0], device=device), device) 11545*da0073e9SAndroid Build Coastguard Worker self._nll_loss_helper([0, 3, 5, 7], "none", torch.empty([0, 5, 7], device=device), device) 11546*da0073e9SAndroid Build Coastguard Worker self._nll_loss_helper([2, 3, 0, 7], "none", torch.empty([2, 0, 7], device=device), device) 11547*da0073e9SAndroid Build Coastguard Worker self._nll_loss_helper([2, 3, 5, 0], "none", torch.empty([2, 5, 0], device=device), device) 11548*da0073e9SAndroid Build Coastguard Worker self._nll_loss_helper([2, 3, 5, 7, 0], "none", torch.empty([2, 5, 7, 0], device=device), device) 11549*da0073e9SAndroid Build Coastguard Worker 11550*da0073e9SAndroid Build Coastguard Worker @expectedFailureMPS # RuntimeError: [srcBuf length] > 0 INTERNAL ASSERT FAILED https://github.com/pytorch/pytorch/issues/134431 11551*da0073e9SAndroid Build Coastguard Worker def test_nll_loss_empty_tensor_reduction_mean(self, device): 11552*da0073e9SAndroid Build Coastguard Worker nan = torch.tensor(float('nan'), device=device) 11553*da0073e9SAndroid Build Coastguard Worker self._nll_loss_helper([0, 3], "mean", nan, device) 11554*da0073e9SAndroid Build Coastguard Worker self._nll_loss_helper([0, 3, 5, 7], "mean", nan, device) 11555*da0073e9SAndroid Build Coastguard Worker self._nll_loss_helper([2, 3, 0, 7], "mean", nan, device) 11556*da0073e9SAndroid Build Coastguard Worker self._nll_loss_helper([2, 3, 5, 0], "mean", nan, device) 11557*da0073e9SAndroid Build Coastguard Worker self._nll_loss_helper([2, 3, 5, 7, 0], "mean", nan, device) 11558*da0073e9SAndroid Build Coastguard Worker 11559*da0073e9SAndroid Build Coastguard Worker @expectedFailureMPS # RuntimeError: [srcBuf length] > 0 INTERNAL ASSERT FAILED https://github.com/pytorch/pytorch/issues/134431 11560*da0073e9SAndroid Build Coastguard Worker def test_nll_loss_empty_tensor_reduction_sum(self, device): 11561*da0073e9SAndroid Build Coastguard Worker zero = torch.tensor(0, device=device) 11562*da0073e9SAndroid Build Coastguard Worker self._nll_loss_helper([0, 3], "sum", zero, device) 11563*da0073e9SAndroid Build Coastguard Worker self._nll_loss_helper([0, 3, 5, 7], "sum", zero, device) 11564*da0073e9SAndroid Build Coastguard Worker self._nll_loss_helper([2, 3, 0, 7], "sum", zero, device) 11565*da0073e9SAndroid Build Coastguard Worker self._nll_loss_helper([2, 3, 5, 0], "sum", zero, device) 11566*da0073e9SAndroid Build Coastguard Worker self._nll_loss_helper([2, 3, 5, 7, 0], "sum", zero, device) 11567*da0073e9SAndroid Build Coastguard Worker 11568*da0073e9SAndroid Build Coastguard Worker @expectedFailureMPS # AssertionError: Expected nan but got 0.0. 11569*da0073e9SAndroid Build Coastguard Worker def test_nll_loss_total_weight_is_zero(self, device): 11570*da0073e9SAndroid Build Coastguard Worker 11571*da0073e9SAndroid Build Coastguard Worker def helper(input_size): 11572*da0073e9SAndroid Build Coastguard Worker input = torch.ones(input_size, requires_grad=True, device=device) 11573*da0073e9SAndroid Build Coastguard Worker num_channels = input_size[1] 11574*da0073e9SAndroid Build Coastguard Worker target_size = (input_size[0], ) + tuple(input_size[2:]) 11575*da0073e9SAndroid Build Coastguard Worker target = torch.zeros(target_size, dtype=torch.long, device=device) 11576*da0073e9SAndroid Build Coastguard Worker weight = torch.zeros([num_channels], device=device) 11577*da0073e9SAndroid Build Coastguard Worker self.assertEqual(F.nll_loss(input, target, weight, reduction="sum").item(), 0.) 11578*da0073e9SAndroid Build Coastguard Worker self.assertEqual(F.nll_loss(input, target, weight, reduction="mean").item(), float("nan")) 11579*da0073e9SAndroid Build Coastguard Worker self.assertEqual(F.nll_loss(input, target, weight, reduction="none"), torch.zeros(target.shape, device=device)) 11580*da0073e9SAndroid Build Coastguard Worker 11581*da0073e9SAndroid Build Coastguard Worker helper([2, 3]) 11582*da0073e9SAndroid Build Coastguard Worker helper([2, 3, 5, 7]) 11583*da0073e9SAndroid Build Coastguard Worker helper([2, 3, 5, 7, 9]) 11584*da0073e9SAndroid Build Coastguard Worker 11585*da0073e9SAndroid Build Coastguard Worker @expectedFailureMPS # AssertionError: Expected nan but got 0.0. 11586*da0073e9SAndroid Build Coastguard Worker def test_nll_loss_all_ignored(self, device): 11587*da0073e9SAndroid Build Coastguard Worker 11588*da0073e9SAndroid Build Coastguard Worker def helper(input_size): 11589*da0073e9SAndroid Build Coastguard Worker input = torch.ones(input_size, device=device) 11590*da0073e9SAndroid Build Coastguard Worker num_channels = input_size[1] 11591*da0073e9SAndroid Build Coastguard Worker target_size = (input_size[0], ) + tuple(input_size[2:]) 11592*da0073e9SAndroid Build Coastguard Worker target = torch.zeros(target_size, dtype=torch.long, device=device) 11593*da0073e9SAndroid Build Coastguard Worker self.assertEqual(F.nll_loss(input, target, ignore_index=0, reduction="sum").item(), 0) 11594*da0073e9SAndroid Build Coastguard Worker self.assertEqual(F.nll_loss(input, target, ignore_index=0, reduction="mean").item(), float("nan")) 11595*da0073e9SAndroid Build Coastguard Worker self.assertEqual(F.nll_loss(input, target, ignore_index=0, reduction="none"), torch.zeros(target.shape, device=device)) 11596*da0073e9SAndroid Build Coastguard Worker 11597*da0073e9SAndroid Build Coastguard Worker helper([2, 3]) 11598*da0073e9SAndroid Build Coastguard Worker helper([2, 3, 5, 7]) 11599*da0073e9SAndroid Build Coastguard Worker helper([2, 3, 5, 7, 9]) 11600*da0073e9SAndroid Build Coastguard Worker 11601*da0073e9SAndroid Build Coastguard Worker def test_nll_loss_byte_target_matches_long(self, device): 11602*da0073e9SAndroid Build Coastguard Worker N, C = 10, 4 11603*da0073e9SAndroid Build Coastguard Worker input = torch.randn(N, C, device=device, requires_grad=True) 11604*da0073e9SAndroid Build Coastguard Worker target = torch.empty(N, dtype=torch.long, device=device).random_(0, C) 11605*da0073e9SAndroid Build Coastguard Worker 11606*da0073e9SAndroid Build Coastguard Worker def compute_result_and_gradient(reduction, target_dtype): 11607*da0073e9SAndroid Build Coastguard Worker input_ = input.detach() 11608*da0073e9SAndroid Build Coastguard Worker input_.requires_grad_() 11609*da0073e9SAndroid Build Coastguard Worker 11610*da0073e9SAndroid Build Coastguard Worker prob = F.log_softmax(input_, dim=-1) 11611*da0073e9SAndroid Build Coastguard Worker loss = nn.NLLLoss(reduction=reduction) 11612*da0073e9SAndroid Build Coastguard Worker result = loss(prob, target.to(target_dtype)) 11613*da0073e9SAndroid Build Coastguard Worker result.sum().backward() 11614*da0073e9SAndroid Build Coastguard Worker 11615*da0073e9SAndroid Build Coastguard Worker return result, input_.grad 11616*da0073e9SAndroid Build Coastguard Worker 11617*da0073e9SAndroid Build Coastguard Worker for reduction in ["none", "mean", "sum"]: 11618*da0073e9SAndroid Build Coastguard Worker result_long, grad_long = compute_result_and_gradient(reduction, torch.long) 11619*da0073e9SAndroid Build Coastguard Worker result_byte, grad_byte = compute_result_and_gradient(reduction, torch.uint8) 11620*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result_long, result_byte) 11621*da0073e9SAndroid Build Coastguard Worker self.assertEqual(grad_long, grad_byte) 11622*da0073e9SAndroid Build Coastguard Worker 11623*da0073e9SAndroid Build Coastguard Worker @onlyCUDA 11624*da0073e9SAndroid Build Coastguard Worker @skipIfRocm 11625*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float16, torch.float32) 11626*da0073e9SAndroid Build Coastguard Worker def test_cross_entropy_loss_2d_out_of_bounds_class_index(self, device, dtype): 11627*da0073e9SAndroid Build Coastguard Worker # Test for issue #117532 11628*da0073e9SAndroid Build Coastguard Worker # Run in a different process to prevent the device-side assert from affecting other tests 11629*da0073e9SAndroid Build Coastguard Worker stderr = TestCase.runWithPytorchAPIUsageStderr(f"""\ 11630*da0073e9SAndroid Build Coastguard Worker#!/usr/bin/env python3 11631*da0073e9SAndroid Build Coastguard Worker 11632*da0073e9SAndroid Build Coastguard Workerimport torch 11633*da0073e9SAndroid Build Coastguard Workerimport torch.nn.functional as F 11634*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import (run_tests, TestCase) 11635*da0073e9SAndroid Build Coastguard Worker 11636*da0073e9SAndroid Build Coastguard Workerclass TestThatContainsCUDAAssert(TestCase): 11637*da0073e9SAndroid Build Coastguard Worker def test_cross_entropy_loss_2d_out_of_bounds_class_index(self): 11638*da0073e9SAndroid Build Coastguard Worker device = '{str(device)}' 11639*da0073e9SAndroid Build Coastguard Worker dtype = {str(dtype).strip("'")} 11640*da0073e9SAndroid Build Coastguard Worker ignore_index = 255 11641*da0073e9SAndroid Build Coastguard Worker b = 10 11642*da0073e9SAndroid Build Coastguard Worker n_classes = 3 11643*da0073e9SAndroid Build Coastguard Worker w = 768 11644*da0073e9SAndroid Build Coastguard Worker h = 1024 11645*da0073e9SAndroid Build Coastguard Worker pred = torch.randn(b, n_classes, w, h, dtype=dtype, device=device) 11646*da0073e9SAndroid Build Coastguard Worker labels = torch.zeros(b, w, h, dtype=torch.int64, device=device) 11647*da0073e9SAndroid Build Coastguard Worker labels[5, 200, 200] = ignore_index 11648*da0073e9SAndroid Build Coastguard Worker # Set invalid class index 11649*da0073e9SAndroid Build Coastguard Worker labels[5, 200, 200] = 254 11650*da0073e9SAndroid Build Coastguard Worker 11651*da0073e9SAndroid Build Coastguard Worker x = F.cross_entropy( 11652*da0073e9SAndroid Build Coastguard Worker pred, labels, reduction="none", ignore_index=ignore_index 11653*da0073e9SAndroid Build Coastguard Worker ) 11654*da0073e9SAndroid Build Coastguard Worker torch.cuda.synchronize() 11655*da0073e9SAndroid Build Coastguard Worker 11656*da0073e9SAndroid Build Coastguard Worker 11657*da0073e9SAndroid Build Coastguard Workerif __name__ == '__main__': 11658*da0073e9SAndroid Build Coastguard Worker run_tests() 11659*da0073e9SAndroid Build Coastguard Worker """) 11660*da0073e9SAndroid Build Coastguard Worker self.assertIn('CUDA error: device-side assert triggered', stderr) 11661*da0073e9SAndroid Build Coastguard Worker 11662*da0073e9SAndroid Build Coastguard Worker 11663*da0073e9SAndroid Build Coastguard Worker 11664*da0073e9SAndroid Build Coastguard Worker def test_cross_entropy_loss_prob_target_all_reductions(self, device): 11665*da0073e9SAndroid Build Coastguard Worker # Test with k-dimensional loss. 11666*da0073e9SAndroid Build Coastguard Worker for k in range(5): 11667*da0073e9SAndroid Build Coastguard Worker N, C = 5, 4 11668*da0073e9SAndroid Build Coastguard Worker other_dims = [torch.randint(2, 5, size=(1,)).item() for _ in range(k)] 11669*da0073e9SAndroid Build Coastguard Worker input = torch.randn(N, C, *other_dims, device=device, requires_grad=True) 11670*da0073e9SAndroid Build Coastguard Worker target = torch.randn(N, C, *other_dims, device=device, requires_grad=True) 11671*da0073e9SAndroid Build Coastguard Worker weight = torch.randn(C, device=device).abs() 11672*da0073e9SAndroid Build Coastguard Worker 11673*da0073e9SAndroid Build Coastguard Worker for reduction, w in product(['none', 'mean', 'sum'], [None, weight]): 11674*da0073e9SAndroid Build Coastguard Worker m = torch.nn.CrossEntropyLoss(weight=w, reduction=reduction) 11675*da0073e9SAndroid Build Coastguard Worker output = m(input, target) 11676*da0073e9SAndroid Build Coastguard Worker output_ref = loss_reference_fns['CrossEntropyLoss']( 11677*da0073e9SAndroid Build Coastguard Worker input, target, reduction=reduction, weight=w) 11678*da0073e9SAndroid Build Coastguard Worker self.assertEqual(output, output_ref) 11679*da0073e9SAndroid Build Coastguard Worker 11680*da0073e9SAndroid Build Coastguard Worker def test_cross_entropy_loss_prob_target_unit_weights(self, device): 11681*da0073e9SAndroid Build Coastguard Worker # Test with k-dimensional loss. 11682*da0073e9SAndroid Build Coastguard Worker for k in range(5): 11683*da0073e9SAndroid Build Coastguard Worker N, C = 5, 4 11684*da0073e9SAndroid Build Coastguard Worker other_dims = [torch.randint(2, 5, size=(1,)).item() for _ in range(k)] 11685*da0073e9SAndroid Build Coastguard Worker input = torch.randn(N, C, *other_dims, device=device, requires_grad=True) 11686*da0073e9SAndroid Build Coastguard Worker target = torch.randn(N, C, *other_dims, device=device, requires_grad=True) 11687*da0073e9SAndroid Build Coastguard Worker 11688*da0073e9SAndroid Build Coastguard Worker for reduction in ['none', 'mean', 'sum']: 11689*da0073e9SAndroid Build Coastguard Worker # Ensure result with unit weights is equivalent to result without weights. 11690*da0073e9SAndroid Build Coastguard Worker m = torch.nn.CrossEntropyLoss(reduction=reduction) 11691*da0073e9SAndroid Build Coastguard Worker unit_weight = torch.ones(C, device=device, dtype=target.dtype) 11692*da0073e9SAndroid Build Coastguard Worker m_unit = torch.nn.CrossEntropyLoss(weight=unit_weight, reduction=reduction) 11693*da0073e9SAndroid Build Coastguard Worker output = m(input, target) 11694*da0073e9SAndroid Build Coastguard Worker output_unit = m_unit(input, target) 11695*da0073e9SAndroid Build Coastguard Worker self.assertEqual(output, output_unit) 11696*da0073e9SAndroid Build Coastguard Worker 11697*da0073e9SAndroid Build Coastguard Worker @parametrize_test('reduction', ['none', 'mean', 'sum']) 11698*da0073e9SAndroid Build Coastguard Worker @parametrize_test('weighted', [False, True]) 11699*da0073e9SAndroid Build Coastguard Worker def test_cross_entropy_loss_prob_target_no_batch_dim(self, device, reduction, weighted): 11700*da0073e9SAndroid Build Coastguard Worker C = 5 11701*da0073e9SAndroid Build Coastguard Worker input = torch.randn(C, device=device).log_softmax(dim=-1) 11702*da0073e9SAndroid Build Coastguard Worker target = torch.randn(C, device=device).softmax(dim=-1) 11703*da0073e9SAndroid Build Coastguard Worker weight = torch.randn(C, device=device) if weighted else None 11704*da0073e9SAndroid Build Coastguard Worker m = nn.CrossEntropyLoss(reduction=reduction, weight=weight) 11705*da0073e9SAndroid Build Coastguard Worker loss_no_batch = m(input, target) 11706*da0073e9SAndroid Build Coastguard Worker loss_batch = m(input.unsqueeze(0), target.unsqueeze(0)) 11707*da0073e9SAndroid Build Coastguard Worker if reduction == 'none': 11708*da0073e9SAndroid Build Coastguard Worker loss_batch = loss_batch.squeeze(0) 11709*da0073e9SAndroid Build Coastguard Worker self.assertEqual(loss_no_batch, loss_batch) 11710*da0073e9SAndroid Build Coastguard Worker 11711*da0073e9SAndroid Build Coastguard Worker def test_cross_entropy_loss_index_target_unit_weights(self, device): 11712*da0073e9SAndroid Build Coastguard Worker # Test with k-dimensional loss. 11713*da0073e9SAndroid Build Coastguard Worker for k in range(5): 11714*da0073e9SAndroid Build Coastguard Worker N, C = 5, 4 11715*da0073e9SAndroid Build Coastguard Worker other_dims = [torch.randint(2, 5, size=(1,)).item() for _ in range(k)] 11716*da0073e9SAndroid Build Coastguard Worker input = torch.randn(N, C, *other_dims, device=device, requires_grad=True) 11717*da0073e9SAndroid Build Coastguard Worker target = torch.empty(N, *other_dims, dtype=torch.long, device=device).random_(0, C) 11718*da0073e9SAndroid Build Coastguard Worker 11719*da0073e9SAndroid Build Coastguard Worker for reduction in ['none', 'mean', 'sum']: 11720*da0073e9SAndroid Build Coastguard Worker # Ensure result with unit weights is equivalent to result without weights. 11721*da0073e9SAndroid Build Coastguard Worker m = torch.nn.CrossEntropyLoss(reduction=reduction) 11722*da0073e9SAndroid Build Coastguard Worker unit_weight = torch.ones(C, device=device, dtype=input.dtype) 11723*da0073e9SAndroid Build Coastguard Worker m_unit = torch.nn.CrossEntropyLoss(weight=unit_weight, reduction=reduction) 11724*da0073e9SAndroid Build Coastguard Worker output = m(input, target) 11725*da0073e9SAndroid Build Coastguard Worker output_unit = m_unit(input, target) 11726*da0073e9SAndroid Build Coastguard Worker self.assertEqual(output, output_unit) 11727*da0073e9SAndroid Build Coastguard Worker 11728*da0073e9SAndroid Build Coastguard Worker def test_cross_entropy_loss_one_hot_target(self, device): 11729*da0073e9SAndroid Build Coastguard Worker # Test with k-dimensional loss. 11730*da0073e9SAndroid Build Coastguard Worker for k in range(5): 11731*da0073e9SAndroid Build Coastguard Worker N, C = 5, 4 11732*da0073e9SAndroid Build Coastguard Worker other_dims = [torch.randint(2, 5, size=(1,)).item() for _ in range(k)] 11733*da0073e9SAndroid Build Coastguard Worker input = torch.randn(N, C, *other_dims, device=device, requires_grad=True) 11734*da0073e9SAndroid Build Coastguard Worker target = torch.empty(N, *other_dims, dtype=torch.long, device=device).random_(0, C) 11735*da0073e9SAndroid Build Coastguard Worker weight = torch.randn(C, device=device).abs() 11736*da0073e9SAndroid Build Coastguard Worker 11737*da0073e9SAndroid Build Coastguard Worker # Get one-hot representation of the target. 11738*da0073e9SAndroid Build Coastguard Worker target_one_hot = F.one_hot(target, num_classes=C).to(input.dtype) 11739*da0073e9SAndroid Build Coastguard Worker # Need to put the C dim at index 1. 11740*da0073e9SAndroid Build Coastguard Worker target_one_hot = target_one_hot.permute(0, -1, *range(1, target_one_hot.dim() - 1)) 11741*da0073e9SAndroid Build Coastguard Worker 11742*da0073e9SAndroid Build Coastguard Worker for reduction, w in product(['none', 'mean', 'sum'], [None, weight]): 11743*da0073e9SAndroid Build Coastguard Worker # Skip this case for now because soft and hard label CE are not consistent 11744*da0073e9SAndroid Build Coastguard Worker # in the way they apply class weights (see issue #61309). 11745*da0073e9SAndroid Build Coastguard Worker if reduction == 'mean' and weight is not None: 11746*da0073e9SAndroid Build Coastguard Worker continue 11747*da0073e9SAndroid Build Coastguard Worker 11748*da0073e9SAndroid Build Coastguard Worker # Ensure loss computed with class indices matches loss 11749*da0073e9SAndroid Build Coastguard Worker # computed with one-hot class probs. 11750*da0073e9SAndroid Build Coastguard Worker m = torch.nn.CrossEntropyLoss(weight=w, reduction=reduction) 11751*da0073e9SAndroid Build Coastguard Worker output = m(input, target) 11752*da0073e9SAndroid Build Coastguard Worker output_one_hot = m(input, target_one_hot) 11753*da0073e9SAndroid Build Coastguard Worker self.assertEqual(output, output_one_hot) 11754*da0073e9SAndroid Build Coastguard Worker 11755*da0073e9SAndroid Build Coastguard Worker def test_cross_entropy_label_smoothing_errors(self, device): 11756*da0073e9SAndroid Build Coastguard Worker N, C = 3, 4 11757*da0073e9SAndroid Build Coastguard Worker input_args = [ 11758*da0073e9SAndroid Build Coastguard Worker (torch.randn((N, C), device=device), torch.arange(0, C, device=device)), 11759*da0073e9SAndroid Build Coastguard Worker (torch.randn((N, C), device=device), torch.randn(N, C, device=device)) 11760*da0073e9SAndroid Build Coastguard Worker ] 11761*da0073e9SAndroid Build Coastguard Worker for input_arg in input_args: 11762*da0073e9SAndroid Build Coastguard Worker loss = nn.CrossEntropyLoss(label_smoothing=1.2) 11763*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, 11764*da0073e9SAndroid Build Coastguard Worker r"label_smoothing must be between 0\.0"): 11765*da0073e9SAndroid Build Coastguard Worker loss(*input_arg) 11766*da0073e9SAndroid Build Coastguard Worker 11767*da0073e9SAndroid Build Coastguard Worker @expectedFailureMPS # TypeError: the MPS framework doesn't support float64 11768*da0073e9SAndroid Build Coastguard Worker @set_default_dtype(torch.double) 11769*da0073e9SAndroid Build Coastguard Worker def test_cross_entropy_label_smoothing_consistent_index_target_and_probs(self, device): 11770*da0073e9SAndroid Build Coastguard Worker N, C = 10, 4 11771*da0073e9SAndroid Build Coastguard Worker ks = range(5) 11772*da0073e9SAndroid Build Coastguard Worker reductions = ['none', 'mean', 'sum'] 11773*da0073e9SAndroid Build Coastguard Worker label_smoothings = [0.05, 0.15] 11774*da0073e9SAndroid Build Coastguard Worker 11775*da0073e9SAndroid Build Coastguard Worker for k, reduction, label_smoothing in product(ks, reductions, label_smoothings): 11776*da0073e9SAndroid Build Coastguard Worker other_dims = [torch.randint(2, 5, size=(1,)).item() for _ in range(k)] 11777*da0073e9SAndroid Build Coastguard Worker input = torch.randn(N, C, *other_dims, device=device, requires_grad=True) 11778*da0073e9SAndroid Build Coastguard Worker target = torch.empty(N, *other_dims, dtype=torch.long, device=device).random_(0, C) 11779*da0073e9SAndroid Build Coastguard Worker 11780*da0073e9SAndroid Build Coastguard Worker # construct target probablity that should have the same result as label_smoothing 11781*da0073e9SAndroid Build Coastguard Worker target_proba = F.one_hot(target, num_classes=C) 11782*da0073e9SAndroid Build Coastguard Worker # Need to put the C dim at index 1. 11783*da0073e9SAndroid Build Coastguard Worker target_proba = target_proba.permute(0, -1, *range(1, target_proba.dim() - 1)) 11784*da0073e9SAndroid Build Coastguard Worker target_mask = (target_proba == 1) 11785*da0073e9SAndroid Build Coastguard Worker target_proba = target_proba.to(dtype=input.dtype) 11786*da0073e9SAndroid Build Coastguard Worker 11787*da0073e9SAndroid Build Coastguard Worker # y_k^ls = y_k * (1 - label_smoothing) + label_smoothing / n_classes 11788*da0073e9SAndroid Build Coastguard Worker # Get one-hot representation of the target. 11789*da0073e9SAndroid Build Coastguard Worker target_proba.masked_fill_(target_mask, 1 - label_smoothing + label_smoothing / C) 11790*da0073e9SAndroid Build Coastguard Worker target_proba.masked_fill_(~target_mask, label_smoothing / C) 11791*da0073e9SAndroid Build Coastguard Worker 11792*da0073e9SAndroid Build Coastguard Worker loss = nn.CrossEntropyLoss(reduction=reduction) 11793*da0073e9SAndroid Build Coastguard Worker output_with_prob = loss(input, target_proba) 11794*da0073e9SAndroid Build Coastguard Worker 11795*da0073e9SAndroid Build Coastguard Worker loss = nn.CrossEntropyLoss( 11796*da0073e9SAndroid Build Coastguard Worker reduction=reduction, label_smoothing=label_smoothing) 11797*da0073e9SAndroid Build Coastguard Worker output_with_index = loss(input, target) 11798*da0073e9SAndroid Build Coastguard Worker 11799*da0073e9SAndroid Build Coastguard Worker self.assertEqual(output_with_prob, output_with_index, 11800*da0073e9SAndroid Build Coastguard Worker rtol=1e-07, atol=1e-05) 11801*da0073e9SAndroid Build Coastguard Worker 11802*da0073e9SAndroid Build Coastguard Worker def test_cross_entropy_label_smoothing_with_probs(self, device): 11803*da0073e9SAndroid Build Coastguard Worker N, C = 10, 4 11804*da0073e9SAndroid Build Coastguard Worker ks = range(5) 11805*da0073e9SAndroid Build Coastguard Worker reductions = ['none', 'mean', 'sum'] 11806*da0073e9SAndroid Build Coastguard Worker label_smoothings = [0.05, 0.15] 11807*da0073e9SAndroid Build Coastguard Worker 11808*da0073e9SAndroid Build Coastguard Worker # Test with k-dimensional loss. 11809*da0073e9SAndroid Build Coastguard Worker for k, label_smoothing in product(ks, label_smoothings): 11810*da0073e9SAndroid Build Coastguard Worker other_dims = [torch.randint(2, 5, size=(1,)).item() for _ in range(k)] 11811*da0073e9SAndroid Build Coastguard Worker input = torch.randn(N, C, *other_dims, device=device, requires_grad=True) 11812*da0073e9SAndroid Build Coastguard Worker target = F.log_softmax(torch.randn(N, C, *other_dims, device=device), dim=1) 11813*da0073e9SAndroid Build Coastguard Worker 11814*da0073e9SAndroid Build Coastguard Worker for reduction in reductions: 11815*da0073e9SAndroid Build Coastguard Worker # use with label_smoothing 11816*da0073e9SAndroid Build Coastguard Worker loss = nn.CrossEntropyLoss(reduction=reduction, label_smoothing=label_smoothing) 11817*da0073e9SAndroid Build Coastguard Worker output_with_smoothing = loss(input, target) 11818*da0073e9SAndroid Build Coastguard Worker 11819*da0073e9SAndroid Build Coastguard Worker # manually smoothing target 11820*da0073e9SAndroid Build Coastguard Worker # class_proba^ls = class_proba * (1 - label_smoothing) + 11821*da0073e9SAndroid Build Coastguard Worker # label_smoothing / n_classes 11822*da0073e9SAndroid Build Coastguard Worker target_with_smoothing = target * (1 - label_smoothing) + label_smoothing / C 11823*da0073e9SAndroid Build Coastguard Worker loss = nn.CrossEntropyLoss(reduction=reduction) 11824*da0073e9SAndroid Build Coastguard Worker output_with_manual_smoothing = loss(input, target_with_smoothing) 11825*da0073e9SAndroid Build Coastguard Worker 11826*da0073e9SAndroid Build Coastguard Worker self.assertEqual(output_with_smoothing, output_with_manual_smoothing) 11827*da0073e9SAndroid Build Coastguard Worker 11828*da0073e9SAndroid Build Coastguard Worker 11829*da0073e9SAndroid Build Coastguard Worker def test_cross_entropy_label_smoothing_weight_ignore_indices(self, device): 11830*da0073e9SAndroid Build Coastguard Worker reductions = ['none', 'sum', 'mean'] 11831*da0073e9SAndroid Build Coastguard Worker label_smoothings = [0.05, 0.15] 11832*da0073e9SAndroid Build Coastguard Worker 11833*da0073e9SAndroid Build Coastguard Worker wgt = torch.tensor([0.3, 0.6], device=device) 11834*da0073e9SAndroid Build Coastguard Worker inp1 = torch.tensor([[0.3, 0.4], [1, 2]], device=device) 11835*da0073e9SAndroid Build Coastguard Worker inp2 = torch.tensor([[0.3, 0.6], [1, 2]], device=device) 11836*da0073e9SAndroid Build Coastguard Worker 11837*da0073e9SAndroid Build Coastguard Worker targ_default_ignore_index = torch.tensor([-100, 1], device=device) 11838*da0073e9SAndroid Build Coastguard Worker targ_negative_ignore_index = torch.tensor([-2, 1], device=device) 11839*da0073e9SAndroid Build Coastguard Worker targ_positive_ignore_index = torch.tensor([2, 1], device=device) 11840*da0073e9SAndroid Build Coastguard Worker 11841*da0073e9SAndroid Build Coastguard Worker for reduction, label_smoothing, weight in product(reductions, label_smoothings, (None, wgt)): 11842*da0073e9SAndroid Build Coastguard Worker def check_equal(loss, inp_targ_1, inp_targ_2): 11843*da0073e9SAndroid Build Coastguard Worker inp1, targ1 = inp_targ_1 11844*da0073e9SAndroid Build Coastguard Worker inp2, targ2 = inp_targ_2 11845*da0073e9SAndroid Build Coastguard Worker l1 = loss(inp1, targ1) 11846*da0073e9SAndroid Build Coastguard Worker l2 = loss(inp2, targ2) 11847*da0073e9SAndroid Build Coastguard Worker self.assertEqual(l1, l2) 11848*da0073e9SAndroid Build Coastguard Worker 11849*da0073e9SAndroid Build Coastguard Worker # Default ignore_index 11850*da0073e9SAndroid Build Coastguard Worker loss = nn.CrossEntropyLoss(reduction=reduction, 11851*da0073e9SAndroid Build Coastguard Worker label_smoothing=label_smoothing, 11852*da0073e9SAndroid Build Coastguard Worker weight=weight) 11853*da0073e9SAndroid Build Coastguard Worker check_equal(loss, (inp1, targ_default_ignore_index), (inp2, targ_default_ignore_index)) 11854*da0073e9SAndroid Build Coastguard Worker if reduction != 'none': 11855*da0073e9SAndroid Build Coastguard Worker # Check that we correctly tally the denominator for `mean` 11856*da0073e9SAndroid Build Coastguard Worker # i.e. we don't count the ignored_idx at all. 11857*da0073e9SAndroid Build Coastguard Worker check_equal(loss, (inp1, targ_default_ignore_index), (inp2[1:], targ_default_ignore_index[1:])) 11858*da0073e9SAndroid Build Coastguard Worker 11859*da0073e9SAndroid Build Coastguard Worker # negative ignore_index 11860*da0073e9SAndroid Build Coastguard Worker loss = nn.CrossEntropyLoss(reduction=reduction, 11861*da0073e9SAndroid Build Coastguard Worker label_smoothing=label_smoothing, 11862*da0073e9SAndroid Build Coastguard Worker ignore_index=-2, 11863*da0073e9SAndroid Build Coastguard Worker weight=weight) 11864*da0073e9SAndroid Build Coastguard Worker check_equal(loss, (inp1, targ_negative_ignore_index), (inp2, targ_negative_ignore_index)) 11865*da0073e9SAndroid Build Coastguard Worker if reduction != 'none': 11866*da0073e9SAndroid Build Coastguard Worker # Check that we correctly tally the denominator for `mean` 11867*da0073e9SAndroid Build Coastguard Worker # i.e. we don't count the ignored_idx at all. 11868*da0073e9SAndroid Build Coastguard Worker check_equal(loss, (inp1, targ_negative_ignore_index), (inp2[1:], targ_negative_ignore_index[1:])) 11869*da0073e9SAndroid Build Coastguard Worker 11870*da0073e9SAndroid Build Coastguard Worker # positive ignore_index 11871*da0073e9SAndroid Build Coastguard Worker loss = nn.CrossEntropyLoss(reduction=reduction, 11872*da0073e9SAndroid Build Coastguard Worker label_smoothing=label_smoothing, 11873*da0073e9SAndroid Build Coastguard Worker ignore_index=2, 11874*da0073e9SAndroid Build Coastguard Worker weight=weight) 11875*da0073e9SAndroid Build Coastguard Worker check_equal(loss, (inp1, targ_positive_ignore_index), (inp2, targ_positive_ignore_index)) 11876*da0073e9SAndroid Build Coastguard Worker if reduction != 'none': 11877*da0073e9SAndroid Build Coastguard Worker # Check that we correctly tally the denominator for `mean` 11878*da0073e9SAndroid Build Coastguard Worker # i.e. we don't count the ignored_idx at all. 11879*da0073e9SAndroid Build Coastguard Worker check_equal(loss, (inp1, targ_positive_ignore_index), (inp2[1:], targ_positive_ignore_index[1:])) 11880*da0073e9SAndroid Build Coastguard Worker 11881*da0073e9SAndroid Build Coastguard Worker # Ref: https://github.com/pytorch/pytorch/issues/85005 11882*da0073e9SAndroid Build Coastguard Worker @onlyCUDA 11883*da0073e9SAndroid Build Coastguard Worker @largeTensorTest("45GB", "cpu") 11884*da0073e9SAndroid Build Coastguard Worker @largeTensorTest("70GB", "cuda") 11885*da0073e9SAndroid Build Coastguard Worker @parametrize_test("reduction", ("none", "mean", "sum")) 11886*da0073e9SAndroid Build Coastguard Worker def test_cross_entropy_large_tensor(self, device, reduction): 11887*da0073e9SAndroid Build Coastguard Worker logits = torch.randn(int(2 ** 16), int(2 ** 16) + 1, dtype=torch.float32, device='cuda', requires_grad=True) 11888*da0073e9SAndroid Build Coastguard Worker labels = torch.zeros(logits.size(0), dtype=torch.long, device='cuda') 11889*da0073e9SAndroid Build Coastguard Worker loss = F.cross_entropy(logits, labels, reduction=reduction) 11890*da0073e9SAndroid Build Coastguard Worker if reduction != "none": 11891*da0073e9SAndroid Build Coastguard Worker loss.backward() 11892*da0073e9SAndroid Build Coastguard Worker 11893*da0073e9SAndroid Build Coastguard Worker with torch.no_grad(): 11894*da0073e9SAndroid Build Coastguard Worker logits_cpu = logits.cpu().detach().requires_grad_() 11895*da0073e9SAndroid Build Coastguard Worker labels_cpu = labels.cpu().detach() 11896*da0073e9SAndroid Build Coastguard Worker loss_cpu = F.cross_entropy(logits_cpu, labels_cpu, reduction=reduction) 11897*da0073e9SAndroid Build Coastguard Worker if reduction != "none": 11898*da0073e9SAndroid Build Coastguard Worker loss_cpu.backward() 11899*da0073e9SAndroid Build Coastguard Worker 11900*da0073e9SAndroid Build Coastguard Worker # workaround to reduce memory usage vs. self.assertEqual, see #84944 11901*da0073e9SAndroid Build Coastguard Worker rtol, atol = torch.testing._comparison.get_tolerances(torch.float32, rtol=None, atol=None) 11902*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.allclose(loss.cpu(), loss_cpu, rtol=rtol, atol=atol)) 11903*da0073e9SAndroid Build Coastguard Worker if reduction != "none": 11904*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.allclose(logits.grad.cpu(), logits_cpu.grad, rtol=rtol, atol=atol)) 11905*da0073e9SAndroid Build Coastguard Worker 11906*da0073e9SAndroid Build Coastguard Worker def test_smoothl1loss_backward_zero_beta(self, device): 11907*da0073e9SAndroid Build Coastguard Worker input = torch.randn(300, 256, requires_grad=True, device=device) 11908*da0073e9SAndroid Build Coastguard Worker target = input.detach() 11909*da0073e9SAndroid Build Coastguard Worker 11910*da0073e9SAndroid Build Coastguard Worker loss = F.smooth_l1_loss(input, target, beta=0.0, reduction='sum') 11911*da0073e9SAndroid Build Coastguard Worker loss.backward() 11912*da0073e9SAndroid Build Coastguard Worker 11913*da0073e9SAndroid Build Coastguard Worker grad_max_abs = input.grad.abs().max().item() 11914*da0073e9SAndroid Build Coastguard Worker self.assertLessEqual(grad_max_abs, 1.0) 11915*da0073e9SAndroid Build Coastguard Worker 11916*da0073e9SAndroid Build Coastguard Worker def test_softshrink_negative(self, device): 11917*da0073e9SAndroid Build Coastguard Worker input = torch.randn(5, device=device, requires_grad=True) 11918*da0073e9SAndroid Build Coastguard Worker m = torch.nn.Softshrink(-1) 11919*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, 11920*da0073e9SAndroid Build Coastguard Worker r'lambda must be greater or equal to 0, but found to be -1\.'): 11921*da0073e9SAndroid Build Coastguard Worker m(input) 11922*da0073e9SAndroid Build Coastguard Worker 11923*da0073e9SAndroid Build Coastguard Worker @expectedFailureMPS # TypeError: the MPS framework doesn't support float64 11924*da0073e9SAndroid Build Coastguard Worker def test_fold(self, device): 11925*da0073e9SAndroid Build Coastguard Worker def test_dtype(fn, input, dtype): 11926*da0073e9SAndroid Build Coastguard Worker input = input.detach().clone().to(dtype=dtype).requires_grad_(True) 11927*da0073e9SAndroid Build Coastguard Worker input2 = input.detach().clone().float().requires_grad_(True) 11928*da0073e9SAndroid Build Coastguard Worker out = fn(input) 11929*da0073e9SAndroid Build Coastguard Worker out.sum().backward() 11930*da0073e9SAndroid Build Coastguard Worker out2 = fn(input2) 11931*da0073e9SAndroid Build Coastguard Worker out2.sum().backward() 11932*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out.dtype, dtype) 11933*da0073e9SAndroid Build Coastguard Worker self.assertEqual(input.grad.dtype, dtype) 11934*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out, out2.to(dtype=dtype), atol=0.05, rtol=0) 11935*da0073e9SAndroid Build Coastguard Worker self.assertEqual(input.grad, input2.grad.to(dtype=dtype)) 11936*da0073e9SAndroid Build Coastguard Worker 11937*da0073e9SAndroid Build Coastguard Worker def func(x): 11938*da0073e9SAndroid Build Coastguard Worker return F.fold(x, output_size=(4, 5), kernel_size=(2, 2)) 11939*da0073e9SAndroid Build Coastguard Worker 11940*da0073e9SAndroid Build Coastguard Worker seeds = (44, 83, 71, 25, 999) 11941*da0073e9SAndroid Build Coastguard Worker for sd in seeds: 11942*da0073e9SAndroid Build Coastguard Worker torch.manual_seed(sd) 11943*da0073e9SAndroid Build Coastguard Worker x = torch.randn(1, 12, 12, device=device, requires_grad=True, dtype=torch.double) 11944*da0073e9SAndroid Build Coastguard Worker gradcheck(func, [x], check_forward_ad=True) 11945*da0073e9SAndroid Build Coastguard Worker gradgradcheck(func, [x], check_fwd_over_rev=True) 11946*da0073e9SAndroid Build Coastguard Worker if device == 'cpu': 11947*da0073e9SAndroid Build Coastguard Worker test_dtype(func, x, torch.bfloat16) 11948*da0073e9SAndroid Build Coastguard Worker 11949*da0073e9SAndroid Build Coastguard Worker 11950*da0073e9SAndroid Build Coastguard Worker def test_logsigmoid_out(self, device): 11951*da0073e9SAndroid Build Coastguard Worker # this isn't actually documented, but was broken previously: 11952*da0073e9SAndroid Build Coastguard Worker # https://github.com/pytorch/pytorch/issues/36499 11953*da0073e9SAndroid Build Coastguard Worker x = torch.randn(2, 3, device=device).t() 11954*da0073e9SAndroid Build Coastguard Worker empty_out = torch.randn(0, device=device) 11955*da0073e9SAndroid Build Coastguard Worker self.assertEqual(F.logsigmoid(x), F.logsigmoid(x, out=empty_out)) 11956*da0073e9SAndroid Build Coastguard Worker 11957*da0073e9SAndroid Build Coastguard Worker noncontig_out = torch.randn(2, 3, device=device).t() 11958*da0073e9SAndroid Build Coastguard Worker self.assertEqual(F.logsigmoid(x), F.logsigmoid(x, out=noncontig_out)) 11959*da0073e9SAndroid Build Coastguard Worker 11960*da0073e9SAndroid Build Coastguard Worker # Check that clip_grad_norm_ raises an error if the total norm of the 11961*da0073e9SAndroid Build Coastguard Worker # parameters' gradients is non-finite 11962*da0073e9SAndroid Build Coastguard Worker @expectedFailureMPS # TypeError: the MPS framework doesn't support float64 11963*da0073e9SAndroid Build Coastguard Worker def test_clip_grad_norm_error_if_nonfinite(self, device): 11964*da0073e9SAndroid Build Coastguard Worker norms_pos = [0.1, 1, 2, 3.5, inf] 11965*da0073e9SAndroid Build Coastguard Worker norms_neg = [-0.1, -1, -2, -3.5] 11966*da0073e9SAndroid Build Coastguard Worker norms_except_0 = norms_pos + norms_neg 11967*da0073e9SAndroid Build Coastguard Worker norms_all = norms_except_0 + [0] 11968*da0073e9SAndroid Build Coastguard Worker 11969*da0073e9SAndroid Build Coastguard Worker # Each entry in test_cases has the following values, in this order: 11970*da0073e9SAndroid Build Coastguard Worker # 11971*da0073e9SAndroid Build Coastguard Worker # grad_only_one_elem If True, only one element of the parameter's 11972*da0073e9SAndroid Build Coastguard Worker # gradient is set to the scalar grad, and the 11973*da0073e9SAndroid Build Coastguard Worker # rest of the elements are 0. If False, all grad 11974*da0073e9SAndroid Build Coastguard Worker # elements are equal to the scalar. 11975*da0073e9SAndroid Build Coastguard Worker # 11976*da0073e9SAndroid Build Coastguard Worker # prefix_finite_grad_param If True, prefix a parameter that has a grad 11977*da0073e9SAndroid Build Coastguard Worker # of 1. 11978*da0073e9SAndroid Build Coastguard Worker # 11979*da0073e9SAndroid Build Coastguard Worker # scalars Scalars to use as the parameter's grad, through 11980*da0073e9SAndroid Build Coastguard Worker # multiplication 11981*da0073e9SAndroid Build Coastguard Worker # 11982*da0073e9SAndroid Build Coastguard Worker # norms_nonfinite Norm types that should produce nonfinite total norm 11983*da0073e9SAndroid Build Coastguard Worker # 11984*da0073e9SAndroid Build Coastguard Worker # norms_finite Norm types that should produce finite total norm 11985*da0073e9SAndroid Build Coastguard Worker test_cases = [ 11986*da0073e9SAndroid Build Coastguard Worker # Test errors from an infinite grad 11987*da0073e9SAndroid Build Coastguard Worker (False, False, [inf, -inf], norms_except_0, [0]), 11988*da0073e9SAndroid Build Coastguard Worker (False, True, [inf, -inf], norms_pos, norms_neg + [0]), 11989*da0073e9SAndroid Build Coastguard Worker (True, False, [inf, -inf], norms_pos, norms_neg + [0]), 11990*da0073e9SAndroid Build Coastguard Worker (True, True, [inf, -inf], norms_pos, norms_neg + [0]), 11991*da0073e9SAndroid Build Coastguard Worker 11992*da0073e9SAndroid Build Coastguard Worker # Test errors from a NaN grad 11993*da0073e9SAndroid Build Coastguard Worker (False, False, [nan], norms_except_0, [0]), 11994*da0073e9SAndroid Build Coastguard Worker (False, True, [nan], norms_except_0, [0]), 11995*da0073e9SAndroid Build Coastguard Worker (True, False, [nan], norms_except_0, [0]), 11996*da0073e9SAndroid Build Coastguard Worker (True, True, [nan], norms_except_0, [0]), 11997*da0073e9SAndroid Build Coastguard Worker 11998*da0073e9SAndroid Build Coastguard Worker # Test a grad that should never error 11999*da0073e9SAndroid Build Coastguard Worker (False, False, [2e22, -2e22], [], norms_all), 12000*da0073e9SAndroid Build Coastguard Worker (False, True, [2e22, -2e22], [], norms_all), 12001*da0073e9SAndroid Build Coastguard Worker (True, False, [2e22, -2e22], [], norms_all), 12002*da0073e9SAndroid Build Coastguard Worker (True, True, [2e22, -2e22], [], norms_all), 12003*da0073e9SAndroid Build Coastguard Worker 12004*da0073e9SAndroid Build Coastguard Worker # Test a grad that will overflow to inf for only some norm orders 12005*da0073e9SAndroid Build Coastguard Worker (False, False, [2e200, -2e200], [3.5, 2, -2, -3.5], [inf, 1, 0.1, 0, -1, -0.1]), 12006*da0073e9SAndroid Build Coastguard Worker (False, True, [2e200, -2e200], [3.5, 2], norms_neg + [inf, 1, 0.1, 0]), 12007*da0073e9SAndroid Build Coastguard Worker (True, False, [2e200, -2e200], [3.5, 2], norms_neg + [inf, 1, 0.1, 0]), 12008*da0073e9SAndroid Build Coastguard Worker (True, True, [2e200, -2e200], [3.5, 2], norms_neg + [inf, 1, 0.1, 0]), 12009*da0073e9SAndroid Build Coastguard Worker ] 12010*da0073e9SAndroid Build Coastguard Worker 12011*da0073e9SAndroid Build Coastguard Worker def gen_parameters(scalar, grad_only_one_elem, prefix_finite_grad_param): 12012*da0073e9SAndroid Build Coastguard Worker param = torch.ones(10, dtype=torch.float64, device=device, requires_grad=True) 12013*da0073e9SAndroid Build Coastguard Worker 12014*da0073e9SAndroid Build Coastguard Worker if grad_only_one_elem: 12015*da0073e9SAndroid Build Coastguard Worker param[1].mul(scalar).sum().backward() 12016*da0073e9SAndroid Build Coastguard Worker else: 12017*da0073e9SAndroid Build Coastguard Worker param.mul(scalar).sum().backward() 12018*da0073e9SAndroid Build Coastguard Worker 12019*da0073e9SAndroid Build Coastguard Worker if prefix_finite_grad_param: 12020*da0073e9SAndroid Build Coastguard Worker prefix_param = torch.ones(1, dtype=torch.float64, device=device, requires_grad=True) 12021*da0073e9SAndroid Build Coastguard Worker prefix_param.mul(1).sum().backward() 12022*da0073e9SAndroid Build Coastguard Worker parameters = [prefix_param, param] 12023*da0073e9SAndroid Build Coastguard Worker else: 12024*da0073e9SAndroid Build Coastguard Worker parameters = [param] 12025*da0073e9SAndroid Build Coastguard Worker 12026*da0073e9SAndroid Build Coastguard Worker return parameters 12027*da0073e9SAndroid Build Coastguard Worker 12028*da0073e9SAndroid Build Coastguard Worker def run_test_case(norm_type, error_if_nonfinite, scalar, grad_only_one_elem, prefix_finite_grad_param, is_norm_nonfinite): 12029*da0073e9SAndroid Build Coastguard Worker msg = ( 12030*da0073e9SAndroid Build Coastguard Worker f'norm_type: {norm_type}, ', 12031*da0073e9SAndroid Build Coastguard Worker f'error_if_nonfinite: {error_if_nonfinite}, ' 12032*da0073e9SAndroid Build Coastguard Worker f'scalar: {scalar}, ' 12033*da0073e9SAndroid Build Coastguard Worker f'grad_only_one_elem: {grad_only_one_elem}, ' 12034*da0073e9SAndroid Build Coastguard Worker f'prefix_finite_grad_param: {prefix_finite_grad_param}, ' 12035*da0073e9SAndroid Build Coastguard Worker f'is_norm_nonfinite: {is_norm_nonfinite}') 12036*da0073e9SAndroid Build Coastguard Worker 12037*da0073e9SAndroid Build Coastguard Worker parameters = gen_parameters(scalar, grad_only_one_elem, prefix_finite_grad_param) 12038*da0073e9SAndroid Build Coastguard Worker 12039*da0073e9SAndroid Build Coastguard Worker # Should only throw an error if the total norm is expected to be 12040*da0073e9SAndroid Build Coastguard Worker # nonfinite and `error_if_nonfinite=True` 12041*da0073e9SAndroid Build Coastguard Worker if is_norm_nonfinite and error_if_nonfinite: 12042*da0073e9SAndroid Build Coastguard Worker error_msg = f'The total norm of order {float(norm_type)} for gradients' 12043*da0073e9SAndroid Build Coastguard Worker 12044*da0073e9SAndroid Build Coastguard Worker grads_before = [p.grad.clone() for p in parameters] 12045*da0073e9SAndroid Build Coastguard Worker 12046*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, error_msg, msg=msg): 12047*da0073e9SAndroid Build Coastguard Worker clip_grad_norm_(parameters, 1, norm_type=norm_type, error_if_nonfinite=True) 12048*da0073e9SAndroid Build Coastguard Worker 12049*da0073e9SAndroid Build Coastguard Worker # Grad should not change if error is thrown 12050*da0073e9SAndroid Build Coastguard Worker grads_after = [p.grad for p in parameters] 12051*da0073e9SAndroid Build Coastguard Worker self.assertEqual(grads_before, grads_after, msg=msg) 12052*da0073e9SAndroid Build Coastguard Worker else: 12053*da0073e9SAndroid Build Coastguard Worker clip_grad_norm_(parameters, 1, norm_type=norm_type, error_if_nonfinite=error_if_nonfinite) 12054*da0073e9SAndroid Build Coastguard Worker 12055*da0073e9SAndroid Build Coastguard Worker for grad_only_one_elem, prefix_finite_grad_param, scalars, norms_nonfinite, norms_finite in test_cases: 12056*da0073e9SAndroid Build Coastguard Worker for error_if_nonfinite in [False, True]: 12057*da0073e9SAndroid Build Coastguard Worker for norm_type, scalar in product(norms_nonfinite, scalars): 12058*da0073e9SAndroid Build Coastguard Worker run_test_case(norm_type, error_if_nonfinite, scalar, grad_only_one_elem, prefix_finite_grad_param, True) 12059*da0073e9SAndroid Build Coastguard Worker 12060*da0073e9SAndroid Build Coastguard Worker for norm_type, scalar in product(norms_finite, scalars): 12061*da0073e9SAndroid Build Coastguard Worker run_test_case(norm_type, error_if_nonfinite, scalar, grad_only_one_elem, prefix_finite_grad_param, False) 12062*da0073e9SAndroid Build Coastguard Worker 12063*da0073e9SAndroid Build Coastguard Worker @onlyCUDA 12064*da0073e9SAndroid Build Coastguard Worker @deviceCountAtLeast(2) 12065*da0073e9SAndroid Build Coastguard Worker @parametrize_test('foreach', (False, True)) 12066*da0073e9SAndroid Build Coastguard Worker def test_clip_grad_norm_multi_device(self, devices, foreach): 12067*da0073e9SAndroid Build Coastguard Worker class TestModel(nn.Module): 12068*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 12069*da0073e9SAndroid Build Coastguard Worker super().__init__() 12070*da0073e9SAndroid Build Coastguard Worker self.layer1 = nn.Linear(10, 10) 12071*da0073e9SAndroid Build Coastguard Worker self.layer2 = nn.Linear(10, 10) 12072*da0073e9SAndroid Build Coastguard Worker 12073*da0073e9SAndroid Build Coastguard Worker test_model = TestModel() 12074*da0073e9SAndroid Build Coastguard Worker test_model.layer1.to(devices[0]) 12075*da0073e9SAndroid Build Coastguard Worker test_model.layer2.to(devices[1]) 12076*da0073e9SAndroid Build Coastguard Worker ref_model = TestModel().to(devices[0]) 12077*da0073e9SAndroid Build Coastguard Worker for norm_type in [2., math.inf]: 12078*da0073e9SAndroid Build Coastguard Worker for p in test_model.parameters(): 12079*da0073e9SAndroid Build Coastguard Worker p.grad = torch.ones_like(p) 12080*da0073e9SAndroid Build Coastguard Worker for p in ref_model.parameters(): 12081*da0073e9SAndroid Build Coastguard Worker p.grad = torch.ones_like(p) 12082*da0073e9SAndroid Build Coastguard Worker norm = clip_grad_norm_(test_model.parameters(), 0.5, norm_type=norm_type, foreach=foreach) 12083*da0073e9SAndroid Build Coastguard Worker expected = clip_grad_norm_(ref_model.parameters(), 0.5, norm_type=norm_type, foreach=foreach) 12084*da0073e9SAndroid Build Coastguard Worker self.assertEqual(norm, expected) 12085*da0073e9SAndroid Build Coastguard Worker for p, pe in zip(test_model.parameters(), ref_model.parameters()): 12086*da0073e9SAndroid Build Coastguard Worker self.assertEqual(p.grad.to(devices[0]), pe.grad) 12087*da0073e9SAndroid Build Coastguard Worker 12088*da0073e9SAndroid Build Coastguard Worker def test_elu_inplace_overlap(self, device): 12089*da0073e9SAndroid Build Coastguard Worker dtype = torch.bfloat16 if device != 'mps:0' else torch.float16 12090*da0073e9SAndroid Build Coastguard Worker x = torch.randn((1, 6), dtype=dtype, device=device).expand((6, 6)) 12091*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, 'unsupported operation'): 12092*da0073e9SAndroid Build Coastguard Worker F.elu(x, inplace=True) 12093*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, 'unsupported operation'): 12094*da0073e9SAndroid Build Coastguard Worker F.elu_(x) 12095*da0073e9SAndroid Build Coastguard Worker 12096*da0073e9SAndroid Build Coastguard Worker # Merge into OpInfo? 12097*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 12098*da0073e9SAndroid Build Coastguard Worker def test_elu_inplace_with_neg_alpha(self, device): 12099*da0073e9SAndroid Build Coastguard Worker a = torch.tensor([-1., 1.], device=device, requires_grad=True) 12100*da0073e9SAndroid Build Coastguard Worker b = torch.nn.functional.elu_(a.clone(), alpha=-2) 12101*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "call out-of-place version"): 12102*da0073e9SAndroid Build Coastguard Worker b.backward(torch.ones(2, device=device)) 12103*da0073e9SAndroid Build Coastguard Worker 12104*da0073e9SAndroid Build Coastguard Worker a = torch.tensor([-1., 1.], device=device, requires_grad=True) 12105*da0073e9SAndroid Build Coastguard Worker b = torch.nn.functional.celu_(a.clone(), alpha=-2) 12106*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "call out-of-place version"): 12107*da0073e9SAndroid Build Coastguard Worker b.backward(torch.ones(2, device=device)) 12108*da0073e9SAndroid Build Coastguard Worker 12109*da0073e9SAndroid Build Coastguard Worker @expectedFailureMeta # https://github.com/pytorch/pytorch/issues/54897 12110*da0073e9SAndroid Build Coastguard Worker def test_hardswish_inplace_overlap(self, device): 12111*da0073e9SAndroid Build Coastguard Worker x = torch.randn((1, 6), device=device).expand((6, 6)) 12112*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, 'unsupported operation'): 12113*da0073e9SAndroid Build Coastguard Worker F.hardswish(x, inplace=True) 12114*da0073e9SAndroid Build Coastguard Worker 12115*da0073e9SAndroid Build Coastguard Worker def test_silu_inplace_overlap(self, device): 12116*da0073e9SAndroid Build Coastguard Worker x = torch.randn((1, 6), device=device).expand((6, 6)) 12117*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, 'unsupported operation'): 12118*da0073e9SAndroid Build Coastguard Worker F.silu(x, inplace=True) 12119*da0073e9SAndroid Build Coastguard Worker 12120*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 12121*da0073e9SAndroid Build Coastguard Worker def test_mish_inplace_overlap(self, device): 12122*da0073e9SAndroid Build Coastguard Worker x = torch.randn((1, 6), device=device).expand((6, 6)) 12123*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, 'unsupported operation'): 12124*da0073e9SAndroid Build Coastguard Worker F.mish(x, inplace=True) 12125*da0073e9SAndroid Build Coastguard Worker 12126*da0073e9SAndroid Build Coastguard Worker def test_softplus_inplace_overlap(self, device): 12127*da0073e9SAndroid Build Coastguard Worker x = torch.randn((1, 6), device=device).expand((6, 6)) 12128*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, 'unsupported operation'): 12129*da0073e9SAndroid Build Coastguard Worker F.softplus(x, out=x) 12130*da0073e9SAndroid Build Coastguard Worker 12131*da0073e9SAndroid Build Coastguard Worker @expectedFailureMPS # TypeError: the MPS framework doesn't support float64 12132*da0073e9SAndroid Build Coastguard Worker def test_softplus_low_threshold(self, device): 12133*da0073e9SAndroid Build Coastguard Worker # Ensure gradients are computed correctly with a low threshold. 12134*da0073e9SAndroid Build Coastguard Worker model = torch.nn.Softplus(threshold=1).double() 12135*da0073e9SAndroid Build Coastguard Worker input = torch.tensor(0.9, device=device, dtype=torch.double, 12136*da0073e9SAndroid Build Coastguard Worker requires_grad=True) 12137*da0073e9SAndroid Build Coastguard Worker output = model(input) 12138*da0073e9SAndroid Build Coastguard Worker torch.autograd.gradcheck(model, input) 12139*da0073e9SAndroid Build Coastguard Worker 12140*da0073e9SAndroid Build Coastguard Worker def test_softshrink_inplace_overlap(self, device): 12141*da0073e9SAndroid Build Coastguard Worker x = torch.randn((1, 6), device=device).expand((6, 6)) 12142*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, 'unsupported operation'): 12143*da0073e9SAndroid Build Coastguard Worker F.softshrink(x, out=x) 12144*da0073e9SAndroid Build Coastguard Worker 12145*da0073e9SAndroid Build Coastguard Worker def test_leaky_relu_inplace_overlap(self, device): 12146*da0073e9SAndroid Build Coastguard Worker x = torch.randn((1, 6), device=device).expand((6, 6)) 12147*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, 'unsupported operation'): 12148*da0073e9SAndroid Build Coastguard Worker F.leaky_relu(x, inplace=True) 12149*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, 'unsupported operation'): 12150*da0073e9SAndroid Build Coastguard Worker F.leaky_relu_(x) 12151*da0073e9SAndroid Build Coastguard Worker 12152*da0073e9SAndroid Build Coastguard Worker # Merge into OpInfo? 12153*da0073e9SAndroid Build Coastguard Worker @expectedFailureMPS # NotImplementedError: aten::rrelu_with_noise_ https://github.com/pytorch/pytorch/issues/77764 12154*da0073e9SAndroid Build Coastguard Worker def test_leaky_relu_inplace_with_neg_slope(self, device): 12155*da0073e9SAndroid Build Coastguard Worker a = torch.tensor([-1., 1.], device=device, requires_grad=True) 12156*da0073e9SAndroid Build Coastguard Worker b = torch.nn.functional.leaky_relu_(a.clone(), -2) 12157*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "call out-of-place version"): 12158*da0073e9SAndroid Build Coastguard Worker b.backward(torch.ones(2, device=device)) 12159*da0073e9SAndroid Build Coastguard Worker 12160*da0073e9SAndroid Build Coastguard Worker a = torch.tensor([-1., 1.], device=device, requires_grad=True) 12161*da0073e9SAndroid Build Coastguard Worker b = torch.nn.functional.rrelu_(a.clone(), -5.0, 1.0) 12162*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "call out-of-place version"): 12163*da0073e9SAndroid Build Coastguard Worker b.backward(torch.ones(2, device=device)) 12164*da0073e9SAndroid Build Coastguard Worker 12165*da0073e9SAndroid Build Coastguard Worker # Merge into OpInfo? 12166*da0073e9SAndroid Build Coastguard Worker def test_leaky_relu_inplace_with_zero_slope(self, device): 12167*da0073e9SAndroid Build Coastguard Worker a = torch.tensor([-2., 0., 2.], device=device, requires_grad=True) 12168*da0073e9SAndroid Build Coastguard Worker b = torch.nn.functional.leaky_relu_(a.clone(), 0.0) 12169*da0073e9SAndroid Build Coastguard Worker b.backward(torch.ones(3, device=device)) 12170*da0073e9SAndroid Build Coastguard Worker expected = torch.tensor([0., 0., 1.], device=device) 12171*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a.grad, expected) 12172*da0073e9SAndroid Build Coastguard Worker 12173*da0073e9SAndroid Build Coastguard Worker dtype = torch.bfloat16 if device != 'mps:0' else torch.float16 12174*da0073e9SAndroid Build Coastguard Worker a_bf16 = torch.tensor([-2., 0., 2.], device=device, dtype=dtype, requires_grad=True) 12175*da0073e9SAndroid Build Coastguard Worker b_bf16 = torch.nn.functional.leaky_relu_(a_bf16.clone(), 0.0) 12176*da0073e9SAndroid Build Coastguard Worker b_bf16.backward(torch.ones(3, device=device)) 12177*da0073e9SAndroid Build Coastguard Worker expected_bf16 = torch.tensor([0., 0., 1.], device=device, dtype=dtype) 12178*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a_bf16.grad, expected_bf16) 12179*da0073e9SAndroid Build Coastguard Worker 12180*da0073e9SAndroid Build Coastguard Worker @onlyCPU 12181*da0073e9SAndroid Build Coastguard Worker def test_softshrink(self, device): 12182*da0073e9SAndroid Build Coastguard Worker x = torch.tensor([[1.21, 0.56, 0.5001, 0.4999, 1.2357, -0.4999, -0.5001, -1.154, 12183*da0073e9SAndroid Build Coastguard Worker 0.254, -0.24, -0.225, 0.104, 0.002, -0.001, 0.0574, 1.2344, 12184*da0073e9SAndroid Build Coastguard Worker 0.1748, -0.1797, -0.8125, 0.2051, -1.1328, 1.2344, -0.1562, 2.3554, 12185*da0073e9SAndroid Build Coastguard Worker -0.1953, 0.0304, -0.3613, -1.3047, 1.0312, 0.1436, -0.6953, 0.5664, 12186*da0073e9SAndroid Build Coastguard Worker -0.5820, -0.3301, 0.8203, 0.6133, 0.5938], 12187*da0073e9SAndroid Build Coastguard Worker [-0.8203, -1.2344, -0.5234, 2.5312, -0.4551, -0.6875, -1.5547, -0.2217, 12188*da0073e9SAndroid Build Coastguard Worker -0.3027, 2.6406, 1.3047, 0.2344, -1.6719, 0.2773, -1.3516, 3.4575, 12189*da0073e9SAndroid Build Coastguard Worker 0.4414, 0.2656, 2.1094, -1.5156, 1.2344, -0.4336, 0.6797, -3.5486, 12190*da0073e9SAndroid Build Coastguard Worker 0.9766, -0.4062, 1.4844, 0.7500, -1.7578, 0.7461, 1.6094, 8.5458, 12191*da0073e9SAndroid Build Coastguard Worker 0.3730, -0.3477, -1.0625, 0.3848, 0.0557]], device=device) 12192*da0073e9SAndroid Build Coastguard Worker expected = torch.tensor([[0.71, 0.06, 0.0001, 0., 0.7357, 0., -0.0001, -0.654, 12193*da0073e9SAndroid Build Coastguard Worker 0., 0., 0., 0., 0., 0., 0., 0.7344, 12194*da0073e9SAndroid Build Coastguard Worker 0., 0., -0.3125, 0., -0.6328, 0.7344, 0., 1.8554, 12195*da0073e9SAndroid Build Coastguard Worker 0., 0., 0., -0.8047, 0.5312, 0., -0.1953, 0.0664, 12196*da0073e9SAndroid Build Coastguard Worker -0.0820, 0.0, 0.3203, 0.1133, 0.0938], 12197*da0073e9SAndroid Build Coastguard Worker [-0.3203, -0.7344, -0.0234, 2.0312, 0.0, -0.1875, -1.0547, 0., 12198*da0073e9SAndroid Build Coastguard Worker 0.0, 2.1406, 0.8047, 0., -1.1719, 0., -0.8516, 2.9575, 12199*da0073e9SAndroid Build Coastguard Worker 0., 0., 1.6094, -1.0156, 0.7344, 0., 0.1797, -3.0486, 12200*da0073e9SAndroid Build Coastguard Worker 0.4766, 0., 0.9844, 0.2500, -1.2578, 0.2461, 1.1094, 8.0458, 12201*da0073e9SAndroid Build Coastguard Worker 0., 0., -0.5625, 0., 0.]]) 12202*da0073e9SAndroid Build Coastguard Worker softshrink = torch.nn.Softshrink() 12203*da0073e9SAndroid Build Coastguard Worker out = softshrink(x) 12204*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out, expected, atol=1e-2, rtol=0) 12205*da0073e9SAndroid Build Coastguard Worker 12206*da0073e9SAndroid Build Coastguard Worker def test_threshold_inplace_overlap(self, device): 12207*da0073e9SAndroid Build Coastguard Worker # Inplace threshold is okay, because it is idempotent 12208*da0073e9SAndroid Build Coastguard Worker x = torch.randn((1, 6), device=device).expand((6, 6)) 12209*da0073e9SAndroid Build Coastguard Worker F.threshold(x, 0.5, 0.5, inplace=True) 12210*da0073e9SAndroid Build Coastguard Worker F.threshold_(x, 0.5, 0.5) 12211*da0073e9SAndroid Build Coastguard Worker 12212*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 12213*da0073e9SAndroid Build Coastguard Worker def test_triplet_margin_with_distance_loss_default_parity(self, device): 12214*da0073e9SAndroid Build Coastguard Worker # Test for `nn.TripletMarginWithDistanceLoss` and 12215*da0073e9SAndroid Build Coastguard Worker # `F.triplet_margin_with_distance_loss`. Checks 12216*da0073e9SAndroid Build Coastguard Worker # for parity against the respective non-distance-agnostic 12217*da0073e9SAndroid Build Coastguard Worker # implementations of triplet margin loss (``nn.TripletMarginLoss` 12218*da0073e9SAndroid Build Coastguard Worker # and `F.triplet_margin_loss`) under *default args*. 12219*da0073e9SAndroid Build Coastguard Worker 12220*da0073e9SAndroid Build Coastguard Worker for extra_args in \ 12221*da0073e9SAndroid Build Coastguard Worker itertools.product((0.5, 1, 1.5), (True, False), ('none', 'mean', 'sum')): 12222*da0073e9SAndroid Build Coastguard Worker kwargs = {'margin': extra_args[0], 'swap': extra_args[1], 'reduction': extra_args[2]} 12223*da0073e9SAndroid Build Coastguard Worker 12224*da0073e9SAndroid Build Coastguard Worker anchor = torch.randn(5, 10, device=device, requires_grad=True, dtype=torch.double) 12225*da0073e9SAndroid Build Coastguard Worker positive = torch.randn(5, 10, device=device, requires_grad=True, dtype=torch.double) 12226*da0073e9SAndroid Build Coastguard Worker negative = torch.randn(5, 10, device=device, requires_grad=True, dtype=torch.double) 12227*da0073e9SAndroid Build Coastguard Worker 12228*da0073e9SAndroid Build Coastguard Worker # Test forward, functional 12229*da0073e9SAndroid Build Coastguard Worker expected = F.triplet_margin_loss(anchor, positive, negative, **kwargs) 12230*da0073e9SAndroid Build Coastguard Worker actual = F.triplet_margin_with_distance_loss(anchor, positive, negative, **kwargs) 12231*da0073e9SAndroid Build Coastguard Worker self.assertEqual(actual, expected, rtol=1e-6, atol=1e-6) 12232*da0073e9SAndroid Build Coastguard Worker 12233*da0073e9SAndroid Build Coastguard Worker # Test forward, module 12234*da0073e9SAndroid Build Coastguard Worker loss_ref = nn.TripletMarginLoss(**kwargs) 12235*da0073e9SAndroid Build Coastguard Worker loss_op = nn.TripletMarginWithDistanceLoss(**kwargs) 12236*da0073e9SAndroid Build Coastguard Worker self.assertEqual(loss_op(anchor, positive, negative), 12237*da0073e9SAndroid Build Coastguard Worker loss_ref(anchor, positive, negative), 12238*da0073e9SAndroid Build Coastguard Worker rtol=1e-6, atol=1e-6) 12239*da0073e9SAndroid Build Coastguard Worker 12240*da0073e9SAndroid Build Coastguard Worker # Test backward 12241*da0073e9SAndroid Build Coastguard Worker self.assertTrue(gradcheck(lambda a, p, n: F.triplet_margin_with_distance_loss( 12242*da0073e9SAndroid Build Coastguard Worker a, p, n, **kwargs), (anchor, positive, negative))) 12243*da0073e9SAndroid Build Coastguard Worker self.assertTrue(gradcheck(lambda a, p, n: loss_op(a, p, n), 12244*da0073e9SAndroid Build Coastguard Worker (anchor, positive, negative))) 12245*da0073e9SAndroid Build Coastguard Worker 12246*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 12247*da0073e9SAndroid Build Coastguard Worker def test_triplet_margin_with_distance_loss(self, device): 12248*da0073e9SAndroid Build Coastguard Worker # Test for parity between `nn.TripletMarginWithDistanceLoss` and 12249*da0073e9SAndroid Build Coastguard Worker # `F.triplet_margin_with_distance_loss`. 12250*da0073e9SAndroid Build Coastguard Worker 12251*da0073e9SAndroid Build Coastguard Worker pairwise_distance = nn.PairwiseDistance() 12252*da0073e9SAndroid Build Coastguard Worker 12253*da0073e9SAndroid Build Coastguard Worker def cosine_distance(x, y): 12254*da0073e9SAndroid Build Coastguard Worker return 1.0 - F.cosine_similarity(x, y) 12255*da0073e9SAndroid Build Coastguard Worker 12256*da0073e9SAndroid Build Coastguard Worker distance_functions = (pairwise_distance, cosine_distance, 12257*da0073e9SAndroid Build Coastguard Worker lambda x, y: 1.0 - F.cosine_similarity(x, y)) 12258*da0073e9SAndroid Build Coastguard Worker 12259*da0073e9SAndroid Build Coastguard Worker reductions = ('mean', 'none', 'sum') 12260*da0073e9SAndroid Build Coastguard Worker margins = (1.0, 1.5, 0.5) 12261*da0073e9SAndroid Build Coastguard Worker swaps = (True, False) 12262*da0073e9SAndroid Build Coastguard Worker 12263*da0073e9SAndroid Build Coastguard Worker for distance_fn, reduction, margin, swap \ 12264*da0073e9SAndroid Build Coastguard Worker in itertools.product(distance_functions, reductions, margins, swaps): 12265*da0073e9SAndroid Build Coastguard Worker anchor = torch.randn(5, 10, device=device, requires_grad=True, dtype=torch.double) 12266*da0073e9SAndroid Build Coastguard Worker positive = torch.randn(5, 10, device=device, requires_grad=True, dtype=torch.double) 12267*da0073e9SAndroid Build Coastguard Worker negative = torch.randn(5, 10, device=device, requires_grad=True, dtype=torch.double) 12268*da0073e9SAndroid Build Coastguard Worker 12269*da0073e9SAndroid Build Coastguard Worker # Test backward 12270*da0073e9SAndroid Build Coastguard Worker self.assertTrue(gradcheck(lambda a, p, n: F.triplet_margin_with_distance_loss( 12271*da0073e9SAndroid Build Coastguard Worker a, p, n, distance_function=distance_fn, reduction=reduction, margin=margin, swap=swap), 12272*da0073e9SAndroid Build Coastguard Worker (anchor, positive, negative))) 12273*da0073e9SAndroid Build Coastguard Worker loss_op = nn.TripletMarginWithDistanceLoss(distance_function=distance_fn, 12274*da0073e9SAndroid Build Coastguard Worker reduction=reduction, margin=margin, swap=swap) 12275*da0073e9SAndroid Build Coastguard Worker self.assertTrue(gradcheck(lambda a, p, n: loss_op( 12276*da0073e9SAndroid Build Coastguard Worker a, p, n), (anchor, positive, negative))) 12277*da0073e9SAndroid Build Coastguard Worker traced_loss_op = torch.jit.trace(loss_op, (anchor, positive, negative)) 12278*da0073e9SAndroid Build Coastguard Worker self.assertTrue(gradcheck(lambda a, p, n: traced_loss_op( 12279*da0073e9SAndroid Build Coastguard Worker a, p, n), (anchor, positive, negative))) 12280*da0073e9SAndroid Build Coastguard Worker 12281*da0073e9SAndroid Build Coastguard Worker # Test forward parity 12282*da0073e9SAndroid Build Coastguard Worker functional = F.triplet_margin_with_distance_loss(anchor, positive, negative, 12283*da0073e9SAndroid Build Coastguard Worker distance_function=distance_fn, 12284*da0073e9SAndroid Build Coastguard Worker reduction=reduction, margin=margin, swap=swap) 12285*da0073e9SAndroid Build Coastguard Worker modular = loss_op(anchor, positive, negative) 12286*da0073e9SAndroid Build Coastguard Worker traced = traced_loss_op(anchor, positive, negative) 12287*da0073e9SAndroid Build Coastguard Worker self.assertEqual(functional, modular, atol=1e-6, rtol=1e-6) 12288*da0073e9SAndroid Build Coastguard Worker self.assertEqual(traced, modular, atol=1e-6, rtol=1e-6) 12289*da0073e9SAndroid Build Coastguard Worker 12290*da0073e9SAndroid Build Coastguard Worker @dtypesIfMPS(torch.cfloat, torch.float) 12291*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.cfloat, torch.cdouble, torch.float) 12292*da0073e9SAndroid Build Coastguard Worker def test_to_complex(self, device, dtype): 12293*da0073e9SAndroid Build Coastguard Worker m = nn.Linear(3, 5).to(device) 12294*da0073e9SAndroid Build Coastguard Worker self.assertIs(m, m.to(device)) 12295*da0073e9SAndroid Build Coastguard Worker m.to(dtype) 12296*da0073e9SAndroid Build Coastguard Worker self.assertIs(m.weight.dtype, dtype) 12297*da0073e9SAndroid Build Coastguard Worker with warnings.catch_warnings(record=True) as w: 12298*da0073e9SAndroid Build Coastguard Worker # Trigger warning 12299*da0073e9SAndroid Build Coastguard Worker m.to(torch.cfloat) 12300*da0073e9SAndroid Build Coastguard Worker # Check warning occurs 12301*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(w), 1) 12302*da0073e9SAndroid Build Coastguard Worker self.assertTrue("Complex modules are a new feature" in str(w[-1].message)) 12303*da0073e9SAndroid Build Coastguard Worker 12304*da0073e9SAndroid Build Coastguard Worker @skipMeta 12305*da0073e9SAndroid Build Coastguard Worker @dtypesIfMPS(torch.float32) 12306*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float32, torch.float64) 12307*da0073e9SAndroid Build Coastguard Worker def test_module_to_empty(self, device, dtype): 12308*da0073e9SAndroid Build Coastguard Worker class MyModule(nn.Module): 12309*da0073e9SAndroid Build Coastguard Worker def __init__(self, in_features, out_features, device=None, dtype=None): 12310*da0073e9SAndroid Build Coastguard Worker super().__init__() 12311*da0073e9SAndroid Build Coastguard Worker factory_kwargs = {"device": device, "dtype": dtype} 12312*da0073e9SAndroid Build Coastguard Worker self.weight = nn.Parameter(torch.randn(in_features, out_features, **factory_kwargs)) 12313*da0073e9SAndroid Build Coastguard Worker 12314*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 12315*da0073e9SAndroid Build Coastguard Worker return x @ self.weight 12316*da0073e9SAndroid Build Coastguard Worker 12317*da0073e9SAndroid Build Coastguard Worker # Test meta module instantiation. 12318*da0073e9SAndroid Build Coastguard Worker input = torch.randn(5, 10, device=device, dtype=dtype) 12319*da0073e9SAndroid Build Coastguard Worker m = MyModule(10, 1, device='meta', dtype=dtype) 12320*da0073e9SAndroid Build Coastguard Worker m(input) 12321*da0073e9SAndroid Build Coastguard Worker 12322*da0073e9SAndroid Build Coastguard Worker # Test empty meta module error with torch.nn.Module.to(). 12323*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 12324*da0073e9SAndroid Build Coastguard Worker NotImplementedError, 12325*da0073e9SAndroid Build Coastguard Worker re.escape( 12326*da0073e9SAndroid Build Coastguard Worker "Cannot copy out of meta tensor; no data! Please use torch.nn.Module.to_empty() " 12327*da0073e9SAndroid Build Coastguard Worker "instead of torch.nn.Module.to() when moving module from meta to a different " 12328*da0073e9SAndroid Build Coastguard Worker "device." 12329*da0073e9SAndroid Build Coastguard Worker ), 12330*da0073e9SAndroid Build Coastguard Worker ): 12331*da0073e9SAndroid Build Coastguard Worker m.to(device) 12332*da0073e9SAndroid Build Coastguard Worker 12333*da0073e9SAndroid Build Coastguard Worker # Test materializing meta module on a real device. 12334*da0073e9SAndroid Build Coastguard Worker m.to_empty(device=device) 12335*da0073e9SAndroid Build Coastguard Worker m(input) 12336*da0073e9SAndroid Build Coastguard Worker with torch.no_grad(): 12337*da0073e9SAndroid Build Coastguard Worker torch.nn.init.kaiming_uniform_(m.weight) 12338*da0073e9SAndroid Build Coastguard Worker m(input) 12339*da0073e9SAndroid Build Coastguard Worker 12340*da0073e9SAndroid Build Coastguard Worker # Test creating meta module from materialized module. 12341*da0073e9SAndroid Build Coastguard Worker m.to_empty(device='meta') 12342*da0073e9SAndroid Build Coastguard Worker m(input) 12343*da0073e9SAndroid Build Coastguard Worker 12344*da0073e9SAndroid Build Coastguard Worker def test_module_to_empty_non_recursive(self, device): 12345*da0073e9SAndroid Build Coastguard Worker class Layer(nn.Module): 12346*da0073e9SAndroid Build Coastguard Worker def __init__(self, in_features, out_features): 12347*da0073e9SAndroid Build Coastguard Worker super().__init__() 12348*da0073e9SAndroid Build Coastguard Worker self.weight = nn.Parameter(torch.randn(in_features, out_features)) 12349*da0073e9SAndroid Build Coastguard Worker self.register_buffer('buf', torch.randn(out_features)) 12350*da0073e9SAndroid Build Coastguard Worker 12351*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 12352*da0073e9SAndroid Build Coastguard Worker return x @ self.weight + self.buf 12353*da0073e9SAndroid Build Coastguard Worker 12354*da0073e9SAndroid Build Coastguard Worker class MyModule(nn.Module): 12355*da0073e9SAndroid Build Coastguard Worker def __init__(self, in_features, out_features): 12356*da0073e9SAndroid Build Coastguard Worker super().__init__() 12357*da0073e9SAndroid Build Coastguard Worker self.weight = nn.Parameter(torch.randn(in_features, out_features)) 12358*da0073e9SAndroid Build Coastguard Worker self.register_buffer('buf1', torch.randn(out_features)) 12359*da0073e9SAndroid Build Coastguard Worker self.layer = Layer(out_features, out_features) 12360*da0073e9SAndroid Build Coastguard Worker 12361*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 12362*da0073e9SAndroid Build Coastguard Worker return self.layer(x @ self.weight + self.buf1) 12363*da0073e9SAndroid Build Coastguard Worker 12364*da0073e9SAndroid Build Coastguard Worker with torch.device('meta'): 12365*da0073e9SAndroid Build Coastguard Worker m = MyModule(3, 5) 12366*da0073e9SAndroid Build Coastguard Worker 12367*da0073e9SAndroid Build Coastguard Worker m.to_empty(device=device, recurse=False) 12368*da0073e9SAndroid Build Coastguard Worker 12369*da0073e9SAndroid Build Coastguard Worker # params/buffers of parent should have been materialized on device 12370*da0073e9SAndroid Build Coastguard Worker self.assertTrue(not m.weight.is_meta) 12371*da0073e9SAndroid Build Coastguard Worker self.assertTrue(not m.buf1.is_meta) 12372*da0073e9SAndroid Build Coastguard Worker 12373*da0073e9SAndroid Build Coastguard Worker # parameters/buffers of children submodules should still be on meta 12374*da0073e9SAndroid Build Coastguard Worker for p in (*m.layer.parameters(), *m.layer.buffers()): 12375*da0073e9SAndroid Build Coastguard Worker self.assertTrue(p.is_meta) 12376*da0073e9SAndroid Build Coastguard Worker 12377*da0073e9SAndroid Build Coastguard Worker @skipMeta 12378*da0073e9SAndroid Build Coastguard Worker def test_skip_init(self, device): 12379*da0073e9SAndroid Build Coastguard Worker torch.manual_seed(1) 12380*da0073e9SAndroid Build Coastguard Worker m_initialized = torch.nn.Linear(5, 1) 12381*da0073e9SAndroid Build Coastguard Worker m_initialized.to(device) 12382*da0073e9SAndroid Build Coastguard Worker 12383*da0073e9SAndroid Build Coastguard Worker torch.manual_seed(1) 12384*da0073e9SAndroid Build Coastguard Worker m_uninitialized = torch.nn.utils.skip_init(torch.nn.Linear, 5, 1, device=device) 12385*da0073e9SAndroid Build Coastguard Worker 12386*da0073e9SAndroid Build Coastguard Worker self.assertEqual(m_initialized.weight.device, m_uninitialized.weight.device) 12387*da0073e9SAndroid Build Coastguard Worker self.assertFalse(torch.allclose(m_initialized.weight, m_uninitialized.weight)) 12388*da0073e9SAndroid Build Coastguard Worker 12389*da0073e9SAndroid Build Coastguard Worker @skipIfRocm(msg='See https://github.com/pytorch/pytorch/issues/135150') 12390*da0073e9SAndroid Build Coastguard Worker @skipIfMps # TODO(hvaara): Investigate as possible bug. macOS 13 passes, while 14 and 15 fails. 12391*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float) 12392*da0073e9SAndroid Build Coastguard Worker @dtypesIfCUDA(torch.double, torch.float, torch.half) 12393*da0073e9SAndroid Build Coastguard Worker def test_transformerencoderlayer(self, device, dtype): 12394*da0073e9SAndroid Build Coastguard Worker if TEST_WITH_ROCM and PLATFORM_SUPPORTS_FLASH_ATTENTION and dtype == torch.half: 12395*da0073e9SAndroid Build Coastguard Worker self.skipTest("Skip on ROCM due to Flash Attention tolerances") 12396*da0073e9SAndroid Build Coastguard Worker # this is a deterministic test for TransformerEncoderLayer 12397*da0073e9SAndroid Build Coastguard Worker d_model = 4 12398*da0073e9SAndroid Build Coastguard Worker nhead = 2 12399*da0073e9SAndroid Build Coastguard Worker dim_feedforward = 16 12400*da0073e9SAndroid Build Coastguard Worker dropout = 0.0 12401*da0073e9SAndroid Build Coastguard Worker bsz = 2 12402*da0073e9SAndroid Build Coastguard Worker 12403*da0073e9SAndroid Build Coastguard Worker atol = 1e-5 12404*da0073e9SAndroid Build Coastguard Worker rtol = 1e-7 12405*da0073e9SAndroid Build Coastguard Worker if "cuda" in device: 12406*da0073e9SAndroid Build Coastguard Worker atol = 1e-3 12407*da0073e9SAndroid Build Coastguard Worker rtol = 1e-2 12408*da0073e9SAndroid Build Coastguard Worker 12409*da0073e9SAndroid Build Coastguard Worker def _test(training, batch_first, atol, rtol): 12410*da0073e9SAndroid Build Coastguard Worker def perm_fn(x): 12411*da0073e9SAndroid Build Coastguard Worker return x.transpose(1, 0) if batch_first else x 12412*da0073e9SAndroid Build Coastguard Worker 12413*da0073e9SAndroid Build Coastguard Worker model = nn.TransformerEncoderLayer(d_model, nhead, dim_feedforward, dropout, 12414*da0073e9SAndroid Build Coastguard Worker batch_first=batch_first, device=device, dtype=dtype) 12415*da0073e9SAndroid Build Coastguard Worker 12416*da0073e9SAndroid Build Coastguard Worker if not training: 12417*da0073e9SAndroid Build Coastguard Worker assert dropout == 0 12418*da0073e9SAndroid Build Coastguard Worker model = model.eval() 12419*da0073e9SAndroid Build Coastguard Worker 12420*da0073e9SAndroid Build Coastguard Worker # set constant weights of the model 12421*da0073e9SAndroid Build Coastguard Worker for idx, p in enumerate(model.parameters()): 12422*da0073e9SAndroid Build Coastguard Worker x = p.data 12423*da0073e9SAndroid Build Coastguard Worker sz = x.view(-1).size(0) 12424*da0073e9SAndroid Build Coastguard Worker shape = x.shape 12425*da0073e9SAndroid Build Coastguard Worker x = torch.cos(torch.arange(0, sz).float().view(shape)) 12426*da0073e9SAndroid Build Coastguard Worker p.data.copy_(x) 12427*da0073e9SAndroid Build Coastguard Worker 12428*da0073e9SAndroid Build Coastguard Worker # deterministic input 12429*da0073e9SAndroid Build Coastguard Worker encoder_input = torch.tensor([[[20., 30., 40., 50.]]], device=device, dtype=dtype) 12430*da0073e9SAndroid Build Coastguard Worker result = model(encoder_input) 12431*da0073e9SAndroid Build Coastguard Worker ref_output = torch.tensor([[[2.258703, 0.127985, -0.697881, 0.170862]]], device=device, dtype=dtype) 12432*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result.shape, ref_output.shape) 12433*da0073e9SAndroid Build Coastguard Worker torch.testing.assert_close(result, ref_output, atol=atol, rtol=rtol) 12434*da0073e9SAndroid Build Coastguard Worker # 0 values are NOT masked. This shouldn't mask anything. 12435*da0073e9SAndroid Build Coastguard Worker mask = torch.tensor([[0]], device=device) == 1 12436*da0073e9SAndroid Build Coastguard Worker # TODO: enable fast path for calls with a mask! 12437*da0073e9SAndroid Build Coastguard Worker result = model(encoder_input, src_key_padding_mask=mask) 12438*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result.shape, ref_output.shape) 12439*da0073e9SAndroid Build Coastguard Worker torch.testing.assert_close(result, ref_output, atol=atol, rtol=rtol) 12440*da0073e9SAndroid Build Coastguard Worker mask = torch.tensor([[1]], device=device) == 1 12441*da0073e9SAndroid Build Coastguard Worker result = model(encoder_input, src_key_padding_mask=mask) 12442*da0073e9SAndroid Build Coastguard Worker fast_path_device = result.is_cuda or result.is_cpu 12443*da0073e9SAndroid Build Coastguard Worker result = result.cpu().detach().numpy() 12444*da0073e9SAndroid Build Coastguard Worker # Non Fast Paths 12445*da0073e9SAndroid Build Coastguard Worker if training or not batch_first or TEST_WITH_CROSSREF or not fast_path_device: 12446*da0073e9SAndroid Build Coastguard Worker # We changed the semenatic, on the non fast path so that fully masked out rows return 12447*da0073e9SAndroid Build Coastguard Worker # 0 from attention thus NaNs should no longer be present and the output should be nonzero 12448*da0073e9SAndroid Build Coastguard Worker # due to skip connections 12449*da0073e9SAndroid Build Coastguard Worker self.assertTrue(not np.isnan(result).any()) 12450*da0073e9SAndroid Build Coastguard Worker else: 12451*da0073e9SAndroid Build Coastguard Worker # Fast Paths 12452*da0073e9SAndroid Build Coastguard Worker self.assertTrue(np.isnan(result).all()) 12453*da0073e9SAndroid Build Coastguard Worker 12454*da0073e9SAndroid Build Coastguard Worker 12455*da0073e9SAndroid Build Coastguard Worker # deterministic input 12456*da0073e9SAndroid Build Coastguard Worker encoder_input = perm_fn(torch.tensor([[[1., 2., 3., 4.]], 12457*da0073e9SAndroid Build Coastguard Worker [[5., 6., 7., 8.]]], device=device, dtype=dtype)) 12458*da0073e9SAndroid Build Coastguard Worker result = model(encoder_input) 12459*da0073e9SAndroid Build Coastguard Worker ref_output = perm_fn(torch.tensor([[[2.272644, 0.119035, -0.691669, 0.153486]], 12460*da0073e9SAndroid Build Coastguard Worker [[2.272644, 0.119035, -0.691669, 0.153486]]], device=device, dtype=dtype)) 12461*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result.shape, ref_output.shape) 12462*da0073e9SAndroid Build Coastguard Worker torch.testing.assert_close(result, ref_output, atol=atol, rtol=rtol) 12463*da0073e9SAndroid Build Coastguard Worker # all 0 which is no masking 12464*da0073e9SAndroid Build Coastguard Worker mask = torch.tensor([[0, 0]], device=device) == 1 12465*da0073e9SAndroid Build Coastguard Worker result = model(encoder_input, src_key_padding_mask=mask) 12466*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result.shape, ref_output.shape) 12467*da0073e9SAndroid Build Coastguard Worker torch.testing.assert_close(result, ref_output, atol=atol, rtol=rtol) 12468*da0073e9SAndroid Build Coastguard Worker mask = torch.tensor([[1, 0]], device=device) == 1 12469*da0073e9SAndroid Build Coastguard Worker result = model(encoder_input, src_key_padding_mask=mask) 12470*da0073e9SAndroid Build Coastguard Worker ref_output = perm_fn(torch.tensor([[[2.301516, 0.092249, -0.679101, 0.103088]], 12471*da0073e9SAndroid Build Coastguard Worker [[2.301516, 0.092249, -0.679101, 0.103088]]], device=device, dtype=dtype)) 12472*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result.shape, ref_output.shape) 12473*da0073e9SAndroid Build Coastguard Worker torch.testing.assert_close(result, ref_output, atol=atol, rtol=rtol) 12474*da0073e9SAndroid Build Coastguard Worker 12475*da0073e9SAndroid Build Coastguard Worker # deterministic input 12476*da0073e9SAndroid Build Coastguard Worker encoder_input = perm_fn(torch.tensor([[[0.7462, 0.6653, 0.5679, 0.4891], 12477*da0073e9SAndroid Build Coastguard Worker [0.5387, 0.1655, 0.3565, 0.0471]], 12478*da0073e9SAndroid Build Coastguard Worker [[0.8335, 0.2799, 0.5031, 0.2947], 12479*da0073e9SAndroid Build Coastguard Worker [0.1402, 0.0318, 0.7636, 0.1346]], 12480*da0073e9SAndroid Build Coastguard Worker [[0.6333, 0.9344, 0.1376, 0.9938], 12481*da0073e9SAndroid Build Coastguard Worker [0.8924, 0.2872, 0.6692, 0.2944]], 12482*da0073e9SAndroid Build Coastguard Worker [[0.9897, 0.6915, 0.3154, 0.1733], 12483*da0073e9SAndroid Build Coastguard Worker [0.8645, 0.3513, 0.3064, 0.0767]], 12484*da0073e9SAndroid Build Coastguard Worker [[0.8117, 0.2366, 0.4838, 0.7881], 12485*da0073e9SAndroid Build Coastguard Worker [0.3718, 0.4945, 0.9511, 0.0864]]], device=device, dtype=dtype)) 12486*da0073e9SAndroid Build Coastguard Worker result = model(encoder_input) 12487*da0073e9SAndroid Build Coastguard Worker ref_output = perm_fn(torch.tensor([[[2.428589, 0.020835, -0.602055, -0.085249], 12488*da0073e9SAndroid Build Coastguard Worker [2.427987, 0.021213, -0.602496, -0.084103]], 12489*da0073e9SAndroid Build Coastguard Worker [[2.424689, 0.019155, -0.604793, -0.085672], 12490*da0073e9SAndroid Build Coastguard Worker [2.413863, 0.022211, -0.612486, -0.072490]], 12491*da0073e9SAndroid Build Coastguard Worker [[2.433774, 0.021598, -0.598343, -0.087548], 12492*da0073e9SAndroid Build Coastguard Worker [2.425104, 0.019748, -0.604515, -0.084839]], 12493*da0073e9SAndroid Build Coastguard Worker [[2.436185, 0.022682, -0.596625, -0.087261], 12494*da0073e9SAndroid Build Coastguard Worker [2.433556, 0.021891, -0.598509, -0.086832]], 12495*da0073e9SAndroid Build Coastguard Worker [[2.416246, 0.017512, -0.610712, -0.082961], 12496*da0073e9SAndroid Build Coastguard Worker [2.422901, 0.024187, -0.606178, -0.074929]]], device=device, dtype=dtype)) 12497*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result.shape, ref_output.shape) 12498*da0073e9SAndroid Build Coastguard Worker torch.testing.assert_close(result, ref_output, atol=atol, rtol=rtol) 12499*da0073e9SAndroid Build Coastguard Worker 12500*da0073e9SAndroid Build Coastguard Worker # all 0 12501*da0073e9SAndroid Build Coastguard Worker mask = torch.zeros([2, 5], device=device) == 1 12502*da0073e9SAndroid Build Coastguard Worker result = model(encoder_input, src_key_padding_mask=mask) 12503*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result.shape, ref_output.shape) 12504*da0073e9SAndroid Build Coastguard Worker torch.testing.assert_close(result, ref_output, atol=atol, rtol=rtol) 12505*da0073e9SAndroid Build Coastguard Worker mask[0, 1] = 1 12506*da0073e9SAndroid Build Coastguard Worker mask[1, 3] = 1 12507*da0073e9SAndroid Build Coastguard Worker mask[1, 4] = 1 12508*da0073e9SAndroid Build Coastguard Worker result = model(encoder_input, src_key_padding_mask=mask) 12509*da0073e9SAndroid Build Coastguard Worker ref_output = perm_fn(torch.tensor([[[2.429026, 0.020793, -0.601741, -0.085642], 12510*da0073e9SAndroid Build Coastguard Worker [2.428811, 0.021445, -0.601912, -0.084252]], 12511*da0073e9SAndroid Build Coastguard Worker [[2.425009, 0.019155, -0.604566, -0.085899], 12512*da0073e9SAndroid Build Coastguard Worker [2.415408, 0.02249 , -0.611415, -0.073]], 12513*da0073e9SAndroid Build Coastguard Worker [[2.434199, 0.021682, -0.598039, -0.087699], 12514*da0073e9SAndroid Build Coastguard Worker [2.42598, 0.019941, -0.603896, -0.085091]], 12515*da0073e9SAndroid Build Coastguard Worker [[2.436457, 0.022736, -0.59643 , -0.08736], 12516*da0073e9SAndroid Build Coastguard Worker [2.434021, 0.022093, -0.598179, -0.08679]], 12517*da0073e9SAndroid Build Coastguard Worker [[2.416531, 0.017498, -0.610513, -0.083181], 12518*da0073e9SAndroid Build Coastguard Worker [2.4242, 0.024653, -0.605266, -0.074959]]], device=device, dtype=dtype)) 12519*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result.shape, ref_output.shape) 12520*da0073e9SAndroid Build Coastguard Worker torch.testing.assert_close(result, ref_output, atol=atol, rtol=rtol) 12521*da0073e9SAndroid Build Coastguard Worker 12522*da0073e9SAndroid Build Coastguard Worker # NestedTensor is only supported for the fast path 12523*da0073e9SAndroid Build Coastguard Worker # currently, which won't be used if training. 12524*da0073e9SAndroid Build Coastguard Worker if (batch_first and not training and 12525*da0073e9SAndroid Build Coastguard Worker ('cuda' in str(device) or 'cpu' in str(device)) and not TEST_WITH_CROSSREF): 12526*da0073e9SAndroid Build Coastguard Worker encoder_input[0][-1] = torch.zeros_like(encoder_input[0][1]) 12527*da0073e9SAndroid Build Coastguard Worker mask = torch.zeros(encoder_input.shape[:-1], device=device, dtype=torch.bool) 12528*da0073e9SAndroid Build Coastguard Worker mask[0][-1] = True 12529*da0073e9SAndroid Build Coastguard Worker 12530*da0073e9SAndroid Build Coastguard Worker nt = torch.nested.nested_tensor([encoder_input[0][:-1], encoder_input[1]], device=device) 12531*da0073e9SAndroid Build Coastguard Worker result = model(nt) 12532*da0073e9SAndroid Build Coastguard Worker ref_output = torch.tensor( 12533*da0073e9SAndroid Build Coastguard Worker [ 12534*da0073e9SAndroid Build Coastguard Worker [ 12535*da0073e9SAndroid Build Coastguard Worker [2.4268184, 0.02042419, -0.603311, -0.08476824], 12536*da0073e9SAndroid Build Coastguard Worker [2.423306, 0.01889652, -0.6057701, -0.08519465], 12537*da0073e9SAndroid Build Coastguard Worker [2.431538, 0.02078694, -0.5999354, -0.08746159], 12538*da0073e9SAndroid Build Coastguard Worker [2.4348664, 0.02212971, -0.5975677, -0.08733892], 12539*da0073e9SAndroid Build Coastguard Worker [2.423133, 0.02097577, -0.60594773, -0.08113337], 12540*da0073e9SAndroid Build Coastguard Worker ], 12541*da0073e9SAndroid Build Coastguard Worker [ 12542*da0073e9SAndroid Build Coastguard Worker [2.4279876, 0.02121329, -0.60249615, -0.08410317], 12543*da0073e9SAndroid Build Coastguard Worker [2.4138637, 0.02221113, -0.6124869, -0.07249016], 12544*da0073e9SAndroid Build Coastguard Worker [2.4251041, 0.01974815, -0.6045152, -0.08483928], 12545*da0073e9SAndroid Build Coastguard Worker [2.4335563, 0.0218913, -0.59850943, -0.08683228], 12546*da0073e9SAndroid Build Coastguard Worker [2.4229012, 0.02418739, -0.6061784, -0.07492948], 12547*da0073e9SAndroid Build Coastguard Worker ], 12548*da0073e9SAndroid Build Coastguard Worker ], 12549*da0073e9SAndroid Build Coastguard Worker device=device, dtype=dtype 12550*da0073e9SAndroid Build Coastguard Worker ) 12551*da0073e9SAndroid Build Coastguard Worker result = result.to_padded_tensor(0) 12552*da0073e9SAndroid Build Coastguard Worker ref_output[0][-1] = torch.zeros_like( 12553*da0073e9SAndroid Build Coastguard Worker ref_output[0][-1], device=device, dtype=dtype 12554*da0073e9SAndroid Build Coastguard Worker ) 12555*da0073e9SAndroid Build Coastguard Worker result[0][-1] = torch.zeros_like( 12556*da0073e9SAndroid Build Coastguard Worker result[0][-1], device=device, dtype=dtype 12557*da0073e9SAndroid Build Coastguard Worker ) 12558*da0073e9SAndroid Build Coastguard Worker self.assertEqual(tuple(result.shape), tuple(ref_output.shape)) 12559*da0073e9SAndroid Build Coastguard Worker if 'cuda' in device: 12560*da0073e9SAndroid Build Coastguard Worker if dtype == torch.float: 12561*da0073e9SAndroid Build Coastguard Worker atol = 2e-4 12562*da0073e9SAndroid Build Coastguard Worker rtol = 4e-3 12563*da0073e9SAndroid Build Coastguard Worker else: 12564*da0073e9SAndroid Build Coastguard Worker atol = 7e-4 12565*da0073e9SAndroid Build Coastguard Worker rtol = 2e-2 12566*da0073e9SAndroid Build Coastguard Worker torch.testing.assert_close(result, ref_output, atol=atol, rtol=rtol) 12567*da0073e9SAndroid Build Coastguard Worker else: 12568*da0073e9SAndroid Build Coastguard Worker torch.testing.assert_close(result, ref_output) 12569*da0073e9SAndroid Build Coastguard Worker 12570*da0073e9SAndroid Build Coastguard Worker 12571*da0073e9SAndroid Build Coastguard Worker for batch_first in (True, False): 12572*da0073e9SAndroid Build Coastguard Worker for training in (True, False): 12573*da0073e9SAndroid Build Coastguard Worker if training: 12574*da0073e9SAndroid Build Coastguard Worker cm = contextlib.nullcontext() 12575*da0073e9SAndroid Build Coastguard Worker else: 12576*da0073e9SAndroid Build Coastguard Worker # Fast path requires inference mode. 12577*da0073e9SAndroid Build Coastguard Worker cm = torch.no_grad() 12578*da0073e9SAndroid Build Coastguard Worker with cm: 12579*da0073e9SAndroid Build Coastguard Worker _test(batch_first=batch_first, training=training, atol=atol, rtol=rtol) 12580*da0073e9SAndroid Build Coastguard Worker 12581*da0073e9SAndroid Build Coastguard Worker @onlyCPU 12582*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.double) 12583*da0073e9SAndroid Build Coastguard Worker def test_transformerencoderlayer_fast_path(self, device, dtype): 12584*da0073e9SAndroid Build Coastguard Worker """ 12585*da0073e9SAndroid Build Coastguard Worker Test transformer fast path on CPU with different valid mask types and shapes 12586*da0073e9SAndroid Build Coastguard Worker """ 12587*da0073e9SAndroid Build Coastguard Worker d_model = 512 12588*da0073e9SAndroid Build Coastguard Worker nhead = 8 12589*da0073e9SAndroid Build Coastguard Worker batch_size = 32 12590*da0073e9SAndroid Build Coastguard Worker src_len = 10 12591*da0073e9SAndroid Build Coastguard Worker 12592*da0073e9SAndroid Build Coastguard Worker model = torch.nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead, batch_first=True, 12593*da0073e9SAndroid Build Coastguard Worker device=device, dtype=dtype, dropout=0) 12594*da0073e9SAndroid Build Coastguard Worker model.eval() 12595*da0073e9SAndroid Build Coastguard Worker 12596*da0073e9SAndroid Build Coastguard Worker # Batched inputs 12597*da0073e9SAndroid Build Coastguard Worker src = torch.rand(batch_size, src_len, 512, dtype=dtype) 12598*da0073e9SAndroid Build Coastguard Worker 12599*da0073e9SAndroid Build Coastguard Worker # Attention mask of shape (src_len, src_len) 12600*da0073e9SAndroid Build Coastguard Worker src_mask = torch.zeros(src_len, src_len).to(torch.bool) 12601*da0073e9SAndroid Build Coastguard Worker with torch.no_grad(): 12602*da0073e9SAndroid Build Coastguard Worker model(src, src_mask=src_mask) 12603*da0073e9SAndroid Build Coastguard Worker 12604*da0073e9SAndroid Build Coastguard Worker # Padding mask of shape (batch_size, src_len) 12605*da0073e9SAndroid Build Coastguard Worker src_key_padding_mask = torch.zeros(batch_size, src_len).to(torch.bool) 12606*da0073e9SAndroid Build Coastguard Worker with torch.no_grad(): 12607*da0073e9SAndroid Build Coastguard Worker model(src, src_key_padding_mask=src_key_padding_mask) 12608*da0073e9SAndroid Build Coastguard Worker 12609*da0073e9SAndroid Build Coastguard Worker # Provide both masks 12610*da0073e9SAndroid Build Coastguard Worker with torch.no_grad(): 12611*da0073e9SAndroid Build Coastguard Worker model(src, src_mask=src_mask, src_key_padding_mask=src_key_padding_mask) 12612*da0073e9SAndroid Build Coastguard Worker 12613*da0073e9SAndroid Build Coastguard Worker 12614*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float) 12615*da0073e9SAndroid Build Coastguard Worker @dtypesIfCUDA(torch.half, torch.float) 12616*da0073e9SAndroid Build Coastguard Worker def test_transformerencoderlayer_gelu(self, device, dtype): 12617*da0073e9SAndroid Build Coastguard Worker if TEST_WITH_ROCM and PLATFORM_SUPPORTS_FLASH_ATTENTION and dtype == torch.half: 12618*da0073e9SAndroid Build Coastguard Worker self.skipTest("Skip on ROCM due to Flash Attention tolerances") 12619*da0073e9SAndroid Build Coastguard Worker # this is a deterministic test for TransformerEncoderLayer with gelu activation 12620*da0073e9SAndroid Build Coastguard Worker d_model = 4 12621*da0073e9SAndroid Build Coastguard Worker nhead = 2 12622*da0073e9SAndroid Build Coastguard Worker dim_feedforward = 16 12623*da0073e9SAndroid Build Coastguard Worker dropout = 0.0 12624*da0073e9SAndroid Build Coastguard Worker bsz = 2 12625*da0073e9SAndroid Build Coastguard Worker 12626*da0073e9SAndroid Build Coastguard Worker atol = 0 12627*da0073e9SAndroid Build Coastguard Worker rtol = 1e-5 12628*da0073e9SAndroid Build Coastguard Worker if "cuda" in device: 12629*da0073e9SAndroid Build Coastguard Worker atol = 1e-3 12630*da0073e9SAndroid Build Coastguard Worker rtol = 1e-2 12631*da0073e9SAndroid Build Coastguard Worker 12632*da0073e9SAndroid Build Coastguard Worker def _test(activation, batch_first, training): 12633*da0073e9SAndroid Build Coastguard Worker def perm_fn(x): 12634*da0073e9SAndroid Build Coastguard Worker return x.transpose(1, 0) if batch_first else x 12635*da0073e9SAndroid Build Coastguard Worker 12636*da0073e9SAndroid Build Coastguard Worker model = nn.TransformerEncoderLayer(d_model, nhead, dim_feedforward, dropout, 12637*da0073e9SAndroid Build Coastguard Worker activation, batch_first=batch_first, device=device, dtype=dtype) 12638*da0073e9SAndroid Build Coastguard Worker if not training: 12639*da0073e9SAndroid Build Coastguard Worker assert dropout == 0 12640*da0073e9SAndroid Build Coastguard Worker model = model.eval() 12641*da0073e9SAndroid Build Coastguard Worker 12642*da0073e9SAndroid Build Coastguard Worker # set constant weights of the model 12643*da0073e9SAndroid Build Coastguard Worker for idx, p in enumerate(model.parameters()): 12644*da0073e9SAndroid Build Coastguard Worker x = p.data 12645*da0073e9SAndroid Build Coastguard Worker sz = x.view(-1).size(0) 12646*da0073e9SAndroid Build Coastguard Worker shape = x.shape 12647*da0073e9SAndroid Build Coastguard Worker x = torch.cos(torch.arange(0, sz).float().view(shape)) 12648*da0073e9SAndroid Build Coastguard Worker p.data.copy_(x) 12649*da0073e9SAndroid Build Coastguard Worker 12650*da0073e9SAndroid Build Coastguard Worker # deterministic input 12651*da0073e9SAndroid Build Coastguard Worker encoder_input = torch.tensor([[[20., 30., 40., 50.]]], device=device, dtype=dtype) 12652*da0073e9SAndroid Build Coastguard Worker result = model(encoder_input) 12653*da0073e9SAndroid Build Coastguard Worker ref_output = torch.tensor([[[2.249815, 0.131006, -0.702199, 0.177868]]], device=device, dtype=dtype) 12654*da0073e9SAndroid Build Coastguard Worker torch.testing.assert_close(result, ref_output, rtol=rtol, atol=atol) 12655*da0073e9SAndroid Build Coastguard Worker 12656*da0073e9SAndroid Build Coastguard Worker # deterministic input 12657*da0073e9SAndroid Build Coastguard Worker encoder_input = perm_fn(torch.tensor([[[1., 2., 3., 4.]], 12658*da0073e9SAndroid Build Coastguard Worker [[5., 6., 7., 8.]]], device=device, dtype=dtype)) 12659*da0073e9SAndroid Build Coastguard Worker result = model(encoder_input) 12660*da0073e9SAndroid Build Coastguard Worker ref_output = perm_fn(torch.tensor([[[2.264103, 0.121417, -0.696012, 0.159724]], 12661*da0073e9SAndroid Build Coastguard Worker [[2.264103, 0.121417, -0.696012, 0.159724]]], device=device, dtype=dtype)) 12662*da0073e9SAndroid Build Coastguard Worker torch.testing.assert_close(result, ref_output, rtol=rtol, atol=atol) 12663*da0073e9SAndroid Build Coastguard Worker 12664*da0073e9SAndroid Build Coastguard Worker # deterministic input 12665*da0073e9SAndroid Build Coastguard Worker encoder_input = perm_fn(torch.tensor([[[0.7462, 0.6653, 0.5679, 0.4891], 12666*da0073e9SAndroid Build Coastguard Worker [0.5387, 0.1655, 0.3565, 0.0471]], 12667*da0073e9SAndroid Build Coastguard Worker [[0.8335, 0.2799, 0.5031, 0.2947], 12668*da0073e9SAndroid Build Coastguard Worker [0.1402, 0.0318, 0.7636, 0.1346]], 12669*da0073e9SAndroid Build Coastguard Worker [[0.6333, 0.9344, 0.1376, 0.9938], 12670*da0073e9SAndroid Build Coastguard Worker [0.8924, 0.2872, 0.6692, 0.2944]], 12671*da0073e9SAndroid Build Coastguard Worker [[0.9897, 0.6915, 0.3154, 0.1733], 12672*da0073e9SAndroid Build Coastguard Worker [0.8645, 0.3513, 0.3064, 0.0767]], 12673*da0073e9SAndroid Build Coastguard Worker [[0.8117, 0.2366, 0.4838, 0.7881], 12674*da0073e9SAndroid Build Coastguard Worker [0.3718, 0.4945, 0.9511, 0.0864]]], device=device, dtype=dtype)) 12675*da0073e9SAndroid Build Coastguard Worker result = model(encoder_input) 12676*da0073e9SAndroid Build Coastguard Worker ref_output = perm_fn(torch.tensor([[[2.42163188, 0.03227153, -0.60714219, -0.05908082], 12677*da0073e9SAndroid Build Coastguard Worker [2.42151276, 0.03302179, -0.60722523, -0.05762651]], 12678*da0073e9SAndroid Build Coastguard Worker [[2.41926761, 0.02974034, -0.60879519, -0.0621269], 12679*da0073e9SAndroid Build Coastguard Worker [2.41626395, 0.03539356, -0.61087842, -0.04978623]], 12680*da0073e9SAndroid Build Coastguard Worker [[2.42382808, 0.03218872, -0.6055963, -0.06073591], 12681*da0073e9SAndroid Build Coastguard Worker [2.41983477, 0.03085259, -0.60840145, -0.06046414]], 12682*da0073e9SAndroid Build Coastguard Worker [[2.42500749, 0.03328855, -0.60476388, -0.0595334], 12683*da0073e9SAndroid Build Coastguard Worker [2.4237977, 0.03290575, -0.60561789, -0.05940082]], 12684*da0073e9SAndroid Build Coastguard Worker [[2.41383916, 0.02686345, -0.61256377, -0.06380707], 12685*da0073e9SAndroid Build Coastguard Worker [2.42000277, 0.03800944, -0.60824798, -0.04754947]]], device=device, dtype=dtype)) 12686*da0073e9SAndroid Build Coastguard Worker torch.testing.assert_close(result, ref_output, rtol=rtol, atol=atol) 12687*da0073e9SAndroid Build Coastguard Worker for activation, batch_first, training in product(('gelu', F.gelu, nn.GELU()), (True, False), (True, False)): 12688*da0073e9SAndroid Build Coastguard Worker # Fast path requires inference mode. 12689*da0073e9SAndroid Build Coastguard Worker if training: 12690*da0073e9SAndroid Build Coastguard Worker cm = contextlib.nullcontext() 12691*da0073e9SAndroid Build Coastguard Worker else: 12692*da0073e9SAndroid Build Coastguard Worker cm = torch.no_grad() 12693*da0073e9SAndroid Build Coastguard Worker with cm: 12694*da0073e9SAndroid Build Coastguard Worker _test(activation=activation, batch_first=batch_first, training=training) 12695*da0073e9SAndroid Build Coastguard Worker 12696*da0073e9SAndroid Build Coastguard Worker @skipIfMps # RuntimeError: foreach=True was passed, but can't use the foreach API on mps tensors 12697*da0073e9SAndroid Build Coastguard Worker @parametrize_test('foreach', (False, True)) 12698*da0073e9SAndroid Build Coastguard Worker def test_clip_grad_value(self, foreach, device): 12699*da0073e9SAndroid Build Coastguard Worker if torch.device(device).type == 'xla' and foreach: 12700*da0073e9SAndroid Build Coastguard Worker raise SkipTest('foreach not supported on XLA') 12701*da0073e9SAndroid Build Coastguard Worker 12702*da0073e9SAndroid Build Coastguard Worker l = nn.Linear(10, 10).to(device) 12703*da0073e9SAndroid Build Coastguard Worker clip_value = 2.5 12704*da0073e9SAndroid Build Coastguard Worker 12705*da0073e9SAndroid Build Coastguard Worker grad_w, grad_b = torch.arange(-50., 50, device=device).view(10, 10).div_(5), torch.ones(10, device=device).mul_(2) 12706*da0073e9SAndroid Build Coastguard Worker for grad_list in [[grad_w, grad_b], [grad_w, None]]: 12707*da0073e9SAndroid Build Coastguard Worker for p, g in zip(l.parameters(), grad_list): 12708*da0073e9SAndroid Build Coastguard Worker p._grad = g.clone().view_as(p.data) if g is not None else g 12709*da0073e9SAndroid Build Coastguard Worker 12710*da0073e9SAndroid Build Coastguard Worker clip_grad_value_(l.parameters(), clip_value, foreach=foreach) 12711*da0073e9SAndroid Build Coastguard Worker for p in filter(lambda p: p.grad is not None, l.parameters()): 12712*da0073e9SAndroid Build Coastguard Worker self.assertLessEqual(p.grad.data.max(), clip_value) 12713*da0073e9SAndroid Build Coastguard Worker self.assertGreaterEqual(p.grad.data.min(), -clip_value) 12714*da0073e9SAndroid Build Coastguard Worker 12715*da0073e9SAndroid Build Coastguard Worker # Should accept a single Tensor as input 12716*da0073e9SAndroid Build Coastguard Worker p1, p2 = torch.randn(10, 10, device=device), torch.randn(10, 10, device=device) 12717*da0073e9SAndroid Build Coastguard Worker g = torch.arange(-50., 50, device=device).view(10, 10).div_(5) 12718*da0073e9SAndroid Build Coastguard Worker p1._grad = g.clone() 12719*da0073e9SAndroid Build Coastguard Worker p2._grad = g.clone() 12720*da0073e9SAndroid Build Coastguard Worker clip_grad_value_(p1, clip_value, foreach=foreach) 12721*da0073e9SAndroid Build Coastguard Worker clip_grad_value_([p2], clip_value, foreach=foreach) 12722*da0073e9SAndroid Build Coastguard Worker self.assertEqual(p1.grad, p2.grad) 12723*da0073e9SAndroid Build Coastguard Worker 12724*da0073e9SAndroid Build Coastguard Worker @skipIfMps # TypeError: the MPS framework doesn't support float64 12725*da0073e9SAndroid Build Coastguard Worker @parametrize_test('foreach', (False, True)) 12726*da0073e9SAndroid Build Coastguard Worker @parametrize_test('norm_type', (0.5, 1.5, 2, 4, 'inf')) 12727*da0073e9SAndroid Build Coastguard Worker def test_clip_grad_norm(self, norm_type, foreach, device): 12728*da0073e9SAndroid Build Coastguard Worker if torch.device(device).type == 'xla' and foreach: 12729*da0073e9SAndroid Build Coastguard Worker raise SkipTest('foreach not supported on XLA') 12730*da0073e9SAndroid Build Coastguard Worker 12731*da0073e9SAndroid Build Coastguard Worker l = nn.Linear(10, 10).to(device) 12732*da0073e9SAndroid Build Coastguard Worker max_norm = 2 12733*da0073e9SAndroid Build Coastguard Worker 12734*da0073e9SAndroid Build Coastguard Worker def compute_norm(norm_type): 12735*da0073e9SAndroid Build Coastguard Worker norm_type = float(norm_type) 12736*da0073e9SAndroid Build Coastguard Worker if norm_type != inf: 12737*da0073e9SAndroid Build Coastguard Worker total_norm = 0 12738*da0073e9SAndroid Build Coastguard Worker for p in l.parameters(): 12739*da0073e9SAndroid Build Coastguard Worker total_norm += p.grad.data.abs().pow(norm_type).sum() 12740*da0073e9SAndroid Build Coastguard Worker return pow(total_norm, 1. / norm_type) 12741*da0073e9SAndroid Build Coastguard Worker else: 12742*da0073e9SAndroid Build Coastguard Worker return max(p.grad.data.abs().max() for p in l.parameters()) 12743*da0073e9SAndroid Build Coastguard Worker 12744*da0073e9SAndroid Build Coastguard Worker def compare_scaling(grads): 12745*da0073e9SAndroid Build Coastguard Worker p_scale = [p.grad.data.div(g).view(-1) for p, g in zip(l.parameters(), grads)] 12746*da0073e9SAndroid Build Coastguard Worker scale = torch.cat(p_scale) 12747*da0073e9SAndroid Build Coastguard Worker self.assertEqual(scale.std(), 0) 12748*da0073e9SAndroid Build Coastguard Worker return scale[0] 12749*da0073e9SAndroid Build Coastguard Worker 12750*da0073e9SAndroid Build Coastguard Worker grads = torch.arange(1., 101, device=device).view(10, 10), torch.ones(10, device=device).div(1000) 12751*da0073e9SAndroid Build Coastguard Worker for p, g in zip(l.parameters(), grads): 12752*da0073e9SAndroid Build Coastguard Worker p._grad = g.clone().view_as(p.data) 12753*da0073e9SAndroid Build Coastguard Worker norm_before = compute_norm(norm_type) 12754*da0073e9SAndroid Build Coastguard Worker norm = clip_grad_norm_(l.parameters(), max_norm, norm_type=norm_type, foreach=foreach) 12755*da0073e9SAndroid Build Coastguard Worker norm_after = compute_norm(norm_type) 12756*da0073e9SAndroid Build Coastguard Worker self.assertEqual(norm, norm_before) 12757*da0073e9SAndroid Build Coastguard Worker self.assertEqual(norm_after, max_norm) 12758*da0073e9SAndroid Build Coastguard Worker self.assertLessEqual(norm_after, norm_before) 12759*da0073e9SAndroid Build Coastguard Worker compare_scaling(grads) 12760*da0073e9SAndroid Build Coastguard Worker 12761*da0073e9SAndroid Build Coastguard Worker # Small gradients should be left unchanged 12762*da0073e9SAndroid Build Coastguard Worker grads = torch.rand(10, 10, device=device).div(10000), torch.ones(10, device=device).div(500) 12763*da0073e9SAndroid Build Coastguard Worker for p, g in zip(l.parameters(), grads): 12764*da0073e9SAndroid Build Coastguard Worker p.grad.data.copy_(g) 12765*da0073e9SAndroid Build Coastguard Worker norm_before = compute_norm(norm_type) 12766*da0073e9SAndroid Build Coastguard Worker norm = clip_grad_norm_(l.parameters(), max_norm, norm_type=norm_type, foreach=foreach) 12767*da0073e9SAndroid Build Coastguard Worker norm_after = compute_norm(norm_type) 12768*da0073e9SAndroid Build Coastguard Worker self.assertEqual(norm, norm_before) 12769*da0073e9SAndroid Build Coastguard Worker self.assertEqual(norm_before, norm_after) 12770*da0073e9SAndroid Build Coastguard Worker self.assertLessEqual(norm_after, max_norm) 12771*da0073e9SAndroid Build Coastguard Worker scale = compare_scaling(grads) 12772*da0073e9SAndroid Build Coastguard Worker self.assertEqual(scale, 1) 12773*da0073e9SAndroid Build Coastguard Worker 12774*da0073e9SAndroid Build Coastguard Worker # Should accept a single Tensor as input 12775*da0073e9SAndroid Build Coastguard Worker p1, p2 = torch.randn(10, 10, device=device), torch.randn(10, 10, device=device) 12776*da0073e9SAndroid Build Coastguard Worker g = torch.arange(1., 101, device=device).view(10, 10) 12777*da0073e9SAndroid Build Coastguard Worker p1._grad = g.clone() 12778*da0073e9SAndroid Build Coastguard Worker p2._grad = g.clone() 12779*da0073e9SAndroid Build Coastguard Worker clip_grad_norm_(p1, max_norm, norm_type=norm_type, foreach=foreach) 12780*da0073e9SAndroid Build Coastguard Worker clip_grad_norm_([p2], max_norm, norm_type=norm_type, foreach=foreach) 12781*da0073e9SAndroid Build Coastguard Worker self.assertEqual(p1.grad, p2.grad) 12782*da0073e9SAndroid Build Coastguard Worker 12783*da0073e9SAndroid Build Coastguard Worker # reference issue: https://github.com/pytorch/pytorch/issues/111484 12784*da0073e9SAndroid Build Coastguard Worker @onlyCUDA 12785*da0073e9SAndroid Build Coastguard Worker @largeTensorTest("42GB", "cuda") 12786*da0073e9SAndroid Build Coastguard Worker def test_softmax_forward_64bit_indexing(self, device): 12787*da0073e9SAndroid Build Coastguard Worker batch_size = 70 12788*da0073e9SAndroid Build Coastguard Worker seq_len = 2048 12789*da0073e9SAndroid Build Coastguard Worker vocab_size = 50000 12790*da0073e9SAndroid Build Coastguard Worker 12791*da0073e9SAndroid Build Coastguard Worker shift_labels = torch.zeros(batch_size, seq_len - 1, dtype=torch.long, device=device) 12792*da0073e9SAndroid Build Coastguard Worker logits = torch.ones(batch_size, seq_len - 1, vocab_size, dtype=torch.float16, device=device) 12793*da0073e9SAndroid Build Coastguard Worker loss_fct = torch.nn.CrossEntropyLoss(reduction="none") 12794*da0073e9SAndroid Build Coastguard Worker nll = loss_fct(logits.permute(0, 2, 1), shift_labels).float() 12795*da0073e9SAndroid Build Coastguard Worker rtol, atol = torch.testing._comparison.get_tolerances(torch.float16, rtol=None, atol=None) 12796*da0073e9SAndroid Build Coastguard Worker self.assertEqual(nll, torch.ones_like(nll) * torch.log(torch.tensor(vocab_size)), rtol=rtol, atol=atol) 12797*da0073e9SAndroid Build Coastguard Worker 12798*da0073e9SAndroid Build Coastguard Worker @onlyCUDA 12799*da0073e9SAndroid Build Coastguard Worker @largeTensorTest("20GB", "cuda") 12800*da0073e9SAndroid Build Coastguard Worker def test_softmax_backward_64bit_indexing(self, device): 12801*da0073e9SAndroid Build Coastguard Worker for numel in (2147483650, 2147483650 + 1): 12802*da0073e9SAndroid Build Coastguard Worker x = torch.empty([1, 1, numel], device=device, dtype=torch.float16) 12803*da0073e9SAndroid Build Coastguard Worker x.fill_(1.0 / numel) 12804*da0073e9SAndroid Build Coastguard Worker out = torch._softmax_backward_data(x, x, 2, x.dtype) 12805*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out[0, 0, 0], 1 / numel) 12806*da0073e9SAndroid Build Coastguard Worker 12807*da0073e9SAndroid Build Coastguard Worker # reference issue: https://github.com/pytorch/pytorch/issues/68248 12808*da0073e9SAndroid Build Coastguard Worker @onlyCUDA 12809*da0073e9SAndroid Build Coastguard Worker def test_adaptiveavg_pool1d_shmem(self, device): 12810*da0073e9SAndroid Build Coastguard Worker x = torch.randn(1, 256, 1, 5000, device=device).to(memory_format=torch.channels_last) 12811*da0073e9SAndroid Build Coastguard Worker x_cpu = x.cpu() 12812*da0073e9SAndroid Build Coastguard Worker x_cpu.requires_grad_() 12813*da0073e9SAndroid Build Coastguard Worker x.requires_grad_() 12814*da0073e9SAndroid Build Coastguard Worker y = torch.nn.functional.adaptive_avg_pool2d(x, (1, 256)) 12815*da0073e9SAndroid Build Coastguard Worker y_cpu = torch.nn.functional.adaptive_avg_pool2d(x_cpu, (1, 256)) 12816*da0073e9SAndroid Build Coastguard Worker grad = torch.randn_like(y) 12817*da0073e9SAndroid Build Coastguard Worker grad_cpu = grad.cpu() 12818*da0073e9SAndroid Build Coastguard Worker y.backward(grad) 12819*da0073e9SAndroid Build Coastguard Worker y_cpu.backward(grad_cpu) 12820*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.grad, x_cpu.grad) 12821*da0073e9SAndroid Build Coastguard Worker 12822*da0073e9SAndroid Build Coastguard Worker @skipMeta 12823*da0073e9SAndroid Build Coastguard Worker @expectedFailureMPS # NotImplementedError: aten::channel_shuffle https://github.com/pytorch/pytorch/issues/77764 12824*da0073e9SAndroid Build Coastguard Worker def test_channel_shuffle(self, device): 12825*da0073e9SAndroid Build Coastguard Worker # 3D tensor 12826*da0073e9SAndroid Build Coastguard Worker x = torch.tensor( 12827*da0073e9SAndroid Build Coastguard Worker [[[1, 2], 12828*da0073e9SAndroid Build Coastguard Worker [5, 6], 12829*da0073e9SAndroid Build Coastguard Worker [9, 10], 12830*da0073e9SAndroid Build Coastguard Worker [13, 14], 12831*da0073e9SAndroid Build Coastguard Worker ]], device=device 12832*da0073e9SAndroid Build Coastguard Worker ) 12833*da0073e9SAndroid Build Coastguard Worker y_ref = torch.tensor( 12834*da0073e9SAndroid Build Coastguard Worker [[[1, 2], 12835*da0073e9SAndroid Build Coastguard Worker [9, 10], 12836*da0073e9SAndroid Build Coastguard Worker [5, 6], 12837*da0073e9SAndroid Build Coastguard Worker [13, 14], 12838*da0073e9SAndroid Build Coastguard Worker ]], device=device 12839*da0073e9SAndroid Build Coastguard Worker ) 12840*da0073e9SAndroid Build Coastguard Worker # ChannelsFirst 12841*da0073e9SAndroid Build Coastguard Worker with warnings.catch_warnings(record=True) as w: 12842*da0073e9SAndroid Build Coastguard Worker y = F.channel_shuffle(x, 2).to(device) 12843*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(w), 0) 12844*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y, y_ref) 12845*da0073e9SAndroid Build Coastguard Worker # ChannelsLast not supported for 3dim 12846*da0073e9SAndroid Build Coastguard Worker 12847*da0073e9SAndroid Build Coastguard Worker # 4D tensor 12848*da0073e9SAndroid Build Coastguard Worker x = torch.tensor( 12849*da0073e9SAndroid Build Coastguard Worker [[[[1, 2], 12850*da0073e9SAndroid Build Coastguard Worker [3, 4]], 12851*da0073e9SAndroid Build Coastguard Worker [[5, 6], 12852*da0073e9SAndroid Build Coastguard Worker [7, 8]], 12853*da0073e9SAndroid Build Coastguard Worker [[9, 10], 12854*da0073e9SAndroid Build Coastguard Worker [11, 12]], 12855*da0073e9SAndroid Build Coastguard Worker [[13, 14], 12856*da0073e9SAndroid Build Coastguard Worker [15, 16]], 12857*da0073e9SAndroid Build Coastguard Worker ]], device=device 12858*da0073e9SAndroid Build Coastguard Worker ) 12859*da0073e9SAndroid Build Coastguard Worker y_ref = torch.tensor( 12860*da0073e9SAndroid Build Coastguard Worker [[[[1, 2], 12861*da0073e9SAndroid Build Coastguard Worker [3, 4]], 12862*da0073e9SAndroid Build Coastguard Worker [[9, 10], 12863*da0073e9SAndroid Build Coastguard Worker [11, 12]], 12864*da0073e9SAndroid Build Coastguard Worker [[5, 6], 12865*da0073e9SAndroid Build Coastguard Worker [7, 8]], 12866*da0073e9SAndroid Build Coastguard Worker [[13, 14], 12867*da0073e9SAndroid Build Coastguard Worker [15, 16]], 12868*da0073e9SAndroid Build Coastguard Worker ]], device=device 12869*da0073e9SAndroid Build Coastguard Worker ) 12870*da0073e9SAndroid Build Coastguard Worker # ChannelsFirst NCHW 12871*da0073e9SAndroid Build Coastguard Worker with warnings.catch_warnings(record=True) as w: 12872*da0073e9SAndroid Build Coastguard Worker y = F.channel_shuffle(x, 2).to(device) 12873*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(w), 0) 12874*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y, y_ref) 12875*da0073e9SAndroid Build Coastguard Worker # ChannelsLast NHWC 12876*da0073e9SAndroid Build Coastguard Worker with warnings.catch_warnings(record=True) as w: 12877*da0073e9SAndroid Build Coastguard Worker y = F.channel_shuffle(x.contiguous(memory_format=torch.channels_last), 2).to(device) 12878*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(w), 0) 12879*da0073e9SAndroid Build Coastguard Worker y = y.contiguous(memory_format=torch.contiguous_format) 12880*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y, y_ref) 12881*da0073e9SAndroid Build Coastguard Worker 12882*da0073e9SAndroid Build Coastguard Worker # 5D tensor 12883*da0073e9SAndroid Build Coastguard Worker x = torch.tensor( 12884*da0073e9SAndroid Build Coastguard Worker [[[[[1, 2], 12885*da0073e9SAndroid Build Coastguard Worker [3, 4]]], 12886*da0073e9SAndroid Build Coastguard Worker [[[5, 6], 12887*da0073e9SAndroid Build Coastguard Worker [7, 8]]], 12888*da0073e9SAndroid Build Coastguard Worker [[[9, 10], 12889*da0073e9SAndroid Build Coastguard Worker [11, 12]]], 12890*da0073e9SAndroid Build Coastguard Worker [[[13, 14], 12891*da0073e9SAndroid Build Coastguard Worker [15, 16]]], 12892*da0073e9SAndroid Build Coastguard Worker ]], device=device 12893*da0073e9SAndroid Build Coastguard Worker ) 12894*da0073e9SAndroid Build Coastguard Worker y_ref = torch.tensor( 12895*da0073e9SAndroid Build Coastguard Worker [[[[[1, 2], 12896*da0073e9SAndroid Build Coastguard Worker [3, 4]]], 12897*da0073e9SAndroid Build Coastguard Worker [[[9, 10], 12898*da0073e9SAndroid Build Coastguard Worker [11, 12]]], 12899*da0073e9SAndroid Build Coastguard Worker [[[5, 6], 12900*da0073e9SAndroid Build Coastguard Worker [7, 8]]], 12901*da0073e9SAndroid Build Coastguard Worker [[[13, 14], 12902*da0073e9SAndroid Build Coastguard Worker [15, 16]]], 12903*da0073e9SAndroid Build Coastguard Worker ]], device=device 12904*da0073e9SAndroid Build Coastguard Worker ) 12905*da0073e9SAndroid Build Coastguard Worker # ChannelsFirst NCHW 12906*da0073e9SAndroid Build Coastguard Worker with warnings.catch_warnings(record=True) as w: 12907*da0073e9SAndroid Build Coastguard Worker y = F.channel_shuffle(x, 2).to(device) 12908*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(w), 0) 12909*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y, y_ref) 12910*da0073e9SAndroid Build Coastguard Worker # ChannelsLast NHWC 12911*da0073e9SAndroid Build Coastguard Worker with warnings.catch_warnings(record=True) as w: 12912*da0073e9SAndroid Build Coastguard Worker y = F.channel_shuffle(x.contiguous(memory_format=torch.channels_last_3d), 2).to(device) 12913*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(w), 0) 12914*da0073e9SAndroid Build Coastguard Worker y = y.contiguous(memory_format=torch.contiguous_format) 12915*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y, y_ref) 12916*da0073e9SAndroid Build Coastguard Worker 12917*da0073e9SAndroid Build Coastguard Worker 12918*da0073e9SAndroid Build Coastguard Workerclass TestFunctionalPickle(TestCase): 12919*da0073e9SAndroid Build Coastguard Worker 12920*da0073e9SAndroid Build Coastguard Worker # issue gh-38137 12921*da0073e9SAndroid Build Coastguard Worker def test_pickle_softsign(self): 12922*da0073e9SAndroid Build Coastguard Worker # Make sure it does not throw an exception 12923*da0073e9SAndroid Build Coastguard Worker s = pickle.dumps(F.softsign) 12924*da0073e9SAndroid Build Coastguard Worker 12925*da0073e9SAndroid Build Coastguard Worker 12926*da0073e9SAndroid Build Coastguard Workerclass TestFusionUtils(TestCase): 12927*da0073e9SAndroid Build Coastguard Worker def test_fuse_conv_bn_requires_grad(self): 12928*da0073e9SAndroid Build Coastguard Worker conv = torch.nn.Conv2d(3, 3, 3) 12929*da0073e9SAndroid Build Coastguard Worker bn = torch.nn.BatchNorm2d(3) 12930*da0073e9SAndroid Build Coastguard Worker cases = itertools.product([True, False], [True, False]) 12931*da0073e9SAndroid Build Coastguard Worker for w_rg, b_rg in cases: 12932*da0073e9SAndroid Build Coastguard Worker conv.weight.requires_grad = w_rg 12933*da0073e9SAndroid Build Coastguard Worker conv.bias.requires_grad = b_rg 12934*da0073e9SAndroid Build Coastguard Worker weight, bias = \ 12935*da0073e9SAndroid Build Coastguard Worker fuse_conv_bn_weights(conv.weight, conv.bias, 12936*da0073e9SAndroid Build Coastguard Worker bn.running_mean, bn.running_var, bn.eps, bn.weight, bn.bias) 12937*da0073e9SAndroid Build Coastguard Worker self.assertEqual(weight.requires_grad, w_rg) 12938*da0073e9SAndroid Build Coastguard Worker self.assertEqual(bias.requires_grad, b_rg) 12939*da0073e9SAndroid Build Coastguard Worker 12940*da0073e9SAndroid Build Coastguard Worker def test_fuse_linear_bn_requires_grad(self): 12941*da0073e9SAndroid Build Coastguard Worker linear = torch.nn.Linear(3, 3) 12942*da0073e9SAndroid Build Coastguard Worker bn = torch.nn.BatchNorm1d(3) 12943*da0073e9SAndroid Build Coastguard Worker cases = itertools.product([True, False], [True, False]) 12944*da0073e9SAndroid Build Coastguard Worker for w_rg, b_rg in cases: 12945*da0073e9SAndroid Build Coastguard Worker linear.weight.requires_grad = w_rg 12946*da0073e9SAndroid Build Coastguard Worker linear.bias.requires_grad = b_rg 12947*da0073e9SAndroid Build Coastguard Worker weight, bias = \ 12948*da0073e9SAndroid Build Coastguard Worker fuse_linear_bn_weights(linear.weight, linear.bias, 12949*da0073e9SAndroid Build Coastguard Worker bn.running_mean, bn.running_var, bn.eps, bn.weight, bn.bias) 12950*da0073e9SAndroid Build Coastguard Worker self.assertEqual(weight.requires_grad, w_rg) 12951*da0073e9SAndroid Build Coastguard Worker self.assertEqual(bias.requires_grad, b_rg) 12952*da0073e9SAndroid Build Coastguard Worker 12953*da0073e9SAndroid Build Coastguard Workerclass TestUtils(TestCase): 12954*da0073e9SAndroid Build Coastguard Worker def test_consume_prefix_in_state_dict_if_present(self): 12955*da0073e9SAndroid Build Coastguard Worker class Block(nn.Module): 12956*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 12957*da0073e9SAndroid Build Coastguard Worker super().__init__() 12958*da0073e9SAndroid Build Coastguard Worker self.conv1 = nn.Conv2d(3, 3, 3, bias=True) 12959*da0073e9SAndroid Build Coastguard Worker self.conv2 = nn.Conv2d(3, 3, 3, bias=False) 12960*da0073e9SAndroid Build Coastguard Worker 12961*da0073e9SAndroid Build Coastguard Worker class Net(nn.Module): 12962*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 12963*da0073e9SAndroid Build Coastguard Worker super().__init__() 12964*da0073e9SAndroid Build Coastguard Worker self.linear1 = nn.Linear(5, 5) 12965*da0073e9SAndroid Build Coastguard Worker self.linear2 = nn.Linear(5, 5) 12966*da0073e9SAndroid Build Coastguard Worker net.bn = nn.BatchNorm2d(2) 12967*da0073e9SAndroid Build Coastguard Worker self.block = Block() 12968*da0073e9SAndroid Build Coastguard Worker 12969*da0073e9SAndroid Build Coastguard Worker # 0. Case non-DDP model empty state_dict 12970*da0073e9SAndroid Build Coastguard Worker net = nn.Module() 12971*da0073e9SAndroid Build Coastguard Worker state_dict = net.state_dict() 12972*da0073e9SAndroid Build Coastguard Worker nn.modules.utils.consume_prefix_in_state_dict_if_present(state_dict, 'module.') 12973*da0073e9SAndroid Build Coastguard Worker # check they are the same preserving order 12974*da0073e9SAndroid Build Coastguard Worker self.assertEqual(list(state_dict.keys()), list(net.state_dict().keys())) 12975*da0073e9SAndroid Build Coastguard Worker self.assertEqual(list(state_dict._metadata.keys()), list(net.state_dict()._metadata.keys())) 12976*da0073e9SAndroid Build Coastguard Worker 12977*da0073e9SAndroid Build Coastguard Worker # 1. Case non-DDP model test example state_dict 12978*da0073e9SAndroid Build Coastguard Worker net = Net() 12979*da0073e9SAndroid Build Coastguard Worker state_dict = net.state_dict() 12980*da0073e9SAndroid Build Coastguard Worker nn.modules.utils.consume_prefix_in_state_dict_if_present(state_dict, 'module.') 12981*da0073e9SAndroid Build Coastguard Worker # Check they are the same preserving order 12982*da0073e9SAndroid Build Coastguard Worker self.assertEqual(list(state_dict.keys()), list(net.state_dict().keys())) 12983*da0073e9SAndroid Build Coastguard Worker self.assertEqual(list(state_dict._metadata.keys()), list(net.state_dict()._metadata.keys())) 12984*da0073e9SAndroid Build Coastguard Worker 12985*da0073e9SAndroid Build Coastguard Worker # 2. Case DDP model test example state_dict 12986*da0073e9SAndroid Build Coastguard Worker state_dict = net.state_dict() 12987*da0073e9SAndroid Build Coastguard Worker metadata = state_dict._metadata 12988*da0073e9SAndroid Build Coastguard Worker ddp_state_dict = OrderedDict((f'module.{k}', v) for k, v in state_dict.items()) 12989*da0073e9SAndroid Build Coastguard Worker ddp_state_dict._metadata = OrderedDict({'': metadata['']}) 12990*da0073e9SAndroid Build Coastguard Worker ddp_state_dict._metadata.update(('module' if k == '' else f'module.{k}', v) for k, v in metadata.items()) 12991*da0073e9SAndroid Build Coastguard Worker nn.modules.utils.consume_prefix_in_state_dict_if_present(ddp_state_dict, 'module.') 12992*da0073e9SAndroid Build Coastguard Worker # Check they are the same preserving order 12993*da0073e9SAndroid Build Coastguard Worker self.assertEqual(list(state_dict.keys()), list(ddp_state_dict.keys())) 12994*da0073e9SAndroid Build Coastguard Worker self.assertEqual(list(state_dict._metadata.keys()), list(ddp_state_dict._metadata.keys())) 12995*da0073e9SAndroid Build Coastguard Worker 12996*da0073e9SAndroid Build Coastguard Worker 12997*da0073e9SAndroid Build Coastguard Workerinstantiate_device_type_tests(TestNNDeviceType, globals(), allow_mps=True) 12998*da0073e9SAndroid Build Coastguard Workerinstantiate_parametrized_tests(TestNN) 12999*da0073e9SAndroid Build Coastguard Worker 13000*da0073e9SAndroid Build Coastguard Workerif __name__ == '__main__': 13001*da0073e9SAndroid Build Coastguard Worker TestCase._default_dtype_check_enabled = True 13002*da0073e9SAndroid Build Coastguard Worker run_tests() 13003