1// 2// Copyright (c) 2023 Apple Inc. All rights reserved. 3// Provided subject to the LICENSE file in the top level directory. 4// 5 6#include <executorch/backends/apple/mps/runtime/MPSGraphBuilder.h> 7#include <executorch/backends/apple/mps/runtime/MPSDevice.h> 8#include <executorch/backends/apple/mps/runtime/MPSDelegateHeader.h> 9 10namespace executorch { 11namespace backends { 12namespace mps { 13namespace delegate { 14 15using executorch::runtime::Result; 16 17MPSGraphBuilder::MPSGraphBuilder( 18 const void* buffer_pointer, 19 size_t num_bytes, 20 std::unordered_map<MPSGraphTensor*, int32_t>& mpsGraphTensorToId) : 21 _mpsGraphTensorToId(mpsGraphTensorToId), _buffer_pointer(buffer_pointer), _num_bytes(num_bytes) { 22 23 _mpsGraph = [MPSGraph new]; 24 _feeds = [NSMutableDictionary dictionary]; 25 _targetTensors = [NSMutableArray new]; 26 27 _mpsGraphExecutable = nil; 28 _metal_kernel = false; 29} 30 31Error 32MPSGraphBuilder::compileModel() { 33 Error err = Error::Ok; 34 35 Result<MPSDelegateHeader> header = MPSDelegateHeader::Parse(_buffer_pointer, _num_bytes); 36 const uint8_t* flatbuffer_data_ptr = nullptr; 37 38 if (header.ok()) { 39 flatbuffer_data_ptr = reinterpret_cast<const uint8_t*>(_buffer_pointer) + 40 header->flatbuffer_offset; 41 _constant_data_ptr = reinterpret_cast<const uint8_t*>(_buffer_pointer) + 42 header->constant_data_offset; 43 } else if (header.error() == Error::NotFound) { 44 ET_LOG( 45 Error, 46 "MPSDelegateHeader version mismatch: '%.4s' != expected '%.4s'", 47 // Header Magic and FlatbufferIdentifier are same offset and size 48 flatbuffers::GetBufferIdentifier(_buffer_pointer), 49 MPSDelegateHeader::kMagic); 50 return header.error(); 51 } else { 52 ET_LOG(Error, "MPSDelegateHeader may be corrupt"); 53 return header.error(); 54 } 55 56 ET_CHECK(flatbuffer_data_ptr != nullptr); 57 ET_CHECK_OR_RETURN_ERROR( 58 mpsgraph::MPSGraphBufferHasIdentifier(flatbuffer_data_ptr), 59 DelegateInvalidCompatibility, 60 "MPS Delegate Serialization Format version identifier '%.4s' != expected '%.4s'", 61 flatbuffers::GetBufferIdentifier(flatbuffer_data_ptr), 62 mpsgraph::MPSGraphIdentifier()); 63 64 _flatBufferGraph = mpsgraph::GetMPSGraph(flatbuffer_data_ptr); 65 switch (_flatBufferGraph->graph_type()) { 66 case mpsgraph::OpType::metal_kernel: 67 { 68 _metal_kernel = true; 69 err = compileMetalKernel(); 70 break; 71 } 72 case mpsgraph::OpType::mps_graph: 73 { 74 err = compileMPSGraph(); 75 break; 76 } 77 default: 78 ET_CHECK_OR_RETURN_ERROR( 79 false, 80 DelegateInvalidCompatibility, 81 "Received an invalid operation type: expected MPSGraph or metal kernel, but got: %s", 82 EnumNameOpType(_flatBufferGraph->graph_type())); 83 } 84 85 return err; 86} 87 88Error 89MPSGraphBuilder::compileMPSGraph() { 90 Error err = Error::Ok; 91 92 _idToMPSGraphTensor.resize(_flatBufferGraph->mps_values()->size(), nullptr); 93 // Add the placeholder nodes to the graph. 94 for (auto in_id : *_flatBufferGraph->input_ids()) { 95 err = mpsGraphRankedPlaceholder(in_id); 96 if (err != Error::Ok) { 97 return err; 98 } 99 } 100 101 // Parse all the serialized constant values and add them to MPSGraph. 102 for (auto constant_id : *_flatBufferGraph->constant_ids()) { 103 err = mpsConstantOp(constant_id); 104 if (err != Error::Ok) { 105 return err; 106 } 107 } 108 109 // Create the corresponding MPSGraph ops of the serialized nodes from the FlatBuffer. 110 for (auto node : *_flatBufferGraph->mps_nodes()) { 111 err = addNodeToMPSGraph(node); 112 if (err != Error::Ok) { 113 return err; 114 } 115 } 116 117 // Add the output nodes to the MPSGraphExecutable. 118 for (auto out_id : *_flatBufferGraph->output_ids()) { 119 ET_CHECK_OR_RETURN_ERROR( 120 _idToMPSGraphTensor[out_id] != nil, 121 InvalidState, 122 "Failed to deserialize the model"); 123 124 [_targetTensors addObject: _idToMPSGraphTensor[out_id]]; 125 } 126 127 return err; 128} 129 130Error 131MPSGraphBuilder::compileMetalKernel() { 132 Error err = Error::Ok; 133 134 ET_CHECK_OR_RETURN_ERROR( 135 _flatBufferGraph->mps_nodes()->size() == 1, 136 DelegateInvalidCompatibility, 137 "Currently supporting dispatching a single Metal kernel."); 138 ET_CHECK_OR_RETURN_ERROR( 139 _flatBufferGraph->constant_ids()->size() == 0, 140 DelegateInvalidCompatibility, 141 "Currently not supporting dispatching Metal kernels with constants."); 142 143 // Compile the corresponding Metal kernel 144 for (auto node : *_flatBufferGraph->mps_nodes()) { 145 err = compileMetalKernel(node); 146 if (err != Error::Ok) { 147 return err; 148 } 149 } 150 151 return err; 152} 153 154Error 155MPSGraphBuilder::mpsGraphRankedPlaceholder(int32_t id) { 156 ET_LOG(Debug, "%s: %d", __FUNCTION__, id); 157 MPSShape* mpsShape = getMPSShape(id); 158 MPSDataType mpsDataType = getMPSDataType(id); 159 MPSGraphTensor* placeholder = [_mpsGraph placeholderWithShape:mpsShape 160 dataType:mpsDataType 161 name:nil]; 162 _idToMPSGraphTensor[id] = placeholder; 163 _feeds[placeholder] = [[MPSGraphShapedType alloc] initWithShape:mpsShape 164 dataType:mpsDataType]; 165 _mpsGraphTensorToId[placeholder] = id; 166 return Error::Ok; 167} 168 169MPSGraph* 170MPSGraphBuilder::getMPSGraph() { 171 return _mpsGraph; 172} 173 174MPSGraphExecutable* 175MPSGraphBuilder::getMPSGraphExecutable() { 176 if (_mpsGraphExecutable) { 177 return _mpsGraphExecutable; 178 } 179 _mpsGraphExecutable = [_mpsGraph compileWithDevice:[MPSGraphDevice deviceWithMTLDevice:MPSDevice::getInstance()->device()] 180 feeds:_feeds 181 targetTensors:_targetTensors 182 targetOperations:nil 183 compilationDescriptor:nil]; 184 185 return _mpsGraphExecutable; 186 187} 188 189} // namespace delegate 190} // namespace mps 191} // namespace backends 192} // namespace executorch 193