1 #pragma once 2 3 #include <MetalPerformanceShadersGraph/MetalPerformanceShadersGraph.h> 4 5 #if !defined(__MAC_15_0) && (!defined(MAC_OS_X_VERSION_15_0) || (MAC_OS_X_VERSION_MIN_REQUIRED < MAC_OS_X_VERSION_15_0)) 6 7 #define MPSDataTypeInt4 ((MPSDataType)(MPSDataTypeSignedBit | 4)) 8 9 @interface MPSNDArrayIdentity : MPSNDArrayUnaryKernel 10 - (MPSNDArray *__nullable)reshapeWithCommandBuffer:(__nullable id<MTLCommandBuffer>)cmdBuf 11 sourceArray:(MPSNDArray *__nonnull)sourceArray 12 shape:(MPSShape *__nonnull)shape 13 destinationArray:(MPSNDArray *__nullable)destinationArray; 14 @end 15 16 @interface MPSNDArrayDescriptor () 17 @property(readwrite, nonatomic) BOOL preferPackedRows; 18 @end 19 20 @interface MPSNDArray () 21 - (nonnull instancetype)initWithBuffer:(id<MTLBuffer> _Nonnull)buffer 22 offset:(NSUInteger)offset 23 descriptor:(MPSNDArrayDescriptor *_Nonnull)descriptor; 24 - (MPSNDArray *__nullable)arrayViewWithShape:(MPSShape *_Nullable)shape strides:(MPSShape *_Nonnull)strides; 25 @end 26 27 @interface MPSNDArrayQuantizationDescriptor : NSObject<NSCopying> 28 @end 29 30 @interface MPSNDArrayQuantizedMatrixMultiplication : MPSNDArrayMatrixMultiplication 31 - (nonnull instancetype)initWithDevice:(nonnull id<MTLDevice>)device 32 leftQuantizationDescriptor:(MPSNDArrayQuantizationDescriptor *_Nullable)leftQuantizationDescriptor 33 rightQuantizationDescriptor:(MPSNDArrayQuantizationDescriptor *_Nullable)rightQuantizationDescriptor; 34 35 - (void)encodeToCommandEncoder:(id<MTLComputeCommandEncoder> _Nullable)encoder 36 commandBuffer:(nonnull id<MTLCommandBuffer>)commandBuffer 37 sourceArrays:(nonnull NSArray<MPSNDArray *> *)sourceArrays 38 destinationArray:(nonnull MPSNDArray *)destination; 39 @end 40 41 @interface MPSNDArrayAffineQuantizationDescriptor : MPSNDArrayQuantizationDescriptor 42 - (nonnull instancetype)initWithDataType:(MPSDataType)quantizationDataType 43 hasZeroPoint:(BOOL)hasZeroPoint 44 hasMinValue:(BOOL)hasMinValue; 45 @property(readwrite, nonatomic) bool implicitZeroPoint; 46 @end 47 48 @interface MPSGraph () 49 - (MPSGraphTensor *_Nonnull)dequantizeTensor:(MPSGraphTensor *_Nonnull)tensor 50 scaleTensor:(MPSGraphTensor *_Nonnull)scaleTensor 51 zeroPointTensor:(MPSGraphTensor *_Nonnull)zeroPointTensor 52 dataType:(MPSDataType)dataType 53 name:(NSString *_Nullable)name; 54 @end 55 56 #endif 57