xref: /aosp_15_r20/external/mesa3d/src/amd/vulkan/bvh/copy.comp (revision 6104692788411f58d303aa86923a9ff6ecaded22)
1/*
2 * Copyright © 2022 Bas Nieuwenhuizen
3 *
4 * SPDX-License-Identifier: MIT
5 */
6
7#version 460
8
9#extension GL_GOOGLE_include_directive : require
10
11#extension GL_EXT_shader_explicit_arithmetic_types_int8 : require
12#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require
13#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
14#extension GL_EXT_shader_explicit_arithmetic_types_int64 : require
15#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
16#extension GL_EXT_scalar_block_layout : require
17#extension GL_EXT_buffer_reference : require
18#extension GL_EXT_buffer_reference2 : require
19
20layout(local_size_x = 64, local_size_y = 1, local_size_z = 1) in;
21
22#include "build_interface.h"
23
24layout(push_constant) uniform CONSTS {
25   copy_args args;
26};
27
28void
29main(void)
30{
31   uint32_t global_id = gl_GlobalInvocationID.x;
32   uint32_t lanes = gl_NumWorkGroups.x * 64;
33   uint32_t increment = lanes * 16;
34
35   uint64_t copy_src_addr = args.src_addr;
36   uint64_t copy_dst_addr = args.dst_addr;
37
38   if (args.mode == RADV_COPY_MODE_DESERIALIZE) {
39      copy_src_addr += SIZEOF(radv_accel_struct_serialization_header) +
40                       DEREF(REF(radv_accel_struct_serialization_header)(args.src_addr)).instance_count * SIZEOF(uint64_t);
41
42   }
43
44   REF(radv_accel_struct_header) header = REF(radv_accel_struct_header)(copy_src_addr);
45
46   uint64_t instance_base = args.src_addr + SIZEOF(radv_accel_struct_serialization_header);
47   uint64_t node_offset = DEREF(header).instance_offset;
48   uint64_t node_end = DEREF(header).instance_count * SIZEOF(radv_bvh_instance_node);
49   if (node_end > 0)
50      node_end += node_offset;
51
52   if (args.mode == RADV_COPY_MODE_SERIALIZE) {
53      copy_dst_addr += SIZEOF(radv_accel_struct_serialization_header) +
54                       DEREF(REF(radv_accel_struct_header)(args.src_addr)).instance_count * SIZEOF(uint64_t);
55
56      if (global_id == 0) {
57         REF(radv_accel_struct_serialization_header) ser_header =
58            REF(radv_accel_struct_serialization_header)(args.dst_addr);
59         DEREF(ser_header).serialization_size = DEREF(header).serialization_size;
60         DEREF(ser_header).compacted_size = DEREF(header).compacted_size;
61         DEREF(ser_header).instance_count = DEREF(header).instance_count;
62      }
63
64      instance_base = args.dst_addr + SIZEOF(radv_accel_struct_serialization_header);
65   } else if (args.mode == RADV_COPY_MODE_COPY)
66      node_end = 0;
67
68   uint64_t size = DEREF(header).compacted_size;
69   for (uint64_t offset = global_id * 16; offset < size; offset += increment) {
70      DEREF(REF(uvec4)(copy_dst_addr + offset)) =
71         DEREF(REF(uvec4)(copy_src_addr + offset));
72
73      /* Do the adjustment inline in the same invocation that copies the data so that we don't have
74       * to synchronize. */
75      if (offset < node_end && offset >= node_offset &&
76          (offset - node_offset) % SIZEOF(radv_bvh_instance_node) == 0) {
77         uint64_t idx = (offset - node_offset) / SIZEOF(radv_bvh_instance_node);
78
79         uint32_t bvh_offset = DEREF(REF(radv_bvh_instance_node)(copy_src_addr + offset)).bvh_offset;
80         if (args.mode == RADV_COPY_MODE_SERIALIZE) {
81            DEREF(INDEX(uint64_t, instance_base, idx)) =
82               node_to_addr(DEREF(REF(radv_bvh_instance_node)(copy_src_addr + offset)).bvh_ptr) - bvh_offset;
83         } else { /* RADV_COPY_MODE_DESERIALIZE */
84            uint64_t blas_addr = DEREF(INDEX(uint64_t, instance_base, idx));
85            DEREF(REF(radv_bvh_instance_node)(copy_dst_addr + offset)).bvh_ptr = addr_to_node(blas_addr + bvh_offset);
86         }
87      }
88   }
89}
90