xref: /aosp_15_r20/external/executorch/backends/apple/mps/runtime/operations/RangeOps.mm (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1
2//
3//  Copyright (c) 2023 Apple Inc. All rights reserved.
4//  Provided subject to the LICENSE file in the top level directory.
5//
6
7#include <executorch/backends/apple/mps/runtime/MPSGraphBuilder.h>
8
9namespace executorch {
10namespace backends {
11namespace mps {
12namespace delegate {
13
14
15Error
16MPSGraphBuilder::mpsArangeOp(NodePtr nodePtr) {
17  auto graphNode = nodePtr->mpsnode_union_as_MPSArange();
18  ET_LOG(
19    Debug, "%s: () -> %d",
20    __FUNCTION__,
21    graphNode->output_id()
22  );
23
24  auto start = graphNode->start();
25  auto end = graphNode->end();
26  auto step = graphNode->step();
27  MPSDataType dataType = getMPSDataType(graphNode->dtype());
28
29  int32_t size_d = std::ceil(static_cast<double>(end - start) / step);
30  auto shapeTensor = [_mpsGraph constantWithData:[NSData dataWithBytes:&size_d length:sizeof(int32_t)]
31                                        shape:@[ @1 ]
32                                      dataType:MPSDataTypeInt32];
33  auto startScalar = start;
34  auto stepScalar = step;
35  auto coordsTensor = [_mpsGraph coordinateAlongAxis:0 withShapeTensor:shapeTensor name:nil];
36  coordsTensor = [_mpsGraph castTensor:coordsTensor toType:dataType name:@"coords"];
37
38  auto startTensor = [_mpsGraph constantWithScalar:startScalar
39                            dataType:dataType];
40  auto multiplyTensor = [_mpsGraph constantWithScalar:stepScalar
41                            dataType:dataType];
42  auto scaledCoords = [_mpsGraph multiplicationWithPrimaryTensor:coordsTensor
43                                              secondaryTensor:multiplyTensor
44                                                          name:nil];
45  _idToMPSGraphTensor[graphNode->output_id()] = [_mpsGraph additionWithPrimaryTensor:scaledCoords secondaryTensor:startTensor name:nil];
46
47  return Error::Ok;
48}
49
50} // namespace delegate
51} // namespace mps
52} // namespace backends
53} // namespace executorch
54