1// 2// Copyright (c) 2023 Apple Inc. All rights reserved. 3// Provided subject to the LICENSE file in the top level directory. 4// 5 6#import <Foundation/Foundation.h> 7 8#include "MPSCompiler.h" 9#include <executorch/runtime/backend/interface.h> 10#include <executorch/runtime/core/error.h> 11#include <executorch/runtime/core/evalue.h> 12#include <executorch/runtime/platform/profiler.h> 13#include <cstdio> 14#include <cstdlib> /* strtol */ 15#include <memory> 16#include <string> 17#include <iostream> 18 19namespace executorch { 20namespace backends { 21 22using executorch::aten::Tensor; 23using executorch::runtime::ArrayRef; 24using executorch::runtime::Backend; 25using executorch::runtime::BackendExecutionContext; 26using executorch::runtime::BackendInitContext; 27using executorch::runtime::CompileSpec; 28using executorch::runtime::DelegateHandle; 29using executorch::runtime::EValue; 30using executorch::runtime::Error; 31using executorch::runtime::FreeableBuffer; 32using executorch::runtime::Result; 33 34class MPSBackend final : public ::executorch::runtime::BackendInterface { 35 public: 36 ~MPSBackend() = default; 37 38 bool is_available() const override { 39 return true; 40 } 41 42 Result<DelegateHandle*> init( 43 BackendInitContext& context, 44 FreeableBuffer* processed, 45 ArrayRef<CompileSpec> compile_specs) const override { 46 auto executor = ET_ALLOCATE_INSTANCE_OR_RETURN_ERROR( 47 context.get_runtime_allocator(), mps::delegate::MPSExecutor); 48 // NOTE: Since we use placement new and since this type is not trivially 49 // destructible, we must call the destructor manually in destroy(). 50 new (executor) mps::delegate::MPSExecutor; 51 Error err = mps::delegate::MPSCompiler::compileModel( 52 processed->data(), 53 processed->size(), 54 executor, 55 context.get_runtime_allocator(), 56 compile_specs); 57 ET_CHECK_OR_RETURN_ERROR( 58 err == Error::Ok, 59 Internal, 60 "Failed to initialize the MPS delegate"); 61 62 // Free the flatbuffer. 63 processed->Free(); 64 65 return executor; 66 } 67 68 // Function that actually executes the model in the backend. 69 Error execute( 70 ET_UNUSED BackendExecutionContext& context, 71 DelegateHandle* handle, 72 EValue** args) const override { 73 auto executor = static_cast<mps::delegate::MPSExecutor*>(handle); 74 std::vector<const Tensor*> input_pointers; 75 std::vector<const Tensor*> output_pointers; 76 77 Error err = Error::Ok; 78 79 int i = 0; 80 int total_placeholders = executor->getNumInputs() + executor->getNumOutputs(); 81 while ((input_pointers.size() != executor->getNumInputs() || 82 output_pointers.size() != executor->getNumOutputs()) && 83 (i < total_placeholders)) { 84 ET_CHECK_OR_RETURN_ERROR( 85 args[i] != nullptr, 86 Internal, 87 "Nullptr tensor received during graph execution"); 88 89 if (args[i]->isTensor()) { 90 if (input_pointers.size() < executor->getNumInputs()) { 91 input_pointers.push_back(&args[i]->toTensor()); 92 } else { 93 output_pointers.push_back(&args[i]->toTensor()); 94 } 95 } else if (args[i]->isTensorList()) { 96 const executorch::aten::ArrayRef<executorch::aten::Tensor>& tensorList = args[i]->toTensorList(); 97 for (auto& tensor_ : tensorList) { 98 if (input_pointers.size() < executor->getNumInputs()) { 99 input_pointers.push_back(&tensor_); 100 } else { 101 output_pointers.push_back(&tensor_); 102 } 103 } 104 } else { 105 ET_CHECK_OR_RETURN_ERROR( 106 false, 107 Internal, 108 "Unhandled tag during execution of the graph"); 109 } 110 i++; 111 } 112 113 err = executor->set_inputs_outputs(input_pointers, output_pointers); 114 if (err != Error::Ok) { 115 return err; 116 } 117 118 err = executor->forward(output_pointers); 119 return err; 120 } 121 122 void destroy(DelegateHandle* handle) const override { 123 if (handle != nullptr) { 124 auto executor = static_cast<mps::delegate::MPSExecutor*>(handle); 125 // manually in init(), we must destroy it manually here. 126 executor->~MPSExecutor(); 127 } 128 } 129}; 130 131namespace { 132auto cls = MPSBackend(); 133Backend backend{"MPSBackend", &cls}; 134static auto success_with_compiler = register_backend(backend); 135} // namespace 136 137} // namespace backends 138} // namespace executorch 139