1 #pragma once 2 3 #ifdef USE_VULKAN_API 4 5 #include <ATen/native/vulkan/api/api.h> 6 7 namespace at { 8 namespace native { 9 namespace vulkan { 10 11 /* 12 * Maps a semantic dimension name to an integer that corresponds to its 13 * innermost ordering in a 4D tensor in NCHW format. Width is the innermost 14 * dimension, so it corresponds to 1, height is the next innermost, so it 15 * corresponds to 2, and so on. 16 */ 17 struct Dim4D { 18 static constexpr uint32_t Width = 1u; 19 static constexpr uint32_t Height = 2u; 20 static constexpr uint32_t Channel = 3u; 21 static constexpr uint32_t Batch = 4u; 22 }; 23 24 /* 25 * Semantic dimension names for a 1D tensor 26 */ 27 struct Dim1D { 28 static constexpr uint32_t Length = 1u; 29 }; 30 31 /* 32 * Semantic dimension names for a 2D Convolution kernel. 33 */ 34 struct DimConv2DKernel { 35 static constexpr uint32_t Width = 1u; 36 static constexpr uint32_t Height = 2u; 37 static constexpr uint32_t InChannels = 3u; 38 static constexpr uint32_t OutChannels = 4u; 39 }; 40 41 /* 42 * The same as the above, except for a 2D Transposed Convolution kernel. 43 */ 44 struct DimTConv2DKernel { 45 static constexpr uint32_t Width = 1u; 46 static constexpr uint32_t Height = 2u; 47 static constexpr uint32_t OutChannels = 3u; 48 static constexpr uint32_t InChannels = 4u; 49 }; 50 51 /* 52 * The functions below safely return the size of the dimension at the N-th 53 * innermost index. If the dimensionality of the size array is not sufficient 54 * then 1 will be returned. The structs above are intended to be used with 55 * these functions. 56 */ 57 template <uint32_t N> dim_at(const std::vector<int64_t> & sizes)58uint32_t dim_at(const std::vector<int64_t>& sizes) { 59 const uint32_t dims = sizes.size(); 60 return dims < N ? 1 : api::utils::safe_downcast<uint32_t>(sizes[dims - N]); 61 } 62 63 template <uint32_t N> dim_at(const vTensor & v_in)64uint32_t dim_at(const vTensor& v_in) { 65 return dim_at<N>(v_in.sizes()); 66 } 67 68 /* 69 * For most global work group sizes, returns {4, 4, 4}, but adjusts the size for 70 * 2D global work group sizes. Always maintains a total of 64 invocations 71 */ 72 api::utils::uvec3 adaptive_work_group_size( 73 const api::utils::uvec3& global_work_group); 74 75 } // namespace vulkan 76 } // namespace native 77 } // namespace at 78 79 #endif /* USE_VULKAN_API */ 80