1#import <ATen/native/metal/MetalContext.h> 2#import <ATen/native/metal/mpscnn/MPSCNNNeuronOp.h> 3 4#include <c10/macros/Macros.h> 5 6C10_CLANG_DIAGNOSTIC_PUSH() 7C10_CLANG_DIAGNOSTIC_IGNORE("-Wdeprecated-declarations") 8 9@implementation MPSCNNNeuronOp 10 11+ (MPSCNNNeuronHardSigmoid*)hardSigmoid API_AVAILABLE(ios(11.0), macos(10.13)) { 12// Remove this once we support iOS 11.3 13#if TARGET_OS_MACCATALYST 14 return nil; 15#else 16 static dispatch_once_t onceToken; 17 static MPSCNNNeuronHardSigmoid* neuron = nil; 18 dispatch_once(&onceToken, ^{ 19 neuron = [[MPSCNNNeuronHardSigmoid alloc] 20 initWithDevice:[MetalContext sharedInstance].device 21 a:1.0 / 6.0 22 b:0.5]; 23 }); 24 return neuron; 25#endif 26} 27 28+ (MPSCNNNeuronReLU*)relu { 29// Remove this once we support iOS 11.3 30#if TARGET_OS_MACCATALYST 31 return nil; 32#else 33 static MPSCNNNeuronReLU* relu = nil; 34 static dispatch_once_t onceToken; 35 dispatch_once(&onceToken, ^{ 36 relu = [[MPSCNNNeuronReLU alloc] 37 initWithDevice:[MetalContext sharedInstance].device 38 a:0]; 39 }); 40 return relu; 41#endif 42} 43 44+ (MPSCNNNeuronSigmoid*)sigmoid { 45// Remove this once we support iOS 11.3 46#if TARGET_OS_MACCATALYST 47 return nil; 48#else 49 static dispatch_once_t onceToken; 50 static MPSCNNNeuronSigmoid* sigmoid = nil; 51 dispatch_once(&onceToken, ^{ 52 sigmoid = [[MPSCNNNeuronSigmoid alloc] 53 initWithDevice:[MetalContext sharedInstance].device]; 54 }); 55 return sigmoid; 56#endif 57} 58 59+ (MPSCNNNeuronTanH*)tanh { 60// Remove this once we support iOS 11.3 61#if TARGET_OS_MACCATALYST 62 return nil; 63#else 64 static dispatch_once_t onceToken; 65 static MPSCNNNeuronTanH* tanh = nil; 66 dispatch_once(&onceToken, ^{ 67 tanh = [[MPSCNNNeuronTanH alloc] 68 initWithDevice:[MetalContext sharedInstance].device 69 a:1 70 b:1]; 71 }); 72 return tanh; 73#endif 74} 75 76@end 77 78C10_CLANG_DIAGNOSTIC_POP() 79 80API_AVAILABLE(ios(11.3), macos(10.13), macCatalyst(13.0)) 81@implementation MPSCNNNeuronOpDescriptor 82 83+ (MPSNNNeuronDescriptor*)hardSigmoidDescriptor { 84 static dispatch_once_t onceToken; 85 static MPSNNNeuronDescriptor* neuronDesc = nil; 86 dispatch_once(&onceToken, ^{ 87 neuronDesc = [MPSNNNeuronDescriptor 88 cnnNeuronDescriptorWithType:MPSCNNNeuronTypeHardSigmoid 89 a:1.0 / 6.0 90 b:0.5]; 91 }); 92 return neuronDesc; 93} 94 95+ (MPSNNNeuronDescriptor*)reluDescriptor { 96 static dispatch_once_t onceToken; 97 static MPSNNNeuronDescriptor* neuronDesc = nil; 98 dispatch_once(&onceToken, ^{ 99 neuronDesc = 100 [MPSNNNeuronDescriptor cnnNeuronDescriptorWithType:MPSCNNNeuronTypeReLU 101 a:0]; 102 }); 103 return neuronDesc; 104} 105 106+ (MPSNNNeuronDescriptor*)sigmoidDescriptor { 107 static dispatch_once_t onceToken; 108 static MPSNNNeuronDescriptor* neuronDesc = nil; 109 dispatch_once(&onceToken, ^{ 110 neuronDesc = [MPSNNNeuronDescriptor 111 cnnNeuronDescriptorWithType:MPSCNNNeuronTypeSigmoid]; 112 }); 113 return neuronDesc; 114} 115 116+ (MPSNNNeuronDescriptor*)tanhDescriptor { 117 static dispatch_once_t onceToken; 118 static MPSNNNeuronDescriptor* neuronDesc = nil; 119 dispatch_once(&onceToken, ^{ 120 neuronDesc = 121 [MPSNNNeuronDescriptor cnnNeuronDescriptorWithType:MPSCNNNeuronTypeTanH 122 a:1.0 123 b:1.0]; 124 }); 125 return neuronDesc; 126} 127 128@end 129