xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/vulkan/ops/Convolution.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 
2 #include <ATen/Context.h>
3 
4 #include <ATen/native/ConvUtils.h>
5 #include <ATen/native/utils/ParamUtils.h>
6 #include <ATen/native/vulkan/api/Utils.h>
7 #include <ATen/native/vulkan/impl/Packing.h>
8 #include <ATen/native/vulkan/ops/Common.h>
9 #include <ATen/native/vulkan/ops/Convolution.h>
10 #include <ATen/native/vulkan/ops/Copy.h>
11 #include <ATen/native/vulkan/ops/Utils.h>
12 #include <c10/util/irange.h>
13 
14 #ifndef AT_PER_OPERATOR_HEADERS
15 #include <ATen/Functions.h>
16 #else
17 #include <ATen/ops/dequantize.h>
18 #include <ATen/ops/pad.h>
19 #include <ATen/ops/permute.h>
20 #include <ATen/ops/quantize_per_tensor.h>
21 #include <ATen/ops/zeros.h>
22 #endif
23 
24 namespace at {
25 namespace native {
26 namespace vulkan {
27 namespace ops {
28 
29 namespace conv2d {
30 
31 //
32 // Convolution type classification
33 //
34 
is_depthwise(const IntArrayRef weight_size,const int64_t groups)35 inline bool is_depthwise(const IntArrayRef weight_size, const int64_t groups) {
36   uint32_t groups_uint = api::utils::safe_downcast<uint32_t>(groups);
37   if (get_dim<DimConv2DKernel::OutChannels>(weight_size) != groups_uint) {
38     return false;
39   }
40   if (get_dim<DimConv2DKernel::InChannels>(weight_size) != 1) {
41     return false;
42   }
43   return true;
44 }
45 
is_pointwise(const IntArrayRef weight_size)46 inline bool is_pointwise(const IntArrayRef weight_size) {
47   if (get_dim<DimConv2DKernel::Width>(weight_size) != 1) {
48     return false;
49   }
50   if (get_dim<DimConv2DKernel::Height>(weight_size) != 1) {
51     return false;
52   }
53   return true;
54 }
55 
determine_method(const IntArrayRef weight_size,const IntArrayRef stride,const IntArrayRef padding,const IntArrayRef dilation,const int64_t groups,const bool transposed,const bool quantized)56 static Conv2dMethod determine_method(
57     const IntArrayRef weight_size,
58     const IntArrayRef stride,
59     const IntArrayRef padding,
60     const IntArrayRef dilation,
61     const int64_t groups,
62     const bool transposed,
63     const bool quantized) {
64   if (transposed) {
65     return Conv2dSlidingWindow;
66   }
67   if (is_depthwise(weight_size, groups)) {
68     return Conv2dDepthwise;
69   }
70   if (is_pointwise(weight_size)) {
71     return Conv2dPointwise;
72   }
73   return Conv2dSlidingWindow;
74 }
75 
76 //
77 // Rearrangement functions for pre-packing
78 //
79 
80 /*
81  * Rearranges a convolution weight tensor to a layout that can be used by
82  * convolution compute shaders. The goal of this packing is to arrange the data
83  * such that data access in the compute shader is as linear as possible. The
84  * reasoning behind the packing pattern will be described in the shader kernel
85  * code.
86  *
87  * To understand the transformations performed by this function, consider an
88  * example input of size {11, 1, 3, 3}. The following transformations will
89  * applied to this weight tensor:
90  *
91  * 1. First, apply padding to the N dims so that it is a multiple of 4.
92  * In this case, 1 batch is added, producing a tensor of size {12,1,3,3}.
93  *
94  * 2. Next, flatten the last two dims of the tensor. This is done by reshaping
95  * the tensor to size {12,1,9}.
96  *
97  * 3. Finally, we want to "fold" the batch dim into the channel dim. We start by
98  * splitting the tensor along the N dim so that each split has 4 batches. This
99  * is done by reshaping the tensor to size {3,4,1,9}.
100  *
101  * 4. Normally, we would be done, but we want to stack each back vertically.
102  * This is done by permuting the N and C dims and reshaping the tensor to size
103  * {4,3,9}.
104  */
rearrange_weights_dw(const Tensor & weight_in)105 at::Tensor rearrange_weights_dw(const Tensor& weight_in) {
106   at::Tensor weight = weight_in.clone();
107 
108   uint32_t N = ops::get_dim<DimConv2DKernel::OutChannels>(weight);
109   uint32_t C = ops::get_dim<DimConv2DKernel::InChannels>(weight);
110   uint32_t H = ops::get_dim<DimConv2DKernel::Height>(weight);
111   uint32_t W = ops::get_dim<DimConv2DKernel::Width>(weight);
112 
113   uint32_t N_aligned = api::utils::align_up(N, 4u);
114 
115   // Add padding to the N dimension so that it's a multiple of 4
116   uint32_t N_padding_needed = N_aligned - N;
117   weight =
118       at::pad(weight, {0, 0, 0, 0, 0, 0, 0, N_padding_needed}, "constant", 0);
119 
120   // Flatten so the H and W dim are on one row
121   weight = weight.reshape({N_aligned, C, H * W});
122 
123   // Split batch dim to make groups of 4
124   uint32_t N4 = N_aligned / 4u;
125   weight = weight.reshape({N4, 4, C, H * W});
126 
127   // Permute the groups of 4 so they are arranged along the channel dim, then
128   // reshape to stack the resulting batches vertically
129   weight = weight.permute({1, 0, 2, 3}).reshape({4, N4 * C, H * W});
130 
131   return weight.contiguous();
132 }
133 
134 /*
135  * Rearranges a convolution weight tensor to a layout that can be used by
136  * convolution compute shaders. The goal of this packing is to arrange the data
137  * such that data access in the compute shader is as linear as possible. The
138  * reasoning behind the packing pattern will be described in the shader kernel
139  * code.
140  *
141  * To understand the transformations performed by this function, consider an
142  * example input of size {10, 7, 3, 3}. The following transformations will
143  * applied to this weight tensor:
144  *
145  * 1. First, apply padding to the N and C dims so that both are a multiple of 4.
146  * In this case, 2 batches and 1 channel of padding are added, producing a
147  * tensor of size {12,8,3,3}.
148  *
149  * 2. Next, split the tensor along the C dim so that each split has 4 channels.
150  * This is done by reshaping the channel to have the size {12,2,(4,3,3)}. ()
151  * brackets denote the size of the split.
152  *
153  * 3. For each split, we want to "fold" the C dim into the W dim. So suppose the
154  * first rows at H=0 of the split has values
155  *
156  *    0,1,2 | 10,11,12 | 20,21,22 | 30,31,32
157  *
158  *    where | denotes a channel boundary, then the goal is to combine those rows
159  * into one row with the values
160  *
161  *    0, 10, 20, 30, 1, 11, 21, 31, 2, 12, 22, 32
162  *
163  *    This is done in code by permuting and reshaping the tensor, producing a
164  * tensor of size {12,2,(3,12)}.
165  *
166  * 4. Next, we want to stack the splits belonging to the same batch horizontally
167  * which is done by swapping the C and H dims of the intermediate tensor and
168  * reshaping to produce a tensor of size {12,3,24}.
169  *
170  * 5. Now we will repeat a similar process of "folding" the N dim into the C
171  * dim. We start by splitting along the N dim so that each split has 4 batches.
172  * To do this the tensor is reshaped to {3,4,3,24}.
173  *
174  * 6. Normally, we would be done but we also want to stack each batch on each
175  * other vertically. Therefore final step is another permute swapping the N and
176  * C dims and reshaping to the output shape of {4, 9, 24}.
177  *
178  * For transposed convolutions, there are some slight differences to reflect the
179  * data access pattern in the shader. The first major difference is that the
180  * weight tensor is flipped along the H and W dims. The second major difference
181  * is that steps 3 and 4 are slightly different so that the splits are
182  * interleaved.
183  */
rearrange_weights_2d(const Tensor & weight_in,bool tconv)184 at::Tensor rearrange_weights_2d(const Tensor& weight_in, bool tconv) {
185   at::Tensor weight = weight_in.clone();
186 
187   // Flip values along the H and W axes for transposed convolutions
188   if (tconv) {
189     weight = weight.flip(3).flip(2);
190   }
191 
192   uint32_t N = get_dim<DimConv2DKernel::OutChannels>(weight);
193   uint32_t C = get_dim<DimConv2DKernel::InChannels>(weight);
194   uint32_t H = get_dim<DimConv2DKernel::Height>(weight);
195   uint32_t W = get_dim<DimConv2DKernel::Width>(weight);
196 
197   uint32_t N_aligned = api::utils::align_up(N, 4u);
198   uint32_t C_aligned = api::utils::align_up(C, 4u);
199 
200   // Add padding to the N and C dimensions so that it's a multiple of 4
201   uint32_t C_padding_needed = C_aligned - C;
202   uint32_t N_padding_needed = N_aligned - N;
203   weight = at::pad(
204       weight,
205       {0, 0, 0, 0, 0, C_padding_needed, 0, N_padding_needed},
206       "constant",
207       0);
208 
209   // Split the C dim into groups of 4
210   uint32_t C4 = C_aligned / 4u;
211   weight = weight.reshape({N_aligned, C4, 4, H, W});
212 
213   if (!tconv) {
214     // Collapse each group of 4 channels onto the width axis
215     weight = weight.permute({0, 1, 3, 4, 2}).reshape({N_aligned, C4, H, 4 * W});
216     // Next collapse each group of four onto the width axis
217     weight =
218         weight.permute({0, 2, 1, 3}).reshape({N_aligned, H, C_aligned * W});
219   } else {
220     // For tconv, do the same thing as above but we want to interleave batches
221     // of 4 from each of the channels
222     weight = weight.permute({0, 3, 4, 1, 2}).reshape({N_aligned, H, W, 4 * C4});
223     // Next reshape to combine the last two dims into a single row
224     weight = weight.reshape({N_aligned, H, C_aligned * W});
225   }
226 
227   // Split the N dim into groups of 4
228   uint32_t N4 = N_aligned / 4u;
229   weight = weight.reshape({N4, 4, H, C_aligned * W});
230 
231   // Collapse the outermost dim so that each group of 4 is stacked vertically
232   weight = weight.permute({1, 0, 2, 3}).reshape({4, N4 * H, C_aligned * W});
233 
234   return weight.contiguous();
235 }
236 
237 /*
238  * Rearranges a convolution weight tensor to a layout that can be used by
239  * convolution compute shaders. The goal of this packing is to arrange the data
240  * such that data access in the compute shader is as linear as possible. The
241  * reasoning behind the packing pattern will be described in the shader kernel
242  * code.
243  *
244  * The rearrangement structure is quite straightforward. Essentially we are
245  * taking each texel and arranging them along the x axis.
246  */
rearrange_bias(const std::optional<Tensor> & bias_in,const at::Tensor & weight_in,bool tconv)247 at::Tensor rearrange_bias(
248     const std::optional<Tensor>& bias_in,
249     const at::Tensor& weight_in,
250     bool tconv) {
251   // If optional is empty, just return zeros
252   if (!bias_in) {
253     uint32_t L = tconv ? get_dim<DimTConv2DKernel::OutChannels>(weight_in)
254                        : get_dim<DimConv2DKernel::OutChannels>(weight_in);
255     const uint32_t L4 = api::utils::div_up(L, 4u);
256 
257     at::Tensor bias = at::zeros({4, 1, L4}, weight_in.options());
258     return bias;
259   }
260 
261   at::Tensor bias = bias_in->clone();
262 
263   // Bias should just be a 1D tensor
264   uint32_t L = get_dim<Dim1D::Length>(bias);
265 
266   uint32_t L_aligned = api::utils::align_up(L, 4u);
267 
268   // Add padding so that the length is a multiple of 4
269   uint32_t padding_needed = L_aligned - L;
270   bias = at::pad(bias, {0, padding_needed}, "constant", 0);
271 
272   // Reshape + permute to group every 4 consecutive elements along the same
273   // channel
274   uint32_t L4 = L_aligned / 4u;
275   bias = bias.reshape({L4, 4}).permute({1, 0});
276   bias = bias.reshape({4, 1, L4});
277 
278   return bias.contiguous();
279 }
280 
281 //
282 // Shader and Workgroup size determination
283 //
284 
get_shader(const IntArrayRef kernel_size,const IntArrayRef stride,const IntArrayRef padding,const IntArrayRef dilation,const Conv2dMethod method,const bool transposed,const bool quantized)285 static api::ShaderInfo get_shader(
286     const IntArrayRef kernel_size,
287     const IntArrayRef stride,
288     const IntArrayRef padding,
289     const IntArrayRef dilation,
290     const Conv2dMethod method,
291     const bool transposed,
292     const bool quantized) {
293   api::ShaderInfo shader;
294 
295   if (quantized) {
296     if (transposed) {
297       shader = VK_KERNEL(quantized_conv_transpose2d);
298       return shader;
299     }
300 
301     switch (method) {
302       case Conv2dSlidingWindow:
303         shader = VK_KERNEL(quantized_conv2d);
304         break;
305       case Conv2dDepthwise:
306         shader = VK_KERNEL(quantized_conv2d_dw);
307         break;
308       case Conv2dPointwise:
309         shader = VK_KERNEL(quantized_conv2d_pw_2x2);
310         break;
311         // todo fail for quantized transposed conv
312     }
313     return shader;
314   }
315 
316   if (transposed) {
317     shader = VK_KERNEL(conv_transpose2d);
318     return shader;
319   }
320 
321   switch (method) {
322     case Conv2dSlidingWindow:
323       shader = VK_KERNEL(conv2d);
324       break;
325     case Conv2dDepthwise:
326       shader = VK_KERNEL(conv2d_dw);
327       if (kernel_size.size() == 4 && kernel_size[2] == 3 &&
328           kernel_size[3] == 3) {
329         // 1x1 refers to the output tile size
330         shader = VK_KERNEL(conv2d_dw_output_tile_3x3);
331       }
332       if (kernel_size.size() == 4 && kernel_size[2] == 5 &&
333           kernel_size[3] == 5) {
334         // 1x1 refers to the output tile size
335         shader = VK_KERNEL(conv2d_dw_output_tile_5x5);
336       }
337       break;
338     case Conv2dPointwise:
339       shader = VK_KERNEL(conv2d_pw_output_tile_2x2);
340       break;
341   }
342   return shader;
343 }
344 
345 //
346 // Op Recording
347 //
348 
349 struct Params final {
350   api::utils::ivec3 out_extents;
351   int32_t fill0;
352   api::utils::ivec3 in_extents;
353   int32_t fill1;
354   api::utils::ivec4 overlay_region;
355   api::utils::ivec2 kernel_size;
356   api::utils::ivec2 stride;
357   api::utils::ivec2 padding;
358   api::utils::ivec2 dilate;
359   api::utils::vec2 clamp;
360 };
361 
record_op(api::Context * const context,api::ShaderInfo & compute_shader,vTensor & v_output,const vTensor & v_input,const vTensor & v_weight,const vTensor & v_bias,const IntArrayRef overlay_region,const IntArrayRef stride,const IntArrayRef padding,const IntArrayRef dilation,const float output_min,const float output_max,const IntArrayRef kernel_size,const Conv2dMethod method,const bool transposed)362 static void record_op(
363     api::Context* const context,
364     api::ShaderInfo& compute_shader,
365     vTensor& v_output,
366     const vTensor& v_input,
367     const vTensor& v_weight,
368     const vTensor& v_bias,
369     const IntArrayRef overlay_region,
370     const IntArrayRef stride,
371     const IntArrayRef padding,
372     const IntArrayRef dilation,
373     const float output_min,
374     const float output_max,
375     const IntArrayRef kernel_size,
376     const Conv2dMethod method,
377     const bool transposed) {
378   api::PipelineBarrier pipeline_barrier{};
379 
380   api::utils::uvec3 global_size = v_output.extents();
381   api::utils::uvec3 local_size = adaptive_work_group_size(global_size);
382 
383   Params block{
384       api::utils::make_ivec3(v_output.extents()),
385       0u,
386       api::utils::make_ivec3(v_input.extents()),
387       0u,
388       utils::make_ivec4(overlay_region, /*reverse=*/true),
389       utils::make_ivec2({kernel_size[3], kernel_size[2]}),
390       utils::make_ivec2(stride, /*reverse=*/true),
391       utils::make_ivec2(padding, /*reverse=*/true),
392       utils::make_ivec2(dilation, /*reverse=*/true),
393       {output_min, output_max},
394   };
395   api::UniformParamsBuffer params(context, block);
396 
397   context->submit_compute_job(
398       // shader descriptor
399       compute_shader,
400       // pipeline barrier
401       pipeline_barrier,
402       // global work group size
403       global_size,
404       // local work group size
405       local_size,
406       // fence handle
407       VK_NULL_HANDLE,
408       // shader arguments
409       v_output.image(
410           pipeline_barrier,
411           api::PipelineStage::COMPUTE,
412           api::MemoryAccessType::WRITE),
413       v_input.image(pipeline_barrier, api::PipelineStage::COMPUTE),
414       v_weight.image(pipeline_barrier, api::PipelineStage::COMPUTE),
415       v_bias.image(pipeline_barrier, api::PipelineStage::COMPUTE),
416       // params buffer
417       params.buffer());
418 }
419 
420 struct QParams final {
421   api::utils::vec4 scales;
422   api::utils::ivec4 zero_points;
423   api::utils::ivec3 out_extents;
424   int32_t fill0;
425   api::utils::ivec3 in_extents;
426   int32_t fill1;
427   api::utils::ivec4 overlay_region;
428   api::utils::ivec2 kernel_size;
429   api::utils::ivec2 stride;
430   api::utils::ivec2 padding;
431   api::utils::ivec2 dilate;
432   api::utils::vec2 clamp;
433 };
434 
record_quantized_op(api::Context * const context,api::ShaderInfo & compute_shader,vTensor & v_output,const vTensor & v_input,const vTensor & v_weight,const vTensor & v_bias,const IntArrayRef overlay_region,const IntArrayRef stride,const IntArrayRef padding,const IntArrayRef dilation,const float output_min,const float output_max,const IntArrayRef kernel_size,const Conv2dMethod method,const bool transposed)435 static void record_quantized_op(
436     api::Context* const context,
437     api::ShaderInfo& compute_shader,
438     vTensor& v_output,
439     const vTensor& v_input,
440     const vTensor& v_weight,
441     const vTensor& v_bias,
442     const IntArrayRef overlay_region,
443     const IntArrayRef stride,
444     const IntArrayRef padding,
445     const IntArrayRef dilation,
446     const float output_min,
447     const float output_max,
448     const IntArrayRef kernel_size,
449     const Conv2dMethod method,
450     const bool transposed) {
451   api::PipelineBarrier pipeline_barrier{};
452 
453   api::utils::uvec3 global_size = v_output.extents();
454   api::utils::uvec3 local_size = adaptive_work_group_size(global_size);
455 
456   QParams block{
457       {
458           v_output.get_scale_float(),
459           v_input.get_scale_float(),
460           v_weight.get_scale_float(),
461           v_bias.get_scale_float(),
462       },
463       {
464           v_output.get_zero_point_int32(),
465           v_input.get_zero_point_int32(),
466           v_weight.get_zero_point_int32(),
467           v_bias.get_zero_point_int32(),
468       },
469       api::utils::make_ivec3(v_output.extents()),
470       0u,
471       api::utils::make_ivec3(v_input.extents()),
472       0u,
473       utils::make_ivec4(overlay_region, /*reverse=*/true),
474       utils::make_ivec2({kernel_size[3], kernel_size[2]}),
475       utils::make_ivec2(stride, /*reverse=*/true),
476       utils::make_ivec2(padding, /*reverse=*/true),
477       utils::make_ivec2(dilation, /*reverse=*/true),
478       {output_min, output_max},
479   };
480   api::UniformParamsBuffer params(context, block);
481 
482   context->submit_compute_job(
483       // shader descriptor
484       compute_shader,
485       // pipeline barrier
486       pipeline_barrier,
487       // global work group size
488       global_size,
489       // local work group size
490       local_size,
491       // fence handle
492       VK_NULL_HANDLE,
493       // shader arguments
494       v_output.image(
495           pipeline_barrier,
496           api::PipelineStage::COMPUTE,
497           api::MemoryAccessType::WRITE),
498       v_input.image(pipeline_barrier, api::PipelineStage::COMPUTE),
499       v_weight.image(pipeline_barrier, api::PipelineStage::COMPUTE),
500       v_bias.image(pipeline_barrier, api::PipelineStage::COMPUTE),
501       // params buffer
502       params.buffer());
503 }
504 
505 } // namespace conv2d
506 
507 namespace {
508 
509 using namespace api::utils;
510 
pack_weights(const Tensor & weight_inp,const bool transposed,const bool quantized,const Conv2dMethod conv_method)511 vTensor pack_weights(
512     const Tensor& weight_inp,
513     const bool transposed,
514     const bool quantized,
515     const Conv2dMethod conv_method) {
516   if (weight_inp.is_vulkan()) {
517     return convert(weight_inp);
518   }
519 
520   const Tensor weight_arg = quantized ? at::dequantize(weight_inp) : weight_inp;
521 
522   const Tensor weight = transposed
523       ? at::permute(weight_arg, {1, 0, 2, 3}).contiguous()
524       : weight_arg.contiguous();
525 
526   at::Tensor weight_rearranged;
527   if (conv_method == Conv2dDepthwise) {
528     weight_rearranged = conv2d::rearrange_weights_dw(weight);
529   } else {
530     weight_rearranged = conv2d::rearrange_weights_2d(weight, transposed);
531   }
532 
533   vTensor v_weight{
534       api::context(),
535       weight_rearranged.sizes().vec(),
536       convert_dtype(weight_rearranged.scalar_type()),
537       api::StorageType::TEXTURE_2D,
538   };
539 
540   pack_cpu_to_vulkan(weight_rearranged, v_weight);
541 
542   return v_weight;
543 }
544 
pack_biases(const std::optional<Tensor> & bias,const Tensor & weight,const bool transposed,const bool quantized)545 vTensor pack_biases(
546     const std::optional<Tensor>& bias,
547     const Tensor& weight,
548     const bool transposed,
549     const bool quantized) {
550   at::Tensor bias_arg = conv2d::rearrange_bias(bias, weight, transposed);
551   at::Tensor bias_rearranged =
552       (quantized &&
553        (bias_arg.scalar_type() == kQUInt8 || bias_arg.scalar_type() == kQInt8 ||
554         bias_arg.scalar_type() == kQInt32))
555       ? at::dequantize(bias_arg)
556       : bias_arg;
557 
558   vTensor v_bias{
559       api::context(),
560       bias_rearranged.sizes().vec(),
561       convert_dtype(bias_rearranged.scalar_type()),
562       api::StorageType::TEXTURE_2D,
563   };
564 
565   pack_cpu_to_vulkan(bias_rearranged, v_bias);
566 
567   return v_bias;
568 }
569 
570 /*
571  * Computes the size of the overlay region when computing a convolution output.
572  */
compute_overlay_region(const Tensor & weight,const IntArrayRef dilation,const bool transposed)573 std::array<int64_t, 4> compute_overlay_region(
574     const Tensor& weight,
575     const IntArrayRef dilation,
576     const bool transposed) {
577   const IntArrayRef filter = weight.sizes();
578 
579   const auto overlay_length = [](const int64_t k, const int64_t d) {
580     return k + (k - 1) * (d - 1);
581   };
582 
583   return {
584       align_up(
585           transposed ? filter[Layout::TransposedFilter::output]
586                      : filter[Layout::Filter::output],
587           INT64_C(4)),
588       align_up(
589           transposed ? filter[Layout::TransposedFilter::input]
590                      : filter[Layout::Filter::input],
591           INT64_C(4)),
592       overlay_length(
593           filter[Layout::Filter::height], dilation[Layout::Parameter::height]),
594       overlay_length(
595           filter[Layout::Filter::width], dilation[Layout::Parameter::width]),
596   };
597 }
598 
pack_params(const std::vector<int64_t> & vector)599 std::array<int64_t, 2> pack_params(const std::vector<int64_t>& vector) {
600   TORCH_INTERNAL_ASSERT(2u == vector.size(), "Invalid usage!");
601 
602   return {
603       vector[0],
604       vector[1],
605   };
606 }
607 
weight_valid(const Tensor & weight,const bool quantized)608 bool weight_valid(const Tensor& weight, const bool quantized) {
609   if (4 != weight.ndimension()) {
610     return false;
611   }
612   if (get_dim<DimConv2DKernel::Height>(weight) == 0) {
613     return false;
614   }
615   if (get_dim<DimConv2DKernel::Width>(weight) == 0) {
616     return false;
617   }
618   if (!weight.device().is_cpu() &&
619       weight.device().type() != c10::DeviceType::Vulkan) {
620     return false;
621   }
622   if (quantized &&
623       (weight.scalar_type() != c10::kQUInt8 &&
624        weight.scalar_type() != c10::kQInt8)) {
625     return false;
626   }
627 
628   return true;
629 }
630 
bias_valid(const std::optional<Tensor> & bias,const Tensor & weight,const bool transposed,const bool quantized)631 bool bias_valid(
632     const std::optional<Tensor>& bias,
633     const Tensor& weight,
634     const bool transposed,
635     const bool quantized) {
636   if (!bias) {
637     return true;
638   }
639 
640   if (bias->ndimension() != 1) {
641     return false;
642   }
643   if (!bias->device().is_cpu() &&
644       bias->device().type() != c10::DeviceType::Vulkan) {
645     return false;
646   }
647   uint32_t L = get_dim<Dim1D::Length>(*bias);
648   uint32_t OC = transposed ? get_dim<DimTConv2DKernel::OutChannels>(weight)
649                            : get_dim<DimConv2DKernel::OutChannels>(weight);
650   if (L != OC) {
651     return false;
652   }
653 
654   return true;
655 }
656 
available(const Tensor & weight,const std::optional<Tensor> & bias,const IntArrayRef stride,const IntArrayRef padding,const IntArrayRef dilation,const bool transposed,const bool quantized,const IntArrayRef,const int64_t groups,const std::optional<Scalar> & output_min,const std::optional<Scalar> & output_max)657 bool available(
658     const Tensor& weight,
659     const std::optional<Tensor>& bias,
660     const IntArrayRef stride,
661     const IntArrayRef padding,
662     const IntArrayRef dilation,
663     const bool transposed,
664     const bool quantized,
665     const IntArrayRef /* output_padding */,
666     const int64_t groups,
667     const std::optional<Scalar>& output_min,
668     const std::optional<Scalar>& output_max) {
669   if (!weight_valid(weight, quantized)) {
670     return false;
671   }
672   if (!bias_valid(bias, weight, transposed, quantized)) {
673     return false;
674   }
675   if (get_dim<Dim4D::Height>(stride) == 0 ||
676       get_dim<Dim4D::Width>(stride) == 0) {
677     return false;
678   }
679   if (transposed) {
680     if (get_dim<Dim4D::Height>(dilation) != 1 ||
681         get_dim<Dim4D::Width>(dilation) != 1) {
682       return false;
683     }
684   } else {
685     if (get_dim<Dim4D::Height>(dilation) == 0 ||
686         get_dim<Dim4D::Width>(dilation) == 0) {
687       return false;
688     }
689   }
690   if (groups <= 0) {
691     return false;
692   }
693   if (transposed) {
694     if ((get_dim<DimTConv2DKernel::OutChannels>(weight) % groups) != 0) {
695       return false;
696     }
697   } else {
698     if ((get_dim<DimConv2DKernel::OutChannels>(weight) % groups) != 0) {
699       return false;
700     }
701   }
702   if (get_dim<DimConv2DKernel::InChannels>(weight) == 0 ||
703       get_dim<DimConv2DKernel::OutChannels>(weight) == 0) {
704     return false;
705   }
706   if (output_min && !output_min->isFloatingPoint()) {
707     return false;
708   }
709   if (output_max && !output_max->isFloatingPoint()) {
710     return false;
711   }
712   return true;
713 }
714 
usable(const Tensor & input,const bool quantized)715 bool usable(const Tensor& input, const bool quantized) {
716   if (input.ndimension() != 4) {
717     return false;
718   }
719   if (input.device().type() != c10::DeviceType::Vulkan) {
720     return false;
721   }
722   if (!quantized && input.scalar_type() != at::kFloat) {
723     return false;
724   }
725   if (quantized && input.scalar_type() != c10::kQUInt8) {
726     return false;
727   }
728   if (get_dim<Dim4D::Batch>(input) == 0) {
729     return false;
730   }
731   if (get_dim<Dim4D::Channel>(input) == 0) {
732     return false;
733   }
734   if (get_dim<Dim4D::Height>(input) == 0) {
735     return false;
736   }
737   if (get_dim<Dim4D::Width>(input) == 0) {
738     return false;
739   }
740   if (input.requires_grad()) {
741     return false;
742   }
743 
744   return true;
745 }
746 
get_conv_transpose_output_size(IntArrayRef input_size,IntArrayRef weight_size,IntArrayRef padding,IntArrayRef output_padding,IntArrayRef stride,IntArrayRef dilation=IntArrayRef ())747 static inline std::vector<int64_t> get_conv_transpose_output_size(
748     IntArrayRef input_size,
749     IntArrayRef weight_size,
750     IntArrayRef padding,
751     IntArrayRef output_padding,
752     IntArrayRef stride,
753     IntArrayRef dilation = IntArrayRef()) {
754   auto dim = input_size.size();
755   std::vector<int64_t> output_size(dim);
756   output_size[0] = input_size[input_batch_size_dim];
757   output_size[1] = weight_size[weight_input_channels_dim];
758   for (const auto d : c10::irange(2, dim)) {
759     output_size[d] = stride[d - 2] * (input_size[d] - 1) + weight_size[d] -
760         2 * padding[d - 2] + output_padding[d - 2];
761   }
762   return output_size;
763 }
764 
convolution(const Tensor & input,const Tensor & weight,const std::optional<Tensor> & bias,const IntArrayRef stride,const IntArrayRef padding,const IntArrayRef dilation,const bool transposed,const IntArrayRef output_padding,const int64_t groups)765 Tensor convolution(
766     const Tensor& input,
767     const Tensor& weight,
768     const std::optional<Tensor>& bias,
769     const IntArrayRef stride,
770     const IntArrayRef padding,
771     const IntArrayRef dilation,
772     const bool transposed,
773     const IntArrayRef output_padding,
774     const int64_t groups) {
775   Conv2dPackedContext conv_context = Conv2dPackedContext(
776       weight,
777       bias,
778       stride,
779       padding,
780       dilation,
781       transposed,
782       false,
783       output_padding,
784       groups);
785 
786   return run_conv2d_context(
787       input, c10::make_intrusive<Conv2dPackedContext>(conv_context));
788 }
789 
790 } // namespace
791 
792 namespace conv1d {
793 
pack_weights_using_width_packing(const Tensor & weight_arg)794 static vTensor pack_weights_using_width_packing(const Tensor& weight_arg) {
795   Tensor weight = weight_arg;
796 
797   if (weight.is_cpu()) {
798     weight = weight.vulkan();
799   }
800 
801   TORCH_CHECK(weight.is_vulkan(), "Weight must be on Vulkan device!");
802 
803   vTensor v_weight = convert(weight);
804   if (v_weight.gpu_memory_layout() ==
805       api::GPUMemoryLayout::TENSOR_CHANNELS_PACKED) {
806     v_weight = packing::convert_image_channels_packed_to_width_packed(v_weight);
807   }
808 
809   TORCH_CHECK(
810       v_weight.gpu_memory_layout() == api::GPUMemoryLayout::TENSOR_WIDTH_PACKED,
811       "After packing, the v_weight must be in TENSOR_WIDTH_PACKED format");
812 
813   return v_weight;
814 }
815 
816 /*
817  * This is a full implementation. For algorithm details, refer to the shader
818  * kernel code.
819  */
run_conv1d_context_impl(const Tensor & input_arg,const Tensor & weight_arg,const std::optional<Tensor> & bias_arg_opt,IntArrayRef stride,IntArrayRef padding,IntArrayRef dilation,int64_t groups)820 static Tensor run_conv1d_context_impl(
821     const Tensor& input_arg,
822     const Tensor& weight_arg,
823     const std::optional<Tensor>& bias_arg_opt,
824     IntArrayRef stride,
825     IntArrayRef padding,
826     IntArrayRef dilation,
827     int64_t groups) {
828   api::Context* const context = api::context();
829   const Tensor input = input_arg.is_vulkan() ? input_arg : input_arg.vulkan();
830   const Tensor weight =
831       weight_arg.is_vulkan() ? weight_arg : weight_arg.vulkan();
832 
833   const IntArrayRef& input_sizes = input.sizes();
834   const IntArrayRef& weight_sizes = weight.sizes();
835 
836   int32_t in_channels = static_cast<int32_t>(input_sizes[1]);
837   int32_t out_channels = static_cast<int32_t>(weight_sizes[0]);
838   int32_t kernel_size = static_cast<int32_t>(weight_sizes[2]);
839 
840   Tensor bias;
841   if (bias_arg_opt) {
842     if (bias_arg_opt->is_vulkan()) {
843       bias = bias_arg_opt.value();
844     } else {
845       bias = bias_arg_opt.value().vulkan();
846     }
847   } else {
848     bias = at::zeros({out_channels}).vulkan();
849   }
850 
851   TORCH_CHECK(input.dim() == 3, "input must be a 3-dim tensor");
852   TORCH_CHECK(weight.dim() == 3, "weight must be a 3-dim tensor");
853   TORCH_CHECK(
854       in_channels % groups == 0, "in_channels must be divisible by groups");
855   TORCH_CHECK(
856       out_channels % groups == 0, "out_channels must be divisible by groups");
857 
858   const vTensor& v_input = convert(input);
859   const vTensor& v_weight = convert(weight);
860   const vTensor& v_bias = convert(bias);
861 
862   vTensor v_output{
863       context,
864       conv_output_size(input_sizes, weight_sizes, padding, stride, dilation),
865       v_input.dtype(),
866   };
867 
868   const struct Block final {
869     int32_t in_length;
870     int32_t kernel_size;
871     int32_t stride;
872     int32_t padding;
873     int32_t dilation;
874     int32_t in_group_size;
875     int32_t out_group_size;
876     int32_t batch_size;
877   } block{
878       static_cast<int32_t>(input_sizes[2]),
879       kernel_size,
880       static_cast<int32_t>(stride[0]),
881       static_cast<int32_t>(padding[0]),
882       static_cast<int32_t>(dilation[0]),
883       static_cast<int32_t>(in_channels / groups),
884       static_cast<int32_t>(out_channels / groups),
885       static_cast<int32_t>(input_sizes[0]),
886   };
887 
888   api::UniformParamsBuffer params(context, block);
889   api::PipelineBarrier pipeline_barrier{};
890 
891   context->submit_compute_job(
892       // shader descriptor
893       VK_KERNEL(conv1d),
894       // pipeline barrier
895       pipeline_barrier,
896       // global work group size
897       {1, static_cast<uint32_t>(out_channels), 1},
898       // local work group size
899       {1, 1, 1},
900       // fence handle
901       VK_NULL_HANDLE,
902       // shader arguments
903       v_output.image(
904           pipeline_barrier,
905           api::PipelineStage::COMPUTE,
906           api::MemoryAccessType::WRITE),
907       v_input.image(pipeline_barrier, api::PipelineStage::COMPUTE),
908       v_weight.image(pipeline_barrier, api::PipelineStage::COMPUTE),
909       v_bias.image(pipeline_barrier, api::PipelineStage::COMPUTE),
910       // params buffer
911       params.buffer());
912 
913   return convert(v_output);
914 }
915 
916 } // namespace conv1d
917 
Conv2dPackedContext(const Tensor & weight,const std::optional<Tensor> & bias,const IntArrayRef stride_arg,const IntArrayRef padding_arg,const IntArrayRef dilation_arg,const bool transposed,const bool quantized,const IntArrayRef output_padding_arg,const int64_t groups,const std::optional<Scalar> & output_min,const std::optional<Scalar> & output_max)918 Conv2dPackedContext::Conv2dPackedContext(
919     const Tensor& weight,
920     const std::optional<Tensor>& bias,
921     const IntArrayRef stride_arg,
922     const IntArrayRef padding_arg,
923     const IntArrayRef dilation_arg,
924     const bool transposed,
925     const bool quantized,
926     const IntArrayRef output_padding_arg,
927     const int64_t groups,
928     const std::optional<Scalar>& output_min,
929     const std::optional<Scalar>& output_max)
930     : unpacked_{c10::AnyType::get()} {
931   const auto stride = expand_param_if_needed(stride_arg, "stride", 2);
932   const auto padding = expand_param_if_needed(padding_arg, "padding", 2);
933   const auto dilation = expand_param_if_needed(dilation_arg, "dilation", 2);
934   const auto output_padding =
935       expand_param_if_needed(output_padding_arg, "output_padding", 2);
936 
937   TORCH_CHECK(
938       available(
939           weight,
940           bias,
941           stride,
942           padding,
943           dilation,
944           transposed,
945           quantized,
946           output_padding,
947           groups,
948           output_min,
949           output_max),
950       "Vulkan::convolution not available! "
951       "Reason: The provided (weight, bias, stride, padding, dilation, groups, "
952       "transposed, output_padding, output_min, output_max) parameters are either "
953       "invalid individually or their combination is not supported by Vulkan impl.");
954 
955   const auto method = conv2d::determine_method(
956       weight.sizes(), stride, padding, dilation, groups, transposed, quantized);
957 
958   packed_.reserve(Packed::NumArgs);
959   packed_.emplace_back(
960       convert(pack_weights(weight, transposed, quantized, method)));
961   packed_.emplace_back(
962       convert(pack_biases(bias, weight, transposed, quantized)));
963   packed_.emplace_back(compute_overlay_region(weight, dilation, transposed));
964   packed_.emplace_back(pack_params(stride));
965   packed_.emplace_back(pack_params(padding));
966   packed_.emplace_back(output_padding);
967   packed_.emplace_back(pack_params(dilation));
968   packed_.emplace_back(transposed);
969   packed_.emplace_back(quantized);
970   packed_.emplace_back(safe_downcast<int32_t>(groups));
971   packed_.emplace_back(
972       output_min ? output_min->template to<float>()
973                  : -std::numeric_limits<float>::infinity());
974   packed_.emplace_back(
975       output_max ? output_max->template to<float>()
976                  : +std::numeric_limits<float>::infinity());
977   packed_.emplace_back(method);
978   packed_.emplace_back(weight.sizes().vec());
979 
980   compute_shader_ = conv2d::get_shader(
981       weight.sizes(), stride, padding, dilation, method, transposed, quantized);
982 
983   if (!at::globalContext().releaseWeightsWhenPrepacking()) {
984     unpacked_.reserve(Unpacked::NumArgs);
985     unpacked_.emplace_back(weight);
986     unpacked_.emplace_back(bias);
987     unpacked_.emplace_back(stride_arg.vec());
988     unpacked_.emplace_back(padding_arg.vec());
989     unpacked_.emplace_back(dilation_arg.vec());
990     unpacked_.emplace_back(transposed);
991     unpacked_.emplace_back(quantized);
992     unpacked_.emplace_back(output_padding_arg.vec());
993     unpacked_.emplace_back(groups);
994     unpacked_.emplace_back(output_min);
995     unpacked_.emplace_back(output_max);
996   }
997 }
998 
pack(c10::impl::GenericList unpacked)999 Conv2dPackedContext Conv2dPackedContext::pack(c10::impl::GenericList unpacked) {
1000   return Conv2dPackedContext(
1001       unpacked.get(Unpacked::Weight).toTensor(),
1002       get_optional_tensor(unpacked, Unpacked::Bias),
1003       unpacked.get(Unpacked::Stride).toIntVector(),
1004       unpacked.get(Unpacked::Padding).toIntVector(),
1005       unpacked.get(Unpacked::Dilation).toIntVector(),
1006       unpacked.get(Unpacked::isTransposed).toBool(),
1007       unpacked.get(Unpacked::isQuantized).toBool(),
1008       unpacked.get(Unpacked::OutputPadding).toIntVector(),
1009       unpacked.get(Unpacked::Groups).toInt(),
1010       get_optional_scalar(unpacked, Unpacked::OutputMin),
1011       get_optional_scalar(unpacked, Unpacked::OutputMax));
1012 }
1013 
create_conv2d_context(Tensor && weight,std::optional<Tensor> && bias,std::vector<int64_t> && stride,std::vector<int64_t> && padding,std::vector<int64_t> && dilation,const int64_t groups,const std::optional<Scalar> & output_min,const std::optional<Scalar> & output_max)1014 c10::intrusive_ptr<Conv2dPackedContext> create_conv2d_context(
1015     Tensor&& weight,
1016     std::optional<Tensor>&& bias,
1017     std::vector<int64_t>&& stride,
1018     std::vector<int64_t>&& padding,
1019     std::vector<int64_t>&& dilation,
1020     const int64_t groups,
1021     const std::optional<Scalar>& output_min,
1022     const std::optional<Scalar>& output_max) {
1023   return c10::make_intrusive<Conv2dPackedContext>(Conv2dPackedContext(
1024       weight,
1025       bias,
1026       stride,
1027       padding,
1028       dilation,
1029       /* transposed = */ false,
1030       /* quantized = */ false,
1031       /* output_padding_arg = */ {0},
1032       groups,
1033       output_min,
1034       output_max));
1035 }
1036 
create_tconv2d_context(Tensor && weight,std::optional<Tensor> && bias,std::vector<int64_t> && stride,std::vector<int64_t> && padding,std::vector<int64_t> && output_padding,std::vector<int64_t> && dilation,const int64_t groups,const std::optional<Scalar> & output_min,const std::optional<Scalar> & output_max)1037 c10::intrusive_ptr<Conv2dPackedContext> create_tconv2d_context(
1038     Tensor&& weight,
1039     std::optional<Tensor>&& bias,
1040     std::vector<int64_t>&& stride,
1041     std::vector<int64_t>&& padding,
1042     std::vector<int64_t>&& output_padding,
1043     std::vector<int64_t>&& dilation,
1044     const int64_t groups,
1045     const std::optional<Scalar>& output_min,
1046     const std::optional<Scalar>& output_max) {
1047   return c10::make_intrusive<Conv2dPackedContext>(Conv2dPackedContext(
1048       weight,
1049       bias,
1050       stride,
1051       padding,
1052       dilation,
1053       /* transposed = */ true,
1054       /* quantized = */ false,
1055       output_padding,
1056       groups,
1057       output_min,
1058       output_max));
1059 }
1060 
create_qconv2d_context(Tensor && weight,std::optional<Tensor> && bias,std::vector<int64_t> && stride,std::vector<int64_t> && padding,std::vector<int64_t> && dilation,const int64_t groups,const std::optional<Scalar> & output_min,const std::optional<Scalar> & output_max)1061 c10::intrusive_ptr<Conv2dPackedContext> create_qconv2d_context(
1062     Tensor&& weight,
1063     std::optional<Tensor>&& bias,
1064     std::vector<int64_t>&& stride,
1065     std::vector<int64_t>&& padding,
1066     std::vector<int64_t>&& dilation,
1067     const int64_t groups,
1068     const std::optional<Scalar>& output_min,
1069     const std::optional<Scalar>& output_max) {
1070   return c10::make_intrusive<Conv2dPackedContext>(Conv2dPackedContext(
1071       weight,
1072       bias,
1073       stride,
1074       padding,
1075       dilation,
1076       /* transposed = */ false,
1077       /* quantized = */ true,
1078       /* output_padding_arg = */ {0},
1079       groups,
1080       output_min,
1081       output_max));
1082 }
1083 
create_qtconv2d_context(Tensor && weight,std::optional<Tensor> && bias,std::vector<int64_t> && stride,std::vector<int64_t> && padding,std::vector<int64_t> && output_padding,std::vector<int64_t> && dilation,const int64_t groups,const std::optional<Scalar> & output_min,const std::optional<Scalar> & output_max)1084 c10::intrusive_ptr<Conv2dPackedContext> create_qtconv2d_context(
1085     Tensor&& weight,
1086     std::optional<Tensor>&& bias,
1087     std::vector<int64_t>&& stride,
1088     std::vector<int64_t>&& padding,
1089     std::vector<int64_t>&& output_padding,
1090     std::vector<int64_t>&& dilation,
1091     const int64_t groups,
1092     const std::optional<Scalar>& output_min,
1093     const std::optional<Scalar>& output_max) {
1094   return c10::make_intrusive<Conv2dPackedContext>(Conv2dPackedContext(
1095       weight,
1096       bias,
1097       stride,
1098       padding,
1099       dilation,
1100       /* transposed = */ true,
1101       /* quantized = */ true,
1102       output_padding,
1103       groups,
1104       output_min,
1105       output_max));
1106 }
1107 
run_conv2d_context_impl(const Tensor & input_arg,const c10::intrusive_ptr<Conv2dPackedContext> & conv_context,double scale,int64_t zero_point)1108 static Tensor run_conv2d_context_impl(
1109     const Tensor& input_arg,
1110     const c10::intrusive_ptr<Conv2dPackedContext>& conv_context,
1111     double scale,
1112     int64_t zero_point) {
1113   api::Context* const context = api::context();
1114   // Validate input tensor is a Vulkan tensor, then convert to vTensor
1115   TORCH_CHECK(input_arg.is_vulkan(), "Input tensor must be Vulkan!");
1116   const vTensor& v_input = convert(input_arg);
1117 
1118   // Extract everything from the PackedContext
1119   const Tensor weight =
1120       conv_context->get_val(Conv2dPackedContext::Packed::Weight).toTensor();
1121   const vTensor& v_weight = convert(weight);
1122 
1123   const auto quantized =
1124       conv_context->get_val(Conv2dPackedContext::Packed::isQuantized).toBool();
1125 
1126   Tensor bias =
1127       conv_context->get_val(Conv2dPackedContext::Packed::Bias).toTensor();
1128 
1129   const vTensor& v_bias = convert(bias);
1130 
1131   const auto overlay_region =
1132       conv_context->get_val(Conv2dPackedContext::Packed::OverlayRegion)
1133           .toIntVector();
1134 
1135   const auto stride =
1136       conv_context->get_val(Conv2dPackedContext::Packed::Stride).toIntVector();
1137   const auto padding =
1138       conv_context->get_val(Conv2dPackedContext::Packed::Padding).toIntVector();
1139   const auto output_padding =
1140       conv_context->get_val(Conv2dPackedContext::Packed::OutputPadding)
1141           .toIntVector();
1142   const auto dilation =
1143       conv_context->get_val(Conv2dPackedContext::Packed::Dilation)
1144           .toIntVector();
1145 
1146   const auto transposed =
1147       conv_context->get_val(Conv2dPackedContext::Packed::isTransposed).toBool();
1148 
1149   const float output_min = safe_downcast<float>(
1150       conv_context->get_val(Conv2dPackedContext::Packed::OutputMin).toDouble());
1151   const float output_max = safe_downcast<float>(
1152       conv_context->get_val(Conv2dPackedContext::Packed::OutputMax).toDouble());
1153 
1154   const Conv2dMethod method_ = static_cast<Conv2dMethod>(
1155       conv_context->get_val(Conv2dPackedContext::Packed::ConvMethod).toInt());
1156 
1157   const auto kernel_size =
1158       conv_context->get_val(Conv2dPackedContext::Packed::WeightSizes)
1159           .toIntVector();
1160 
1161   TORCH_CHECK(
1162       usable(input_arg, quantized), "Input tensor not usable for convolution!");
1163 
1164   std::vector<int64_t> output_size;
1165   if (transposed) {
1166     output_size = get_conv_transpose_output_size(
1167         v_input.sizes(),
1168         kernel_size,
1169         padding,
1170         output_padding,
1171         stride,
1172         dilation);
1173   } else {
1174     output_size = conv_output_size(
1175         v_input.sizes(), kernel_size, padding, stride, dilation);
1176   }
1177 
1178   vTensor v_output{
1179       context,
1180       output_size,
1181       v_input.dtype(),
1182   };
1183 
1184   if (quantized) {
1185     v_output.set_is_quantized();
1186     v_output.set_scale(scale);
1187     v_output.set_zero_point(zero_point);
1188   }
1189 
1190   if (quantized) {
1191     conv2d::record_quantized_op(
1192         context,
1193         conv_context->compute_shader(),
1194         v_output,
1195         v_input,
1196         v_weight,
1197         v_bias,
1198         overlay_region,
1199         stride,
1200         padding,
1201         dilation,
1202         output_min,
1203         output_max,
1204         kernel_size,
1205         method_,
1206         transposed);
1207   } else {
1208     conv2d::record_op(
1209         context,
1210         conv_context->compute_shader(),
1211         v_output,
1212         v_input,
1213         v_weight,
1214         v_bias,
1215         overlay_region,
1216         stride,
1217         padding,
1218         dilation,
1219         output_min,
1220         output_max,
1221         kernel_size,
1222         method_,
1223         transposed);
1224   }
1225 
1226   return convert(v_output);
1227 }
1228 
run_conv2d_context(const Tensor & input_arg,const c10::intrusive_ptr<Conv2dPackedContext> & conv_context)1229 Tensor run_conv2d_context(
1230     const Tensor& input_arg,
1231     const c10::intrusive_ptr<Conv2dPackedContext>& conv_context) {
1232   return run_conv2d_context_impl(input_arg, conv_context, 1.0f, 0u);
1233 }
1234 
run_tconv2d_context(const Tensor & input_arg,const c10::intrusive_ptr<Conv2dPackedContext> & conv_context)1235 Tensor run_tconv2d_context(
1236     const Tensor& input_arg,
1237     const c10::intrusive_ptr<Conv2dPackedContext>& conv_context) {
1238   return run_conv2d_context_impl(input_arg, conv_context, 1.0f, 0u);
1239 }
1240 
run_qconv2d_context(const Tensor & input_arg,double scale,int64_t zero_point,const c10::intrusive_ptr<Conv2dPackedContext> & conv_context)1241 Tensor run_qconv2d_context(
1242     const Tensor& input_arg,
1243     double scale,
1244     int64_t zero_point,
1245     const c10::intrusive_ptr<Conv2dPackedContext>& conv_context) {
1246   return run_conv2d_context_impl(input_arg, conv_context, scale, zero_point);
1247 }
1248 
1249 /* Backwards compatibility */
Conv2dOpContext(Conv2dPackedContext conv_context)1250 Conv2dOpContext::Conv2dOpContext(Conv2dPackedContext conv_context)
1251     : conv_context_{std::move(conv_context)} {}
1252 
create(const Tensor & weight,const std::optional<Tensor> & bias,const IntArrayRef stride_arg,const IntArrayRef padding_arg,const IntArrayRef dilation_arg,const bool transposed,const IntArrayRef output_padding_arg,const int64_t groups,const std::optional<Scalar> & output_min,const std::optional<Scalar> & output_max)1253 Conv2dOpContext Conv2dOpContext::create(
1254     const Tensor& weight,
1255     const std::optional<Tensor>& bias,
1256     const IntArrayRef stride_arg,
1257     const IntArrayRef padding_arg,
1258     const IntArrayRef dilation_arg,
1259     const bool transposed,
1260     const IntArrayRef output_padding_arg,
1261     const int64_t groups,
1262     const std::optional<Scalar>& output_min,
1263     const std::optional<Scalar>& output_max) {
1264   return Conv2dOpContext{Conv2dPackedContext(
1265       weight,
1266       bias,
1267       stride_arg,
1268       padding_arg,
1269       dilation_arg,
1270       transposed,
1271       /* quantized = */ false,
1272       output_padding_arg,
1273       groups,
1274       output_min,
1275       output_max)};
1276 }
1277 
run(const Tensor & input_arg) const1278 Tensor Conv2dOpContext::run(const Tensor& input_arg) const {
1279   return run_conv2d_context(
1280       input_arg, c10::make_intrusive<Conv2dPackedContext>(conv_context_));
1281 }
1282 
unpack() const1283 Conv2dOpContext::State Conv2dOpContext::unpack() const {
1284   const c10::impl::GenericList unpacked_ = conv_context_.unpack();
1285 
1286   TORCH_CHECK(unpacked_.size() > 0u, "unpacked_ does not have any elements!");
1287 
1288   return Conv2dOpContext::State(
1289       unpacked_.get(Conv2dPackedContext::Unpacked::Weight).toTensor(),
1290       get_optional_tensor(unpacked_, Conv2dPackedContext::Unpacked::Bias),
1291       unpacked_.get(Conv2dPackedContext::Unpacked::Stride).toIntVector(),
1292       unpacked_.get(Conv2dPackedContext::Unpacked::Padding).toIntVector(),
1293       unpacked_.get(Conv2dPackedContext::Unpacked::Dilation).toIntVector(),
1294       unpacked_.get(Conv2dPackedContext::Unpacked::Groups).toInt(),
1295       get_optional_scalar(unpacked_, Conv2dPackedContext::Unpacked::OutputMin),
1296       get_optional_scalar(unpacked_, Conv2dPackedContext::Unpacked::OutputMax));
1297 }
1298 
conv2d_clamp_prepack(Tensor && weight,std::optional<Tensor> && bias,std::vector<int64_t> && stride,std::vector<int64_t> && padding,std::vector<int64_t> && dilation,const int64_t groups,const std::optional<Scalar> & output_min,const std::optional<Scalar> & output_max)1299 c10::intrusive_ptr<Conv2dOpContext> conv2d_clamp_prepack(
1300     Tensor&& weight,
1301     std::optional<Tensor>&& bias,
1302     std::vector<int64_t>&& stride,
1303     std::vector<int64_t>&& padding,
1304     std::vector<int64_t>&& dilation,
1305     const int64_t groups,
1306     const std::optional<Scalar>& output_min,
1307     const std::optional<Scalar>& output_max) {
1308   return c10::make_intrusive<Conv2dOpContext>(Conv2dOpContext::create(
1309       std::move(weight),
1310       std::move(bias),
1311       std::move(stride),
1312       std::move(padding),
1313       std::move(dilation),
1314       /* transposed = */ false,
1315       /* output_padding = */ {0},
1316       groups,
1317       output_min,
1318       output_max));
1319 }
1320 
conv2d_clamp_run(const Tensor & input,const c10::intrusive_ptr<Conv2dOpContext> & context)1321 Tensor conv2d_clamp_run(
1322     const Tensor& input,
1323     const c10::intrusive_ptr<Conv2dOpContext>& context) {
1324   return context->run(input);
1325 }
1326 
Conv1dPackedContext(const Tensor & weight,const std::optional<Tensor> & bias,const IntArrayRef stride_arg,const IntArrayRef padding_arg,const IntArrayRef dilation_arg,const int64_t groups)1327 Conv1dPackedContext::Conv1dPackedContext(
1328     const Tensor& weight,
1329     const std::optional<Tensor>& bias,
1330     const IntArrayRef stride_arg,
1331     const IntArrayRef padding_arg,
1332     const IntArrayRef dilation_arg,
1333     const int64_t groups)
1334     : unpacked_{c10::AnyType::get()} {
1335   packed_.reserve(Packed::NumArgs);
1336   packed_.emplace_back(
1337       convert(conv1d::pack_weights_using_width_packing(weight.vulkan())));
1338   packed_.emplace_back(bias->vulkan());
1339   packed_.emplace_back(stride_arg);
1340   packed_.emplace_back(padding_arg);
1341   packed_.emplace_back(dilation_arg);
1342   packed_.emplace_back(safe_downcast<int32_t>(groups));
1343 
1344   compute_shader_ = VK_KERNEL(conv1d);
1345 
1346   if (!at::globalContext().releaseWeightsWhenPrepacking()) {
1347     unpacked_.reserve(Unpacked::NumArgs);
1348     unpacked_.emplace_back(weight);
1349     unpacked_.emplace_back(bias);
1350     unpacked_.emplace_back(stride_arg.vec());
1351     unpacked_.emplace_back(padding_arg.vec());
1352     unpacked_.emplace_back(dilation_arg.vec());
1353     unpacked_.emplace_back(safe_downcast<int32_t>(groups));
1354   }
1355 }
1356 
pack(c10::impl::GenericList unpacked)1357 Conv1dPackedContext Conv1dPackedContext::pack(c10::impl::GenericList unpacked) {
1358   return Conv1dPackedContext(
1359       unpacked.get(Unpacked::Weight).toTensor(),
1360       get_optional_tensor(unpacked, Unpacked::Bias),
1361       unpacked.get(Unpacked::Stride).toIntVector(),
1362       unpacked.get(Unpacked::Padding).toIntVector(),
1363       unpacked.get(Unpacked::Dilation).toIntVector(),
1364       unpacked.get(Unpacked::Groups).toInt());
1365 }
1366 
create_conv1d_context(Tensor && weight,std::optional<Tensor> && bias,std::vector<int64_t> && stride,std::vector<int64_t> && padding,std::vector<int64_t> && dilation,const int64_t groups)1367 c10::intrusive_ptr<Conv1dPackedContext> create_conv1d_context(
1368     Tensor&& weight,
1369     std::optional<Tensor>&& bias,
1370     std::vector<int64_t>&& stride,
1371     std::vector<int64_t>&& padding,
1372     std::vector<int64_t>&& dilation,
1373     const int64_t groups) {
1374   return c10::make_intrusive<Conv1dPackedContext>(
1375       Conv1dPackedContext(weight, bias, stride, padding, dilation, groups));
1376 }
1377 
convolution1d(const Tensor & input,const Tensor & weight,const std::optional<Tensor> & bias,const IntArrayRef stride,const IntArrayRef padding,const IntArrayRef dilation,const int64_t groups)1378 static Tensor convolution1d(
1379     const Tensor& input,
1380     const Tensor& weight,
1381     const std::optional<Tensor>& bias,
1382     const IntArrayRef stride,
1383     const IntArrayRef padding,
1384     const IntArrayRef dilation,
1385     const int64_t groups) {
1386   Conv1dPackedContext conv1d_context =
1387       Conv1dPackedContext(weight, bias, stride, padding, dilation, groups);
1388 
1389   return run_conv1d_context(
1390       input, c10::make_intrusive<Conv1dPackedContext>(conv1d_context));
1391 }
1392 
run_conv1d_context(const Tensor & input,const c10::intrusive_ptr<Conv1dPackedContext> & context)1393 Tensor run_conv1d_context(
1394     const Tensor& input,
1395     const c10::intrusive_ptr<Conv1dPackedContext>& context) {
1396   const Tensor weight =
1397       context->get_val(Conv1dPackedContext::Packed::Weight).toTensor();
1398   const std::optional<Tensor>& bias_opt =
1399       context->get_val(Conv1dPackedContext::Packed::Bias).toTensor();
1400   const auto stride =
1401       context->get_val(Conv1dPackedContext::Packed::Stride).toIntVector();
1402   const auto padding =
1403       context->get_val(Conv1dPackedContext::Packed::Padding).toIntVector();
1404   const auto dilation =
1405       context->get_val(Conv1dPackedContext::Packed::Dilation).toIntVector();
1406   const auto groups =
1407       context->get_val(Conv1dPackedContext::Packed::Groups).toInt();
1408   return conv1d::run_conv1d_context_impl(
1409       input, weight, bias_opt, stride, padding, dilation, groups);
1410 }
1411 
TORCH_LIBRARY_IMPL(aten,Vulkan,m)1412 TORCH_LIBRARY_IMPL(aten, Vulkan, m) {
1413   m.impl("convolution_overrideable", convolution);
1414   m.impl(TORCH_SELECTIVE_NAME("aten::conv1d"), TORCH_FN(convolution1d));
1415 }
1416 
1417 } // namespace ops
1418 } // namespace vulkan
1419 } // namespace native
1420 } // namespace at
1421