1 /*
2 * Copyright (c) Qualcomm Innovation Center, Inc.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #include <executorch/backends/qualcomm/aot/ir/qcir_utils.h>
10 #include <executorch/backends/qualcomm/qc_binary_info_generated.h>
11 #include <executorch/backends/qualcomm/runtime/backends/QnnBackendCache.h>
12
13 namespace executorch {
14 namespace backends {
15 namespace qnn {
16
17 using executorch::runtime::Error;
18
GetQnnGraphInfoFromBinary(void * buffer,uint32_t nbytes)19 Error QnnBackendCache::GetQnnGraphInfoFromBinary(
20 void* buffer,
21 uint32_t nbytes) {
22 const QnnSystemInterface& qnn_sys_interface =
23 qnn_sys_impl_.GetQnnSystemInterface();
24 std::uint32_t num_graphs;
25 QnnSystemContext_GraphInfo_t* graphs = nullptr;
26 const QnnSystemContext_BinaryInfo_t* binaryinfo{nullptr};
27 Qnn_ContextBinarySize_t binaryinfo_size = 0;
28 Qnn_ErrorHandle_t error = QNN_SUCCESS;
29
30 error = qnn_sys_interface.qnn_system_context_get_binary_info(
31 sys_context_handle_, buffer, nbytes, &binaryinfo, &binaryinfo_size);
32
33 if (error != QNN_SUCCESS) {
34 QNN_EXECUTORCH_LOG_WARN(
35 "Failed to interpret QNN context "
36 "binary. Error code %d. "
37 "Try verifying binary with online-prepare format.",
38 QNN_GET_ERROR_CODE(error));
39 return Error::Internal;
40 }
41
42 Error status = RetrieveBackendBinaryInfo(binaryinfo);
43 if (status == Error::Internal) {
44 QNN_EXECUTORCH_LOG_ERROR(
45 "Failed to retrieve backend binary info from QNN context binary.");
46 return Error::Internal;
47 }
48
49 if (binaryinfo->version == QNN_SYSTEM_CONTEXT_BINARY_INFO_VERSION_1) {
50 num_graphs = binaryinfo->contextBinaryInfoV1.numGraphs;
51 graphs = binaryinfo->contextBinaryInfoV1.graphs;
52 } else if (binaryinfo->version == QNN_SYSTEM_CONTEXT_BINARY_INFO_VERSION_2) {
53 num_graphs = binaryinfo->contextBinaryInfoV2.numGraphs;
54 graphs = binaryinfo->contextBinaryInfoV2.graphs;
55 } else {
56 QNN_EXECUTORCH_LOG_WARN(
57 "Unknown QNN BinaryInfo version %d.", binaryinfo->version);
58 return Error::Internal;
59 }
60
61 for (std::uint32_t i = 0; i < num_graphs; ++i) {
62 if (graphs->version == QNN_SYSTEM_CONTEXT_GRAPH_INFO_VERSION_1) {
63 RetrieveGraphInfo<QnnSystemContext_GraphInfoV1_t>(graphs[i].graphInfoV1);
64 } else if (graphs->version == QNN_SYSTEM_CONTEXT_GRAPH_INFO_VERSION_2) {
65 RetrieveGraphInfo<QnnSystemContext_GraphInfoV2_t>(graphs[i].graphInfoV2);
66 } else {
67 QNN_EXECUTORCH_LOG_WARN(
68 "Unknown QNN GraphInfo version %d.", binaryinfo->version);
69 return Error::Internal;
70 }
71 }
72
73 return Error::Ok;
74 }
75
Configure()76 Error QnnBackendCache::Configure() {
77 if (qnn_context_blob_.buffer == nullptr) {
78 state_ = SERIALIZE;
79 // use aot_graph_name if we're lowering graph on host side
80 graph_names_.push_back(aot_graph_name_);
81 QNN_EXECUTORCH_LOG_INFO("Caching: Caching is in SAVE MODE.");
82 return Error::Ok;
83 }
84
85 if (qnn_sys_impl_.Load() != Error::Ok) {
86 QNN_EXECUTORCH_LOG_ERROR(
87 "Failed to Load QnnSystem "
88 "APIs. Caching mechanism is being disabled.");
89 return Error::Internal;
90 }
91
92 Qnn_ErrorHandle_t error = QNN_SUCCESS;
93
94 // create QNN SystemContext
95 const QnnSystemInterface& qnn_sys_interface =
96 qnn_sys_impl_.GetQnnSystemInterface();
97 error = qnn_sys_interface.qnn_system_context_create(&sys_context_handle_);
98
99 if (error != QNN_SUCCESS) {
100 QNN_EXECUTORCH_LOG_ERROR(
101 "Failed to create Qnn "
102 "SystemContext. Caching mechanism will be disabled. Error code %d",
103 QNN_GET_ERROR_CODE(error));
104 return Error::Internal;
105 }
106
107 // DO DESERIALIZE
108 state_ = DESERIALIZE;
109 QNN_EXECUTORCH_LOG_INFO("Caching: Caching is in RESTORE MODE.");
110 flatbuffers::Verifier verifier_binary_info(
111 static_cast<const uint8_t* const>(qnn_context_blob_.buffer),
112 qnn_context_blob_.nbytes);
113 if (!qnn_delegate::VerifyBinaryInfoBuffer(verifier_binary_info)) {
114 QNN_EXECUTORCH_LOG_ERROR("Fail to verify binary info");
115 return Error::Internal;
116 }
117
118 auto binary_info = GetBinaryInfo(qnn_context_blob_.buffer);
119 Error status = GetQnnGraphInfoFromBinary(
120 const_cast<uint8_t*>(binary_info->data()->data()),
121 binary_info->data()->size());
122
123 if (status == Error::Internal) {
124 // check if context binary came from flatbuffer
125 flatbuffers::Verifier verifier(
126 binary_info->data()->data(), binary_info->data()->size());
127
128 if (qcir::VerifyContextBuffer(verifier)) {
129 state_ = ONLINE_PREPARE;
130 auto context = qcir::GetContext(binary_info->data()->data());
131 for (const auto& graph : *context->graphs()) {
132 graph_names_.emplace_back(graph->name()->str());
133 }
134 return Error::Ok;
135 }
136
137 QNN_EXECUTORCH_LOG_ERROR(
138 "Failed to parse QNN Graph Info. The cache "
139 "might be broken. Please consider to re-generate the "
140 "cache.");
141 InvalidateCache();
142 }
143 return Error::Ok;
144 }
145
~QnnBackendCache()146 QnnBackendCache::~QnnBackendCache() {
147 Qnn_ErrorHandle_t error = QNN_SUCCESS;
148 if (sys_context_handle_ != nullptr) {
149 const QnnSystemInterface& qnn_sys_interface =
150 qnn_sys_impl_.GetQnnSystemInterface();
151 error = qnn_sys_interface.qnn_system_context_free(sys_context_handle_);
152 if (error != QNN_SUCCESS) {
153 QNN_EXECUTORCH_LOG_WARN("Failed to free QNN system context.");
154 }
155 sys_context_handle_ = nullptr;
156 }
157 qnn_sys_impl_.Unload();
158 }
159
GetGraphInputs(const std::string & graph_name)160 std::vector<Qnn_Tensor_t> QnnBackendCache::GetGraphInputs(
161 const std::string& graph_name) {
162 if (state_ != DESERIALIZE)
163 return {};
164
165 return input_tensor_structs_[graph_name];
166 }
167
GetGraphOutputs(const std::string & graph_name)168 std::vector<Qnn_Tensor_t> QnnBackendCache::GetGraphOutputs(
169 const std::string& graph_name) {
170 if (state_ != DESERIALIZE)
171 return {};
172
173 return output_tensor_structs_[graph_name];
174 }
175
176 template <typename INFO>
RetrieveGraphInfo(const INFO & info)177 void QnnBackendCache::RetrieveGraphInfo(const INFO& info) {
178 // get graph name from metadata
179 graph_names_.push_back(info.graphName);
180 // get graph inputs from metadata
181 uint32_t numGraphInputs = info.numGraphInputs;
182 input_tensor_structs_[graph_names_.back()].reserve(numGraphInputs);
183 for (std::uint32_t i = 0; i < numGraphInputs; ++i) {
184 input_tensor_structs_[graph_names_.back()].emplace_back(
185 info.graphInputs[i]);
186 }
187 // get graph outputs from metadata
188 uint32_t numGraphOutputs = info.numGraphOutputs;
189 output_tensor_structs_[graph_names_.back()].reserve(numGraphOutputs);
190 for (std::uint32_t i = 0; i < numGraphOutputs; ++i) {
191 output_tensor_structs_[graph_names_.back()].emplace_back(
192 info.graphOutputs[i]);
193 }
194 }
195
196 } // namespace qnn
197 } // namespace backends
198 } // namespace executorch
199