xref: /aosp_15_r20/external/executorch/backends/apple/mps/runtime/operations/ClampOps.mm (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1
2//
3//  Copyright (c) 2023 Apple Inc. All rights reserved.
4//  Provided subject to the LICENSE file in the top level directory.
5//
6
7#include <executorch/backends/apple/mps/runtime/MPSGraphBuilder.h>
8
9namespace executorch {
10namespace backends {
11namespace mps {
12namespace delegate {
13
14Error
15MPSGraphBuilder::mpsClampOp(NodePtr nodePtr) {
16  auto graphNode = nodePtr->mpsnode_union_as_MPSClamp();
17  ET_LOG(
18    Debug, "%s: %d -> %d",
19    __FUNCTION__,
20    graphNode->input1_id(),
21    graphNode->output_id()
22  );
23
24  std::pair<float, float> minMaxValues = getMinMaxValues(nodePtr);
25  MPSGraphTensor* inputTensor = getMPSGraphTensor(graphNode->input1_id());
26  bool useMin = minMaxValues.first != -INF;
27  bool useMax = minMaxValues.second != INF;
28
29  if (useMin && useMax) {
30    // Both min and max values are set
31    MPSGraphTensor* minTensor = [_mpsGraph constantWithScalar:minMaxValues.first
32                                                       shape:inputTensor.shape
33                                                    dataType:inputTensor.dataType];
34    MPSGraphTensor* maxTensor = [_mpsGraph constantWithScalar:minMaxValues.second
35                                                       shape:inputTensor.shape
36                                                    dataType:inputTensor.dataType];
37
38    _idToMPSGraphTensor[graphNode->output_id()] = [_mpsGraph clampWithTensor:inputTensor
39                                                              minValueTensor:minTensor
40                                                              maxValueTensor:maxTensor
41                                                                        name:@"clamp"];
42  } else if (useMin && !useMax) {
43    // Only min is set
44    MPSGraphTensor* minTensor = [_mpsGraph constantWithScalar:minMaxValues.first
45                                                       shape:inputTensor.shape
46                                                    dataType:inputTensor.dataType];
47    _idToMPSGraphTensor[graphNode->output_id()] = [_mpsGraph maximumWithPrimaryTensor:inputTensor
48                                                                      secondaryTensor:minTensor
49                                                                                 name:nil];
50  } else if (!useMin && useMax) {
51    // Only max is set
52    MPSGraphTensor* maxTensor = [_mpsGraph constantWithScalar:minMaxValues.second
53                                                    shape:inputTensor.shape
54                                                dataType:inputTensor.dataType];
55    _idToMPSGraphTensor[graphNode->output_id()] = [_mpsGraph minimumWithPrimaryTensor:inputTensor
56                                                                     secondaryTensor:maxTensor
57                                                                                name:nil];
58  }
59  return Error::Ok;
60}
61
62Error
63MPSGraphBuilder::mpsWhereOp(NodePtr nodePtr) {
64  auto graphNode = nodePtr->mpsnode_union_as_MPSWhere();
65  ET_LOG(
66    Debug, "%s: (%d, %d, %d) -> %d",
67    __FUNCTION__,
68    graphNode->input1_id(),
69    graphNode->input2_id(),
70    graphNode->input3_id(),
71    graphNode->output_id()
72  );
73
74  MPSGraphTensor* condition = getMPSGraphTensor(graphNode->input1_id());
75  MPSGraphTensor* input = getMPSGraphTensor(graphNode->input2_id());
76  MPSGraphTensor* other = getMPSGraphTensor(graphNode->input3_id());
77
78  if ([condition dataType] != MPSDataTypeBool) {
79    condition = [_mpsGraph castTensor:condition
80                               toType:MPSDataTypeBool
81                                 name:@"condition"];
82  }
83  _idToMPSGraphTensor[graphNode->output_id()]  = [_mpsGraph selectWithPredicateTensor:condition
84                                                     truePredicateTensor:input
85                                                    falsePredicateTensor:other
86                                                                    name:nil];
87  return Error::Ok;
88}
89
90
91} // namespace delegate
92} // namespace mps
93} // namespace backends
94} // namespace executorch
95