1// 2// Copyright (c) 2023 Apple Inc. All rights reserved. 3// Provided subject to the LICENSE file in the top level directory. 4// 5 6// Obj-C headers 7#import <Foundation/Foundation.h> 8#import <MetalPerformanceShaders/MetalPerformanceShaders.h> 9#import <MetalPerformanceShadersGraph/MetalPerformanceShadersGraph.h> 10 11// MPS headers 12#include <executorch/backends/apple/mps/runtime/MPSDevice.h> 13#include <executorch/backends/apple/mps/runtime/MPSCompiler.h> 14#include <executorch/backends/apple/mps/runtime/MPSGraphBuilder.h> 15#include <executorch/backends/apple/mps/schema_generated.h> 16 17// Runtime headers 18#include <executorch/runtime/core/exec_aten/util/scalar_type_util.h> 19 20#include <unordered_map> 21#include <string> 22#include <iostream> 23 24#define MPS_UNUSED(x) ( (void)(x) ) 25 26namespace executorch { 27namespace backends { 28namespace mps { 29namespace delegate { 30 31using executorch::runtime::ArrayRef; 32using executorch::runtime::CompileSpec; 33using executorch::runtime::Error; 34using executorch::runtime::MemoryAllocator; 35 36/* 37Builds the mps runtime object using the buffer pointer. The buffer pointer 38must be a valid pointer to the serialized mps object. 39*/ 40ET_NODISCARD Error MPSCompiler::compileModel( 41 const void* buffer_pointer, 42 size_t num_bytes, 43 MPSExecutor* executor, 44 MemoryAllocator* runtime_allocator, 45 ArrayRef<CompileSpec> compile_specs) { 46 MPS_UNUSED(compile_specs); 47 48 Error err = Error::Ok; 49 50 std::unique_ptr<MPSGraphBuilder> mpsGraphBuilder( 51 new MPSGraphBuilder(buffer_pointer, num_bytes, executor->_mpsGraphTensorToId)); 52 err = mpsGraphBuilder->compileModel(); 53 ET_CHECK_OR_RETURN_ERROR( 54 err == Error::Ok, Internal, "Failed to construct the MPS graph object"); 55 56 executor->_executable = mpsGraphBuilder->getMPSGraphExecutable(); 57 ET_CHECK_OR_RETURN_ERROR( 58 executor->_executable != nil, 59 InvalidProgram, 60 "Invalid FlatBuffer contents - could not create MPSGraphExecutable"); 61 62 err = executor->initDataBuffers(); 63 ET_CHECK_OR_RETURN_ERROR( 64 err == Error::Ok, Internal, "Could not allocate data buffers"); 65 66 ET_LOG(Debug, "MPSGraphExecutable total inputs: %lu", [executor->_inputShapes count]); 67 ET_LOG(Debug, "MPSGraphExecutable total outputs: %lu", [executor->_outputShapes count]); 68 69 return err; 70} 71 72} // namespace delegate 73} // namespace mps 74} // namespace backends 75} // namespace executorch 76