xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/vulkan/ops/Common.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #ifdef USE_VULKAN_API
4 
5 #include <c10/util/ArrayRef.h>
6 
7 #include <ATen/core/List.h>
8 #include <ATen/core/Tensor.h>
9 #include <ATen/native/vulkan/api/api.h>
10 #include <ATen/native/vulkan/impl/Common.h>
11 #include <ATen/native/vulkan/ops/Convert.h>
12 
13 namespace at {
14 namespace native {
15 namespace vulkan {
16 namespace ops {
17 
18 struct Layout final {
19   // 4D Activation Maps
20   struct Activation4D final {
21     static constexpr size_t batch = 0u;
22     static constexpr size_t channels = 1u;
23     static constexpr size_t height = 2u;
24     static constexpr size_t width = 3u;
25   };
26 
27   // Convolution Filters
28   struct Filter final {
29     static constexpr size_t output = 0u;
30     static constexpr size_t input = 1u;
31     static constexpr size_t height = 2u;
32     static constexpr size_t width = 3u;
33   };
34 
35   // Transposed Convolution Filters
36   struct TransposedFilter final {
37     static constexpr size_t input = 0u;
38     static constexpr size_t output = 1u;
39     static constexpr size_t height = 2u;
40     static constexpr size_t width = 3u;
41   };
42 
43   // Parameters (Pooling Kernels, Dilation, Padding, Stride, etc.)
44   struct Parameter final {
45     static constexpr size_t height = 0u;
46     static constexpr size_t width = 1u;
47   };
48 
49   // Parameters (Pooling Kernels, Dilation, Padding, Stride, etc.)
50   struct BatchMatrices final {
51     static constexpr size_t batch = 0u;
52     static constexpr size_t height = 1u;
53     static constexpr size_t width = 2u;
54   };
55 };
56 
57 /*
58  * The functions below safely return the size of the dimension at the N-th
59  * innermost index. If the dimensionality of the size array is not sufficient
60  * then 1 will be returned. The structs above are intended to be used with
61  * these functions.
62  */
63 template <uint32_t N>
get_dim(const IntArrayRef sizes)64 uint32_t get_dim(const IntArrayRef sizes) {
65   const uint32_t dims = sizes.size();
66   return dims < N ? 1 : api::utils::safe_downcast<uint32_t>(sizes[dims - N]);
67 }
68 
69 template <uint32_t N>
get_dim(const Tensor & t_in)70 uint32_t get_dim(const Tensor& t_in) {
71   return get_dim<N>(t_in.sizes());
72 }
73 
74 template <uint32_t N>
get_dim(const vTensor & v_in)75 uint32_t get_dim(const vTensor& v_in) {
76   return get_dim<N>(v_in.sizes());
77 }
78 
get_optional_tensor(const c10::impl::GenericList & gen_list,const uint32_t idx)79 inline std::optional<Tensor> get_optional_tensor(
80     const c10::impl::GenericList& gen_list,
81     const uint32_t idx) {
82   return gen_list.get(idx).isTensor() ? gen_list.get(idx).toTensor()
83                                       : std::optional<Tensor>();
84 }
85 
get_optional_scalar(const c10::impl::GenericList & gen_list,const uint32_t idx)86 inline std::optional<Scalar> get_optional_scalar(
87     const c10::impl::GenericList& gen_list,
88     const uint32_t idx) {
89   return gen_list.get(idx).isScalar() ? gen_list.get(idx).toScalar()
90                                       : std::optional<Scalar>();
91 }
92 
roundevenf(float v)93 inline float roundevenf(float v) {
94   return (float)nearbyint(v);
95 }
96 
97 } // namespace ops
98 } // namespace vulkan
99 } // namespace native
100 } // namespace at
101 
102 #endif /* USE_VULKAN_API */
103