xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/mps/MPSGraphVenturaOps.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 #include <MetalPerformanceShadersGraph/MetalPerformanceShadersGraph.h>
3 
4 // TODO: Remove me when moved to MacOS 13
5 #if !defined(__MAC_13_2) && \
6     (!defined(MAC_OS_X_VERSION_13_2) || (MAC_OS_X_VERSION_MIN_REQUIRED < MAC_OS_X_VERSION_13_2))
7 
8 @interface FakeMPSGraphConvolution3DOpDescriptor : NSObject<NSCopying>
9 
10 @property (readwrite, nonatomic) NSUInteger strideInX;
11 @property (readwrite, nonatomic) NSUInteger strideInY;
12 @property (readwrite, nonatomic) NSUInteger strideInZ;
13 @property (readwrite, nonatomic) NSUInteger dilationRateInX;
14 @property (readwrite, nonatomic) NSUInteger dilationRateInY;
15 @property (readwrite, nonatomic) NSUInteger dilationRateInZ;
16 
17 @property (readwrite, nonatomic) NSUInteger paddingLeft;
18 @property (readwrite, nonatomic) NSUInteger paddingRight;
19 @property (readwrite, nonatomic) NSUInteger paddingTop;
20 @property (readwrite, nonatomic) NSUInteger paddingBottom;
21 @property (readwrite, nonatomic) NSUInteger paddingFront;
22 @property (readwrite, nonatomic) NSUInteger paddingBack;
23 
24 @property (readwrite, nonatomic) MPSGraphPaddingStyle paddingStyle;
25 @property (readwrite, nonatomic) MPSGraphTensorNamedDataLayout dataLayout;
26 @property (readwrite, nonatomic) MPSGraphTensorNamedDataLayout weightsLayout;
27 
property(readwrite,nonatomic)28 @property (readwrite, nonatomic) NSUInteger groups;
29 
30 @end
31 
32 @compatibility_alias MPSGraphConvolution3DOpDescriptor FakeMPSGraphConvolution3DOpDescriptor;
33 
34 #endif
35 
36 @interface MPSGraph (VenturaOps)
37 
38 #if !defined(__MAC_13_0) && \
39     (!defined(MAC_OS_X_VERSION_13_0) || (MAC_OS_X_VERSION_MIN_REQUIRED < MAC_OS_X_VERSION_13_0))
40 
41 typedef NS_ENUM(NSUInteger, MPSGraphResizeNearestRoundingMode)
42 {
43     MPSGraphResizeNearestRoundingModeRoundPreferCeil   =  0L,
44     MPSGraphResizeNearestRoundingModeRoundPreferFloor  =  1L,
45     MPSGraphResizeNearestRoundingModeCeil              =  2L,
46     MPSGraphResizeNearestRoundingModeFloor             =  3L,
47     MPSGraphResizeNearestRoundingModeRoundToEven       =  4L,
48     MPSGraphResizeNearestRoundingModeRoundToOdd        =  5L,
49 };
50 
51 // Define complex enums for MacOS 12
52 #define MPSDataTypeComplexBit 0x01000000
53 #define MPSDataTypeComplexFloat32 ((MPSDataType) (MPSDataTypeFloatBit | MPSDataTypeComplexBit | 64))
54 #define MPSDataTypeComplexFloat16 ((MPSDataType) (MPSDataTypeFloatBit | MPSDataTypeComplexBit | 32))
55 #endif
56 
57 - (MPSGraphTensor * _Nonnull) convolution3DWithSourceTensor:(MPSGraphTensor * _Nonnull) source
58                                             weightsTensor:(MPSGraphTensor * _Nonnull) weights
59                                                descriptor:(MPSGraphConvolution3DOpDescriptor * _Nonnull) descriptor
60                                                      name:(NSString * _Nullable) name;
61 
62 - (MPSGraphTensor * _Nonnull) convolution3DDataGradientWithIncomingGradientTensor:(MPSGraphTensor * _Nonnull) incomingGradient
63                                                                   weightsTensor:(MPSGraphTensor * _Nonnull) weights
64                                                                     outputShape:(MPSShape * _Nonnull) outputShape
65                                                    forwardConvolutionDescriptor:(MPSGraphConvolution3DOpDescriptor * _Nonnull) forwardConvolutionDescriptor
66                                                                            name:(NSString * _Nullable) name;
67 
68 - (MPSGraphTensor * _Nonnull) convolution3DWeightsGradientWithIncomingGradientTensor:(MPSGraphTensor * _Nonnull) incomingGradient
69                                                                       sourceTensor:(MPSGraphTensor * _Nonnull) source
70                                                                        outputShape:(MPSShape * _Nonnull) outputShape
71                                                       forwardConvolutionDescriptor:(MPSGraphConvolution3DOpDescriptor * _Nonnull) forwardConvolutionDescriptor
72                                                                               name:(NSString * _Nullable) name;
73 
74 - (MPSGraphTensor * _Nonnull)cumulativeSumWithTensor:(MPSGraphTensor * _Nonnull)tensor
75                                                 axis:(NSInteger)axis
76                                                 name:(NSString * _Nullable)name;
77 
78 - (MPSGraphTensor * _Nonnull)sortWithTensor:(MPSGraphTensor * _Nonnull)tensor
79                                        axis:(NSInteger)axis
80                                        name:(NSString * _Nullable)name;
81 
82 - (MPSGraphTensor * _Nonnull) sortWithTensor:(MPSGraphTensor * _Nonnull) tensor
83                                axis:(NSInteger) axis
84                          descending:(BOOL) descending
85                                name:(NSString * _Nullable) name;
86 
87 - (MPSGraphTensor * _Nonnull) sortWithTensor:(MPSGraphTensor * _Nonnull) tensor
88                          axisTensor:(MPSGraphTensor * _Nonnull) axisTensor
89                          descending:(BOOL) descending
90                                name:(NSString * _Nullable) name;
91 
92 - (MPSGraphTensor * _Nonnull) sortWithTensor:(MPSGraphTensor * _Nonnull) tensor
93                          axisTensor:(MPSGraphTensor * _Nonnull) axisTensor
94                                name:(NSString * _Nullable) name;
95 
96 - (MPSGraphTensor * _Nonnull)argSortWithTensor:(MPSGraphTensor * _Nonnull)tensor
97                                           axis:(NSInteger)axis
98                                           name:(NSString * _Nullable)name;
99 
100 - (MPSGraphTensor * _Nonnull) argSortWithTensor:(MPSGraphTensor * _Nonnull) tensor
101                                   axis:(NSInteger) axis
102                             descending:(BOOL) descending
103                                   name:(NSString * _Nullable) name;
104 
105 - (MPSGraphTensor * _Nonnull) argSortWithTensor:(MPSGraphTensor * _Nonnull) tensor
106                            axisTensor:(MPSGraphTensor * _Nonnull) axisTensor
107                            descending:(BOOL) descending
108                                  name:(NSString * _Nullable) name;
109 
110 - (MPSGraphTensor * _Nonnull) argSortWithTensor:(MPSGraphTensor * _Nonnull) tensor
111                            axisTensor:(MPSGraphTensor * _Nonnull) axisTensor
112                                  name:(NSString * _Nullable) name;
113 
114 - (MPSGraphTensor * _Nonnull)inverseOfTensor:(MPSGraphTensor * _Nonnull) inputTensor
115                                         name:(NSString * _Nullable)name;
116 
117 - (MPSGraphTensor * _Nonnull) resizeNearestWithTensor:(MPSGraphTensor * _Nonnull) imagesTensor
118                                            sizeTensor:(MPSGraphTensor * _Nonnull) size
119                                   nearestRoundingMode:(MPSGraphResizeNearestRoundingMode) nearestRoundingMode
120                                          centerResult:(BOOL) centerResult
121                                          alignCorners:(BOOL) alignCorners
122                                                layout:(MPSGraphTensorNamedDataLayout) layout
123                                                  name:(NSString * _Nullable) name;
124 
125 - (MPSGraphTensor * _Nonnull) resizeNearestWithTensor:(MPSGraphTensor * _Nonnull) imagesTensor
126                                            sizeTensor:(MPSGraphTensor * _Nonnull) size
127                                     scaleOffsetTensor:(MPSGraphTensor * _Nonnull) scaleOffset
128                                   nearestRoundingMode:(MPSGraphResizeNearestRoundingMode) nearestRoundingMode
129                                                layout:(MPSGraphTensorNamedDataLayout) layout
130                                                  name:(NSString * _Nullable) name;
131 
132 - (MPSGraphTensor * _Nonnull) resizeBilinearWithTensor:(MPSGraphTensor * _Nonnull) imagesTensor
133                                             sizeTensor:(MPSGraphTensor * _Nonnull) size
134                                           centerResult:(BOOL) centerResult
135                                           alignCorners:(BOOL) alignCorners
136                                                 layout:(MPSGraphTensorNamedDataLayout) layout
137                                                   name:(NSString * _Nullable) name;
138 
139 - (MPSGraphTensor * _Nonnull) resizeBilinearWithTensor:(MPSGraphTensor * _Nonnull) imagesTensor
140                                             sizeTensor:(MPSGraphTensor * _Nonnull) size
141                                      scaleOffsetTensor:(MPSGraphTensor * _Nonnull) scaleOffset
142                                                 layout:(MPSGraphTensorNamedDataLayout) layout
143                                                   name:(NSString * _Nullable) name;
144 
145 - (MPSGraphTensor * _Nonnull) resizeNearestWithGradientTensor:(MPSGraphTensor * _Nonnull) gradient
146                                                         input:(MPSGraphTensor * _Nonnull) input
147                                           nearestRoundingMode:(MPSGraphResizeNearestRoundingMode) nearestRoundingMode
148                                                  centerResult:(BOOL) centerResult
149                                                  alignCorners:(BOOL) alignCorners
150                                                        layout:(MPSGraphTensorNamedDataLayout) layout
151                                                          name:(NSString * _Nullable) name;
152 
153 - (MPSGraphTensor * _Nonnull) resizeNearestWithGradientTensor:(MPSGraphTensor * _Nonnull) gradient
154                                                         input:(MPSGraphTensor * _Nonnull) input
155                                             scaleOffsetTensor:(MPSGraphTensor * _Nonnull) scaleOffset
156                                           nearestRoundingMode:(MPSGraphResizeNearestRoundingMode) nearestRoundingMode
157                                                        layout:(MPSGraphTensorNamedDataLayout) layout
158                                                          name:(NSString * _Nullable) name;
159 
160 - (MPSGraphTensor * _Nonnull) resizeBilinearWithGradientTensor:(MPSGraphTensor * _Nonnull) gradient
161                                                          input:(MPSGraphTensor * _Nonnull) input
162                                                   centerResult:(BOOL) centerResult
163                                                   alignCorners:(BOOL) alignCorners
164                                                         layout:(MPSGraphTensorNamedDataLayout) layout
165                                                           name:(NSString * _Nullable) name;
166 
167 - (MPSGraphTensor * _Nonnull) resizeBilinearWithGradientTensor:(MPSGraphTensor * _Nonnull) gradient
168                                                          input:(MPSGraphTensor * _Nonnull) input
169                                              scaleOffsetTensor:(MPSGraphTensor * _Nonnull) scaleOffset
170                                                         layout:(MPSGraphTensorNamedDataLayout) layout
171                                                           name:(NSString * _Nullable) name;
172 
173 - (MPSGraphTensor * _Nonnull) sampleGridWithSourceTensor:(MPSGraphTensor * _Nonnull) source
174                                         coordinateTensor:(MPSGraphTensor * _Nonnull) coordinates
175                                                   layout:(MPSGraphTensorNamedDataLayout) layout
176                                     normalizeCoordinates:(BOOL) normalizeCoordinates
177                                      relativeCoordinates:(BOOL) relativeCoordinates
178                                             alignCorners:(BOOL) alignCorners
179                                              paddingMode:(MPSGraphPaddingMode) paddingMode
180                                             samplingMode:(MPSGraphResizeMode) samplingMode
181                                            constantValue:(double) constantValue
182                                                     name:(NSString * _Nullable) name;
183 
184 - (MPSGraphTensor * _Nonnull) sampleGridWithSourceTensor:(MPSGraphTensor * _Nonnull) source
185                                         coordinateTensor:(MPSGraphTensor * _Nonnull) coordinates
186                                                   layout:(MPSGraphTensorNamedDataLayout) layout
187                                     normalizeCoordinates:(BOOL) normalizeCoordinates
188                                      relativeCoordinates:(BOOL) relativeCoordinates
189                                             alignCorners:(BOOL) alignCorners
190                                              paddingMode:(MPSGraphPaddingMode) paddingMode
191                                      nearestRoundingMode:(MPSGraphResizeNearestRoundingMode) nearestRoundingMode
192                                            constantValue:(double) constantValue
193                                                     name:(NSString * _Nullable) name;
194 - (MPSGraphTensor * _Nonnull) truncateWithTensor:(MPSGraphTensor * _Nonnull) tensor
195                                             name:(NSString * _Nullable) name;
196 
197 @end
198