xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/backends/xnnpack/executor/xnn_executor.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 // Copyright (c) Meta Platforms, Inc. and affiliates.
2 //
3 // This source code is licensed under the BSD-style license found in the
4 // LICENSE file in the root directory of this source tree.
5 
6 #pragma once
7 #include <xnnpack.h>
8 #include <memory>
9 #include <vector>
10 
11 namespace torch {
12 namespace jit {
13 namespace xnnpack {
14 namespace delegate {
15 
16 class XNNExecutor {
17  private:
18   std::unique_ptr<xnn_runtime, decltype(&xnn_delete_runtime)> runtime_{
19       nullptr,
20       &xnn_delete_runtime};
21   std::vector<uint32_t> input_ids_;
22   std::vector<uint32_t> output_ids_;
23   std::vector<xnn_external_value> externals_;
24 
25  public:
26   XNNExecutor() = default;
27 
28   template <typename T>
set_inputs(std::vector<T * > & inputs,std::vector<T * > & outputs)29   bool set_inputs(std::vector<T*>& inputs, std::vector<T*>& outputs) {
30     externals_.clear();
31 
32     if (inputs.size() != input_ids_.size()) {
33       return false;
34     }
35 
36     for (int i = 0; i < inputs.size(); i++) {
37       externals_.emplace_back(xnn_external_value{input_ids_[i], inputs[i]});
38     }
39 
40     if (outputs.size() != output_ids_.size()) {
41       return false;
42     }
43 
44     for (int i = 0; i < outputs.size(); i++) {
45       externals_.emplace_back(xnn_external_value{output_ids_[i], outputs[i]});
46     }
47 
48     return true;
49   }
50 
forward()51   bool forward() {
52     xnn_status status =
53         xnn_setup_runtime(runtime_.get(), externals_.size(), externals_.data());
54 
55     if (status != xnn_status_success) {
56       return false;
57     }
58 
59     status = xnn_invoke_runtime(runtime_.get());
60 
61     if (status != xnn_status_success) {
62       return false;
63     }
64 
65     return true;
66   }
67 
68   friend class XNNCompiler;
69 };
70 
71 } // namespace delegate
72 } // namespace xnnpack
73 } // namespace jit
74 } // namespace torch
75