1#include <ATen/native/SpectralOpsUtils.h> 2#include <ATen/native/mps/MPSGraphSonomaOps.h> 3#include <ATen/native/mps/MPSGraphVenturaOps.h> 4#include <ATen/native/mps/OperationUtils.h> 5 6#ifndef AT_PER_OPERATOR_HEADERS 7#include <ATen/Functions.h> 8#include <ATen/NativeFunctions.h> 9#else 10#include <ATen/ops/_fft_c2c_native.h> 11#include <ATen/ops/_fft_c2r_native.h> 12#include <ATen/ops/_fft_r2c_native.h> 13#endif 14 15#if !defined(__MAC_14_0) && (!defined(MAC_OS_X_VERSION_14_0) || (MAC_OS_X_VERSION_MIN_REQUIRED < MAC_OS_X_VERSION_14_0)) 16@implementation FakeMPSGraphFFTDescriptor 17+ (nullable instancetype)descriptor { 18 // Redispatch the constructor to the actual implementation 19 id desc = NSClassFromString(@"MPSGraphFFTDescriptor"); 20 return (FakeMPSGraphFFTDescriptor*)[desc descriptor]; 21} 22 23- (nonnull id)copyWithZone:(nullable NSZone*)zone { 24 return self; 25} 26@end 27#endif 28 29namespace at::native { 30namespace { 31MPSGraphFFTScalingMode normalization_to_ScalingMode(int64_t normalization) { 32 switch (static_cast<fft_norm_mode>(normalization)) { 33 case fft_norm_mode::none: 34 return MPSGraphFFTScalingModeNone; 35 case fft_norm_mode::by_n: 36 return MPSGraphFFTScalingModeSize; 37 case fft_norm_mode::by_root_n: 38 return MPSGraphFFTScalingModeUnitary; 39 default: 40 break; 41 } 42 TORCH_CHECK(false, "Unsupported normalization type", normalization); 43} 44 45NSArray<NSNumber*>* IntArrayToNSArray(IntArrayRef arr) { 46 auto rc = [NSMutableArray<NSNumber*> arrayWithCapacity:arr.size()]; 47 for (const auto idx : c10::irange(arr.size())) { 48 rc[idx] = [NSNumber numberWithInteger:arr[idx]]; 49 } 50 return rc; 51} 52 53} // anonymous namespace 54 55Tensor _fft_c2r_mps(const Tensor& self, IntArrayRef dim, int64_t normalization, int64_t last_dim_size) { 56 TORCH_CHECK(self.is_complex()); 57 auto in_sizes = self.sizes(); 58 DimVector out_sizes(in_sizes.begin(), in_sizes.end()); 59 out_sizes[dim.back()] = last_dim_size; 60 auto out = at::empty(out_sizes, self.options().dtype(c10::toRealValueType(self.scalar_type()))); 61 return _fft_c2r_mps_out(self, dim, normalization, last_dim_size, out); 62} 63 64Tensor _fft_r2c_mps(const Tensor& self, IntArrayRef dim, int64_t normalization, bool onesided) { 65 TORCH_CHECK(self.is_floating_point()); 66 auto input_sizes = self.sizes(); 67 DimVector out_sizes(input_sizes.begin(), input_sizes.end()); 68 auto last_dim = dim.back(); 69 auto last_dim_halfsize = (input_sizes[last_dim]) / 2 + 1; 70 if (onesided) { 71 out_sizes[last_dim] = last_dim_halfsize; 72 } 73 74 auto out = at::empty(out_sizes, self.options().dtype(c10::toComplexType(self.scalar_type()))); 75 return _fft_r2c_mps_out(self, dim, normalization, onesided, out); 76} 77 78Tensor _fft_c2c_mps(const Tensor& self, IntArrayRef dim, int64_t normalization, bool forward) { 79 TORCH_CHECK(self.is_complex()); 80 if (dim.empty()) { 81 return self.clone(); 82 } 83 auto out = at::empty(self.sizes(), self.options()); 84 return _fft_c2c_mps_out(self, dim, normalization, forward, out); 85} 86 87using namespace mps; 88 89// TODO: Investigate numerical discrepancies see https://github.com/pytorch/pytorch/issues/120237 90Tensor& _fft_r2c_mps_out(const Tensor& self, IntArrayRef dim, int64_t normalization, bool onesided, Tensor& out) { 91 TORCH_CHECK(supportsComplex(), "FFT operations are only supported on MacOS 14+"); 92 auto key = __func__ + getTensorsStringKey({self, out}) + ":" + getArrayRefString(dim) + ":" + 93 std::to_string(normalization) + ":" + std::to_string(onesided); 94 @autoreleasepool { 95 auto cachedGraph = LookUpOrCreateCachedGraph<MPSUnaryCachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) { 96 auto inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, self); 97 auto descriptor = [MPSGraphFFTDescriptor descriptor]; 98 descriptor.scalingMode = normalization_to_ScalingMode(normalization); 99 MPSGraphTensor* outputTensor; 100 if (onesided) { 101 // Return only unique results: 102 outputTensor = [mpsGraph realToHermiteanFFTWithTensor:inputTensor 103 axes:IntArrayToNSArray(dim) 104 descriptor:descriptor 105 name:nil]; 106 } else { 107 // Return with Hermitean conjugate results: 108 auto useDataType = 109 (inputTensor.dataType == MPSDataTypeFloat16) ? MPSDataTypeComplexFloat16 : MPSDataTypeComplexFloat32; 110 auto cTensor = [mpsGraph castTensor:inputTensor toType:useDataType name:nil]; 111 outputTensor = [mpsGraph fastFourierTransformWithTensor:cTensor 112 axes:IntArrayToNSArray(dim) 113 descriptor:descriptor 114 name:nil]; 115 } 116 newCachedGraph->inputTensor_ = inputTensor; 117 newCachedGraph->outputTensor_ = outputTensor; 118 }); 119 auto inputPlaceholder = Placeholder(cachedGraph->inputTensor_, self); 120 auto outputPlaceholder = Placeholder(cachedGraph->outputTensor_, out); 121 auto feeds = dictionaryFromPlaceholders(inputPlaceholder); 122 runMPSGraph(getCurrentMPSStream(), cachedGraph->graph(), feeds, outputPlaceholder); 123 } 124 return out; 125} 126 127Tensor& _fft_c2r_mps_out(const Tensor& self, 128 IntArrayRef dim, 129 int64_t normalization, 130 int64_t last_dim_size, 131 Tensor& out) { 132 TORCH_CHECK(supportsComplex(), "FFT operations are only supported on MacOS 14+"); 133 auto key = __func__ + getTensorsStringKey({self}) + ":" + getArrayRefString(dim) + ":" + 134 std::to_string(normalization) + ":" + std::to_string(last_dim_size); 135 @autoreleasepool { 136 auto cachedGraph = LookUpOrCreateCachedGraph<MPSUnaryCachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) { 137 auto inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, self); 138 auto descriptor = [MPSGraphFFTDescriptor descriptor]; 139 descriptor.scalingMode = normalization_to_ScalingMode(normalization); 140 descriptor.inverse = YES; 141 descriptor.roundToOddHermitean = ((last_dim_size % 2) == 1) ? YES : NO; 142 auto outputTensor = [mpsGraph HermiteanToRealFFTWithTensor:inputTensor 143 axes:IntArrayToNSArray(dim) 144 descriptor:descriptor 145 name:nil]; 146 newCachedGraph->inputTensor_ = inputTensor; 147 newCachedGraph->outputTensor_ = outputTensor; 148 }); 149 auto inputPlaceholder = Placeholder(cachedGraph->inputTensor_, self); 150 auto outputPlaceholder = Placeholder(cachedGraph->outputTensor_, out); 151 auto feeds = dictionaryFromPlaceholders(inputPlaceholder); 152 runMPSGraph(getCurrentMPSStream(), cachedGraph->graph(), feeds, outputPlaceholder); 153 } 154 return out; 155} 156 157Tensor& _fft_c2c_mps_out(const Tensor& self, IntArrayRef dim, int64_t normalization, bool forward, Tensor& out) { 158 TORCH_CHECK(supportsComplex(), "FFT operations are only supported on MacOS 14+"); 159 auto key = __func__ + getTensorsStringKey({self}) + ":" + getArrayRefString(dim) + ":" + 160 std::to_string(normalization) + ":" + std::to_string(forward); 161 @autoreleasepool { 162 auto cachedGraph = LookUpOrCreateCachedGraph<MPSUnaryCachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) { 163 auto inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, self); 164 auto descriptor = [MPSGraphFFTDescriptor descriptor]; 165 descriptor.scalingMode = normalization_to_ScalingMode(normalization); 166 descriptor.inverse = !forward; 167 auto outputTensor = [mpsGraph fastFourierTransformWithTensor:inputTensor 168 axes:IntArrayToNSArray(dim) 169 descriptor:descriptor 170 name:nil]; 171 newCachedGraph->inputTensor_ = inputTensor; 172 newCachedGraph->outputTensor_ = outputTensor; 173 }); 174 auto inputPlaceholder = Placeholder(cachedGraph->inputTensor_, self); 175 auto outputPlaceholder = Placeholder(cachedGraph->outputTensor_, out); 176 auto feeds = dictionaryFromPlaceholders(inputPlaceholder); 177 runMPSGraph(getCurrentMPSStream(), cachedGraph->graph(), feeds, outputPlaceholder); 178 } 179 return out; 180} 181 182} // namespace at::native 183