xref: /aosp_15_r20/external/pytorch/test/cpp/api/fft.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <gtest/gtest.h>
2 
3 #include <c10/util/irange.h>
4 #include <test/cpp/api/support.h>
5 #include <torch/torch.h>
6 
7 // Naive DFT of a 1 dimensional tensor
naive_dft(torch::Tensor x,bool forward=true)8 torch::Tensor naive_dft(torch::Tensor x, bool forward = true) {
9   TORCH_INTERNAL_ASSERT(x.dim() == 1);
10   x = x.contiguous();
11   auto out_tensor = torch::zeros_like(x);
12   const int64_t len = x.size(0);
13 
14   // Roots of unity, exp(-2*pi*j*n/N) for n in [0, N), reversed for inverse
15   // transform
16   std::vector<c10::complex<double>> roots(len);
17   const auto angle_base = (forward ? -2.0 : 2.0) * M_PI / len;
18   for (const auto i : c10::irange(len)) {
19     auto angle = i * angle_base;
20     roots[i] = c10::complex<double>(std::cos(angle), std::sin(angle));
21   }
22 
23   const auto in = x.data_ptr<c10::complex<double>>();
24   const auto out = out_tensor.data_ptr<c10::complex<double>>();
25   for (const auto i : c10::irange(len)) {
26     for (const auto j : c10::irange(len)) {
27       out[i] += roots[(j * i) % len] * in[j];
28     }
29   }
30   return out_tensor;
31 }
32 
33 // NOTE: Visual Studio and ROCm builds don't understand complex literals
34 //   as of August 2020
35 
TEST(FFTTest,fft)36 TEST(FFTTest, fft) {
37   auto t = torch::randn(128, torch::kComplexDouble);
38   auto actual = torch::fft::fft(t);
39   auto expect = naive_dft(t);
40   ASSERT_TRUE(torch::allclose(actual, expect));
41 }
42 
TEST(FFTTest,fft_real)43 TEST(FFTTest, fft_real) {
44   auto t = torch::randn(128, torch::kDouble);
45   auto actual = torch::fft::fft(t);
46   auto expect = torch::fft::fft(t.to(torch::kComplexDouble));
47   ASSERT_TRUE(torch::allclose(actual, expect));
48 }
49 
TEST(FFTTest,fft_pad)50 TEST(FFTTest, fft_pad) {
51   auto t = torch::randn(128, torch::kComplexDouble);
52   auto actual = torch::fft::fft(t, 200);
53   auto expect = torch::fft::fft(torch::constant_pad_nd(t, {0, 72}));
54   ASSERT_TRUE(torch::allclose(actual, expect));
55 
56   actual = torch::fft::fft(t, 64);
57   expect = torch::fft::fft(torch::constant_pad_nd(t, {0, -64}));
58   ASSERT_TRUE(torch::allclose(actual, expect));
59 }
60 
TEST(FFTTest,fft_norm)61 TEST(FFTTest, fft_norm) {
62   auto t = torch::randn(128, torch::kComplexDouble);
63   // NOLINTNEXTLINE(bugprone-argument-comment)
64   auto unnorm = torch::fft::fft(t, /*n=*/{}, /*axis=*/-1, /*norm=*/{});
65   // NOLINTNEXTLINE(bugprone-argument-comment)
66   auto norm = torch::fft::fft(t, /*n=*/{}, /*axis=*/-1, /*norm=*/"forward");
67   ASSERT_TRUE(torch::allclose(unnorm / 128, norm));
68 
69   // NOLINTNEXTLINE(bugprone-argument-comment)
70   auto ortho_norm = torch::fft::fft(t, /*n=*/{}, /*axis=*/-1, /*norm=*/"ortho");
71   ASSERT_TRUE(torch::allclose(unnorm / std::sqrt(128), ortho_norm));
72 }
73 
TEST(FFTTest,ifft)74 TEST(FFTTest, ifft) {
75   auto T = torch::randn(128, torch::kComplexDouble);
76   auto actual = torch::fft::ifft(T);
77   auto expect = naive_dft(T, /*forward=*/false) / 128;
78   ASSERT_TRUE(torch::allclose(actual, expect));
79 }
80 
TEST(FFTTest,fft_ifft)81 TEST(FFTTest, fft_ifft) {
82   auto t = torch::randn(77, torch::kComplexDouble);
83   auto T = torch::fft::fft(t);
84   ASSERT_EQ(T.size(0), 77);
85   ASSERT_EQ(T.scalar_type(), torch::kComplexDouble);
86 
87   auto t_round_trip = torch::fft::ifft(T);
88   ASSERT_EQ(t_round_trip.size(0), 77);
89   ASSERT_EQ(t_round_trip.scalar_type(), torch::kComplexDouble);
90   ASSERT_TRUE(torch::allclose(t, t_round_trip));
91 }
92 
TEST(FFTTest,rfft)93 TEST(FFTTest, rfft) {
94   auto t = torch::randn(129, torch::kDouble);
95   auto actual = torch::fft::rfft(t);
96   auto expect = torch::fft::fft(t.to(torch::kComplexDouble)).slice(0, 0, 65);
97   ASSERT_TRUE(torch::allclose(actual, expect));
98 }
99 
TEST(FFTTest,rfft_irfft)100 TEST(FFTTest, rfft_irfft) {
101   auto t = torch::randn(128, torch::kDouble);
102   auto T = torch::fft::rfft(t);
103   ASSERT_EQ(T.size(0), 65);
104   ASSERT_EQ(T.scalar_type(), torch::kComplexDouble);
105 
106   auto t_round_trip = torch::fft::irfft(T);
107   ASSERT_EQ(t_round_trip.size(0), 128);
108   ASSERT_EQ(t_round_trip.scalar_type(), torch::kDouble);
109   ASSERT_TRUE(torch::allclose(t, t_round_trip));
110 }
111 
TEST(FFTTest,ihfft)112 TEST(FFTTest, ihfft) {
113   auto T = torch::randn(129, torch::kDouble);
114   auto actual = torch::fft::ihfft(T);
115   auto expect = torch::fft::ifft(T.to(torch::kComplexDouble)).slice(0, 0, 65);
116   ASSERT_TRUE(torch::allclose(actual, expect));
117 }
118 
TEST(FFTTest,hfft_ihfft)119 TEST(FFTTest, hfft_ihfft) {
120   auto t = torch::randn(64, torch::kComplexDouble);
121   t[0] = .5; // Must be purely real to satisfy hermitian symmetry
122   auto T = torch::fft::hfft(t, 127);
123   ASSERT_EQ(T.size(0), 127);
124   ASSERT_EQ(T.scalar_type(), torch::kDouble);
125 
126   auto t_round_trip = torch::fft::ihfft(T);
127   ASSERT_EQ(t_round_trip.size(0), 64);
128   ASSERT_EQ(t_round_trip.scalar_type(), torch::kComplexDouble);
129   ASSERT_TRUE(torch::allclose(t, t_round_trip));
130 }
131