1 // Copyright (c) Meta Platforms, Inc. and affiliates. 2 // 3 // This source code is licensed under the BSD-style license found in the 4 // LICENSE file in the root directory of this source tree. 5 6 #pragma once 7 #include <xnnpack.h> 8 #include <memory> 9 #include <vector> 10 11 namespace torch { 12 namespace jit { 13 namespace xnnpack { 14 namespace delegate { 15 16 class XNNExecutor { 17 private: 18 std::unique_ptr<xnn_runtime, decltype(&xnn_delete_runtime)> runtime_{ 19 nullptr, 20 &xnn_delete_runtime}; 21 std::vector<uint32_t> input_ids_; 22 std::vector<uint32_t> output_ids_; 23 std::vector<xnn_external_value> externals_; 24 25 public: 26 XNNExecutor() = default; 27 28 template <typename T> set_inputs(std::vector<T * > & inputs,std::vector<T * > & outputs)29 bool set_inputs(std::vector<T*>& inputs, std::vector<T*>& outputs) { 30 externals_.clear(); 31 32 if (inputs.size() != input_ids_.size()) { 33 return false; 34 } 35 36 for (int i = 0; i < inputs.size(); i++) { 37 externals_.emplace_back(xnn_external_value{input_ids_[i], inputs[i]}); 38 } 39 40 if (outputs.size() != output_ids_.size()) { 41 return false; 42 } 43 44 for (int i = 0; i < outputs.size(); i++) { 45 externals_.emplace_back(xnn_external_value{output_ids_[i], outputs[i]}); 46 } 47 48 return true; 49 } 50 forward()51 bool forward() { 52 xnn_status status = 53 xnn_setup_runtime(runtime_.get(), externals_.size(), externals_.data()); 54 55 if (status != xnn_status_success) { 56 return false; 57 } 58 59 status = xnn_invoke_runtime(runtime_.get()); 60 61 if (status != xnn_status_success) { 62 return false; 63 } 64 65 return true; 66 } 67 68 friend class XNNCompiler; 69 }; 70 71 } // namespace delegate 72 } // namespace xnnpack 73 } // namespace jit 74 } // namespace torch 75