xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/vulkan/impl/Packing.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <ATen/native/vulkan/api/Types.h>
2 #include <ATen/native/vulkan/api/Utils.h>
3 #include <ATen/native/vulkan/impl/Common.h>
4 #include <ATen/native/vulkan/impl/Packing.h>
5 
6 namespace at {
7 namespace native {
8 namespace vulkan {
9 namespace packing {
10 
get_nchw_to_image_shader(const vTensor & v_dst)11 api::ShaderInfo get_nchw_to_image_shader(const vTensor& v_dst) {
12   if (v_dst.is_quantized()) {
13     switch (v_dst.storage_type()) {
14       case api::StorageType::TEXTURE_3D:
15         switch (v_dst.dtype()) {
16           case api::ScalarType::QUInt8:
17             return VK_KERNEL(nchw_to_image_uint8);
18           case api::ScalarType::QInt8:
19             return VK_KERNEL(nchw_to_image_int8);
20           case api::ScalarType::QInt32:
21             return VK_KERNEL(nchw_to_image_int32);
22           default:
23             VK_THROW(
24                 "Vulkan quantization currently not supported for dtype ",
25                 v_dst.dtype());
26         }
27       case api::StorageType::TEXTURE_2D:
28         switch (v_dst.dtype()) {
29           case api::ScalarType::QUInt8:
30             return VK_KERNEL(nchw_to_image2d_uint8);
31           case api::ScalarType::QInt8:
32             return VK_KERNEL(nchw_to_image2d_int8);
33           case api::ScalarType::QInt32:
34             return VK_KERNEL(nchw_to_image2d_int32);
35           default:
36             VK_THROW(
37                 "Vulkan quantization currently not supported for dtype ",
38                 v_dst.dtype());
39         }
40       default:
41         VK_THROW("No kernel available!");
42       case api::StorageType::BUFFER:
43       case api::StorageType::UNKNOWN:
44         VK_THROW("Requested storage type must be a texture type.");
45     }
46   }
47 
48   if (v_dst.dtype() == api::kFloat) {
49     switch (v_dst.storage_type()) {
50       case api::StorageType::TEXTURE_3D:
51         return VK_KERNEL(nchw_to_image);
52       case api::StorageType::TEXTURE_2D:
53         return VK_KERNEL(nchw_to_image2d);
54       default:
55         VK_THROW("No kernel available!");
56     }
57   } else if (v_dst.dtype() == api::kBool) {
58     switch (v_dst.storage_type()) {
59       case api::StorageType::TEXTURE_3D:
60         return VK_KERNEL(nchw_to_image_bool);
61       default:
62         VK_THROW("No kernel available!");
63     }
64   } else {
65     VK_THROW("Unsupported dtype!");
66   }
67 }
68 
get_image_to_nchw_shader(const vTensor & v_src)69 api::ShaderInfo get_image_to_nchw_shader(const vTensor& v_src) {
70   if (v_src.is_quantized() || v_src.dtype() == api::kBool) {
71     auto plane_size =
72         dim_at<Dim4D::Height>(v_src) * dim_at<Dim4D::Width>(v_src);
73     switch (v_src.storage_type()) {
74       case api::StorageType::TEXTURE_3D:
75         switch (v_src.dtype()) {
76           case api::ScalarType::QUInt8:
77           case api::ScalarType::QInt8:
78           case api::kBool:
79             return plane_size % 4 == 0 ? VK_KERNEL(image_to_nchw_quantized_mul4)
80                                        : VK_KERNEL(image_to_nchw_uint);
81           case api::ScalarType::QInt32:
82             return VK_KERNEL(image_to_nchw_int32);
83           default:
84             VK_THROW(
85                 "Vulkan quantization currently not supported for dtype ",
86                 v_src.dtype());
87         }
88       default:
89         VK_THROW("No kernel available!");
90       case api::StorageType::BUFFER:
91       case api::StorageType::UNKNOWN:
92         VK_THROW("Requested storage type must be a texture type.");
93     }
94   }
95 
96   if (v_src.dtype() == api::kFloat) {
97     switch (v_src.storage_type()) {
98       case api::StorageType::TEXTURE_3D:
99         return VK_KERNEL(image_to_nchw);
100       case api::StorageType::TEXTURE_2D:
101         return VK_KERNEL(image2d_to_nchw);
102       default:
103         VK_THROW("No kernel available!");
104     }
105   } else {
106     VK_THROW("Unsupported dtype!");
107   }
108 }
109 
110 struct ToFromTextureParams final {
111   api::utils::ivec3 extents;
112   int32_t planeSize;
113   api::utils::ivec2 channelInfo;
114 };
115 
record_nchw_to_image_op(api::Context * const context,api::ShaderInfo & compute_shader,api::VulkanBuffer & src_buffer,vTensor & v_dst,api::PipelineBarrier pipeline_barrier,VkFence fence_handle)116 void record_nchw_to_image_op(
117     api::Context* const context,
118     api::ShaderInfo& compute_shader,
119     api::VulkanBuffer& src_buffer,
120     vTensor& v_dst,
121     api::PipelineBarrier pipeline_barrier,
122     VkFence fence_handle) {
123   api::utils::uvec3 global_size = v_dst.extents();
124   api::utils::uvec3 local_size = adaptive_work_group_size(global_size);
125 
126   int32_t height =
127       api::utils::safe_downcast<int32_t>(dim_at<Dim4D::Height>(v_dst));
128   int32_t width =
129       api::utils::safe_downcast<int32_t>(dim_at<Dim4D::Width>(v_dst));
130   int32_t channels =
131       api::utils::safe_downcast<int32_t>(dim_at<Dim4D::Channel>(v_dst));
132 
133   int32_t plane_size = height * width;
134   int32_t c_depth = api::utils::div_up(channels, 4);
135 
136   ToFromTextureParams block{
137       api::utils::make_ivec3(v_dst.extents()),
138       plane_size,
139       {c_depth, channels},
140   };
141 
142   api::UniformParamsBuffer params(context, block);
143   context->submit_compute_job(
144       // shader descriptor
145       compute_shader,
146       // pipeline barrier
147       pipeline_barrier,
148       // global work group size
149       global_size,
150       // local work group size
151       local_size,
152       // fence handle
153       fence_handle,
154       // shader arguments
155       v_dst.image(
156           pipeline_barrier,
157           api::PipelineStage::COMPUTE,
158           api::MemoryAccessType::WRITE),
159       src_buffer,
160       // params buffer
161       params.buffer());
162 }
163 
record_image_to_nchw_op(api::Context * const context,api::ShaderInfo & compute_shader,vTensor & v_src,api::VulkanBuffer & dst_buffer,api::PipelineBarrier pipeline_barrier,VkFence fence_handle)164 bool record_image_to_nchw_op(
165     api::Context* const context,
166     api::ShaderInfo& compute_shader,
167     vTensor& v_src,
168     api::VulkanBuffer& dst_buffer,
169     api::PipelineBarrier pipeline_barrier,
170     VkFence fence_handle) {
171   api::utils::uvec3 global_size = v_src.extents();
172   api::utils::uvec3 local_size = adaptive_work_group_size(global_size);
173 
174   int32_t height =
175       api::utils::safe_downcast<int32_t>(dim_at<Dim4D::Height>(v_src));
176   int32_t width =
177       api::utils::safe_downcast<int32_t>(dim_at<Dim4D::Width>(v_src));
178   int32_t channels =
179       api::utils::safe_downcast<int32_t>(dim_at<Dim4D::Channel>(v_src));
180 
181   int32_t plane_size = height * width;
182   int32_t c_depth = api::utils::div_up(channels, 4);
183 
184   ToFromTextureParams block{
185       api::utils::make_ivec3(v_src.extents()),
186       plane_size,
187       {c_depth, channels},
188   };
189 
190   if (v_src.dtype() == api::ScalarType::QUInt8 ||
191       v_src.dtype() == api::ScalarType::QInt8 || v_src.dtype() == api::kBool) {
192     // Special case using optimized shader, image_to_nchw_quantized_mul4
193     if (plane_size % 4 == 0) {
194       global_size.data[0u] = plane_size / 4;
195       global_size.data[1u] = 1;
196       local_size.data[0u] *= local_size.data[1u];
197       local_size.data[1u] = 1;
198     }
199     // Global and local size for regular 1D buffer.
200     else {
201       uint32_t numel = v_src.numel();
202       global_size = {api::utils::div_up(numel, uint32_t(4)), 1u, 1u};
203       local_size = {64u, 1u, 1u};
204     }
205   }
206 
207   api::UniformParamsBuffer params(context, block);
208   return context->submit_compute_job(
209       // shader descriptor
210       compute_shader,
211       // pipeline barrier
212       pipeline_barrier,
213       // global work group size
214       global_size,
215       // local work group size
216       local_size,
217       // fence handle
218       fence_handle,
219       // shader arguments
220       v_src.image(
221           pipeline_barrier,
222           api::PipelineStage::COMPUTE,
223           api::MemoryAccessType::WRITE),
224       dst_buffer,
225       // params buffer
226       params.buffer());
227 }
228 
record_nchw_to_buffer_op(api::Context * const context,api::VulkanBuffer & src_buffer,vTensor & v_dst,api::PipelineBarrier pipeline_barrier,VkFence fence_handle)229 void record_nchw_to_buffer_op(
230     api::Context* const context,
231     api::VulkanBuffer& src_buffer,
232     vTensor& v_dst,
233     api::PipelineBarrier pipeline_barrier,
234     VkFence fence_handle) {
235   uint32_t gpu_buf_len = api::utils::safe_downcast<uint32_t>(v_dst.gpu_numel());
236 
237   api::utils::uvec3 global_size = {gpu_buf_len, 1u, 1u};
238   api::utils::uvec3 local_size = {32u, 1u, 1u};
239 
240   api::UniformParamsBuffer cpu_buffer_metadata(
241       context, v_dst.get_cpu_buffer_metadata());
242 
243   context->submit_compute_job(
244       // shader descriptor
245       VK_KERNEL(buffer_to_buffer),
246       // pipeline barrier
247       pipeline_barrier,
248       // global work group size
249       global_size,
250       // local work group size
251       local_size,
252       // fence handle
253       fence_handle,
254       // shader arguments
255       v_dst.buffer(
256           pipeline_barrier,
257           api::PipelineStage::COMPUTE,
258           api::MemoryAccessType::WRITE),
259       v_dst.buffer_metadata(),
260       src_buffer,
261       cpu_buffer_metadata.buffer());
262 }
263 
record_buffer_to_nchw_op(api::Context * const context,vTensor & v_src,api::VulkanBuffer & dst_buffer,api::PipelineBarrier pipeline_barrier,VkFence fence_handle)264 bool record_buffer_to_nchw_op(
265     api::Context* const context,
266     vTensor& v_src,
267     api::VulkanBuffer& dst_buffer,
268     api::PipelineBarrier pipeline_barrier,
269     VkFence fence_handle) {
270   uint32_t buf_len = api::utils::safe_downcast<uint32_t>(v_src.numel());
271 
272   api::utils::uvec3 global_size = {buf_len, 1u, 1u};
273   api::utils::uvec3 local_size = {4u, 1u, 1u};
274 
275   api::UniformParamsBuffer cpu_buffer_metadata(
276       context, v_src.get_cpu_buffer_metadata());
277 
278   return context->submit_compute_job(
279       // shader descriptor
280       VK_KERNEL(buffer_to_buffer),
281       // pipeline barrier
282       pipeline_barrier,
283       // global work group size
284       global_size,
285       // local work group size
286       local_size,
287       // fence handle
288       fence_handle,
289       // shader arguments
290       dst_buffer,
291       cpu_buffer_metadata.buffer(),
292       v_src.buffer(
293           pipeline_barrier,
294           api::PipelineStage::COMPUTE,
295           api::MemoryAccessType::WRITE),
296       v_src.buffer_metadata());
297 }
298 
channel_image_repacking(const vTensor & v_input,api::GPUMemoryLayout target_layout,const api::ShaderInfo & shader_descriptor)299 static vTensor channel_image_repacking(
300     const vTensor& v_input,
301     api::GPUMemoryLayout target_layout,
302     const api::ShaderInfo& shader_descriptor) {
303   api::Context* const context = api::context();
304 
305   vTensor v_output{
306       context,
307       v_input.sizes(),
308       v_input.dtype(),
309       v_input.storage_type(),
310       target_layout,
311   };
312 
313   // Required to determine how to insert memory barriers in the command buffer
314   api::PipelineBarrier pipeline_barrier{};
315 
316   // The shader assumes a 4d nchw to calculate the lookup coordinate.
317   // If the input is not 4d, we need to pad it with 1's on the front.
318   const struct Block final {
319     api::utils::ivec4 sizes;
320   } block{
321       api::utils::make_ivec4_prepadded1(v_input.sizes()),
322   };
323 
324   api::UniformParamsBuffer params(context, block);
325 
326   context->submit_compute_job(
327       // shader descriptor
328       // VK_KERNEL(packing_channel_to_height),
329       shader_descriptor,
330       // pipeline barrier
331       pipeline_barrier,
332       // global work group size
333       v_output.extents(),
334       // local work group size
335       adaptive_work_group_size(v_output.extents()),
336       // fence handle
337       VK_NULL_HANDLE,
338       // shader arguments
339       v_output.image(
340           pipeline_barrier,
341           api::PipelineStage::COMPUTE,
342           api::MemoryAccessType::WRITE),
343       v_input.image(pipeline_barrier, api::PipelineStage::COMPUTE),
344       // params buffer
345       params.buffer());
346 
347   return v_output;
348 }
349 
convert_image_channels_packed_to_height_packed(const vTensor & v_input)350 vTensor convert_image_channels_packed_to_height_packed(const vTensor& v_input) {
351   return channel_image_repacking(
352       v_input,
353       api::GPUMemoryLayout::TENSOR_HEIGHT_PACKED,
354       VK_KERNEL(convert_channels_to_height_packed));
355 }
356 
convert_image_channels_packed_to_width_packed(const vTensor & v_input)357 vTensor convert_image_channels_packed_to_width_packed(const vTensor& v_input) {
358   return channel_image_repacking(
359       v_input,
360       api::GPUMemoryLayout::TENSOR_WIDTH_PACKED,
361       VK_KERNEL(convert_channels_to_width_packed));
362 }
363 
364 } // namespace packing
365 } // namespace vulkan
366 } // namespace native
367 } // namespace at
368