xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/vulkan/impl/Common.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #ifdef USE_VULKAN_API
4 
5 #include <ATen/native/vulkan/api/api.h>
6 
7 namespace at {
8 namespace native {
9 namespace vulkan {
10 
11 /*
12  * Maps a semantic dimension name to an integer that corresponds to its
13  * innermost ordering in a 4D tensor in NCHW format. Width is the innermost
14  * dimension, so it corresponds to 1, height is the next innermost, so it
15  * corresponds to 2, and so on.
16  */
17 struct Dim4D {
18   static constexpr uint32_t Width = 1u;
19   static constexpr uint32_t Height = 2u;
20   static constexpr uint32_t Channel = 3u;
21   static constexpr uint32_t Batch = 4u;
22 };
23 
24 /*
25  * Semantic dimension names for a 1D tensor
26  */
27 struct Dim1D {
28   static constexpr uint32_t Length = 1u;
29 };
30 
31 /*
32  * Semantic dimension names for a 2D Convolution kernel.
33  */
34 struct DimConv2DKernel {
35   static constexpr uint32_t Width = 1u;
36   static constexpr uint32_t Height = 2u;
37   static constexpr uint32_t InChannels = 3u;
38   static constexpr uint32_t OutChannels = 4u;
39 };
40 
41 /*
42  * The same as the above, except for a 2D Transposed Convolution kernel.
43  */
44 struct DimTConv2DKernel {
45   static constexpr uint32_t Width = 1u;
46   static constexpr uint32_t Height = 2u;
47   static constexpr uint32_t OutChannels = 3u;
48   static constexpr uint32_t InChannels = 4u;
49 };
50 
51 /*
52  * The functions below safely return the size of the dimension at the N-th
53  * innermost index. If the dimensionality of the size array is not sufficient
54  * then 1 will be returned. The structs above are intended to be used with
55  * these functions.
56  */
57 template <uint32_t N>
dim_at(const std::vector<int64_t> & sizes)58 uint32_t dim_at(const std::vector<int64_t>& sizes) {
59   const uint32_t dims = sizes.size();
60   return dims < N ? 1 : api::utils::safe_downcast<uint32_t>(sizes[dims - N]);
61 }
62 
63 template <uint32_t N>
dim_at(const vTensor & v_in)64 uint32_t dim_at(const vTensor& v_in) {
65   return dim_at<N>(v_in.sizes());
66 }
67 
68 /*
69  * For most global work group sizes, returns {4, 4, 4}, but adjusts the size for
70  * 2D global work group sizes. Always maintains a total of 64 invocations
71  */
72 api::utils::uvec3 adaptive_work_group_size(
73     const api::utils::uvec3& global_work_group);
74 
75 } // namespace vulkan
76 } // namespace native
77 } // namespace at
78 
79 #endif /* USE_VULKAN_API */
80