1 #import <ATen/native/metal/MetalConvParams.h> 2 #import <ATen/native/metal/MetalNeuronType.h> 3 #import <ATen/native/metal/mpscnn/MPSCNNOp.h> 4 #import <Foundation/Foundation.h> 5 6 API_AVAILABLE(ios(11.0), macos(10.13)) 7 @interface MPSCNNConvDataSource : NSObject<MPSCNNConvolutionDataSource> 8 @property(nonatomic, assign) void* weights; 9 @property(nonatomic, assign) float* bias; 10 11 - (id)initWithWeights:(void*)weights 12 Bias:(float*)bias 13 Desc:(MPSCNNConvolutionDescriptor*)desc; 14 15 @end 16 17 using namespace at::native::metal; 18 API_AVAILABLE(ios(11.0), macos(10.13)) 19 @interface MPSCNNConvOp : NSObject<MPSCNNOp> 20 + (MPSCNNConvOp*)conv2d:(const Conv2DParams&)params 21 weights:(float*)w 22 bias:(float*)b 23 neuronFilter:(NeuronType)t; 24 @end 25