xref: /aosp_15_r20/external/pytorch/test/test_spectral_ops.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["module: fft"]
2*da0073e9SAndroid Build Coastguard Worker
3*da0073e9SAndroid Build Coastguard Workerimport torch
4*da0073e9SAndroid Build Coastguard Workerimport unittest
5*da0073e9SAndroid Build Coastguard Workerimport math
6*da0073e9SAndroid Build Coastguard Workerfrom contextlib import contextmanager
7*da0073e9SAndroid Build Coastguard Workerfrom itertools import product
8*da0073e9SAndroid Build Coastguard Workerimport itertools
9*da0073e9SAndroid Build Coastguard Workerimport doctest
10*da0073e9SAndroid Build Coastguard Workerimport inspect
11*da0073e9SAndroid Build Coastguard Worker
12*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import \
13*da0073e9SAndroid Build Coastguard Worker    (TestCase, run_tests, TEST_NUMPY, TEST_LIBROSA, TEST_MKL, first_sample, TEST_WITH_ROCM,
14*da0073e9SAndroid Build Coastguard Worker     make_tensor, skipIfTorchDynamo)
15*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_device_type import \
16*da0073e9SAndroid Build Coastguard Worker    (instantiate_device_type_tests, ops, dtypes, onlyNativeDeviceTypes,
17*da0073e9SAndroid Build Coastguard Worker     skipCPUIfNoFFT, deviceCountAtLeast, onlyCUDA, OpDTypes, skipIf, toleranceOverride, tol)
18*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_methods_invocations import (
19*da0073e9SAndroid Build Coastguard Worker    spectral_funcs, SpectralFuncType)
20*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_cuda import SM53OrLater
21*da0073e9SAndroid Build Coastguard Workerfrom torch._prims_common import corresponding_complex_dtype
22*da0073e9SAndroid Build Coastguard Worker
23*da0073e9SAndroid Build Coastguard Workerfrom typing import Optional, List
24*da0073e9SAndroid Build Coastguard Workerfrom packaging import version
25*da0073e9SAndroid Build Coastguard Worker
26*da0073e9SAndroid Build Coastguard Worker
27*da0073e9SAndroid Build Coastguard Workerif TEST_NUMPY:
28*da0073e9SAndroid Build Coastguard Worker    import numpy as np
29*da0073e9SAndroid Build Coastguard Worker
30*da0073e9SAndroid Build Coastguard Worker
31*da0073e9SAndroid Build Coastguard Workerif TEST_LIBROSA:
32*da0073e9SAndroid Build Coastguard Worker    import librosa
33*da0073e9SAndroid Build Coastguard Worker
34*da0073e9SAndroid Build Coastguard Workerhas_scipy_fft = False
35*da0073e9SAndroid Build Coastguard Workertry:
36*da0073e9SAndroid Build Coastguard Worker    import scipy.fft
37*da0073e9SAndroid Build Coastguard Worker    has_scipy_fft = True
38*da0073e9SAndroid Build Coastguard Workerexcept ModuleNotFoundError:
39*da0073e9SAndroid Build Coastguard Worker    pass
40*da0073e9SAndroid Build Coastguard Worker
41*da0073e9SAndroid Build Coastguard WorkerREFERENCE_NORM_MODES = (
42*da0073e9SAndroid Build Coastguard Worker    (None, "forward", "backward", "ortho")
43*da0073e9SAndroid Build Coastguard Worker    if version.parse(np.__version__) >= version.parse('1.20.0') and (
44*da0073e9SAndroid Build Coastguard Worker        not has_scipy_fft or version.parse(scipy.__version__) >= version.parse('1.6.0'))
45*da0073e9SAndroid Build Coastguard Worker    else (None, "ortho"))
46*da0073e9SAndroid Build Coastguard Worker
47*da0073e9SAndroid Build Coastguard Worker
48*da0073e9SAndroid Build Coastguard Workerdef _complex_stft(x, *args, **kwargs):
49*da0073e9SAndroid Build Coastguard Worker    # Transform real and imaginary components separably
50*da0073e9SAndroid Build Coastguard Worker    stft_real = torch.stft(x.real, *args, **kwargs, return_complex=True, onesided=False)
51*da0073e9SAndroid Build Coastguard Worker    stft_imag = torch.stft(x.imag, *args, **kwargs, return_complex=True, onesided=False)
52*da0073e9SAndroid Build Coastguard Worker    return stft_real + 1j * stft_imag
53*da0073e9SAndroid Build Coastguard Worker
54*da0073e9SAndroid Build Coastguard Worker
55*da0073e9SAndroid Build Coastguard Workerdef _hermitian_conj(x, dim):
56*da0073e9SAndroid Build Coastguard Worker    """Returns the hermitian conjugate along a single dimension
57*da0073e9SAndroid Build Coastguard Worker
58*da0073e9SAndroid Build Coastguard Worker    H(x)[i] = conj(x[-i])
59*da0073e9SAndroid Build Coastguard Worker    """
60*da0073e9SAndroid Build Coastguard Worker    out = torch.empty_like(x)
61*da0073e9SAndroid Build Coastguard Worker    mid = (x.size(dim) - 1) // 2
62*da0073e9SAndroid Build Coastguard Worker    idx = [slice(None)] * out.dim()
63*da0073e9SAndroid Build Coastguard Worker    idx_center = list(idx)
64*da0073e9SAndroid Build Coastguard Worker    idx_center[dim] = 0
65*da0073e9SAndroid Build Coastguard Worker    out[idx] = x[idx]
66*da0073e9SAndroid Build Coastguard Worker
67*da0073e9SAndroid Build Coastguard Worker    idx_neg = list(idx)
68*da0073e9SAndroid Build Coastguard Worker    idx_neg[dim] = slice(-mid, None)
69*da0073e9SAndroid Build Coastguard Worker    idx_pos = idx
70*da0073e9SAndroid Build Coastguard Worker    idx_pos[dim] = slice(1, mid + 1)
71*da0073e9SAndroid Build Coastguard Worker
72*da0073e9SAndroid Build Coastguard Worker    out[idx_pos] = x[idx_neg].flip(dim)
73*da0073e9SAndroid Build Coastguard Worker    out[idx_neg] = x[idx_pos].flip(dim)
74*da0073e9SAndroid Build Coastguard Worker    if (2 * mid + 1 < x.size(dim)):
75*da0073e9SAndroid Build Coastguard Worker        idx[dim] = mid + 1
76*da0073e9SAndroid Build Coastguard Worker        out[idx] = x[idx]
77*da0073e9SAndroid Build Coastguard Worker    return out.conj()
78*da0073e9SAndroid Build Coastguard Worker
79*da0073e9SAndroid Build Coastguard Worker
80*da0073e9SAndroid Build Coastguard Workerdef _complex_istft(x, *args, **kwargs):
81*da0073e9SAndroid Build Coastguard Worker    # Decompose into Hermitian (FFT of real) and anti-Hermitian (FFT of imaginary)
82*da0073e9SAndroid Build Coastguard Worker    n_fft = x.size(-2)
83*da0073e9SAndroid Build Coastguard Worker    slc = (Ellipsis, slice(None, n_fft // 2 + 1), slice(None))
84*da0073e9SAndroid Build Coastguard Worker
85*da0073e9SAndroid Build Coastguard Worker    hconj = _hermitian_conj(x, dim=-2)
86*da0073e9SAndroid Build Coastguard Worker    x_hermitian = (x + hconj) / 2
87*da0073e9SAndroid Build Coastguard Worker    x_antihermitian = (x - hconj) / 2
88*da0073e9SAndroid Build Coastguard Worker    istft_real = torch.istft(x_hermitian[slc], *args, **kwargs, onesided=True)
89*da0073e9SAndroid Build Coastguard Worker    istft_imag = torch.istft(-1j * x_antihermitian[slc], *args, **kwargs, onesided=True)
90*da0073e9SAndroid Build Coastguard Worker    return torch.complex(istft_real, istft_imag)
91*da0073e9SAndroid Build Coastguard Worker
92*da0073e9SAndroid Build Coastguard Worker
93*da0073e9SAndroid Build Coastguard Workerdef _stft_reference(x, hop_length, window):
94*da0073e9SAndroid Build Coastguard Worker    r"""Reference stft implementation
95*da0073e9SAndroid Build Coastguard Worker
96*da0073e9SAndroid Build Coastguard Worker    This doesn't implement all of torch.stft, only the STFT definition:
97*da0073e9SAndroid Build Coastguard Worker
98*da0073e9SAndroid Build Coastguard Worker    .. math:: X(m, \omega) = \sum_n x[n]w[n - m] e^{-jn\omega}
99*da0073e9SAndroid Build Coastguard Worker
100*da0073e9SAndroid Build Coastguard Worker    """
101*da0073e9SAndroid Build Coastguard Worker    n_fft = window.numel()
102*da0073e9SAndroid Build Coastguard Worker    X = torch.empty((n_fft, (x.numel() - n_fft + hop_length) // hop_length),
103*da0073e9SAndroid Build Coastguard Worker                    device=x.device, dtype=torch.cdouble)
104*da0073e9SAndroid Build Coastguard Worker    for m in range(X.size(1)):
105*da0073e9SAndroid Build Coastguard Worker        start = m * hop_length
106*da0073e9SAndroid Build Coastguard Worker        if start + n_fft > x.numel():
107*da0073e9SAndroid Build Coastguard Worker            slc = torch.empty(n_fft, device=x.device, dtype=x.dtype)
108*da0073e9SAndroid Build Coastguard Worker            tmp = x[start:]
109*da0073e9SAndroid Build Coastguard Worker            slc[:tmp.numel()] = tmp
110*da0073e9SAndroid Build Coastguard Worker        else:
111*da0073e9SAndroid Build Coastguard Worker            slc = x[start: start + n_fft]
112*da0073e9SAndroid Build Coastguard Worker        X[:, m] = torch.fft.fft(slc * window)
113*da0073e9SAndroid Build Coastguard Worker    return X
114*da0073e9SAndroid Build Coastguard Worker
115*da0073e9SAndroid Build Coastguard Worker
116*da0073e9SAndroid Build Coastguard Workerdef skip_helper_for_fft(device, dtype):
117*da0073e9SAndroid Build Coastguard Worker    device_type = torch.device(device).type
118*da0073e9SAndroid Build Coastguard Worker    if dtype not in (torch.half, torch.complex32):
119*da0073e9SAndroid Build Coastguard Worker        return
120*da0073e9SAndroid Build Coastguard Worker
121*da0073e9SAndroid Build Coastguard Worker    if device_type == 'cpu':
122*da0073e9SAndroid Build Coastguard Worker        raise unittest.SkipTest("half and complex32 are not supported on CPU")
123*da0073e9SAndroid Build Coastguard Worker    if not SM53OrLater:
124*da0073e9SAndroid Build Coastguard Worker        raise unittest.SkipTest("half and complex32 are only supported on CUDA device with SM>53")
125*da0073e9SAndroid Build Coastguard Worker
126*da0073e9SAndroid Build Coastguard Worker
127*da0073e9SAndroid Build Coastguard Worker# Tests of functions related to Fourier analysis in the torch.fft namespace
128*da0073e9SAndroid Build Coastguard Workerclass TestFFT(TestCase):
129*da0073e9SAndroid Build Coastguard Worker    exact_dtype = True
130*da0073e9SAndroid Build Coastguard Worker
131*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
132*da0073e9SAndroid Build Coastguard Worker    @ops([op for op in spectral_funcs if op.ndimensional == SpectralFuncType.OneD],
133*da0073e9SAndroid Build Coastguard Worker         allowed_dtypes=(torch.float, torch.cfloat))
134*da0073e9SAndroid Build Coastguard Worker    def test_reference_1d(self, device, dtype, op):
135*da0073e9SAndroid Build Coastguard Worker        if op.ref is None:
136*da0073e9SAndroid Build Coastguard Worker            raise unittest.SkipTest("No reference implementation")
137*da0073e9SAndroid Build Coastguard Worker
138*da0073e9SAndroid Build Coastguard Worker        norm_modes = REFERENCE_NORM_MODES
139*da0073e9SAndroid Build Coastguard Worker        test_args = [
140*da0073e9SAndroid Build Coastguard Worker            *product(
141*da0073e9SAndroid Build Coastguard Worker                # input
142*da0073e9SAndroid Build Coastguard Worker                (torch.randn(67, device=device, dtype=dtype),
143*da0073e9SAndroid Build Coastguard Worker                 torch.randn(80, device=device, dtype=dtype),
144*da0073e9SAndroid Build Coastguard Worker                 torch.randn(12, 14, device=device, dtype=dtype),
145*da0073e9SAndroid Build Coastguard Worker                 torch.randn(9, 6, 3, device=device, dtype=dtype)),
146*da0073e9SAndroid Build Coastguard Worker                # n
147*da0073e9SAndroid Build Coastguard Worker                (None, 50, 6),
148*da0073e9SAndroid Build Coastguard Worker                # dim
149*da0073e9SAndroid Build Coastguard Worker                (-1, 0),
150*da0073e9SAndroid Build Coastguard Worker                # norm
151*da0073e9SAndroid Build Coastguard Worker                norm_modes
152*da0073e9SAndroid Build Coastguard Worker            ),
153*da0073e9SAndroid Build Coastguard Worker            # Test transforming middle dimensions of multi-dim tensor
154*da0073e9SAndroid Build Coastguard Worker            *product(
155*da0073e9SAndroid Build Coastguard Worker                (torch.randn(4, 5, 6, 7, device=device, dtype=dtype),),
156*da0073e9SAndroid Build Coastguard Worker                (None,),
157*da0073e9SAndroid Build Coastguard Worker                (1, 2, -2,),
158*da0073e9SAndroid Build Coastguard Worker                norm_modes
159*da0073e9SAndroid Build Coastguard Worker            )
160*da0073e9SAndroid Build Coastguard Worker        ]
161*da0073e9SAndroid Build Coastguard Worker
162*da0073e9SAndroid Build Coastguard Worker        for iargs in test_args:
163*da0073e9SAndroid Build Coastguard Worker            args = list(iargs)
164*da0073e9SAndroid Build Coastguard Worker            input = args[0]
165*da0073e9SAndroid Build Coastguard Worker            args = args[1:]
166*da0073e9SAndroid Build Coastguard Worker
167*da0073e9SAndroid Build Coastguard Worker            expected = op.ref(input.cpu().numpy(), *args)
168*da0073e9SAndroid Build Coastguard Worker            exact_dtype = dtype in (torch.double, torch.complex128)
169*da0073e9SAndroid Build Coastguard Worker            actual = op(input, *args)
170*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(actual, expected, exact_dtype=exact_dtype)
171*da0073e9SAndroid Build Coastguard Worker
172*da0073e9SAndroid Build Coastguard Worker    @skipCPUIfNoFFT
173*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
174*da0073e9SAndroid Build Coastguard Worker    @toleranceOverride({
175*da0073e9SAndroid Build Coastguard Worker        torch.half : tol(1e-2, 1e-2),
176*da0073e9SAndroid Build Coastguard Worker        torch.chalf : tol(1e-2, 1e-2),
177*da0073e9SAndroid Build Coastguard Worker    })
178*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.half, torch.float, torch.double, torch.complex32, torch.complex64, torch.complex128)
179*da0073e9SAndroid Build Coastguard Worker    def test_fft_round_trip(self, device, dtype):
180*da0073e9SAndroid Build Coastguard Worker        skip_helper_for_fft(device, dtype)
181*da0073e9SAndroid Build Coastguard Worker        # Test that round trip through ifft(fft(x)) is the identity
182*da0073e9SAndroid Build Coastguard Worker        if dtype not in (torch.half, torch.complex32):
183*da0073e9SAndroid Build Coastguard Worker            test_args = list(product(
184*da0073e9SAndroid Build Coastguard Worker                # input
185*da0073e9SAndroid Build Coastguard Worker                (torch.randn(67, device=device, dtype=dtype),
186*da0073e9SAndroid Build Coastguard Worker                 torch.randn(80, device=device, dtype=dtype),
187*da0073e9SAndroid Build Coastguard Worker                 torch.randn(12, 14, device=device, dtype=dtype),
188*da0073e9SAndroid Build Coastguard Worker                 torch.randn(9, 6, 3, device=device, dtype=dtype)),
189*da0073e9SAndroid Build Coastguard Worker                # dim
190*da0073e9SAndroid Build Coastguard Worker                (-1, 0),
191*da0073e9SAndroid Build Coastguard Worker                # norm
192*da0073e9SAndroid Build Coastguard Worker                (None, "forward", "backward", "ortho")
193*da0073e9SAndroid Build Coastguard Worker            ))
194*da0073e9SAndroid Build Coastguard Worker        else:
195*da0073e9SAndroid Build Coastguard Worker            # cuFFT supports powers of 2 for half and complex half precision
196*da0073e9SAndroid Build Coastguard Worker            test_args = list(product(
197*da0073e9SAndroid Build Coastguard Worker                # input
198*da0073e9SAndroid Build Coastguard Worker                (torch.randn(64, device=device, dtype=dtype),
199*da0073e9SAndroid Build Coastguard Worker                 torch.randn(128, device=device, dtype=dtype),
200*da0073e9SAndroid Build Coastguard Worker                 torch.randn(4, 16, device=device, dtype=dtype),
201*da0073e9SAndroid Build Coastguard Worker                 torch.randn(8, 6, 2, device=device, dtype=dtype)),
202*da0073e9SAndroid Build Coastguard Worker                # dim
203*da0073e9SAndroid Build Coastguard Worker                (-1, 0),
204*da0073e9SAndroid Build Coastguard Worker                # norm
205*da0073e9SAndroid Build Coastguard Worker                (None, "forward", "backward", "ortho")
206*da0073e9SAndroid Build Coastguard Worker            ))
207*da0073e9SAndroid Build Coastguard Worker
208*da0073e9SAndroid Build Coastguard Worker        fft_functions = [(torch.fft.fft, torch.fft.ifft)]
209*da0073e9SAndroid Build Coastguard Worker        # Real-only functions
210*da0073e9SAndroid Build Coastguard Worker        if not dtype.is_complex:
211*da0073e9SAndroid Build Coastguard Worker            # NOTE: Using ihfft as "forward" transform to avoid needing to
212*da0073e9SAndroid Build Coastguard Worker            # generate true half-complex input
213*da0073e9SAndroid Build Coastguard Worker            fft_functions += [(torch.fft.rfft, torch.fft.irfft),
214*da0073e9SAndroid Build Coastguard Worker                              (torch.fft.ihfft, torch.fft.hfft)]
215*da0073e9SAndroid Build Coastguard Worker
216*da0073e9SAndroid Build Coastguard Worker        for forward, backward in fft_functions:
217*da0073e9SAndroid Build Coastguard Worker            for x, dim, norm in test_args:
218*da0073e9SAndroid Build Coastguard Worker                kwargs = {
219*da0073e9SAndroid Build Coastguard Worker                    'n': x.size(dim),
220*da0073e9SAndroid Build Coastguard Worker                    'dim': dim,
221*da0073e9SAndroid Build Coastguard Worker                    'norm': norm,
222*da0073e9SAndroid Build Coastguard Worker                }
223*da0073e9SAndroid Build Coastguard Worker
224*da0073e9SAndroid Build Coastguard Worker                y = backward(forward(x, **kwargs), **kwargs)
225*da0073e9SAndroid Build Coastguard Worker                if x.dtype is torch.half and y.dtype is torch.complex32:
226*da0073e9SAndroid Build Coastguard Worker                    # Since type promotion currently doesn't work with complex32
227*da0073e9SAndroid Build Coastguard Worker                    # manually promote `x` to complex32
228*da0073e9SAndroid Build Coastguard Worker                    x = x.to(torch.complex32)
229*da0073e9SAndroid Build Coastguard Worker                # For real input, ifft(fft(x)) will convert to complex
230*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(x, y, exact_dtype=(
231*da0073e9SAndroid Build Coastguard Worker                    forward != torch.fft.fft or x.is_complex()))
232*da0073e9SAndroid Build Coastguard Worker
233*da0073e9SAndroid Build Coastguard Worker    # Note: NumPy will throw a ValueError for an empty input
234*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
235*da0073e9SAndroid Build Coastguard Worker    @ops(spectral_funcs, allowed_dtypes=(torch.half, torch.float, torch.complex32, torch.cfloat))
236*da0073e9SAndroid Build Coastguard Worker    def test_empty_fft(self, device, dtype, op):
237*da0073e9SAndroid Build Coastguard Worker        t = torch.empty(1, 0, device=device, dtype=dtype)
238*da0073e9SAndroid Build Coastguard Worker        match = r"Invalid number of data points \([-\d]*\) specified"
239*da0073e9SAndroid Build Coastguard Worker
240*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, match):
241*da0073e9SAndroid Build Coastguard Worker            op(t)
242*da0073e9SAndroid Build Coastguard Worker
243*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
244*da0073e9SAndroid Build Coastguard Worker    def test_empty_ifft(self, device):
245*da0073e9SAndroid Build Coastguard Worker        t = torch.empty(2, 1, device=device, dtype=torch.complex64)
246*da0073e9SAndroid Build Coastguard Worker        match = r"Invalid number of data points \([-\d]*\) specified"
247*da0073e9SAndroid Build Coastguard Worker
248*da0073e9SAndroid Build Coastguard Worker        for f in [torch.fft.irfft, torch.fft.irfft2, torch.fft.irfftn,
249*da0073e9SAndroid Build Coastguard Worker                  torch.fft.hfft, torch.fft.hfft2, torch.fft.hfftn]:
250*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(RuntimeError, match):
251*da0073e9SAndroid Build Coastguard Worker                f(t)
252*da0073e9SAndroid Build Coastguard Worker
253*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
254*da0073e9SAndroid Build Coastguard Worker    def test_fft_invalid_dtypes(self, device):
255*da0073e9SAndroid Build Coastguard Worker        t = torch.randn(64, device=device, dtype=torch.complex128)
256*da0073e9SAndroid Build Coastguard Worker
257*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "rfft expects a real input tensor"):
258*da0073e9SAndroid Build Coastguard Worker            torch.fft.rfft(t)
259*da0073e9SAndroid Build Coastguard Worker
260*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "rfftn expects a real-valued input tensor"):
261*da0073e9SAndroid Build Coastguard Worker            torch.fft.rfftn(t)
262*da0073e9SAndroid Build Coastguard Worker
263*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "ihfft expects a real input tensor"):
264*da0073e9SAndroid Build Coastguard Worker            torch.fft.ihfft(t)
265*da0073e9SAndroid Build Coastguard Worker
266*da0073e9SAndroid Build Coastguard Worker    @skipCPUIfNoFFT
267*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
268*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.int8, torch.half, torch.float, torch.double,
269*da0073e9SAndroid Build Coastguard Worker            torch.complex32, torch.complex64, torch.complex128)
270*da0073e9SAndroid Build Coastguard Worker    def test_fft_type_promotion(self, device, dtype):
271*da0073e9SAndroid Build Coastguard Worker        skip_helper_for_fft(device, dtype)
272*da0073e9SAndroid Build Coastguard Worker
273*da0073e9SAndroid Build Coastguard Worker        if dtype.is_complex or dtype.is_floating_point:
274*da0073e9SAndroid Build Coastguard Worker            t = torch.randn(64, device=device, dtype=dtype)
275*da0073e9SAndroid Build Coastguard Worker        else:
276*da0073e9SAndroid Build Coastguard Worker            t = torch.randint(-2, 2, (64,), device=device, dtype=dtype)
277*da0073e9SAndroid Build Coastguard Worker
278*da0073e9SAndroid Build Coastguard Worker        PROMOTION_MAP = {
279*da0073e9SAndroid Build Coastguard Worker            torch.int8: torch.complex64,
280*da0073e9SAndroid Build Coastguard Worker            torch.half: torch.complex32,
281*da0073e9SAndroid Build Coastguard Worker            torch.float: torch.complex64,
282*da0073e9SAndroid Build Coastguard Worker            torch.double: torch.complex128,
283*da0073e9SAndroid Build Coastguard Worker            torch.complex32: torch.complex32,
284*da0073e9SAndroid Build Coastguard Worker            torch.complex64: torch.complex64,
285*da0073e9SAndroid Build Coastguard Worker            torch.complex128: torch.complex128,
286*da0073e9SAndroid Build Coastguard Worker        }
287*da0073e9SAndroid Build Coastguard Worker        T = torch.fft.fft(t)
288*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(T.dtype, PROMOTION_MAP[dtype])
289*da0073e9SAndroid Build Coastguard Worker
290*da0073e9SAndroid Build Coastguard Worker        PROMOTION_MAP_C2R = {
291*da0073e9SAndroid Build Coastguard Worker            torch.int8: torch.float,
292*da0073e9SAndroid Build Coastguard Worker            torch.half: torch.half,
293*da0073e9SAndroid Build Coastguard Worker            torch.float: torch.float,
294*da0073e9SAndroid Build Coastguard Worker            torch.double: torch.double,
295*da0073e9SAndroid Build Coastguard Worker            torch.complex32: torch.half,
296*da0073e9SAndroid Build Coastguard Worker            torch.complex64: torch.float,
297*da0073e9SAndroid Build Coastguard Worker            torch.complex128: torch.double,
298*da0073e9SAndroid Build Coastguard Worker        }
299*da0073e9SAndroid Build Coastguard Worker        if dtype in (torch.half, torch.complex32):
300*da0073e9SAndroid Build Coastguard Worker            # cuFFT supports powers of 2 for half and complex half precision
301*da0073e9SAndroid Build Coastguard Worker            # NOTE: With hfft and default args where output_size n=2*(input_size - 1),
302*da0073e9SAndroid Build Coastguard Worker            # we make sure that logical fft size is a power of two.
303*da0073e9SAndroid Build Coastguard Worker            x = torch.randn(65, device=device, dtype=dtype)
304*da0073e9SAndroid Build Coastguard Worker            R = torch.fft.hfft(x)
305*da0073e9SAndroid Build Coastguard Worker        else:
306*da0073e9SAndroid Build Coastguard Worker            R = torch.fft.hfft(t)
307*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(R.dtype, PROMOTION_MAP_C2R[dtype])
308*da0073e9SAndroid Build Coastguard Worker
309*da0073e9SAndroid Build Coastguard Worker        if not dtype.is_complex:
310*da0073e9SAndroid Build Coastguard Worker            PROMOTION_MAP_R2C = {
311*da0073e9SAndroid Build Coastguard Worker                torch.int8: torch.complex64,
312*da0073e9SAndroid Build Coastguard Worker                torch.half: torch.complex32,
313*da0073e9SAndroid Build Coastguard Worker                torch.float: torch.complex64,
314*da0073e9SAndroid Build Coastguard Worker                torch.double: torch.complex128,
315*da0073e9SAndroid Build Coastguard Worker            }
316*da0073e9SAndroid Build Coastguard Worker            C = torch.fft.rfft(t)
317*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(C.dtype, PROMOTION_MAP_R2C[dtype])
318*da0073e9SAndroid Build Coastguard Worker
319*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
320*da0073e9SAndroid Build Coastguard Worker    @ops(spectral_funcs, dtypes=OpDTypes.unsupported,
321*da0073e9SAndroid Build Coastguard Worker         allowed_dtypes=[torch.half, torch.bfloat16])
322*da0073e9SAndroid Build Coastguard Worker    def test_fft_half_and_bfloat16_errors(self, device, dtype, op):
323*da0073e9SAndroid Build Coastguard Worker        # TODO: Remove torch.half error when complex32 is fully implemented
324*da0073e9SAndroid Build Coastguard Worker        sample = first_sample(self, op.sample_inputs(device, dtype))
325*da0073e9SAndroid Build Coastguard Worker        device_type = torch.device(device).type
326*da0073e9SAndroid Build Coastguard Worker        default_msg = "Unsupported dtype"
327*da0073e9SAndroid Build Coastguard Worker        if dtype is torch.half and device_type == 'cuda' and TEST_WITH_ROCM:
328*da0073e9SAndroid Build Coastguard Worker            err_msg = default_msg
329*da0073e9SAndroid Build Coastguard Worker        elif dtype is torch.half and device_type == 'cuda' and not SM53OrLater:
330*da0073e9SAndroid Build Coastguard Worker            err_msg = "cuFFT doesn't support signals of half type with compute capability less than SM_53"
331*da0073e9SAndroid Build Coastguard Worker        else:
332*da0073e9SAndroid Build Coastguard Worker            err_msg = default_msg
333*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, err_msg):
334*da0073e9SAndroid Build Coastguard Worker            op(sample.input, *sample.args, **sample.kwargs)
335*da0073e9SAndroid Build Coastguard Worker
336*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
337*da0073e9SAndroid Build Coastguard Worker    @ops(spectral_funcs, allowed_dtypes=(torch.half, torch.chalf))
338*da0073e9SAndroid Build Coastguard Worker    def test_fft_half_and_chalf_not_power_of_two_error(self, device, dtype, op):
339*da0073e9SAndroid Build Coastguard Worker        t = make_tensor(13, 13, device=device, dtype=dtype)
340*da0073e9SAndroid Build Coastguard Worker        err_msg = "cuFFT only supports dimensions whose sizes are powers of two"
341*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, err_msg):
342*da0073e9SAndroid Build Coastguard Worker            op(t)
343*da0073e9SAndroid Build Coastguard Worker
344*da0073e9SAndroid Build Coastguard Worker        if op.ndimensional in (SpectralFuncType.ND, SpectralFuncType.TwoD):
345*da0073e9SAndroid Build Coastguard Worker            kwargs = {'s': (12, 12)}
346*da0073e9SAndroid Build Coastguard Worker        else:
347*da0073e9SAndroid Build Coastguard Worker            kwargs = {'n': 12}
348*da0073e9SAndroid Build Coastguard Worker
349*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, err_msg):
350*da0073e9SAndroid Build Coastguard Worker            op(t, **kwargs)
351*da0073e9SAndroid Build Coastguard Worker
352*da0073e9SAndroid Build Coastguard Worker    # nd-fft tests
353*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
354*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_NUMPY, 'NumPy not found')
355*da0073e9SAndroid Build Coastguard Worker    @ops([op for op in spectral_funcs if op.ndimensional == SpectralFuncType.ND],
356*da0073e9SAndroid Build Coastguard Worker         allowed_dtypes=(torch.cfloat, torch.cdouble))
357*da0073e9SAndroid Build Coastguard Worker    def test_reference_nd(self, device, dtype, op):
358*da0073e9SAndroid Build Coastguard Worker        if op.ref is None:
359*da0073e9SAndroid Build Coastguard Worker            raise unittest.SkipTest("No reference implementation")
360*da0073e9SAndroid Build Coastguard Worker
361*da0073e9SAndroid Build Coastguard Worker        norm_modes = REFERENCE_NORM_MODES
362*da0073e9SAndroid Build Coastguard Worker
363*da0073e9SAndroid Build Coastguard Worker        # input_ndim, s, dim
364*da0073e9SAndroid Build Coastguard Worker        transform_desc = [
365*da0073e9SAndroid Build Coastguard Worker            *product(range(2, 5), (None,), (None, (0,), (0, -1))),
366*da0073e9SAndroid Build Coastguard Worker            *product(range(2, 5), (None, (4, 10)), (None,)),
367*da0073e9SAndroid Build Coastguard Worker            (6, None, None),
368*da0073e9SAndroid Build Coastguard Worker            (5, None, (1, 3, 4)),
369*da0073e9SAndroid Build Coastguard Worker            (3, None, (1,)),
370*da0073e9SAndroid Build Coastguard Worker            (1, None, (0,)),
371*da0073e9SAndroid Build Coastguard Worker            (4, (10, 10), None),
372*da0073e9SAndroid Build Coastguard Worker            (4, (10, 10), (0, 1))
373*da0073e9SAndroid Build Coastguard Worker        ]
374*da0073e9SAndroid Build Coastguard Worker
375*da0073e9SAndroid Build Coastguard Worker        for input_ndim, s, dim in transform_desc:
376*da0073e9SAndroid Build Coastguard Worker            shape = itertools.islice(itertools.cycle(range(4, 9)), input_ndim)
377*da0073e9SAndroid Build Coastguard Worker            input = torch.randn(*shape, device=device, dtype=dtype)
378*da0073e9SAndroid Build Coastguard Worker
379*da0073e9SAndroid Build Coastguard Worker            for norm in norm_modes:
380*da0073e9SAndroid Build Coastguard Worker                expected = op.ref(input.cpu().numpy(), s, dim, norm)
381*da0073e9SAndroid Build Coastguard Worker                exact_dtype = dtype in (torch.double, torch.complex128)
382*da0073e9SAndroid Build Coastguard Worker                actual = op(input, s, dim, norm)
383*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(actual, expected, exact_dtype=exact_dtype)
384*da0073e9SAndroid Build Coastguard Worker
385*da0073e9SAndroid Build Coastguard Worker    @skipCPUIfNoFFT
386*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
387*da0073e9SAndroid Build Coastguard Worker    @toleranceOverride({
388*da0073e9SAndroid Build Coastguard Worker        torch.half : tol(1e-2, 1e-2),
389*da0073e9SAndroid Build Coastguard Worker        torch.chalf : tol(1e-2, 1e-2),
390*da0073e9SAndroid Build Coastguard Worker    })
391*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.half, torch.float, torch.double,
392*da0073e9SAndroid Build Coastguard Worker            torch.complex32, torch.complex64, torch.complex128)
393*da0073e9SAndroid Build Coastguard Worker    def test_fftn_round_trip(self, device, dtype):
394*da0073e9SAndroid Build Coastguard Worker        skip_helper_for_fft(device, dtype)
395*da0073e9SAndroid Build Coastguard Worker
396*da0073e9SAndroid Build Coastguard Worker        norm_modes = (None, "forward", "backward", "ortho")
397*da0073e9SAndroid Build Coastguard Worker
398*da0073e9SAndroid Build Coastguard Worker        # input_ndim, dim
399*da0073e9SAndroid Build Coastguard Worker        transform_desc = [
400*da0073e9SAndroid Build Coastguard Worker            *product(range(2, 5), (None, (0,), (0, -1))),
401*da0073e9SAndroid Build Coastguard Worker            (7, None),
402*da0073e9SAndroid Build Coastguard Worker            (5, (1, 3, 4)),
403*da0073e9SAndroid Build Coastguard Worker            (3, (1,)),
404*da0073e9SAndroid Build Coastguard Worker            (1, 0),
405*da0073e9SAndroid Build Coastguard Worker        ]
406*da0073e9SAndroid Build Coastguard Worker
407*da0073e9SAndroid Build Coastguard Worker        fft_functions = [(torch.fft.fftn, torch.fft.ifftn)]
408*da0073e9SAndroid Build Coastguard Worker
409*da0073e9SAndroid Build Coastguard Worker        # Real-only functions
410*da0073e9SAndroid Build Coastguard Worker        if not dtype.is_complex:
411*da0073e9SAndroid Build Coastguard Worker            # NOTE: Using ihfftn as "forward" transform to avoid needing to
412*da0073e9SAndroid Build Coastguard Worker            # generate true half-complex input
413*da0073e9SAndroid Build Coastguard Worker            fft_functions += [(torch.fft.rfftn, torch.fft.irfftn),
414*da0073e9SAndroid Build Coastguard Worker                              (torch.fft.ihfftn, torch.fft.hfftn)]
415*da0073e9SAndroid Build Coastguard Worker
416*da0073e9SAndroid Build Coastguard Worker        for input_ndim, dim in transform_desc:
417*da0073e9SAndroid Build Coastguard Worker            if dtype in (torch.half, torch.complex32):
418*da0073e9SAndroid Build Coastguard Worker                # cuFFT supports powers of 2 for half and complex half precision
419*da0073e9SAndroid Build Coastguard Worker                shape = itertools.islice(itertools.cycle((2, 4, 8)), input_ndim)
420*da0073e9SAndroid Build Coastguard Worker            else:
421*da0073e9SAndroid Build Coastguard Worker                shape = itertools.islice(itertools.cycle(range(4, 9)), input_ndim)
422*da0073e9SAndroid Build Coastguard Worker            x = torch.randn(*shape, device=device, dtype=dtype)
423*da0073e9SAndroid Build Coastguard Worker
424*da0073e9SAndroid Build Coastguard Worker            for (forward, backward), norm in product(fft_functions, norm_modes):
425*da0073e9SAndroid Build Coastguard Worker                if isinstance(dim, tuple):
426*da0073e9SAndroid Build Coastguard Worker                    s = [x.size(d) for d in dim]
427*da0073e9SAndroid Build Coastguard Worker                else:
428*da0073e9SAndroid Build Coastguard Worker                    s = x.size() if dim is None else x.size(dim)
429*da0073e9SAndroid Build Coastguard Worker
430*da0073e9SAndroid Build Coastguard Worker                kwargs = {'s': s, 'dim': dim, 'norm': norm}
431*da0073e9SAndroid Build Coastguard Worker                y = backward(forward(x, **kwargs), **kwargs)
432*da0073e9SAndroid Build Coastguard Worker                # For real input, ifftn(fftn(x)) will convert to complex
433*da0073e9SAndroid Build Coastguard Worker                if x.dtype is torch.half and y.dtype is torch.chalf:
434*da0073e9SAndroid Build Coastguard Worker                    # Since type promotion currently doesn't work with complex32
435*da0073e9SAndroid Build Coastguard Worker                    # manually promote `x` to complex32
436*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(x.to(torch.chalf), y)
437*da0073e9SAndroid Build Coastguard Worker                else:
438*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(x, y, exact_dtype=(
439*da0073e9SAndroid Build Coastguard Worker                        forward != torch.fft.fftn or x.is_complex()))
440*da0073e9SAndroid Build Coastguard Worker
441*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
442*da0073e9SAndroid Build Coastguard Worker    @ops([op for op in spectral_funcs if op.ndimensional == SpectralFuncType.ND],
443*da0073e9SAndroid Build Coastguard Worker         allowed_dtypes=[torch.float, torch.cfloat])
444*da0073e9SAndroid Build Coastguard Worker    def test_fftn_invalid(self, device, dtype, op):
445*da0073e9SAndroid Build Coastguard Worker        a = torch.rand(10, 10, 10, device=device, dtype=dtype)
446*da0073e9SAndroid Build Coastguard Worker        # FIXME: https://github.com/pytorch/pytorch/issues/108205
447*da0073e9SAndroid Build Coastguard Worker        errMsg = "dims must be unique"
448*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, errMsg):
449*da0073e9SAndroid Build Coastguard Worker            op(a, dim=(0, 1, 0))
450*da0073e9SAndroid Build Coastguard Worker
451*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, errMsg):
452*da0073e9SAndroid Build Coastguard Worker            op(a, dim=(2, -1))
453*da0073e9SAndroid Build Coastguard Worker
454*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "dim and shape .* same length"):
455*da0073e9SAndroid Build Coastguard Worker            op(a, s=(1,), dim=(0, 1))
456*da0073e9SAndroid Build Coastguard Worker
457*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(IndexError, "Dimension out of range"):
458*da0073e9SAndroid Build Coastguard Worker            op(a, dim=(3,))
459*da0073e9SAndroid Build Coastguard Worker
460*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "tensor only has 3 dimensions"):
461*da0073e9SAndroid Build Coastguard Worker            op(a, s=(10, 10, 10, 10))
462*da0073e9SAndroid Build Coastguard Worker
463*da0073e9SAndroid Build Coastguard Worker    @skipCPUIfNoFFT
464*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
465*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.half, torch.float, torch.double, torch.cfloat, torch.cdouble)
466*da0073e9SAndroid Build Coastguard Worker    def test_fftn_noop_transform(self, device, dtype):
467*da0073e9SAndroid Build Coastguard Worker        skip_helper_for_fft(device, dtype)
468*da0073e9SAndroid Build Coastguard Worker        RESULT_TYPE = {
469*da0073e9SAndroid Build Coastguard Worker            torch.half: torch.chalf,
470*da0073e9SAndroid Build Coastguard Worker            torch.float: torch.cfloat,
471*da0073e9SAndroid Build Coastguard Worker            torch.double: torch.cdouble,
472*da0073e9SAndroid Build Coastguard Worker        }
473*da0073e9SAndroid Build Coastguard Worker
474*da0073e9SAndroid Build Coastguard Worker        for op in [
475*da0073e9SAndroid Build Coastguard Worker            torch.fft.fftn,
476*da0073e9SAndroid Build Coastguard Worker            torch.fft.ifftn,
477*da0073e9SAndroid Build Coastguard Worker            torch.fft.fft2,
478*da0073e9SAndroid Build Coastguard Worker            torch.fft.ifft2,
479*da0073e9SAndroid Build Coastguard Worker        ]:
480*da0073e9SAndroid Build Coastguard Worker            inp = make_tensor((10, 10), device=device, dtype=dtype)
481*da0073e9SAndroid Build Coastguard Worker            out = torch.fft.fftn(inp, dim=[])
482*da0073e9SAndroid Build Coastguard Worker
483*da0073e9SAndroid Build Coastguard Worker            expect_dtype = RESULT_TYPE.get(inp.dtype, inp.dtype)
484*da0073e9SAndroid Build Coastguard Worker            expect = inp.to(expect_dtype)
485*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(expect, out)
486*da0073e9SAndroid Build Coastguard Worker
487*da0073e9SAndroid Build Coastguard Worker
488*da0073e9SAndroid Build Coastguard Worker    @skipCPUIfNoFFT
489*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
490*da0073e9SAndroid Build Coastguard Worker    @toleranceOverride({
491*da0073e9SAndroid Build Coastguard Worker        torch.half : tol(1e-2, 1e-2),
492*da0073e9SAndroid Build Coastguard Worker    })
493*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.half, torch.float, torch.double)
494*da0073e9SAndroid Build Coastguard Worker    def test_hfftn(self, device, dtype):
495*da0073e9SAndroid Build Coastguard Worker        skip_helper_for_fft(device, dtype)
496*da0073e9SAndroid Build Coastguard Worker
497*da0073e9SAndroid Build Coastguard Worker        # input_ndim, dim
498*da0073e9SAndroid Build Coastguard Worker        transform_desc = [
499*da0073e9SAndroid Build Coastguard Worker            *product(range(2, 5), (None, (0,), (0, -1))),
500*da0073e9SAndroid Build Coastguard Worker            (6, None),
501*da0073e9SAndroid Build Coastguard Worker            (5, (1, 3, 4)),
502*da0073e9SAndroid Build Coastguard Worker            (3, (1,)),
503*da0073e9SAndroid Build Coastguard Worker            (1, (0,)),
504*da0073e9SAndroid Build Coastguard Worker            (4, (0, 1))
505*da0073e9SAndroid Build Coastguard Worker        ]
506*da0073e9SAndroid Build Coastguard Worker
507*da0073e9SAndroid Build Coastguard Worker        for input_ndim, dim in transform_desc:
508*da0073e9SAndroid Build Coastguard Worker            actual_dims = list(range(input_ndim)) if dim is None else dim
509*da0073e9SAndroid Build Coastguard Worker            if dtype is torch.half:
510*da0073e9SAndroid Build Coastguard Worker                shape = tuple(itertools.islice(itertools.cycle((2, 4, 8)), input_ndim))
511*da0073e9SAndroid Build Coastguard Worker            else:
512*da0073e9SAndroid Build Coastguard Worker                shape = tuple(itertools.islice(itertools.cycle(range(4, 9)), input_ndim))
513*da0073e9SAndroid Build Coastguard Worker            expect = torch.randn(*shape, device=device, dtype=dtype)
514*da0073e9SAndroid Build Coastguard Worker            input = torch.fft.ifftn(expect, dim=dim, norm="ortho")
515*da0073e9SAndroid Build Coastguard Worker
516*da0073e9SAndroid Build Coastguard Worker            lastdim = actual_dims[-1]
517*da0073e9SAndroid Build Coastguard Worker            lastdim_size = input.size(lastdim) // 2 + 1
518*da0073e9SAndroid Build Coastguard Worker            idx = [slice(None)] * input_ndim
519*da0073e9SAndroid Build Coastguard Worker            idx[lastdim] = slice(0, lastdim_size)
520*da0073e9SAndroid Build Coastguard Worker            input = input[idx]
521*da0073e9SAndroid Build Coastguard Worker
522*da0073e9SAndroid Build Coastguard Worker            s = [shape[dim] for dim in actual_dims]
523*da0073e9SAndroid Build Coastguard Worker            actual = torch.fft.hfftn(input, s=s, dim=dim, norm="ortho")
524*da0073e9SAndroid Build Coastguard Worker
525*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(expect, actual)
526*da0073e9SAndroid Build Coastguard Worker
527*da0073e9SAndroid Build Coastguard Worker    @skipCPUIfNoFFT
528*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
529*da0073e9SAndroid Build Coastguard Worker    @toleranceOverride({
530*da0073e9SAndroid Build Coastguard Worker        torch.half : tol(1e-2, 1e-2),
531*da0073e9SAndroid Build Coastguard Worker    })
532*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.half, torch.float, torch.double)
533*da0073e9SAndroid Build Coastguard Worker    def test_ihfftn(self, device, dtype):
534*da0073e9SAndroid Build Coastguard Worker        skip_helper_for_fft(device, dtype)
535*da0073e9SAndroid Build Coastguard Worker
536*da0073e9SAndroid Build Coastguard Worker        # input_ndim, dim
537*da0073e9SAndroid Build Coastguard Worker        transform_desc = [
538*da0073e9SAndroid Build Coastguard Worker            *product(range(2, 5), (None, (0,), (0, -1))),
539*da0073e9SAndroid Build Coastguard Worker            (6, None),
540*da0073e9SAndroid Build Coastguard Worker            (5, (1, 3, 4)),
541*da0073e9SAndroid Build Coastguard Worker            (3, (1,)),
542*da0073e9SAndroid Build Coastguard Worker            (1, (0,)),
543*da0073e9SAndroid Build Coastguard Worker            (4, (0, 1))
544*da0073e9SAndroid Build Coastguard Worker        ]
545*da0073e9SAndroid Build Coastguard Worker
546*da0073e9SAndroid Build Coastguard Worker        for input_ndim, dim in transform_desc:
547*da0073e9SAndroid Build Coastguard Worker            if dtype is torch.half:
548*da0073e9SAndroid Build Coastguard Worker                shape = tuple(itertools.islice(itertools.cycle((2, 4, 8)), input_ndim))
549*da0073e9SAndroid Build Coastguard Worker            else:
550*da0073e9SAndroid Build Coastguard Worker                shape = tuple(itertools.islice(itertools.cycle(range(4, 9)), input_ndim))
551*da0073e9SAndroid Build Coastguard Worker
552*da0073e9SAndroid Build Coastguard Worker            input = torch.randn(*shape, device=device, dtype=dtype)
553*da0073e9SAndroid Build Coastguard Worker            expect = torch.fft.ifftn(input, dim=dim, norm="ortho")
554*da0073e9SAndroid Build Coastguard Worker
555*da0073e9SAndroid Build Coastguard Worker            # Slice off the half-symmetric component
556*da0073e9SAndroid Build Coastguard Worker            lastdim = -1 if dim is None else dim[-1]
557*da0073e9SAndroid Build Coastguard Worker            lastdim_size = expect.size(lastdim) // 2 + 1
558*da0073e9SAndroid Build Coastguard Worker            idx = [slice(None)] * input_ndim
559*da0073e9SAndroid Build Coastguard Worker            idx[lastdim] = slice(0, lastdim_size)
560*da0073e9SAndroid Build Coastguard Worker            expect = expect[idx]
561*da0073e9SAndroid Build Coastguard Worker
562*da0073e9SAndroid Build Coastguard Worker            actual = torch.fft.ihfftn(input, dim=dim, norm="ortho")
563*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(expect, actual)
564*da0073e9SAndroid Build Coastguard Worker
565*da0073e9SAndroid Build Coastguard Worker
566*da0073e9SAndroid Build Coastguard Worker    # 2d-fft tests
567*da0073e9SAndroid Build Coastguard Worker
568*da0073e9SAndroid Build Coastguard Worker    # NOTE: 2d transforms are only thin wrappers over n-dim transforms,
569*da0073e9SAndroid Build Coastguard Worker    # so don't require exhaustive testing.
570*da0073e9SAndroid Build Coastguard Worker
571*da0073e9SAndroid Build Coastguard Worker
572*da0073e9SAndroid Build Coastguard Worker    @skipCPUIfNoFFT
573*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
574*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.double, torch.complex128)
575*da0073e9SAndroid Build Coastguard Worker    def test_fft2_numpy(self, device, dtype):
576*da0073e9SAndroid Build Coastguard Worker        norm_modes = REFERENCE_NORM_MODES
577*da0073e9SAndroid Build Coastguard Worker
578*da0073e9SAndroid Build Coastguard Worker        # input_ndim, s
579*da0073e9SAndroid Build Coastguard Worker        transform_desc = [
580*da0073e9SAndroid Build Coastguard Worker            *product(range(2, 5), (None, (4, 10))),
581*da0073e9SAndroid Build Coastguard Worker        ]
582*da0073e9SAndroid Build Coastguard Worker
583*da0073e9SAndroid Build Coastguard Worker        fft_functions = ['fft2', 'ifft2', 'irfft2', 'hfft2']
584*da0073e9SAndroid Build Coastguard Worker        if dtype.is_floating_point:
585*da0073e9SAndroid Build Coastguard Worker            fft_functions += ['rfft2', 'ihfft2']
586*da0073e9SAndroid Build Coastguard Worker
587*da0073e9SAndroid Build Coastguard Worker        for input_ndim, s in transform_desc:
588*da0073e9SAndroid Build Coastguard Worker            shape = itertools.islice(itertools.cycle(range(4, 9)), input_ndim)
589*da0073e9SAndroid Build Coastguard Worker            input = torch.randn(*shape, device=device, dtype=dtype)
590*da0073e9SAndroid Build Coastguard Worker            for fname, norm in product(fft_functions, norm_modes):
591*da0073e9SAndroid Build Coastguard Worker                torch_fn = getattr(torch.fft, fname)
592*da0073e9SAndroid Build Coastguard Worker                if "hfft" in fname:
593*da0073e9SAndroid Build Coastguard Worker                    if not has_scipy_fft:
594*da0073e9SAndroid Build Coastguard Worker                        continue  # Requires scipy to compare against
595*da0073e9SAndroid Build Coastguard Worker                    numpy_fn = getattr(scipy.fft, fname)
596*da0073e9SAndroid Build Coastguard Worker                else:
597*da0073e9SAndroid Build Coastguard Worker                    numpy_fn = getattr(np.fft, fname)
598*da0073e9SAndroid Build Coastguard Worker
599*da0073e9SAndroid Build Coastguard Worker                def fn(t: torch.Tensor, s: Optional[List[int]], dim: List[int] = (-2, -1), norm: Optional[str] = None):
600*da0073e9SAndroid Build Coastguard Worker                    return torch_fn(t, s, dim, norm)
601*da0073e9SAndroid Build Coastguard Worker
602*da0073e9SAndroid Build Coastguard Worker                torch_fns = (torch_fn, torch.jit.script(fn))
603*da0073e9SAndroid Build Coastguard Worker
604*da0073e9SAndroid Build Coastguard Worker                # Once with dim defaulted
605*da0073e9SAndroid Build Coastguard Worker                input_np = input.cpu().numpy()
606*da0073e9SAndroid Build Coastguard Worker                expected = numpy_fn(input_np, s, norm=norm)
607*da0073e9SAndroid Build Coastguard Worker                for fn in torch_fns:
608*da0073e9SAndroid Build Coastguard Worker                    actual = fn(input, s, norm=norm)
609*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(actual, expected)
610*da0073e9SAndroid Build Coastguard Worker
611*da0073e9SAndroid Build Coastguard Worker                # Once with explicit dims
612*da0073e9SAndroid Build Coastguard Worker                dim = (1, 0)
613*da0073e9SAndroid Build Coastguard Worker                expected = numpy_fn(input_np, s, dim, norm)
614*da0073e9SAndroid Build Coastguard Worker                for fn in torch_fns:
615*da0073e9SAndroid Build Coastguard Worker                    actual = fn(input, s, dim, norm)
616*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(actual, expected)
617*da0073e9SAndroid Build Coastguard Worker
618*da0073e9SAndroid Build Coastguard Worker    @skipCPUIfNoFFT
619*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
620*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float, torch.complex64)
621*da0073e9SAndroid Build Coastguard Worker    def test_fft2_fftn_equivalence(self, device, dtype):
622*da0073e9SAndroid Build Coastguard Worker        norm_modes = (None, "forward", "backward", "ortho")
623*da0073e9SAndroid Build Coastguard Worker
624*da0073e9SAndroid Build Coastguard Worker        # input_ndim, s, dim
625*da0073e9SAndroid Build Coastguard Worker        transform_desc = [
626*da0073e9SAndroid Build Coastguard Worker            *product(range(2, 5), (None, (4, 10)), (None, (1, 0))),
627*da0073e9SAndroid Build Coastguard Worker            (3, None, (0, 2)),
628*da0073e9SAndroid Build Coastguard Worker        ]
629*da0073e9SAndroid Build Coastguard Worker
630*da0073e9SAndroid Build Coastguard Worker        fft_functions = ['fft', 'ifft', 'irfft', 'hfft']
631*da0073e9SAndroid Build Coastguard Worker        # Real-only functions
632*da0073e9SAndroid Build Coastguard Worker        if dtype.is_floating_point:
633*da0073e9SAndroid Build Coastguard Worker            fft_functions += ['rfft', 'ihfft']
634*da0073e9SAndroid Build Coastguard Worker
635*da0073e9SAndroid Build Coastguard Worker        for input_ndim, s, dim in transform_desc:
636*da0073e9SAndroid Build Coastguard Worker            shape = itertools.islice(itertools.cycle(range(4, 9)), input_ndim)
637*da0073e9SAndroid Build Coastguard Worker            x = torch.randn(*shape, device=device, dtype=dtype)
638*da0073e9SAndroid Build Coastguard Worker
639*da0073e9SAndroid Build Coastguard Worker            for func, norm in product(fft_functions, norm_modes):
640*da0073e9SAndroid Build Coastguard Worker                f2d = getattr(torch.fft, func + '2')
641*da0073e9SAndroid Build Coastguard Worker                fnd = getattr(torch.fft, func + 'n')
642*da0073e9SAndroid Build Coastguard Worker
643*da0073e9SAndroid Build Coastguard Worker                kwargs = {'s': s, 'norm': norm}
644*da0073e9SAndroid Build Coastguard Worker
645*da0073e9SAndroid Build Coastguard Worker                if dim is not None:
646*da0073e9SAndroid Build Coastguard Worker                    kwargs['dim'] = dim
647*da0073e9SAndroid Build Coastguard Worker                    expect = fnd(x, **kwargs)
648*da0073e9SAndroid Build Coastguard Worker                else:
649*da0073e9SAndroid Build Coastguard Worker                    expect = fnd(x, dim=(-2, -1), **kwargs)
650*da0073e9SAndroid Build Coastguard Worker
651*da0073e9SAndroid Build Coastguard Worker                actual = f2d(x, **kwargs)
652*da0073e9SAndroid Build Coastguard Worker
653*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(actual, expect)
654*da0073e9SAndroid Build Coastguard Worker
655*da0073e9SAndroid Build Coastguard Worker    @skipCPUIfNoFFT
656*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
657*da0073e9SAndroid Build Coastguard Worker    def test_fft2_invalid(self, device):
658*da0073e9SAndroid Build Coastguard Worker        a = torch.rand(10, 10, 10, device=device)
659*da0073e9SAndroid Build Coastguard Worker        fft_funcs = (torch.fft.fft2, torch.fft.ifft2,
660*da0073e9SAndroid Build Coastguard Worker                     torch.fft.rfft2, torch.fft.irfft2)
661*da0073e9SAndroid Build Coastguard Worker
662*da0073e9SAndroid Build Coastguard Worker        for func in fft_funcs:
663*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(RuntimeError, "dims must be unique"):
664*da0073e9SAndroid Build Coastguard Worker                func(a, dim=(0, 0))
665*da0073e9SAndroid Build Coastguard Worker
666*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(RuntimeError, "dims must be unique"):
667*da0073e9SAndroid Build Coastguard Worker                func(a, dim=(2, -1))
668*da0073e9SAndroid Build Coastguard Worker
669*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(RuntimeError, "dim and shape .* same length"):
670*da0073e9SAndroid Build Coastguard Worker                func(a, s=(1,))
671*da0073e9SAndroid Build Coastguard Worker
672*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(IndexError, "Dimension out of range"):
673*da0073e9SAndroid Build Coastguard Worker                func(a, dim=(2, 3))
674*da0073e9SAndroid Build Coastguard Worker
675*da0073e9SAndroid Build Coastguard Worker        c = torch.complex(a, a)
676*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "rfftn expects a real-valued input"):
677*da0073e9SAndroid Build Coastguard Worker            torch.fft.rfft2(c)
678*da0073e9SAndroid Build Coastguard Worker
679*da0073e9SAndroid Build Coastguard Worker    # Helper functions
680*da0073e9SAndroid Build Coastguard Worker
681*da0073e9SAndroid Build Coastguard Worker    @skipCPUIfNoFFT
682*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
683*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_NUMPY, 'NumPy not found')
684*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float, torch.double)
685*da0073e9SAndroid Build Coastguard Worker    def test_fftfreq_numpy(self, device, dtype):
686*da0073e9SAndroid Build Coastguard Worker        test_args = [
687*da0073e9SAndroid Build Coastguard Worker            *product(
688*da0073e9SAndroid Build Coastguard Worker                # n
689*da0073e9SAndroid Build Coastguard Worker                range(1, 20),
690*da0073e9SAndroid Build Coastguard Worker                # d
691*da0073e9SAndroid Build Coastguard Worker                (None, 10.0),
692*da0073e9SAndroid Build Coastguard Worker            )
693*da0073e9SAndroid Build Coastguard Worker        ]
694*da0073e9SAndroid Build Coastguard Worker
695*da0073e9SAndroid Build Coastguard Worker        functions = ['fftfreq', 'rfftfreq']
696*da0073e9SAndroid Build Coastguard Worker
697*da0073e9SAndroid Build Coastguard Worker        for fname in functions:
698*da0073e9SAndroid Build Coastguard Worker            torch_fn = getattr(torch.fft, fname)
699*da0073e9SAndroid Build Coastguard Worker            numpy_fn = getattr(np.fft, fname)
700*da0073e9SAndroid Build Coastguard Worker
701*da0073e9SAndroid Build Coastguard Worker            for n, d in test_args:
702*da0073e9SAndroid Build Coastguard Worker                args = (n,) if d is None else (n, d)
703*da0073e9SAndroid Build Coastguard Worker                expected = numpy_fn(*args)
704*da0073e9SAndroid Build Coastguard Worker                actual = torch_fn(*args, device=device, dtype=dtype)
705*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(actual, expected, exact_dtype=False)
706*da0073e9SAndroid Build Coastguard Worker
707*da0073e9SAndroid Build Coastguard Worker    @skipCPUIfNoFFT
708*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
709*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float, torch.double)
710*da0073e9SAndroid Build Coastguard Worker    def test_fftfreq_out(self, device, dtype):
711*da0073e9SAndroid Build Coastguard Worker        for func in (torch.fft.fftfreq, torch.fft.rfftfreq):
712*da0073e9SAndroid Build Coastguard Worker            expect = func(n=100, d=.5, device=device, dtype=dtype)
713*da0073e9SAndroid Build Coastguard Worker            actual = torch.empty((), device=device, dtype=dtype)
714*da0073e9SAndroid Build Coastguard Worker            with self.assertWarnsRegex(UserWarning, "out tensor will be resized"):
715*da0073e9SAndroid Build Coastguard Worker                func(n=100, d=.5, out=actual)
716*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(actual, expect)
717*da0073e9SAndroid Build Coastguard Worker
718*da0073e9SAndroid Build Coastguard Worker
719*da0073e9SAndroid Build Coastguard Worker    @skipCPUIfNoFFT
720*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
721*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_NUMPY, 'NumPy not found')
722*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float, torch.double, torch.complex64, torch.complex128)
723*da0073e9SAndroid Build Coastguard Worker    def test_fftshift_numpy(self, device, dtype):
724*da0073e9SAndroid Build Coastguard Worker        test_args = [
725*da0073e9SAndroid Build Coastguard Worker            # shape, dim
726*da0073e9SAndroid Build Coastguard Worker            *product(((11,), (12,)), (None, 0, -1)),
727*da0073e9SAndroid Build Coastguard Worker            *product(((4, 5), (6, 6)), (None, 0, (-1,))),
728*da0073e9SAndroid Build Coastguard Worker            *product(((1, 1, 4, 6, 7, 2),), (None, (3, 4))),
729*da0073e9SAndroid Build Coastguard Worker        ]
730*da0073e9SAndroid Build Coastguard Worker
731*da0073e9SAndroid Build Coastguard Worker        functions = ['fftshift', 'ifftshift']
732*da0073e9SAndroid Build Coastguard Worker
733*da0073e9SAndroid Build Coastguard Worker        for shape, dim in test_args:
734*da0073e9SAndroid Build Coastguard Worker            input = torch.rand(*shape, device=device, dtype=dtype)
735*da0073e9SAndroid Build Coastguard Worker            input_np = input.cpu().numpy()
736*da0073e9SAndroid Build Coastguard Worker
737*da0073e9SAndroid Build Coastguard Worker            for fname in functions:
738*da0073e9SAndroid Build Coastguard Worker                torch_fn = getattr(torch.fft, fname)
739*da0073e9SAndroid Build Coastguard Worker                numpy_fn = getattr(np.fft, fname)
740*da0073e9SAndroid Build Coastguard Worker
741*da0073e9SAndroid Build Coastguard Worker                expected = numpy_fn(input_np, axes=dim)
742*da0073e9SAndroid Build Coastguard Worker                actual = torch_fn(input, dim=dim)
743*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(actual, expected)
744*da0073e9SAndroid Build Coastguard Worker
745*da0073e9SAndroid Build Coastguard Worker    @skipCPUIfNoFFT
746*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
747*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_NUMPY, 'NumPy not found')
748*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float, torch.double)
749*da0073e9SAndroid Build Coastguard Worker    def test_fftshift_frequencies(self, device, dtype):
750*da0073e9SAndroid Build Coastguard Worker        for n in range(10, 15):
751*da0073e9SAndroid Build Coastguard Worker            sorted_fft_freqs = torch.arange(-(n // 2), n - (n // 2),
752*da0073e9SAndroid Build Coastguard Worker                                            device=device, dtype=dtype)
753*da0073e9SAndroid Build Coastguard Worker            x = torch.fft.fftfreq(n, d=1 / n, device=device, dtype=dtype)
754*da0073e9SAndroid Build Coastguard Worker
755*da0073e9SAndroid Build Coastguard Worker            # Test fftshift sorts the fftfreq output
756*da0073e9SAndroid Build Coastguard Worker            shifted = torch.fft.fftshift(x)
757*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(shifted, shifted.sort().values)
758*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(sorted_fft_freqs, shifted)
759*da0073e9SAndroid Build Coastguard Worker
760*da0073e9SAndroid Build Coastguard Worker            # And ifftshift is the inverse
761*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x, torch.fft.ifftshift(shifted))
762*da0073e9SAndroid Build Coastguard Worker
763*da0073e9SAndroid Build Coastguard Worker    # Legacy fft tests
764*da0073e9SAndroid Build Coastguard Worker    def _test_fft_ifft_rfft_irfft(self, device, dtype):
765*da0073e9SAndroid Build Coastguard Worker        complex_dtype = corresponding_complex_dtype(dtype)
766*da0073e9SAndroid Build Coastguard Worker
767*da0073e9SAndroid Build Coastguard Worker        def _test_complex(sizes, signal_ndim, prepro_fn=lambda x: x):
768*da0073e9SAndroid Build Coastguard Worker            x = prepro_fn(torch.randn(*sizes, dtype=complex_dtype, device=device))
769*da0073e9SAndroid Build Coastguard Worker            dim = tuple(range(-signal_ndim, 0))
770*da0073e9SAndroid Build Coastguard Worker            for norm in ('ortho', None):
771*da0073e9SAndroid Build Coastguard Worker                res = torch.fft.fftn(x, dim=dim, norm=norm)
772*da0073e9SAndroid Build Coastguard Worker                rec = torch.fft.ifftn(res, dim=dim, norm=norm)
773*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(x, rec, atol=1e-8, rtol=0, msg='fft and ifft')
774*da0073e9SAndroid Build Coastguard Worker                res = torch.fft.ifftn(x, dim=dim, norm=norm)
775*da0073e9SAndroid Build Coastguard Worker                rec = torch.fft.fftn(res, dim=dim, norm=norm)
776*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(x, rec, atol=1e-8, rtol=0, msg='ifft and fft')
777*da0073e9SAndroid Build Coastguard Worker
778*da0073e9SAndroid Build Coastguard Worker        def _test_real(sizes, signal_ndim, prepro_fn=lambda x: x):
779*da0073e9SAndroid Build Coastguard Worker            x = prepro_fn(torch.randn(*sizes, dtype=dtype, device=device))
780*da0073e9SAndroid Build Coastguard Worker            signal_numel = 1
781*da0073e9SAndroid Build Coastguard Worker            signal_sizes = x.size()[-signal_ndim:]
782*da0073e9SAndroid Build Coastguard Worker            dim = tuple(range(-signal_ndim, 0))
783*da0073e9SAndroid Build Coastguard Worker            for norm in (None, 'ortho'):
784*da0073e9SAndroid Build Coastguard Worker                res = torch.fft.rfftn(x, dim=dim, norm=norm)
785*da0073e9SAndroid Build Coastguard Worker                rec = torch.fft.irfftn(res, s=signal_sizes, dim=dim, norm=norm)
786*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(x, rec, atol=1e-8, rtol=0, msg='rfft and irfft')
787*da0073e9SAndroid Build Coastguard Worker                res = torch.fft.fftn(x, dim=dim, norm=norm)
788*da0073e9SAndroid Build Coastguard Worker                rec = torch.fft.ifftn(res, dim=dim, norm=norm)
789*da0073e9SAndroid Build Coastguard Worker                x_complex = torch.complex(x, torch.zeros_like(x))
790*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(x_complex, rec, atol=1e-8, rtol=0, msg='fft and ifft (from real)')
791*da0073e9SAndroid Build Coastguard Worker
792*da0073e9SAndroid Build Coastguard Worker        # contiguous case
793*da0073e9SAndroid Build Coastguard Worker        _test_real((100,), 1)
794*da0073e9SAndroid Build Coastguard Worker        _test_real((10, 1, 10, 100), 1)
795*da0073e9SAndroid Build Coastguard Worker        _test_real((100, 100), 2)
796*da0073e9SAndroid Build Coastguard Worker        _test_real((2, 2, 5, 80, 60), 2)
797*da0073e9SAndroid Build Coastguard Worker        _test_real((50, 40, 70), 3)
798*da0073e9SAndroid Build Coastguard Worker        _test_real((30, 1, 50, 25, 20), 3)
799*da0073e9SAndroid Build Coastguard Worker
800*da0073e9SAndroid Build Coastguard Worker        _test_complex((100,), 1)
801*da0073e9SAndroid Build Coastguard Worker        _test_complex((100, 100), 1)
802*da0073e9SAndroid Build Coastguard Worker        _test_complex((100, 100), 2)
803*da0073e9SAndroid Build Coastguard Worker        _test_complex((1, 20, 80, 60), 2)
804*da0073e9SAndroid Build Coastguard Worker        _test_complex((50, 40, 70), 3)
805*da0073e9SAndroid Build Coastguard Worker        _test_complex((6, 5, 50, 25, 20), 3)
806*da0073e9SAndroid Build Coastguard Worker
807*da0073e9SAndroid Build Coastguard Worker        # non-contiguous case
808*da0073e9SAndroid Build Coastguard Worker        _test_real((165,), 1, lambda x: x.narrow(0, 25, 100))  # input is not aligned to complex type
809*da0073e9SAndroid Build Coastguard Worker        _test_real((100, 100, 3), 1, lambda x: x[:, :, 0])
810*da0073e9SAndroid Build Coastguard Worker        _test_real((100, 100), 2, lambda x: x.t())
811*da0073e9SAndroid Build Coastguard Worker        _test_real((20, 100, 10, 10), 2, lambda x: x.view(20, 100, 100)[:, :60])
812*da0073e9SAndroid Build Coastguard Worker        _test_real((65, 80, 115), 3, lambda x: x[10:60, 13:53, 10:80])
813*da0073e9SAndroid Build Coastguard Worker        _test_real((30, 20, 50, 25), 3, lambda x: x.transpose(1, 2).transpose(2, 3))
814*da0073e9SAndroid Build Coastguard Worker
815*da0073e9SAndroid Build Coastguard Worker        _test_complex((100,), 1, lambda x: x.expand(100, 100))
816*da0073e9SAndroid Build Coastguard Worker        _test_complex((20, 90, 110), 2, lambda x: x[:, 5:85].narrow(2, 5, 100))
817*da0073e9SAndroid Build Coastguard Worker        _test_complex((40, 60, 3, 80), 3, lambda x: x.transpose(2, 0).select(0, 2)[5:55, :, 10:])
818*da0073e9SAndroid Build Coastguard Worker        _test_complex((30, 55, 50, 22), 3, lambda x: x[:, 3:53, 15:40, 1:21])
819*da0073e9SAndroid Build Coastguard Worker
820*da0073e9SAndroid Build Coastguard Worker    @skipCPUIfNoFFT
821*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
822*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.double)
823*da0073e9SAndroid Build Coastguard Worker    def test_fft_ifft_rfft_irfft(self, device, dtype):
824*da0073e9SAndroid Build Coastguard Worker        self._test_fft_ifft_rfft_irfft(device, dtype)
825*da0073e9SAndroid Build Coastguard Worker
826*da0073e9SAndroid Build Coastguard Worker    @deviceCountAtLeast(1)
827*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
828*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.double)
829*da0073e9SAndroid Build Coastguard Worker    def test_cufft_plan_cache(self, devices, dtype):
830*da0073e9SAndroid Build Coastguard Worker        @contextmanager
831*da0073e9SAndroid Build Coastguard Worker        def plan_cache_max_size(device, n):
832*da0073e9SAndroid Build Coastguard Worker            if device is None:
833*da0073e9SAndroid Build Coastguard Worker                plan_cache = torch.backends.cuda.cufft_plan_cache
834*da0073e9SAndroid Build Coastguard Worker            else:
835*da0073e9SAndroid Build Coastguard Worker                plan_cache = torch.backends.cuda.cufft_plan_cache[device]
836*da0073e9SAndroid Build Coastguard Worker            original = plan_cache.max_size
837*da0073e9SAndroid Build Coastguard Worker            plan_cache.max_size = n
838*da0073e9SAndroid Build Coastguard Worker            try:
839*da0073e9SAndroid Build Coastguard Worker                yield
840*da0073e9SAndroid Build Coastguard Worker            finally:
841*da0073e9SAndroid Build Coastguard Worker                plan_cache.max_size = original
842*da0073e9SAndroid Build Coastguard Worker
843*da0073e9SAndroid Build Coastguard Worker        with plan_cache_max_size(devices[0], max(1, torch.backends.cuda.cufft_plan_cache.size - 10)):
844*da0073e9SAndroid Build Coastguard Worker            self._test_fft_ifft_rfft_irfft(devices[0], dtype)
845*da0073e9SAndroid Build Coastguard Worker
846*da0073e9SAndroid Build Coastguard Worker        with plan_cache_max_size(devices[0], 0):
847*da0073e9SAndroid Build Coastguard Worker            self._test_fft_ifft_rfft_irfft(devices[0], dtype)
848*da0073e9SAndroid Build Coastguard Worker
849*da0073e9SAndroid Build Coastguard Worker        torch.backends.cuda.cufft_plan_cache.clear()
850*da0073e9SAndroid Build Coastguard Worker
851*da0073e9SAndroid Build Coastguard Worker        # check that stll works after clearing cache
852*da0073e9SAndroid Build Coastguard Worker        with plan_cache_max_size(devices[0], 10):
853*da0073e9SAndroid Build Coastguard Worker            self._test_fft_ifft_rfft_irfft(devices[0], dtype)
854*da0073e9SAndroid Build Coastguard Worker
855*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, r"must be non-negative"):
856*da0073e9SAndroid Build Coastguard Worker            torch.backends.cuda.cufft_plan_cache.max_size = -1
857*da0073e9SAndroid Build Coastguard Worker
858*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, r"read-only property"):
859*da0073e9SAndroid Build Coastguard Worker            torch.backends.cuda.cufft_plan_cache.size = -1
860*da0073e9SAndroid Build Coastguard Worker
861*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, r"but got device with index"):
862*da0073e9SAndroid Build Coastguard Worker            torch.backends.cuda.cufft_plan_cache[torch.cuda.device_count() + 10]
863*da0073e9SAndroid Build Coastguard Worker
864*da0073e9SAndroid Build Coastguard Worker        # Multigpu tests
865*da0073e9SAndroid Build Coastguard Worker        if len(devices) > 1:
866*da0073e9SAndroid Build Coastguard Worker            # Test that different GPU has different cache
867*da0073e9SAndroid Build Coastguard Worker            x0 = torch.randn(2, 3, 3, device=devices[0])
868*da0073e9SAndroid Build Coastguard Worker            x1 = x0.to(devices[1])
869*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(torch.fft.rfftn(x0, dim=(-2, -1)), torch.fft.rfftn(x1, dim=(-2, -1)))
870*da0073e9SAndroid Build Coastguard Worker            # If a plan is used across different devices, the following line (or
871*da0073e9SAndroid Build Coastguard Worker            # the assert above) would trigger illegal memory access. Other ways
872*da0073e9SAndroid Build Coastguard Worker            # to trigger the error include
873*da0073e9SAndroid Build Coastguard Worker            #   (1) setting CUDA_LAUNCH_BLOCKING=1 (pytorch/pytorch#19224) and
874*da0073e9SAndroid Build Coastguard Worker            #   (2) printing a device 1 tensor.
875*da0073e9SAndroid Build Coastguard Worker            x0.copy_(x1)
876*da0073e9SAndroid Build Coastguard Worker
877*da0073e9SAndroid Build Coastguard Worker            # Test that un-indexed `torch.backends.cuda.cufft_plan_cache` uses current device
878*da0073e9SAndroid Build Coastguard Worker            with plan_cache_max_size(devices[0], 10):
879*da0073e9SAndroid Build Coastguard Worker                with plan_cache_max_size(devices[1], 11):
880*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(torch.backends.cuda.cufft_plan_cache[0].max_size, 10)
881*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(torch.backends.cuda.cufft_plan_cache[1].max_size, 11)
882*da0073e9SAndroid Build Coastguard Worker
883*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(torch.backends.cuda.cufft_plan_cache.max_size, 10)  # default is cuda:0
884*da0073e9SAndroid Build Coastguard Worker                    with torch.cuda.device(devices[1]):
885*da0073e9SAndroid Build Coastguard Worker                        self.assertEqual(torch.backends.cuda.cufft_plan_cache.max_size, 11)  # default is cuda:1
886*da0073e9SAndroid Build Coastguard Worker                        with torch.cuda.device(devices[0]):
887*da0073e9SAndroid Build Coastguard Worker                            self.assertEqual(torch.backends.cuda.cufft_plan_cache.max_size, 10)  # default is cuda:0
888*da0073e9SAndroid Build Coastguard Worker
889*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(torch.backends.cuda.cufft_plan_cache[0].max_size, 10)
890*da0073e9SAndroid Build Coastguard Worker                with torch.cuda.device(devices[1]):
891*da0073e9SAndroid Build Coastguard Worker                    with plan_cache_max_size(None, 11):  # default is cuda:1
892*da0073e9SAndroid Build Coastguard Worker                        self.assertEqual(torch.backends.cuda.cufft_plan_cache[0].max_size, 10)
893*da0073e9SAndroid Build Coastguard Worker                        self.assertEqual(torch.backends.cuda.cufft_plan_cache[1].max_size, 11)
894*da0073e9SAndroid Build Coastguard Worker
895*da0073e9SAndroid Build Coastguard Worker                        self.assertEqual(torch.backends.cuda.cufft_plan_cache.max_size, 11)  # default is cuda:1
896*da0073e9SAndroid Build Coastguard Worker                        with torch.cuda.device(devices[0]):
897*da0073e9SAndroid Build Coastguard Worker                            self.assertEqual(torch.backends.cuda.cufft_plan_cache.max_size, 10)  # default is cuda:0
898*da0073e9SAndroid Build Coastguard Worker                        self.assertEqual(torch.backends.cuda.cufft_plan_cache.max_size, 11)  # default is cuda:1
899*da0073e9SAndroid Build Coastguard Worker
900*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
901*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.cfloat, torch.cdouble)
902*da0073e9SAndroid Build Coastguard Worker    def test_cufft_context(self, device, dtype):
903*da0073e9SAndroid Build Coastguard Worker        # Regression test for https://github.com/pytorch/pytorch/issues/109448
904*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(32, dtype=dtype, device=device, requires_grad=True)
905*da0073e9SAndroid Build Coastguard Worker        dout = torch.zeros(32, dtype=dtype, device=device)
906*da0073e9SAndroid Build Coastguard Worker
907*da0073e9SAndroid Build Coastguard Worker        # compute iFFT(FFT(x))
908*da0073e9SAndroid Build Coastguard Worker        out = torch.fft.ifft(torch.fft.fft(x))
909*da0073e9SAndroid Build Coastguard Worker        out.backward(dout, retain_graph=True)
910*da0073e9SAndroid Build Coastguard Worker
911*da0073e9SAndroid Build Coastguard Worker        dx = torch.fft.fft(torch.fft.ifft(dout))
912*da0073e9SAndroid Build Coastguard Worker
913*da0073e9SAndroid Build Coastguard Worker        self.assertTrue((x.grad - dx).abs().max() == 0)
914*da0073e9SAndroid Build Coastguard Worker        self.assertFalse((x.grad - x).abs().max() == 0)
915*da0073e9SAndroid Build Coastguard Worker
916*da0073e9SAndroid Build Coastguard Worker    # passes on ROCm w/ python 2.7, fails w/ python 3.6
917*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchDynamo("cannot set WRITEABLE flag to True of this array")
918*da0073e9SAndroid Build Coastguard Worker    @skipCPUIfNoFFT
919*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
920*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.double)
921*da0073e9SAndroid Build Coastguard Worker    def test_stft(self, device, dtype):
922*da0073e9SAndroid Build Coastguard Worker        if not TEST_LIBROSA:
923*da0073e9SAndroid Build Coastguard Worker            raise unittest.SkipTest('librosa not found')
924*da0073e9SAndroid Build Coastguard Worker
925*da0073e9SAndroid Build Coastguard Worker        def librosa_stft(x, n_fft, hop_length, win_length, window, center):
926*da0073e9SAndroid Build Coastguard Worker            if window is None:
927*da0073e9SAndroid Build Coastguard Worker                window = np.ones(n_fft if win_length is None else win_length)
928*da0073e9SAndroid Build Coastguard Worker            else:
929*da0073e9SAndroid Build Coastguard Worker                window = window.cpu().numpy()
930*da0073e9SAndroid Build Coastguard Worker            input_1d = x.dim() == 1
931*da0073e9SAndroid Build Coastguard Worker            if input_1d:
932*da0073e9SAndroid Build Coastguard Worker                x = x.view(1, -1)
933*da0073e9SAndroid Build Coastguard Worker
934*da0073e9SAndroid Build Coastguard Worker            # NOTE: librosa 0.9 changed default pad_mode to 'constant' (zero padding)
935*da0073e9SAndroid Build Coastguard Worker            # however, we use the pre-0.9 default ('reflect')
936*da0073e9SAndroid Build Coastguard Worker            pad_mode = 'reflect'
937*da0073e9SAndroid Build Coastguard Worker
938*da0073e9SAndroid Build Coastguard Worker            result = []
939*da0073e9SAndroid Build Coastguard Worker            for xi in x:
940*da0073e9SAndroid Build Coastguard Worker                ri = librosa.stft(xi.cpu().numpy(), n_fft=n_fft, hop_length=hop_length,
941*da0073e9SAndroid Build Coastguard Worker                                  win_length=win_length, window=window, center=center,
942*da0073e9SAndroid Build Coastguard Worker                                  pad_mode=pad_mode)
943*da0073e9SAndroid Build Coastguard Worker                result.append(torch.from_numpy(np.stack([ri.real, ri.imag], -1)))
944*da0073e9SAndroid Build Coastguard Worker            result = torch.stack(result, 0)
945*da0073e9SAndroid Build Coastguard Worker            if input_1d:
946*da0073e9SAndroid Build Coastguard Worker                result = result[0]
947*da0073e9SAndroid Build Coastguard Worker            return result
948*da0073e9SAndroid Build Coastguard Worker
949*da0073e9SAndroid Build Coastguard Worker        def _test(sizes, n_fft, hop_length=None, win_length=None, win_sizes=None,
950*da0073e9SAndroid Build Coastguard Worker                  center=True, expected_error=None):
951*da0073e9SAndroid Build Coastguard Worker            x = torch.randn(*sizes, dtype=dtype, device=device)
952*da0073e9SAndroid Build Coastguard Worker            if win_sizes is not None:
953*da0073e9SAndroid Build Coastguard Worker                window = torch.randn(*win_sizes, dtype=dtype, device=device)
954*da0073e9SAndroid Build Coastguard Worker            else:
955*da0073e9SAndroid Build Coastguard Worker                window = None
956*da0073e9SAndroid Build Coastguard Worker            if expected_error is None:
957*da0073e9SAndroid Build Coastguard Worker                result = x.stft(n_fft, hop_length, win_length, window,
958*da0073e9SAndroid Build Coastguard Worker                                center=center, return_complex=False)
959*da0073e9SAndroid Build Coastguard Worker                # NB: librosa defaults to np.complex64 output, no matter what
960*da0073e9SAndroid Build Coastguard Worker                # the input dtype
961*da0073e9SAndroid Build Coastguard Worker                ref_result = librosa_stft(x, n_fft, hop_length, win_length, window, center)
962*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(result, ref_result, atol=7e-6, rtol=0, msg='stft comparison against librosa', exact_dtype=False)
963*da0073e9SAndroid Build Coastguard Worker                # With return_complex=True, the result is the same but viewed as complex instead of real
964*da0073e9SAndroid Build Coastguard Worker                result_complex = x.stft(n_fft, hop_length, win_length, window, center=center, return_complex=True)
965*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(result_complex, torch.view_as_complex(result))
966*da0073e9SAndroid Build Coastguard Worker            else:
967*da0073e9SAndroid Build Coastguard Worker                self.assertRaises(expected_error,
968*da0073e9SAndroid Build Coastguard Worker                                  lambda: x.stft(n_fft, hop_length, win_length, window, center=center))
969*da0073e9SAndroid Build Coastguard Worker
970*da0073e9SAndroid Build Coastguard Worker        for center in [True, False]:
971*da0073e9SAndroid Build Coastguard Worker            _test((10,), 7, center=center)
972*da0073e9SAndroid Build Coastguard Worker            _test((10, 4000), 1024, center=center)
973*da0073e9SAndroid Build Coastguard Worker
974*da0073e9SAndroid Build Coastguard Worker            _test((10,), 7, 2, center=center)
975*da0073e9SAndroid Build Coastguard Worker            _test((10, 4000), 1024, 512, center=center)
976*da0073e9SAndroid Build Coastguard Worker
977*da0073e9SAndroid Build Coastguard Worker            _test((10,), 7, 2, win_sizes=(7,), center=center)
978*da0073e9SAndroid Build Coastguard Worker            _test((10, 4000), 1024, 512, win_sizes=(1024,), center=center)
979*da0073e9SAndroid Build Coastguard Worker
980*da0073e9SAndroid Build Coastguard Worker            # spectral oversample
981*da0073e9SAndroid Build Coastguard Worker            _test((10,), 7, 2, win_length=5, center=center)
982*da0073e9SAndroid Build Coastguard Worker            _test((10, 4000), 1024, 512, win_length=100, center=center)
983*da0073e9SAndroid Build Coastguard Worker
984*da0073e9SAndroid Build Coastguard Worker        _test((10, 4, 2), 1, 1, expected_error=RuntimeError)
985*da0073e9SAndroid Build Coastguard Worker        _test((10,), 11, 1, center=False, expected_error=RuntimeError)
986*da0073e9SAndroid Build Coastguard Worker        _test((10,), -1, 1, expected_error=RuntimeError)
987*da0073e9SAndroid Build Coastguard Worker        _test((10,), 3, win_length=5, expected_error=RuntimeError)
988*da0073e9SAndroid Build Coastguard Worker        _test((10,), 5, 4, win_sizes=(11,), expected_error=RuntimeError)
989*da0073e9SAndroid Build Coastguard Worker        _test((10,), 5, 4, win_sizes=(1, 1), expected_error=RuntimeError)
990*da0073e9SAndroid Build Coastguard Worker
991*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchDynamo("double")
992*da0073e9SAndroid Build Coastguard Worker    @skipCPUIfNoFFT
993*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
994*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.double)
995*da0073e9SAndroid Build Coastguard Worker    def test_istft_against_librosa(self, device, dtype):
996*da0073e9SAndroid Build Coastguard Worker        if not TEST_LIBROSA:
997*da0073e9SAndroid Build Coastguard Worker            raise unittest.SkipTest('librosa not found')
998*da0073e9SAndroid Build Coastguard Worker
999*da0073e9SAndroid Build Coastguard Worker        def librosa_istft(x, n_fft, hop_length, win_length, window, length, center):
1000*da0073e9SAndroid Build Coastguard Worker            if window is None:
1001*da0073e9SAndroid Build Coastguard Worker                window = np.ones(n_fft if win_length is None else win_length)
1002*da0073e9SAndroid Build Coastguard Worker            else:
1003*da0073e9SAndroid Build Coastguard Worker                window = window.cpu().numpy()
1004*da0073e9SAndroid Build Coastguard Worker
1005*da0073e9SAndroid Build Coastguard Worker            return librosa.istft(x.cpu().numpy(), n_fft=n_fft, hop_length=hop_length,
1006*da0073e9SAndroid Build Coastguard Worker                                 win_length=win_length, length=length, window=window, center=center)
1007*da0073e9SAndroid Build Coastguard Worker
1008*da0073e9SAndroid Build Coastguard Worker        def _test(size, n_fft, hop_length=None, win_length=None, win_sizes=None,
1009*da0073e9SAndroid Build Coastguard Worker                  length=None, center=True):
1010*da0073e9SAndroid Build Coastguard Worker            x = torch.randn(size, dtype=dtype, device=device)
1011*da0073e9SAndroid Build Coastguard Worker            if win_sizes is not None:
1012*da0073e9SAndroid Build Coastguard Worker                window = torch.randn(*win_sizes, dtype=dtype, device=device)
1013*da0073e9SAndroid Build Coastguard Worker            else:
1014*da0073e9SAndroid Build Coastguard Worker                window = None
1015*da0073e9SAndroid Build Coastguard Worker
1016*da0073e9SAndroid Build Coastguard Worker            x_stft = x.stft(n_fft, hop_length, win_length, window, center=center,
1017*da0073e9SAndroid Build Coastguard Worker                            onesided=True, return_complex=True)
1018*da0073e9SAndroid Build Coastguard Worker
1019*da0073e9SAndroid Build Coastguard Worker            ref_result = librosa_istft(x_stft, n_fft, hop_length, win_length,
1020*da0073e9SAndroid Build Coastguard Worker                                       window, length, center)
1021*da0073e9SAndroid Build Coastguard Worker            result = x_stft.istft(n_fft, hop_length, win_length, window,
1022*da0073e9SAndroid Build Coastguard Worker                                  length=length, center=center)
1023*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(result, ref_result)
1024*da0073e9SAndroid Build Coastguard Worker
1025*da0073e9SAndroid Build Coastguard Worker        for center in [True, False]:
1026*da0073e9SAndroid Build Coastguard Worker            _test(10, 7, center=center)
1027*da0073e9SAndroid Build Coastguard Worker            _test(4000, 1024, center=center)
1028*da0073e9SAndroid Build Coastguard Worker            _test(4000, 1024, center=center, length=4000)
1029*da0073e9SAndroid Build Coastguard Worker
1030*da0073e9SAndroid Build Coastguard Worker            _test(10, 7, 2, center=center)
1031*da0073e9SAndroid Build Coastguard Worker            _test(4000, 1024, 512, center=center)
1032*da0073e9SAndroid Build Coastguard Worker            _test(4000, 1024, 512, center=center, length=4000)
1033*da0073e9SAndroid Build Coastguard Worker
1034*da0073e9SAndroid Build Coastguard Worker            _test(10, 7, 2, win_sizes=(7,), center=center)
1035*da0073e9SAndroid Build Coastguard Worker            _test(4000, 1024, 512, win_sizes=(1024,), center=center)
1036*da0073e9SAndroid Build Coastguard Worker            _test(4000, 1024, 512, win_sizes=(1024,), center=center, length=4000)
1037*da0073e9SAndroid Build Coastguard Worker
1038*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
1039*da0073e9SAndroid Build Coastguard Worker    @skipCPUIfNoFFT
1040*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.double, torch.cdouble)
1041*da0073e9SAndroid Build Coastguard Worker    def test_complex_stft_roundtrip(self, device, dtype):
1042*da0073e9SAndroid Build Coastguard Worker        test_args = list(product(
1043*da0073e9SAndroid Build Coastguard Worker            # input
1044*da0073e9SAndroid Build Coastguard Worker            (torch.randn(600, device=device, dtype=dtype),
1045*da0073e9SAndroid Build Coastguard Worker             torch.randn(807, device=device, dtype=dtype),
1046*da0073e9SAndroid Build Coastguard Worker             torch.randn(12, 60, device=device, dtype=dtype)),
1047*da0073e9SAndroid Build Coastguard Worker            # n_fft
1048*da0073e9SAndroid Build Coastguard Worker            (50, 27),
1049*da0073e9SAndroid Build Coastguard Worker            # hop_length
1050*da0073e9SAndroid Build Coastguard Worker            (None, 10),
1051*da0073e9SAndroid Build Coastguard Worker            # center
1052*da0073e9SAndroid Build Coastguard Worker            (True,),
1053*da0073e9SAndroid Build Coastguard Worker            # pad_mode
1054*da0073e9SAndroid Build Coastguard Worker            ("constant", "reflect", "circular"),
1055*da0073e9SAndroid Build Coastguard Worker            # normalized
1056*da0073e9SAndroid Build Coastguard Worker            (True, False),
1057*da0073e9SAndroid Build Coastguard Worker            # onesided
1058*da0073e9SAndroid Build Coastguard Worker            (True, False) if not dtype.is_complex else (False,),
1059*da0073e9SAndroid Build Coastguard Worker        ))
1060*da0073e9SAndroid Build Coastguard Worker
1061*da0073e9SAndroid Build Coastguard Worker        for args in test_args:
1062*da0073e9SAndroid Build Coastguard Worker            x, n_fft, hop_length, center, pad_mode, normalized, onesided = args
1063*da0073e9SAndroid Build Coastguard Worker            common_kwargs = {
1064*da0073e9SAndroid Build Coastguard Worker                'n_fft': n_fft, 'hop_length': hop_length, 'center': center,
1065*da0073e9SAndroid Build Coastguard Worker                'normalized': normalized, 'onesided': onesided,
1066*da0073e9SAndroid Build Coastguard Worker            }
1067*da0073e9SAndroid Build Coastguard Worker
1068*da0073e9SAndroid Build Coastguard Worker            # Functional interface
1069*da0073e9SAndroid Build Coastguard Worker            x_stft = torch.stft(x, pad_mode=pad_mode, return_complex=True, **common_kwargs)
1070*da0073e9SAndroid Build Coastguard Worker            x_roundtrip = torch.istft(x_stft, return_complex=dtype.is_complex,
1071*da0073e9SAndroid Build Coastguard Worker                                      length=x.size(-1), **common_kwargs)
1072*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x_roundtrip, x)
1073*da0073e9SAndroid Build Coastguard Worker
1074*da0073e9SAndroid Build Coastguard Worker            # Tensor method interface
1075*da0073e9SAndroid Build Coastguard Worker            x_stft = x.stft(pad_mode=pad_mode, return_complex=True, **common_kwargs)
1076*da0073e9SAndroid Build Coastguard Worker            x_roundtrip = torch.istft(x_stft, return_complex=dtype.is_complex,
1077*da0073e9SAndroid Build Coastguard Worker                                      length=x.size(-1), **common_kwargs)
1078*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x_roundtrip, x)
1079*da0073e9SAndroid Build Coastguard Worker
1080*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
1081*da0073e9SAndroid Build Coastguard Worker    @skipCPUIfNoFFT
1082*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.double, torch.cdouble)
1083*da0073e9SAndroid Build Coastguard Worker    def test_stft_roundtrip_complex_window(self, device, dtype):
1084*da0073e9SAndroid Build Coastguard Worker        test_args = list(product(
1085*da0073e9SAndroid Build Coastguard Worker            # input
1086*da0073e9SAndroid Build Coastguard Worker            (torch.randn(600, device=device, dtype=dtype),
1087*da0073e9SAndroid Build Coastguard Worker             torch.randn(807, device=device, dtype=dtype),
1088*da0073e9SAndroid Build Coastguard Worker             torch.randn(12, 60, device=device, dtype=dtype)),
1089*da0073e9SAndroid Build Coastguard Worker            # n_fft
1090*da0073e9SAndroid Build Coastguard Worker            (50, 27),
1091*da0073e9SAndroid Build Coastguard Worker            # hop_length
1092*da0073e9SAndroid Build Coastguard Worker            (None, 10),
1093*da0073e9SAndroid Build Coastguard Worker            # pad_mode
1094*da0073e9SAndroid Build Coastguard Worker            ("constant", "reflect", "replicate", "circular"),
1095*da0073e9SAndroid Build Coastguard Worker            # normalized
1096*da0073e9SAndroid Build Coastguard Worker            (True, False),
1097*da0073e9SAndroid Build Coastguard Worker        ))
1098*da0073e9SAndroid Build Coastguard Worker        for args in test_args:
1099*da0073e9SAndroid Build Coastguard Worker            x, n_fft, hop_length, pad_mode, normalized = args
1100*da0073e9SAndroid Build Coastguard Worker            window = torch.rand(n_fft, device=device, dtype=torch.cdouble)
1101*da0073e9SAndroid Build Coastguard Worker            x_stft = torch.stft(
1102*da0073e9SAndroid Build Coastguard Worker                x, n_fft=n_fft, hop_length=hop_length, window=window,
1103*da0073e9SAndroid Build Coastguard Worker                center=True, pad_mode=pad_mode, normalized=normalized)
1104*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x_stft.dtype, torch.cdouble)
1105*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x_stft.size(-2), n_fft)  # Not onesided
1106*da0073e9SAndroid Build Coastguard Worker
1107*da0073e9SAndroid Build Coastguard Worker            x_roundtrip = torch.istft(
1108*da0073e9SAndroid Build Coastguard Worker                x_stft, n_fft=n_fft, hop_length=hop_length, window=window,
1109*da0073e9SAndroid Build Coastguard Worker                center=True, normalized=normalized, length=x.size(-1),
1110*da0073e9SAndroid Build Coastguard Worker                return_complex=True)
1111*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x_stft.dtype, torch.cdouble)
1112*da0073e9SAndroid Build Coastguard Worker
1113*da0073e9SAndroid Build Coastguard Worker            if not dtype.is_complex:
1114*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(x_roundtrip.imag, torch.zeros_like(x_roundtrip.imag),
1115*da0073e9SAndroid Build Coastguard Worker                                 atol=1e-6, rtol=0)
1116*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(x_roundtrip.real, x)
1117*da0073e9SAndroid Build Coastguard Worker            else:
1118*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(x_roundtrip, x)
1119*da0073e9SAndroid Build Coastguard Worker
1120*da0073e9SAndroid Build Coastguard Worker
1121*da0073e9SAndroid Build Coastguard Worker    @skipCPUIfNoFFT
1122*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.cdouble)
1123*da0073e9SAndroid Build Coastguard Worker    def test_complex_stft_definition(self, device, dtype):
1124*da0073e9SAndroid Build Coastguard Worker        test_args = list(product(
1125*da0073e9SAndroid Build Coastguard Worker            # input
1126*da0073e9SAndroid Build Coastguard Worker            (torch.randn(600, device=device, dtype=dtype),
1127*da0073e9SAndroid Build Coastguard Worker             torch.randn(807, device=device, dtype=dtype)),
1128*da0073e9SAndroid Build Coastguard Worker            # n_fft
1129*da0073e9SAndroid Build Coastguard Worker            (50, 27),
1130*da0073e9SAndroid Build Coastguard Worker            # hop_length
1131*da0073e9SAndroid Build Coastguard Worker            (10, 15)
1132*da0073e9SAndroid Build Coastguard Worker        ))
1133*da0073e9SAndroid Build Coastguard Worker
1134*da0073e9SAndroid Build Coastguard Worker        for args in test_args:
1135*da0073e9SAndroid Build Coastguard Worker            window = torch.randn(args[1], device=device, dtype=dtype)
1136*da0073e9SAndroid Build Coastguard Worker            expected = _stft_reference(args[0], args[2], window)
1137*da0073e9SAndroid Build Coastguard Worker            actual = torch.stft(*args, window=window, center=False)
1138*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(actual, expected)
1139*da0073e9SAndroid Build Coastguard Worker
1140*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
1141*da0073e9SAndroid Build Coastguard Worker    @skipCPUIfNoFFT
1142*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.cdouble)
1143*da0073e9SAndroid Build Coastguard Worker    def test_complex_stft_real_equiv(self, device, dtype):
1144*da0073e9SAndroid Build Coastguard Worker        test_args = list(product(
1145*da0073e9SAndroid Build Coastguard Worker            # input
1146*da0073e9SAndroid Build Coastguard Worker            (torch.rand(600, device=device, dtype=dtype),
1147*da0073e9SAndroid Build Coastguard Worker             torch.rand(807, device=device, dtype=dtype),
1148*da0073e9SAndroid Build Coastguard Worker             torch.rand(14, 50, device=device, dtype=dtype),
1149*da0073e9SAndroid Build Coastguard Worker             torch.rand(6, 51, device=device, dtype=dtype)),
1150*da0073e9SAndroid Build Coastguard Worker            # n_fft
1151*da0073e9SAndroid Build Coastguard Worker            (50, 27),
1152*da0073e9SAndroid Build Coastguard Worker            # hop_length
1153*da0073e9SAndroid Build Coastguard Worker            (None, 10),
1154*da0073e9SAndroid Build Coastguard Worker            # win_length
1155*da0073e9SAndroid Build Coastguard Worker            (None, 20),
1156*da0073e9SAndroid Build Coastguard Worker            # center
1157*da0073e9SAndroid Build Coastguard Worker            (False, True),
1158*da0073e9SAndroid Build Coastguard Worker            # pad_mode
1159*da0073e9SAndroid Build Coastguard Worker            ("constant", "reflect", "circular"),
1160*da0073e9SAndroid Build Coastguard Worker            # normalized
1161*da0073e9SAndroid Build Coastguard Worker            (True, False),
1162*da0073e9SAndroid Build Coastguard Worker        ))
1163*da0073e9SAndroid Build Coastguard Worker
1164*da0073e9SAndroid Build Coastguard Worker        for args in test_args:
1165*da0073e9SAndroid Build Coastguard Worker            x, n_fft, hop_length, win_length, center, pad_mode, normalized = args
1166*da0073e9SAndroid Build Coastguard Worker            expected = _complex_stft(x, n_fft, hop_length=hop_length,
1167*da0073e9SAndroid Build Coastguard Worker                                     win_length=win_length, pad_mode=pad_mode,
1168*da0073e9SAndroid Build Coastguard Worker                                     center=center, normalized=normalized)
1169*da0073e9SAndroid Build Coastguard Worker            actual = torch.stft(x, n_fft, hop_length=hop_length,
1170*da0073e9SAndroid Build Coastguard Worker                                win_length=win_length, pad_mode=pad_mode,
1171*da0073e9SAndroid Build Coastguard Worker                                center=center, normalized=normalized)
1172*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(expected, actual)
1173*da0073e9SAndroid Build Coastguard Worker
1174*da0073e9SAndroid Build Coastguard Worker    @skipCPUIfNoFFT
1175*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.cdouble)
1176*da0073e9SAndroid Build Coastguard Worker    def test_complex_istft_real_equiv(self, device, dtype):
1177*da0073e9SAndroid Build Coastguard Worker        test_args = list(product(
1178*da0073e9SAndroid Build Coastguard Worker            # input
1179*da0073e9SAndroid Build Coastguard Worker            (torch.rand(40, 20, device=device, dtype=dtype),
1180*da0073e9SAndroid Build Coastguard Worker             torch.rand(25, 1, device=device, dtype=dtype),
1181*da0073e9SAndroid Build Coastguard Worker             torch.rand(4, 20, 10, device=device, dtype=dtype)),
1182*da0073e9SAndroid Build Coastguard Worker            # hop_length
1183*da0073e9SAndroid Build Coastguard Worker            (None, 10),
1184*da0073e9SAndroid Build Coastguard Worker            # center
1185*da0073e9SAndroid Build Coastguard Worker            (False, True),
1186*da0073e9SAndroid Build Coastguard Worker            # normalized
1187*da0073e9SAndroid Build Coastguard Worker            (True, False),
1188*da0073e9SAndroid Build Coastguard Worker        ))
1189*da0073e9SAndroid Build Coastguard Worker
1190*da0073e9SAndroid Build Coastguard Worker        for args in test_args:
1191*da0073e9SAndroid Build Coastguard Worker            x, hop_length, center, normalized = args
1192*da0073e9SAndroid Build Coastguard Worker            n_fft = x.size(-2)
1193*da0073e9SAndroid Build Coastguard Worker            expected = _complex_istft(x, n_fft, hop_length=hop_length,
1194*da0073e9SAndroid Build Coastguard Worker                                      center=center, normalized=normalized)
1195*da0073e9SAndroid Build Coastguard Worker            actual = torch.istft(x, n_fft, hop_length=hop_length,
1196*da0073e9SAndroid Build Coastguard Worker                                 center=center, normalized=normalized,
1197*da0073e9SAndroid Build Coastguard Worker                                 return_complex=True)
1198*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(expected, actual)
1199*da0073e9SAndroid Build Coastguard Worker
1200*da0073e9SAndroid Build Coastguard Worker    @skipCPUIfNoFFT
1201*da0073e9SAndroid Build Coastguard Worker    def test_complex_stft_onesided(self, device):
1202*da0073e9SAndroid Build Coastguard Worker        # stft of complex input cannot be onesided
1203*da0073e9SAndroid Build Coastguard Worker        for x_dtype, window_dtype in product((torch.double, torch.cdouble), repeat=2):
1204*da0073e9SAndroid Build Coastguard Worker            x = torch.rand(100, device=device, dtype=x_dtype)
1205*da0073e9SAndroid Build Coastguard Worker            window = torch.rand(10, device=device, dtype=window_dtype)
1206*da0073e9SAndroid Build Coastguard Worker
1207*da0073e9SAndroid Build Coastguard Worker            if x_dtype.is_complex or window_dtype.is_complex:
1208*da0073e9SAndroid Build Coastguard Worker                with self.assertRaisesRegex(RuntimeError, 'complex'):
1209*da0073e9SAndroid Build Coastguard Worker                    x.stft(10, window=window, pad_mode='constant', onesided=True)
1210*da0073e9SAndroid Build Coastguard Worker            else:
1211*da0073e9SAndroid Build Coastguard Worker                y = x.stft(10, window=window, pad_mode='constant', onesided=True,
1212*da0073e9SAndroid Build Coastguard Worker                           return_complex=True)
1213*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(y.dtype, torch.cdouble)
1214*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(y.size(), (6, 51))
1215*da0073e9SAndroid Build Coastguard Worker
1216*da0073e9SAndroid Build Coastguard Worker        x = torch.rand(100, device=device, dtype=torch.cdouble)
1217*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, 'complex'):
1218*da0073e9SAndroid Build Coastguard Worker            x.stft(10, pad_mode='constant', onesided=True)
1219*da0073e9SAndroid Build Coastguard Worker
1220*da0073e9SAndroid Build Coastguard Worker    # stft is currently warning that it requires return-complex while an upgrader is written
1221*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
1222*da0073e9SAndroid Build Coastguard Worker    @skipCPUIfNoFFT
1223*da0073e9SAndroid Build Coastguard Worker    def test_stft_requires_complex(self, device):
1224*da0073e9SAndroid Build Coastguard Worker        x = torch.rand(100)
1225*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, 'stft requires the return_complex parameter'):
1226*da0073e9SAndroid Build Coastguard Worker            y = x.stft(10, pad_mode='constant')
1227*da0073e9SAndroid Build Coastguard Worker
1228*da0073e9SAndroid Build Coastguard Worker    # stft and istft are currently warning if a window is not provided
1229*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
1230*da0073e9SAndroid Build Coastguard Worker    @skipCPUIfNoFFT
1231*da0073e9SAndroid Build Coastguard Worker    def test_stft_requires_window(self, device):
1232*da0073e9SAndroid Build Coastguard Worker        x = torch.rand(100)
1233*da0073e9SAndroid Build Coastguard Worker        with self.assertWarnsOnceRegex(UserWarning, "A window was not provided"):
1234*da0073e9SAndroid Build Coastguard Worker            y = x.stft(10, pad_mode='constant', return_complex=True)
1235*da0073e9SAndroid Build Coastguard Worker
1236*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
1237*da0073e9SAndroid Build Coastguard Worker    @skipCPUIfNoFFT
1238*da0073e9SAndroid Build Coastguard Worker    def test_istft_requires_window(self, device):
1239*da0073e9SAndroid Build Coastguard Worker        stft = torch.rand((51, 5), dtype=torch.cdouble)
1240*da0073e9SAndroid Build Coastguard Worker        # 51 = 2 * n_fft + 1, 5 = number of frames
1241*da0073e9SAndroid Build Coastguard Worker        with self.assertWarnsOnceRegex(UserWarning, "A window was not provided"):
1242*da0073e9SAndroid Build Coastguard Worker            x = torch.istft(stft, n_fft=100, length=100)
1243*da0073e9SAndroid Build Coastguard Worker
1244*da0073e9SAndroid Build Coastguard Worker    @skipCPUIfNoFFT
1245*da0073e9SAndroid Build Coastguard Worker    def test_fft_input_modification(self, device):
1246*da0073e9SAndroid Build Coastguard Worker        # FFT functions should not modify their input (gh-34551)
1247*da0073e9SAndroid Build Coastguard Worker
1248*da0073e9SAndroid Build Coastguard Worker        signal = torch.ones((2, 2, 2), device=device)
1249*da0073e9SAndroid Build Coastguard Worker        signal_copy = signal.clone()
1250*da0073e9SAndroid Build Coastguard Worker        spectrum = torch.fft.fftn(signal, dim=(-2, -1))
1251*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(signal, signal_copy)
1252*da0073e9SAndroid Build Coastguard Worker
1253*da0073e9SAndroid Build Coastguard Worker        spectrum_copy = spectrum.clone()
1254*da0073e9SAndroid Build Coastguard Worker        _ = torch.fft.ifftn(spectrum, dim=(-2, -1))
1255*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(spectrum, spectrum_copy)
1256*da0073e9SAndroid Build Coastguard Worker
1257*da0073e9SAndroid Build Coastguard Worker        half_spectrum = torch.fft.rfftn(signal, dim=(-2, -1))
1258*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(signal, signal_copy)
1259*da0073e9SAndroid Build Coastguard Worker
1260*da0073e9SAndroid Build Coastguard Worker        half_spectrum_copy = half_spectrum.clone()
1261*da0073e9SAndroid Build Coastguard Worker        _ = torch.fft.irfftn(half_spectrum_copy, s=(2, 2), dim=(-2, -1))
1262*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(half_spectrum, half_spectrum_copy)
1263*da0073e9SAndroid Build Coastguard Worker
1264*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
1265*da0073e9SAndroid Build Coastguard Worker    @skipCPUIfNoFFT
1266*da0073e9SAndroid Build Coastguard Worker    def test_fft_plan_repeatable(self, device):
1267*da0073e9SAndroid Build Coastguard Worker        # Regression test for gh-58724 and gh-63152
1268*da0073e9SAndroid Build Coastguard Worker        for n in [2048, 3199, 5999]:
1269*da0073e9SAndroid Build Coastguard Worker            a = torch.randn(n, device=device, dtype=torch.complex64)
1270*da0073e9SAndroid Build Coastguard Worker            res1 = torch.fft.fftn(a)
1271*da0073e9SAndroid Build Coastguard Worker            res2 = torch.fft.fftn(a.clone())
1272*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(res1, res2)
1273*da0073e9SAndroid Build Coastguard Worker
1274*da0073e9SAndroid Build Coastguard Worker            a = torch.randn(n, device=device, dtype=torch.float64)
1275*da0073e9SAndroid Build Coastguard Worker            res1 = torch.fft.rfft(a)
1276*da0073e9SAndroid Build Coastguard Worker            res2 = torch.fft.rfft(a.clone())
1277*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(res1, res2)
1278*da0073e9SAndroid Build Coastguard Worker
1279*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
1280*da0073e9SAndroid Build Coastguard Worker    @skipCPUIfNoFFT
1281*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.double)
1282*da0073e9SAndroid Build Coastguard Worker    def test_istft_round_trip_simple_cases(self, device, dtype):
1283*da0073e9SAndroid Build Coastguard Worker        """stft -> istft should recover the original signale"""
1284*da0073e9SAndroid Build Coastguard Worker        def _test(input, n_fft, length):
1285*da0073e9SAndroid Build Coastguard Worker            stft = torch.stft(input, n_fft=n_fft, return_complex=True)
1286*da0073e9SAndroid Build Coastguard Worker            inverse = torch.istft(stft, n_fft=n_fft, length=length)
1287*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(input, inverse, exact_dtype=True)
1288*da0073e9SAndroid Build Coastguard Worker
1289*da0073e9SAndroid Build Coastguard Worker        _test(torch.ones(4, dtype=dtype, device=device), 4, 4)
1290*da0073e9SAndroid Build Coastguard Worker        _test(torch.zeros(4, dtype=dtype, device=device), 4, 4)
1291*da0073e9SAndroid Build Coastguard Worker
1292*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
1293*da0073e9SAndroid Build Coastguard Worker    @skipCPUIfNoFFT
1294*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.double)
1295*da0073e9SAndroid Build Coastguard Worker    def test_istft_round_trip_various_params(self, device, dtype):
1296*da0073e9SAndroid Build Coastguard Worker        """stft -> istft should recover the original signale"""
1297*da0073e9SAndroid Build Coastguard Worker        def _test_istft_is_inverse_of_stft(stft_kwargs):
1298*da0073e9SAndroid Build Coastguard Worker            # generates a random sound signal for each tril and then does the stft/istft
1299*da0073e9SAndroid Build Coastguard Worker            # operation to check whether we can reconstruct signal
1300*da0073e9SAndroid Build Coastguard Worker            data_sizes = [(2, 20), (3, 15), (4, 10)]
1301*da0073e9SAndroid Build Coastguard Worker            num_trials = 100
1302*da0073e9SAndroid Build Coastguard Worker            istft_kwargs = stft_kwargs.copy()
1303*da0073e9SAndroid Build Coastguard Worker            del istft_kwargs['pad_mode']
1304*da0073e9SAndroid Build Coastguard Worker            for sizes in data_sizes:
1305*da0073e9SAndroid Build Coastguard Worker                for i in range(num_trials):
1306*da0073e9SAndroid Build Coastguard Worker                    original = torch.randn(*sizes, dtype=dtype, device=device)
1307*da0073e9SAndroid Build Coastguard Worker                    stft = torch.stft(original, return_complex=True, **stft_kwargs)
1308*da0073e9SAndroid Build Coastguard Worker                    inversed = torch.istft(stft, length=original.size(1), **istft_kwargs)
1309*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(
1310*da0073e9SAndroid Build Coastguard Worker                        inversed, original, msg='istft comparison against original',
1311*da0073e9SAndroid Build Coastguard Worker                        atol=7e-6, rtol=0, exact_dtype=True)
1312*da0073e9SAndroid Build Coastguard Worker
1313*da0073e9SAndroid Build Coastguard Worker        patterns = [
1314*da0073e9SAndroid Build Coastguard Worker            # hann_window, centered, normalized, onesided
1315*da0073e9SAndroid Build Coastguard Worker            {
1316*da0073e9SAndroid Build Coastguard Worker                'n_fft': 12,
1317*da0073e9SAndroid Build Coastguard Worker                'hop_length': 4,
1318*da0073e9SAndroid Build Coastguard Worker                'win_length': 12,
1319*da0073e9SAndroid Build Coastguard Worker                'window': torch.hann_window(12, dtype=dtype, device=device),
1320*da0073e9SAndroid Build Coastguard Worker                'center': True,
1321*da0073e9SAndroid Build Coastguard Worker                'pad_mode': 'reflect',
1322*da0073e9SAndroid Build Coastguard Worker                'normalized': True,
1323*da0073e9SAndroid Build Coastguard Worker                'onesided': True,
1324*da0073e9SAndroid Build Coastguard Worker            },
1325*da0073e9SAndroid Build Coastguard Worker            # hann_window, centered, not normalized, not onesided
1326*da0073e9SAndroid Build Coastguard Worker            {
1327*da0073e9SAndroid Build Coastguard Worker                'n_fft': 12,
1328*da0073e9SAndroid Build Coastguard Worker                'hop_length': 2,
1329*da0073e9SAndroid Build Coastguard Worker                'win_length': 8,
1330*da0073e9SAndroid Build Coastguard Worker                'window': torch.hann_window(8, dtype=dtype, device=device),
1331*da0073e9SAndroid Build Coastguard Worker                'center': True,
1332*da0073e9SAndroid Build Coastguard Worker                'pad_mode': 'reflect',
1333*da0073e9SAndroid Build Coastguard Worker                'normalized': False,
1334*da0073e9SAndroid Build Coastguard Worker                'onesided': False,
1335*da0073e9SAndroid Build Coastguard Worker            },
1336*da0073e9SAndroid Build Coastguard Worker            # hamming_window, centered, normalized, not onesided
1337*da0073e9SAndroid Build Coastguard Worker            {
1338*da0073e9SAndroid Build Coastguard Worker                'n_fft': 15,
1339*da0073e9SAndroid Build Coastguard Worker                'hop_length': 3,
1340*da0073e9SAndroid Build Coastguard Worker                'win_length': 11,
1341*da0073e9SAndroid Build Coastguard Worker                'window': torch.hamming_window(11, dtype=dtype, device=device),
1342*da0073e9SAndroid Build Coastguard Worker                'center': True,
1343*da0073e9SAndroid Build Coastguard Worker                'pad_mode': 'constant',
1344*da0073e9SAndroid Build Coastguard Worker                'normalized': True,
1345*da0073e9SAndroid Build Coastguard Worker                'onesided': False,
1346*da0073e9SAndroid Build Coastguard Worker            },
1347*da0073e9SAndroid Build Coastguard Worker            # hamming_window, centered, not normalized, onesided
1348*da0073e9SAndroid Build Coastguard Worker            # window same size as n_fft
1349*da0073e9SAndroid Build Coastguard Worker            {
1350*da0073e9SAndroid Build Coastguard Worker                'n_fft': 5,
1351*da0073e9SAndroid Build Coastguard Worker                'hop_length': 2,
1352*da0073e9SAndroid Build Coastguard Worker                'win_length': 5,
1353*da0073e9SAndroid Build Coastguard Worker                'window': torch.hamming_window(5, dtype=dtype, device=device),
1354*da0073e9SAndroid Build Coastguard Worker                'center': True,
1355*da0073e9SAndroid Build Coastguard Worker                'pad_mode': 'constant',
1356*da0073e9SAndroid Build Coastguard Worker                'normalized': False,
1357*da0073e9SAndroid Build Coastguard Worker                'onesided': True,
1358*da0073e9SAndroid Build Coastguard Worker            },
1359*da0073e9SAndroid Build Coastguard Worker        ]
1360*da0073e9SAndroid Build Coastguard Worker        for i, pattern in enumerate(patterns):
1361*da0073e9SAndroid Build Coastguard Worker            _test_istft_is_inverse_of_stft(pattern)
1362*da0073e9SAndroid Build Coastguard Worker
1363*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
1364*da0073e9SAndroid Build Coastguard Worker    @skipCPUIfNoFFT
1365*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.double)
1366*da0073e9SAndroid Build Coastguard Worker    def test_istft_round_trip_with_padding(self, device, dtype):
1367*da0073e9SAndroid Build Coastguard Worker        """long hop_length or not centered may cause length mismatch in the inversed signal"""
1368*da0073e9SAndroid Build Coastguard Worker        def _test_istft_is_inverse_of_stft_with_padding(stft_kwargs):
1369*da0073e9SAndroid Build Coastguard Worker            # generates a random sound signal for each tril and then does the stft/istft
1370*da0073e9SAndroid Build Coastguard Worker            # operation to check whether we can reconstruct signal
1371*da0073e9SAndroid Build Coastguard Worker            num_trials = 100
1372*da0073e9SAndroid Build Coastguard Worker            sizes = stft_kwargs['size']
1373*da0073e9SAndroid Build Coastguard Worker            del stft_kwargs['size']
1374*da0073e9SAndroid Build Coastguard Worker            istft_kwargs = stft_kwargs.copy()
1375*da0073e9SAndroid Build Coastguard Worker            del istft_kwargs['pad_mode']
1376*da0073e9SAndroid Build Coastguard Worker            for i in range(num_trials):
1377*da0073e9SAndroid Build Coastguard Worker                original = torch.randn(*sizes, dtype=dtype, device=device)
1378*da0073e9SAndroid Build Coastguard Worker                stft = torch.stft(original, return_complex=True, **stft_kwargs)
1379*da0073e9SAndroid Build Coastguard Worker                with self.assertWarnsOnceRegex(UserWarning, "The length of signal is shorter than the length parameter."):
1380*da0073e9SAndroid Build Coastguard Worker                    inversed = torch.istft(stft, length=original.size(-1), **istft_kwargs)
1381*da0073e9SAndroid Build Coastguard Worker                n_frames = stft.size(-1)
1382*da0073e9SAndroid Build Coastguard Worker                if stft_kwargs["center"] is True:
1383*da0073e9SAndroid Build Coastguard Worker                    len_expected = stft_kwargs["n_fft"] // 2 + stft_kwargs["hop_length"] * (n_frames - 1)
1384*da0073e9SAndroid Build Coastguard Worker                else:
1385*da0073e9SAndroid Build Coastguard Worker                    len_expected = stft_kwargs["n_fft"] + stft_kwargs["hop_length"] * (n_frames - 1)
1386*da0073e9SAndroid Build Coastguard Worker                # trim the original for case when constructed signal is shorter than original
1387*da0073e9SAndroid Build Coastguard Worker                padding = inversed[..., len_expected:]
1388*da0073e9SAndroid Build Coastguard Worker                inversed = inversed[..., :len_expected]
1389*da0073e9SAndroid Build Coastguard Worker                original = original[..., :len_expected]
1390*da0073e9SAndroid Build Coastguard Worker                # test the padding points of the inversed signal are all zeros
1391*da0073e9SAndroid Build Coastguard Worker                zeros = torch.zeros_like(padding, device=padding.device)
1392*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(
1393*da0073e9SAndroid Build Coastguard Worker                    padding, zeros, msg='istft padding values against zeros',
1394*da0073e9SAndroid Build Coastguard Worker                    atol=7e-6, rtol=0, exact_dtype=True)
1395*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(
1396*da0073e9SAndroid Build Coastguard Worker                    inversed, original, msg='istft comparison against original',
1397*da0073e9SAndroid Build Coastguard Worker                    atol=7e-6, rtol=0, exact_dtype=True)
1398*da0073e9SAndroid Build Coastguard Worker
1399*da0073e9SAndroid Build Coastguard Worker        patterns = [
1400*da0073e9SAndroid Build Coastguard Worker            # hamming_window, not centered, not normalized, not onesided
1401*da0073e9SAndroid Build Coastguard Worker            # window same size as n_fft
1402*da0073e9SAndroid Build Coastguard Worker            {
1403*da0073e9SAndroid Build Coastguard Worker                'size': [2, 20],
1404*da0073e9SAndroid Build Coastguard Worker                'n_fft': 3,
1405*da0073e9SAndroid Build Coastguard Worker                'hop_length': 2,
1406*da0073e9SAndroid Build Coastguard Worker                'win_length': 3,
1407*da0073e9SAndroid Build Coastguard Worker                'window': torch.hamming_window(3, dtype=dtype, device=device),
1408*da0073e9SAndroid Build Coastguard Worker                'center': False,
1409*da0073e9SAndroid Build Coastguard Worker                'pad_mode': 'reflect',
1410*da0073e9SAndroid Build Coastguard Worker                'normalized': False,
1411*da0073e9SAndroid Build Coastguard Worker                'onesided': False,
1412*da0073e9SAndroid Build Coastguard Worker            },
1413*da0073e9SAndroid Build Coastguard Worker            # hamming_window, centered, not normalized, onesided, long hop_length
1414*da0073e9SAndroid Build Coastguard Worker            # window same size as n_fft
1415*da0073e9SAndroid Build Coastguard Worker            {
1416*da0073e9SAndroid Build Coastguard Worker                'size': [2, 500],
1417*da0073e9SAndroid Build Coastguard Worker                'n_fft': 256,
1418*da0073e9SAndroid Build Coastguard Worker                'hop_length': 254,
1419*da0073e9SAndroid Build Coastguard Worker                'win_length': 256,
1420*da0073e9SAndroid Build Coastguard Worker                'window': torch.hamming_window(256, dtype=dtype, device=device),
1421*da0073e9SAndroid Build Coastguard Worker                'center': True,
1422*da0073e9SAndroid Build Coastguard Worker                'pad_mode': 'constant',
1423*da0073e9SAndroid Build Coastguard Worker                'normalized': False,
1424*da0073e9SAndroid Build Coastguard Worker                'onesided': True,
1425*da0073e9SAndroid Build Coastguard Worker            },
1426*da0073e9SAndroid Build Coastguard Worker        ]
1427*da0073e9SAndroid Build Coastguard Worker        for i, pattern in enumerate(patterns):
1428*da0073e9SAndroid Build Coastguard Worker            _test_istft_is_inverse_of_stft_with_padding(pattern)
1429*da0073e9SAndroid Build Coastguard Worker
1430*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
1431*da0073e9SAndroid Build Coastguard Worker    def test_istft_throws(self, device):
1432*da0073e9SAndroid Build Coastguard Worker        """istft should throw exception for invalid parameters"""
1433*da0073e9SAndroid Build Coastguard Worker        stft = torch.zeros((3, 5, 2), device=device)
1434*da0073e9SAndroid Build Coastguard Worker        # the window is size 1 but it hops 20 so there is a gap which throw an error
1435*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(
1436*da0073e9SAndroid Build Coastguard Worker            RuntimeError, torch.istft, stft, n_fft=4,
1437*da0073e9SAndroid Build Coastguard Worker            hop_length=20, win_length=1, window=torch.ones(1))
1438*da0073e9SAndroid Build Coastguard Worker        # A window of zeros does not meet NOLA
1439*da0073e9SAndroid Build Coastguard Worker        invalid_window = torch.zeros(4, device=device)
1440*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(
1441*da0073e9SAndroid Build Coastguard Worker            RuntimeError, torch.istft, stft, n_fft=4, win_length=4, window=invalid_window)
1442*da0073e9SAndroid Build Coastguard Worker        # Input cannot be empty
1443*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(RuntimeError, torch.istft, torch.zeros((3, 0, 2)), 2)
1444*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(RuntimeError, torch.istft, torch.zeros((0, 3, 2)), 2)
1445*da0073e9SAndroid Build Coastguard Worker
1446*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchDynamo("Failed running call_function")
1447*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
1448*da0073e9SAndroid Build Coastguard Worker    @skipCPUIfNoFFT
1449*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.double)
1450*da0073e9SAndroid Build Coastguard Worker    def test_istft_of_sine(self, device, dtype):
1451*da0073e9SAndroid Build Coastguard Worker        complex_dtype = corresponding_complex_dtype(dtype)
1452*da0073e9SAndroid Build Coastguard Worker
1453*da0073e9SAndroid Build Coastguard Worker        def _test(amplitude, L, n):
1454*da0073e9SAndroid Build Coastguard Worker            # stft of amplitude*sin(2*pi/L*n*x) with the hop length and window size equaling L
1455*da0073e9SAndroid Build Coastguard Worker            x = torch.arange(2 * L + 1, device=device, dtype=dtype)
1456*da0073e9SAndroid Build Coastguard Worker            original = amplitude * torch.sin(2 * math.pi / L * x * n)
1457*da0073e9SAndroid Build Coastguard Worker            # stft = torch.stft(original, L, hop_length=L, win_length=L,
1458*da0073e9SAndroid Build Coastguard Worker            #                   window=torch.ones(L), center=False, normalized=False)
1459*da0073e9SAndroid Build Coastguard Worker            stft = torch.zeros((L // 2 + 1, 2), device=device, dtype=complex_dtype)
1460*da0073e9SAndroid Build Coastguard Worker            stft_largest_val = (amplitude * L) / 2.0
1461*da0073e9SAndroid Build Coastguard Worker            if n < stft.size(0):
1462*da0073e9SAndroid Build Coastguard Worker                stft[n].imag = torch.tensor(-stft_largest_val, dtype=dtype)
1463*da0073e9SAndroid Build Coastguard Worker
1464*da0073e9SAndroid Build Coastguard Worker            if 0 <= L - n < stft.size(0):
1465*da0073e9SAndroid Build Coastguard Worker                # symmetric about L // 2
1466*da0073e9SAndroid Build Coastguard Worker                stft[L - n].imag = torch.tensor(stft_largest_val, dtype=dtype)
1467*da0073e9SAndroid Build Coastguard Worker
1468*da0073e9SAndroid Build Coastguard Worker            inverse = torch.istft(
1469*da0073e9SAndroid Build Coastguard Worker                stft, L, hop_length=L, win_length=L,
1470*da0073e9SAndroid Build Coastguard Worker                window=torch.ones(L, device=device, dtype=dtype), center=False, normalized=False)
1471*da0073e9SAndroid Build Coastguard Worker            # There is a larger error due to the scaling of amplitude
1472*da0073e9SAndroid Build Coastguard Worker            original = original[..., :inverse.size(-1)]
1473*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(inverse, original, atol=1e-3, rtol=0)
1474*da0073e9SAndroid Build Coastguard Worker
1475*da0073e9SAndroid Build Coastguard Worker        _test(amplitude=123, L=5, n=1)
1476*da0073e9SAndroid Build Coastguard Worker        _test(amplitude=150, L=5, n=2)
1477*da0073e9SAndroid Build Coastguard Worker        _test(amplitude=111, L=5, n=3)
1478*da0073e9SAndroid Build Coastguard Worker        _test(amplitude=160, L=7, n=4)
1479*da0073e9SAndroid Build Coastguard Worker        _test(amplitude=145, L=8, n=5)
1480*da0073e9SAndroid Build Coastguard Worker        _test(amplitude=80, L=9, n=6)
1481*da0073e9SAndroid Build Coastguard Worker        _test(amplitude=99, L=10, n=7)
1482*da0073e9SAndroid Build Coastguard Worker
1483*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
1484*da0073e9SAndroid Build Coastguard Worker    @skipCPUIfNoFFT
1485*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.double)
1486*da0073e9SAndroid Build Coastguard Worker    def test_istft_linearity(self, device, dtype):
1487*da0073e9SAndroid Build Coastguard Worker        num_trials = 100
1488*da0073e9SAndroid Build Coastguard Worker        complex_dtype = corresponding_complex_dtype(dtype)
1489*da0073e9SAndroid Build Coastguard Worker
1490*da0073e9SAndroid Build Coastguard Worker        def _test(data_size, kwargs):
1491*da0073e9SAndroid Build Coastguard Worker            for i in range(num_trials):
1492*da0073e9SAndroid Build Coastguard Worker                tensor1 = torch.randn(data_size, device=device, dtype=complex_dtype)
1493*da0073e9SAndroid Build Coastguard Worker                tensor2 = torch.randn(data_size, device=device, dtype=complex_dtype)
1494*da0073e9SAndroid Build Coastguard Worker                a, b = torch.rand(2, dtype=dtype, device=device)
1495*da0073e9SAndroid Build Coastguard Worker                # Also compare method vs. functional call signature
1496*da0073e9SAndroid Build Coastguard Worker                istft1 = tensor1.istft(**kwargs)
1497*da0073e9SAndroid Build Coastguard Worker                istft2 = tensor2.istft(**kwargs)
1498*da0073e9SAndroid Build Coastguard Worker                istft = a * istft1 + b * istft2
1499*da0073e9SAndroid Build Coastguard Worker                estimate = torch.istft(a * tensor1 + b * tensor2, **kwargs)
1500*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(istft, estimate, atol=1e-5, rtol=0)
1501*da0073e9SAndroid Build Coastguard Worker        patterns = [
1502*da0073e9SAndroid Build Coastguard Worker            # hann_window, centered, normalized, onesided
1503*da0073e9SAndroid Build Coastguard Worker            (
1504*da0073e9SAndroid Build Coastguard Worker                (2, 7, 7),
1505*da0073e9SAndroid Build Coastguard Worker                {
1506*da0073e9SAndroid Build Coastguard Worker                    'n_fft': 12,
1507*da0073e9SAndroid Build Coastguard Worker                    'window': torch.hann_window(12, device=device, dtype=dtype),
1508*da0073e9SAndroid Build Coastguard Worker                    'center': True,
1509*da0073e9SAndroid Build Coastguard Worker                    'normalized': True,
1510*da0073e9SAndroid Build Coastguard Worker                    'onesided': True,
1511*da0073e9SAndroid Build Coastguard Worker                },
1512*da0073e9SAndroid Build Coastguard Worker            ),
1513*da0073e9SAndroid Build Coastguard Worker            # hann_window, centered, not normalized, not onesided
1514*da0073e9SAndroid Build Coastguard Worker            (
1515*da0073e9SAndroid Build Coastguard Worker                (2, 12, 7),
1516*da0073e9SAndroid Build Coastguard Worker                {
1517*da0073e9SAndroid Build Coastguard Worker                    'n_fft': 12,
1518*da0073e9SAndroid Build Coastguard Worker                    'window': torch.hann_window(12, device=device, dtype=dtype),
1519*da0073e9SAndroid Build Coastguard Worker                    'center': True,
1520*da0073e9SAndroid Build Coastguard Worker                    'normalized': False,
1521*da0073e9SAndroid Build Coastguard Worker                    'onesided': False,
1522*da0073e9SAndroid Build Coastguard Worker                },
1523*da0073e9SAndroid Build Coastguard Worker            ),
1524*da0073e9SAndroid Build Coastguard Worker            # hamming_window, centered, normalized, not onesided
1525*da0073e9SAndroid Build Coastguard Worker            (
1526*da0073e9SAndroid Build Coastguard Worker                (2, 12, 7),
1527*da0073e9SAndroid Build Coastguard Worker                {
1528*da0073e9SAndroid Build Coastguard Worker                    'n_fft': 12,
1529*da0073e9SAndroid Build Coastguard Worker                    'window': torch.hamming_window(12, device=device, dtype=dtype),
1530*da0073e9SAndroid Build Coastguard Worker                    'center': True,
1531*da0073e9SAndroid Build Coastguard Worker                    'normalized': True,
1532*da0073e9SAndroid Build Coastguard Worker                    'onesided': False,
1533*da0073e9SAndroid Build Coastguard Worker                },
1534*da0073e9SAndroid Build Coastguard Worker            ),
1535*da0073e9SAndroid Build Coastguard Worker            # hamming_window, not centered, not normalized, onesided
1536*da0073e9SAndroid Build Coastguard Worker            (
1537*da0073e9SAndroid Build Coastguard Worker                (2, 7, 3),
1538*da0073e9SAndroid Build Coastguard Worker                {
1539*da0073e9SAndroid Build Coastguard Worker                    'n_fft': 12,
1540*da0073e9SAndroid Build Coastguard Worker                    'window': torch.hamming_window(12, device=device, dtype=dtype),
1541*da0073e9SAndroid Build Coastguard Worker                    'center': False,
1542*da0073e9SAndroid Build Coastguard Worker                    'normalized': False,
1543*da0073e9SAndroid Build Coastguard Worker                    'onesided': True,
1544*da0073e9SAndroid Build Coastguard Worker                },
1545*da0073e9SAndroid Build Coastguard Worker            )
1546*da0073e9SAndroid Build Coastguard Worker        ]
1547*da0073e9SAndroid Build Coastguard Worker        for data_size, kwargs in patterns:
1548*da0073e9SAndroid Build Coastguard Worker            _test(data_size, kwargs)
1549*da0073e9SAndroid Build Coastguard Worker
1550*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
1551*da0073e9SAndroid Build Coastguard Worker    @skipCPUIfNoFFT
1552*da0073e9SAndroid Build Coastguard Worker    def test_batch_istft(self, device):
1553*da0073e9SAndroid Build Coastguard Worker        original = torch.tensor([
1554*da0073e9SAndroid Build Coastguard Worker            [4., 4., 4., 4., 4.],
1555*da0073e9SAndroid Build Coastguard Worker            [0., 0., 0., 0., 0.],
1556*da0073e9SAndroid Build Coastguard Worker            [0., 0., 0., 0., 0.]
1557*da0073e9SAndroid Build Coastguard Worker        ], device=device, dtype=torch.complex64)
1558*da0073e9SAndroid Build Coastguard Worker
1559*da0073e9SAndroid Build Coastguard Worker        single = original.repeat(1, 1, 1)
1560*da0073e9SAndroid Build Coastguard Worker        multi = original.repeat(4, 1, 1)
1561*da0073e9SAndroid Build Coastguard Worker
1562*da0073e9SAndroid Build Coastguard Worker        i_original = torch.istft(original, n_fft=4, length=4)
1563*da0073e9SAndroid Build Coastguard Worker        i_single = torch.istft(single, n_fft=4, length=4)
1564*da0073e9SAndroid Build Coastguard Worker        i_multi = torch.istft(multi, n_fft=4, length=4)
1565*da0073e9SAndroid Build Coastguard Worker
1566*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(i_original.repeat(1, 1), i_single, atol=1e-6, rtol=0, exact_dtype=True)
1567*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(i_original.repeat(4, 1), i_multi, atol=1e-6, rtol=0, exact_dtype=True)
1568*da0073e9SAndroid Build Coastguard Worker
1569*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
1570*da0073e9SAndroid Build Coastguard Worker    @skipIf(not TEST_MKL, "Test requires MKL")
1571*da0073e9SAndroid Build Coastguard Worker    def test_stft_window_device(self, device):
1572*da0073e9SAndroid Build Coastguard Worker        # Test the (i)stft window must be on the same device as the input
1573*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(1000, dtype=torch.complex64)
1574*da0073e9SAndroid Build Coastguard Worker        window = torch.randn(100, dtype=torch.complex64)
1575*da0073e9SAndroid Build Coastguard Worker
1576*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "stft input and window must be on the same device"):
1577*da0073e9SAndroid Build Coastguard Worker            torch.stft(x, n_fft=100, window=window.to(device))
1578*da0073e9SAndroid Build Coastguard Worker
1579*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "stft input and window must be on the same device"):
1580*da0073e9SAndroid Build Coastguard Worker            torch.stft(x.to(device), n_fft=100, window=window)
1581*da0073e9SAndroid Build Coastguard Worker
1582*da0073e9SAndroid Build Coastguard Worker        X = torch.stft(x, n_fft=100, window=window)
1583*da0073e9SAndroid Build Coastguard Worker
1584*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "istft input and window must be on the same device"):
1585*da0073e9SAndroid Build Coastguard Worker            torch.istft(X, n_fft=100, window=window.to(device))
1586*da0073e9SAndroid Build Coastguard Worker
1587*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "istft input and window must be on the same device"):
1588*da0073e9SAndroid Build Coastguard Worker            torch.istft(x.to(device), n_fft=100, window=window)
1589*da0073e9SAndroid Build Coastguard Worker
1590*da0073e9SAndroid Build Coastguard Worker
1591*da0073e9SAndroid Build Coastguard Workerclass FFTDocTestFinder:
1592*da0073e9SAndroid Build Coastguard Worker    '''The default doctest finder doesn't like that function.__module__ doesn't
1593*da0073e9SAndroid Build Coastguard Worker    match torch.fft. It assumes the functions are leaked imports.
1594*da0073e9SAndroid Build Coastguard Worker    '''
1595*da0073e9SAndroid Build Coastguard Worker    def __init__(self) -> None:
1596*da0073e9SAndroid Build Coastguard Worker        self.parser = doctest.DocTestParser()
1597*da0073e9SAndroid Build Coastguard Worker
1598*da0073e9SAndroid Build Coastguard Worker    def find(self, obj, name=None, module=None, globs=None, extraglobs=None):
1599*da0073e9SAndroid Build Coastguard Worker        doctests = []
1600*da0073e9SAndroid Build Coastguard Worker
1601*da0073e9SAndroid Build Coastguard Worker        modname = name if name is not None else obj.__name__
1602*da0073e9SAndroid Build Coastguard Worker        globs = {} if globs is None else globs
1603*da0073e9SAndroid Build Coastguard Worker
1604*da0073e9SAndroid Build Coastguard Worker        for fname in obj.__all__:
1605*da0073e9SAndroid Build Coastguard Worker            func = getattr(obj, fname)
1606*da0073e9SAndroid Build Coastguard Worker            if inspect.isroutine(func):
1607*da0073e9SAndroid Build Coastguard Worker                qualname = modname + '.' + fname
1608*da0073e9SAndroid Build Coastguard Worker                docstring = inspect.getdoc(func)
1609*da0073e9SAndroid Build Coastguard Worker                if docstring is None:
1610*da0073e9SAndroid Build Coastguard Worker                    continue
1611*da0073e9SAndroid Build Coastguard Worker
1612*da0073e9SAndroid Build Coastguard Worker                examples = self.parser.get_doctest(
1613*da0073e9SAndroid Build Coastguard Worker                    docstring, globs=globs, name=fname, filename=None, lineno=None)
1614*da0073e9SAndroid Build Coastguard Worker                doctests.append(examples)
1615*da0073e9SAndroid Build Coastguard Worker
1616*da0073e9SAndroid Build Coastguard Worker        return doctests
1617*da0073e9SAndroid Build Coastguard Worker
1618*da0073e9SAndroid Build Coastguard Worker
1619*da0073e9SAndroid Build Coastguard Workerclass TestFFTDocExamples(TestCase):
1620*da0073e9SAndroid Build Coastguard Worker    pass
1621*da0073e9SAndroid Build Coastguard Worker
1622*da0073e9SAndroid Build Coastguard Workerdef generate_doc_test(doc_test):
1623*da0073e9SAndroid Build Coastguard Worker    def test(self, device):
1624*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(device, 'cpu')
1625*da0073e9SAndroid Build Coastguard Worker        runner = doctest.DocTestRunner()
1626*da0073e9SAndroid Build Coastguard Worker        runner.run(doc_test)
1627*da0073e9SAndroid Build Coastguard Worker
1628*da0073e9SAndroid Build Coastguard Worker        if runner.failures != 0:
1629*da0073e9SAndroid Build Coastguard Worker            runner.summarize()
1630*da0073e9SAndroid Build Coastguard Worker            self.fail('Doctest failed')
1631*da0073e9SAndroid Build Coastguard Worker
1632*da0073e9SAndroid Build Coastguard Worker    setattr(TestFFTDocExamples, 'test_' + doc_test.name, skipCPUIfNoFFT(test))
1633*da0073e9SAndroid Build Coastguard Worker
1634*da0073e9SAndroid Build Coastguard Workerfor doc_test in FFTDocTestFinder().find(torch.fft, globs=dict(torch=torch)):
1635*da0073e9SAndroid Build Coastguard Worker    generate_doc_test(doc_test)
1636*da0073e9SAndroid Build Coastguard Worker
1637*da0073e9SAndroid Build Coastguard Worker
1638*da0073e9SAndroid Build Coastguard Workerinstantiate_device_type_tests(TestFFT, globals())
1639*da0073e9SAndroid Build Coastguard Workerinstantiate_device_type_tests(TestFFTDocExamples, globals(), only_for='cpu')
1640*da0073e9SAndroid Build Coastguard Worker
1641*da0073e9SAndroid Build Coastguard Workerif __name__ == '__main__':
1642*da0073e9SAndroid Build Coastguard Worker    run_tests()
1643