xref: /aosp_15_r20/external/executorch/backends/qualcomm/aot/wrappers/OpWrapper.h (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Qualcomm Innovation Center, Inc.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 #pragma once
9 
10 #include <executorch/backends/qualcomm/aot/wrappers/ParamWrapper.h>
11 #include <executorch/backends/qualcomm/aot/wrappers/QuantizeParamsWrapper.h>
12 #include <executorch/backends/qualcomm/aot/wrappers/ScalarParamWrapper.h>
13 #include <executorch/backends/qualcomm/aot/wrappers/TensorParamWrapper.h>
14 #include <executorch/backends/qualcomm/aot/wrappers/TensorWrapper.h>
15 
16 #include <cstdint>
17 #include <memory>
18 #include <sstream>
19 #include <typeinfo>
20 namespace executorch {
21 namespace backends {
22 namespace qnn {
23 class OpWrapper final {
24  public:
OpWrapper(std::string name,std::string package_name,std::string op_type)25   explicit OpWrapper(
26       std::string name,
27       std::string package_name,
28       std::string op_type)
29       : name_(std::move(name)),
30         package_name_(std::move(package_name)),
31         op_type_(std::move(op_type)) {}
32 
OpWrapper(OpWrapper && other)33   OpWrapper(OpWrapper&& other) noexcept
34       : name_(std::move(other.name_)),
35         package_name_(std::move(other.package_name_)),
36         op_type_(std::move(other.op_type_)),
37         input_tensors_(std::move(other.input_tensors_)),
38         output_tensors_(std::move(other.output_tensors_)),
39         params_(std::move(other.params_)),
40         input_tensor_structs_(std::move(other.input_tensor_structs_)),
41         output_tensor_structs_(std::move(other.output_tensor_structs_)) {}
42 
43   OpWrapper(const OpWrapper& other) = delete;
44 
45   OpWrapper& operator=(const OpWrapper& other) = delete;
46 
47   OpWrapper& operator=(OpWrapper&& other) = delete;
48 
49   ~OpWrapper() = default;
50 
AddInputTensors(const std::vector<std::shared_ptr<TensorWrapper>> & tensors)51   void AddInputTensors(
52       const std::vector<std::shared_ptr<TensorWrapper>>& tensors) {
53     input_tensors_ = tensors;
54   }
55 
AddOutputTensors(const std::vector<std::shared_ptr<TensorWrapper>> & tensors)56   void AddOutputTensors(
57       const std::vector<std::shared_ptr<TensorWrapper>>& tensors) {
58     output_tensors_ = tensors;
59   }
60 
61   void AddTensorParam(
62       const std::string& name,
63       Qnn_DataType_t data_type,
64       std::uint32_t rank,
65       const std::uint32_t dims[],
66       const void* data,
67       bool copy_data = false) {
68     std::unique_ptr<QuantizeParamsWrapper> quantize_param_wrapper =
69         std::make_unique<UndefinedQuantizeParamsWrapper>();
70     constexpr std::uint32_t kBytes = 0;
71     std::shared_ptr<TensorWrapper> tensor_wrapper = CreateTensorWrapper(
72         QNN_TENSOR_TYPE_STATIC,
73         data_type,
74         std::move(quantize_param_wrapper),
75         rank,
76         dims,
77         kBytes,
78         data,
79         copy_data);
80     params_.emplace_back(
81         std::make_unique<TensorParamWrapper>(name, tensor_wrapper));
82   }
83 
84   template <typename T>
85   void
AddScalarParam(const std::string & name,Qnn_DataType_t data_type,T data)86   AddScalarParam(const std::string& name, Qnn_DataType_t data_type, T data) {
87     params_.emplace_back(
88         std::make_unique<ScalarParamWrapper<T>>(name, data_type, data));
89   }
90 
GetParams()91   const std::vector<std::unique_ptr<ParamWrapper>>& GetParams() {
92     return params_;
93   }
94 
GetInputTensors()95   const std::vector<std::shared_ptr<TensorWrapper>>& GetInputTensors() {
96     return input_tensors_;
97   }
98 
GetOutputTensors()99   const std::vector<std::shared_ptr<TensorWrapper>>& GetOutputTensors() {
100     return output_tensors_;
101   }
GetOpType()102   const std::string GetOpType() {
103     return op_type_;
104   }
105   Qnn_OpConfig_t GetOpConfig();
106 
107  private:
108   std::string name_;
109   std::string package_name_;
110   std::string op_type_;
111   std::vector<std::shared_ptr<TensorWrapper>> input_tensors_;
112   std::vector<std::shared_ptr<TensorWrapper>> output_tensors_;
113   std::vector<std::unique_ptr<ParamWrapper>> params_;
114   std::vector<Qnn_Tensor_t> input_tensor_structs_;
115   std::vector<Qnn_Tensor_t> output_tensor_structs_;
116   std::vector<Qnn_Param_t> param_types_;
117 };
118 } // namespace qnn
119 } // namespace backends
120 } // namespace executorch
121