1 #ifndef MetalNeuronType_h
2 #define MetalNeuronType_h
3
4 #import <ATen/native/metal/mpscnn/MPSCNNNeuronOp.h>
5 #import <MetalPerformanceShaders/MetalPerformanceShaders.h>
6
7 #include <ATen/ATen.h>
8
9 namespace at::native::metal {
10
11 enum class NeuronType {
12 None,
13 Clamp,
14 Relu,
15 Sigmoid,
16 HardSigmoid,
17 Tanh,
18 };
19
neuronType(std::optional<c10::Scalar> output_min,std::optional<c10::Scalar> output_max)20 static inline NeuronType neuronType(
21 std::optional<c10::Scalar> output_min,
22 std::optional<c10::Scalar> output_max) {
23 float inf_max = std::numeric_limits<float>::infinity();
24 float inf_min = -std::numeric_limits<float>::infinity();
25 float output_max_ =
26 output_max.has_value() ? output_max.value().toFloat() : inf_max;
27 float output_min_ =
28 output_min.has_value() ? output_min.value().toFloat() : inf_min;
29 if (output_max_ == inf_max && output_min_ == 0) {
30 return NeuronType::Relu;
31 } else if (output_max_ < inf_max && output_min_ > inf_min) {
32 return NeuronType::Clamp;
33 } else {
34 return NeuronType::None;
35 }
36 }
37
neuron(NeuronType type)38 static inline MPSCNNNeuron* neuron(NeuronType type) {
39 if (type == NeuronType::Relu) {
40 return [MPSCNNNeuronOp relu];
41 } else if (type == NeuronType::Sigmoid) {
42 return [MPSCNNNeuronOp sigmoid];
43 } else if (type == NeuronType::Tanh) {
44 return [MPSCNNNeuronOp tanh];
45 } else if (type == NeuronType::HardSigmoid) {
46 return [MPSCNNNeuronOp hardSigmoid];
47 } else {
48 return nil;
49 }
50 }
51
52 API_AVAILABLE(ios(11.3), macos(10.13), macCatalyst(13.0))
neuronDescriptor(NeuronType type)53 static inline MPSNNNeuronDescriptor* neuronDescriptor(NeuronType type) {
54 if (type == NeuronType::Relu) {
55 return [MPSCNNNeuronOpDescriptor reluDescriptor];
56 } else if (type == NeuronType::Sigmoid) {
57 return [MPSCNNNeuronOpDescriptor sigmoidDescriptor];
58 } else if (type == NeuronType::Tanh) {
59 return [MPSCNNNeuronOpDescriptor tanhDescriptor];
60 } else if (type == NeuronType::HardSigmoid) {
61 return [MPSCNNNeuronOpDescriptor hardSigmoidDescriptor];
62 } else {
63 return [MPSNNNeuronDescriptor cnnNeuronDescriptorWithType:MPSCNNNeuronTypeNone];
64 }
65 }
66
67 } // namespace at::native::metal
68
69 #endif /* MetalNeuronType_h */
70