xref: /aosp_15_r20/external/executorch/backends/vulkan/runtime/VulkanBackend.cpp (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1*523fa7a6SAndroid Build Coastguard Worker /*
2*523fa7a6SAndroid Build Coastguard Worker  * Copyright (c) Meta Platforms, Inc. and affiliates.
3*523fa7a6SAndroid Build Coastguard Worker  * All rights reserved.
4*523fa7a6SAndroid Build Coastguard Worker  *
5*523fa7a6SAndroid Build Coastguard Worker  * This source code is licensed under the BSD-style license found in the
6*523fa7a6SAndroid Build Coastguard Worker  * LICENSE file in the root directory of this source tree.
7*523fa7a6SAndroid Build Coastguard Worker  */
8*523fa7a6SAndroid Build Coastguard Worker 
9*523fa7a6SAndroid Build Coastguard Worker #include <executorch/backends/vulkan/runtime/VulkanDelegateHeader.h>
10*523fa7a6SAndroid Build Coastguard Worker #include <executorch/backends/vulkan/serialization/schema_generated.h>
11*523fa7a6SAndroid Build Coastguard Worker 
12*523fa7a6SAndroid Build Coastguard Worker #include <executorch/backends/vulkan/runtime/graph/ComputeGraph.h>
13*523fa7a6SAndroid Build Coastguard Worker 
14*523fa7a6SAndroid Build Coastguard Worker #include <executorch/backends/vulkan/runtime/graph/ops/OperatorRegistry.h>
15*523fa7a6SAndroid Build Coastguard Worker 
16*523fa7a6SAndroid Build Coastguard Worker #include <executorch/runtime/backend/interface.h>
17*523fa7a6SAndroid Build Coastguard Worker #include <executorch/runtime/core/error.h>
18*523fa7a6SAndroid Build Coastguard Worker #include <executorch/runtime/core/evalue.h>
19*523fa7a6SAndroid Build Coastguard Worker #ifdef ET_EVENT_TRACER_ENABLED
20*523fa7a6SAndroid Build Coastguard Worker #include <executorch/runtime/core/event_tracer_hooks_delegate.h>
21*523fa7a6SAndroid Build Coastguard Worker #endif // ET_EVENT_TRACER_ENABLED
22*523fa7a6SAndroid Build Coastguard Worker #include <executorch/runtime/core/exec_aten/util/tensor_util.h>
23*523fa7a6SAndroid Build Coastguard Worker #include <executorch/runtime/platform/compiler.h>
24*523fa7a6SAndroid Build Coastguard Worker #include <executorch/runtime/platform/profiler.h>
25*523fa7a6SAndroid Build Coastguard Worker 
26*523fa7a6SAndroid Build Coastguard Worker #include <cstdio>
27*523fa7a6SAndroid Build Coastguard Worker #include <cstdlib> /* strtol */
28*523fa7a6SAndroid Build Coastguard Worker #include <cstring>
29*523fa7a6SAndroid Build Coastguard Worker #include <memory>
30*523fa7a6SAndroid Build Coastguard Worker #include <type_traits>
31*523fa7a6SAndroid Build Coastguard Worker #include <vector>
32*523fa7a6SAndroid Build Coastguard Worker 
33*523fa7a6SAndroid Build Coastguard Worker namespace executorch {
34*523fa7a6SAndroid Build Coastguard Worker namespace backends {
35*523fa7a6SAndroid Build Coastguard Worker namespace vulkan {
36*523fa7a6SAndroid Build Coastguard Worker namespace {
37*523fa7a6SAndroid Build Coastguard Worker 
38*523fa7a6SAndroid Build Coastguard Worker using executorch::runtime::ArrayRef;
39*523fa7a6SAndroid Build Coastguard Worker using executorch::runtime::Backend;
40*523fa7a6SAndroid Build Coastguard Worker using executorch::runtime::BackendExecutionContext;
41*523fa7a6SAndroid Build Coastguard Worker using executorch::runtime::BackendInitContext;
42*523fa7a6SAndroid Build Coastguard Worker using executorch::runtime::CompileSpec;
43*523fa7a6SAndroid Build Coastguard Worker using executorch::runtime::DelegateHandle;
44*523fa7a6SAndroid Build Coastguard Worker using executorch::runtime::Error;
45*523fa7a6SAndroid Build Coastguard Worker using executorch::runtime::EValue;
46*523fa7a6SAndroid Build Coastguard Worker using executorch::runtime::FreeableBuffer;
47*523fa7a6SAndroid Build Coastguard Worker using executorch::runtime::kTensorDimensionLimit;
48*523fa7a6SAndroid Build Coastguard Worker using executorch::runtime::Result;
49*523fa7a6SAndroid Build Coastguard Worker 
50*523fa7a6SAndroid Build Coastguard Worker using namespace vkcompute;
51*523fa7a6SAndroid Build Coastguard Worker 
52*523fa7a6SAndroid Build Coastguard Worker // Flatbuffer types
53*523fa7a6SAndroid Build Coastguard Worker using VkGraphPtr = const vkgraph::VkGraph*;
54*523fa7a6SAndroid Build Coastguard Worker using OpCallPtr = const vkgraph::OperatorCall*;
55*523fa7a6SAndroid Build Coastguard Worker using VkValuePtr = const vkgraph::VkValue*;
56*523fa7a6SAndroid Build Coastguard Worker using VkTensorPtr = const vkgraph::VkTensor*;
57*523fa7a6SAndroid Build Coastguard Worker using VkBytesPtr = const vkgraph::VkBytes*;
58*523fa7a6SAndroid Build Coastguard Worker 
59*523fa7a6SAndroid Build Coastguard Worker // Flatbuffer vector types
60*523fa7a6SAndroid Build Coastguard Worker using VkValuesVector =
61*523fa7a6SAndroid Build Coastguard Worker     const flatbuffers::Vector<flatbuffers::Offset<vkgraph::VkValue>>*;
62*523fa7a6SAndroid Build Coastguard Worker using BytesVector =
63*523fa7a6SAndroid Build Coastguard Worker     const flatbuffers::Vector<flatbuffers::Offset<vkgraph::VkBytes>>*;
64*523fa7a6SAndroid Build Coastguard Worker using UIntVector = const flatbuffers::Vector<uint32_t>*;
65*523fa7a6SAndroid Build Coastguard Worker 
get_constant_data_ptr(VkGraphPtr flatbuffer_graph,const int32_t buffer_idx,const uint8_t * constant_data)66*523fa7a6SAndroid Build Coastguard Worker const uint8_t* get_constant_data_ptr(
67*523fa7a6SAndroid Build Coastguard Worker     VkGraphPtr flatbuffer_graph,
68*523fa7a6SAndroid Build Coastguard Worker     const int32_t buffer_idx,
69*523fa7a6SAndroid Build Coastguard Worker     const uint8_t* constant_data) {
70*523fa7a6SAndroid Build Coastguard Worker   VkBytesPtr constant_bytes = flatbuffer_graph->constants()->Get(buffer_idx);
71*523fa7a6SAndroid Build Coastguard Worker   return constant_data + constant_bytes->offset();
72*523fa7a6SAndroid Build Coastguard Worker }
73*523fa7a6SAndroid Build Coastguard Worker 
get_scalar_type(const vkgraph::VkDataType & vk_datatype)74*523fa7a6SAndroid Build Coastguard Worker vkapi::ScalarType get_scalar_type(const vkgraph::VkDataType& vk_datatype) {
75*523fa7a6SAndroid Build Coastguard Worker   switch (vk_datatype) {
76*523fa7a6SAndroid Build Coastguard Worker     case vkgraph::VkDataType::BOOL:
77*523fa7a6SAndroid Build Coastguard Worker       return vkapi::kBool;
78*523fa7a6SAndroid Build Coastguard Worker     case vkgraph::VkDataType::UINT8:
79*523fa7a6SAndroid Build Coastguard Worker       return vkapi::kByte;
80*523fa7a6SAndroid Build Coastguard Worker     case vkgraph::VkDataType::INT8:
81*523fa7a6SAndroid Build Coastguard Worker       return vkapi::kChar;
82*523fa7a6SAndroid Build Coastguard Worker     case vkgraph::VkDataType::INT32:
83*523fa7a6SAndroid Build Coastguard Worker       return vkapi::kInt;
84*523fa7a6SAndroid Build Coastguard Worker     case vkgraph::VkDataType::FLOAT16:
85*523fa7a6SAndroid Build Coastguard Worker       return vkapi::kHalf;
86*523fa7a6SAndroid Build Coastguard Worker     case vkgraph::VkDataType::FLOAT32:
87*523fa7a6SAndroid Build Coastguard Worker       return vkapi::kFloat;
88*523fa7a6SAndroid Build Coastguard Worker   }
89*523fa7a6SAndroid Build Coastguard Worker }
90*523fa7a6SAndroid Build Coastguard Worker 
get_storage_type(const vkgraph::VkStorageType & vk_storage_type)91*523fa7a6SAndroid Build Coastguard Worker utils::StorageType get_storage_type(
92*523fa7a6SAndroid Build Coastguard Worker     const vkgraph::VkStorageType& vk_storage_type) {
93*523fa7a6SAndroid Build Coastguard Worker   switch (vk_storage_type) {
94*523fa7a6SAndroid Build Coastguard Worker     case vkgraph::VkStorageType::BUFFER:
95*523fa7a6SAndroid Build Coastguard Worker       return utils::kBuffer;
96*523fa7a6SAndroid Build Coastguard Worker     case vkgraph::VkStorageType::TEXTURE_3D:
97*523fa7a6SAndroid Build Coastguard Worker       return utils::kTexture3D;
98*523fa7a6SAndroid Build Coastguard Worker     case vkgraph::VkStorageType::TEXTURE_2D:
99*523fa7a6SAndroid Build Coastguard Worker       return utils::kTexture2D;
100*523fa7a6SAndroid Build Coastguard Worker     default:
101*523fa7a6SAndroid Build Coastguard Worker       break;
102*523fa7a6SAndroid Build Coastguard Worker   }
103*523fa7a6SAndroid Build Coastguard Worker   VK_THROW("Invalid storage type encountered!");
104*523fa7a6SAndroid Build Coastguard Worker }
105*523fa7a6SAndroid Build Coastguard Worker 
get_memory_layout(const vkgraph::VkMemoryLayout & vk_memory_layout)106*523fa7a6SAndroid Build Coastguard Worker utils::GPUMemoryLayout get_memory_layout(
107*523fa7a6SAndroid Build Coastguard Worker     const vkgraph::VkMemoryLayout& vk_memory_layout) {
108*523fa7a6SAndroid Build Coastguard Worker   switch (vk_memory_layout) {
109*523fa7a6SAndroid Build Coastguard Worker     case vkgraph::VkMemoryLayout::TENSOR_WIDTH_PACKED:
110*523fa7a6SAndroid Build Coastguard Worker       return utils::kWidthPacked;
111*523fa7a6SAndroid Build Coastguard Worker     case vkgraph::VkMemoryLayout::TENSOR_HEIGHT_PACKED:
112*523fa7a6SAndroid Build Coastguard Worker       return utils::kHeightPacked;
113*523fa7a6SAndroid Build Coastguard Worker     case vkgraph::VkMemoryLayout::TENSOR_CHANNELS_PACKED:
114*523fa7a6SAndroid Build Coastguard Worker       return utils::kChannelsPacked;
115*523fa7a6SAndroid Build Coastguard Worker     default:
116*523fa7a6SAndroid Build Coastguard Worker       break;
117*523fa7a6SAndroid Build Coastguard Worker   }
118*523fa7a6SAndroid Build Coastguard Worker   VK_THROW("Invalid memory layout encountered!");
119*523fa7a6SAndroid Build Coastguard Worker }
120*523fa7a6SAndroid Build Coastguard Worker 
get_graph_config(ArrayRef<CompileSpec> & compile_specs)121*523fa7a6SAndroid Build Coastguard Worker GraphConfig get_graph_config(ArrayRef<CompileSpec>& compile_specs) {
122*523fa7a6SAndroid Build Coastguard Worker   GraphConfig config = GraphConfig();
123*523fa7a6SAndroid Build Coastguard Worker 
124*523fa7a6SAndroid Build Coastguard Worker   for (const CompileSpec& spec : compile_specs) {
125*523fa7a6SAndroid Build Coastguard Worker     const uint8_t* value_data = (const uint8_t*)spec.value.buffer;
126*523fa7a6SAndroid Build Coastguard Worker     const size_t value_size = spec.value.nbytes;
127*523fa7a6SAndroid Build Coastguard Worker     if (strcmp(spec.key, "storage_type_override") == 0) {
128*523fa7a6SAndroid Build Coastguard Worker       ET_CHECK_MSG(value_size == sizeof(int32_t), "Unexpected value size!");
129*523fa7a6SAndroid Build Coastguard Worker       int value_as_int = static_cast<int>(getUInt32LE(value_data));
130*523fa7a6SAndroid Build Coastguard Worker       utils::StorageType storage_type =
131*523fa7a6SAndroid Build Coastguard Worker           static_cast<utils::StorageType>(value_as_int);
132*523fa7a6SAndroid Build Coastguard Worker 
133*523fa7a6SAndroid Build Coastguard Worker       config.set_storage_type_override(storage_type);
134*523fa7a6SAndroid Build Coastguard Worker     }
135*523fa7a6SAndroid Build Coastguard Worker     if (strcmp(spec.key, "memory_layout_override") == 0) {
136*523fa7a6SAndroid Build Coastguard Worker       ET_CHECK_MSG(value_size == sizeof(uint32_t), "Unexpected value size!");
137*523fa7a6SAndroid Build Coastguard Worker       uint32_t value_as_int = getUInt32LE(value_data);
138*523fa7a6SAndroid Build Coastguard Worker       utils::GPUMemoryLayout memory_layout =
139*523fa7a6SAndroid Build Coastguard Worker           static_cast<utils::GPUMemoryLayout>(value_as_int);
140*523fa7a6SAndroid Build Coastguard Worker 
141*523fa7a6SAndroid Build Coastguard Worker       config.set_memory_layout_override(memory_layout);
142*523fa7a6SAndroid Build Coastguard Worker     }
143*523fa7a6SAndroid Build Coastguard Worker   }
144*523fa7a6SAndroid Build Coastguard Worker #ifdef ET_EVENT_TRACER_ENABLED
145*523fa7a6SAndroid Build Coastguard Worker   config.enable_querypool = true;
146*523fa7a6SAndroid Build Coastguard Worker #endif // ET_EVENT_TRACER_ENABLED
147*523fa7a6SAndroid Build Coastguard Worker   return config;
148*523fa7a6SAndroid Build Coastguard Worker }
149*523fa7a6SAndroid Build Coastguard Worker 
150*523fa7a6SAndroid Build Coastguard Worker class GraphBuilder {
151*523fa7a6SAndroid Build Coastguard Worker   ComputeGraph* compute_graph_;
152*523fa7a6SAndroid Build Coastguard Worker   VkGraphPtr flatbuffer_;
153*523fa7a6SAndroid Build Coastguard Worker   const uint8_t* constant_data_;
154*523fa7a6SAndroid Build Coastguard Worker 
155*523fa7a6SAndroid Build Coastguard Worker   std::unordered_map<uint32_t, ValueRef> ref_mapping_;
156*523fa7a6SAndroid Build Coastguard Worker 
157*523fa7a6SAndroid Build Coastguard Worker  public:
GraphBuilder(ComputeGraph * compute_graph,VkGraphPtr flatbuffer,const uint8_t * constant_data)158*523fa7a6SAndroid Build Coastguard Worker   explicit GraphBuilder(
159*523fa7a6SAndroid Build Coastguard Worker       ComputeGraph* compute_graph,
160*523fa7a6SAndroid Build Coastguard Worker       VkGraphPtr flatbuffer,
161*523fa7a6SAndroid Build Coastguard Worker       const uint8_t* constant_data)
162*523fa7a6SAndroid Build Coastguard Worker       : compute_graph_(compute_graph),
163*523fa7a6SAndroid Build Coastguard Worker         flatbuffer_(flatbuffer),
164*523fa7a6SAndroid Build Coastguard Worker         constant_data_(constant_data),
165*523fa7a6SAndroid Build Coastguard Worker         ref_mapping_() {}
166*523fa7a6SAndroid Build Coastguard Worker 
fb_id_exists(const uint32_t fb_id)167*523fa7a6SAndroid Build Coastguard Worker   bool fb_id_exists(const uint32_t fb_id) {
168*523fa7a6SAndroid Build Coastguard Worker     const std::unordered_map<uint32_t, ValueRef>::iterator found_ref =
169*523fa7a6SAndroid Build Coastguard Worker         ref_mapping_.find(fb_id);
170*523fa7a6SAndroid Build Coastguard Worker 
171*523fa7a6SAndroid Build Coastguard Worker     return found_ref != ref_mapping_.end();
172*523fa7a6SAndroid Build Coastguard Worker   }
173*523fa7a6SAndroid Build Coastguard Worker 
get_fb_id_valueref(const uint32_t fb_id)174*523fa7a6SAndroid Build Coastguard Worker   ValueRef get_fb_id_valueref(const uint32_t fb_id) {
175*523fa7a6SAndroid Build Coastguard Worker     const std::unordered_map<uint32_t, ValueRef>::iterator found_ref =
176*523fa7a6SAndroid Build Coastguard Worker         ref_mapping_.find(fb_id);
177*523fa7a6SAndroid Build Coastguard Worker 
178*523fa7a6SAndroid Build Coastguard Worker     ET_CHECK_MSG(
179*523fa7a6SAndroid Build Coastguard Worker         found_ref != ref_mapping_.end(),
180*523fa7a6SAndroid Build Coastguard Worker         "Trying to extract a value that hasn't yet been added to the graph.");
181*523fa7a6SAndroid Build Coastguard Worker 
182*523fa7a6SAndroid Build Coastguard Worker     return found_ref->second;
183*523fa7a6SAndroid Build Coastguard Worker   }
184*523fa7a6SAndroid Build Coastguard Worker 
add_tensor_to_graph(const uint32_t fb_id,VkTensorPtr tensor_fb)185*523fa7a6SAndroid Build Coastguard Worker   void add_tensor_to_graph(const uint32_t fb_id, VkTensorPtr tensor_fb) {
186*523fa7a6SAndroid Build Coastguard Worker     const vkapi::ScalarType& dtype = get_scalar_type(tensor_fb->datatype());
187*523fa7a6SAndroid Build Coastguard Worker     utils::StorageType storage_type =
188*523fa7a6SAndroid Build Coastguard Worker         tensor_fb->storage_type() == vkgraph::VkStorageType::DEFAULT_STORAGE
189*523fa7a6SAndroid Build Coastguard Worker         ? compute_graph_->suggested_storage_type()
190*523fa7a6SAndroid Build Coastguard Worker         : get_storage_type(tensor_fb->storage_type());
191*523fa7a6SAndroid Build Coastguard Worker 
192*523fa7a6SAndroid Build Coastguard Worker     UIntVector dims_fb = tensor_fb->dims();
193*523fa7a6SAndroid Build Coastguard Worker     const std::vector<int64_t> dims_vector(dims_fb->cbegin(), dims_fb->cend());
194*523fa7a6SAndroid Build Coastguard Worker 
195*523fa7a6SAndroid Build Coastguard Worker     utils::GPUMemoryLayout memory_layout =
196*523fa7a6SAndroid Build Coastguard Worker         tensor_fb->memory_layout() == vkgraph::VkMemoryLayout::DEFAULT_LAYOUT
197*523fa7a6SAndroid Build Coastguard Worker         ? compute_graph_->suggested_memory_layout(dims_vector)
198*523fa7a6SAndroid Build Coastguard Worker         : get_memory_layout(tensor_fb->memory_layout());
199*523fa7a6SAndroid Build Coastguard Worker 
200*523fa7a6SAndroid Build Coastguard Worker     ValueRef ref;
201*523fa7a6SAndroid Build Coastguard Worker     if (tensor_fb->constant_id() >= 0) {
202*523fa7a6SAndroid Build Coastguard Worker       const uint8_t* tensor_data = get_constant_data_ptr(
203*523fa7a6SAndroid Build Coastguard Worker           flatbuffer_, tensor_fb->constant_id(), constant_data_);
204*523fa7a6SAndroid Build Coastguard Worker 
205*523fa7a6SAndroid Build Coastguard Worker       ref = compute_graph_->add_tensorref(dims_vector, dtype, tensor_data);
206*523fa7a6SAndroid Build Coastguard Worker     } else {
207*523fa7a6SAndroid Build Coastguard Worker       ref = compute_graph_->add_tensor(
208*523fa7a6SAndroid Build Coastguard Worker           dims_vector,
209*523fa7a6SAndroid Build Coastguard Worker           dtype,
210*523fa7a6SAndroid Build Coastguard Worker           storage_type,
211*523fa7a6SAndroid Build Coastguard Worker           memory_layout,
212*523fa7a6SAndroid Build Coastguard Worker           tensor_fb->mem_obj_id());
213*523fa7a6SAndroid Build Coastguard Worker     }
214*523fa7a6SAndroid Build Coastguard Worker 
215*523fa7a6SAndroid Build Coastguard Worker     ref_mapping_[fb_id] = ref;
216*523fa7a6SAndroid Build Coastguard Worker   }
217*523fa7a6SAndroid Build Coastguard Worker 
add_none_to_graph(const uint32_t fb_id)218*523fa7a6SAndroid Build Coastguard Worker   void add_none_to_graph(const uint32_t fb_id) {
219*523fa7a6SAndroid Build Coastguard Worker     ValueRef ref = compute_graph_->add_none();
220*523fa7a6SAndroid Build Coastguard Worker     ref_mapping_[fb_id] = ref;
221*523fa7a6SAndroid Build Coastguard Worker   }
222*523fa7a6SAndroid Build Coastguard Worker 
223*523fa7a6SAndroid Build Coastguard Worker   template <typename T>
224*523fa7a6SAndroid Build Coastguard Worker   typename std::enable_if<is_valid_scalar_type<T>::value, void>::type
add_scalar_to_graph(const uint32_t fb_id,T value)225*523fa7a6SAndroid Build Coastguard Worker   add_scalar_to_graph(const uint32_t fb_id, T value) {
226*523fa7a6SAndroid Build Coastguard Worker     ValueRef ref = compute_graph_->add_scalar(value);
227*523fa7a6SAndroid Build Coastguard Worker     ref_mapping_[fb_id] = ref;
228*523fa7a6SAndroid Build Coastguard Worker   }
229*523fa7a6SAndroid Build Coastguard Worker 
230*523fa7a6SAndroid Build Coastguard Worker   template <typename T>
231*523fa7a6SAndroid Build Coastguard Worker   typename std::enable_if<is_valid_scalar_type<T>::value, void>::type
add_scalar_list_to_graph(const uint32_t fb_id,std::vector<T> && value)232*523fa7a6SAndroid Build Coastguard Worker   add_scalar_list_to_graph(const uint32_t fb_id, std::vector<T>&& value) {
233*523fa7a6SAndroid Build Coastguard Worker     ValueRef ref = compute_graph_->add_scalar_list(std::move(value));
234*523fa7a6SAndroid Build Coastguard Worker     ref_mapping_[fb_id] = ref;
235*523fa7a6SAndroid Build Coastguard Worker   }
236*523fa7a6SAndroid Build Coastguard Worker 
add_value_list_to_graph(const uint32_t fb_id,std::vector<ValueRef> && value)237*523fa7a6SAndroid Build Coastguard Worker   void add_value_list_to_graph(
238*523fa7a6SAndroid Build Coastguard Worker       const uint32_t fb_id,
239*523fa7a6SAndroid Build Coastguard Worker       std::vector<ValueRef>&& value) {
240*523fa7a6SAndroid Build Coastguard Worker     ValueRef ref = compute_graph_->add_value_list(std::move(value));
241*523fa7a6SAndroid Build Coastguard Worker     ref_mapping_[fb_id] = ref;
242*523fa7a6SAndroid Build Coastguard Worker   }
243*523fa7a6SAndroid Build Coastguard Worker 
add_string_to_graph(const uint32_t fb_id,VkValuePtr value)244*523fa7a6SAndroid Build Coastguard Worker   void add_string_to_graph(const uint32_t fb_id, VkValuePtr value) {
245*523fa7a6SAndroid Build Coastguard Worker     const auto fb_str = value->value_as_String()->string_val();
246*523fa7a6SAndroid Build Coastguard Worker     std::string string(fb_str->cbegin(), fb_str->cend());
247*523fa7a6SAndroid Build Coastguard Worker     ValueRef ref = compute_graph_->add_string(std::move(string));
248*523fa7a6SAndroid Build Coastguard Worker     ref_mapping_[fb_id] = ref;
249*523fa7a6SAndroid Build Coastguard Worker   }
250*523fa7a6SAndroid Build Coastguard Worker 
add_symint_to_graph(const uint32_t fb_id,VkValuePtr value)251*523fa7a6SAndroid Build Coastguard Worker   void add_symint_to_graph(const uint32_t fb_id, VkValuePtr value) {
252*523fa7a6SAndroid Build Coastguard Worker     const int32_t fb_symint = value->value_as_SymInt()->value();
253*523fa7a6SAndroid Build Coastguard Worker     ValueRef ref = compute_graph_->add_symint(fb_symint);
254*523fa7a6SAndroid Build Coastguard Worker     ref_mapping_[fb_id] = ref;
255*523fa7a6SAndroid Build Coastguard Worker   }
256*523fa7a6SAndroid Build Coastguard Worker 
add_value_to_graph(const uint32_t fb_id,VkValuePtr value)257*523fa7a6SAndroid Build Coastguard Worker   void add_value_to_graph(const uint32_t fb_id, VkValuePtr value) {
258*523fa7a6SAndroid Build Coastguard Worker     ET_CHECK_MSG(
259*523fa7a6SAndroid Build Coastguard Worker         !fb_id_exists(fb_id),
260*523fa7a6SAndroid Build Coastguard Worker         "Trying to add a value that has already been added to the graph.");
261*523fa7a6SAndroid Build Coastguard Worker 
262*523fa7a6SAndroid Build Coastguard Worker     switch (value->value_type()) {
263*523fa7a6SAndroid Build Coastguard Worker       case vkgraph::GraphTypes::Null:
264*523fa7a6SAndroid Build Coastguard Worker         add_none_to_graph(fb_id);
265*523fa7a6SAndroid Build Coastguard Worker         break;
266*523fa7a6SAndroid Build Coastguard Worker       case vkgraph::GraphTypes::Int:
267*523fa7a6SAndroid Build Coastguard Worker         add_scalar_to_graph(fb_id, value->value_as_Int()->int_val());
268*523fa7a6SAndroid Build Coastguard Worker         break;
269*523fa7a6SAndroid Build Coastguard Worker       case vkgraph::GraphTypes::Double:
270*523fa7a6SAndroid Build Coastguard Worker         add_scalar_to_graph(fb_id, value->value_as_Double()->double_val());
271*523fa7a6SAndroid Build Coastguard Worker         break;
272*523fa7a6SAndroid Build Coastguard Worker       case vkgraph::GraphTypes::Bool:
273*523fa7a6SAndroid Build Coastguard Worker         add_scalar_to_graph(fb_id, value->value_as_Bool()->bool_val());
274*523fa7a6SAndroid Build Coastguard Worker         break;
275*523fa7a6SAndroid Build Coastguard Worker       case vkgraph::GraphTypes::VkTensor:
276*523fa7a6SAndroid Build Coastguard Worker         add_tensor_to_graph(fb_id, value->value_as_VkTensor());
277*523fa7a6SAndroid Build Coastguard Worker         break;
278*523fa7a6SAndroid Build Coastguard Worker       case vkgraph::GraphTypes::IntList:
279*523fa7a6SAndroid Build Coastguard Worker         add_scalar_list_to_graph(
280*523fa7a6SAndroid Build Coastguard Worker             fb_id,
281*523fa7a6SAndroid Build Coastguard Worker             std::vector<int64_t>(
282*523fa7a6SAndroid Build Coastguard Worker                 value->value_as_IntList()->items()->cbegin(),
283*523fa7a6SAndroid Build Coastguard Worker                 value->value_as_IntList()->items()->cend()));
284*523fa7a6SAndroid Build Coastguard Worker         break;
285*523fa7a6SAndroid Build Coastguard Worker       case vkgraph::GraphTypes::DoubleList:
286*523fa7a6SAndroid Build Coastguard Worker         add_scalar_list_to_graph(
287*523fa7a6SAndroid Build Coastguard Worker             fb_id,
288*523fa7a6SAndroid Build Coastguard Worker             std::vector<double>(
289*523fa7a6SAndroid Build Coastguard Worker                 value->value_as_DoubleList()->items()->cbegin(),
290*523fa7a6SAndroid Build Coastguard Worker                 value->value_as_DoubleList()->items()->cend()));
291*523fa7a6SAndroid Build Coastguard Worker         break;
292*523fa7a6SAndroid Build Coastguard Worker       case vkgraph::GraphTypes::BoolList:
293*523fa7a6SAndroid Build Coastguard Worker         add_scalar_list_to_graph(
294*523fa7a6SAndroid Build Coastguard Worker             fb_id,
295*523fa7a6SAndroid Build Coastguard Worker             std::vector<bool>(
296*523fa7a6SAndroid Build Coastguard Worker                 value->value_as_BoolList()->items()->cbegin(),
297*523fa7a6SAndroid Build Coastguard Worker                 value->value_as_BoolList()->items()->cend()));
298*523fa7a6SAndroid Build Coastguard Worker         break;
299*523fa7a6SAndroid Build Coastguard Worker       case vkgraph::GraphTypes::ValueList:
300*523fa7a6SAndroid Build Coastguard Worker         add_value_list_to_graph(
301*523fa7a6SAndroid Build Coastguard Worker             fb_id,
302*523fa7a6SAndroid Build Coastguard Worker             std::vector<ValueRef>(
303*523fa7a6SAndroid Build Coastguard Worker                 value->value_as_ValueList()->items()->cbegin(),
304*523fa7a6SAndroid Build Coastguard Worker                 value->value_as_ValueList()->items()->cend()));
305*523fa7a6SAndroid Build Coastguard Worker         break;
306*523fa7a6SAndroid Build Coastguard Worker       case vkgraph::GraphTypes::String:
307*523fa7a6SAndroid Build Coastguard Worker         add_string_to_graph(fb_id, value);
308*523fa7a6SAndroid Build Coastguard Worker         break;
309*523fa7a6SAndroid Build Coastguard Worker       case vkgraph::GraphTypes::SymInt:
310*523fa7a6SAndroid Build Coastguard Worker         add_symint_to_graph(fb_id, value);
311*523fa7a6SAndroid Build Coastguard Worker         break;
312*523fa7a6SAndroid Build Coastguard Worker       default:
313*523fa7a6SAndroid Build Coastguard Worker         ET_CHECK_MSG(false, "Unsupported value type.");
314*523fa7a6SAndroid Build Coastguard Worker     }
315*523fa7a6SAndroid Build Coastguard Worker   }
316*523fa7a6SAndroid Build Coastguard Worker 
build_graph()317*523fa7a6SAndroid Build Coastguard Worker   void build_graph() {
318*523fa7a6SAndroid Build Coastguard Worker     // First, add all values to the graph
319*523fa7a6SAndroid Build Coastguard Worker     for (uint32_t fb_id = 0; fb_id < flatbuffer_->values()->size(); ++fb_id) {
320*523fa7a6SAndroid Build Coastguard Worker       VkValuePtr value = flatbuffer_->values()->Get(fb_id);
321*523fa7a6SAndroid Build Coastguard Worker       add_value_to_graph(fb_id, value);
322*523fa7a6SAndroid Build Coastguard Worker     }
323*523fa7a6SAndroid Build Coastguard Worker 
324*523fa7a6SAndroid Build Coastguard Worker     // Parse the inputs, which will be tensors most of the time but can also be
325*523fa7a6SAndroid Build Coastguard Worker     // symints and tensorrefs (which will be the case if the original graph had)
326*523fa7a6SAndroid Build Coastguard Worker     // mutable buffers.
327*523fa7a6SAndroid Build Coastguard Worker     for (const uint32_t fb_id : *flatbuffer_->input_ids()) {
328*523fa7a6SAndroid Build Coastguard Worker       const ValueRef ref = get_fb_id_valueref(fb_id);
329*523fa7a6SAndroid Build Coastguard Worker       if (compute_graph_->val_is_tensor(ref)) {
330*523fa7a6SAndroid Build Coastguard Worker         compute_graph_->set_input_tensor(ref);
331*523fa7a6SAndroid Build Coastguard Worker       } else {
332*523fa7a6SAndroid Build Coastguard Worker         compute_graph_->set_val_as_input(ref);
333*523fa7a6SAndroid Build Coastguard Worker       }
334*523fa7a6SAndroid Build Coastguard Worker     }
335*523fa7a6SAndroid Build Coastguard Worker 
336*523fa7a6SAndroid Build Coastguard Worker     // Parse the operators
337*523fa7a6SAndroid Build Coastguard Worker     uint32_t last_prepack_node_ct = 0;
338*523fa7a6SAndroid Build Coastguard Worker     uint32_t last_execute_node_ct = 0;
339*523fa7a6SAndroid Build Coastguard Worker 
340*523fa7a6SAndroid Build Coastguard Worker     for (OpCallPtr op_call : *(flatbuffer_->chain())) {
341*523fa7a6SAndroid Build Coastguard Worker       std::string op_name = op_call->name()->str();
342*523fa7a6SAndroid Build Coastguard Worker       ET_CHECK_MSG(VK_HAS_OP(op_name), "Missing operator: %s", op_name.c_str());
343*523fa7a6SAndroid Build Coastguard Worker 
344*523fa7a6SAndroid Build Coastguard Worker       const std::vector<int> arg_fb_ids(
345*523fa7a6SAndroid Build Coastguard Worker           op_call->args()->cbegin(), op_call->args()->cend());
346*523fa7a6SAndroid Build Coastguard Worker 
347*523fa7a6SAndroid Build Coastguard Worker       std::vector<ValueRef> args;
348*523fa7a6SAndroid Build Coastguard Worker       for (const int arg_fb_id : arg_fb_ids) {
349*523fa7a6SAndroid Build Coastguard Worker         args.push_back(get_fb_id_valueref(arg_fb_id));
350*523fa7a6SAndroid Build Coastguard Worker       }
351*523fa7a6SAndroid Build Coastguard Worker 
352*523fa7a6SAndroid Build Coastguard Worker       auto vkFn = VK_GET_OP_FN(op_name);
353*523fa7a6SAndroid Build Coastguard Worker       vkFn(*compute_graph_, args);
354*523fa7a6SAndroid Build Coastguard Worker       if (compute_graph_->graphconfig().enable_querypool) {
355*523fa7a6SAndroid Build Coastguard Worker         for (uint32_t idx_prepack = last_prepack_node_ct;
356*523fa7a6SAndroid Build Coastguard Worker              idx_prepack < compute_graph_->prepack_nodes().size();
357*523fa7a6SAndroid Build Coastguard Worker              idx_prepack++) {
358*523fa7a6SAndroid Build Coastguard Worker           compute_graph_->prepack_nodes()[idx_prepack]->set_node_id(
359*523fa7a6SAndroid Build Coastguard Worker               op_call->node_id());
360*523fa7a6SAndroid Build Coastguard Worker         }
361*523fa7a6SAndroid Build Coastguard Worker         for (uint32_t idx_execute = last_execute_node_ct;
362*523fa7a6SAndroid Build Coastguard Worker              idx_execute < compute_graph_->execute_nodes().size();
363*523fa7a6SAndroid Build Coastguard Worker              idx_execute++) {
364*523fa7a6SAndroid Build Coastguard Worker           compute_graph_->execute_nodes()[idx_execute]->set_node_id(
365*523fa7a6SAndroid Build Coastguard Worker               op_call->node_id());
366*523fa7a6SAndroid Build Coastguard Worker         }
367*523fa7a6SAndroid Build Coastguard Worker         last_prepack_node_ct = compute_graph_->prepack_nodes().size();
368*523fa7a6SAndroid Build Coastguard Worker         last_execute_node_ct = compute_graph_->execute_nodes().size();
369*523fa7a6SAndroid Build Coastguard Worker       }
370*523fa7a6SAndroid Build Coastguard Worker     }
371*523fa7a6SAndroid Build Coastguard Worker 
372*523fa7a6SAndroid Build Coastguard Worker     // Parse the outputs, which will be mostly tensors.  For some reason,
373*523fa7a6SAndroid Build Coastguard Worker     // mutable buffers are shown to be returned in the fx.Graph but do not get
374*523fa7a6SAndroid Build Coastguard Worker     // returned by the delegate; this may be an implementation detail of how the
375*523fa7a6SAndroid Build Coastguard Worker     // executorch emitter handles mutable buffers.
376*523fa7a6SAndroid Build Coastguard Worker     for (const uint32_t fb_id : *flatbuffer_->output_ids()) {
377*523fa7a6SAndroid Build Coastguard Worker       const ValueRef ref = get_fb_id_valueref(fb_id);
378*523fa7a6SAndroid Build Coastguard Worker       if (compute_graph_->val_is_tensor(ref)) {
379*523fa7a6SAndroid Build Coastguard Worker         compute_graph_->set_output_tensor(ref);
380*523fa7a6SAndroid Build Coastguard Worker       }
381*523fa7a6SAndroid Build Coastguard Worker     }
382*523fa7a6SAndroid Build Coastguard Worker   }
383*523fa7a6SAndroid Build Coastguard Worker };
384*523fa7a6SAndroid Build Coastguard Worker 
385*523fa7a6SAndroid Build Coastguard Worker //
386*523fa7a6SAndroid Build Coastguard Worker // Execution tools
387*523fa7a6SAndroid Build Coastguard Worker //
388*523fa7a6SAndroid Build Coastguard Worker 
maybe_resize_input(ComputeGraph * graph,const size_t input_i,executorch::aten::Tensor & et_tensor)389*523fa7a6SAndroid Build Coastguard Worker bool maybe_resize_input(
390*523fa7a6SAndroid Build Coastguard Worker     ComputeGraph* graph,
391*523fa7a6SAndroid Build Coastguard Worker     const size_t input_i,
392*523fa7a6SAndroid Build Coastguard Worker     executorch::aten::Tensor& et_tensor) {
393*523fa7a6SAndroid Build Coastguard Worker   ValueRef in_tensor_ref = graph->inputs()[input_i].value;
394*523fa7a6SAndroid Build Coastguard Worker   vTensorPtr in_tensor = graph->get_tensor(in_tensor_ref);
395*523fa7a6SAndroid Build Coastguard Worker 
396*523fa7a6SAndroid Build Coastguard Worker   ET_CHECK_MSG(
397*523fa7a6SAndroid Build Coastguard Worker       et_tensor.dim() == in_tensor->sizes().size(),
398*523fa7a6SAndroid Build Coastguard Worker       "Cannot resize input tensor: old ndim %zu does not match new ndim %zu",
399*523fa7a6SAndroid Build Coastguard Worker       static_cast<size_t>(in_tensor->sizes().size()),
400*523fa7a6SAndroid Build Coastguard Worker       static_cast<size_t>(et_tensor.dim()));
401*523fa7a6SAndroid Build Coastguard Worker 
402*523fa7a6SAndroid Build Coastguard Worker   bool should_resize = false;
403*523fa7a6SAndroid Build Coastguard Worker   std::vector<int64_t> new_sizes(et_tensor.dim());
404*523fa7a6SAndroid Build Coastguard Worker   for (size_t i = 0; i < et_tensor.dim(); i++) {
405*523fa7a6SAndroid Build Coastguard Worker     if (in_tensor->sizes()[i] != et_tensor.sizes()[i]) {
406*523fa7a6SAndroid Build Coastguard Worker       should_resize = true;
407*523fa7a6SAndroid Build Coastguard Worker     }
408*523fa7a6SAndroid Build Coastguard Worker     new_sizes.at(i) = et_tensor.sizes()[i];
409*523fa7a6SAndroid Build Coastguard Worker   }
410*523fa7a6SAndroid Build Coastguard Worker 
411*523fa7a6SAndroid Build Coastguard Worker   if (should_resize) {
412*523fa7a6SAndroid Build Coastguard Worker     graph->resize_input(input_i, new_sizes);
413*523fa7a6SAndroid Build Coastguard Worker   }
414*523fa7a6SAndroid Build Coastguard Worker 
415*523fa7a6SAndroid Build Coastguard Worker   ET_CHECK_MSG(
416*523fa7a6SAndroid Build Coastguard Worker       in_tensor->numel() == et_tensor.numel(),
417*523fa7a6SAndroid Build Coastguard Worker       "Vulkan tensor numel %zu does not match ET tensor numel %zu",
418*523fa7a6SAndroid Build Coastguard Worker       static_cast<size_t>(in_tensor->numel()),
419*523fa7a6SAndroid Build Coastguard Worker       static_cast<size_t>(et_tensor.numel()));
420*523fa7a6SAndroid Build Coastguard Worker 
421*523fa7a6SAndroid Build Coastguard Worker   return should_resize;
422*523fa7a6SAndroid Build Coastguard Worker }
423*523fa7a6SAndroid Build Coastguard Worker 
maybe_update_scalar_tensor(ComputeGraph * graph,const ValueRef ref,executorch::aten::Tensor & scalar_tensor_src)424*523fa7a6SAndroid Build Coastguard Worker bool maybe_update_scalar_tensor(
425*523fa7a6SAndroid Build Coastguard Worker     ComputeGraph* graph,
426*523fa7a6SAndroid Build Coastguard Worker     const ValueRef ref,
427*523fa7a6SAndroid Build Coastguard Worker     executorch::aten::Tensor& scalar_tensor_src) {
428*523fa7a6SAndroid Build Coastguard Worker   const int32_t cur_val = graph->read_symint(ref);
429*523fa7a6SAndroid Build Coastguard Worker   int32_t scalar_tensor_val = 0;
430*523fa7a6SAndroid Build Coastguard Worker   exec_aten::ScalarType dtype = scalar_tensor_src.scalar_type();
431*523fa7a6SAndroid Build Coastguard Worker   if (dtype == exec_aten::ScalarType::Int) {
432*523fa7a6SAndroid Build Coastguard Worker     scalar_tensor_val = *scalar_tensor_src.const_data_ptr<int32_t>();
433*523fa7a6SAndroid Build Coastguard Worker   } else if (dtype == exec_aten::ScalarType::Long) {
434*523fa7a6SAndroid Build Coastguard Worker     scalar_tensor_val = int32_t(*scalar_tensor_src.const_data_ptr<int64_t>());
435*523fa7a6SAndroid Build Coastguard Worker   }
436*523fa7a6SAndroid Build Coastguard Worker   bool was_updated = false;
437*523fa7a6SAndroid Build Coastguard Worker   if (scalar_tensor_val != cur_val) {
438*523fa7a6SAndroid Build Coastguard Worker     graph->set_symint(ref, scalar_tensor_val);
439*523fa7a6SAndroid Build Coastguard Worker     was_updated = true;
440*523fa7a6SAndroid Build Coastguard Worker   }
441*523fa7a6SAndroid Build Coastguard Worker   return was_updated;
442*523fa7a6SAndroid Build Coastguard Worker }
443*523fa7a6SAndroid Build Coastguard Worker 
maybe_resize_output(ComputeGraph * graph,const size_t output_i,executorch::aten::Tensor & et_tensor)444*523fa7a6SAndroid Build Coastguard Worker void maybe_resize_output(
445*523fa7a6SAndroid Build Coastguard Worker     ComputeGraph* graph,
446*523fa7a6SAndroid Build Coastguard Worker     const size_t output_i,
447*523fa7a6SAndroid Build Coastguard Worker     executorch::aten::Tensor& et_tensor) {
448*523fa7a6SAndroid Build Coastguard Worker   ValueRef out_tensor_ref = graph->outputs()[output_i].value;
449*523fa7a6SAndroid Build Coastguard Worker   vTensorPtr out_tensor = graph->get_tensor(out_tensor_ref);
450*523fa7a6SAndroid Build Coastguard Worker 
451*523fa7a6SAndroid Build Coastguard Worker   executorch::aten::SizesType new_output_size[kTensorDimensionLimit];
452*523fa7a6SAndroid Build Coastguard Worker   size_t ndim = out_tensor->sizes().size();
453*523fa7a6SAndroid Build Coastguard Worker   for (int i = 0; i < ndim; ++i) {
454*523fa7a6SAndroid Build Coastguard Worker     new_output_size[i] = out_tensor->sizes()[i];
455*523fa7a6SAndroid Build Coastguard Worker   }
456*523fa7a6SAndroid Build Coastguard Worker 
457*523fa7a6SAndroid Build Coastguard Worker   executorch::aten::ArrayRef<executorch::aten::SizesType> output_size{
458*523fa7a6SAndroid Build Coastguard Worker       new_output_size, ndim};
459*523fa7a6SAndroid Build Coastguard Worker   Error err = resize_tensor(et_tensor, output_size);
460*523fa7a6SAndroid Build Coastguard Worker 
461*523fa7a6SAndroid Build Coastguard Worker   ET_CHECK_MSG(err == Error::Ok, "Failed to resize output tensor.");
462*523fa7a6SAndroid Build Coastguard Worker }
463*523fa7a6SAndroid Build Coastguard Worker 
464*523fa7a6SAndroid Build Coastguard Worker //
465*523fa7a6SAndroid Build Coastguard Worker // VulkanBackend class
466*523fa7a6SAndroid Build Coastguard Worker //
467*523fa7a6SAndroid Build Coastguard Worker 
468*523fa7a6SAndroid Build Coastguard Worker class VulkanBackend final : public ::executorch::runtime::BackendInterface {
469*523fa7a6SAndroid Build Coastguard Worker  public:
470*523fa7a6SAndroid Build Coastguard Worker   ~VulkanBackend() override = default;
471*523fa7a6SAndroid Build Coastguard Worker 
is_available() const472*523fa7a6SAndroid Build Coastguard Worker   bool is_available() const override {
473*523fa7a6SAndroid Build Coastguard Worker     // TODO(ssjia): replace with an actual Vulkan runtime availability check
474*523fa7a6SAndroid Build Coastguard Worker     return true;
475*523fa7a6SAndroid Build Coastguard Worker   }
476*523fa7a6SAndroid Build Coastguard Worker 
477*523fa7a6SAndroid Build Coastguard Worker   ET_NODISCARD Error
compileModel(const void * buffer_pointer,ComputeGraph * compute_graph) const478*523fa7a6SAndroid Build Coastguard Worker   compileModel(const void* buffer_pointer, ComputeGraph* compute_graph) const {
479*523fa7a6SAndroid Build Coastguard Worker     Result<VulkanDelegateHeader> header =
480*523fa7a6SAndroid Build Coastguard Worker         VulkanDelegateHeader::parse(buffer_pointer);
481*523fa7a6SAndroid Build Coastguard Worker 
482*523fa7a6SAndroid Build Coastguard Worker     const uint8_t* flatbuffer_data = nullptr;
483*523fa7a6SAndroid Build Coastguard Worker     const uint8_t* constant_data = nullptr;
484*523fa7a6SAndroid Build Coastguard Worker 
485*523fa7a6SAndroid Build Coastguard Worker     if (header.ok()) {
486*523fa7a6SAndroid Build Coastguard Worker       const uint8_t* buffer_start =
487*523fa7a6SAndroid Build Coastguard Worker           reinterpret_cast<const uint8_t*>(buffer_pointer);
488*523fa7a6SAndroid Build Coastguard Worker       flatbuffer_data = buffer_start + header->flatbuffer_offset;
489*523fa7a6SAndroid Build Coastguard Worker       constant_data = buffer_start + header->bytes_offset;
490*523fa7a6SAndroid Build Coastguard Worker     } else {
491*523fa7a6SAndroid Build Coastguard Worker       ET_LOG(Error, "VulkanDelegateHeader may be corrupt");
492*523fa7a6SAndroid Build Coastguard Worker       return header.error();
493*523fa7a6SAndroid Build Coastguard Worker     }
494*523fa7a6SAndroid Build Coastguard Worker 
495*523fa7a6SAndroid Build Coastguard Worker     ET_CHECK_OR_RETURN_ERROR(
496*523fa7a6SAndroid Build Coastguard Worker         vkgraph::VkGraphBufferHasIdentifier(flatbuffer_data),
497*523fa7a6SAndroid Build Coastguard Worker         DelegateInvalidCompatibility,
498*523fa7a6SAndroid Build Coastguard Worker         "Vulkan Delegate Serialization Format version identifier '%.4s' != expected '%.4s'",
499*523fa7a6SAndroid Build Coastguard Worker         flatbuffers::GetBufferIdentifier(flatbuffer_data),
500*523fa7a6SAndroid Build Coastguard Worker         vkgraph::VkGraphIdentifier());
501*523fa7a6SAndroid Build Coastguard Worker 
502*523fa7a6SAndroid Build Coastguard Worker     VkGraphPtr flatbuffer_graph = vkgraph::GetVkGraph(flatbuffer_data);
503*523fa7a6SAndroid Build Coastguard Worker 
504*523fa7a6SAndroid Build Coastguard Worker     GraphBuilder builder =
505*523fa7a6SAndroid Build Coastguard Worker         GraphBuilder(compute_graph, flatbuffer_graph, constant_data);
506*523fa7a6SAndroid Build Coastguard Worker 
507*523fa7a6SAndroid Build Coastguard Worker     builder.build_graph();
508*523fa7a6SAndroid Build Coastguard Worker 
509*523fa7a6SAndroid Build Coastguard Worker     compute_graph->prepare();
510*523fa7a6SAndroid Build Coastguard Worker 
511*523fa7a6SAndroid Build Coastguard Worker     compute_graph->encode_prepack();
512*523fa7a6SAndroid Build Coastguard Worker     compute_graph->prepack();
513*523fa7a6SAndroid Build Coastguard Worker 
514*523fa7a6SAndroid Build Coastguard Worker     compute_graph->encode_execute();
515*523fa7a6SAndroid Build Coastguard Worker 
516*523fa7a6SAndroid Build Coastguard Worker     return Error::Ok;
517*523fa7a6SAndroid Build Coastguard Worker   }
518*523fa7a6SAndroid Build Coastguard Worker 
init(BackendInitContext & context,FreeableBuffer * processed,ArrayRef<CompileSpec> compile_specs) const519*523fa7a6SAndroid Build Coastguard Worker   Result<DelegateHandle*> init(
520*523fa7a6SAndroid Build Coastguard Worker       BackendInitContext& context,
521*523fa7a6SAndroid Build Coastguard Worker       FreeableBuffer* processed,
522*523fa7a6SAndroid Build Coastguard Worker       ArrayRef<CompileSpec> compile_specs) const override {
523*523fa7a6SAndroid Build Coastguard Worker     ComputeGraph* compute_graph = ET_ALLOCATE_INSTANCE_OR_RETURN_ERROR(
524*523fa7a6SAndroid Build Coastguard Worker         context.get_runtime_allocator(), ComputeGraph);
525*523fa7a6SAndroid Build Coastguard Worker 
526*523fa7a6SAndroid Build Coastguard Worker     new (compute_graph) ComputeGraph(get_graph_config(compile_specs));
527*523fa7a6SAndroid Build Coastguard Worker 
528*523fa7a6SAndroid Build Coastguard Worker     Error err = compileModel(processed->data(), compute_graph);
529*523fa7a6SAndroid Build Coastguard Worker 
530*523fa7a6SAndroid Build Coastguard Worker     // This backend does not need its processed data after compiling the
531*523fa7a6SAndroid Build Coastguard Worker     // model.
532*523fa7a6SAndroid Build Coastguard Worker     processed->Free();
533*523fa7a6SAndroid Build Coastguard Worker 
534*523fa7a6SAndroid Build Coastguard Worker     if (err != Error::Ok) {
535*523fa7a6SAndroid Build Coastguard Worker       return err;
536*523fa7a6SAndroid Build Coastguard Worker     }
537*523fa7a6SAndroid Build Coastguard Worker 
538*523fa7a6SAndroid Build Coastguard Worker     return compute_graph;
539*523fa7a6SAndroid Build Coastguard Worker   }
540*523fa7a6SAndroid Build Coastguard Worker 
execute(ET_UNUSED BackendExecutionContext & context,DelegateHandle * handle,EValue ** args) const541*523fa7a6SAndroid Build Coastguard Worker   Error execute(
542*523fa7a6SAndroid Build Coastguard Worker       ET_UNUSED BackendExecutionContext& context,
543*523fa7a6SAndroid Build Coastguard Worker       DelegateHandle* handle,
544*523fa7a6SAndroid Build Coastguard Worker       EValue** args) const override {
545*523fa7a6SAndroid Build Coastguard Worker     EXECUTORCH_SCOPE_PROF("VulkanBackend::execute");
546*523fa7a6SAndroid Build Coastguard Worker 
547*523fa7a6SAndroid Build Coastguard Worker     ComputeGraph* compute_graph = static_cast<ComputeGraph*>(handle);
548*523fa7a6SAndroid Build Coastguard Worker 
549*523fa7a6SAndroid Build Coastguard Worker     const size_t num_inputs = compute_graph->inputs().size();
550*523fa7a6SAndroid Build Coastguard Worker     bool should_propagate_resize = false;
551*523fa7a6SAndroid Build Coastguard Worker     for (size_t i = 0; i < num_inputs; i++) {
552*523fa7a6SAndroid Build Coastguard Worker       const ValueRef iref = compute_graph->inputs()[i].value;
553*523fa7a6SAndroid Build Coastguard Worker       if (compute_graph->val_is_tensor(iref)) {
554*523fa7a6SAndroid Build Coastguard Worker         VK_CHECK_COND(args[i]->isTensor());
555*523fa7a6SAndroid Build Coastguard Worker         bool was_resized =
556*523fa7a6SAndroid Build Coastguard Worker             maybe_resize_input(compute_graph, i, args[i]->toTensor());
557*523fa7a6SAndroid Build Coastguard Worker         should_propagate_resize = should_propagate_resize || was_resized;
558*523fa7a6SAndroid Build Coastguard Worker         compute_graph->copy_into_staging(
559*523fa7a6SAndroid Build Coastguard Worker             compute_graph->inputs()[i].staging,
560*523fa7a6SAndroid Build Coastguard Worker             args[i]->toTensor().const_data_ptr(),
561*523fa7a6SAndroid Build Coastguard Worker             args[i]->toTensor().numel());
562*523fa7a6SAndroid Build Coastguard Worker       } else if (compute_graph->val_is_symint(iref)) {
563*523fa7a6SAndroid Build Coastguard Worker         VK_CHECK_COND(
564*523fa7a6SAndroid Build Coastguard Worker             args[i]->isTensor(),
565*523fa7a6SAndroid Build Coastguard Worker             "Cannot handle symint arg to graph that is not derived from a "
566*523fa7a6SAndroid Build Coastguard Worker             "scalar tensor at the moment.");
567*523fa7a6SAndroid Build Coastguard Worker         bool was_updated = maybe_update_scalar_tensor(
568*523fa7a6SAndroid Build Coastguard Worker             compute_graph, iref, args[i]->toTensor());
569*523fa7a6SAndroid Build Coastguard Worker         // Since symint inputs may impact tensor's sizes, trigger a resize if
570*523fa7a6SAndroid Build Coastguard Worker         // any symbolic integer shapes are updated.
571*523fa7a6SAndroid Build Coastguard Worker         should_propagate_resize = should_propagate_resize || was_updated;
572*523fa7a6SAndroid Build Coastguard Worker       } else {
573*523fa7a6SAndroid Build Coastguard Worker         VK_THROW(
574*523fa7a6SAndroid Build Coastguard Worker             "Could not handle input with type ",
575*523fa7a6SAndroid Build Coastguard Worker             compute_graph->get_val_type(iref));
576*523fa7a6SAndroid Build Coastguard Worker       }
577*523fa7a6SAndroid Build Coastguard Worker     }
578*523fa7a6SAndroid Build Coastguard Worker 
579*523fa7a6SAndroid Build Coastguard Worker     if (should_propagate_resize) {
580*523fa7a6SAndroid Build Coastguard Worker       compute_graph->propagate_resize();
581*523fa7a6SAndroid Build Coastguard Worker     }
582*523fa7a6SAndroid Build Coastguard Worker     compute_graph->execute();
583*523fa7a6SAndroid Build Coastguard Worker 
584*523fa7a6SAndroid Build Coastguard Worker     for (size_t i = 0; i < compute_graph->outputs().size(); i++) {
585*523fa7a6SAndroid Build Coastguard Worker       const ValueRef oref = compute_graph->outputs()[i].value;
586*523fa7a6SAndroid Build Coastguard Worker       if (compute_graph->val_is_tensor(oref)) {
587*523fa7a6SAndroid Build Coastguard Worker         VK_CHECK_COND(args[i]->isTensor());
588*523fa7a6SAndroid Build Coastguard Worker         maybe_resize_output(compute_graph, i, args[num_inputs + i]->toTensor());
589*523fa7a6SAndroid Build Coastguard Worker         // args holds inputs directly followed by outputs, so the i'th output
590*523fa7a6SAndroid Build Coastguard Worker         // for compute_graph corresponds to the (i + num_inputs)'th arg
591*523fa7a6SAndroid Build Coastguard Worker         compute_graph->copy_from_staging(
592*523fa7a6SAndroid Build Coastguard Worker             compute_graph->outputs()[i].staging,
593*523fa7a6SAndroid Build Coastguard Worker             args[num_inputs + i]->toTensor().mutable_data_ptr(),
594*523fa7a6SAndroid Build Coastguard Worker             args[num_inputs + i]->toTensor().numel());
595*523fa7a6SAndroid Build Coastguard Worker       } else {
596*523fa7a6SAndroid Build Coastguard Worker         VK_THROW(
597*523fa7a6SAndroid Build Coastguard Worker             "Could not handle output with type ",
598*523fa7a6SAndroid Build Coastguard Worker             compute_graph->get_val_type(oref));
599*523fa7a6SAndroid Build Coastguard Worker       }
600*523fa7a6SAndroid Build Coastguard Worker     }
601*523fa7a6SAndroid Build Coastguard Worker 
602*523fa7a6SAndroid Build Coastguard Worker #ifdef ET_EVENT_TRACER_ENABLED
603*523fa7a6SAndroid Build Coastguard Worker     runtime::EventTracer* event_tracer = context.event_tracer();
604*523fa7a6SAndroid Build Coastguard Worker     compute_graph->context()->querypool().extract_results();
605*523fa7a6SAndroid Build Coastguard Worker     for (const auto& tup :
606*523fa7a6SAndroid Build Coastguard Worker          compute_graph->context()->querypool().get_shader_timestamp_data()) {
607*523fa7a6SAndroid Build Coastguard Worker       std::string event_name =
608*523fa7a6SAndroid Build Coastguard Worker           std::get<0>(tup) + "_" + std::to_string(std::get<1>(tup));
609*523fa7a6SAndroid Build Coastguard Worker       event_tracer_log_profiling_delegate(
610*523fa7a6SAndroid Build Coastguard Worker           event_tracer,
611*523fa7a6SAndroid Build Coastguard Worker           event_name.c_str(),
612*523fa7a6SAndroid Build Coastguard Worker           -1,
613*523fa7a6SAndroid Build Coastguard Worker           std::get<2>(tup),
614*523fa7a6SAndroid Build Coastguard Worker           std::get<3>(tup));
615*523fa7a6SAndroid Build Coastguard Worker     }
616*523fa7a6SAndroid Build Coastguard Worker #endif // ET_EVENT_TRACER_ENABLED
617*523fa7a6SAndroid Build Coastguard Worker 
618*523fa7a6SAndroid Build Coastguard Worker     return Error::Ok;
619*523fa7a6SAndroid Build Coastguard Worker   }
620*523fa7a6SAndroid Build Coastguard Worker 
destroy(DelegateHandle * handle) const621*523fa7a6SAndroid Build Coastguard Worker   void destroy(DelegateHandle* handle) const override {
622*523fa7a6SAndroid Build Coastguard Worker     if (handle != nullptr) {
623*523fa7a6SAndroid Build Coastguard Worker       ComputeGraph* compute_graph = static_cast<ComputeGraph*>(handle);
624*523fa7a6SAndroid Build Coastguard Worker       compute_graph->context()
625*523fa7a6SAndroid Build Coastguard Worker           ->adapter_ptr()
626*523fa7a6SAndroid Build Coastguard Worker           ->compute_pipeline_cache()
627*523fa7a6SAndroid Build Coastguard Worker           .save_cache();
628*523fa7a6SAndroid Build Coastguard Worker       // ComputeGraph is not trivially destructible. Since
629*523fa7a6SAndroid Build Coastguard Worker       // this was constructed manually in init(), we must destroy it manually
630*523fa7a6SAndroid Build Coastguard Worker       // here.
631*523fa7a6SAndroid Build Coastguard Worker       compute_graph->~ComputeGraph();
632*523fa7a6SAndroid Build Coastguard Worker     }
633*523fa7a6SAndroid Build Coastguard Worker   }
634*523fa7a6SAndroid Build Coastguard Worker };
635*523fa7a6SAndroid Build Coastguard Worker 
636*523fa7a6SAndroid Build Coastguard Worker auto cls = VulkanBackend();
637*523fa7a6SAndroid Build Coastguard Worker Backend backend{"VulkanBackend", &cls};
638*523fa7a6SAndroid Build Coastguard Worker static auto success_with_compiler = register_backend(backend);
639*523fa7a6SAndroid Build Coastguard Worker 
640*523fa7a6SAndroid Build Coastguard Worker } // namespace
641*523fa7a6SAndroid Build Coastguard Worker } // namespace vulkan
642*523fa7a6SAndroid Build Coastguard Worker } // namespace backends
643*523fa7a6SAndroid Build Coastguard Worker } // namespace executorch
644