1 /*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 // @lint-ignore-every CLANGTIDY facebook-security-vulnerable-memcpy
10
11 #include <executorch/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.h>
12 #include <executorch/backends/vulkan/runtime/graph/ops/utils/StagingUtils.h>
13
14 #include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/DimUtils.h>
15
16 namespace vkcompute {
17
is_bitw8(vkapi::ScalarType dtype)18 bool is_bitw8(vkapi::ScalarType dtype) {
19 return dtype == vkapi::kByte || dtype == vkapi::kChar ||
20 dtype == vkapi::kQInt8 || dtype == vkapi::kQUInt8;
21 }
22
get_nchw_to_tensor_shader(const api::vTensor & v_dst,const bool int8_buffer_enabled)23 vkapi::ShaderInfo get_nchw_to_tensor_shader(
24 const api::vTensor& v_dst,
25 const bool int8_buffer_enabled) {
26 std::string kernel_name;
27 kernel_name.reserve(kShaderNameReserve);
28
29 if (is_bitw8(v_dst.dtype()) && v_dst.storage_type() != utils::kBuffer &&
30 !int8_buffer_enabled) {
31 kernel_name = "nchw_to_bitw8_image_nobitw8buffer";
32 add_storage_type_suffix(kernel_name, v_dst);
33 add_dtype_suffix(kernel_name, v_dst);
34 return VK_KERNEL_FROM_STR(kernel_name);
35 }
36
37 if (v_dst.storage_type() == utils::kBuffer) {
38 kernel_name = "nchw_to_buffer";
39 add_dtype_suffix(kernel_name, v_dst);
40 return VK_KERNEL_FROM_STR(kernel_name);
41 }
42
43 kernel_name = "nchw_to_image";
44 add_storage_type_suffix(kernel_name, v_dst);
45 add_dtype_suffix(kernel_name, v_dst);
46
47 return VK_KERNEL_FROM_STR(kernel_name);
48 }
49
get_tensor_to_nchw_shader(const api::vTensor & v_src,bool int8_buffer_enabled)50 vkapi::ShaderInfo get_tensor_to_nchw_shader(
51 const api::vTensor& v_src,
52 bool int8_buffer_enabled) {
53 std::string kernel_name;
54 kernel_name.reserve(kShaderNameReserve);
55
56 if (is_bitw8(v_src.dtype()) && v_src.storage_type() != utils::kBuffer &&
57 !int8_buffer_enabled) {
58 kernel_name = "bitw8_image_to_nchw_nobitw8buffer";
59 add_storage_type_suffix(kernel_name, v_src);
60 add_dtype_suffix(kernel_name, v_src);
61 return VK_KERNEL_FROM_STR(kernel_name);
62 }
63
64 if (v_src.storage_type() == utils::kBuffer) {
65 kernel_name = "buffer_to_nchw";
66 add_dtype_suffix(kernel_name, v_src);
67 return VK_KERNEL_FROM_STR(kernel_name);
68 }
69
70 kernel_name = "image_to_nchw";
71 add_storage_type_suffix(kernel_name, v_src);
72 add_dtype_suffix(kernel_name, v_src);
73
74 return VK_KERNEL_FROM_STR(kernel_name);
75 }
76
77 } // namespace vkcompute
78