1 #pragma once
2
3 #ifdef USE_VULKAN_API
4
5 #include <ATen/native/vulkan/ops/Common.h>
6
7 namespace at {
8 namespace native {
9 namespace vulkan {
10 namespace ops {
11
12 namespace utils {
13
14 Tensor nchw_to_nc4hw(const Tensor&);
15
16 Tensor create_staging_tensor(const vTensor&);
17
18 Tensor nc4hw_to_nchw(const Tensor&, IntArrayRef);
19
20 void copy_buffer_to_buffer(
21 api::Context* const context,
22 api::StorageBuffer& src,
23 api::StorageBuffer& dst,
24 VkFence fence_handle);
25
26 void copy_buffer_to_vtensor(
27 api::VulkanBuffer&,
28 vTensor&,
29 api::PipelineBarrier&);
30
31 void copy_vtensor_to_buffer(
32 vTensor&,
33 api::VulkanBuffer&,
34 api::PipelineBarrier&,
35 const VkFence fence_handle = VK_NULL_HANDLE);
36
normalize(const int64_t dimension,const int64_t n)37 inline int64_t normalize(const int64_t dimension, const int64_t n) {
38 return (dimension % n + n) % n;
39 }
40
41 void pack_buffer_to_vtensor(
42 api::VulkanBuffer&,
43 vTensor&,
44 api::PipelineBarrier&);
45
46 void pack_staging_to_vtensor(api::VulkanBuffer&, vTensor&);
47
48 bool pack_vtensor_to_staging(
49 vTensor&,
50 api::VulkanBuffer&,
51 const VkFence fence_handle = VK_NULL_HANDLE);
52
53 // Broadcasting Utils
54 void is_broadcastable(const Tensor& input1, const Tensor& input2);
55 std::vector<int64_t> broadcast_size(const Tensor& t1, const Tensor& t2);
56
57 // This function returns the value of the underlying texel at pos of the given
58 // tensor. It is useful for debugging and unit test at which we want to verify
59 // the actual tensor layout. This function is very slow as it involves a fench
60 // to extract just one value.
61 api::utils::vec4 extract_texel(
62 const Tensor& tensor,
63 const api::utils::ivec3& pos);
64
65 inline api::utils::ivec2 make_ivec2(
66 const IntArrayRef ints,
67 bool reverse = false) {
68 VK_CHECK_COND(ints.size() == 2);
69 if (reverse) {
70 return {
71 api::utils::safe_downcast<int32_t>(ints[1]),
72 api::utils::safe_downcast<int32_t>(ints[0])};
73 } else {
74 return {
75 api::utils::safe_downcast<int32_t>(ints[0]),
76 api::utils::safe_downcast<int32_t>(ints[1])};
77 }
78 }
79
80 inline api::utils::ivec4 make_ivec4(
81 const IntArrayRef ints,
82 bool reverse = false) {
83 VK_CHECK_COND(ints.size() == 4);
84 if (reverse) {
85 return {
86 api::utils::safe_downcast<int32_t>(ints[3]),
87 api::utils::safe_downcast<int32_t>(ints[2]),
88 api::utils::safe_downcast<int32_t>(ints[1]),
89 api::utils::safe_downcast<int32_t>(ints[0]),
90 };
91 } else {
92 return {
93 api::utils::safe_downcast<int32_t>(ints[0]),
94 api::utils::safe_downcast<int32_t>(ints[1]),
95 api::utils::safe_downcast<int32_t>(ints[2]),
96 api::utils::safe_downcast<int32_t>(ints[3]),
97 };
98 }
99 }
100
101 } // namespace utils
102 } // namespace ops
103 } // namespace vulkan
104 } // namespace native
105 } // namespace at
106
107 #endif /* USE_VULKAN_API */
108