1 /* 2 * Copyright (c) 2024 MediaTek Inc. 3 * 4 * Licensed under the BSD License (the "License"); you may not use this file 5 * except in compliance with the License. See the license file in the root 6 * directory of this source tree for more details. 7 */ 8 9 #pragma once 10 11 #include "NeuronLog.h" 12 #include "api/NeuronAdapter.h" 13 #include "api/NeuronAdapterShim.h" 14 15 #include <cstddef> 16 #include <cstdint> 17 #include <memory> 18 #include <string> 19 #include <vector> 20 21 namespace executorch { 22 namespace backends { 23 namespace neuron { 24 25 struct NeuronDeleter { operatorNeuronDeleter26 void operator()(NeuronModel* model) { 27 if (model != nullptr) { 28 NeuronModel_free(model); 29 } 30 } operatorNeuronDeleter31 void operator()(NeuronCompilation* compilation) { 32 if (compilation != nullptr) { 33 NeuronCompilation_free(compilation); 34 } 35 } operatorNeuronDeleter36 void operator()(NeuronExecution* execution) { 37 if (execution != nullptr) { 38 NeuronExecution_free(execution); 39 } 40 } operatorNeuronDeleter41 void operator()(NeuronMemory* memory) { 42 if (memory != nullptr) { 43 NeuronMemory_free(memory); 44 } 45 } 46 }; 47 48 class NeuronExecutor { 49 public: 50 explicit NeuronExecutor(); 51 52 int LoadFromCompiledNetwork( 53 const void* buffer, 54 size_t size, 55 int inputCount, 56 int outputCount, 57 std::string& runtimeOption); 58 59 template <bool isInput> SetInputOutput(uint32_t index,void * buffer,size_t length)60 int SetInputOutput(uint32_t index, void* buffer, size_t length) const { 61 CHECK_VALID_PTR(buffer); 62 CHECK_VALID_PTR(mExecution); 63 return isInput ? NeuronExecution_setInput( 64 mExecution.get(), index, nullptr, buffer, length) 65 : NeuronExecution_setOutput( 66 mExecution.get(), index, nullptr, buffer, length); 67 } 68 69 template <bool isInput> SetInputOutputFromMemory(uint32_t index,const NeuronMemory * memory,size_t offset,size_t length)70 int SetInputOutputFromMemory( 71 uint32_t index, 72 const NeuronMemory* memory, 73 size_t offset, 74 size_t length) const { 75 CHECK_VALID_PTR(memory); 76 CHECK_VALID_PTR(mExecution); 77 return isInput 78 ? NeuronExecution_setInputFromMemory( 79 mExecution.get(), index, nullptr, memory, offset, length) 80 : NeuronExecution_setOutputFromMemory( 81 mExecution.get(), index, nullptr, memory, offset, length); 82 } 83 84 template <bool isInput> GetInputOutputPaddedSize(int32_t index)85 size_t GetInputOutputPaddedSize(int32_t index) const { 86 CHECK_VALID_PTR(mCompilation); 87 size_t size = 0; 88 auto res = isInput 89 ? NeuronCompilation_getInputPaddedSize(mCompilation.get(), index, &size) 90 : NeuronCompilation_getOutputPaddedSize( 91 mCompilation.get(), index, &size); 92 return res == NEURON_NO_ERROR ? size : 0; 93 } 94 Compute()95 int Compute() const { 96 CHECK_VALID_PTR(mExecution); 97 return NeuronExecution_compute(mExecution.get()); 98 } 99 IsValid()100 bool IsValid() const { 101 return mExecution != nullptr; 102 } 103 104 private: 105 std::unique_ptr<NeuronModel, NeuronDeleter> mModel; 106 107 std::unique_ptr<NeuronCompilation, NeuronDeleter> mCompilation; 108 109 std::unique_ptr<NeuronExecution, NeuronDeleter> mExecution; 110 111 std::vector<size_t> mInputSizes; 112 113 std::vector<size_t> mOutputSizes; 114 115 private: 116 NeuronExecutor(const NeuronExecutor&); 117 118 NeuronExecutor operator=(const NeuronExecutor&); 119 }; 120 121 } // namespace neuron 122 } // namespace backends 123 } // namespace executorch 124