xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/vulkan/ops/Convolution.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #ifdef USE_VULKAN_API
4 
5 #include <ATen/native/vulkan/ops/Common.h>
6 #include <ATen/native/vulkan/ops/VulkanPackedContext.h>
7 
8 namespace at {
9 namespace native {
10 namespace vulkan {
11 namespace ops {
12 
13 enum Conv2dMethod {
14   Conv2dDepthwise,
15   Conv2dPointwise,
16   Conv2dSlidingWindow,
17 };
18 
19 namespace conv2d {
20 
21 Tensor rearrange_weights_dw(const Tensor& weight_in);
22 Tensor rearrange_weights_2d(const Tensor& weight_in, bool tconv);
23 Tensor rearrange_bias(
24     const std::optional<Tensor>& bias_in,
25     const at::Tensor& weight_in,
26     bool tconv);
27 
28 } // namespace conv2d
29 
30 namespace qconv2d_vk {
31 
32 struct QParams final {
33   api::utils::uvec3 out_extents;
34   int32_t ic4;
35   api::utils::ivec4 sizes_2d;
36   float output_scale;
37   float input_scale;
38   int32_t output_zero_point;
39   int32_t input_zero_point;
40   float weight_scale;
41   float bias_scale;
42   int32_t weight_zero_point;
43   int32_t bias_zero_point;
44   api::utils::ivec2 kernel_size;
45   api::utils::ivec2 stride;
46   api::utils::ivec2 padding;
47   api::utils::ivec2 dilate;
48   api::utils::vec2 clamp;
49   api::utils::ivec4 src_filter;
50 };
51 
52 } // namespace qconv2d_vk
53 
54 class Conv2dPackedContext final : virtual public VulkanPackedContext,
55                                   public torch::jit::CustomClassHolder {
56  private:
57   c10::impl::GenericList unpacked_;
58   api::ShaderInfo compute_shader_{};
59 
60  public:
61   Conv2dPackedContext(
62       const Tensor& weight,
63       const std::optional<Tensor>& bias,
64       const IntArrayRef stride_arg,
65       const IntArrayRef padding_arg,
66       const IntArrayRef dilation_arg,
67       const bool transposed,
68       const bool quantized,
69       const IntArrayRef output_padding_arg,
70       const int64_t groups,
71       const std::optional<Scalar>& output_min = std::nullopt,
72       const std::optional<Scalar>& output_max = std::nullopt);
73 
74   /*
75    * Assigns a name to each index in the unpacked list.
76    */
77   struct Unpacked final {
78     static constexpr uint32_t Weight = 0u;
79     static constexpr uint32_t Bias = 1u;
80     static constexpr uint32_t Stride = 2u;
81     static constexpr uint32_t Padding = 3u;
82     static constexpr uint32_t Dilation = 4u;
83     static constexpr uint32_t isTransposed = 5u;
84     static constexpr uint32_t isQuantized = 6u;
85     static constexpr uint32_t OutputPadding = 7u;
86     static constexpr uint32_t Groups = 8u;
87     static constexpr uint32_t OutputMin = 9u;
88     static constexpr uint32_t OutputMax = 10u;
89 
90     static constexpr uint32_t NumArgs = 11u;
91   };
92 
93   /*
94    * Assigns a name to each index in the packed list.
95    */
96   struct Packed final {
97     static constexpr uint32_t Weight = 0u;
98     static constexpr uint32_t Bias = 1u;
99     static constexpr uint32_t OverlayRegion = 2u;
100     static constexpr uint32_t Stride = 3u;
101     static constexpr uint32_t Padding = 4u;
102     static constexpr uint32_t OutputPadding = 5u;
103     static constexpr uint32_t Dilation = 6u;
104     static constexpr uint32_t isTransposed = 7u;
105     static constexpr uint32_t isQuantized = 8u;
106     static constexpr uint32_t Groups = 9u;
107     static constexpr uint32_t OutputMin = 10u;
108     static constexpr uint32_t OutputMax = 11u;
109     static constexpr uint32_t ConvMethod = 12u;
110     static constexpr uint32_t WeightSizes = 13u;
111 
112     static constexpr uint32_t NumArgs = 14u;
113   };
114 
115   static Conv2dPackedContext pack(c10::impl::GenericList);
116 
unpack()117   const c10::impl::GenericList unpack() const override {
118     TORCH_CHECK(unpacked_.size() > 0u, "unpacked_ does not have any elements!");
119 
120     return unpacked_;
121   }
122 
compute_shader()123   inline api::ShaderInfo& compute_shader() {
124     return compute_shader_;
125   }
126 };
127 
128 c10::intrusive_ptr<Conv2dPackedContext> create_conv2d_context(
129     Tensor&& weight,
130     std::optional<Tensor>&& bias,
131     std::vector<int64_t>&& stride,
132     std::vector<int64_t>&& padding,
133     std::vector<int64_t>&& dilation,
134     const int64_t groups,
135     const std::optional<Scalar>& output_min = std::nullopt,
136     const std::optional<Scalar>& output_max = std::nullopt);
137 
138 Tensor run_conv2d_context(
139     const Tensor& input,
140     const c10::intrusive_ptr<Conv2dPackedContext>& context);
141 
142 c10::intrusive_ptr<Conv2dPackedContext> create_tconv2d_context(
143     Tensor&& weight,
144     std::optional<Tensor>&& bias,
145     std::vector<int64_t>&& stride,
146     std::vector<int64_t>&& padding,
147     std::vector<int64_t>&& output_padding,
148     std::vector<int64_t>&& dilation,
149     const int64_t groups,
150     const std::optional<Scalar>& output_min = std::nullopt,
151     const std::optional<Scalar>& output_max = std::nullopt);
152 
153 Tensor run_tconv2d_context(
154     const Tensor& input,
155     const c10::intrusive_ptr<Conv2dPackedContext>& context);
156 
157 c10::intrusive_ptr<Conv2dPackedContext> create_qconv2d_context(
158     Tensor&& weight,
159     std::optional<Tensor>&& bias,
160     std::vector<int64_t>&& stride,
161     std::vector<int64_t>&& padding,
162     std::vector<int64_t>&& dilation,
163     const int64_t groups,
164     const std::optional<Scalar>& output_min = std::nullopt,
165     const std::optional<Scalar>& output_max = std::nullopt);
166 
167 Tensor run_qconv2d_context(
168     const Tensor& input_arg,
169     double scale,
170     int64_t zero_point,
171     const c10::intrusive_ptr<Conv2dPackedContext>& conv_context);
172 
173 c10::intrusive_ptr<Conv2dPackedContext> create_qtconv2d_context(
174     Tensor&& weight,
175     std::optional<Tensor>&& bias,
176     std::vector<int64_t>&& stride,
177     std::vector<int64_t>&& padding,
178     std::vector<int64_t>&& output_padding,
179     std::vector<int64_t>&& dilation,
180     const int64_t groups,
181     const std::optional<Scalar>& output_min = std::nullopt,
182     const std::optional<Scalar>& output_max = std::nullopt);
183 
184 // Backwards compatibility
185 class Conv2dOpContext final : public torch::jit::CustomClassHolder {
186  public:
187   static Conv2dOpContext create(
188       const Tensor& weight,
189       const std::optional<Tensor>& bias,
190       IntArrayRef stride,
191       IntArrayRef padding,
192       IntArrayRef dilation,
193       bool transposed,
194       IntArrayRef output_padding,
195       int64_t groups,
196       const std::optional<Scalar>& output_min = std::nullopt,
197       const std::optional<Scalar>& output_max = std::nullopt);
198 
199   using State = std::tuple<
200       Tensor,
201       std::optional<Tensor>,
202       std::vector<int64_t>,
203       std::vector<int64_t>,
204       std::vector<int64_t>,
205       int64_t,
206       std::optional<Scalar>,
207       std::optional<Scalar>>;
208 
209   Tensor run(const Tensor& input) const;
210   State unpack() const;
211 
212  private:
213   explicit Conv2dOpContext(Conv2dPackedContext conv_context);
214   Conv2dPackedContext conv_context_;
215 };
216 
217 Tensor conv2d_clamp_run(
218     const Tensor& input,
219     const c10::intrusive_ptr<Conv2dOpContext>& context);
220 
221 c10::intrusive_ptr<Conv2dOpContext> conv2d_clamp_prepack(
222     Tensor&& weight,
223     std::optional<Tensor>&& bias,
224     std::vector<int64_t>&& stride,
225     std::vector<int64_t>&& padding,
226     std::vector<int64_t>&& dilation,
227     const int64_t groups,
228     const std::optional<Scalar>& output_min,
229     const std::optional<Scalar>& output_max);
230 
231 class Conv1dPackedContext final : virtual public VulkanPackedContext,
232                                   public torch::jit::CustomClassHolder {
233  private:
234   c10::impl::GenericList unpacked_;
235   api::ShaderInfo compute_shader_{};
236 
237  public:
238   Conv1dPackedContext(
239       const Tensor& weight,
240       const std::optional<Tensor>& bias,
241       const IntArrayRef stride_arg,
242       const IntArrayRef padding_arg,
243       const IntArrayRef dilation_arg,
244       const int64_t groups);
245 
246   /*
247    * Assigns a name to each index in the unpacked list.
248    */
249   struct Unpacked final {
250     static constexpr uint32_t Weight = 0u;
251     static constexpr uint32_t Bias = 1u;
252     static constexpr uint32_t Stride = 2u;
253     static constexpr uint32_t Padding = 3u;
254     static constexpr uint32_t Dilation = 4u;
255     static constexpr uint32_t Groups = 5u;
256 
257     static constexpr uint32_t NumArgs = 6u;
258   };
259 
260   /*
261    * Assigns a name to each index in the packed list.
262    */
263   struct Packed final {
264     static constexpr uint32_t Weight = 0u;
265     static constexpr uint32_t Bias = 1u;
266     static constexpr uint32_t Stride = 2u;
267     static constexpr uint32_t Padding = 3u;
268     static constexpr uint32_t Dilation = 4u;
269     static constexpr uint32_t Groups = 5u;
270     static constexpr uint32_t WeightSizes = 6u;
271 
272     static constexpr uint32_t NumArgs = 7u;
273   };
274 
275   static Conv1dPackedContext pack(c10::impl::GenericList);
276 
unpack()277   const c10::impl::GenericList unpack() const override {
278     TORCH_CHECK(unpacked_.size() > 0u, "unpacked_ does not have any elements!");
279 
280     return unpacked_;
281   }
282 
compute_shader()283   inline api::ShaderInfo& compute_shader() {
284     return compute_shader_;
285   }
286 };
287 
288 c10::intrusive_ptr<Conv1dPackedContext> create_conv1d_context(
289     Tensor&& weight,
290     std::optional<Tensor>&& bias,
291     std::vector<int64_t>&& stride,
292     std::vector<int64_t>&& padding,
293     std::vector<int64_t>&& dilation,
294     const int64_t groups);
295 
296 Tensor run_conv1d_context(
297     const Tensor& input,
298     const c10::intrusive_ptr<Conv1dPackedContext>& context);
299 
300 } // namespace ops
301 } // namespace vulkan
302 } // namespace native
303 } // namespace at
304 
305 #endif /* USE_VULKAN_API */
306