xref: /aosp_15_r20/external/executorch/backends/vulkan/runtime/graph/ops/utils/StagingUtils.cpp (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 // @lint-ignore-every CLANGTIDY facebook-security-vulnerable-memcpy
10 
11 #include <executorch/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.h>
12 #include <executorch/backends/vulkan/runtime/graph/ops/utils/StagingUtils.h>
13 
14 #include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/DimUtils.h>
15 
16 namespace vkcompute {
17 
is_bitw8(vkapi::ScalarType dtype)18 bool is_bitw8(vkapi::ScalarType dtype) {
19   return dtype == vkapi::kByte || dtype == vkapi::kChar ||
20       dtype == vkapi::kQInt8 || dtype == vkapi::kQUInt8;
21 }
22 
get_nchw_to_tensor_shader(const api::vTensor & v_dst,const bool int8_buffer_enabled)23 vkapi::ShaderInfo get_nchw_to_tensor_shader(
24     const api::vTensor& v_dst,
25     const bool int8_buffer_enabled) {
26   std::string kernel_name;
27   kernel_name.reserve(kShaderNameReserve);
28 
29   if (is_bitw8(v_dst.dtype()) && v_dst.storage_type() != utils::kBuffer &&
30       !int8_buffer_enabled) {
31     kernel_name = "nchw_to_bitw8_image_nobitw8buffer";
32     add_storage_type_suffix(kernel_name, v_dst);
33     add_dtype_suffix(kernel_name, v_dst);
34     return VK_KERNEL_FROM_STR(kernel_name);
35   }
36 
37   if (v_dst.storage_type() == utils::kBuffer) {
38     kernel_name = "nchw_to_buffer";
39     add_dtype_suffix(kernel_name, v_dst);
40     return VK_KERNEL_FROM_STR(kernel_name);
41   }
42 
43   kernel_name = "nchw_to_image";
44   add_storage_type_suffix(kernel_name, v_dst);
45   add_dtype_suffix(kernel_name, v_dst);
46 
47   return VK_KERNEL_FROM_STR(kernel_name);
48 }
49 
get_tensor_to_nchw_shader(const api::vTensor & v_src,bool int8_buffer_enabled)50 vkapi::ShaderInfo get_tensor_to_nchw_shader(
51     const api::vTensor& v_src,
52     bool int8_buffer_enabled) {
53   std::string kernel_name;
54   kernel_name.reserve(kShaderNameReserve);
55 
56   if (is_bitw8(v_src.dtype()) && v_src.storage_type() != utils::kBuffer &&
57       !int8_buffer_enabled) {
58     kernel_name = "bitw8_image_to_nchw_nobitw8buffer";
59     add_storage_type_suffix(kernel_name, v_src);
60     add_dtype_suffix(kernel_name, v_src);
61     return VK_KERNEL_FROM_STR(kernel_name);
62   }
63 
64   if (v_src.storage_type() == utils::kBuffer) {
65     kernel_name = "buffer_to_nchw";
66     add_dtype_suffix(kernel_name, v_src);
67     return VK_KERNEL_FROM_STR(kernel_name);
68   }
69 
70   kernel_name = "image_to_nchw";
71   add_storage_type_suffix(kernel_name, v_src);
72   add_dtype_suffix(kernel_name, v_src);
73 
74   return VK_KERNEL_FROM_STR(kernel_name);
75 }
76 
77 } // namespace vkcompute
78