1 2// 3// Copyright (c) 2023 Apple Inc. All rights reserved. 4// Provided subject to the LICENSE file in the top level directory. 5// 6 7#include <executorch/backends/apple/mps/runtime/MPSGraphBuilder.h> 8 9namespace executorch { 10namespace backends { 11namespace mps { 12namespace delegate { 13 14 15Error 16MPSGraphBuilder::mpsArangeOp(NodePtr nodePtr) { 17 auto graphNode = nodePtr->mpsnode_union_as_MPSArange(); 18 ET_LOG( 19 Debug, "%s: () -> %d", 20 __FUNCTION__, 21 graphNode->output_id() 22 ); 23 24 auto start = graphNode->start(); 25 auto end = graphNode->end(); 26 auto step = graphNode->step(); 27 MPSDataType dataType = getMPSDataType(graphNode->dtype()); 28 29 int32_t size_d = std::ceil(static_cast<double>(end - start) / step); 30 auto shapeTensor = [_mpsGraph constantWithData:[NSData dataWithBytes:&size_d length:sizeof(int32_t)] 31 shape:@[ @1 ] 32 dataType:MPSDataTypeInt32]; 33 auto startScalar = start; 34 auto stepScalar = step; 35 auto coordsTensor = [_mpsGraph coordinateAlongAxis:0 withShapeTensor:shapeTensor name:nil]; 36 coordsTensor = [_mpsGraph castTensor:coordsTensor toType:dataType name:@"coords"]; 37 38 auto startTensor = [_mpsGraph constantWithScalar:startScalar 39 dataType:dataType]; 40 auto multiplyTensor = [_mpsGraph constantWithScalar:stepScalar 41 dataType:dataType]; 42 auto scaledCoords = [_mpsGraph multiplicationWithPrimaryTensor:coordsTensor 43 secondaryTensor:multiplyTensor 44 name:nil]; 45 _idToMPSGraphTensor[graphNode->output_id()] = [_mpsGraph additionWithPrimaryTensor:scaledCoords secondaryTensor:startTensor name:nil]; 46 47 return Error::Ok; 48} 49 50} // namespace delegate 51} // namespace mps 52} // namespace backends 53} // namespace executorch 54