xref: /aosp_15_r20/external/executorch/backends/vulkan/runtime/VulkanBackend.cpp (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <executorch/backends/vulkan/runtime/VulkanDelegateHeader.h>
10 #include <executorch/backends/vulkan/serialization/schema_generated.h>
11 
12 #include <executorch/backends/vulkan/runtime/graph/ComputeGraph.h>
13 
14 #include <executorch/backends/vulkan/runtime/graph/ops/OperatorRegistry.h>
15 
16 #include <executorch/runtime/backend/interface.h>
17 #include <executorch/runtime/core/error.h>
18 #include <executorch/runtime/core/evalue.h>
19 #ifdef ET_EVENT_TRACER_ENABLED
20 #include <executorch/runtime/core/event_tracer_hooks_delegate.h>
21 #endif // ET_EVENT_TRACER_ENABLED
22 #include <executorch/runtime/core/exec_aten/util/tensor_util.h>
23 #include <executorch/runtime/platform/compiler.h>
24 #include <executorch/runtime/platform/profiler.h>
25 
26 #include <cstdio>
27 #include <cstdlib> /* strtol */
28 #include <cstring>
29 #include <memory>
30 #include <type_traits>
31 #include <vector>
32 
33 namespace executorch {
34 namespace backends {
35 namespace vulkan {
36 namespace {
37 
38 using executorch::runtime::ArrayRef;
39 using executorch::runtime::Backend;
40 using executorch::runtime::BackendExecutionContext;
41 using executorch::runtime::BackendInitContext;
42 using executorch::runtime::CompileSpec;
43 using executorch::runtime::DelegateHandle;
44 using executorch::runtime::Error;
45 using executorch::runtime::EValue;
46 using executorch::runtime::FreeableBuffer;
47 using executorch::runtime::kTensorDimensionLimit;
48 using executorch::runtime::Result;
49 
50 using namespace vkcompute;
51 
52 // Flatbuffer types
53 using VkGraphPtr = const vkgraph::VkGraph*;
54 using OpCallPtr = const vkgraph::OperatorCall*;
55 using VkValuePtr = const vkgraph::VkValue*;
56 using VkTensorPtr = const vkgraph::VkTensor*;
57 using VkBytesPtr = const vkgraph::VkBytes*;
58 
59 // Flatbuffer vector types
60 using VkValuesVector =
61     const flatbuffers::Vector<flatbuffers::Offset<vkgraph::VkValue>>*;
62 using BytesVector =
63     const flatbuffers::Vector<flatbuffers::Offset<vkgraph::VkBytes>>*;
64 using UIntVector = const flatbuffers::Vector<uint32_t>*;
65 
get_constant_data_ptr(VkGraphPtr flatbuffer_graph,const int32_t buffer_idx,const uint8_t * constant_data)66 const uint8_t* get_constant_data_ptr(
67     VkGraphPtr flatbuffer_graph,
68     const int32_t buffer_idx,
69     const uint8_t* constant_data) {
70   VkBytesPtr constant_bytes = flatbuffer_graph->constants()->Get(buffer_idx);
71   return constant_data + constant_bytes->offset();
72 }
73 
get_scalar_type(const vkgraph::VkDataType & vk_datatype)74 vkapi::ScalarType get_scalar_type(const vkgraph::VkDataType& vk_datatype) {
75   switch (vk_datatype) {
76     case vkgraph::VkDataType::BOOL:
77       return vkapi::kBool;
78     case vkgraph::VkDataType::UINT8:
79       return vkapi::kByte;
80     case vkgraph::VkDataType::INT8:
81       return vkapi::kChar;
82     case vkgraph::VkDataType::INT32:
83       return vkapi::kInt;
84     case vkgraph::VkDataType::FLOAT16:
85       return vkapi::kHalf;
86     case vkgraph::VkDataType::FLOAT32:
87       return vkapi::kFloat;
88   }
89 }
90 
get_storage_type(const vkgraph::VkStorageType & vk_storage_type)91 utils::StorageType get_storage_type(
92     const vkgraph::VkStorageType& vk_storage_type) {
93   switch (vk_storage_type) {
94     case vkgraph::VkStorageType::BUFFER:
95       return utils::kBuffer;
96     case vkgraph::VkStorageType::TEXTURE_3D:
97       return utils::kTexture3D;
98     case vkgraph::VkStorageType::TEXTURE_2D:
99       return utils::kTexture2D;
100     default:
101       break;
102   }
103   VK_THROW("Invalid storage type encountered!");
104 }
105 
get_memory_layout(const vkgraph::VkMemoryLayout & vk_memory_layout)106 utils::GPUMemoryLayout get_memory_layout(
107     const vkgraph::VkMemoryLayout& vk_memory_layout) {
108   switch (vk_memory_layout) {
109     case vkgraph::VkMemoryLayout::TENSOR_WIDTH_PACKED:
110       return utils::kWidthPacked;
111     case vkgraph::VkMemoryLayout::TENSOR_HEIGHT_PACKED:
112       return utils::kHeightPacked;
113     case vkgraph::VkMemoryLayout::TENSOR_CHANNELS_PACKED:
114       return utils::kChannelsPacked;
115     default:
116       break;
117   }
118   VK_THROW("Invalid memory layout encountered!");
119 }
120 
get_graph_config(ArrayRef<CompileSpec> & compile_specs)121 GraphConfig get_graph_config(ArrayRef<CompileSpec>& compile_specs) {
122   GraphConfig config = GraphConfig();
123 
124   for (const CompileSpec& spec : compile_specs) {
125     const uint8_t* value_data = (const uint8_t*)spec.value.buffer;
126     const size_t value_size = spec.value.nbytes;
127     if (strcmp(spec.key, "storage_type_override") == 0) {
128       ET_CHECK_MSG(value_size == sizeof(int32_t), "Unexpected value size!");
129       int value_as_int = static_cast<int>(getUInt32LE(value_data));
130       utils::StorageType storage_type =
131           static_cast<utils::StorageType>(value_as_int);
132 
133       config.set_storage_type_override(storage_type);
134     }
135     if (strcmp(spec.key, "memory_layout_override") == 0) {
136       ET_CHECK_MSG(value_size == sizeof(uint32_t), "Unexpected value size!");
137       uint32_t value_as_int = getUInt32LE(value_data);
138       utils::GPUMemoryLayout memory_layout =
139           static_cast<utils::GPUMemoryLayout>(value_as_int);
140 
141       config.set_memory_layout_override(memory_layout);
142     }
143   }
144 #ifdef ET_EVENT_TRACER_ENABLED
145   config.enable_querypool = true;
146 #endif // ET_EVENT_TRACER_ENABLED
147   return config;
148 }
149 
150 class GraphBuilder {
151   ComputeGraph* compute_graph_;
152   VkGraphPtr flatbuffer_;
153   const uint8_t* constant_data_;
154 
155   std::unordered_map<uint32_t, ValueRef> ref_mapping_;
156 
157  public:
GraphBuilder(ComputeGraph * compute_graph,VkGraphPtr flatbuffer,const uint8_t * constant_data)158   explicit GraphBuilder(
159       ComputeGraph* compute_graph,
160       VkGraphPtr flatbuffer,
161       const uint8_t* constant_data)
162       : compute_graph_(compute_graph),
163         flatbuffer_(flatbuffer),
164         constant_data_(constant_data),
165         ref_mapping_() {}
166 
fb_id_exists(const uint32_t fb_id)167   bool fb_id_exists(const uint32_t fb_id) {
168     const std::unordered_map<uint32_t, ValueRef>::iterator found_ref =
169         ref_mapping_.find(fb_id);
170 
171     return found_ref != ref_mapping_.end();
172   }
173 
get_fb_id_valueref(const uint32_t fb_id)174   ValueRef get_fb_id_valueref(const uint32_t fb_id) {
175     const std::unordered_map<uint32_t, ValueRef>::iterator found_ref =
176         ref_mapping_.find(fb_id);
177 
178     ET_CHECK_MSG(
179         found_ref != ref_mapping_.end(),
180         "Trying to extract a value that hasn't yet been added to the graph.");
181 
182     return found_ref->second;
183   }
184 
add_tensor_to_graph(const uint32_t fb_id,VkTensorPtr tensor_fb)185   void add_tensor_to_graph(const uint32_t fb_id, VkTensorPtr tensor_fb) {
186     const vkapi::ScalarType& dtype = get_scalar_type(tensor_fb->datatype());
187     utils::StorageType storage_type =
188         tensor_fb->storage_type() == vkgraph::VkStorageType::DEFAULT_STORAGE
189         ? compute_graph_->suggested_storage_type()
190         : get_storage_type(tensor_fb->storage_type());
191 
192     UIntVector dims_fb = tensor_fb->dims();
193     const std::vector<int64_t> dims_vector(dims_fb->cbegin(), dims_fb->cend());
194 
195     utils::GPUMemoryLayout memory_layout =
196         tensor_fb->memory_layout() == vkgraph::VkMemoryLayout::DEFAULT_LAYOUT
197         ? compute_graph_->suggested_memory_layout(dims_vector)
198         : get_memory_layout(tensor_fb->memory_layout());
199 
200     ValueRef ref;
201     if (tensor_fb->constant_id() >= 0) {
202       const uint8_t* tensor_data = get_constant_data_ptr(
203           flatbuffer_, tensor_fb->constant_id(), constant_data_);
204 
205       ref = compute_graph_->add_tensorref(dims_vector, dtype, tensor_data);
206     } else {
207       ref = compute_graph_->add_tensor(
208           dims_vector,
209           dtype,
210           storage_type,
211           memory_layout,
212           tensor_fb->mem_obj_id());
213     }
214 
215     ref_mapping_[fb_id] = ref;
216   }
217 
add_none_to_graph(const uint32_t fb_id)218   void add_none_to_graph(const uint32_t fb_id) {
219     ValueRef ref = compute_graph_->add_none();
220     ref_mapping_[fb_id] = ref;
221   }
222 
223   template <typename T>
224   typename std::enable_if<is_valid_scalar_type<T>::value, void>::type
add_scalar_to_graph(const uint32_t fb_id,T value)225   add_scalar_to_graph(const uint32_t fb_id, T value) {
226     ValueRef ref = compute_graph_->add_scalar(value);
227     ref_mapping_[fb_id] = ref;
228   }
229 
230   template <typename T>
231   typename std::enable_if<is_valid_scalar_type<T>::value, void>::type
add_scalar_list_to_graph(const uint32_t fb_id,std::vector<T> && value)232   add_scalar_list_to_graph(const uint32_t fb_id, std::vector<T>&& value) {
233     ValueRef ref = compute_graph_->add_scalar_list(std::move(value));
234     ref_mapping_[fb_id] = ref;
235   }
236 
add_value_list_to_graph(const uint32_t fb_id,std::vector<ValueRef> && value)237   void add_value_list_to_graph(
238       const uint32_t fb_id,
239       std::vector<ValueRef>&& value) {
240     ValueRef ref = compute_graph_->add_value_list(std::move(value));
241     ref_mapping_[fb_id] = ref;
242   }
243 
add_string_to_graph(const uint32_t fb_id,VkValuePtr value)244   void add_string_to_graph(const uint32_t fb_id, VkValuePtr value) {
245     const auto fb_str = value->value_as_String()->string_val();
246     std::string string(fb_str->cbegin(), fb_str->cend());
247     ValueRef ref = compute_graph_->add_string(std::move(string));
248     ref_mapping_[fb_id] = ref;
249   }
250 
add_symint_to_graph(const uint32_t fb_id,VkValuePtr value)251   void add_symint_to_graph(const uint32_t fb_id, VkValuePtr value) {
252     const int32_t fb_symint = value->value_as_SymInt()->value();
253     ValueRef ref = compute_graph_->add_symint(fb_symint);
254     ref_mapping_[fb_id] = ref;
255   }
256 
add_value_to_graph(const uint32_t fb_id,VkValuePtr value)257   void add_value_to_graph(const uint32_t fb_id, VkValuePtr value) {
258     ET_CHECK_MSG(
259         !fb_id_exists(fb_id),
260         "Trying to add a value that has already been added to the graph.");
261 
262     switch (value->value_type()) {
263       case vkgraph::GraphTypes::Null:
264         add_none_to_graph(fb_id);
265         break;
266       case vkgraph::GraphTypes::Int:
267         add_scalar_to_graph(fb_id, value->value_as_Int()->int_val());
268         break;
269       case vkgraph::GraphTypes::Double:
270         add_scalar_to_graph(fb_id, value->value_as_Double()->double_val());
271         break;
272       case vkgraph::GraphTypes::Bool:
273         add_scalar_to_graph(fb_id, value->value_as_Bool()->bool_val());
274         break;
275       case vkgraph::GraphTypes::VkTensor:
276         add_tensor_to_graph(fb_id, value->value_as_VkTensor());
277         break;
278       case vkgraph::GraphTypes::IntList:
279         add_scalar_list_to_graph(
280             fb_id,
281             std::vector<int64_t>(
282                 value->value_as_IntList()->items()->cbegin(),
283                 value->value_as_IntList()->items()->cend()));
284         break;
285       case vkgraph::GraphTypes::DoubleList:
286         add_scalar_list_to_graph(
287             fb_id,
288             std::vector<double>(
289                 value->value_as_DoubleList()->items()->cbegin(),
290                 value->value_as_DoubleList()->items()->cend()));
291         break;
292       case vkgraph::GraphTypes::BoolList:
293         add_scalar_list_to_graph(
294             fb_id,
295             std::vector<bool>(
296                 value->value_as_BoolList()->items()->cbegin(),
297                 value->value_as_BoolList()->items()->cend()));
298         break;
299       case vkgraph::GraphTypes::ValueList:
300         add_value_list_to_graph(
301             fb_id,
302             std::vector<ValueRef>(
303                 value->value_as_ValueList()->items()->cbegin(),
304                 value->value_as_ValueList()->items()->cend()));
305         break;
306       case vkgraph::GraphTypes::String:
307         add_string_to_graph(fb_id, value);
308         break;
309       case vkgraph::GraphTypes::SymInt:
310         add_symint_to_graph(fb_id, value);
311         break;
312       default:
313         ET_CHECK_MSG(false, "Unsupported value type.");
314     }
315   }
316 
build_graph()317   void build_graph() {
318     // First, add all values to the graph
319     for (uint32_t fb_id = 0; fb_id < flatbuffer_->values()->size(); ++fb_id) {
320       VkValuePtr value = flatbuffer_->values()->Get(fb_id);
321       add_value_to_graph(fb_id, value);
322     }
323 
324     // Parse the inputs, which will be tensors most of the time but can also be
325     // symints and tensorrefs (which will be the case if the original graph had)
326     // mutable buffers.
327     for (const uint32_t fb_id : *flatbuffer_->input_ids()) {
328       const ValueRef ref = get_fb_id_valueref(fb_id);
329       if (compute_graph_->val_is_tensor(ref)) {
330         compute_graph_->set_input_tensor(ref);
331       } else {
332         compute_graph_->set_val_as_input(ref);
333       }
334     }
335 
336     // Parse the operators
337     uint32_t last_prepack_node_ct = 0;
338     uint32_t last_execute_node_ct = 0;
339 
340     for (OpCallPtr op_call : *(flatbuffer_->chain())) {
341       std::string op_name = op_call->name()->str();
342       ET_CHECK_MSG(VK_HAS_OP(op_name), "Missing operator: %s", op_name.c_str());
343 
344       const std::vector<int> arg_fb_ids(
345           op_call->args()->cbegin(), op_call->args()->cend());
346 
347       std::vector<ValueRef> args;
348       for (const int arg_fb_id : arg_fb_ids) {
349         args.push_back(get_fb_id_valueref(arg_fb_id));
350       }
351 
352       auto vkFn = VK_GET_OP_FN(op_name);
353       vkFn(*compute_graph_, args);
354       if (compute_graph_->graphconfig().enable_querypool) {
355         for (uint32_t idx_prepack = last_prepack_node_ct;
356              idx_prepack < compute_graph_->prepack_nodes().size();
357              idx_prepack++) {
358           compute_graph_->prepack_nodes()[idx_prepack]->set_node_id(
359               op_call->node_id());
360         }
361         for (uint32_t idx_execute = last_execute_node_ct;
362              idx_execute < compute_graph_->execute_nodes().size();
363              idx_execute++) {
364           compute_graph_->execute_nodes()[idx_execute]->set_node_id(
365               op_call->node_id());
366         }
367         last_prepack_node_ct = compute_graph_->prepack_nodes().size();
368         last_execute_node_ct = compute_graph_->execute_nodes().size();
369       }
370     }
371 
372     // Parse the outputs, which will be mostly tensors.  For some reason,
373     // mutable buffers are shown to be returned in the fx.Graph but do not get
374     // returned by the delegate; this may be an implementation detail of how the
375     // executorch emitter handles mutable buffers.
376     for (const uint32_t fb_id : *flatbuffer_->output_ids()) {
377       const ValueRef ref = get_fb_id_valueref(fb_id);
378       if (compute_graph_->val_is_tensor(ref)) {
379         compute_graph_->set_output_tensor(ref);
380       }
381     }
382   }
383 };
384 
385 //
386 // Execution tools
387 //
388 
maybe_resize_input(ComputeGraph * graph,const size_t input_i,executorch::aten::Tensor & et_tensor)389 bool maybe_resize_input(
390     ComputeGraph* graph,
391     const size_t input_i,
392     executorch::aten::Tensor& et_tensor) {
393   ValueRef in_tensor_ref = graph->inputs()[input_i].value;
394   vTensorPtr in_tensor = graph->get_tensor(in_tensor_ref);
395 
396   ET_CHECK_MSG(
397       et_tensor.dim() == in_tensor->sizes().size(),
398       "Cannot resize input tensor: old ndim %zu does not match new ndim %zu",
399       static_cast<size_t>(in_tensor->sizes().size()),
400       static_cast<size_t>(et_tensor.dim()));
401 
402   bool should_resize = false;
403   std::vector<int64_t> new_sizes(et_tensor.dim());
404   for (size_t i = 0; i < et_tensor.dim(); i++) {
405     if (in_tensor->sizes()[i] != et_tensor.sizes()[i]) {
406       should_resize = true;
407     }
408     new_sizes.at(i) = et_tensor.sizes()[i];
409   }
410 
411   if (should_resize) {
412     graph->resize_input(input_i, new_sizes);
413   }
414 
415   ET_CHECK_MSG(
416       in_tensor->numel() == et_tensor.numel(),
417       "Vulkan tensor numel %zu does not match ET tensor numel %zu",
418       static_cast<size_t>(in_tensor->numel()),
419       static_cast<size_t>(et_tensor.numel()));
420 
421   return should_resize;
422 }
423 
maybe_update_scalar_tensor(ComputeGraph * graph,const ValueRef ref,executorch::aten::Tensor & scalar_tensor_src)424 bool maybe_update_scalar_tensor(
425     ComputeGraph* graph,
426     const ValueRef ref,
427     executorch::aten::Tensor& scalar_tensor_src) {
428   const int32_t cur_val = graph->read_symint(ref);
429   int32_t scalar_tensor_val = 0;
430   exec_aten::ScalarType dtype = scalar_tensor_src.scalar_type();
431   if (dtype == exec_aten::ScalarType::Int) {
432     scalar_tensor_val = *scalar_tensor_src.const_data_ptr<int32_t>();
433   } else if (dtype == exec_aten::ScalarType::Long) {
434     scalar_tensor_val = int32_t(*scalar_tensor_src.const_data_ptr<int64_t>());
435   }
436   bool was_updated = false;
437   if (scalar_tensor_val != cur_val) {
438     graph->set_symint(ref, scalar_tensor_val);
439     was_updated = true;
440   }
441   return was_updated;
442 }
443 
maybe_resize_output(ComputeGraph * graph,const size_t output_i,executorch::aten::Tensor & et_tensor)444 void maybe_resize_output(
445     ComputeGraph* graph,
446     const size_t output_i,
447     executorch::aten::Tensor& et_tensor) {
448   ValueRef out_tensor_ref = graph->outputs()[output_i].value;
449   vTensorPtr out_tensor = graph->get_tensor(out_tensor_ref);
450 
451   executorch::aten::SizesType new_output_size[kTensorDimensionLimit];
452   size_t ndim = out_tensor->sizes().size();
453   for (int i = 0; i < ndim; ++i) {
454     new_output_size[i] = out_tensor->sizes()[i];
455   }
456 
457   executorch::aten::ArrayRef<executorch::aten::SizesType> output_size{
458       new_output_size, ndim};
459   Error err = resize_tensor(et_tensor, output_size);
460 
461   ET_CHECK_MSG(err == Error::Ok, "Failed to resize output tensor.");
462 }
463 
464 //
465 // VulkanBackend class
466 //
467 
468 class VulkanBackend final : public ::executorch::runtime::BackendInterface {
469  public:
470   ~VulkanBackend() override = default;
471 
is_available() const472   bool is_available() const override {
473     // TODO(ssjia): replace with an actual Vulkan runtime availability check
474     return true;
475   }
476 
477   ET_NODISCARD Error
compileModel(const void * buffer_pointer,ComputeGraph * compute_graph) const478   compileModel(const void* buffer_pointer, ComputeGraph* compute_graph) const {
479     Result<VulkanDelegateHeader> header =
480         VulkanDelegateHeader::parse(buffer_pointer);
481 
482     const uint8_t* flatbuffer_data = nullptr;
483     const uint8_t* constant_data = nullptr;
484 
485     if (header.ok()) {
486       const uint8_t* buffer_start =
487           reinterpret_cast<const uint8_t*>(buffer_pointer);
488       flatbuffer_data = buffer_start + header->flatbuffer_offset;
489       constant_data = buffer_start + header->bytes_offset;
490     } else {
491       ET_LOG(Error, "VulkanDelegateHeader may be corrupt");
492       return header.error();
493     }
494 
495     ET_CHECK_OR_RETURN_ERROR(
496         vkgraph::VkGraphBufferHasIdentifier(flatbuffer_data),
497         DelegateInvalidCompatibility,
498         "Vulkan Delegate Serialization Format version identifier '%.4s' != expected '%.4s'",
499         flatbuffers::GetBufferIdentifier(flatbuffer_data),
500         vkgraph::VkGraphIdentifier());
501 
502     VkGraphPtr flatbuffer_graph = vkgraph::GetVkGraph(flatbuffer_data);
503 
504     GraphBuilder builder =
505         GraphBuilder(compute_graph, flatbuffer_graph, constant_data);
506 
507     builder.build_graph();
508 
509     compute_graph->prepare();
510 
511     compute_graph->encode_prepack();
512     compute_graph->prepack();
513 
514     compute_graph->encode_execute();
515 
516     return Error::Ok;
517   }
518 
init(BackendInitContext & context,FreeableBuffer * processed,ArrayRef<CompileSpec> compile_specs) const519   Result<DelegateHandle*> init(
520       BackendInitContext& context,
521       FreeableBuffer* processed,
522       ArrayRef<CompileSpec> compile_specs) const override {
523     ComputeGraph* compute_graph = ET_ALLOCATE_INSTANCE_OR_RETURN_ERROR(
524         context.get_runtime_allocator(), ComputeGraph);
525 
526     new (compute_graph) ComputeGraph(get_graph_config(compile_specs));
527 
528     Error err = compileModel(processed->data(), compute_graph);
529 
530     // This backend does not need its processed data after compiling the
531     // model.
532     processed->Free();
533 
534     if (err != Error::Ok) {
535       return err;
536     }
537 
538     return compute_graph;
539   }
540 
execute(ET_UNUSED BackendExecutionContext & context,DelegateHandle * handle,EValue ** args) const541   Error execute(
542       ET_UNUSED BackendExecutionContext& context,
543       DelegateHandle* handle,
544       EValue** args) const override {
545     EXECUTORCH_SCOPE_PROF("VulkanBackend::execute");
546 
547     ComputeGraph* compute_graph = static_cast<ComputeGraph*>(handle);
548 
549     const size_t num_inputs = compute_graph->inputs().size();
550     bool should_propagate_resize = false;
551     for (size_t i = 0; i < num_inputs; i++) {
552       const ValueRef iref = compute_graph->inputs()[i].value;
553       if (compute_graph->val_is_tensor(iref)) {
554         VK_CHECK_COND(args[i]->isTensor());
555         bool was_resized =
556             maybe_resize_input(compute_graph, i, args[i]->toTensor());
557         should_propagate_resize = should_propagate_resize || was_resized;
558         compute_graph->copy_into_staging(
559             compute_graph->inputs()[i].staging,
560             args[i]->toTensor().const_data_ptr(),
561             args[i]->toTensor().numel());
562       } else if (compute_graph->val_is_symint(iref)) {
563         VK_CHECK_COND(
564             args[i]->isTensor(),
565             "Cannot handle symint arg to graph that is not derived from a "
566             "scalar tensor at the moment.");
567         bool was_updated = maybe_update_scalar_tensor(
568             compute_graph, iref, args[i]->toTensor());
569         // Since symint inputs may impact tensor's sizes, trigger a resize if
570         // any symbolic integer shapes are updated.
571         should_propagate_resize = should_propagate_resize || was_updated;
572       } else {
573         VK_THROW(
574             "Could not handle input with type ",
575             compute_graph->get_val_type(iref));
576       }
577     }
578 
579     if (should_propagate_resize) {
580       compute_graph->propagate_resize();
581     }
582     compute_graph->execute();
583 
584     for (size_t i = 0; i < compute_graph->outputs().size(); i++) {
585       const ValueRef oref = compute_graph->outputs()[i].value;
586       if (compute_graph->val_is_tensor(oref)) {
587         VK_CHECK_COND(args[i]->isTensor());
588         maybe_resize_output(compute_graph, i, args[num_inputs + i]->toTensor());
589         // args holds inputs directly followed by outputs, so the i'th output
590         // for compute_graph corresponds to the (i + num_inputs)'th arg
591         compute_graph->copy_from_staging(
592             compute_graph->outputs()[i].staging,
593             args[num_inputs + i]->toTensor().mutable_data_ptr(),
594             args[num_inputs + i]->toTensor().numel());
595       } else {
596         VK_THROW(
597             "Could not handle output with type ",
598             compute_graph->get_val_type(oref));
599       }
600     }
601 
602 #ifdef ET_EVENT_TRACER_ENABLED
603     runtime::EventTracer* event_tracer = context.event_tracer();
604     compute_graph->context()->querypool().extract_results();
605     for (const auto& tup :
606          compute_graph->context()->querypool().get_shader_timestamp_data()) {
607       std::string event_name =
608           std::get<0>(tup) + "_" + std::to_string(std::get<1>(tup));
609       event_tracer_log_profiling_delegate(
610           event_tracer,
611           event_name.c_str(),
612           -1,
613           std::get<2>(tup),
614           std::get<3>(tup));
615     }
616 #endif // ET_EVENT_TRACER_ENABLED
617 
618     return Error::Ok;
619   }
620 
destroy(DelegateHandle * handle) const621   void destroy(DelegateHandle* handle) const override {
622     if (handle != nullptr) {
623       ComputeGraph* compute_graph = static_cast<ComputeGraph*>(handle);
624       compute_graph->context()
625           ->adapter_ptr()
626           ->compute_pipeline_cache()
627           .save_cache();
628       // ComputeGraph is not trivially destructible. Since
629       // this was constructed manually in init(), we must destroy it manually
630       // here.
631       compute_graph->~ComputeGraph();
632     }
633   }
634 };
635 
636 auto cls = VulkanBackend();
637 Backend backend{"VulkanBackend", &cls};
638 static auto success_with_compiler = register_backend(backend);
639 
640 } // namespace
641 } // namespace vulkan
642 } // namespace backends
643 } // namespace executorch
644