xref: /aosp_15_r20/external/executorch/backends/apple/mps/runtime/MPSGraphBuilder.mm (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1//
2//  Copyright (c) 2023 Apple Inc. All rights reserved.
3//  Provided subject to the LICENSE file in the top level directory.
4//
5
6#include <executorch/backends/apple/mps/runtime/MPSGraphBuilder.h>
7#include <executorch/backends/apple/mps/runtime/MPSDevice.h>
8#include <executorch/backends/apple/mps/runtime/MPSDelegateHeader.h>
9
10namespace executorch {
11namespace backends {
12namespace mps {
13namespace delegate {
14
15using executorch::runtime::Result;
16
17MPSGraphBuilder::MPSGraphBuilder(
18  const void* buffer_pointer,
19  size_t num_bytes,
20  std::unordered_map<MPSGraphTensor*, int32_t>& mpsGraphTensorToId) :
21    _mpsGraphTensorToId(mpsGraphTensorToId), _buffer_pointer(buffer_pointer), _num_bytes(num_bytes) {
22
23  _mpsGraph = [MPSGraph new];
24  _feeds = [NSMutableDictionary dictionary];
25  _targetTensors = [NSMutableArray new];
26
27  _mpsGraphExecutable = nil;
28  _metal_kernel = false;
29}
30
31Error
32MPSGraphBuilder::compileModel() {
33  Error err = Error::Ok;
34
35  Result<MPSDelegateHeader> header = MPSDelegateHeader::Parse(_buffer_pointer, _num_bytes);
36  const uint8_t* flatbuffer_data_ptr = nullptr;
37
38  if (header.ok()) {
39    flatbuffer_data_ptr = reinterpret_cast<const uint8_t*>(_buffer_pointer) +
40        header->flatbuffer_offset;
41    _constant_data_ptr = reinterpret_cast<const uint8_t*>(_buffer_pointer) +
42        header->constant_data_offset;
43  } else if (header.error() == Error::NotFound) {
44    ET_LOG(
45        Error,
46        "MPSDelegateHeader version mismatch: '%.4s' != expected '%.4s'",
47        // Header Magic and FlatbufferIdentifier are same offset and size
48        flatbuffers::GetBufferIdentifier(_buffer_pointer),
49        MPSDelegateHeader::kMagic);
50    return header.error();
51  } else {
52    ET_LOG(Error, "MPSDelegateHeader may be corrupt");
53    return header.error();
54  }
55
56  ET_CHECK(flatbuffer_data_ptr != nullptr);
57  ET_CHECK_OR_RETURN_ERROR(
58    mpsgraph::MPSGraphBufferHasIdentifier(flatbuffer_data_ptr),
59    DelegateInvalidCompatibility,
60    "MPS Delegate Serialization Format version identifier '%.4s' != expected '%.4s'",
61    flatbuffers::GetBufferIdentifier(flatbuffer_data_ptr),
62    mpsgraph::MPSGraphIdentifier());
63
64  _flatBufferGraph = mpsgraph::GetMPSGraph(flatbuffer_data_ptr);
65  switch (_flatBufferGraph->graph_type()) {
66    case mpsgraph::OpType::metal_kernel:
67    {
68      _metal_kernel = true;
69      err = compileMetalKernel();
70      break;
71    }
72    case mpsgraph::OpType::mps_graph:
73    {
74      err = compileMPSGraph();
75      break;
76    }
77    default:
78      ET_CHECK_OR_RETURN_ERROR(
79      false,
80      DelegateInvalidCompatibility,
81      "Received an invalid operation type: expected MPSGraph or metal kernel, but got: %s",
82      EnumNameOpType(_flatBufferGraph->graph_type()));
83  }
84
85  return err;
86}
87
88Error
89MPSGraphBuilder::compileMPSGraph() {
90  Error err = Error::Ok;
91
92  _idToMPSGraphTensor.resize(_flatBufferGraph->mps_values()->size(), nullptr);
93  // Add the placeholder nodes to the graph.
94  for (auto in_id : *_flatBufferGraph->input_ids()) {
95    err = mpsGraphRankedPlaceholder(in_id);
96    if (err != Error::Ok) {
97      return err;
98    }
99  }
100
101  // Parse all the serialized constant values and add them to MPSGraph.
102  for (auto constant_id : *_flatBufferGraph->constant_ids()) {
103    err = mpsConstantOp(constant_id);
104    if (err != Error::Ok) {
105      return err;
106    }
107  }
108
109  // Create the corresponding MPSGraph ops of the serialized nodes from the FlatBuffer.
110  for (auto node : *_flatBufferGraph->mps_nodes()) {
111    err = addNodeToMPSGraph(node);
112    if (err != Error::Ok) {
113      return err;
114    }
115  }
116
117  // Add the output nodes to the MPSGraphExecutable.
118  for (auto out_id : *_flatBufferGraph->output_ids()) {
119    ET_CHECK_OR_RETURN_ERROR(
120      _idToMPSGraphTensor[out_id] != nil,
121      InvalidState,
122      "Failed to deserialize the model");
123
124    [_targetTensors addObject: _idToMPSGraphTensor[out_id]];
125  }
126
127  return err;
128}
129
130Error
131MPSGraphBuilder::compileMetalKernel() {
132  Error err = Error::Ok;
133
134  ET_CHECK_OR_RETURN_ERROR(
135    _flatBufferGraph->mps_nodes()->size() == 1,
136    DelegateInvalidCompatibility,
137    "Currently supporting dispatching a single Metal kernel.");
138  ET_CHECK_OR_RETURN_ERROR(
139    _flatBufferGraph->constant_ids()->size() == 0,
140    DelegateInvalidCompatibility,
141    "Currently not supporting dispatching Metal kernels with constants.");
142
143  // Compile the corresponding Metal kernel
144  for (auto node : *_flatBufferGraph->mps_nodes()) {
145    err = compileMetalKernel(node);
146    if (err != Error::Ok) {
147      return err;
148    }
149  }
150
151  return err;
152}
153
154Error
155MPSGraphBuilder::mpsGraphRankedPlaceholder(int32_t id) {
156  ET_LOG(Debug, "%s: %d", __FUNCTION__, id);
157  MPSShape* mpsShape = getMPSShape(id);
158  MPSDataType mpsDataType = getMPSDataType(id);
159  MPSGraphTensor* placeholder = [_mpsGraph placeholderWithShape:mpsShape
160                                                  dataType:mpsDataType
161                                                      name:nil];
162  _idToMPSGraphTensor[id] = placeholder;
163  _feeds[placeholder] = [[MPSGraphShapedType alloc] initWithShape:mpsShape
164                                                         dataType:mpsDataType];
165  _mpsGraphTensorToId[placeholder] = id;
166  return Error::Ok;
167}
168
169MPSGraph*
170MPSGraphBuilder::getMPSGraph() {
171  return _mpsGraph;
172}
173
174MPSGraphExecutable*
175MPSGraphBuilder::getMPSGraphExecutable() {
176  if (_mpsGraphExecutable) {
177    return _mpsGraphExecutable;
178  }
179  _mpsGraphExecutable = [_mpsGraph compileWithDevice:[MPSGraphDevice deviceWithMTLDevice:MPSDevice::getInstance()->device()]
180                                               feeds:_feeds
181                                       targetTensors:_targetTensors
182                                    targetOperations:nil
183                               compilationDescriptor:nil];
184
185  return _mpsGraphExecutable;
186
187}
188
189} // namespace delegate
190} // namespace mps
191} // namespace backends
192} // namespace executorch
193