xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/metal/mpscnn/MPSCNNNeuronOp.mm (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1#import <ATen/native/metal/MetalContext.h>
2#import <ATen/native/metal/mpscnn/MPSCNNNeuronOp.h>
3
4#include <c10/macros/Macros.h>
5
6C10_CLANG_DIAGNOSTIC_PUSH()
7C10_CLANG_DIAGNOSTIC_IGNORE("-Wdeprecated-declarations")
8
9@implementation MPSCNNNeuronOp
10
11+ (MPSCNNNeuronHardSigmoid*)hardSigmoid API_AVAILABLE(ios(11.0), macos(10.13)) {
12// Remove this once we support iOS 11.3
13#if TARGET_OS_MACCATALYST
14  return nil;
15#else
16  static dispatch_once_t onceToken;
17  static MPSCNNNeuronHardSigmoid* neuron = nil;
18  dispatch_once(&onceToken, ^{
19    neuron = [[MPSCNNNeuronHardSigmoid alloc]
20        initWithDevice:[MetalContext sharedInstance].device
21                     a:1.0 / 6.0
22                     b:0.5];
23  });
24  return neuron;
25#endif
26}
27
28+ (MPSCNNNeuronReLU*)relu {
29// Remove this once we support iOS 11.3
30#if TARGET_OS_MACCATALYST
31  return nil;
32#else
33  static MPSCNNNeuronReLU* relu = nil;
34  static dispatch_once_t onceToken;
35  dispatch_once(&onceToken, ^{
36    relu = [[MPSCNNNeuronReLU alloc]
37        initWithDevice:[MetalContext sharedInstance].device
38                     a:0];
39  });
40  return relu;
41#endif
42}
43
44+ (MPSCNNNeuronSigmoid*)sigmoid {
45// Remove this once we support iOS 11.3
46#if TARGET_OS_MACCATALYST
47  return nil;
48#else
49  static dispatch_once_t onceToken;
50  static MPSCNNNeuronSigmoid* sigmoid = nil;
51  dispatch_once(&onceToken, ^{
52    sigmoid = [[MPSCNNNeuronSigmoid alloc]
53        initWithDevice:[MetalContext sharedInstance].device];
54  });
55  return sigmoid;
56#endif
57}
58
59+ (MPSCNNNeuronTanH*)tanh {
60// Remove this once we support iOS 11.3
61#if TARGET_OS_MACCATALYST
62  return nil;
63#else
64  static dispatch_once_t onceToken;
65  static MPSCNNNeuronTanH* tanh = nil;
66  dispatch_once(&onceToken, ^{
67    tanh = [[MPSCNNNeuronTanH alloc]
68        initWithDevice:[MetalContext sharedInstance].device
69                     a:1
70                     b:1];
71  });
72  return tanh;
73#endif
74}
75
76@end
77
78C10_CLANG_DIAGNOSTIC_POP()
79
80API_AVAILABLE(ios(11.3), macos(10.13), macCatalyst(13.0))
81@implementation MPSCNNNeuronOpDescriptor
82
83+ (MPSNNNeuronDescriptor*)hardSigmoidDescriptor {
84  static dispatch_once_t onceToken;
85  static MPSNNNeuronDescriptor* neuronDesc = nil;
86  dispatch_once(&onceToken, ^{
87    neuronDesc = [MPSNNNeuronDescriptor
88        cnnNeuronDescriptorWithType:MPSCNNNeuronTypeHardSigmoid
89                                  a:1.0 / 6.0
90                                  b:0.5];
91  });
92  return neuronDesc;
93}
94
95+ (MPSNNNeuronDescriptor*)reluDescriptor {
96  static dispatch_once_t onceToken;
97  static MPSNNNeuronDescriptor* neuronDesc = nil;
98  dispatch_once(&onceToken, ^{
99    neuronDesc =
100        [MPSNNNeuronDescriptor cnnNeuronDescriptorWithType:MPSCNNNeuronTypeReLU
101                                                         a:0];
102  });
103  return neuronDesc;
104}
105
106+ (MPSNNNeuronDescriptor*)sigmoidDescriptor {
107  static dispatch_once_t onceToken;
108  static MPSNNNeuronDescriptor* neuronDesc = nil;
109  dispatch_once(&onceToken, ^{
110    neuronDesc = [MPSNNNeuronDescriptor
111        cnnNeuronDescriptorWithType:MPSCNNNeuronTypeSigmoid];
112  });
113  return neuronDesc;
114}
115
116+ (MPSNNNeuronDescriptor*)tanhDescriptor {
117  static dispatch_once_t onceToken;
118  static MPSNNNeuronDescriptor* neuronDesc = nil;
119  dispatch_once(&onceToken, ^{
120    neuronDesc =
121        [MPSNNNeuronDescriptor cnnNeuronDescriptorWithType:MPSCNNNeuronTypeTanH
122                                                         a:1.0
123                                                         b:1.0];
124  });
125  return neuronDesc;
126}
127
128@end
129