1 #pragma once
2
3 #ifdef USE_VULKAN_API
4
5 #include <c10/util/ArrayRef.h>
6
7 #include <ATen/core/List.h>
8 #include <ATen/core/Tensor.h>
9 #include <ATen/native/vulkan/api/api.h>
10 #include <ATen/native/vulkan/impl/Common.h>
11 #include <ATen/native/vulkan/ops/Convert.h>
12
13 namespace at {
14 namespace native {
15 namespace vulkan {
16 namespace ops {
17
18 struct Layout final {
19 // 4D Activation Maps
20 struct Activation4D final {
21 static constexpr size_t batch = 0u;
22 static constexpr size_t channels = 1u;
23 static constexpr size_t height = 2u;
24 static constexpr size_t width = 3u;
25 };
26
27 // Convolution Filters
28 struct Filter final {
29 static constexpr size_t output = 0u;
30 static constexpr size_t input = 1u;
31 static constexpr size_t height = 2u;
32 static constexpr size_t width = 3u;
33 };
34
35 // Transposed Convolution Filters
36 struct TransposedFilter final {
37 static constexpr size_t input = 0u;
38 static constexpr size_t output = 1u;
39 static constexpr size_t height = 2u;
40 static constexpr size_t width = 3u;
41 };
42
43 // Parameters (Pooling Kernels, Dilation, Padding, Stride, etc.)
44 struct Parameter final {
45 static constexpr size_t height = 0u;
46 static constexpr size_t width = 1u;
47 };
48
49 // Parameters (Pooling Kernels, Dilation, Padding, Stride, etc.)
50 struct BatchMatrices final {
51 static constexpr size_t batch = 0u;
52 static constexpr size_t height = 1u;
53 static constexpr size_t width = 2u;
54 };
55 };
56
57 /*
58 * The functions below safely return the size of the dimension at the N-th
59 * innermost index. If the dimensionality of the size array is not sufficient
60 * then 1 will be returned. The structs above are intended to be used with
61 * these functions.
62 */
63 template <uint32_t N>
get_dim(const IntArrayRef sizes)64 uint32_t get_dim(const IntArrayRef sizes) {
65 const uint32_t dims = sizes.size();
66 return dims < N ? 1 : api::utils::safe_downcast<uint32_t>(sizes[dims - N]);
67 }
68
69 template <uint32_t N>
get_dim(const Tensor & t_in)70 uint32_t get_dim(const Tensor& t_in) {
71 return get_dim<N>(t_in.sizes());
72 }
73
74 template <uint32_t N>
get_dim(const vTensor & v_in)75 uint32_t get_dim(const vTensor& v_in) {
76 return get_dim<N>(v_in.sizes());
77 }
78
get_optional_tensor(const c10::impl::GenericList & gen_list,const uint32_t idx)79 inline std::optional<Tensor> get_optional_tensor(
80 const c10::impl::GenericList& gen_list,
81 const uint32_t idx) {
82 return gen_list.get(idx).isTensor() ? gen_list.get(idx).toTensor()
83 : std::optional<Tensor>();
84 }
85
get_optional_scalar(const c10::impl::GenericList & gen_list,const uint32_t idx)86 inline std::optional<Scalar> get_optional_scalar(
87 const c10::impl::GenericList& gen_list,
88 const uint32_t idx) {
89 return gen_list.get(idx).isScalar() ? gen_list.get(idx).toScalar()
90 : std::optional<Scalar>();
91 }
92
roundevenf(float v)93 inline float roundevenf(float v) {
94 return (float)nearbyint(v);
95 }
96
97 } // namespace ops
98 } // namespace vulkan
99 } // namespace native
100 } // namespace at
101
102 #endif /* USE_VULKAN_API */
103