xref: /aosp_15_r20/external/executorch/backends/qualcomm/aot/wrappers/ScalarParamWrapper.h (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Qualcomm Innovation Center, Inc.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 #pragma once
9 
10 #include <executorch/backends/qualcomm/aot/wrappers/ParamWrapper.h>
11 #include <executorch/runtime/core/error.h>
12 namespace executorch {
13 namespace backends {
14 namespace qnn {
15 template <typename T>
16 class ScalarParamWrapper final : public ParamWrapper {
17  public:
ScalarParamWrapper(std::string name,Qnn_DataType_t data_type,T data)18   explicit ScalarParamWrapper(
19       std::string name,
20       Qnn_DataType_t data_type,
21       T data)
22       : ParamWrapper(QNN_PARAMTYPE_SCALAR, std::move(name)),
23         data_type_(data_type),
24         data_(data) {}
25 
26   // Populate appropriate field in Qnn scalarParam depending on the datatype
27   // of the scalar
PopulateQnnParam()28   executorch::runtime::Error PopulateQnnParam() override {
29     qnn_param_.scalarParam.dataType = data_type_;
30     switch (data_type_) {
31       case QNN_DATATYPE_BOOL_8:
32         qnn_param_.scalarParam.bool8Value = data_;
33         break;
34       case QNN_DATATYPE_UINT_8:
35         qnn_param_.scalarParam.uint8Value = data_;
36         break;
37       case QNN_DATATYPE_INT_8:
38         qnn_param_.scalarParam.int8Value = data_;
39         break;
40       case QNN_DATATYPE_UINT_16:
41         qnn_param_.scalarParam.uint16Value = data_;
42         break;
43       case QNN_DATATYPE_INT_16:
44         qnn_param_.scalarParam.int16Value = data_;
45         break;
46       case QNN_DATATYPE_UINT_32:
47         qnn_param_.scalarParam.uint32Value = data_;
48         break;
49       case QNN_DATATYPE_INT_32:
50         qnn_param_.scalarParam.int32Value = data_;
51         break;
52       case QNN_DATATYPE_FLOAT_32:
53         qnn_param_.scalarParam.floatValue = data_;
54         break;
55       default:
56         QNN_EXECUTORCH_LOG_ERROR(
57             "ScalarParamWrapper failed to assign scalarParam value - "
58             "invalid datatype %d",
59             data_type_);
60         return executorch::runtime::Error::Internal;
61     }
62     return executorch::runtime::Error::Ok;
63   }
64 
GetData()65   const T& GetData() const {
66     return data_;
67   };
68 
69  private:
70   Qnn_DataType_t data_type_;
71   T data_;
72 };
73 } // namespace qnn
74 } // namespace backends
75 } // namespace executorch
76