xref: /aosp_15_r20/external/executorch/backends/apple/mps/runtime/operations/ConstantOps.mm (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1
2//
3//  Copyright (c) 2023 Apple Inc. All rights reserved.
4//  Provided subject to the LICENSE file in the top level directory.
5//
6
7#include <executorch/backends/apple/mps/runtime/MPSGraphBuilder.h>
8
9namespace executorch {
10namespace backends {
11namespace mps {
12namespace delegate {
13
14Error
15MPSGraphBuilder::mpsConstantOp(int32_t id) {
16  _idToMPSGraphTensor[id] = [_mpsGraph constantWithData:getConstantData(id)
17                                                  shape:getMPSShape(id)
18                                               dataType:getMPSDataType(id)];
19
20  return Error::Ok;
21}
22
23Error
24MPSGraphBuilder::mpsFullOp(NodePtr nodePtr) {
25  auto graphNode = nodePtr->mpsnode_union_as_MPSFull();
26  ET_LOG(
27    Debug, "%s: - -> %d",
28    __FUNCTION__, graphNode->output_id()
29  );
30
31  if (numel(graphNode->shape()) == 0) {
32    _idToMPSGraphTensor[graphNode->output_id()] = nil;
33  } else {
34    _idToMPSGraphTensor[graphNode->output_id()] =
35      [_mpsGraph constantWithScalar:graphNode->fill_value()
36                              shape:getMPSShape(graphNode->shape())
37                           dataType:getMPSDataType(graphNode->dtype())];
38  }
39
40  return Error::Ok;
41}
42
43Error
44MPSGraphBuilder::mpsFullLikeOp(NodePtr nodePtr) {
45  auto graphNode = nodePtr->mpsnode_union_as_MPSFullLike();
46  ET_LOG(
47    Debug, "%s: %d -> %d",
48    __FUNCTION__, graphNode->input1_id(), graphNode->output_id()
49  );
50
51  _idToMPSGraphTensor[graphNode->output_id()] =
52    [_mpsGraph constantWithScalar:graphNode->fill_value()
53                            shape:getMPSGraphTensor(graphNode->input1_id()).shape
54                         dataType:getMPSDataType(graphNode->dtype())];
55
56  return Error::Ok;
57}
58
59
60} // namespace delegate
61} // namespace mps
62} // namespace backends
63} // namespace executorch
64