1 2// 3// Copyright (c) 2023 Apple Inc. All rights reserved. 4// Provided subject to the LICENSE file in the top level directory. 5// 6 7#include <executorch/backends/apple/mps/runtime/MPSGraphBuilder.h> 8 9namespace executorch { 10namespace backends { 11namespace mps { 12namespace delegate { 13 14 15Error 16MPSGraphBuilder::mpsMeanOp(NodePtr nodePtr) { 17 auto graphNode = nodePtr->mpsnode_union_as_MPSMean(); 18 ET_LOG( 19 Debug, "%s: %d -> %d", 20 __FUNCTION__, 21 graphNode->input1_id(), 22 graphNode->output_id() 23 ); 24 25 MPSGraphTensor* inputTensor = getMPSGraphTensor(graphNode->input1_id()); 26 27 //MPSGraph wants negative axes to be converted to positive 28 const int inputDims = [inputTensor.shape count]; 29 30 NSMutableArray<NSNumber*>* dimArray = [NSMutableArray array]; 31 for(int64_t i = 0; i < graphNode->num_dims(); i++) { 32 int32_t dim = graphNode->dims()->Get(i); 33 if (dim < 0) { 34 dim = inputDims + dim; 35 } 36 [dimArray addObject:[NSNumber numberWithInt:dim]]; 37 } 38 39 // Reverting back to get the ordering back to slowest axis first as MPSGraph expects 40 dimArray = [[[dimArray reverseObjectEnumerator] allObjects] mutableCopy]; 41 42 MPSGraphTensor* meanTensor = [_mpsGraph meanOfTensor:inputTensor 43 axes:dimArray 44 name:@"Mean"]; 45 if (!graphNode->keep_dims()) { 46 meanTensor = [_mpsGraph squeezeTensor:meanTensor 47 axes:dimArray 48 name:@"Mean/squeezed"]; 49 } 50 51 _idToMPSGraphTensor[graphNode->output_id()] = meanTensor; 52 return Error::Ok; 53} 54 55 56} // namespace delegate 57} // namespace mps 58} // namespace backends 59} // namespace executorch 60