1 2// 3// Copyright (c) 2023 Apple Inc. All rights reserved. 4// Provided subject to the LICENSE file in the top level directory. 5// 6 7#include <executorch/backends/apple/mps/runtime/MPSGraphBuilder.h> 8 9namespace executorch { 10namespace backends { 11namespace mps { 12namespace delegate { 13 14Error 15MPSGraphBuilder::mpsClampOp(NodePtr nodePtr) { 16 auto graphNode = nodePtr->mpsnode_union_as_MPSClamp(); 17 ET_LOG( 18 Debug, "%s: %d -> %d", 19 __FUNCTION__, 20 graphNode->input1_id(), 21 graphNode->output_id() 22 ); 23 24 std::pair<float, float> minMaxValues = getMinMaxValues(nodePtr); 25 MPSGraphTensor* inputTensor = getMPSGraphTensor(graphNode->input1_id()); 26 bool useMin = minMaxValues.first != -INF; 27 bool useMax = minMaxValues.second != INF; 28 29 if (useMin && useMax) { 30 // Both min and max values are set 31 MPSGraphTensor* minTensor = [_mpsGraph constantWithScalar:minMaxValues.first 32 shape:inputTensor.shape 33 dataType:inputTensor.dataType]; 34 MPSGraphTensor* maxTensor = [_mpsGraph constantWithScalar:minMaxValues.second 35 shape:inputTensor.shape 36 dataType:inputTensor.dataType]; 37 38 _idToMPSGraphTensor[graphNode->output_id()] = [_mpsGraph clampWithTensor:inputTensor 39 minValueTensor:minTensor 40 maxValueTensor:maxTensor 41 name:@"clamp"]; 42 } else if (useMin && !useMax) { 43 // Only min is set 44 MPSGraphTensor* minTensor = [_mpsGraph constantWithScalar:minMaxValues.first 45 shape:inputTensor.shape 46 dataType:inputTensor.dataType]; 47 _idToMPSGraphTensor[graphNode->output_id()] = [_mpsGraph maximumWithPrimaryTensor:inputTensor 48 secondaryTensor:minTensor 49 name:nil]; 50 } else if (!useMin && useMax) { 51 // Only max is set 52 MPSGraphTensor* maxTensor = [_mpsGraph constantWithScalar:minMaxValues.second 53 shape:inputTensor.shape 54 dataType:inputTensor.dataType]; 55 _idToMPSGraphTensor[graphNode->output_id()] = [_mpsGraph minimumWithPrimaryTensor:inputTensor 56 secondaryTensor:maxTensor 57 name:nil]; 58 } 59 return Error::Ok; 60} 61 62Error 63MPSGraphBuilder::mpsWhereOp(NodePtr nodePtr) { 64 auto graphNode = nodePtr->mpsnode_union_as_MPSWhere(); 65 ET_LOG( 66 Debug, "%s: (%d, %d, %d) -> %d", 67 __FUNCTION__, 68 graphNode->input1_id(), 69 graphNode->input2_id(), 70 graphNode->input3_id(), 71 graphNode->output_id() 72 ); 73 74 MPSGraphTensor* condition = getMPSGraphTensor(graphNode->input1_id()); 75 MPSGraphTensor* input = getMPSGraphTensor(graphNode->input2_id()); 76 MPSGraphTensor* other = getMPSGraphTensor(graphNode->input3_id()); 77 78 if ([condition dataType] != MPSDataTypeBool) { 79 condition = [_mpsGraph castTensor:condition 80 toType:MPSDataTypeBool 81 name:@"condition"]; 82 } 83 _idToMPSGraphTensor[graphNode->output_id()] = [_mpsGraph selectWithPredicateTensor:condition 84 truePredicateTensor:input 85 falsePredicateTensor:other 86 name:nil]; 87 return Error::Ok; 88} 89 90 91} // namespace delegate 92} // namespace mps 93} // namespace backends 94} // namespace executorch 95