xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/mps/MPSGraphSonomaOps.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <MetalPerformanceShadersGraph/MetalPerformanceShadersGraph.h>
4 
5 #if !defined(__MAC_14_0) && \
6     (!defined(MAC_OS_X_VERSION_14_0) || (MAC_OS_X_VERSION_MIN_REQUIRED < MAC_OS_X_VERSION_14_0))
7 
8 typedef NS_ENUM(NSUInteger, MPSGraphFFTScalingMode)
9 {
10     MPSGraphFFTScalingModeNone          = 0L,
11     MPSGraphFFTScalingModeSize          = 1L,
12     MPSGraphFFTScalingModeUnitary       = 2L,
13 };
14 
15 @interface FakeMPSGraphFFTDescriptor : NSObject<NSCopying>
16 @property (readwrite, nonatomic) BOOL inverse;
17 @property (readwrite, nonatomic) MPSGraphFFTScalingMode scalingMode;
18 @property (readwrite, nonatomic) BOOL roundToOddHermitean;
19 +(nullable instancetype) descriptor;
20 @end
21 
22 @compatibility_alias MPSGraphFFTDescriptor FakeMPSGraphFFTDescriptor;
23 
24 @interface MPSGraph (SonomaOps)
25 -(MPSGraphTensor * _Nonnull) conjugateWithTensor:(MPSGraphTensor * _Nonnull) tensor
26                                             name:(NSString * _Nullable) name;
27 
28 -(MPSGraphTensor * _Nonnull) realPartOfTensor:(MPSGraphTensor * _Nonnull) tensor
29                                          name:(NSString * _Nullable) name;
30 
31 
32 -(MPSGraphTensor * _Nonnull) fastFourierTransformWithTensor:(MPSGraphTensor * _Nonnull) tensor
33                                                        axes:(NSArray<NSNumber *> * _Nonnull) axes
34                                                  descriptor:(MPSGraphFFTDescriptor * _Nonnull) descriptor
35                                                        name:(NSString * _Nullable) name;
36 
37 -(MPSGraphTensor * _Nonnull) realToHermiteanFFTWithTensor:(MPSGraphTensor * _Nonnull) tensor
38                                                      axes:(NSArray<NSNumber *> * _Nonnull) axes
39                                                descriptor:(MPSGraphFFTDescriptor * _Nonnull) descriptor
40                                                      name:(NSString * _Nullable) name;
41 
42 -(MPSGraphTensor * _Nonnull) HermiteanToRealFFTWithTensor:(MPSGraphTensor * _Nonnull) tensor
43                                                      axes:(NSArray<NSNumber *> * _Nonnull) axes
44                                                descriptor:(MPSGraphFFTDescriptor * _Nonnull) descriptor
45                                                      name:(NSString * _Nullable) name;
46 @end
47 
48 // define BFloat16 enums for MacOS13
49 #define MPSDataTypeBFloat16 ((MPSDataType) (MPSDataTypeAlternateEncodingBit | MPSDataTypeFloat16))
50 
51 // define Metal version
52 #define MTLLanguageVersion3_1 ((MTLLanguageVersion) ((3 << 16) + 1))
53 #endif
54