xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/vulkan/ops/Utils.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #ifdef USE_VULKAN_API
4 
5 #include <ATen/native/vulkan/ops/Common.h>
6 
7 namespace at {
8 namespace native {
9 namespace vulkan {
10 namespace ops {
11 
12 namespace utils {
13 
14 Tensor nchw_to_nc4hw(const Tensor&);
15 
16 Tensor create_staging_tensor(const vTensor&);
17 
18 Tensor nc4hw_to_nchw(const Tensor&, IntArrayRef);
19 
20 void copy_buffer_to_buffer(
21     api::Context* const context,
22     api::StorageBuffer& src,
23     api::StorageBuffer& dst,
24     VkFence fence_handle);
25 
26 void copy_buffer_to_vtensor(
27     api::VulkanBuffer&,
28     vTensor&,
29     api::PipelineBarrier&);
30 
31 void copy_vtensor_to_buffer(
32     vTensor&,
33     api::VulkanBuffer&,
34     api::PipelineBarrier&,
35     const VkFence fence_handle = VK_NULL_HANDLE);
36 
normalize(const int64_t dimension,const int64_t n)37 inline int64_t normalize(const int64_t dimension, const int64_t n) {
38   return (dimension % n + n) % n;
39 }
40 
41 void pack_buffer_to_vtensor(
42     api::VulkanBuffer&,
43     vTensor&,
44     api::PipelineBarrier&);
45 
46 void pack_staging_to_vtensor(api::VulkanBuffer&, vTensor&);
47 
48 bool pack_vtensor_to_staging(
49     vTensor&,
50     api::VulkanBuffer&,
51     const VkFence fence_handle = VK_NULL_HANDLE);
52 
53 // Broadcasting Utils
54 void is_broadcastable(const Tensor& input1, const Tensor& input2);
55 std::vector<int64_t> broadcast_size(const Tensor& t1, const Tensor& t2);
56 
57 // This function returns the value of the underlying texel at pos of the given
58 // tensor. It is useful for debugging and unit test at which we want to verify
59 // the actual tensor layout. This function is very slow as it involves a fench
60 // to extract just one value.
61 api::utils::vec4 extract_texel(
62     const Tensor& tensor,
63     const api::utils::ivec3& pos);
64 
65 inline api::utils::ivec2 make_ivec2(
66     const IntArrayRef ints,
67     bool reverse = false) {
68   VK_CHECK_COND(ints.size() == 2);
69   if (reverse) {
70     return {
71         api::utils::safe_downcast<int32_t>(ints[1]),
72         api::utils::safe_downcast<int32_t>(ints[0])};
73   } else {
74     return {
75         api::utils::safe_downcast<int32_t>(ints[0]),
76         api::utils::safe_downcast<int32_t>(ints[1])};
77   }
78 }
79 
80 inline api::utils::ivec4 make_ivec4(
81     const IntArrayRef ints,
82     bool reverse = false) {
83   VK_CHECK_COND(ints.size() == 4);
84   if (reverse) {
85     return {
86         api::utils::safe_downcast<int32_t>(ints[3]),
87         api::utils::safe_downcast<int32_t>(ints[2]),
88         api::utils::safe_downcast<int32_t>(ints[1]),
89         api::utils::safe_downcast<int32_t>(ints[0]),
90     };
91   } else {
92     return {
93         api::utils::safe_downcast<int32_t>(ints[0]),
94         api::utils::safe_downcast<int32_t>(ints[1]),
95         api::utils::safe_downcast<int32_t>(ints[2]),
96         api::utils::safe_downcast<int32_t>(ints[3]),
97     };
98   }
99 }
100 
101 } // namespace utils
102 } // namespace ops
103 } // namespace vulkan
104 } // namespace native
105 } // namespace at
106 
107 #endif /* USE_VULKAN_API */
108