xref: /aosp_15_r20/external/executorch/backends/apple/mps/runtime/operations/MPSGraphSequoiaOps.h (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 #pragma once
2 
3 #include <MetalPerformanceShadersGraph/MetalPerformanceShadersGraph.h>
4 
5 #if !defined(__MAC_15_0) && (!defined(MAC_OS_X_VERSION_15_0) || (MAC_OS_X_VERSION_MIN_REQUIRED < MAC_OS_X_VERSION_15_0))
6 
7 #define MPSDataTypeInt4 ((MPSDataType)(MPSDataTypeSignedBit | 4))
8 
9 @interface MPSNDArrayIdentity : MPSNDArrayUnaryKernel
10 - (MPSNDArray *__nullable)reshapeWithCommandBuffer:(__nullable id<MTLCommandBuffer>)cmdBuf
11                                        sourceArray:(MPSNDArray *__nonnull)sourceArray
12                                              shape:(MPSShape *__nonnull)shape
13                                   destinationArray:(MPSNDArray *__nullable)destinationArray;
14 @end
15 
16 @interface MPSNDArrayDescriptor ()
17 @property(readwrite, nonatomic) BOOL preferPackedRows;
18 @end
19 
20 @interface MPSNDArray ()
21 - (nonnull instancetype)initWithBuffer:(id<MTLBuffer> _Nonnull)buffer
22                                 offset:(NSUInteger)offset
23                             descriptor:(MPSNDArrayDescriptor *_Nonnull)descriptor;
24 - (MPSNDArray *__nullable)arrayViewWithShape:(MPSShape *_Nullable)shape strides:(MPSShape *_Nonnull)strides;
25 @end
26 
27 @interface MPSNDArrayQuantizationDescriptor : NSObject<NSCopying>
28 @end
29 
30 @interface MPSNDArrayQuantizedMatrixMultiplication : MPSNDArrayMatrixMultiplication
31 - (nonnull instancetype)initWithDevice:(nonnull id<MTLDevice>)device
32             leftQuantizationDescriptor:(MPSNDArrayQuantizationDescriptor *_Nullable)leftQuantizationDescriptor
33            rightQuantizationDescriptor:(MPSNDArrayQuantizationDescriptor *_Nullable)rightQuantizationDescriptor;
34 
35 - (void)encodeToCommandEncoder:(id<MTLComputeCommandEncoder> _Nullable)encoder
36                  commandBuffer:(nonnull id<MTLCommandBuffer>)commandBuffer
37                   sourceArrays:(nonnull NSArray<MPSNDArray *> *)sourceArrays
38               destinationArray:(nonnull MPSNDArray *)destination;
39 @end
40 
41 @interface MPSNDArrayAffineQuantizationDescriptor : MPSNDArrayQuantizationDescriptor
42 - (nonnull instancetype)initWithDataType:(MPSDataType)quantizationDataType
43                             hasZeroPoint:(BOOL)hasZeroPoint
44                              hasMinValue:(BOOL)hasMinValue;
45 @property(readwrite, nonatomic) bool implicitZeroPoint;
46 @end
47 
48 @interface MPSGraph ()
49 - (MPSGraphTensor *_Nonnull)dequantizeTensor:(MPSGraphTensor *_Nonnull)tensor
50                                  scaleTensor:(MPSGraphTensor *_Nonnull)scaleTensor
51                              zeroPointTensor:(MPSGraphTensor *_Nonnull)zeroPointTensor
52                                     dataType:(MPSDataType)dataType
53                                         name:(NSString *_Nullable)name;
54 @end
55 
56 #endif
57