1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["module: fft"] 2*da0073e9SAndroid Build Coastguard Worker 3*da0073e9SAndroid Build Coastguard Workerimport torch 4*da0073e9SAndroid Build Coastguard Workerimport unittest 5*da0073e9SAndroid Build Coastguard Workerimport math 6*da0073e9SAndroid Build Coastguard Workerfrom contextlib import contextmanager 7*da0073e9SAndroid Build Coastguard Workerfrom itertools import product 8*da0073e9SAndroid Build Coastguard Workerimport itertools 9*da0073e9SAndroid Build Coastguard Workerimport doctest 10*da0073e9SAndroid Build Coastguard Workerimport inspect 11*da0073e9SAndroid Build Coastguard Worker 12*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import \ 13*da0073e9SAndroid Build Coastguard Worker (TestCase, run_tests, TEST_NUMPY, TEST_LIBROSA, TEST_MKL, first_sample, TEST_WITH_ROCM, 14*da0073e9SAndroid Build Coastguard Worker make_tensor, skipIfTorchDynamo) 15*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_device_type import \ 16*da0073e9SAndroid Build Coastguard Worker (instantiate_device_type_tests, ops, dtypes, onlyNativeDeviceTypes, 17*da0073e9SAndroid Build Coastguard Worker skipCPUIfNoFFT, deviceCountAtLeast, onlyCUDA, OpDTypes, skipIf, toleranceOverride, tol) 18*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_methods_invocations import ( 19*da0073e9SAndroid Build Coastguard Worker spectral_funcs, SpectralFuncType) 20*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_cuda import SM53OrLater 21*da0073e9SAndroid Build Coastguard Workerfrom torch._prims_common import corresponding_complex_dtype 22*da0073e9SAndroid Build Coastguard Worker 23*da0073e9SAndroid Build Coastguard Workerfrom typing import Optional, List 24*da0073e9SAndroid Build Coastguard Workerfrom packaging import version 25*da0073e9SAndroid Build Coastguard Worker 26*da0073e9SAndroid Build Coastguard Worker 27*da0073e9SAndroid Build Coastguard Workerif TEST_NUMPY: 28*da0073e9SAndroid Build Coastguard Worker import numpy as np 29*da0073e9SAndroid Build Coastguard Worker 30*da0073e9SAndroid Build Coastguard Worker 31*da0073e9SAndroid Build Coastguard Workerif TEST_LIBROSA: 32*da0073e9SAndroid Build Coastguard Worker import librosa 33*da0073e9SAndroid Build Coastguard Worker 34*da0073e9SAndroid Build Coastguard Workerhas_scipy_fft = False 35*da0073e9SAndroid Build Coastguard Workertry: 36*da0073e9SAndroid Build Coastguard Worker import scipy.fft 37*da0073e9SAndroid Build Coastguard Worker has_scipy_fft = True 38*da0073e9SAndroid Build Coastguard Workerexcept ModuleNotFoundError: 39*da0073e9SAndroid Build Coastguard Worker pass 40*da0073e9SAndroid Build Coastguard Worker 41*da0073e9SAndroid Build Coastguard WorkerREFERENCE_NORM_MODES = ( 42*da0073e9SAndroid Build Coastguard Worker (None, "forward", "backward", "ortho") 43*da0073e9SAndroid Build Coastguard Worker if version.parse(np.__version__) >= version.parse('1.20.0') and ( 44*da0073e9SAndroid Build Coastguard Worker not has_scipy_fft or version.parse(scipy.__version__) >= version.parse('1.6.0')) 45*da0073e9SAndroid Build Coastguard Worker else (None, "ortho")) 46*da0073e9SAndroid Build Coastguard Worker 47*da0073e9SAndroid Build Coastguard Worker 48*da0073e9SAndroid Build Coastguard Workerdef _complex_stft(x, *args, **kwargs): 49*da0073e9SAndroid Build Coastguard Worker # Transform real and imaginary components separably 50*da0073e9SAndroid Build Coastguard Worker stft_real = torch.stft(x.real, *args, **kwargs, return_complex=True, onesided=False) 51*da0073e9SAndroid Build Coastguard Worker stft_imag = torch.stft(x.imag, *args, **kwargs, return_complex=True, onesided=False) 52*da0073e9SAndroid Build Coastguard Worker return stft_real + 1j * stft_imag 53*da0073e9SAndroid Build Coastguard Worker 54*da0073e9SAndroid Build Coastguard Worker 55*da0073e9SAndroid Build Coastguard Workerdef _hermitian_conj(x, dim): 56*da0073e9SAndroid Build Coastguard Worker """Returns the hermitian conjugate along a single dimension 57*da0073e9SAndroid Build Coastguard Worker 58*da0073e9SAndroid Build Coastguard Worker H(x)[i] = conj(x[-i]) 59*da0073e9SAndroid Build Coastguard Worker """ 60*da0073e9SAndroid Build Coastguard Worker out = torch.empty_like(x) 61*da0073e9SAndroid Build Coastguard Worker mid = (x.size(dim) - 1) // 2 62*da0073e9SAndroid Build Coastguard Worker idx = [slice(None)] * out.dim() 63*da0073e9SAndroid Build Coastguard Worker idx_center = list(idx) 64*da0073e9SAndroid Build Coastguard Worker idx_center[dim] = 0 65*da0073e9SAndroid Build Coastguard Worker out[idx] = x[idx] 66*da0073e9SAndroid Build Coastguard Worker 67*da0073e9SAndroid Build Coastguard Worker idx_neg = list(idx) 68*da0073e9SAndroid Build Coastguard Worker idx_neg[dim] = slice(-mid, None) 69*da0073e9SAndroid Build Coastguard Worker idx_pos = idx 70*da0073e9SAndroid Build Coastguard Worker idx_pos[dim] = slice(1, mid + 1) 71*da0073e9SAndroid Build Coastguard Worker 72*da0073e9SAndroid Build Coastguard Worker out[idx_pos] = x[idx_neg].flip(dim) 73*da0073e9SAndroid Build Coastguard Worker out[idx_neg] = x[idx_pos].flip(dim) 74*da0073e9SAndroid Build Coastguard Worker if (2 * mid + 1 < x.size(dim)): 75*da0073e9SAndroid Build Coastguard Worker idx[dim] = mid + 1 76*da0073e9SAndroid Build Coastguard Worker out[idx] = x[idx] 77*da0073e9SAndroid Build Coastguard Worker return out.conj() 78*da0073e9SAndroid Build Coastguard Worker 79*da0073e9SAndroid Build Coastguard Worker 80*da0073e9SAndroid Build Coastguard Workerdef _complex_istft(x, *args, **kwargs): 81*da0073e9SAndroid Build Coastguard Worker # Decompose into Hermitian (FFT of real) and anti-Hermitian (FFT of imaginary) 82*da0073e9SAndroid Build Coastguard Worker n_fft = x.size(-2) 83*da0073e9SAndroid Build Coastguard Worker slc = (Ellipsis, slice(None, n_fft // 2 + 1), slice(None)) 84*da0073e9SAndroid Build Coastguard Worker 85*da0073e9SAndroid Build Coastguard Worker hconj = _hermitian_conj(x, dim=-2) 86*da0073e9SAndroid Build Coastguard Worker x_hermitian = (x + hconj) / 2 87*da0073e9SAndroid Build Coastguard Worker x_antihermitian = (x - hconj) / 2 88*da0073e9SAndroid Build Coastguard Worker istft_real = torch.istft(x_hermitian[slc], *args, **kwargs, onesided=True) 89*da0073e9SAndroid Build Coastguard Worker istft_imag = torch.istft(-1j * x_antihermitian[slc], *args, **kwargs, onesided=True) 90*da0073e9SAndroid Build Coastguard Worker return torch.complex(istft_real, istft_imag) 91*da0073e9SAndroid Build Coastguard Worker 92*da0073e9SAndroid Build Coastguard Worker 93*da0073e9SAndroid Build Coastguard Workerdef _stft_reference(x, hop_length, window): 94*da0073e9SAndroid Build Coastguard Worker r"""Reference stft implementation 95*da0073e9SAndroid Build Coastguard Worker 96*da0073e9SAndroid Build Coastguard Worker This doesn't implement all of torch.stft, only the STFT definition: 97*da0073e9SAndroid Build Coastguard Worker 98*da0073e9SAndroid Build Coastguard Worker .. math:: X(m, \omega) = \sum_n x[n]w[n - m] e^{-jn\omega} 99*da0073e9SAndroid Build Coastguard Worker 100*da0073e9SAndroid Build Coastguard Worker """ 101*da0073e9SAndroid Build Coastguard Worker n_fft = window.numel() 102*da0073e9SAndroid Build Coastguard Worker X = torch.empty((n_fft, (x.numel() - n_fft + hop_length) // hop_length), 103*da0073e9SAndroid Build Coastguard Worker device=x.device, dtype=torch.cdouble) 104*da0073e9SAndroid Build Coastguard Worker for m in range(X.size(1)): 105*da0073e9SAndroid Build Coastguard Worker start = m * hop_length 106*da0073e9SAndroid Build Coastguard Worker if start + n_fft > x.numel(): 107*da0073e9SAndroid Build Coastguard Worker slc = torch.empty(n_fft, device=x.device, dtype=x.dtype) 108*da0073e9SAndroid Build Coastguard Worker tmp = x[start:] 109*da0073e9SAndroid Build Coastguard Worker slc[:tmp.numel()] = tmp 110*da0073e9SAndroid Build Coastguard Worker else: 111*da0073e9SAndroid Build Coastguard Worker slc = x[start: start + n_fft] 112*da0073e9SAndroid Build Coastguard Worker X[:, m] = torch.fft.fft(slc * window) 113*da0073e9SAndroid Build Coastguard Worker return X 114*da0073e9SAndroid Build Coastguard Worker 115*da0073e9SAndroid Build Coastguard Worker 116*da0073e9SAndroid Build Coastguard Workerdef skip_helper_for_fft(device, dtype): 117*da0073e9SAndroid Build Coastguard Worker device_type = torch.device(device).type 118*da0073e9SAndroid Build Coastguard Worker if dtype not in (torch.half, torch.complex32): 119*da0073e9SAndroid Build Coastguard Worker return 120*da0073e9SAndroid Build Coastguard Worker 121*da0073e9SAndroid Build Coastguard Worker if device_type == 'cpu': 122*da0073e9SAndroid Build Coastguard Worker raise unittest.SkipTest("half and complex32 are not supported on CPU") 123*da0073e9SAndroid Build Coastguard Worker if not SM53OrLater: 124*da0073e9SAndroid Build Coastguard Worker raise unittest.SkipTest("half and complex32 are only supported on CUDA device with SM>53") 125*da0073e9SAndroid Build Coastguard Worker 126*da0073e9SAndroid Build Coastguard Worker 127*da0073e9SAndroid Build Coastguard Worker# Tests of functions related to Fourier analysis in the torch.fft namespace 128*da0073e9SAndroid Build Coastguard Workerclass TestFFT(TestCase): 129*da0073e9SAndroid Build Coastguard Worker exact_dtype = True 130*da0073e9SAndroid Build Coastguard Worker 131*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 132*da0073e9SAndroid Build Coastguard Worker @ops([op for op in spectral_funcs if op.ndimensional == SpectralFuncType.OneD], 133*da0073e9SAndroid Build Coastguard Worker allowed_dtypes=(torch.float, torch.cfloat)) 134*da0073e9SAndroid Build Coastguard Worker def test_reference_1d(self, device, dtype, op): 135*da0073e9SAndroid Build Coastguard Worker if op.ref is None: 136*da0073e9SAndroid Build Coastguard Worker raise unittest.SkipTest("No reference implementation") 137*da0073e9SAndroid Build Coastguard Worker 138*da0073e9SAndroid Build Coastguard Worker norm_modes = REFERENCE_NORM_MODES 139*da0073e9SAndroid Build Coastguard Worker test_args = [ 140*da0073e9SAndroid Build Coastguard Worker *product( 141*da0073e9SAndroid Build Coastguard Worker # input 142*da0073e9SAndroid Build Coastguard Worker (torch.randn(67, device=device, dtype=dtype), 143*da0073e9SAndroid Build Coastguard Worker torch.randn(80, device=device, dtype=dtype), 144*da0073e9SAndroid Build Coastguard Worker torch.randn(12, 14, device=device, dtype=dtype), 145*da0073e9SAndroid Build Coastguard Worker torch.randn(9, 6, 3, device=device, dtype=dtype)), 146*da0073e9SAndroid Build Coastguard Worker # n 147*da0073e9SAndroid Build Coastguard Worker (None, 50, 6), 148*da0073e9SAndroid Build Coastguard Worker # dim 149*da0073e9SAndroid Build Coastguard Worker (-1, 0), 150*da0073e9SAndroid Build Coastguard Worker # norm 151*da0073e9SAndroid Build Coastguard Worker norm_modes 152*da0073e9SAndroid Build Coastguard Worker ), 153*da0073e9SAndroid Build Coastguard Worker # Test transforming middle dimensions of multi-dim tensor 154*da0073e9SAndroid Build Coastguard Worker *product( 155*da0073e9SAndroid Build Coastguard Worker (torch.randn(4, 5, 6, 7, device=device, dtype=dtype),), 156*da0073e9SAndroid Build Coastguard Worker (None,), 157*da0073e9SAndroid Build Coastguard Worker (1, 2, -2,), 158*da0073e9SAndroid Build Coastguard Worker norm_modes 159*da0073e9SAndroid Build Coastguard Worker ) 160*da0073e9SAndroid Build Coastguard Worker ] 161*da0073e9SAndroid Build Coastguard Worker 162*da0073e9SAndroid Build Coastguard Worker for iargs in test_args: 163*da0073e9SAndroid Build Coastguard Worker args = list(iargs) 164*da0073e9SAndroid Build Coastguard Worker input = args[0] 165*da0073e9SAndroid Build Coastguard Worker args = args[1:] 166*da0073e9SAndroid Build Coastguard Worker 167*da0073e9SAndroid Build Coastguard Worker expected = op.ref(input.cpu().numpy(), *args) 168*da0073e9SAndroid Build Coastguard Worker exact_dtype = dtype in (torch.double, torch.complex128) 169*da0073e9SAndroid Build Coastguard Worker actual = op(input, *args) 170*da0073e9SAndroid Build Coastguard Worker self.assertEqual(actual, expected, exact_dtype=exact_dtype) 171*da0073e9SAndroid Build Coastguard Worker 172*da0073e9SAndroid Build Coastguard Worker @skipCPUIfNoFFT 173*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 174*da0073e9SAndroid Build Coastguard Worker @toleranceOverride({ 175*da0073e9SAndroid Build Coastguard Worker torch.half : tol(1e-2, 1e-2), 176*da0073e9SAndroid Build Coastguard Worker torch.chalf : tol(1e-2, 1e-2), 177*da0073e9SAndroid Build Coastguard Worker }) 178*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.half, torch.float, torch.double, torch.complex32, torch.complex64, torch.complex128) 179*da0073e9SAndroid Build Coastguard Worker def test_fft_round_trip(self, device, dtype): 180*da0073e9SAndroid Build Coastguard Worker skip_helper_for_fft(device, dtype) 181*da0073e9SAndroid Build Coastguard Worker # Test that round trip through ifft(fft(x)) is the identity 182*da0073e9SAndroid Build Coastguard Worker if dtype not in (torch.half, torch.complex32): 183*da0073e9SAndroid Build Coastguard Worker test_args = list(product( 184*da0073e9SAndroid Build Coastguard Worker # input 185*da0073e9SAndroid Build Coastguard Worker (torch.randn(67, device=device, dtype=dtype), 186*da0073e9SAndroid Build Coastguard Worker torch.randn(80, device=device, dtype=dtype), 187*da0073e9SAndroid Build Coastguard Worker torch.randn(12, 14, device=device, dtype=dtype), 188*da0073e9SAndroid Build Coastguard Worker torch.randn(9, 6, 3, device=device, dtype=dtype)), 189*da0073e9SAndroid Build Coastguard Worker # dim 190*da0073e9SAndroid Build Coastguard Worker (-1, 0), 191*da0073e9SAndroid Build Coastguard Worker # norm 192*da0073e9SAndroid Build Coastguard Worker (None, "forward", "backward", "ortho") 193*da0073e9SAndroid Build Coastguard Worker )) 194*da0073e9SAndroid Build Coastguard Worker else: 195*da0073e9SAndroid Build Coastguard Worker # cuFFT supports powers of 2 for half and complex half precision 196*da0073e9SAndroid Build Coastguard Worker test_args = list(product( 197*da0073e9SAndroid Build Coastguard Worker # input 198*da0073e9SAndroid Build Coastguard Worker (torch.randn(64, device=device, dtype=dtype), 199*da0073e9SAndroid Build Coastguard Worker torch.randn(128, device=device, dtype=dtype), 200*da0073e9SAndroid Build Coastguard Worker torch.randn(4, 16, device=device, dtype=dtype), 201*da0073e9SAndroid Build Coastguard Worker torch.randn(8, 6, 2, device=device, dtype=dtype)), 202*da0073e9SAndroid Build Coastguard Worker # dim 203*da0073e9SAndroid Build Coastguard Worker (-1, 0), 204*da0073e9SAndroid Build Coastguard Worker # norm 205*da0073e9SAndroid Build Coastguard Worker (None, "forward", "backward", "ortho") 206*da0073e9SAndroid Build Coastguard Worker )) 207*da0073e9SAndroid Build Coastguard Worker 208*da0073e9SAndroid Build Coastguard Worker fft_functions = [(torch.fft.fft, torch.fft.ifft)] 209*da0073e9SAndroid Build Coastguard Worker # Real-only functions 210*da0073e9SAndroid Build Coastguard Worker if not dtype.is_complex: 211*da0073e9SAndroid Build Coastguard Worker # NOTE: Using ihfft as "forward" transform to avoid needing to 212*da0073e9SAndroid Build Coastguard Worker # generate true half-complex input 213*da0073e9SAndroid Build Coastguard Worker fft_functions += [(torch.fft.rfft, torch.fft.irfft), 214*da0073e9SAndroid Build Coastguard Worker (torch.fft.ihfft, torch.fft.hfft)] 215*da0073e9SAndroid Build Coastguard Worker 216*da0073e9SAndroid Build Coastguard Worker for forward, backward in fft_functions: 217*da0073e9SAndroid Build Coastguard Worker for x, dim, norm in test_args: 218*da0073e9SAndroid Build Coastguard Worker kwargs = { 219*da0073e9SAndroid Build Coastguard Worker 'n': x.size(dim), 220*da0073e9SAndroid Build Coastguard Worker 'dim': dim, 221*da0073e9SAndroid Build Coastguard Worker 'norm': norm, 222*da0073e9SAndroid Build Coastguard Worker } 223*da0073e9SAndroid Build Coastguard Worker 224*da0073e9SAndroid Build Coastguard Worker y = backward(forward(x, **kwargs), **kwargs) 225*da0073e9SAndroid Build Coastguard Worker if x.dtype is torch.half and y.dtype is torch.complex32: 226*da0073e9SAndroid Build Coastguard Worker # Since type promotion currently doesn't work with complex32 227*da0073e9SAndroid Build Coastguard Worker # manually promote `x` to complex32 228*da0073e9SAndroid Build Coastguard Worker x = x.to(torch.complex32) 229*da0073e9SAndroid Build Coastguard Worker # For real input, ifft(fft(x)) will convert to complex 230*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x, y, exact_dtype=( 231*da0073e9SAndroid Build Coastguard Worker forward != torch.fft.fft or x.is_complex())) 232*da0073e9SAndroid Build Coastguard Worker 233*da0073e9SAndroid Build Coastguard Worker # Note: NumPy will throw a ValueError for an empty input 234*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 235*da0073e9SAndroid Build Coastguard Worker @ops(spectral_funcs, allowed_dtypes=(torch.half, torch.float, torch.complex32, torch.cfloat)) 236*da0073e9SAndroid Build Coastguard Worker def test_empty_fft(self, device, dtype, op): 237*da0073e9SAndroid Build Coastguard Worker t = torch.empty(1, 0, device=device, dtype=dtype) 238*da0073e9SAndroid Build Coastguard Worker match = r"Invalid number of data points \([-\d]*\) specified" 239*da0073e9SAndroid Build Coastguard Worker 240*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, match): 241*da0073e9SAndroid Build Coastguard Worker op(t) 242*da0073e9SAndroid Build Coastguard Worker 243*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 244*da0073e9SAndroid Build Coastguard Worker def test_empty_ifft(self, device): 245*da0073e9SAndroid Build Coastguard Worker t = torch.empty(2, 1, device=device, dtype=torch.complex64) 246*da0073e9SAndroid Build Coastguard Worker match = r"Invalid number of data points \([-\d]*\) specified" 247*da0073e9SAndroid Build Coastguard Worker 248*da0073e9SAndroid Build Coastguard Worker for f in [torch.fft.irfft, torch.fft.irfft2, torch.fft.irfftn, 249*da0073e9SAndroid Build Coastguard Worker torch.fft.hfft, torch.fft.hfft2, torch.fft.hfftn]: 250*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, match): 251*da0073e9SAndroid Build Coastguard Worker f(t) 252*da0073e9SAndroid Build Coastguard Worker 253*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 254*da0073e9SAndroid Build Coastguard Worker def test_fft_invalid_dtypes(self, device): 255*da0073e9SAndroid Build Coastguard Worker t = torch.randn(64, device=device, dtype=torch.complex128) 256*da0073e9SAndroid Build Coastguard Worker 257*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "rfft expects a real input tensor"): 258*da0073e9SAndroid Build Coastguard Worker torch.fft.rfft(t) 259*da0073e9SAndroid Build Coastguard Worker 260*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "rfftn expects a real-valued input tensor"): 261*da0073e9SAndroid Build Coastguard Worker torch.fft.rfftn(t) 262*da0073e9SAndroid Build Coastguard Worker 263*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "ihfft expects a real input tensor"): 264*da0073e9SAndroid Build Coastguard Worker torch.fft.ihfft(t) 265*da0073e9SAndroid Build Coastguard Worker 266*da0073e9SAndroid Build Coastguard Worker @skipCPUIfNoFFT 267*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 268*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.int8, torch.half, torch.float, torch.double, 269*da0073e9SAndroid Build Coastguard Worker torch.complex32, torch.complex64, torch.complex128) 270*da0073e9SAndroid Build Coastguard Worker def test_fft_type_promotion(self, device, dtype): 271*da0073e9SAndroid Build Coastguard Worker skip_helper_for_fft(device, dtype) 272*da0073e9SAndroid Build Coastguard Worker 273*da0073e9SAndroid Build Coastguard Worker if dtype.is_complex or dtype.is_floating_point: 274*da0073e9SAndroid Build Coastguard Worker t = torch.randn(64, device=device, dtype=dtype) 275*da0073e9SAndroid Build Coastguard Worker else: 276*da0073e9SAndroid Build Coastguard Worker t = torch.randint(-2, 2, (64,), device=device, dtype=dtype) 277*da0073e9SAndroid Build Coastguard Worker 278*da0073e9SAndroid Build Coastguard Worker PROMOTION_MAP = { 279*da0073e9SAndroid Build Coastguard Worker torch.int8: torch.complex64, 280*da0073e9SAndroid Build Coastguard Worker torch.half: torch.complex32, 281*da0073e9SAndroid Build Coastguard Worker torch.float: torch.complex64, 282*da0073e9SAndroid Build Coastguard Worker torch.double: torch.complex128, 283*da0073e9SAndroid Build Coastguard Worker torch.complex32: torch.complex32, 284*da0073e9SAndroid Build Coastguard Worker torch.complex64: torch.complex64, 285*da0073e9SAndroid Build Coastguard Worker torch.complex128: torch.complex128, 286*da0073e9SAndroid Build Coastguard Worker } 287*da0073e9SAndroid Build Coastguard Worker T = torch.fft.fft(t) 288*da0073e9SAndroid Build Coastguard Worker self.assertEqual(T.dtype, PROMOTION_MAP[dtype]) 289*da0073e9SAndroid Build Coastguard Worker 290*da0073e9SAndroid Build Coastguard Worker PROMOTION_MAP_C2R = { 291*da0073e9SAndroid Build Coastguard Worker torch.int8: torch.float, 292*da0073e9SAndroid Build Coastguard Worker torch.half: torch.half, 293*da0073e9SAndroid Build Coastguard Worker torch.float: torch.float, 294*da0073e9SAndroid Build Coastguard Worker torch.double: torch.double, 295*da0073e9SAndroid Build Coastguard Worker torch.complex32: torch.half, 296*da0073e9SAndroid Build Coastguard Worker torch.complex64: torch.float, 297*da0073e9SAndroid Build Coastguard Worker torch.complex128: torch.double, 298*da0073e9SAndroid Build Coastguard Worker } 299*da0073e9SAndroid Build Coastguard Worker if dtype in (torch.half, torch.complex32): 300*da0073e9SAndroid Build Coastguard Worker # cuFFT supports powers of 2 for half and complex half precision 301*da0073e9SAndroid Build Coastguard Worker # NOTE: With hfft and default args where output_size n=2*(input_size - 1), 302*da0073e9SAndroid Build Coastguard Worker # we make sure that logical fft size is a power of two. 303*da0073e9SAndroid Build Coastguard Worker x = torch.randn(65, device=device, dtype=dtype) 304*da0073e9SAndroid Build Coastguard Worker R = torch.fft.hfft(x) 305*da0073e9SAndroid Build Coastguard Worker else: 306*da0073e9SAndroid Build Coastguard Worker R = torch.fft.hfft(t) 307*da0073e9SAndroid Build Coastguard Worker self.assertEqual(R.dtype, PROMOTION_MAP_C2R[dtype]) 308*da0073e9SAndroid Build Coastguard Worker 309*da0073e9SAndroid Build Coastguard Worker if not dtype.is_complex: 310*da0073e9SAndroid Build Coastguard Worker PROMOTION_MAP_R2C = { 311*da0073e9SAndroid Build Coastguard Worker torch.int8: torch.complex64, 312*da0073e9SAndroid Build Coastguard Worker torch.half: torch.complex32, 313*da0073e9SAndroid Build Coastguard Worker torch.float: torch.complex64, 314*da0073e9SAndroid Build Coastguard Worker torch.double: torch.complex128, 315*da0073e9SAndroid Build Coastguard Worker } 316*da0073e9SAndroid Build Coastguard Worker C = torch.fft.rfft(t) 317*da0073e9SAndroid Build Coastguard Worker self.assertEqual(C.dtype, PROMOTION_MAP_R2C[dtype]) 318*da0073e9SAndroid Build Coastguard Worker 319*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 320*da0073e9SAndroid Build Coastguard Worker @ops(spectral_funcs, dtypes=OpDTypes.unsupported, 321*da0073e9SAndroid Build Coastguard Worker allowed_dtypes=[torch.half, torch.bfloat16]) 322*da0073e9SAndroid Build Coastguard Worker def test_fft_half_and_bfloat16_errors(self, device, dtype, op): 323*da0073e9SAndroid Build Coastguard Worker # TODO: Remove torch.half error when complex32 is fully implemented 324*da0073e9SAndroid Build Coastguard Worker sample = first_sample(self, op.sample_inputs(device, dtype)) 325*da0073e9SAndroid Build Coastguard Worker device_type = torch.device(device).type 326*da0073e9SAndroid Build Coastguard Worker default_msg = "Unsupported dtype" 327*da0073e9SAndroid Build Coastguard Worker if dtype is torch.half and device_type == 'cuda' and TEST_WITH_ROCM: 328*da0073e9SAndroid Build Coastguard Worker err_msg = default_msg 329*da0073e9SAndroid Build Coastguard Worker elif dtype is torch.half and device_type == 'cuda' and not SM53OrLater: 330*da0073e9SAndroid Build Coastguard Worker err_msg = "cuFFT doesn't support signals of half type with compute capability less than SM_53" 331*da0073e9SAndroid Build Coastguard Worker else: 332*da0073e9SAndroid Build Coastguard Worker err_msg = default_msg 333*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, err_msg): 334*da0073e9SAndroid Build Coastguard Worker op(sample.input, *sample.args, **sample.kwargs) 335*da0073e9SAndroid Build Coastguard Worker 336*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 337*da0073e9SAndroid Build Coastguard Worker @ops(spectral_funcs, allowed_dtypes=(torch.half, torch.chalf)) 338*da0073e9SAndroid Build Coastguard Worker def test_fft_half_and_chalf_not_power_of_two_error(self, device, dtype, op): 339*da0073e9SAndroid Build Coastguard Worker t = make_tensor(13, 13, device=device, dtype=dtype) 340*da0073e9SAndroid Build Coastguard Worker err_msg = "cuFFT only supports dimensions whose sizes are powers of two" 341*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, err_msg): 342*da0073e9SAndroid Build Coastguard Worker op(t) 343*da0073e9SAndroid Build Coastguard Worker 344*da0073e9SAndroid Build Coastguard Worker if op.ndimensional in (SpectralFuncType.ND, SpectralFuncType.TwoD): 345*da0073e9SAndroid Build Coastguard Worker kwargs = {'s': (12, 12)} 346*da0073e9SAndroid Build Coastguard Worker else: 347*da0073e9SAndroid Build Coastguard Worker kwargs = {'n': 12} 348*da0073e9SAndroid Build Coastguard Worker 349*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, err_msg): 350*da0073e9SAndroid Build Coastguard Worker op(t, **kwargs) 351*da0073e9SAndroid Build Coastguard Worker 352*da0073e9SAndroid Build Coastguard Worker # nd-fft tests 353*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 354*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_NUMPY, 'NumPy not found') 355*da0073e9SAndroid Build Coastguard Worker @ops([op for op in spectral_funcs if op.ndimensional == SpectralFuncType.ND], 356*da0073e9SAndroid Build Coastguard Worker allowed_dtypes=(torch.cfloat, torch.cdouble)) 357*da0073e9SAndroid Build Coastguard Worker def test_reference_nd(self, device, dtype, op): 358*da0073e9SAndroid Build Coastguard Worker if op.ref is None: 359*da0073e9SAndroid Build Coastguard Worker raise unittest.SkipTest("No reference implementation") 360*da0073e9SAndroid Build Coastguard Worker 361*da0073e9SAndroid Build Coastguard Worker norm_modes = REFERENCE_NORM_MODES 362*da0073e9SAndroid Build Coastguard Worker 363*da0073e9SAndroid Build Coastguard Worker # input_ndim, s, dim 364*da0073e9SAndroid Build Coastguard Worker transform_desc = [ 365*da0073e9SAndroid Build Coastguard Worker *product(range(2, 5), (None,), (None, (0,), (0, -1))), 366*da0073e9SAndroid Build Coastguard Worker *product(range(2, 5), (None, (4, 10)), (None,)), 367*da0073e9SAndroid Build Coastguard Worker (6, None, None), 368*da0073e9SAndroid Build Coastguard Worker (5, None, (1, 3, 4)), 369*da0073e9SAndroid Build Coastguard Worker (3, None, (1,)), 370*da0073e9SAndroid Build Coastguard Worker (1, None, (0,)), 371*da0073e9SAndroid Build Coastguard Worker (4, (10, 10), None), 372*da0073e9SAndroid Build Coastguard Worker (4, (10, 10), (0, 1)) 373*da0073e9SAndroid Build Coastguard Worker ] 374*da0073e9SAndroid Build Coastguard Worker 375*da0073e9SAndroid Build Coastguard Worker for input_ndim, s, dim in transform_desc: 376*da0073e9SAndroid Build Coastguard Worker shape = itertools.islice(itertools.cycle(range(4, 9)), input_ndim) 377*da0073e9SAndroid Build Coastguard Worker input = torch.randn(*shape, device=device, dtype=dtype) 378*da0073e9SAndroid Build Coastguard Worker 379*da0073e9SAndroid Build Coastguard Worker for norm in norm_modes: 380*da0073e9SAndroid Build Coastguard Worker expected = op.ref(input.cpu().numpy(), s, dim, norm) 381*da0073e9SAndroid Build Coastguard Worker exact_dtype = dtype in (torch.double, torch.complex128) 382*da0073e9SAndroid Build Coastguard Worker actual = op(input, s, dim, norm) 383*da0073e9SAndroid Build Coastguard Worker self.assertEqual(actual, expected, exact_dtype=exact_dtype) 384*da0073e9SAndroid Build Coastguard Worker 385*da0073e9SAndroid Build Coastguard Worker @skipCPUIfNoFFT 386*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 387*da0073e9SAndroid Build Coastguard Worker @toleranceOverride({ 388*da0073e9SAndroid Build Coastguard Worker torch.half : tol(1e-2, 1e-2), 389*da0073e9SAndroid Build Coastguard Worker torch.chalf : tol(1e-2, 1e-2), 390*da0073e9SAndroid Build Coastguard Worker }) 391*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.half, torch.float, torch.double, 392*da0073e9SAndroid Build Coastguard Worker torch.complex32, torch.complex64, torch.complex128) 393*da0073e9SAndroid Build Coastguard Worker def test_fftn_round_trip(self, device, dtype): 394*da0073e9SAndroid Build Coastguard Worker skip_helper_for_fft(device, dtype) 395*da0073e9SAndroid Build Coastguard Worker 396*da0073e9SAndroid Build Coastguard Worker norm_modes = (None, "forward", "backward", "ortho") 397*da0073e9SAndroid Build Coastguard Worker 398*da0073e9SAndroid Build Coastguard Worker # input_ndim, dim 399*da0073e9SAndroid Build Coastguard Worker transform_desc = [ 400*da0073e9SAndroid Build Coastguard Worker *product(range(2, 5), (None, (0,), (0, -1))), 401*da0073e9SAndroid Build Coastguard Worker (7, None), 402*da0073e9SAndroid Build Coastguard Worker (5, (1, 3, 4)), 403*da0073e9SAndroid Build Coastguard Worker (3, (1,)), 404*da0073e9SAndroid Build Coastguard Worker (1, 0), 405*da0073e9SAndroid Build Coastguard Worker ] 406*da0073e9SAndroid Build Coastguard Worker 407*da0073e9SAndroid Build Coastguard Worker fft_functions = [(torch.fft.fftn, torch.fft.ifftn)] 408*da0073e9SAndroid Build Coastguard Worker 409*da0073e9SAndroid Build Coastguard Worker # Real-only functions 410*da0073e9SAndroid Build Coastguard Worker if not dtype.is_complex: 411*da0073e9SAndroid Build Coastguard Worker # NOTE: Using ihfftn as "forward" transform to avoid needing to 412*da0073e9SAndroid Build Coastguard Worker # generate true half-complex input 413*da0073e9SAndroid Build Coastguard Worker fft_functions += [(torch.fft.rfftn, torch.fft.irfftn), 414*da0073e9SAndroid Build Coastguard Worker (torch.fft.ihfftn, torch.fft.hfftn)] 415*da0073e9SAndroid Build Coastguard Worker 416*da0073e9SAndroid Build Coastguard Worker for input_ndim, dim in transform_desc: 417*da0073e9SAndroid Build Coastguard Worker if dtype in (torch.half, torch.complex32): 418*da0073e9SAndroid Build Coastguard Worker # cuFFT supports powers of 2 for half and complex half precision 419*da0073e9SAndroid Build Coastguard Worker shape = itertools.islice(itertools.cycle((2, 4, 8)), input_ndim) 420*da0073e9SAndroid Build Coastguard Worker else: 421*da0073e9SAndroid Build Coastguard Worker shape = itertools.islice(itertools.cycle(range(4, 9)), input_ndim) 422*da0073e9SAndroid Build Coastguard Worker x = torch.randn(*shape, device=device, dtype=dtype) 423*da0073e9SAndroid Build Coastguard Worker 424*da0073e9SAndroid Build Coastguard Worker for (forward, backward), norm in product(fft_functions, norm_modes): 425*da0073e9SAndroid Build Coastguard Worker if isinstance(dim, tuple): 426*da0073e9SAndroid Build Coastguard Worker s = [x.size(d) for d in dim] 427*da0073e9SAndroid Build Coastguard Worker else: 428*da0073e9SAndroid Build Coastguard Worker s = x.size() if dim is None else x.size(dim) 429*da0073e9SAndroid Build Coastguard Worker 430*da0073e9SAndroid Build Coastguard Worker kwargs = {'s': s, 'dim': dim, 'norm': norm} 431*da0073e9SAndroid Build Coastguard Worker y = backward(forward(x, **kwargs), **kwargs) 432*da0073e9SAndroid Build Coastguard Worker # For real input, ifftn(fftn(x)) will convert to complex 433*da0073e9SAndroid Build Coastguard Worker if x.dtype is torch.half and y.dtype is torch.chalf: 434*da0073e9SAndroid Build Coastguard Worker # Since type promotion currently doesn't work with complex32 435*da0073e9SAndroid Build Coastguard Worker # manually promote `x` to complex32 436*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.to(torch.chalf), y) 437*da0073e9SAndroid Build Coastguard Worker else: 438*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x, y, exact_dtype=( 439*da0073e9SAndroid Build Coastguard Worker forward != torch.fft.fftn or x.is_complex())) 440*da0073e9SAndroid Build Coastguard Worker 441*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 442*da0073e9SAndroid Build Coastguard Worker @ops([op for op in spectral_funcs if op.ndimensional == SpectralFuncType.ND], 443*da0073e9SAndroid Build Coastguard Worker allowed_dtypes=[torch.float, torch.cfloat]) 444*da0073e9SAndroid Build Coastguard Worker def test_fftn_invalid(self, device, dtype, op): 445*da0073e9SAndroid Build Coastguard Worker a = torch.rand(10, 10, 10, device=device, dtype=dtype) 446*da0073e9SAndroid Build Coastguard Worker # FIXME: https://github.com/pytorch/pytorch/issues/108205 447*da0073e9SAndroid Build Coastguard Worker errMsg = "dims must be unique" 448*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, errMsg): 449*da0073e9SAndroid Build Coastguard Worker op(a, dim=(0, 1, 0)) 450*da0073e9SAndroid Build Coastguard Worker 451*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, errMsg): 452*da0073e9SAndroid Build Coastguard Worker op(a, dim=(2, -1)) 453*da0073e9SAndroid Build Coastguard Worker 454*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "dim and shape .* same length"): 455*da0073e9SAndroid Build Coastguard Worker op(a, s=(1,), dim=(0, 1)) 456*da0073e9SAndroid Build Coastguard Worker 457*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(IndexError, "Dimension out of range"): 458*da0073e9SAndroid Build Coastguard Worker op(a, dim=(3,)) 459*da0073e9SAndroid Build Coastguard Worker 460*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "tensor only has 3 dimensions"): 461*da0073e9SAndroid Build Coastguard Worker op(a, s=(10, 10, 10, 10)) 462*da0073e9SAndroid Build Coastguard Worker 463*da0073e9SAndroid Build Coastguard Worker @skipCPUIfNoFFT 464*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 465*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.half, torch.float, torch.double, torch.cfloat, torch.cdouble) 466*da0073e9SAndroid Build Coastguard Worker def test_fftn_noop_transform(self, device, dtype): 467*da0073e9SAndroid Build Coastguard Worker skip_helper_for_fft(device, dtype) 468*da0073e9SAndroid Build Coastguard Worker RESULT_TYPE = { 469*da0073e9SAndroid Build Coastguard Worker torch.half: torch.chalf, 470*da0073e9SAndroid Build Coastguard Worker torch.float: torch.cfloat, 471*da0073e9SAndroid Build Coastguard Worker torch.double: torch.cdouble, 472*da0073e9SAndroid Build Coastguard Worker } 473*da0073e9SAndroid Build Coastguard Worker 474*da0073e9SAndroid Build Coastguard Worker for op in [ 475*da0073e9SAndroid Build Coastguard Worker torch.fft.fftn, 476*da0073e9SAndroid Build Coastguard Worker torch.fft.ifftn, 477*da0073e9SAndroid Build Coastguard Worker torch.fft.fft2, 478*da0073e9SAndroid Build Coastguard Worker torch.fft.ifft2, 479*da0073e9SAndroid Build Coastguard Worker ]: 480*da0073e9SAndroid Build Coastguard Worker inp = make_tensor((10, 10), device=device, dtype=dtype) 481*da0073e9SAndroid Build Coastguard Worker out = torch.fft.fftn(inp, dim=[]) 482*da0073e9SAndroid Build Coastguard Worker 483*da0073e9SAndroid Build Coastguard Worker expect_dtype = RESULT_TYPE.get(inp.dtype, inp.dtype) 484*da0073e9SAndroid Build Coastguard Worker expect = inp.to(expect_dtype) 485*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expect, out) 486*da0073e9SAndroid Build Coastguard Worker 487*da0073e9SAndroid Build Coastguard Worker 488*da0073e9SAndroid Build Coastguard Worker @skipCPUIfNoFFT 489*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 490*da0073e9SAndroid Build Coastguard Worker @toleranceOverride({ 491*da0073e9SAndroid Build Coastguard Worker torch.half : tol(1e-2, 1e-2), 492*da0073e9SAndroid Build Coastguard Worker }) 493*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.half, torch.float, torch.double) 494*da0073e9SAndroid Build Coastguard Worker def test_hfftn(self, device, dtype): 495*da0073e9SAndroid Build Coastguard Worker skip_helper_for_fft(device, dtype) 496*da0073e9SAndroid Build Coastguard Worker 497*da0073e9SAndroid Build Coastguard Worker # input_ndim, dim 498*da0073e9SAndroid Build Coastguard Worker transform_desc = [ 499*da0073e9SAndroid Build Coastguard Worker *product(range(2, 5), (None, (0,), (0, -1))), 500*da0073e9SAndroid Build Coastguard Worker (6, None), 501*da0073e9SAndroid Build Coastguard Worker (5, (1, 3, 4)), 502*da0073e9SAndroid Build Coastguard Worker (3, (1,)), 503*da0073e9SAndroid Build Coastguard Worker (1, (0,)), 504*da0073e9SAndroid Build Coastguard Worker (4, (0, 1)) 505*da0073e9SAndroid Build Coastguard Worker ] 506*da0073e9SAndroid Build Coastguard Worker 507*da0073e9SAndroid Build Coastguard Worker for input_ndim, dim in transform_desc: 508*da0073e9SAndroid Build Coastguard Worker actual_dims = list(range(input_ndim)) if dim is None else dim 509*da0073e9SAndroid Build Coastguard Worker if dtype is torch.half: 510*da0073e9SAndroid Build Coastguard Worker shape = tuple(itertools.islice(itertools.cycle((2, 4, 8)), input_ndim)) 511*da0073e9SAndroid Build Coastguard Worker else: 512*da0073e9SAndroid Build Coastguard Worker shape = tuple(itertools.islice(itertools.cycle(range(4, 9)), input_ndim)) 513*da0073e9SAndroid Build Coastguard Worker expect = torch.randn(*shape, device=device, dtype=dtype) 514*da0073e9SAndroid Build Coastguard Worker input = torch.fft.ifftn(expect, dim=dim, norm="ortho") 515*da0073e9SAndroid Build Coastguard Worker 516*da0073e9SAndroid Build Coastguard Worker lastdim = actual_dims[-1] 517*da0073e9SAndroid Build Coastguard Worker lastdim_size = input.size(lastdim) // 2 + 1 518*da0073e9SAndroid Build Coastguard Worker idx = [slice(None)] * input_ndim 519*da0073e9SAndroid Build Coastguard Worker idx[lastdim] = slice(0, lastdim_size) 520*da0073e9SAndroid Build Coastguard Worker input = input[idx] 521*da0073e9SAndroid Build Coastguard Worker 522*da0073e9SAndroid Build Coastguard Worker s = [shape[dim] for dim in actual_dims] 523*da0073e9SAndroid Build Coastguard Worker actual = torch.fft.hfftn(input, s=s, dim=dim, norm="ortho") 524*da0073e9SAndroid Build Coastguard Worker 525*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expect, actual) 526*da0073e9SAndroid Build Coastguard Worker 527*da0073e9SAndroid Build Coastguard Worker @skipCPUIfNoFFT 528*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 529*da0073e9SAndroid Build Coastguard Worker @toleranceOverride({ 530*da0073e9SAndroid Build Coastguard Worker torch.half : tol(1e-2, 1e-2), 531*da0073e9SAndroid Build Coastguard Worker }) 532*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.half, torch.float, torch.double) 533*da0073e9SAndroid Build Coastguard Worker def test_ihfftn(self, device, dtype): 534*da0073e9SAndroid Build Coastguard Worker skip_helper_for_fft(device, dtype) 535*da0073e9SAndroid Build Coastguard Worker 536*da0073e9SAndroid Build Coastguard Worker # input_ndim, dim 537*da0073e9SAndroid Build Coastguard Worker transform_desc = [ 538*da0073e9SAndroid Build Coastguard Worker *product(range(2, 5), (None, (0,), (0, -1))), 539*da0073e9SAndroid Build Coastguard Worker (6, None), 540*da0073e9SAndroid Build Coastguard Worker (5, (1, 3, 4)), 541*da0073e9SAndroid Build Coastguard Worker (3, (1,)), 542*da0073e9SAndroid Build Coastguard Worker (1, (0,)), 543*da0073e9SAndroid Build Coastguard Worker (4, (0, 1)) 544*da0073e9SAndroid Build Coastguard Worker ] 545*da0073e9SAndroid Build Coastguard Worker 546*da0073e9SAndroid Build Coastguard Worker for input_ndim, dim in transform_desc: 547*da0073e9SAndroid Build Coastguard Worker if dtype is torch.half: 548*da0073e9SAndroid Build Coastguard Worker shape = tuple(itertools.islice(itertools.cycle((2, 4, 8)), input_ndim)) 549*da0073e9SAndroid Build Coastguard Worker else: 550*da0073e9SAndroid Build Coastguard Worker shape = tuple(itertools.islice(itertools.cycle(range(4, 9)), input_ndim)) 551*da0073e9SAndroid Build Coastguard Worker 552*da0073e9SAndroid Build Coastguard Worker input = torch.randn(*shape, device=device, dtype=dtype) 553*da0073e9SAndroid Build Coastguard Worker expect = torch.fft.ifftn(input, dim=dim, norm="ortho") 554*da0073e9SAndroid Build Coastguard Worker 555*da0073e9SAndroid Build Coastguard Worker # Slice off the half-symmetric component 556*da0073e9SAndroid Build Coastguard Worker lastdim = -1 if dim is None else dim[-1] 557*da0073e9SAndroid Build Coastguard Worker lastdim_size = expect.size(lastdim) // 2 + 1 558*da0073e9SAndroid Build Coastguard Worker idx = [slice(None)] * input_ndim 559*da0073e9SAndroid Build Coastguard Worker idx[lastdim] = slice(0, lastdim_size) 560*da0073e9SAndroid Build Coastguard Worker expect = expect[idx] 561*da0073e9SAndroid Build Coastguard Worker 562*da0073e9SAndroid Build Coastguard Worker actual = torch.fft.ihfftn(input, dim=dim, norm="ortho") 563*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expect, actual) 564*da0073e9SAndroid Build Coastguard Worker 565*da0073e9SAndroid Build Coastguard Worker 566*da0073e9SAndroid Build Coastguard Worker # 2d-fft tests 567*da0073e9SAndroid Build Coastguard Worker 568*da0073e9SAndroid Build Coastguard Worker # NOTE: 2d transforms are only thin wrappers over n-dim transforms, 569*da0073e9SAndroid Build Coastguard Worker # so don't require exhaustive testing. 570*da0073e9SAndroid Build Coastguard Worker 571*da0073e9SAndroid Build Coastguard Worker 572*da0073e9SAndroid Build Coastguard Worker @skipCPUIfNoFFT 573*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 574*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.double, torch.complex128) 575*da0073e9SAndroid Build Coastguard Worker def test_fft2_numpy(self, device, dtype): 576*da0073e9SAndroid Build Coastguard Worker norm_modes = REFERENCE_NORM_MODES 577*da0073e9SAndroid Build Coastguard Worker 578*da0073e9SAndroid Build Coastguard Worker # input_ndim, s 579*da0073e9SAndroid Build Coastguard Worker transform_desc = [ 580*da0073e9SAndroid Build Coastguard Worker *product(range(2, 5), (None, (4, 10))), 581*da0073e9SAndroid Build Coastguard Worker ] 582*da0073e9SAndroid Build Coastguard Worker 583*da0073e9SAndroid Build Coastguard Worker fft_functions = ['fft2', 'ifft2', 'irfft2', 'hfft2'] 584*da0073e9SAndroid Build Coastguard Worker if dtype.is_floating_point: 585*da0073e9SAndroid Build Coastguard Worker fft_functions += ['rfft2', 'ihfft2'] 586*da0073e9SAndroid Build Coastguard Worker 587*da0073e9SAndroid Build Coastguard Worker for input_ndim, s in transform_desc: 588*da0073e9SAndroid Build Coastguard Worker shape = itertools.islice(itertools.cycle(range(4, 9)), input_ndim) 589*da0073e9SAndroid Build Coastguard Worker input = torch.randn(*shape, device=device, dtype=dtype) 590*da0073e9SAndroid Build Coastguard Worker for fname, norm in product(fft_functions, norm_modes): 591*da0073e9SAndroid Build Coastguard Worker torch_fn = getattr(torch.fft, fname) 592*da0073e9SAndroid Build Coastguard Worker if "hfft" in fname: 593*da0073e9SAndroid Build Coastguard Worker if not has_scipy_fft: 594*da0073e9SAndroid Build Coastguard Worker continue # Requires scipy to compare against 595*da0073e9SAndroid Build Coastguard Worker numpy_fn = getattr(scipy.fft, fname) 596*da0073e9SAndroid Build Coastguard Worker else: 597*da0073e9SAndroid Build Coastguard Worker numpy_fn = getattr(np.fft, fname) 598*da0073e9SAndroid Build Coastguard Worker 599*da0073e9SAndroid Build Coastguard Worker def fn(t: torch.Tensor, s: Optional[List[int]], dim: List[int] = (-2, -1), norm: Optional[str] = None): 600*da0073e9SAndroid Build Coastguard Worker return torch_fn(t, s, dim, norm) 601*da0073e9SAndroid Build Coastguard Worker 602*da0073e9SAndroid Build Coastguard Worker torch_fns = (torch_fn, torch.jit.script(fn)) 603*da0073e9SAndroid Build Coastguard Worker 604*da0073e9SAndroid Build Coastguard Worker # Once with dim defaulted 605*da0073e9SAndroid Build Coastguard Worker input_np = input.cpu().numpy() 606*da0073e9SAndroid Build Coastguard Worker expected = numpy_fn(input_np, s, norm=norm) 607*da0073e9SAndroid Build Coastguard Worker for fn in torch_fns: 608*da0073e9SAndroid Build Coastguard Worker actual = fn(input, s, norm=norm) 609*da0073e9SAndroid Build Coastguard Worker self.assertEqual(actual, expected) 610*da0073e9SAndroid Build Coastguard Worker 611*da0073e9SAndroid Build Coastguard Worker # Once with explicit dims 612*da0073e9SAndroid Build Coastguard Worker dim = (1, 0) 613*da0073e9SAndroid Build Coastguard Worker expected = numpy_fn(input_np, s, dim, norm) 614*da0073e9SAndroid Build Coastguard Worker for fn in torch_fns: 615*da0073e9SAndroid Build Coastguard Worker actual = fn(input, s, dim, norm) 616*da0073e9SAndroid Build Coastguard Worker self.assertEqual(actual, expected) 617*da0073e9SAndroid Build Coastguard Worker 618*da0073e9SAndroid Build Coastguard Worker @skipCPUIfNoFFT 619*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 620*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float, torch.complex64) 621*da0073e9SAndroid Build Coastguard Worker def test_fft2_fftn_equivalence(self, device, dtype): 622*da0073e9SAndroid Build Coastguard Worker norm_modes = (None, "forward", "backward", "ortho") 623*da0073e9SAndroid Build Coastguard Worker 624*da0073e9SAndroid Build Coastguard Worker # input_ndim, s, dim 625*da0073e9SAndroid Build Coastguard Worker transform_desc = [ 626*da0073e9SAndroid Build Coastguard Worker *product(range(2, 5), (None, (4, 10)), (None, (1, 0))), 627*da0073e9SAndroid Build Coastguard Worker (3, None, (0, 2)), 628*da0073e9SAndroid Build Coastguard Worker ] 629*da0073e9SAndroid Build Coastguard Worker 630*da0073e9SAndroid Build Coastguard Worker fft_functions = ['fft', 'ifft', 'irfft', 'hfft'] 631*da0073e9SAndroid Build Coastguard Worker # Real-only functions 632*da0073e9SAndroid Build Coastguard Worker if dtype.is_floating_point: 633*da0073e9SAndroid Build Coastguard Worker fft_functions += ['rfft', 'ihfft'] 634*da0073e9SAndroid Build Coastguard Worker 635*da0073e9SAndroid Build Coastguard Worker for input_ndim, s, dim in transform_desc: 636*da0073e9SAndroid Build Coastguard Worker shape = itertools.islice(itertools.cycle(range(4, 9)), input_ndim) 637*da0073e9SAndroid Build Coastguard Worker x = torch.randn(*shape, device=device, dtype=dtype) 638*da0073e9SAndroid Build Coastguard Worker 639*da0073e9SAndroid Build Coastguard Worker for func, norm in product(fft_functions, norm_modes): 640*da0073e9SAndroid Build Coastguard Worker f2d = getattr(torch.fft, func + '2') 641*da0073e9SAndroid Build Coastguard Worker fnd = getattr(torch.fft, func + 'n') 642*da0073e9SAndroid Build Coastguard Worker 643*da0073e9SAndroid Build Coastguard Worker kwargs = {'s': s, 'norm': norm} 644*da0073e9SAndroid Build Coastguard Worker 645*da0073e9SAndroid Build Coastguard Worker if dim is not None: 646*da0073e9SAndroid Build Coastguard Worker kwargs['dim'] = dim 647*da0073e9SAndroid Build Coastguard Worker expect = fnd(x, **kwargs) 648*da0073e9SAndroid Build Coastguard Worker else: 649*da0073e9SAndroid Build Coastguard Worker expect = fnd(x, dim=(-2, -1), **kwargs) 650*da0073e9SAndroid Build Coastguard Worker 651*da0073e9SAndroid Build Coastguard Worker actual = f2d(x, **kwargs) 652*da0073e9SAndroid Build Coastguard Worker 653*da0073e9SAndroid Build Coastguard Worker self.assertEqual(actual, expect) 654*da0073e9SAndroid Build Coastguard Worker 655*da0073e9SAndroid Build Coastguard Worker @skipCPUIfNoFFT 656*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 657*da0073e9SAndroid Build Coastguard Worker def test_fft2_invalid(self, device): 658*da0073e9SAndroid Build Coastguard Worker a = torch.rand(10, 10, 10, device=device) 659*da0073e9SAndroid Build Coastguard Worker fft_funcs = (torch.fft.fft2, torch.fft.ifft2, 660*da0073e9SAndroid Build Coastguard Worker torch.fft.rfft2, torch.fft.irfft2) 661*da0073e9SAndroid Build Coastguard Worker 662*da0073e9SAndroid Build Coastguard Worker for func in fft_funcs: 663*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "dims must be unique"): 664*da0073e9SAndroid Build Coastguard Worker func(a, dim=(0, 0)) 665*da0073e9SAndroid Build Coastguard Worker 666*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "dims must be unique"): 667*da0073e9SAndroid Build Coastguard Worker func(a, dim=(2, -1)) 668*da0073e9SAndroid Build Coastguard Worker 669*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "dim and shape .* same length"): 670*da0073e9SAndroid Build Coastguard Worker func(a, s=(1,)) 671*da0073e9SAndroid Build Coastguard Worker 672*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(IndexError, "Dimension out of range"): 673*da0073e9SAndroid Build Coastguard Worker func(a, dim=(2, 3)) 674*da0073e9SAndroid Build Coastguard Worker 675*da0073e9SAndroid Build Coastguard Worker c = torch.complex(a, a) 676*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "rfftn expects a real-valued input"): 677*da0073e9SAndroid Build Coastguard Worker torch.fft.rfft2(c) 678*da0073e9SAndroid Build Coastguard Worker 679*da0073e9SAndroid Build Coastguard Worker # Helper functions 680*da0073e9SAndroid Build Coastguard Worker 681*da0073e9SAndroid Build Coastguard Worker @skipCPUIfNoFFT 682*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 683*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_NUMPY, 'NumPy not found') 684*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float, torch.double) 685*da0073e9SAndroid Build Coastguard Worker def test_fftfreq_numpy(self, device, dtype): 686*da0073e9SAndroid Build Coastguard Worker test_args = [ 687*da0073e9SAndroid Build Coastguard Worker *product( 688*da0073e9SAndroid Build Coastguard Worker # n 689*da0073e9SAndroid Build Coastguard Worker range(1, 20), 690*da0073e9SAndroid Build Coastguard Worker # d 691*da0073e9SAndroid Build Coastguard Worker (None, 10.0), 692*da0073e9SAndroid Build Coastguard Worker ) 693*da0073e9SAndroid Build Coastguard Worker ] 694*da0073e9SAndroid Build Coastguard Worker 695*da0073e9SAndroid Build Coastguard Worker functions = ['fftfreq', 'rfftfreq'] 696*da0073e9SAndroid Build Coastguard Worker 697*da0073e9SAndroid Build Coastguard Worker for fname in functions: 698*da0073e9SAndroid Build Coastguard Worker torch_fn = getattr(torch.fft, fname) 699*da0073e9SAndroid Build Coastguard Worker numpy_fn = getattr(np.fft, fname) 700*da0073e9SAndroid Build Coastguard Worker 701*da0073e9SAndroid Build Coastguard Worker for n, d in test_args: 702*da0073e9SAndroid Build Coastguard Worker args = (n,) if d is None else (n, d) 703*da0073e9SAndroid Build Coastguard Worker expected = numpy_fn(*args) 704*da0073e9SAndroid Build Coastguard Worker actual = torch_fn(*args, device=device, dtype=dtype) 705*da0073e9SAndroid Build Coastguard Worker self.assertEqual(actual, expected, exact_dtype=False) 706*da0073e9SAndroid Build Coastguard Worker 707*da0073e9SAndroid Build Coastguard Worker @skipCPUIfNoFFT 708*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 709*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float, torch.double) 710*da0073e9SAndroid Build Coastguard Worker def test_fftfreq_out(self, device, dtype): 711*da0073e9SAndroid Build Coastguard Worker for func in (torch.fft.fftfreq, torch.fft.rfftfreq): 712*da0073e9SAndroid Build Coastguard Worker expect = func(n=100, d=.5, device=device, dtype=dtype) 713*da0073e9SAndroid Build Coastguard Worker actual = torch.empty((), device=device, dtype=dtype) 714*da0073e9SAndroid Build Coastguard Worker with self.assertWarnsRegex(UserWarning, "out tensor will be resized"): 715*da0073e9SAndroid Build Coastguard Worker func(n=100, d=.5, out=actual) 716*da0073e9SAndroid Build Coastguard Worker self.assertEqual(actual, expect) 717*da0073e9SAndroid Build Coastguard Worker 718*da0073e9SAndroid Build Coastguard Worker 719*da0073e9SAndroid Build Coastguard Worker @skipCPUIfNoFFT 720*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 721*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_NUMPY, 'NumPy not found') 722*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float, torch.double, torch.complex64, torch.complex128) 723*da0073e9SAndroid Build Coastguard Worker def test_fftshift_numpy(self, device, dtype): 724*da0073e9SAndroid Build Coastguard Worker test_args = [ 725*da0073e9SAndroid Build Coastguard Worker # shape, dim 726*da0073e9SAndroid Build Coastguard Worker *product(((11,), (12,)), (None, 0, -1)), 727*da0073e9SAndroid Build Coastguard Worker *product(((4, 5), (6, 6)), (None, 0, (-1,))), 728*da0073e9SAndroid Build Coastguard Worker *product(((1, 1, 4, 6, 7, 2),), (None, (3, 4))), 729*da0073e9SAndroid Build Coastguard Worker ] 730*da0073e9SAndroid Build Coastguard Worker 731*da0073e9SAndroid Build Coastguard Worker functions = ['fftshift', 'ifftshift'] 732*da0073e9SAndroid Build Coastguard Worker 733*da0073e9SAndroid Build Coastguard Worker for shape, dim in test_args: 734*da0073e9SAndroid Build Coastguard Worker input = torch.rand(*shape, device=device, dtype=dtype) 735*da0073e9SAndroid Build Coastguard Worker input_np = input.cpu().numpy() 736*da0073e9SAndroid Build Coastguard Worker 737*da0073e9SAndroid Build Coastguard Worker for fname in functions: 738*da0073e9SAndroid Build Coastguard Worker torch_fn = getattr(torch.fft, fname) 739*da0073e9SAndroid Build Coastguard Worker numpy_fn = getattr(np.fft, fname) 740*da0073e9SAndroid Build Coastguard Worker 741*da0073e9SAndroid Build Coastguard Worker expected = numpy_fn(input_np, axes=dim) 742*da0073e9SAndroid Build Coastguard Worker actual = torch_fn(input, dim=dim) 743*da0073e9SAndroid Build Coastguard Worker self.assertEqual(actual, expected) 744*da0073e9SAndroid Build Coastguard Worker 745*da0073e9SAndroid Build Coastguard Worker @skipCPUIfNoFFT 746*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 747*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_NUMPY, 'NumPy not found') 748*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float, torch.double) 749*da0073e9SAndroid Build Coastguard Worker def test_fftshift_frequencies(self, device, dtype): 750*da0073e9SAndroid Build Coastguard Worker for n in range(10, 15): 751*da0073e9SAndroid Build Coastguard Worker sorted_fft_freqs = torch.arange(-(n // 2), n - (n // 2), 752*da0073e9SAndroid Build Coastguard Worker device=device, dtype=dtype) 753*da0073e9SAndroid Build Coastguard Worker x = torch.fft.fftfreq(n, d=1 / n, device=device, dtype=dtype) 754*da0073e9SAndroid Build Coastguard Worker 755*da0073e9SAndroid Build Coastguard Worker # Test fftshift sorts the fftfreq output 756*da0073e9SAndroid Build Coastguard Worker shifted = torch.fft.fftshift(x) 757*da0073e9SAndroid Build Coastguard Worker self.assertEqual(shifted, shifted.sort().values) 758*da0073e9SAndroid Build Coastguard Worker self.assertEqual(sorted_fft_freqs, shifted) 759*da0073e9SAndroid Build Coastguard Worker 760*da0073e9SAndroid Build Coastguard Worker # And ifftshift is the inverse 761*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x, torch.fft.ifftshift(shifted)) 762*da0073e9SAndroid Build Coastguard Worker 763*da0073e9SAndroid Build Coastguard Worker # Legacy fft tests 764*da0073e9SAndroid Build Coastguard Worker def _test_fft_ifft_rfft_irfft(self, device, dtype): 765*da0073e9SAndroid Build Coastguard Worker complex_dtype = corresponding_complex_dtype(dtype) 766*da0073e9SAndroid Build Coastguard Worker 767*da0073e9SAndroid Build Coastguard Worker def _test_complex(sizes, signal_ndim, prepro_fn=lambda x: x): 768*da0073e9SAndroid Build Coastguard Worker x = prepro_fn(torch.randn(*sizes, dtype=complex_dtype, device=device)) 769*da0073e9SAndroid Build Coastguard Worker dim = tuple(range(-signal_ndim, 0)) 770*da0073e9SAndroid Build Coastguard Worker for norm in ('ortho', None): 771*da0073e9SAndroid Build Coastguard Worker res = torch.fft.fftn(x, dim=dim, norm=norm) 772*da0073e9SAndroid Build Coastguard Worker rec = torch.fft.ifftn(res, dim=dim, norm=norm) 773*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x, rec, atol=1e-8, rtol=0, msg='fft and ifft') 774*da0073e9SAndroid Build Coastguard Worker res = torch.fft.ifftn(x, dim=dim, norm=norm) 775*da0073e9SAndroid Build Coastguard Worker rec = torch.fft.fftn(res, dim=dim, norm=norm) 776*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x, rec, atol=1e-8, rtol=0, msg='ifft and fft') 777*da0073e9SAndroid Build Coastguard Worker 778*da0073e9SAndroid Build Coastguard Worker def _test_real(sizes, signal_ndim, prepro_fn=lambda x: x): 779*da0073e9SAndroid Build Coastguard Worker x = prepro_fn(torch.randn(*sizes, dtype=dtype, device=device)) 780*da0073e9SAndroid Build Coastguard Worker signal_numel = 1 781*da0073e9SAndroid Build Coastguard Worker signal_sizes = x.size()[-signal_ndim:] 782*da0073e9SAndroid Build Coastguard Worker dim = tuple(range(-signal_ndim, 0)) 783*da0073e9SAndroid Build Coastguard Worker for norm in (None, 'ortho'): 784*da0073e9SAndroid Build Coastguard Worker res = torch.fft.rfftn(x, dim=dim, norm=norm) 785*da0073e9SAndroid Build Coastguard Worker rec = torch.fft.irfftn(res, s=signal_sizes, dim=dim, norm=norm) 786*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x, rec, atol=1e-8, rtol=0, msg='rfft and irfft') 787*da0073e9SAndroid Build Coastguard Worker res = torch.fft.fftn(x, dim=dim, norm=norm) 788*da0073e9SAndroid Build Coastguard Worker rec = torch.fft.ifftn(res, dim=dim, norm=norm) 789*da0073e9SAndroid Build Coastguard Worker x_complex = torch.complex(x, torch.zeros_like(x)) 790*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x_complex, rec, atol=1e-8, rtol=0, msg='fft and ifft (from real)') 791*da0073e9SAndroid Build Coastguard Worker 792*da0073e9SAndroid Build Coastguard Worker # contiguous case 793*da0073e9SAndroid Build Coastguard Worker _test_real((100,), 1) 794*da0073e9SAndroid Build Coastguard Worker _test_real((10, 1, 10, 100), 1) 795*da0073e9SAndroid Build Coastguard Worker _test_real((100, 100), 2) 796*da0073e9SAndroid Build Coastguard Worker _test_real((2, 2, 5, 80, 60), 2) 797*da0073e9SAndroid Build Coastguard Worker _test_real((50, 40, 70), 3) 798*da0073e9SAndroid Build Coastguard Worker _test_real((30, 1, 50, 25, 20), 3) 799*da0073e9SAndroid Build Coastguard Worker 800*da0073e9SAndroid Build Coastguard Worker _test_complex((100,), 1) 801*da0073e9SAndroid Build Coastguard Worker _test_complex((100, 100), 1) 802*da0073e9SAndroid Build Coastguard Worker _test_complex((100, 100), 2) 803*da0073e9SAndroid Build Coastguard Worker _test_complex((1, 20, 80, 60), 2) 804*da0073e9SAndroid Build Coastguard Worker _test_complex((50, 40, 70), 3) 805*da0073e9SAndroid Build Coastguard Worker _test_complex((6, 5, 50, 25, 20), 3) 806*da0073e9SAndroid Build Coastguard Worker 807*da0073e9SAndroid Build Coastguard Worker # non-contiguous case 808*da0073e9SAndroid Build Coastguard Worker _test_real((165,), 1, lambda x: x.narrow(0, 25, 100)) # input is not aligned to complex type 809*da0073e9SAndroid Build Coastguard Worker _test_real((100, 100, 3), 1, lambda x: x[:, :, 0]) 810*da0073e9SAndroid Build Coastguard Worker _test_real((100, 100), 2, lambda x: x.t()) 811*da0073e9SAndroid Build Coastguard Worker _test_real((20, 100, 10, 10), 2, lambda x: x.view(20, 100, 100)[:, :60]) 812*da0073e9SAndroid Build Coastguard Worker _test_real((65, 80, 115), 3, lambda x: x[10:60, 13:53, 10:80]) 813*da0073e9SAndroid Build Coastguard Worker _test_real((30, 20, 50, 25), 3, lambda x: x.transpose(1, 2).transpose(2, 3)) 814*da0073e9SAndroid Build Coastguard Worker 815*da0073e9SAndroid Build Coastguard Worker _test_complex((100,), 1, lambda x: x.expand(100, 100)) 816*da0073e9SAndroid Build Coastguard Worker _test_complex((20, 90, 110), 2, lambda x: x[:, 5:85].narrow(2, 5, 100)) 817*da0073e9SAndroid Build Coastguard Worker _test_complex((40, 60, 3, 80), 3, lambda x: x.transpose(2, 0).select(0, 2)[5:55, :, 10:]) 818*da0073e9SAndroid Build Coastguard Worker _test_complex((30, 55, 50, 22), 3, lambda x: x[:, 3:53, 15:40, 1:21]) 819*da0073e9SAndroid Build Coastguard Worker 820*da0073e9SAndroid Build Coastguard Worker @skipCPUIfNoFFT 821*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 822*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.double) 823*da0073e9SAndroid Build Coastguard Worker def test_fft_ifft_rfft_irfft(self, device, dtype): 824*da0073e9SAndroid Build Coastguard Worker self._test_fft_ifft_rfft_irfft(device, dtype) 825*da0073e9SAndroid Build Coastguard Worker 826*da0073e9SAndroid Build Coastguard Worker @deviceCountAtLeast(1) 827*da0073e9SAndroid Build Coastguard Worker @onlyCUDA 828*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.double) 829*da0073e9SAndroid Build Coastguard Worker def test_cufft_plan_cache(self, devices, dtype): 830*da0073e9SAndroid Build Coastguard Worker @contextmanager 831*da0073e9SAndroid Build Coastguard Worker def plan_cache_max_size(device, n): 832*da0073e9SAndroid Build Coastguard Worker if device is None: 833*da0073e9SAndroid Build Coastguard Worker plan_cache = torch.backends.cuda.cufft_plan_cache 834*da0073e9SAndroid Build Coastguard Worker else: 835*da0073e9SAndroid Build Coastguard Worker plan_cache = torch.backends.cuda.cufft_plan_cache[device] 836*da0073e9SAndroid Build Coastguard Worker original = plan_cache.max_size 837*da0073e9SAndroid Build Coastguard Worker plan_cache.max_size = n 838*da0073e9SAndroid Build Coastguard Worker try: 839*da0073e9SAndroid Build Coastguard Worker yield 840*da0073e9SAndroid Build Coastguard Worker finally: 841*da0073e9SAndroid Build Coastguard Worker plan_cache.max_size = original 842*da0073e9SAndroid Build Coastguard Worker 843*da0073e9SAndroid Build Coastguard Worker with plan_cache_max_size(devices[0], max(1, torch.backends.cuda.cufft_plan_cache.size - 10)): 844*da0073e9SAndroid Build Coastguard Worker self._test_fft_ifft_rfft_irfft(devices[0], dtype) 845*da0073e9SAndroid Build Coastguard Worker 846*da0073e9SAndroid Build Coastguard Worker with plan_cache_max_size(devices[0], 0): 847*da0073e9SAndroid Build Coastguard Worker self._test_fft_ifft_rfft_irfft(devices[0], dtype) 848*da0073e9SAndroid Build Coastguard Worker 849*da0073e9SAndroid Build Coastguard Worker torch.backends.cuda.cufft_plan_cache.clear() 850*da0073e9SAndroid Build Coastguard Worker 851*da0073e9SAndroid Build Coastguard Worker # check that stll works after clearing cache 852*da0073e9SAndroid Build Coastguard Worker with plan_cache_max_size(devices[0], 10): 853*da0073e9SAndroid Build Coastguard Worker self._test_fft_ifft_rfft_irfft(devices[0], dtype) 854*da0073e9SAndroid Build Coastguard Worker 855*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, r"must be non-negative"): 856*da0073e9SAndroid Build Coastguard Worker torch.backends.cuda.cufft_plan_cache.max_size = -1 857*da0073e9SAndroid Build Coastguard Worker 858*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, r"read-only property"): 859*da0073e9SAndroid Build Coastguard Worker torch.backends.cuda.cufft_plan_cache.size = -1 860*da0073e9SAndroid Build Coastguard Worker 861*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, r"but got device with index"): 862*da0073e9SAndroid Build Coastguard Worker torch.backends.cuda.cufft_plan_cache[torch.cuda.device_count() + 10] 863*da0073e9SAndroid Build Coastguard Worker 864*da0073e9SAndroid Build Coastguard Worker # Multigpu tests 865*da0073e9SAndroid Build Coastguard Worker if len(devices) > 1: 866*da0073e9SAndroid Build Coastguard Worker # Test that different GPU has different cache 867*da0073e9SAndroid Build Coastguard Worker x0 = torch.randn(2, 3, 3, device=devices[0]) 868*da0073e9SAndroid Build Coastguard Worker x1 = x0.to(devices[1]) 869*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.fft.rfftn(x0, dim=(-2, -1)), torch.fft.rfftn(x1, dim=(-2, -1))) 870*da0073e9SAndroid Build Coastguard Worker # If a plan is used across different devices, the following line (or 871*da0073e9SAndroid Build Coastguard Worker # the assert above) would trigger illegal memory access. Other ways 872*da0073e9SAndroid Build Coastguard Worker # to trigger the error include 873*da0073e9SAndroid Build Coastguard Worker # (1) setting CUDA_LAUNCH_BLOCKING=1 (pytorch/pytorch#19224) and 874*da0073e9SAndroid Build Coastguard Worker # (2) printing a device 1 tensor. 875*da0073e9SAndroid Build Coastguard Worker x0.copy_(x1) 876*da0073e9SAndroid Build Coastguard Worker 877*da0073e9SAndroid Build Coastguard Worker # Test that un-indexed `torch.backends.cuda.cufft_plan_cache` uses current device 878*da0073e9SAndroid Build Coastguard Worker with plan_cache_max_size(devices[0], 10): 879*da0073e9SAndroid Build Coastguard Worker with plan_cache_max_size(devices[1], 11): 880*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.backends.cuda.cufft_plan_cache[0].max_size, 10) 881*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.backends.cuda.cufft_plan_cache[1].max_size, 11) 882*da0073e9SAndroid Build Coastguard Worker 883*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.backends.cuda.cufft_plan_cache.max_size, 10) # default is cuda:0 884*da0073e9SAndroid Build Coastguard Worker with torch.cuda.device(devices[1]): 885*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.backends.cuda.cufft_plan_cache.max_size, 11) # default is cuda:1 886*da0073e9SAndroid Build Coastguard Worker with torch.cuda.device(devices[0]): 887*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.backends.cuda.cufft_plan_cache.max_size, 10) # default is cuda:0 888*da0073e9SAndroid Build Coastguard Worker 889*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.backends.cuda.cufft_plan_cache[0].max_size, 10) 890*da0073e9SAndroid Build Coastguard Worker with torch.cuda.device(devices[1]): 891*da0073e9SAndroid Build Coastguard Worker with plan_cache_max_size(None, 11): # default is cuda:1 892*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.backends.cuda.cufft_plan_cache[0].max_size, 10) 893*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.backends.cuda.cufft_plan_cache[1].max_size, 11) 894*da0073e9SAndroid Build Coastguard Worker 895*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.backends.cuda.cufft_plan_cache.max_size, 11) # default is cuda:1 896*da0073e9SAndroid Build Coastguard Worker with torch.cuda.device(devices[0]): 897*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.backends.cuda.cufft_plan_cache.max_size, 10) # default is cuda:0 898*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.backends.cuda.cufft_plan_cache.max_size, 11) # default is cuda:1 899*da0073e9SAndroid Build Coastguard Worker 900*da0073e9SAndroid Build Coastguard Worker @onlyCUDA 901*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.cfloat, torch.cdouble) 902*da0073e9SAndroid Build Coastguard Worker def test_cufft_context(self, device, dtype): 903*da0073e9SAndroid Build Coastguard Worker # Regression test for https://github.com/pytorch/pytorch/issues/109448 904*da0073e9SAndroid Build Coastguard Worker x = torch.randn(32, dtype=dtype, device=device, requires_grad=True) 905*da0073e9SAndroid Build Coastguard Worker dout = torch.zeros(32, dtype=dtype, device=device) 906*da0073e9SAndroid Build Coastguard Worker 907*da0073e9SAndroid Build Coastguard Worker # compute iFFT(FFT(x)) 908*da0073e9SAndroid Build Coastguard Worker out = torch.fft.ifft(torch.fft.fft(x)) 909*da0073e9SAndroid Build Coastguard Worker out.backward(dout, retain_graph=True) 910*da0073e9SAndroid Build Coastguard Worker 911*da0073e9SAndroid Build Coastguard Worker dx = torch.fft.fft(torch.fft.ifft(dout)) 912*da0073e9SAndroid Build Coastguard Worker 913*da0073e9SAndroid Build Coastguard Worker self.assertTrue((x.grad - dx).abs().max() == 0) 914*da0073e9SAndroid Build Coastguard Worker self.assertFalse((x.grad - x).abs().max() == 0) 915*da0073e9SAndroid Build Coastguard Worker 916*da0073e9SAndroid Build Coastguard Worker # passes on ROCm w/ python 2.7, fails w/ python 3.6 917*da0073e9SAndroid Build Coastguard Worker @skipIfTorchDynamo("cannot set WRITEABLE flag to True of this array") 918*da0073e9SAndroid Build Coastguard Worker @skipCPUIfNoFFT 919*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 920*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.double) 921*da0073e9SAndroid Build Coastguard Worker def test_stft(self, device, dtype): 922*da0073e9SAndroid Build Coastguard Worker if not TEST_LIBROSA: 923*da0073e9SAndroid Build Coastguard Worker raise unittest.SkipTest('librosa not found') 924*da0073e9SAndroid Build Coastguard Worker 925*da0073e9SAndroid Build Coastguard Worker def librosa_stft(x, n_fft, hop_length, win_length, window, center): 926*da0073e9SAndroid Build Coastguard Worker if window is None: 927*da0073e9SAndroid Build Coastguard Worker window = np.ones(n_fft if win_length is None else win_length) 928*da0073e9SAndroid Build Coastguard Worker else: 929*da0073e9SAndroid Build Coastguard Worker window = window.cpu().numpy() 930*da0073e9SAndroid Build Coastguard Worker input_1d = x.dim() == 1 931*da0073e9SAndroid Build Coastguard Worker if input_1d: 932*da0073e9SAndroid Build Coastguard Worker x = x.view(1, -1) 933*da0073e9SAndroid Build Coastguard Worker 934*da0073e9SAndroid Build Coastguard Worker # NOTE: librosa 0.9 changed default pad_mode to 'constant' (zero padding) 935*da0073e9SAndroid Build Coastguard Worker # however, we use the pre-0.9 default ('reflect') 936*da0073e9SAndroid Build Coastguard Worker pad_mode = 'reflect' 937*da0073e9SAndroid Build Coastguard Worker 938*da0073e9SAndroid Build Coastguard Worker result = [] 939*da0073e9SAndroid Build Coastguard Worker for xi in x: 940*da0073e9SAndroid Build Coastguard Worker ri = librosa.stft(xi.cpu().numpy(), n_fft=n_fft, hop_length=hop_length, 941*da0073e9SAndroid Build Coastguard Worker win_length=win_length, window=window, center=center, 942*da0073e9SAndroid Build Coastguard Worker pad_mode=pad_mode) 943*da0073e9SAndroid Build Coastguard Worker result.append(torch.from_numpy(np.stack([ri.real, ri.imag], -1))) 944*da0073e9SAndroid Build Coastguard Worker result = torch.stack(result, 0) 945*da0073e9SAndroid Build Coastguard Worker if input_1d: 946*da0073e9SAndroid Build Coastguard Worker result = result[0] 947*da0073e9SAndroid Build Coastguard Worker return result 948*da0073e9SAndroid Build Coastguard Worker 949*da0073e9SAndroid Build Coastguard Worker def _test(sizes, n_fft, hop_length=None, win_length=None, win_sizes=None, 950*da0073e9SAndroid Build Coastguard Worker center=True, expected_error=None): 951*da0073e9SAndroid Build Coastguard Worker x = torch.randn(*sizes, dtype=dtype, device=device) 952*da0073e9SAndroid Build Coastguard Worker if win_sizes is not None: 953*da0073e9SAndroid Build Coastguard Worker window = torch.randn(*win_sizes, dtype=dtype, device=device) 954*da0073e9SAndroid Build Coastguard Worker else: 955*da0073e9SAndroid Build Coastguard Worker window = None 956*da0073e9SAndroid Build Coastguard Worker if expected_error is None: 957*da0073e9SAndroid Build Coastguard Worker result = x.stft(n_fft, hop_length, win_length, window, 958*da0073e9SAndroid Build Coastguard Worker center=center, return_complex=False) 959*da0073e9SAndroid Build Coastguard Worker # NB: librosa defaults to np.complex64 output, no matter what 960*da0073e9SAndroid Build Coastguard Worker # the input dtype 961*da0073e9SAndroid Build Coastguard Worker ref_result = librosa_stft(x, n_fft, hop_length, win_length, window, center) 962*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result, ref_result, atol=7e-6, rtol=0, msg='stft comparison against librosa', exact_dtype=False) 963*da0073e9SAndroid Build Coastguard Worker # With return_complex=True, the result is the same but viewed as complex instead of real 964*da0073e9SAndroid Build Coastguard Worker result_complex = x.stft(n_fft, hop_length, win_length, window, center=center, return_complex=True) 965*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result_complex, torch.view_as_complex(result)) 966*da0073e9SAndroid Build Coastguard Worker else: 967*da0073e9SAndroid Build Coastguard Worker self.assertRaises(expected_error, 968*da0073e9SAndroid Build Coastguard Worker lambda: x.stft(n_fft, hop_length, win_length, window, center=center)) 969*da0073e9SAndroid Build Coastguard Worker 970*da0073e9SAndroid Build Coastguard Worker for center in [True, False]: 971*da0073e9SAndroid Build Coastguard Worker _test((10,), 7, center=center) 972*da0073e9SAndroid Build Coastguard Worker _test((10, 4000), 1024, center=center) 973*da0073e9SAndroid Build Coastguard Worker 974*da0073e9SAndroid Build Coastguard Worker _test((10,), 7, 2, center=center) 975*da0073e9SAndroid Build Coastguard Worker _test((10, 4000), 1024, 512, center=center) 976*da0073e9SAndroid Build Coastguard Worker 977*da0073e9SAndroid Build Coastguard Worker _test((10,), 7, 2, win_sizes=(7,), center=center) 978*da0073e9SAndroid Build Coastguard Worker _test((10, 4000), 1024, 512, win_sizes=(1024,), center=center) 979*da0073e9SAndroid Build Coastguard Worker 980*da0073e9SAndroid Build Coastguard Worker # spectral oversample 981*da0073e9SAndroid Build Coastguard Worker _test((10,), 7, 2, win_length=5, center=center) 982*da0073e9SAndroid Build Coastguard Worker _test((10, 4000), 1024, 512, win_length=100, center=center) 983*da0073e9SAndroid Build Coastguard Worker 984*da0073e9SAndroid Build Coastguard Worker _test((10, 4, 2), 1, 1, expected_error=RuntimeError) 985*da0073e9SAndroid Build Coastguard Worker _test((10,), 11, 1, center=False, expected_error=RuntimeError) 986*da0073e9SAndroid Build Coastguard Worker _test((10,), -1, 1, expected_error=RuntimeError) 987*da0073e9SAndroid Build Coastguard Worker _test((10,), 3, win_length=5, expected_error=RuntimeError) 988*da0073e9SAndroid Build Coastguard Worker _test((10,), 5, 4, win_sizes=(11,), expected_error=RuntimeError) 989*da0073e9SAndroid Build Coastguard Worker _test((10,), 5, 4, win_sizes=(1, 1), expected_error=RuntimeError) 990*da0073e9SAndroid Build Coastguard Worker 991*da0073e9SAndroid Build Coastguard Worker @skipIfTorchDynamo("double") 992*da0073e9SAndroid Build Coastguard Worker @skipCPUIfNoFFT 993*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 994*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.double) 995*da0073e9SAndroid Build Coastguard Worker def test_istft_against_librosa(self, device, dtype): 996*da0073e9SAndroid Build Coastguard Worker if not TEST_LIBROSA: 997*da0073e9SAndroid Build Coastguard Worker raise unittest.SkipTest('librosa not found') 998*da0073e9SAndroid Build Coastguard Worker 999*da0073e9SAndroid Build Coastguard Worker def librosa_istft(x, n_fft, hop_length, win_length, window, length, center): 1000*da0073e9SAndroid Build Coastguard Worker if window is None: 1001*da0073e9SAndroid Build Coastguard Worker window = np.ones(n_fft if win_length is None else win_length) 1002*da0073e9SAndroid Build Coastguard Worker else: 1003*da0073e9SAndroid Build Coastguard Worker window = window.cpu().numpy() 1004*da0073e9SAndroid Build Coastguard Worker 1005*da0073e9SAndroid Build Coastguard Worker return librosa.istft(x.cpu().numpy(), n_fft=n_fft, hop_length=hop_length, 1006*da0073e9SAndroid Build Coastguard Worker win_length=win_length, length=length, window=window, center=center) 1007*da0073e9SAndroid Build Coastguard Worker 1008*da0073e9SAndroid Build Coastguard Worker def _test(size, n_fft, hop_length=None, win_length=None, win_sizes=None, 1009*da0073e9SAndroid Build Coastguard Worker length=None, center=True): 1010*da0073e9SAndroid Build Coastguard Worker x = torch.randn(size, dtype=dtype, device=device) 1011*da0073e9SAndroid Build Coastguard Worker if win_sizes is not None: 1012*da0073e9SAndroid Build Coastguard Worker window = torch.randn(*win_sizes, dtype=dtype, device=device) 1013*da0073e9SAndroid Build Coastguard Worker else: 1014*da0073e9SAndroid Build Coastguard Worker window = None 1015*da0073e9SAndroid Build Coastguard Worker 1016*da0073e9SAndroid Build Coastguard Worker x_stft = x.stft(n_fft, hop_length, win_length, window, center=center, 1017*da0073e9SAndroid Build Coastguard Worker onesided=True, return_complex=True) 1018*da0073e9SAndroid Build Coastguard Worker 1019*da0073e9SAndroid Build Coastguard Worker ref_result = librosa_istft(x_stft, n_fft, hop_length, win_length, 1020*da0073e9SAndroid Build Coastguard Worker window, length, center) 1021*da0073e9SAndroid Build Coastguard Worker result = x_stft.istft(n_fft, hop_length, win_length, window, 1022*da0073e9SAndroid Build Coastguard Worker length=length, center=center) 1023*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result, ref_result) 1024*da0073e9SAndroid Build Coastguard Worker 1025*da0073e9SAndroid Build Coastguard Worker for center in [True, False]: 1026*da0073e9SAndroid Build Coastguard Worker _test(10, 7, center=center) 1027*da0073e9SAndroid Build Coastguard Worker _test(4000, 1024, center=center) 1028*da0073e9SAndroid Build Coastguard Worker _test(4000, 1024, center=center, length=4000) 1029*da0073e9SAndroid Build Coastguard Worker 1030*da0073e9SAndroid Build Coastguard Worker _test(10, 7, 2, center=center) 1031*da0073e9SAndroid Build Coastguard Worker _test(4000, 1024, 512, center=center) 1032*da0073e9SAndroid Build Coastguard Worker _test(4000, 1024, 512, center=center, length=4000) 1033*da0073e9SAndroid Build Coastguard Worker 1034*da0073e9SAndroid Build Coastguard Worker _test(10, 7, 2, win_sizes=(7,), center=center) 1035*da0073e9SAndroid Build Coastguard Worker _test(4000, 1024, 512, win_sizes=(1024,), center=center) 1036*da0073e9SAndroid Build Coastguard Worker _test(4000, 1024, 512, win_sizes=(1024,), center=center, length=4000) 1037*da0073e9SAndroid Build Coastguard Worker 1038*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 1039*da0073e9SAndroid Build Coastguard Worker @skipCPUIfNoFFT 1040*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.double, torch.cdouble) 1041*da0073e9SAndroid Build Coastguard Worker def test_complex_stft_roundtrip(self, device, dtype): 1042*da0073e9SAndroid Build Coastguard Worker test_args = list(product( 1043*da0073e9SAndroid Build Coastguard Worker # input 1044*da0073e9SAndroid Build Coastguard Worker (torch.randn(600, device=device, dtype=dtype), 1045*da0073e9SAndroid Build Coastguard Worker torch.randn(807, device=device, dtype=dtype), 1046*da0073e9SAndroid Build Coastguard Worker torch.randn(12, 60, device=device, dtype=dtype)), 1047*da0073e9SAndroid Build Coastguard Worker # n_fft 1048*da0073e9SAndroid Build Coastguard Worker (50, 27), 1049*da0073e9SAndroid Build Coastguard Worker # hop_length 1050*da0073e9SAndroid Build Coastguard Worker (None, 10), 1051*da0073e9SAndroid Build Coastguard Worker # center 1052*da0073e9SAndroid Build Coastguard Worker (True,), 1053*da0073e9SAndroid Build Coastguard Worker # pad_mode 1054*da0073e9SAndroid Build Coastguard Worker ("constant", "reflect", "circular"), 1055*da0073e9SAndroid Build Coastguard Worker # normalized 1056*da0073e9SAndroid Build Coastguard Worker (True, False), 1057*da0073e9SAndroid Build Coastguard Worker # onesided 1058*da0073e9SAndroid Build Coastguard Worker (True, False) if not dtype.is_complex else (False,), 1059*da0073e9SAndroid Build Coastguard Worker )) 1060*da0073e9SAndroid Build Coastguard Worker 1061*da0073e9SAndroid Build Coastguard Worker for args in test_args: 1062*da0073e9SAndroid Build Coastguard Worker x, n_fft, hop_length, center, pad_mode, normalized, onesided = args 1063*da0073e9SAndroid Build Coastguard Worker common_kwargs = { 1064*da0073e9SAndroid Build Coastguard Worker 'n_fft': n_fft, 'hop_length': hop_length, 'center': center, 1065*da0073e9SAndroid Build Coastguard Worker 'normalized': normalized, 'onesided': onesided, 1066*da0073e9SAndroid Build Coastguard Worker } 1067*da0073e9SAndroid Build Coastguard Worker 1068*da0073e9SAndroid Build Coastguard Worker # Functional interface 1069*da0073e9SAndroid Build Coastguard Worker x_stft = torch.stft(x, pad_mode=pad_mode, return_complex=True, **common_kwargs) 1070*da0073e9SAndroid Build Coastguard Worker x_roundtrip = torch.istft(x_stft, return_complex=dtype.is_complex, 1071*da0073e9SAndroid Build Coastguard Worker length=x.size(-1), **common_kwargs) 1072*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x_roundtrip, x) 1073*da0073e9SAndroid Build Coastguard Worker 1074*da0073e9SAndroid Build Coastguard Worker # Tensor method interface 1075*da0073e9SAndroid Build Coastguard Worker x_stft = x.stft(pad_mode=pad_mode, return_complex=True, **common_kwargs) 1076*da0073e9SAndroid Build Coastguard Worker x_roundtrip = torch.istft(x_stft, return_complex=dtype.is_complex, 1077*da0073e9SAndroid Build Coastguard Worker length=x.size(-1), **common_kwargs) 1078*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x_roundtrip, x) 1079*da0073e9SAndroid Build Coastguard Worker 1080*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 1081*da0073e9SAndroid Build Coastguard Worker @skipCPUIfNoFFT 1082*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.double, torch.cdouble) 1083*da0073e9SAndroid Build Coastguard Worker def test_stft_roundtrip_complex_window(self, device, dtype): 1084*da0073e9SAndroid Build Coastguard Worker test_args = list(product( 1085*da0073e9SAndroid Build Coastguard Worker # input 1086*da0073e9SAndroid Build Coastguard Worker (torch.randn(600, device=device, dtype=dtype), 1087*da0073e9SAndroid Build Coastguard Worker torch.randn(807, device=device, dtype=dtype), 1088*da0073e9SAndroid Build Coastguard Worker torch.randn(12, 60, device=device, dtype=dtype)), 1089*da0073e9SAndroid Build Coastguard Worker # n_fft 1090*da0073e9SAndroid Build Coastguard Worker (50, 27), 1091*da0073e9SAndroid Build Coastguard Worker # hop_length 1092*da0073e9SAndroid Build Coastguard Worker (None, 10), 1093*da0073e9SAndroid Build Coastguard Worker # pad_mode 1094*da0073e9SAndroid Build Coastguard Worker ("constant", "reflect", "replicate", "circular"), 1095*da0073e9SAndroid Build Coastguard Worker # normalized 1096*da0073e9SAndroid Build Coastguard Worker (True, False), 1097*da0073e9SAndroid Build Coastguard Worker )) 1098*da0073e9SAndroid Build Coastguard Worker for args in test_args: 1099*da0073e9SAndroid Build Coastguard Worker x, n_fft, hop_length, pad_mode, normalized = args 1100*da0073e9SAndroid Build Coastguard Worker window = torch.rand(n_fft, device=device, dtype=torch.cdouble) 1101*da0073e9SAndroid Build Coastguard Worker x_stft = torch.stft( 1102*da0073e9SAndroid Build Coastguard Worker x, n_fft=n_fft, hop_length=hop_length, window=window, 1103*da0073e9SAndroid Build Coastguard Worker center=True, pad_mode=pad_mode, normalized=normalized) 1104*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x_stft.dtype, torch.cdouble) 1105*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x_stft.size(-2), n_fft) # Not onesided 1106*da0073e9SAndroid Build Coastguard Worker 1107*da0073e9SAndroid Build Coastguard Worker x_roundtrip = torch.istft( 1108*da0073e9SAndroid Build Coastguard Worker x_stft, n_fft=n_fft, hop_length=hop_length, window=window, 1109*da0073e9SAndroid Build Coastguard Worker center=True, normalized=normalized, length=x.size(-1), 1110*da0073e9SAndroid Build Coastguard Worker return_complex=True) 1111*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x_stft.dtype, torch.cdouble) 1112*da0073e9SAndroid Build Coastguard Worker 1113*da0073e9SAndroid Build Coastguard Worker if not dtype.is_complex: 1114*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x_roundtrip.imag, torch.zeros_like(x_roundtrip.imag), 1115*da0073e9SAndroid Build Coastguard Worker atol=1e-6, rtol=0) 1116*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x_roundtrip.real, x) 1117*da0073e9SAndroid Build Coastguard Worker else: 1118*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x_roundtrip, x) 1119*da0073e9SAndroid Build Coastguard Worker 1120*da0073e9SAndroid Build Coastguard Worker 1121*da0073e9SAndroid Build Coastguard Worker @skipCPUIfNoFFT 1122*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.cdouble) 1123*da0073e9SAndroid Build Coastguard Worker def test_complex_stft_definition(self, device, dtype): 1124*da0073e9SAndroid Build Coastguard Worker test_args = list(product( 1125*da0073e9SAndroid Build Coastguard Worker # input 1126*da0073e9SAndroid Build Coastguard Worker (torch.randn(600, device=device, dtype=dtype), 1127*da0073e9SAndroid Build Coastguard Worker torch.randn(807, device=device, dtype=dtype)), 1128*da0073e9SAndroid Build Coastguard Worker # n_fft 1129*da0073e9SAndroid Build Coastguard Worker (50, 27), 1130*da0073e9SAndroid Build Coastguard Worker # hop_length 1131*da0073e9SAndroid Build Coastguard Worker (10, 15) 1132*da0073e9SAndroid Build Coastguard Worker )) 1133*da0073e9SAndroid Build Coastguard Worker 1134*da0073e9SAndroid Build Coastguard Worker for args in test_args: 1135*da0073e9SAndroid Build Coastguard Worker window = torch.randn(args[1], device=device, dtype=dtype) 1136*da0073e9SAndroid Build Coastguard Worker expected = _stft_reference(args[0], args[2], window) 1137*da0073e9SAndroid Build Coastguard Worker actual = torch.stft(*args, window=window, center=False) 1138*da0073e9SAndroid Build Coastguard Worker self.assertEqual(actual, expected) 1139*da0073e9SAndroid Build Coastguard Worker 1140*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 1141*da0073e9SAndroid Build Coastguard Worker @skipCPUIfNoFFT 1142*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.cdouble) 1143*da0073e9SAndroid Build Coastguard Worker def test_complex_stft_real_equiv(self, device, dtype): 1144*da0073e9SAndroid Build Coastguard Worker test_args = list(product( 1145*da0073e9SAndroid Build Coastguard Worker # input 1146*da0073e9SAndroid Build Coastguard Worker (torch.rand(600, device=device, dtype=dtype), 1147*da0073e9SAndroid Build Coastguard Worker torch.rand(807, device=device, dtype=dtype), 1148*da0073e9SAndroid Build Coastguard Worker torch.rand(14, 50, device=device, dtype=dtype), 1149*da0073e9SAndroid Build Coastguard Worker torch.rand(6, 51, device=device, dtype=dtype)), 1150*da0073e9SAndroid Build Coastguard Worker # n_fft 1151*da0073e9SAndroid Build Coastguard Worker (50, 27), 1152*da0073e9SAndroid Build Coastguard Worker # hop_length 1153*da0073e9SAndroid Build Coastguard Worker (None, 10), 1154*da0073e9SAndroid Build Coastguard Worker # win_length 1155*da0073e9SAndroid Build Coastguard Worker (None, 20), 1156*da0073e9SAndroid Build Coastguard Worker # center 1157*da0073e9SAndroid Build Coastguard Worker (False, True), 1158*da0073e9SAndroid Build Coastguard Worker # pad_mode 1159*da0073e9SAndroid Build Coastguard Worker ("constant", "reflect", "circular"), 1160*da0073e9SAndroid Build Coastguard Worker # normalized 1161*da0073e9SAndroid Build Coastguard Worker (True, False), 1162*da0073e9SAndroid Build Coastguard Worker )) 1163*da0073e9SAndroid Build Coastguard Worker 1164*da0073e9SAndroid Build Coastguard Worker for args in test_args: 1165*da0073e9SAndroid Build Coastguard Worker x, n_fft, hop_length, win_length, center, pad_mode, normalized = args 1166*da0073e9SAndroid Build Coastguard Worker expected = _complex_stft(x, n_fft, hop_length=hop_length, 1167*da0073e9SAndroid Build Coastguard Worker win_length=win_length, pad_mode=pad_mode, 1168*da0073e9SAndroid Build Coastguard Worker center=center, normalized=normalized) 1169*da0073e9SAndroid Build Coastguard Worker actual = torch.stft(x, n_fft, hop_length=hop_length, 1170*da0073e9SAndroid Build Coastguard Worker win_length=win_length, pad_mode=pad_mode, 1171*da0073e9SAndroid Build Coastguard Worker center=center, normalized=normalized) 1172*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected, actual) 1173*da0073e9SAndroid Build Coastguard Worker 1174*da0073e9SAndroid Build Coastguard Worker @skipCPUIfNoFFT 1175*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.cdouble) 1176*da0073e9SAndroid Build Coastguard Worker def test_complex_istft_real_equiv(self, device, dtype): 1177*da0073e9SAndroid Build Coastguard Worker test_args = list(product( 1178*da0073e9SAndroid Build Coastguard Worker # input 1179*da0073e9SAndroid Build Coastguard Worker (torch.rand(40, 20, device=device, dtype=dtype), 1180*da0073e9SAndroid Build Coastguard Worker torch.rand(25, 1, device=device, dtype=dtype), 1181*da0073e9SAndroid Build Coastguard Worker torch.rand(4, 20, 10, device=device, dtype=dtype)), 1182*da0073e9SAndroid Build Coastguard Worker # hop_length 1183*da0073e9SAndroid Build Coastguard Worker (None, 10), 1184*da0073e9SAndroid Build Coastguard Worker # center 1185*da0073e9SAndroid Build Coastguard Worker (False, True), 1186*da0073e9SAndroid Build Coastguard Worker # normalized 1187*da0073e9SAndroid Build Coastguard Worker (True, False), 1188*da0073e9SAndroid Build Coastguard Worker )) 1189*da0073e9SAndroid Build Coastguard Worker 1190*da0073e9SAndroid Build Coastguard Worker for args in test_args: 1191*da0073e9SAndroid Build Coastguard Worker x, hop_length, center, normalized = args 1192*da0073e9SAndroid Build Coastguard Worker n_fft = x.size(-2) 1193*da0073e9SAndroid Build Coastguard Worker expected = _complex_istft(x, n_fft, hop_length=hop_length, 1194*da0073e9SAndroid Build Coastguard Worker center=center, normalized=normalized) 1195*da0073e9SAndroid Build Coastguard Worker actual = torch.istft(x, n_fft, hop_length=hop_length, 1196*da0073e9SAndroid Build Coastguard Worker center=center, normalized=normalized, 1197*da0073e9SAndroid Build Coastguard Worker return_complex=True) 1198*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected, actual) 1199*da0073e9SAndroid Build Coastguard Worker 1200*da0073e9SAndroid Build Coastguard Worker @skipCPUIfNoFFT 1201*da0073e9SAndroid Build Coastguard Worker def test_complex_stft_onesided(self, device): 1202*da0073e9SAndroid Build Coastguard Worker # stft of complex input cannot be onesided 1203*da0073e9SAndroid Build Coastguard Worker for x_dtype, window_dtype in product((torch.double, torch.cdouble), repeat=2): 1204*da0073e9SAndroid Build Coastguard Worker x = torch.rand(100, device=device, dtype=x_dtype) 1205*da0073e9SAndroid Build Coastguard Worker window = torch.rand(10, device=device, dtype=window_dtype) 1206*da0073e9SAndroid Build Coastguard Worker 1207*da0073e9SAndroid Build Coastguard Worker if x_dtype.is_complex or window_dtype.is_complex: 1208*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, 'complex'): 1209*da0073e9SAndroid Build Coastguard Worker x.stft(10, window=window, pad_mode='constant', onesided=True) 1210*da0073e9SAndroid Build Coastguard Worker else: 1211*da0073e9SAndroid Build Coastguard Worker y = x.stft(10, window=window, pad_mode='constant', onesided=True, 1212*da0073e9SAndroid Build Coastguard Worker return_complex=True) 1213*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y.dtype, torch.cdouble) 1214*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y.size(), (6, 51)) 1215*da0073e9SAndroid Build Coastguard Worker 1216*da0073e9SAndroid Build Coastguard Worker x = torch.rand(100, device=device, dtype=torch.cdouble) 1217*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, 'complex'): 1218*da0073e9SAndroid Build Coastguard Worker x.stft(10, pad_mode='constant', onesided=True) 1219*da0073e9SAndroid Build Coastguard Worker 1220*da0073e9SAndroid Build Coastguard Worker # stft is currently warning that it requires return-complex while an upgrader is written 1221*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 1222*da0073e9SAndroid Build Coastguard Worker @skipCPUIfNoFFT 1223*da0073e9SAndroid Build Coastguard Worker def test_stft_requires_complex(self, device): 1224*da0073e9SAndroid Build Coastguard Worker x = torch.rand(100) 1225*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, 'stft requires the return_complex parameter'): 1226*da0073e9SAndroid Build Coastguard Worker y = x.stft(10, pad_mode='constant') 1227*da0073e9SAndroid Build Coastguard Worker 1228*da0073e9SAndroid Build Coastguard Worker # stft and istft are currently warning if a window is not provided 1229*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 1230*da0073e9SAndroid Build Coastguard Worker @skipCPUIfNoFFT 1231*da0073e9SAndroid Build Coastguard Worker def test_stft_requires_window(self, device): 1232*da0073e9SAndroid Build Coastguard Worker x = torch.rand(100) 1233*da0073e9SAndroid Build Coastguard Worker with self.assertWarnsOnceRegex(UserWarning, "A window was not provided"): 1234*da0073e9SAndroid Build Coastguard Worker y = x.stft(10, pad_mode='constant', return_complex=True) 1235*da0073e9SAndroid Build Coastguard Worker 1236*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 1237*da0073e9SAndroid Build Coastguard Worker @skipCPUIfNoFFT 1238*da0073e9SAndroid Build Coastguard Worker def test_istft_requires_window(self, device): 1239*da0073e9SAndroid Build Coastguard Worker stft = torch.rand((51, 5), dtype=torch.cdouble) 1240*da0073e9SAndroid Build Coastguard Worker # 51 = 2 * n_fft + 1, 5 = number of frames 1241*da0073e9SAndroid Build Coastguard Worker with self.assertWarnsOnceRegex(UserWarning, "A window was not provided"): 1242*da0073e9SAndroid Build Coastguard Worker x = torch.istft(stft, n_fft=100, length=100) 1243*da0073e9SAndroid Build Coastguard Worker 1244*da0073e9SAndroid Build Coastguard Worker @skipCPUIfNoFFT 1245*da0073e9SAndroid Build Coastguard Worker def test_fft_input_modification(self, device): 1246*da0073e9SAndroid Build Coastguard Worker # FFT functions should not modify their input (gh-34551) 1247*da0073e9SAndroid Build Coastguard Worker 1248*da0073e9SAndroid Build Coastguard Worker signal = torch.ones((2, 2, 2), device=device) 1249*da0073e9SAndroid Build Coastguard Worker signal_copy = signal.clone() 1250*da0073e9SAndroid Build Coastguard Worker spectrum = torch.fft.fftn(signal, dim=(-2, -1)) 1251*da0073e9SAndroid Build Coastguard Worker self.assertEqual(signal, signal_copy) 1252*da0073e9SAndroid Build Coastguard Worker 1253*da0073e9SAndroid Build Coastguard Worker spectrum_copy = spectrum.clone() 1254*da0073e9SAndroid Build Coastguard Worker _ = torch.fft.ifftn(spectrum, dim=(-2, -1)) 1255*da0073e9SAndroid Build Coastguard Worker self.assertEqual(spectrum, spectrum_copy) 1256*da0073e9SAndroid Build Coastguard Worker 1257*da0073e9SAndroid Build Coastguard Worker half_spectrum = torch.fft.rfftn(signal, dim=(-2, -1)) 1258*da0073e9SAndroid Build Coastguard Worker self.assertEqual(signal, signal_copy) 1259*da0073e9SAndroid Build Coastguard Worker 1260*da0073e9SAndroid Build Coastguard Worker half_spectrum_copy = half_spectrum.clone() 1261*da0073e9SAndroid Build Coastguard Worker _ = torch.fft.irfftn(half_spectrum_copy, s=(2, 2), dim=(-2, -1)) 1262*da0073e9SAndroid Build Coastguard Worker self.assertEqual(half_spectrum, half_spectrum_copy) 1263*da0073e9SAndroid Build Coastguard Worker 1264*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 1265*da0073e9SAndroid Build Coastguard Worker @skipCPUIfNoFFT 1266*da0073e9SAndroid Build Coastguard Worker def test_fft_plan_repeatable(self, device): 1267*da0073e9SAndroid Build Coastguard Worker # Regression test for gh-58724 and gh-63152 1268*da0073e9SAndroid Build Coastguard Worker for n in [2048, 3199, 5999]: 1269*da0073e9SAndroid Build Coastguard Worker a = torch.randn(n, device=device, dtype=torch.complex64) 1270*da0073e9SAndroid Build Coastguard Worker res1 = torch.fft.fftn(a) 1271*da0073e9SAndroid Build Coastguard Worker res2 = torch.fft.fftn(a.clone()) 1272*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res1, res2) 1273*da0073e9SAndroid Build Coastguard Worker 1274*da0073e9SAndroid Build Coastguard Worker a = torch.randn(n, device=device, dtype=torch.float64) 1275*da0073e9SAndroid Build Coastguard Worker res1 = torch.fft.rfft(a) 1276*da0073e9SAndroid Build Coastguard Worker res2 = torch.fft.rfft(a.clone()) 1277*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res1, res2) 1278*da0073e9SAndroid Build Coastguard Worker 1279*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 1280*da0073e9SAndroid Build Coastguard Worker @skipCPUIfNoFFT 1281*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.double) 1282*da0073e9SAndroid Build Coastguard Worker def test_istft_round_trip_simple_cases(self, device, dtype): 1283*da0073e9SAndroid Build Coastguard Worker """stft -> istft should recover the original signale""" 1284*da0073e9SAndroid Build Coastguard Worker def _test(input, n_fft, length): 1285*da0073e9SAndroid Build Coastguard Worker stft = torch.stft(input, n_fft=n_fft, return_complex=True) 1286*da0073e9SAndroid Build Coastguard Worker inverse = torch.istft(stft, n_fft=n_fft, length=length) 1287*da0073e9SAndroid Build Coastguard Worker self.assertEqual(input, inverse, exact_dtype=True) 1288*da0073e9SAndroid Build Coastguard Worker 1289*da0073e9SAndroid Build Coastguard Worker _test(torch.ones(4, dtype=dtype, device=device), 4, 4) 1290*da0073e9SAndroid Build Coastguard Worker _test(torch.zeros(4, dtype=dtype, device=device), 4, 4) 1291*da0073e9SAndroid Build Coastguard Worker 1292*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 1293*da0073e9SAndroid Build Coastguard Worker @skipCPUIfNoFFT 1294*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.double) 1295*da0073e9SAndroid Build Coastguard Worker def test_istft_round_trip_various_params(self, device, dtype): 1296*da0073e9SAndroid Build Coastguard Worker """stft -> istft should recover the original signale""" 1297*da0073e9SAndroid Build Coastguard Worker def _test_istft_is_inverse_of_stft(stft_kwargs): 1298*da0073e9SAndroid Build Coastguard Worker # generates a random sound signal for each tril and then does the stft/istft 1299*da0073e9SAndroid Build Coastguard Worker # operation to check whether we can reconstruct signal 1300*da0073e9SAndroid Build Coastguard Worker data_sizes = [(2, 20), (3, 15), (4, 10)] 1301*da0073e9SAndroid Build Coastguard Worker num_trials = 100 1302*da0073e9SAndroid Build Coastguard Worker istft_kwargs = stft_kwargs.copy() 1303*da0073e9SAndroid Build Coastguard Worker del istft_kwargs['pad_mode'] 1304*da0073e9SAndroid Build Coastguard Worker for sizes in data_sizes: 1305*da0073e9SAndroid Build Coastguard Worker for i in range(num_trials): 1306*da0073e9SAndroid Build Coastguard Worker original = torch.randn(*sizes, dtype=dtype, device=device) 1307*da0073e9SAndroid Build Coastguard Worker stft = torch.stft(original, return_complex=True, **stft_kwargs) 1308*da0073e9SAndroid Build Coastguard Worker inversed = torch.istft(stft, length=original.size(1), **istft_kwargs) 1309*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 1310*da0073e9SAndroid Build Coastguard Worker inversed, original, msg='istft comparison against original', 1311*da0073e9SAndroid Build Coastguard Worker atol=7e-6, rtol=0, exact_dtype=True) 1312*da0073e9SAndroid Build Coastguard Worker 1313*da0073e9SAndroid Build Coastguard Worker patterns = [ 1314*da0073e9SAndroid Build Coastguard Worker # hann_window, centered, normalized, onesided 1315*da0073e9SAndroid Build Coastguard Worker { 1316*da0073e9SAndroid Build Coastguard Worker 'n_fft': 12, 1317*da0073e9SAndroid Build Coastguard Worker 'hop_length': 4, 1318*da0073e9SAndroid Build Coastguard Worker 'win_length': 12, 1319*da0073e9SAndroid Build Coastguard Worker 'window': torch.hann_window(12, dtype=dtype, device=device), 1320*da0073e9SAndroid Build Coastguard Worker 'center': True, 1321*da0073e9SAndroid Build Coastguard Worker 'pad_mode': 'reflect', 1322*da0073e9SAndroid Build Coastguard Worker 'normalized': True, 1323*da0073e9SAndroid Build Coastguard Worker 'onesided': True, 1324*da0073e9SAndroid Build Coastguard Worker }, 1325*da0073e9SAndroid Build Coastguard Worker # hann_window, centered, not normalized, not onesided 1326*da0073e9SAndroid Build Coastguard Worker { 1327*da0073e9SAndroid Build Coastguard Worker 'n_fft': 12, 1328*da0073e9SAndroid Build Coastguard Worker 'hop_length': 2, 1329*da0073e9SAndroid Build Coastguard Worker 'win_length': 8, 1330*da0073e9SAndroid Build Coastguard Worker 'window': torch.hann_window(8, dtype=dtype, device=device), 1331*da0073e9SAndroid Build Coastguard Worker 'center': True, 1332*da0073e9SAndroid Build Coastguard Worker 'pad_mode': 'reflect', 1333*da0073e9SAndroid Build Coastguard Worker 'normalized': False, 1334*da0073e9SAndroid Build Coastguard Worker 'onesided': False, 1335*da0073e9SAndroid Build Coastguard Worker }, 1336*da0073e9SAndroid Build Coastguard Worker # hamming_window, centered, normalized, not onesided 1337*da0073e9SAndroid Build Coastguard Worker { 1338*da0073e9SAndroid Build Coastguard Worker 'n_fft': 15, 1339*da0073e9SAndroid Build Coastguard Worker 'hop_length': 3, 1340*da0073e9SAndroid Build Coastguard Worker 'win_length': 11, 1341*da0073e9SAndroid Build Coastguard Worker 'window': torch.hamming_window(11, dtype=dtype, device=device), 1342*da0073e9SAndroid Build Coastguard Worker 'center': True, 1343*da0073e9SAndroid Build Coastguard Worker 'pad_mode': 'constant', 1344*da0073e9SAndroid Build Coastguard Worker 'normalized': True, 1345*da0073e9SAndroid Build Coastguard Worker 'onesided': False, 1346*da0073e9SAndroid Build Coastguard Worker }, 1347*da0073e9SAndroid Build Coastguard Worker # hamming_window, centered, not normalized, onesided 1348*da0073e9SAndroid Build Coastguard Worker # window same size as n_fft 1349*da0073e9SAndroid Build Coastguard Worker { 1350*da0073e9SAndroid Build Coastguard Worker 'n_fft': 5, 1351*da0073e9SAndroid Build Coastguard Worker 'hop_length': 2, 1352*da0073e9SAndroid Build Coastguard Worker 'win_length': 5, 1353*da0073e9SAndroid Build Coastguard Worker 'window': torch.hamming_window(5, dtype=dtype, device=device), 1354*da0073e9SAndroid Build Coastguard Worker 'center': True, 1355*da0073e9SAndroid Build Coastguard Worker 'pad_mode': 'constant', 1356*da0073e9SAndroid Build Coastguard Worker 'normalized': False, 1357*da0073e9SAndroid Build Coastguard Worker 'onesided': True, 1358*da0073e9SAndroid Build Coastguard Worker }, 1359*da0073e9SAndroid Build Coastguard Worker ] 1360*da0073e9SAndroid Build Coastguard Worker for i, pattern in enumerate(patterns): 1361*da0073e9SAndroid Build Coastguard Worker _test_istft_is_inverse_of_stft(pattern) 1362*da0073e9SAndroid Build Coastguard Worker 1363*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 1364*da0073e9SAndroid Build Coastguard Worker @skipCPUIfNoFFT 1365*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.double) 1366*da0073e9SAndroid Build Coastguard Worker def test_istft_round_trip_with_padding(self, device, dtype): 1367*da0073e9SAndroid Build Coastguard Worker """long hop_length or not centered may cause length mismatch in the inversed signal""" 1368*da0073e9SAndroid Build Coastguard Worker def _test_istft_is_inverse_of_stft_with_padding(stft_kwargs): 1369*da0073e9SAndroid Build Coastguard Worker # generates a random sound signal for each tril and then does the stft/istft 1370*da0073e9SAndroid Build Coastguard Worker # operation to check whether we can reconstruct signal 1371*da0073e9SAndroid Build Coastguard Worker num_trials = 100 1372*da0073e9SAndroid Build Coastguard Worker sizes = stft_kwargs['size'] 1373*da0073e9SAndroid Build Coastguard Worker del stft_kwargs['size'] 1374*da0073e9SAndroid Build Coastguard Worker istft_kwargs = stft_kwargs.copy() 1375*da0073e9SAndroid Build Coastguard Worker del istft_kwargs['pad_mode'] 1376*da0073e9SAndroid Build Coastguard Worker for i in range(num_trials): 1377*da0073e9SAndroid Build Coastguard Worker original = torch.randn(*sizes, dtype=dtype, device=device) 1378*da0073e9SAndroid Build Coastguard Worker stft = torch.stft(original, return_complex=True, **stft_kwargs) 1379*da0073e9SAndroid Build Coastguard Worker with self.assertWarnsOnceRegex(UserWarning, "The length of signal is shorter than the length parameter."): 1380*da0073e9SAndroid Build Coastguard Worker inversed = torch.istft(stft, length=original.size(-1), **istft_kwargs) 1381*da0073e9SAndroid Build Coastguard Worker n_frames = stft.size(-1) 1382*da0073e9SAndroid Build Coastguard Worker if stft_kwargs["center"] is True: 1383*da0073e9SAndroid Build Coastguard Worker len_expected = stft_kwargs["n_fft"] // 2 + stft_kwargs["hop_length"] * (n_frames - 1) 1384*da0073e9SAndroid Build Coastguard Worker else: 1385*da0073e9SAndroid Build Coastguard Worker len_expected = stft_kwargs["n_fft"] + stft_kwargs["hop_length"] * (n_frames - 1) 1386*da0073e9SAndroid Build Coastguard Worker # trim the original for case when constructed signal is shorter than original 1387*da0073e9SAndroid Build Coastguard Worker padding = inversed[..., len_expected:] 1388*da0073e9SAndroid Build Coastguard Worker inversed = inversed[..., :len_expected] 1389*da0073e9SAndroid Build Coastguard Worker original = original[..., :len_expected] 1390*da0073e9SAndroid Build Coastguard Worker # test the padding points of the inversed signal are all zeros 1391*da0073e9SAndroid Build Coastguard Worker zeros = torch.zeros_like(padding, device=padding.device) 1392*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 1393*da0073e9SAndroid Build Coastguard Worker padding, zeros, msg='istft padding values against zeros', 1394*da0073e9SAndroid Build Coastguard Worker atol=7e-6, rtol=0, exact_dtype=True) 1395*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 1396*da0073e9SAndroid Build Coastguard Worker inversed, original, msg='istft comparison against original', 1397*da0073e9SAndroid Build Coastguard Worker atol=7e-6, rtol=0, exact_dtype=True) 1398*da0073e9SAndroid Build Coastguard Worker 1399*da0073e9SAndroid Build Coastguard Worker patterns = [ 1400*da0073e9SAndroid Build Coastguard Worker # hamming_window, not centered, not normalized, not onesided 1401*da0073e9SAndroid Build Coastguard Worker # window same size as n_fft 1402*da0073e9SAndroid Build Coastguard Worker { 1403*da0073e9SAndroid Build Coastguard Worker 'size': [2, 20], 1404*da0073e9SAndroid Build Coastguard Worker 'n_fft': 3, 1405*da0073e9SAndroid Build Coastguard Worker 'hop_length': 2, 1406*da0073e9SAndroid Build Coastguard Worker 'win_length': 3, 1407*da0073e9SAndroid Build Coastguard Worker 'window': torch.hamming_window(3, dtype=dtype, device=device), 1408*da0073e9SAndroid Build Coastguard Worker 'center': False, 1409*da0073e9SAndroid Build Coastguard Worker 'pad_mode': 'reflect', 1410*da0073e9SAndroid Build Coastguard Worker 'normalized': False, 1411*da0073e9SAndroid Build Coastguard Worker 'onesided': False, 1412*da0073e9SAndroid Build Coastguard Worker }, 1413*da0073e9SAndroid Build Coastguard Worker # hamming_window, centered, not normalized, onesided, long hop_length 1414*da0073e9SAndroid Build Coastguard Worker # window same size as n_fft 1415*da0073e9SAndroid Build Coastguard Worker { 1416*da0073e9SAndroid Build Coastguard Worker 'size': [2, 500], 1417*da0073e9SAndroid Build Coastguard Worker 'n_fft': 256, 1418*da0073e9SAndroid Build Coastguard Worker 'hop_length': 254, 1419*da0073e9SAndroid Build Coastguard Worker 'win_length': 256, 1420*da0073e9SAndroid Build Coastguard Worker 'window': torch.hamming_window(256, dtype=dtype, device=device), 1421*da0073e9SAndroid Build Coastguard Worker 'center': True, 1422*da0073e9SAndroid Build Coastguard Worker 'pad_mode': 'constant', 1423*da0073e9SAndroid Build Coastguard Worker 'normalized': False, 1424*da0073e9SAndroid Build Coastguard Worker 'onesided': True, 1425*da0073e9SAndroid Build Coastguard Worker }, 1426*da0073e9SAndroid Build Coastguard Worker ] 1427*da0073e9SAndroid Build Coastguard Worker for i, pattern in enumerate(patterns): 1428*da0073e9SAndroid Build Coastguard Worker _test_istft_is_inverse_of_stft_with_padding(pattern) 1429*da0073e9SAndroid Build Coastguard Worker 1430*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 1431*da0073e9SAndroid Build Coastguard Worker def test_istft_throws(self, device): 1432*da0073e9SAndroid Build Coastguard Worker """istft should throw exception for invalid parameters""" 1433*da0073e9SAndroid Build Coastguard Worker stft = torch.zeros((3, 5, 2), device=device) 1434*da0073e9SAndroid Build Coastguard Worker # the window is size 1 but it hops 20 so there is a gap which throw an error 1435*da0073e9SAndroid Build Coastguard Worker self.assertRaises( 1436*da0073e9SAndroid Build Coastguard Worker RuntimeError, torch.istft, stft, n_fft=4, 1437*da0073e9SAndroid Build Coastguard Worker hop_length=20, win_length=1, window=torch.ones(1)) 1438*da0073e9SAndroid Build Coastguard Worker # A window of zeros does not meet NOLA 1439*da0073e9SAndroid Build Coastguard Worker invalid_window = torch.zeros(4, device=device) 1440*da0073e9SAndroid Build Coastguard Worker self.assertRaises( 1441*da0073e9SAndroid Build Coastguard Worker RuntimeError, torch.istft, stft, n_fft=4, win_length=4, window=invalid_window) 1442*da0073e9SAndroid Build Coastguard Worker # Input cannot be empty 1443*da0073e9SAndroid Build Coastguard Worker self.assertRaises(RuntimeError, torch.istft, torch.zeros((3, 0, 2)), 2) 1444*da0073e9SAndroid Build Coastguard Worker self.assertRaises(RuntimeError, torch.istft, torch.zeros((0, 3, 2)), 2) 1445*da0073e9SAndroid Build Coastguard Worker 1446*da0073e9SAndroid Build Coastguard Worker @skipIfTorchDynamo("Failed running call_function") 1447*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 1448*da0073e9SAndroid Build Coastguard Worker @skipCPUIfNoFFT 1449*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.double) 1450*da0073e9SAndroid Build Coastguard Worker def test_istft_of_sine(self, device, dtype): 1451*da0073e9SAndroid Build Coastguard Worker complex_dtype = corresponding_complex_dtype(dtype) 1452*da0073e9SAndroid Build Coastguard Worker 1453*da0073e9SAndroid Build Coastguard Worker def _test(amplitude, L, n): 1454*da0073e9SAndroid Build Coastguard Worker # stft of amplitude*sin(2*pi/L*n*x) with the hop length and window size equaling L 1455*da0073e9SAndroid Build Coastguard Worker x = torch.arange(2 * L + 1, device=device, dtype=dtype) 1456*da0073e9SAndroid Build Coastguard Worker original = amplitude * torch.sin(2 * math.pi / L * x * n) 1457*da0073e9SAndroid Build Coastguard Worker # stft = torch.stft(original, L, hop_length=L, win_length=L, 1458*da0073e9SAndroid Build Coastguard Worker # window=torch.ones(L), center=False, normalized=False) 1459*da0073e9SAndroid Build Coastguard Worker stft = torch.zeros((L // 2 + 1, 2), device=device, dtype=complex_dtype) 1460*da0073e9SAndroid Build Coastguard Worker stft_largest_val = (amplitude * L) / 2.0 1461*da0073e9SAndroid Build Coastguard Worker if n < stft.size(0): 1462*da0073e9SAndroid Build Coastguard Worker stft[n].imag = torch.tensor(-stft_largest_val, dtype=dtype) 1463*da0073e9SAndroid Build Coastguard Worker 1464*da0073e9SAndroid Build Coastguard Worker if 0 <= L - n < stft.size(0): 1465*da0073e9SAndroid Build Coastguard Worker # symmetric about L // 2 1466*da0073e9SAndroid Build Coastguard Worker stft[L - n].imag = torch.tensor(stft_largest_val, dtype=dtype) 1467*da0073e9SAndroid Build Coastguard Worker 1468*da0073e9SAndroid Build Coastguard Worker inverse = torch.istft( 1469*da0073e9SAndroid Build Coastguard Worker stft, L, hop_length=L, win_length=L, 1470*da0073e9SAndroid Build Coastguard Worker window=torch.ones(L, device=device, dtype=dtype), center=False, normalized=False) 1471*da0073e9SAndroid Build Coastguard Worker # There is a larger error due to the scaling of amplitude 1472*da0073e9SAndroid Build Coastguard Worker original = original[..., :inverse.size(-1)] 1473*da0073e9SAndroid Build Coastguard Worker self.assertEqual(inverse, original, atol=1e-3, rtol=0) 1474*da0073e9SAndroid Build Coastguard Worker 1475*da0073e9SAndroid Build Coastguard Worker _test(amplitude=123, L=5, n=1) 1476*da0073e9SAndroid Build Coastguard Worker _test(amplitude=150, L=5, n=2) 1477*da0073e9SAndroid Build Coastguard Worker _test(amplitude=111, L=5, n=3) 1478*da0073e9SAndroid Build Coastguard Worker _test(amplitude=160, L=7, n=4) 1479*da0073e9SAndroid Build Coastguard Worker _test(amplitude=145, L=8, n=5) 1480*da0073e9SAndroid Build Coastguard Worker _test(amplitude=80, L=9, n=6) 1481*da0073e9SAndroid Build Coastguard Worker _test(amplitude=99, L=10, n=7) 1482*da0073e9SAndroid Build Coastguard Worker 1483*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 1484*da0073e9SAndroid Build Coastguard Worker @skipCPUIfNoFFT 1485*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.double) 1486*da0073e9SAndroid Build Coastguard Worker def test_istft_linearity(self, device, dtype): 1487*da0073e9SAndroid Build Coastguard Worker num_trials = 100 1488*da0073e9SAndroid Build Coastguard Worker complex_dtype = corresponding_complex_dtype(dtype) 1489*da0073e9SAndroid Build Coastguard Worker 1490*da0073e9SAndroid Build Coastguard Worker def _test(data_size, kwargs): 1491*da0073e9SAndroid Build Coastguard Worker for i in range(num_trials): 1492*da0073e9SAndroid Build Coastguard Worker tensor1 = torch.randn(data_size, device=device, dtype=complex_dtype) 1493*da0073e9SAndroid Build Coastguard Worker tensor2 = torch.randn(data_size, device=device, dtype=complex_dtype) 1494*da0073e9SAndroid Build Coastguard Worker a, b = torch.rand(2, dtype=dtype, device=device) 1495*da0073e9SAndroid Build Coastguard Worker # Also compare method vs. functional call signature 1496*da0073e9SAndroid Build Coastguard Worker istft1 = tensor1.istft(**kwargs) 1497*da0073e9SAndroid Build Coastguard Worker istft2 = tensor2.istft(**kwargs) 1498*da0073e9SAndroid Build Coastguard Worker istft = a * istft1 + b * istft2 1499*da0073e9SAndroid Build Coastguard Worker estimate = torch.istft(a * tensor1 + b * tensor2, **kwargs) 1500*da0073e9SAndroid Build Coastguard Worker self.assertEqual(istft, estimate, atol=1e-5, rtol=0) 1501*da0073e9SAndroid Build Coastguard Worker patterns = [ 1502*da0073e9SAndroid Build Coastguard Worker # hann_window, centered, normalized, onesided 1503*da0073e9SAndroid Build Coastguard Worker ( 1504*da0073e9SAndroid Build Coastguard Worker (2, 7, 7), 1505*da0073e9SAndroid Build Coastguard Worker { 1506*da0073e9SAndroid Build Coastguard Worker 'n_fft': 12, 1507*da0073e9SAndroid Build Coastguard Worker 'window': torch.hann_window(12, device=device, dtype=dtype), 1508*da0073e9SAndroid Build Coastguard Worker 'center': True, 1509*da0073e9SAndroid Build Coastguard Worker 'normalized': True, 1510*da0073e9SAndroid Build Coastguard Worker 'onesided': True, 1511*da0073e9SAndroid Build Coastguard Worker }, 1512*da0073e9SAndroid Build Coastguard Worker ), 1513*da0073e9SAndroid Build Coastguard Worker # hann_window, centered, not normalized, not onesided 1514*da0073e9SAndroid Build Coastguard Worker ( 1515*da0073e9SAndroid Build Coastguard Worker (2, 12, 7), 1516*da0073e9SAndroid Build Coastguard Worker { 1517*da0073e9SAndroid Build Coastguard Worker 'n_fft': 12, 1518*da0073e9SAndroid Build Coastguard Worker 'window': torch.hann_window(12, device=device, dtype=dtype), 1519*da0073e9SAndroid Build Coastguard Worker 'center': True, 1520*da0073e9SAndroid Build Coastguard Worker 'normalized': False, 1521*da0073e9SAndroid Build Coastguard Worker 'onesided': False, 1522*da0073e9SAndroid Build Coastguard Worker }, 1523*da0073e9SAndroid Build Coastguard Worker ), 1524*da0073e9SAndroid Build Coastguard Worker # hamming_window, centered, normalized, not onesided 1525*da0073e9SAndroid Build Coastguard Worker ( 1526*da0073e9SAndroid Build Coastguard Worker (2, 12, 7), 1527*da0073e9SAndroid Build Coastguard Worker { 1528*da0073e9SAndroid Build Coastguard Worker 'n_fft': 12, 1529*da0073e9SAndroid Build Coastguard Worker 'window': torch.hamming_window(12, device=device, dtype=dtype), 1530*da0073e9SAndroid Build Coastguard Worker 'center': True, 1531*da0073e9SAndroid Build Coastguard Worker 'normalized': True, 1532*da0073e9SAndroid Build Coastguard Worker 'onesided': False, 1533*da0073e9SAndroid Build Coastguard Worker }, 1534*da0073e9SAndroid Build Coastguard Worker ), 1535*da0073e9SAndroid Build Coastguard Worker # hamming_window, not centered, not normalized, onesided 1536*da0073e9SAndroid Build Coastguard Worker ( 1537*da0073e9SAndroid Build Coastguard Worker (2, 7, 3), 1538*da0073e9SAndroid Build Coastguard Worker { 1539*da0073e9SAndroid Build Coastguard Worker 'n_fft': 12, 1540*da0073e9SAndroid Build Coastguard Worker 'window': torch.hamming_window(12, device=device, dtype=dtype), 1541*da0073e9SAndroid Build Coastguard Worker 'center': False, 1542*da0073e9SAndroid Build Coastguard Worker 'normalized': False, 1543*da0073e9SAndroid Build Coastguard Worker 'onesided': True, 1544*da0073e9SAndroid Build Coastguard Worker }, 1545*da0073e9SAndroid Build Coastguard Worker ) 1546*da0073e9SAndroid Build Coastguard Worker ] 1547*da0073e9SAndroid Build Coastguard Worker for data_size, kwargs in patterns: 1548*da0073e9SAndroid Build Coastguard Worker _test(data_size, kwargs) 1549*da0073e9SAndroid Build Coastguard Worker 1550*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 1551*da0073e9SAndroid Build Coastguard Worker @skipCPUIfNoFFT 1552*da0073e9SAndroid Build Coastguard Worker def test_batch_istft(self, device): 1553*da0073e9SAndroid Build Coastguard Worker original = torch.tensor([ 1554*da0073e9SAndroid Build Coastguard Worker [4., 4., 4., 4., 4.], 1555*da0073e9SAndroid Build Coastguard Worker [0., 0., 0., 0., 0.], 1556*da0073e9SAndroid Build Coastguard Worker [0., 0., 0., 0., 0.] 1557*da0073e9SAndroid Build Coastguard Worker ], device=device, dtype=torch.complex64) 1558*da0073e9SAndroid Build Coastguard Worker 1559*da0073e9SAndroid Build Coastguard Worker single = original.repeat(1, 1, 1) 1560*da0073e9SAndroid Build Coastguard Worker multi = original.repeat(4, 1, 1) 1561*da0073e9SAndroid Build Coastguard Worker 1562*da0073e9SAndroid Build Coastguard Worker i_original = torch.istft(original, n_fft=4, length=4) 1563*da0073e9SAndroid Build Coastguard Worker i_single = torch.istft(single, n_fft=4, length=4) 1564*da0073e9SAndroid Build Coastguard Worker i_multi = torch.istft(multi, n_fft=4, length=4) 1565*da0073e9SAndroid Build Coastguard Worker 1566*da0073e9SAndroid Build Coastguard Worker self.assertEqual(i_original.repeat(1, 1), i_single, atol=1e-6, rtol=0, exact_dtype=True) 1567*da0073e9SAndroid Build Coastguard Worker self.assertEqual(i_original.repeat(4, 1), i_multi, atol=1e-6, rtol=0, exact_dtype=True) 1568*da0073e9SAndroid Build Coastguard Worker 1569*da0073e9SAndroid Build Coastguard Worker @onlyCUDA 1570*da0073e9SAndroid Build Coastguard Worker @skipIf(not TEST_MKL, "Test requires MKL") 1571*da0073e9SAndroid Build Coastguard Worker def test_stft_window_device(self, device): 1572*da0073e9SAndroid Build Coastguard Worker # Test the (i)stft window must be on the same device as the input 1573*da0073e9SAndroid Build Coastguard Worker x = torch.randn(1000, dtype=torch.complex64) 1574*da0073e9SAndroid Build Coastguard Worker window = torch.randn(100, dtype=torch.complex64) 1575*da0073e9SAndroid Build Coastguard Worker 1576*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "stft input and window must be on the same device"): 1577*da0073e9SAndroid Build Coastguard Worker torch.stft(x, n_fft=100, window=window.to(device)) 1578*da0073e9SAndroid Build Coastguard Worker 1579*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "stft input and window must be on the same device"): 1580*da0073e9SAndroid Build Coastguard Worker torch.stft(x.to(device), n_fft=100, window=window) 1581*da0073e9SAndroid Build Coastguard Worker 1582*da0073e9SAndroid Build Coastguard Worker X = torch.stft(x, n_fft=100, window=window) 1583*da0073e9SAndroid Build Coastguard Worker 1584*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "istft input and window must be on the same device"): 1585*da0073e9SAndroid Build Coastguard Worker torch.istft(X, n_fft=100, window=window.to(device)) 1586*da0073e9SAndroid Build Coastguard Worker 1587*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "istft input and window must be on the same device"): 1588*da0073e9SAndroid Build Coastguard Worker torch.istft(x.to(device), n_fft=100, window=window) 1589*da0073e9SAndroid Build Coastguard Worker 1590*da0073e9SAndroid Build Coastguard Worker 1591*da0073e9SAndroid Build Coastguard Workerclass FFTDocTestFinder: 1592*da0073e9SAndroid Build Coastguard Worker '''The default doctest finder doesn't like that function.__module__ doesn't 1593*da0073e9SAndroid Build Coastguard Worker match torch.fft. It assumes the functions are leaked imports. 1594*da0073e9SAndroid Build Coastguard Worker ''' 1595*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 1596*da0073e9SAndroid Build Coastguard Worker self.parser = doctest.DocTestParser() 1597*da0073e9SAndroid Build Coastguard Worker 1598*da0073e9SAndroid Build Coastguard Worker def find(self, obj, name=None, module=None, globs=None, extraglobs=None): 1599*da0073e9SAndroid Build Coastguard Worker doctests = [] 1600*da0073e9SAndroid Build Coastguard Worker 1601*da0073e9SAndroid Build Coastguard Worker modname = name if name is not None else obj.__name__ 1602*da0073e9SAndroid Build Coastguard Worker globs = {} if globs is None else globs 1603*da0073e9SAndroid Build Coastguard Worker 1604*da0073e9SAndroid Build Coastguard Worker for fname in obj.__all__: 1605*da0073e9SAndroid Build Coastguard Worker func = getattr(obj, fname) 1606*da0073e9SAndroid Build Coastguard Worker if inspect.isroutine(func): 1607*da0073e9SAndroid Build Coastguard Worker qualname = modname + '.' + fname 1608*da0073e9SAndroid Build Coastguard Worker docstring = inspect.getdoc(func) 1609*da0073e9SAndroid Build Coastguard Worker if docstring is None: 1610*da0073e9SAndroid Build Coastguard Worker continue 1611*da0073e9SAndroid Build Coastguard Worker 1612*da0073e9SAndroid Build Coastguard Worker examples = self.parser.get_doctest( 1613*da0073e9SAndroid Build Coastguard Worker docstring, globs=globs, name=fname, filename=None, lineno=None) 1614*da0073e9SAndroid Build Coastguard Worker doctests.append(examples) 1615*da0073e9SAndroid Build Coastguard Worker 1616*da0073e9SAndroid Build Coastguard Worker return doctests 1617*da0073e9SAndroid Build Coastguard Worker 1618*da0073e9SAndroid Build Coastguard Worker 1619*da0073e9SAndroid Build Coastguard Workerclass TestFFTDocExamples(TestCase): 1620*da0073e9SAndroid Build Coastguard Worker pass 1621*da0073e9SAndroid Build Coastguard Worker 1622*da0073e9SAndroid Build Coastguard Workerdef generate_doc_test(doc_test): 1623*da0073e9SAndroid Build Coastguard Worker def test(self, device): 1624*da0073e9SAndroid Build Coastguard Worker self.assertEqual(device, 'cpu') 1625*da0073e9SAndroid Build Coastguard Worker runner = doctest.DocTestRunner() 1626*da0073e9SAndroid Build Coastguard Worker runner.run(doc_test) 1627*da0073e9SAndroid Build Coastguard Worker 1628*da0073e9SAndroid Build Coastguard Worker if runner.failures != 0: 1629*da0073e9SAndroid Build Coastguard Worker runner.summarize() 1630*da0073e9SAndroid Build Coastguard Worker self.fail('Doctest failed') 1631*da0073e9SAndroid Build Coastguard Worker 1632*da0073e9SAndroid Build Coastguard Worker setattr(TestFFTDocExamples, 'test_' + doc_test.name, skipCPUIfNoFFT(test)) 1633*da0073e9SAndroid Build Coastguard Worker 1634*da0073e9SAndroid Build Coastguard Workerfor doc_test in FFTDocTestFinder().find(torch.fft, globs=dict(torch=torch)): 1635*da0073e9SAndroid Build Coastguard Worker generate_doc_test(doc_test) 1636*da0073e9SAndroid Build Coastguard Worker 1637*da0073e9SAndroid Build Coastguard Worker 1638*da0073e9SAndroid Build Coastguard Workerinstantiate_device_type_tests(TestFFT, globals()) 1639*da0073e9SAndroid Build Coastguard Workerinstantiate_device_type_tests(TestFFTDocExamples, globals(), only_for='cpu') 1640*da0073e9SAndroid Build Coastguard Worker 1641*da0073e9SAndroid Build Coastguard Workerif __name__ == '__main__': 1642*da0073e9SAndroid Build Coastguard Worker run_tests() 1643