1 2// 3// Copyright (c) 2023 Apple Inc. All rights reserved. 4// Provided subject to the LICENSE file in the top level directory. 5// 6 7#include <executorch/backends/apple/mps/runtime/MPSGraphBuilder.h> 8 9namespace executorch { 10namespace backends { 11namespace mps { 12namespace delegate { 13 14Error 15MPSGraphBuilder::mpsHardTanhOp(NodePtr nodePtr) { 16 auto graphNode = nodePtr->mpsnode_union_as_MPSHardTanh(); 17 18 ET_LOG( 19 Debug, "%s: %d -> %d", 20 __FUNCTION__, graphNode->input1_id(), graphNode->output_id() 21 ); 22 23 float minValue = graphNode->min_value(); 24 float maxValue = graphNode->max_value(); 25 MPSGraphTensor* inputTensor = getMPSGraphTensor(graphNode->input1_id()); 26 27 MPSDataType inputType = [inputTensor dataType]; 28 MPSShape* inputShape = [inputTensor shape]; 29 MPSGraphTensor* minTensor = [_mpsGraph constantWithScalar:minValue shape:inputShape dataType:inputType]; 30 MPSGraphTensor* maxTensor = [_mpsGraph constantWithScalar:maxValue shape:inputShape dataType:inputType]; 31 MPSGraphTensor* lessThanMinPredicateTensor = [_mpsGraph lessThanWithPrimaryTensor:inputTensor 32 secondaryTensor:minTensor 33 name:@"LessThanPredicate"]; 34 MPSGraphTensor* greaterThanMaxPredicateTensor = [_mpsGraph greaterThanWithPrimaryTensor:inputTensor 35 secondaryTensor:maxTensor 36 name:@"MoreThanPredicate"]; 37 38 MPSGraphTensor* temp = [_mpsGraph selectWithPredicateTensor:lessThanMinPredicateTensor 39 truePredicateTensor:minTensor 40 falsePredicateTensor:inputTensor 41 name:@"minOutput"]; 42 43 _idToMPSGraphTensor[graphNode->output_id()] = [_mpsGraph selectWithPredicateTensor:greaterThanMaxPredicateTensor 44 truePredicateTensor:maxTensor 45 falsePredicateTensor:temp 46 name:@"hardTanh"]; 47 48 return Error::Ok; 49} 50 51Error 52MPSGraphBuilder::mpsReLUOp(NodePtr nodePtr) { 53 auto graphNode = nodePtr->mpsnode_union_as_MPSReLU(); 54 55 ET_LOG( 56 Debug, "%s: %d -> %d", 57 __FUNCTION__, graphNode->input1_id(), graphNode->output_id() 58 ); 59 60 _idToMPSGraphTensor[graphNode->output_id()] = 61 [_mpsGraph reLUWithTensor:getMPSGraphTensor(graphNode->input1_id()) 62 name:@"relu"]; 63 64 return Error::Ok; 65} 66 67Error 68MPSGraphBuilder::mpsGELUOp(NodePtr nodePtr) { 69 auto graphNode = nodePtr->mpsnode_union_as_MPSGELU(); 70 std::string approximation = graphNode->approximate()->str(); 71 Error status = Error::Ok; 72 73 ET_LOG( 74 Debug, "%s: %d (%s) -> %d", 75 __FUNCTION__, graphNode->input1_id(), approximation.c_str(), graphNode->output_id() 76 ); 77 78 if (approximation == "tanh") { 79 status = mpsTanhOp(nodePtr); 80 } else { 81 status = mpsNormCdfOp(nodePtr); 82 } 83 84 ET_CHECK_OR_RETURN_ERROR( 85 status == Error::Ok, 86 Internal, 87 "[ERROR] Couldn't add GELU node to MPSGraph"); 88 _idToMPSGraphTensor[graphNode->output_id()] = 89 [_mpsGraph multiplicationWithPrimaryTensor:_idToMPSGraphTensor[graphNode->output_id()] 90 secondaryTensor:getMPSGraphTensor(graphNode->input1_id()) 91 name:nil]; 92 93 return status; 94} 95 96Error 97MPSGraphBuilder::mpsLeakyReLUOp(NodePtr nodePtr) { 98 auto graphNode = nodePtr->mpsnode_union_as_MPSLeakyReLU(); 99 100 ET_LOG( 101 Debug, "%s: %d -> %d", 102 __FUNCTION__, graphNode->input1_id(), graphNode->output_id() 103 ); 104 105 _idToMPSGraphTensor[graphNode->output_id()] = 106 [_mpsGraph leakyReLUWithTensor:getMPSGraphTensor(graphNode->input1_id()) 107 alpha:graphNode->negative_slope() 108 name:@"leaky_relu"]; 109 110 return Error::Ok; 111} 112 113Error 114MPSGraphBuilder::mpsSoftmaxOp(NodePtr nodePtr) { 115 auto graphNode = nodePtr->mpsnode_union_as_MPSSoftmax(); 116 117 ET_LOG( 118 Debug, "%s: %d -> %d", 119 __FUNCTION__, graphNode->input1_id(), graphNode->output_id() 120 ); 121 122 ET_CHECK_MSG(!graphNode->half_to_float(), "softmax with half to float conversion is not supported on MPS"); 123 124 _idToMPSGraphTensor[graphNode->output_id()] = 125 [_mpsGraph softMaxWithTensor:getMPSGraphTensor(graphNode->input1_id()) 126 axis:graphNode->dim() 127 name:@"softmax"]; 128 return Error::Ok; 129} 130 131Error 132MPSGraphBuilder::mpsLogSoftmaxOp(NodePtr nodePtr) { 133 auto graphNode = nodePtr->mpsnode_union_as_MPSLogSoftmax(); 134 135 ET_LOG( 136 Debug, "%s: %d -> %d", 137 __FUNCTION__, graphNode->input1_id(), graphNode->output_id() 138 ); 139 140 ET_CHECK_MSG(!graphNode->half_to_float(), "softmax with half to float conversion is not supported on MPS"); 141 142 MPSGraphTensor* softmaxTensor = [_mpsGraph softMaxWithTensor:getMPSGraphTensor(graphNode->input1_id()) 143 axis:graphNode->dim() 144 name:@"softmax"]; 145 _idToMPSGraphTensor[graphNode->output_id()] = 146 [_mpsGraph logarithmWithTensor:softmaxTensor 147 name:@"log_softmax"]; 148 149 return Error::Ok; 150} 151 152} // namespace delegate 153} // namespace mps 154} // namespace backends 155} // namespace executorch 156