xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/mps/operations/FastFourierTransform.mm (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1#include <ATen/native/SpectralOpsUtils.h>
2#include <ATen/native/mps/MPSGraphSonomaOps.h>
3#include <ATen/native/mps/MPSGraphVenturaOps.h>
4#include <ATen/native/mps/OperationUtils.h>
5
6#ifndef AT_PER_OPERATOR_HEADERS
7#include <ATen/Functions.h>
8#include <ATen/NativeFunctions.h>
9#else
10#include <ATen/ops/_fft_c2c_native.h>
11#include <ATen/ops/_fft_c2r_native.h>
12#include <ATen/ops/_fft_r2c_native.h>
13#endif
14
15#if !defined(__MAC_14_0) && (!defined(MAC_OS_X_VERSION_14_0) || (MAC_OS_X_VERSION_MIN_REQUIRED < MAC_OS_X_VERSION_14_0))
16@implementation FakeMPSGraphFFTDescriptor
17+ (nullable instancetype)descriptor {
18  // Redispatch the constructor to the actual implementation
19  id desc = NSClassFromString(@"MPSGraphFFTDescriptor");
20  return (FakeMPSGraphFFTDescriptor*)[desc descriptor];
21}
22
23- (nonnull id)copyWithZone:(nullable NSZone*)zone {
24  return self;
25}
26@end
27#endif
28
29namespace at::native {
30namespace {
31MPSGraphFFTScalingMode normalization_to_ScalingMode(int64_t normalization) {
32  switch (static_cast<fft_norm_mode>(normalization)) {
33    case fft_norm_mode::none:
34      return MPSGraphFFTScalingModeNone;
35    case fft_norm_mode::by_n:
36      return MPSGraphFFTScalingModeSize;
37    case fft_norm_mode::by_root_n:
38      return MPSGraphFFTScalingModeUnitary;
39    default:
40      break;
41  }
42  TORCH_CHECK(false, "Unsupported normalization type", normalization);
43}
44
45NSArray<NSNumber*>* IntArrayToNSArray(IntArrayRef arr) {
46  auto rc = [NSMutableArray<NSNumber*> arrayWithCapacity:arr.size()];
47  for (const auto idx : c10::irange(arr.size())) {
48    rc[idx] = [NSNumber numberWithInteger:arr[idx]];
49  }
50  return rc;
51}
52
53} // anonymous namespace
54
55Tensor _fft_c2r_mps(const Tensor& self, IntArrayRef dim, int64_t normalization, int64_t last_dim_size) {
56  TORCH_CHECK(self.is_complex());
57  auto in_sizes = self.sizes();
58  DimVector out_sizes(in_sizes.begin(), in_sizes.end());
59  out_sizes[dim.back()] = last_dim_size;
60  auto out = at::empty(out_sizes, self.options().dtype(c10::toRealValueType(self.scalar_type())));
61  return _fft_c2r_mps_out(self, dim, normalization, last_dim_size, out);
62}
63
64Tensor _fft_r2c_mps(const Tensor& self, IntArrayRef dim, int64_t normalization, bool onesided) {
65  TORCH_CHECK(self.is_floating_point());
66  auto input_sizes = self.sizes();
67  DimVector out_sizes(input_sizes.begin(), input_sizes.end());
68  auto last_dim = dim.back();
69  auto last_dim_halfsize = (input_sizes[last_dim]) / 2 + 1;
70  if (onesided) {
71    out_sizes[last_dim] = last_dim_halfsize;
72  }
73
74  auto out = at::empty(out_sizes, self.options().dtype(c10::toComplexType(self.scalar_type())));
75  return _fft_r2c_mps_out(self, dim, normalization, onesided, out);
76}
77
78Tensor _fft_c2c_mps(const Tensor& self, IntArrayRef dim, int64_t normalization, bool forward) {
79  TORCH_CHECK(self.is_complex());
80  if (dim.empty()) {
81    return self.clone();
82  }
83  auto out = at::empty(self.sizes(), self.options());
84  return _fft_c2c_mps_out(self, dim, normalization, forward, out);
85}
86
87using namespace mps;
88
89// TODO: Investigate numerical discrepancies see https://github.com/pytorch/pytorch/issues/120237
90Tensor& _fft_r2c_mps_out(const Tensor& self, IntArrayRef dim, int64_t normalization, bool onesided, Tensor& out) {
91  TORCH_CHECK(supportsComplex(), "FFT operations are only supported on MacOS 14+");
92  auto key = __func__ + getTensorsStringKey({self, out}) + ":" + getArrayRefString(dim) + ":" +
93      std::to_string(normalization) + ":" + std::to_string(onesided);
94  @autoreleasepool {
95    auto cachedGraph = LookUpOrCreateCachedGraph<MPSUnaryCachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
96      auto inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, self);
97      auto descriptor = [MPSGraphFFTDescriptor descriptor];
98      descriptor.scalingMode = normalization_to_ScalingMode(normalization);
99      MPSGraphTensor* outputTensor;
100      if (onesided) {
101        // Return only unique results:
102        outputTensor = [mpsGraph realToHermiteanFFTWithTensor:inputTensor
103                                                         axes:IntArrayToNSArray(dim)
104                                                   descriptor:descriptor
105                                                         name:nil];
106      } else {
107        // Return with Hermitean conjugate results:
108        auto useDataType =
109            (inputTensor.dataType == MPSDataTypeFloat16) ? MPSDataTypeComplexFloat16 : MPSDataTypeComplexFloat32;
110        auto cTensor = [mpsGraph castTensor:inputTensor toType:useDataType name:nil];
111        outputTensor = [mpsGraph fastFourierTransformWithTensor:cTensor
112                                                           axes:IntArrayToNSArray(dim)
113                                                     descriptor:descriptor
114                                                           name:nil];
115      }
116      newCachedGraph->inputTensor_ = inputTensor;
117      newCachedGraph->outputTensor_ = outputTensor;
118    });
119    auto inputPlaceholder = Placeholder(cachedGraph->inputTensor_, self);
120    auto outputPlaceholder = Placeholder(cachedGraph->outputTensor_, out);
121    auto feeds = dictionaryFromPlaceholders(inputPlaceholder);
122    runMPSGraph(getCurrentMPSStream(), cachedGraph->graph(), feeds, outputPlaceholder);
123  }
124  return out;
125}
126
127Tensor& _fft_c2r_mps_out(const Tensor& self,
128                         IntArrayRef dim,
129                         int64_t normalization,
130                         int64_t last_dim_size,
131                         Tensor& out) {
132  TORCH_CHECK(supportsComplex(), "FFT operations are only supported on MacOS 14+");
133  auto key = __func__ + getTensorsStringKey({self}) + ":" + getArrayRefString(dim) + ":" +
134      std::to_string(normalization) + ":" + std::to_string(last_dim_size);
135  @autoreleasepool {
136    auto cachedGraph = LookUpOrCreateCachedGraph<MPSUnaryCachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
137      auto inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, self);
138      auto descriptor = [MPSGraphFFTDescriptor descriptor];
139      descriptor.scalingMode = normalization_to_ScalingMode(normalization);
140      descriptor.inverse = YES;
141      descriptor.roundToOddHermitean = ((last_dim_size % 2) == 1) ? YES : NO;
142      auto outputTensor = [mpsGraph HermiteanToRealFFTWithTensor:inputTensor
143                                                            axes:IntArrayToNSArray(dim)
144                                                      descriptor:descriptor
145                                                            name:nil];
146      newCachedGraph->inputTensor_ = inputTensor;
147      newCachedGraph->outputTensor_ = outputTensor;
148    });
149    auto inputPlaceholder = Placeholder(cachedGraph->inputTensor_, self);
150    auto outputPlaceholder = Placeholder(cachedGraph->outputTensor_, out);
151    auto feeds = dictionaryFromPlaceholders(inputPlaceholder);
152    runMPSGraph(getCurrentMPSStream(), cachedGraph->graph(), feeds, outputPlaceholder);
153  }
154  return out;
155}
156
157Tensor& _fft_c2c_mps_out(const Tensor& self, IntArrayRef dim, int64_t normalization, bool forward, Tensor& out) {
158  TORCH_CHECK(supportsComplex(), "FFT operations are only supported on MacOS 14+");
159  auto key = __func__ + getTensorsStringKey({self}) + ":" + getArrayRefString(dim) + ":" +
160      std::to_string(normalization) + ":" + std::to_string(forward);
161  @autoreleasepool {
162    auto cachedGraph = LookUpOrCreateCachedGraph<MPSUnaryCachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
163      auto inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, self);
164      auto descriptor = [MPSGraphFFTDescriptor descriptor];
165      descriptor.scalingMode = normalization_to_ScalingMode(normalization);
166      descriptor.inverse = !forward;
167      auto outputTensor = [mpsGraph fastFourierTransformWithTensor:inputTensor
168                                                              axes:IntArrayToNSArray(dim)
169                                                        descriptor:descriptor
170                                                              name:nil];
171      newCachedGraph->inputTensor_ = inputTensor;
172      newCachedGraph->outputTensor_ = outputTensor;
173    });
174    auto inputPlaceholder = Placeholder(cachedGraph->inputTensor_, self);
175    auto outputPlaceholder = Placeholder(cachedGraph->outputTensor_, out);
176    auto feeds = dictionaryFromPlaceholders(inputPlaceholder);
177    runMPSGraph(getCurrentMPSStream(), cachedGraph->graph(), feeds, outputPlaceholder);
178  }
179  return out;
180}
181
182} // namespace at::native
183