1 #include <ATen/native/vulkan/api/Types.h>
2 #include <ATen/native/vulkan/api/Utils.h>
3 #include <ATen/native/vulkan/impl/Common.h>
4 #include <ATen/native/vulkan/impl/Packing.h>
5
6 namespace at {
7 namespace native {
8 namespace vulkan {
9 namespace packing {
10
get_nchw_to_image_shader(const vTensor & v_dst)11 api::ShaderInfo get_nchw_to_image_shader(const vTensor& v_dst) {
12 if (v_dst.is_quantized()) {
13 switch (v_dst.storage_type()) {
14 case api::StorageType::TEXTURE_3D:
15 switch (v_dst.dtype()) {
16 case api::ScalarType::QUInt8:
17 return VK_KERNEL(nchw_to_image_uint8);
18 case api::ScalarType::QInt8:
19 return VK_KERNEL(nchw_to_image_int8);
20 case api::ScalarType::QInt32:
21 return VK_KERNEL(nchw_to_image_int32);
22 default:
23 VK_THROW(
24 "Vulkan quantization currently not supported for dtype ",
25 v_dst.dtype());
26 }
27 case api::StorageType::TEXTURE_2D:
28 switch (v_dst.dtype()) {
29 case api::ScalarType::QUInt8:
30 return VK_KERNEL(nchw_to_image2d_uint8);
31 case api::ScalarType::QInt8:
32 return VK_KERNEL(nchw_to_image2d_int8);
33 case api::ScalarType::QInt32:
34 return VK_KERNEL(nchw_to_image2d_int32);
35 default:
36 VK_THROW(
37 "Vulkan quantization currently not supported for dtype ",
38 v_dst.dtype());
39 }
40 default:
41 VK_THROW("No kernel available!");
42 case api::StorageType::BUFFER:
43 case api::StorageType::UNKNOWN:
44 VK_THROW("Requested storage type must be a texture type.");
45 }
46 }
47
48 if (v_dst.dtype() == api::kFloat) {
49 switch (v_dst.storage_type()) {
50 case api::StorageType::TEXTURE_3D:
51 return VK_KERNEL(nchw_to_image);
52 case api::StorageType::TEXTURE_2D:
53 return VK_KERNEL(nchw_to_image2d);
54 default:
55 VK_THROW("No kernel available!");
56 }
57 } else if (v_dst.dtype() == api::kBool) {
58 switch (v_dst.storage_type()) {
59 case api::StorageType::TEXTURE_3D:
60 return VK_KERNEL(nchw_to_image_bool);
61 default:
62 VK_THROW("No kernel available!");
63 }
64 } else {
65 VK_THROW("Unsupported dtype!");
66 }
67 }
68
get_image_to_nchw_shader(const vTensor & v_src)69 api::ShaderInfo get_image_to_nchw_shader(const vTensor& v_src) {
70 if (v_src.is_quantized() || v_src.dtype() == api::kBool) {
71 auto plane_size =
72 dim_at<Dim4D::Height>(v_src) * dim_at<Dim4D::Width>(v_src);
73 switch (v_src.storage_type()) {
74 case api::StorageType::TEXTURE_3D:
75 switch (v_src.dtype()) {
76 case api::ScalarType::QUInt8:
77 case api::ScalarType::QInt8:
78 case api::kBool:
79 return plane_size % 4 == 0 ? VK_KERNEL(image_to_nchw_quantized_mul4)
80 : VK_KERNEL(image_to_nchw_uint);
81 case api::ScalarType::QInt32:
82 return VK_KERNEL(image_to_nchw_int32);
83 default:
84 VK_THROW(
85 "Vulkan quantization currently not supported for dtype ",
86 v_src.dtype());
87 }
88 default:
89 VK_THROW("No kernel available!");
90 case api::StorageType::BUFFER:
91 case api::StorageType::UNKNOWN:
92 VK_THROW("Requested storage type must be a texture type.");
93 }
94 }
95
96 if (v_src.dtype() == api::kFloat) {
97 switch (v_src.storage_type()) {
98 case api::StorageType::TEXTURE_3D:
99 return VK_KERNEL(image_to_nchw);
100 case api::StorageType::TEXTURE_2D:
101 return VK_KERNEL(image2d_to_nchw);
102 default:
103 VK_THROW("No kernel available!");
104 }
105 } else {
106 VK_THROW("Unsupported dtype!");
107 }
108 }
109
110 struct ToFromTextureParams final {
111 api::utils::ivec3 extents;
112 int32_t planeSize;
113 api::utils::ivec2 channelInfo;
114 };
115
record_nchw_to_image_op(api::Context * const context,api::ShaderInfo & compute_shader,api::VulkanBuffer & src_buffer,vTensor & v_dst,api::PipelineBarrier pipeline_barrier,VkFence fence_handle)116 void record_nchw_to_image_op(
117 api::Context* const context,
118 api::ShaderInfo& compute_shader,
119 api::VulkanBuffer& src_buffer,
120 vTensor& v_dst,
121 api::PipelineBarrier pipeline_barrier,
122 VkFence fence_handle) {
123 api::utils::uvec3 global_size = v_dst.extents();
124 api::utils::uvec3 local_size = adaptive_work_group_size(global_size);
125
126 int32_t height =
127 api::utils::safe_downcast<int32_t>(dim_at<Dim4D::Height>(v_dst));
128 int32_t width =
129 api::utils::safe_downcast<int32_t>(dim_at<Dim4D::Width>(v_dst));
130 int32_t channels =
131 api::utils::safe_downcast<int32_t>(dim_at<Dim4D::Channel>(v_dst));
132
133 int32_t plane_size = height * width;
134 int32_t c_depth = api::utils::div_up(channels, 4);
135
136 ToFromTextureParams block{
137 api::utils::make_ivec3(v_dst.extents()),
138 plane_size,
139 {c_depth, channels},
140 };
141
142 api::UniformParamsBuffer params(context, block);
143 context->submit_compute_job(
144 // shader descriptor
145 compute_shader,
146 // pipeline barrier
147 pipeline_barrier,
148 // global work group size
149 global_size,
150 // local work group size
151 local_size,
152 // fence handle
153 fence_handle,
154 // shader arguments
155 v_dst.image(
156 pipeline_barrier,
157 api::PipelineStage::COMPUTE,
158 api::MemoryAccessType::WRITE),
159 src_buffer,
160 // params buffer
161 params.buffer());
162 }
163
record_image_to_nchw_op(api::Context * const context,api::ShaderInfo & compute_shader,vTensor & v_src,api::VulkanBuffer & dst_buffer,api::PipelineBarrier pipeline_barrier,VkFence fence_handle)164 bool record_image_to_nchw_op(
165 api::Context* const context,
166 api::ShaderInfo& compute_shader,
167 vTensor& v_src,
168 api::VulkanBuffer& dst_buffer,
169 api::PipelineBarrier pipeline_barrier,
170 VkFence fence_handle) {
171 api::utils::uvec3 global_size = v_src.extents();
172 api::utils::uvec3 local_size = adaptive_work_group_size(global_size);
173
174 int32_t height =
175 api::utils::safe_downcast<int32_t>(dim_at<Dim4D::Height>(v_src));
176 int32_t width =
177 api::utils::safe_downcast<int32_t>(dim_at<Dim4D::Width>(v_src));
178 int32_t channels =
179 api::utils::safe_downcast<int32_t>(dim_at<Dim4D::Channel>(v_src));
180
181 int32_t plane_size = height * width;
182 int32_t c_depth = api::utils::div_up(channels, 4);
183
184 ToFromTextureParams block{
185 api::utils::make_ivec3(v_src.extents()),
186 plane_size,
187 {c_depth, channels},
188 };
189
190 if (v_src.dtype() == api::ScalarType::QUInt8 ||
191 v_src.dtype() == api::ScalarType::QInt8 || v_src.dtype() == api::kBool) {
192 // Special case using optimized shader, image_to_nchw_quantized_mul4
193 if (plane_size % 4 == 0) {
194 global_size.data[0u] = plane_size / 4;
195 global_size.data[1u] = 1;
196 local_size.data[0u] *= local_size.data[1u];
197 local_size.data[1u] = 1;
198 }
199 // Global and local size for regular 1D buffer.
200 else {
201 uint32_t numel = v_src.numel();
202 global_size = {api::utils::div_up(numel, uint32_t(4)), 1u, 1u};
203 local_size = {64u, 1u, 1u};
204 }
205 }
206
207 api::UniformParamsBuffer params(context, block);
208 return context->submit_compute_job(
209 // shader descriptor
210 compute_shader,
211 // pipeline barrier
212 pipeline_barrier,
213 // global work group size
214 global_size,
215 // local work group size
216 local_size,
217 // fence handle
218 fence_handle,
219 // shader arguments
220 v_src.image(
221 pipeline_barrier,
222 api::PipelineStage::COMPUTE,
223 api::MemoryAccessType::WRITE),
224 dst_buffer,
225 // params buffer
226 params.buffer());
227 }
228
record_nchw_to_buffer_op(api::Context * const context,api::VulkanBuffer & src_buffer,vTensor & v_dst,api::PipelineBarrier pipeline_barrier,VkFence fence_handle)229 void record_nchw_to_buffer_op(
230 api::Context* const context,
231 api::VulkanBuffer& src_buffer,
232 vTensor& v_dst,
233 api::PipelineBarrier pipeline_barrier,
234 VkFence fence_handle) {
235 uint32_t gpu_buf_len = api::utils::safe_downcast<uint32_t>(v_dst.gpu_numel());
236
237 api::utils::uvec3 global_size = {gpu_buf_len, 1u, 1u};
238 api::utils::uvec3 local_size = {32u, 1u, 1u};
239
240 api::UniformParamsBuffer cpu_buffer_metadata(
241 context, v_dst.get_cpu_buffer_metadata());
242
243 context->submit_compute_job(
244 // shader descriptor
245 VK_KERNEL(buffer_to_buffer),
246 // pipeline barrier
247 pipeline_barrier,
248 // global work group size
249 global_size,
250 // local work group size
251 local_size,
252 // fence handle
253 fence_handle,
254 // shader arguments
255 v_dst.buffer(
256 pipeline_barrier,
257 api::PipelineStage::COMPUTE,
258 api::MemoryAccessType::WRITE),
259 v_dst.buffer_metadata(),
260 src_buffer,
261 cpu_buffer_metadata.buffer());
262 }
263
record_buffer_to_nchw_op(api::Context * const context,vTensor & v_src,api::VulkanBuffer & dst_buffer,api::PipelineBarrier pipeline_barrier,VkFence fence_handle)264 bool record_buffer_to_nchw_op(
265 api::Context* const context,
266 vTensor& v_src,
267 api::VulkanBuffer& dst_buffer,
268 api::PipelineBarrier pipeline_barrier,
269 VkFence fence_handle) {
270 uint32_t buf_len = api::utils::safe_downcast<uint32_t>(v_src.numel());
271
272 api::utils::uvec3 global_size = {buf_len, 1u, 1u};
273 api::utils::uvec3 local_size = {4u, 1u, 1u};
274
275 api::UniformParamsBuffer cpu_buffer_metadata(
276 context, v_src.get_cpu_buffer_metadata());
277
278 return context->submit_compute_job(
279 // shader descriptor
280 VK_KERNEL(buffer_to_buffer),
281 // pipeline barrier
282 pipeline_barrier,
283 // global work group size
284 global_size,
285 // local work group size
286 local_size,
287 // fence handle
288 fence_handle,
289 // shader arguments
290 dst_buffer,
291 cpu_buffer_metadata.buffer(),
292 v_src.buffer(
293 pipeline_barrier,
294 api::PipelineStage::COMPUTE,
295 api::MemoryAccessType::WRITE),
296 v_src.buffer_metadata());
297 }
298
channel_image_repacking(const vTensor & v_input,api::GPUMemoryLayout target_layout,const api::ShaderInfo & shader_descriptor)299 static vTensor channel_image_repacking(
300 const vTensor& v_input,
301 api::GPUMemoryLayout target_layout,
302 const api::ShaderInfo& shader_descriptor) {
303 api::Context* const context = api::context();
304
305 vTensor v_output{
306 context,
307 v_input.sizes(),
308 v_input.dtype(),
309 v_input.storage_type(),
310 target_layout,
311 };
312
313 // Required to determine how to insert memory barriers in the command buffer
314 api::PipelineBarrier pipeline_barrier{};
315
316 // The shader assumes a 4d nchw to calculate the lookup coordinate.
317 // If the input is not 4d, we need to pad it with 1's on the front.
318 const struct Block final {
319 api::utils::ivec4 sizes;
320 } block{
321 api::utils::make_ivec4_prepadded1(v_input.sizes()),
322 };
323
324 api::UniformParamsBuffer params(context, block);
325
326 context->submit_compute_job(
327 // shader descriptor
328 // VK_KERNEL(packing_channel_to_height),
329 shader_descriptor,
330 // pipeline barrier
331 pipeline_barrier,
332 // global work group size
333 v_output.extents(),
334 // local work group size
335 adaptive_work_group_size(v_output.extents()),
336 // fence handle
337 VK_NULL_HANDLE,
338 // shader arguments
339 v_output.image(
340 pipeline_barrier,
341 api::PipelineStage::COMPUTE,
342 api::MemoryAccessType::WRITE),
343 v_input.image(pipeline_barrier, api::PipelineStage::COMPUTE),
344 // params buffer
345 params.buffer());
346
347 return v_output;
348 }
349
convert_image_channels_packed_to_height_packed(const vTensor & v_input)350 vTensor convert_image_channels_packed_to_height_packed(const vTensor& v_input) {
351 return channel_image_repacking(
352 v_input,
353 api::GPUMemoryLayout::TENSOR_HEIGHT_PACKED,
354 VK_KERNEL(convert_channels_to_height_packed));
355 }
356
convert_image_channels_packed_to_width_packed(const vTensor & v_input)357 vTensor convert_image_channels_packed_to_width_packed(const vTensor& v_input) {
358 return channel_image_repacking(
359 v_input,
360 api::GPUMemoryLayout::TENSOR_WIDTH_PACKED,
361 VK_KERNEL(convert_channels_to_width_packed));
362 }
363
364 } // namespace packing
365 } // namespace vulkan
366 } // namespace native
367 } // namespace at
368