xref: /aosp_15_r20/external/executorch/backends/apple/mps/runtime/operations/ActivationOps.mm (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1
2//
3//  Copyright (c) 2023 Apple Inc. All rights reserved.
4//  Provided subject to the LICENSE file in the top level directory.
5//
6
7#include <executorch/backends/apple/mps/runtime/MPSGraphBuilder.h>
8
9namespace executorch {
10namespace backends {
11namespace mps {
12namespace delegate {
13
14Error
15MPSGraphBuilder::mpsHardTanhOp(NodePtr nodePtr) {
16  auto graphNode = nodePtr->mpsnode_union_as_MPSHardTanh();
17
18  ET_LOG(
19    Debug, "%s: %d -> %d",
20    __FUNCTION__, graphNode->input1_id(), graphNode->output_id()
21  );
22
23  float minValue = graphNode->min_value();
24  float maxValue = graphNode->max_value();
25  MPSGraphTensor* inputTensor = getMPSGraphTensor(graphNode->input1_id());
26
27  MPSDataType inputType = [inputTensor dataType];
28  MPSShape* inputShape = [inputTensor shape];
29  MPSGraphTensor* minTensor = [_mpsGraph constantWithScalar:minValue shape:inputShape dataType:inputType];
30  MPSGraphTensor* maxTensor = [_mpsGraph constantWithScalar:maxValue shape:inputShape dataType:inputType];
31  MPSGraphTensor* lessThanMinPredicateTensor = [_mpsGraph lessThanWithPrimaryTensor:inputTensor
32                                                                   secondaryTensor:minTensor
33                                                                              name:@"LessThanPredicate"];
34  MPSGraphTensor* greaterThanMaxPredicateTensor = [_mpsGraph greaterThanWithPrimaryTensor:inputTensor
35                                                                      secondaryTensor:maxTensor
36                                                                                 name:@"MoreThanPredicate"];
37
38  MPSGraphTensor* temp = [_mpsGraph selectWithPredicateTensor:lessThanMinPredicateTensor
39                                              truePredicateTensor:minTensor
40                                              falsePredicateTensor:inputTensor
41                                              name:@"minOutput"];
42
43  _idToMPSGraphTensor[graphNode->output_id()] = [_mpsGraph selectWithPredicateTensor:greaterThanMaxPredicateTensor
44                                                                truePredicateTensor:maxTensor
45                                                                falsePredicateTensor:temp
46                                                                                name:@"hardTanh"];
47
48  return Error::Ok;
49}
50
51Error
52MPSGraphBuilder::mpsReLUOp(NodePtr nodePtr) {
53  auto graphNode = nodePtr->mpsnode_union_as_MPSReLU();
54
55  ET_LOG(
56    Debug, "%s: %d -> %d",
57    __FUNCTION__, graphNode->input1_id(), graphNode->output_id()
58  );
59
60  _idToMPSGraphTensor[graphNode->output_id()] =
61    [_mpsGraph reLUWithTensor:getMPSGraphTensor(graphNode->input1_id())
62                          name:@"relu"];
63
64  return Error::Ok;
65}
66
67Error
68MPSGraphBuilder::mpsGELUOp(NodePtr nodePtr) {
69  auto graphNode = nodePtr->mpsnode_union_as_MPSGELU();
70  std::string approximation = graphNode->approximate()->str();
71  Error status = Error::Ok;
72
73  ET_LOG(
74    Debug, "%s: %d (%s) -> %d",
75    __FUNCTION__, graphNode->input1_id(), approximation.c_str(), graphNode->output_id()
76  );
77
78  if (approximation == "tanh") {
79    status = mpsTanhOp(nodePtr);
80  } else {
81    status = mpsNormCdfOp(nodePtr);
82  }
83
84  ET_CHECK_OR_RETURN_ERROR(
85    status == Error::Ok,
86    Internal,
87    "[ERROR] Couldn't add GELU node to MPSGraph");
88    _idToMPSGraphTensor[graphNode->output_id()] =
89      [_mpsGraph multiplicationWithPrimaryTensor:_idToMPSGraphTensor[graphNode->output_id()]
90                                 secondaryTensor:getMPSGraphTensor(graphNode->input1_id())
91                                           name:nil];
92
93  return status;
94}
95
96Error
97MPSGraphBuilder::mpsLeakyReLUOp(NodePtr nodePtr) {
98  auto graphNode = nodePtr->mpsnode_union_as_MPSLeakyReLU();
99
100  ET_LOG(
101    Debug, "%s: %d -> %d",
102    __FUNCTION__, graphNode->input1_id(), graphNode->output_id()
103  );
104
105  _idToMPSGraphTensor[graphNode->output_id()] =
106    [_mpsGraph leakyReLUWithTensor:getMPSGraphTensor(graphNode->input1_id())
107                             alpha:graphNode->negative_slope()
108                              name:@"leaky_relu"];
109
110  return Error::Ok;
111}
112
113Error
114MPSGraphBuilder::mpsSoftmaxOp(NodePtr nodePtr) {
115  auto graphNode = nodePtr->mpsnode_union_as_MPSSoftmax();
116
117  ET_LOG(
118    Debug, "%s: %d -> %d",
119    __FUNCTION__, graphNode->input1_id(), graphNode->output_id()
120  );
121
122  ET_CHECK_MSG(!graphNode->half_to_float(), "softmax with half to float conversion is not supported on MPS");
123
124  _idToMPSGraphTensor[graphNode->output_id()] =
125    [_mpsGraph softMaxWithTensor:getMPSGraphTensor(graphNode->input1_id())
126                            axis:graphNode->dim()
127                            name:@"softmax"];
128  return Error::Ok;
129}
130
131Error
132MPSGraphBuilder::mpsLogSoftmaxOp(NodePtr nodePtr) {
133  auto graphNode = nodePtr->mpsnode_union_as_MPSLogSoftmax();
134
135  ET_LOG(
136    Debug, "%s: %d -> %d",
137    __FUNCTION__, graphNode->input1_id(), graphNode->output_id()
138  );
139
140  ET_CHECK_MSG(!graphNode->half_to_float(), "softmax with half to float conversion is not supported on MPS");
141
142  MPSGraphTensor* softmaxTensor = [_mpsGraph softMaxWithTensor:getMPSGraphTensor(graphNode->input1_id())
143                                                         axis:graphNode->dim()
144                                                         name:@"softmax"];
145  _idToMPSGraphTensor[graphNode->output_id()] =
146    [_mpsGraph logarithmWithTensor:softmaxTensor
147                              name:@"log_softmax"];
148
149  return Error::Ok;
150}
151
152} // namespace delegate
153} // namespace mps
154} // namespace backends
155} // namespace executorch
156