xref: /aosp_15_r20/external/mesa3d/src/amd/vulkan/bvh/lbvh_generate_ir.comp (revision 6104692788411f58d303aa86923a9ff6ecaded22)
1/*
2 * Copyright © 2022 Bas Nieuwenhuizen
3 *
4 * SPDX-License-Identifier: MIT
5 */
6
7#version 460
8
9#extension GL_GOOGLE_include_directive : require
10
11#extension GL_EXT_shader_explicit_arithmetic_types_int8 : require
12#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require
13#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
14#extension GL_EXT_shader_explicit_arithmetic_types_int64 : require
15#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
16#extension GL_EXT_scalar_block_layout : require
17#extension GL_EXT_buffer_reference : require
18#extension GL_EXT_buffer_reference2 : require
19#extension GL_KHR_memory_scope_semantics : require
20
21layout(local_size_x = 64, local_size_y = 1, local_size_z = 1) in;
22
23#include "build_interface.h"
24
25TYPE(lbvh_node_info, 4);
26
27layout(push_constant) uniform CONSTS
28{
29   lbvh_generate_ir_args args;
30};
31
32void
33main(void)
34{
35   uint32_t global_id = gl_GlobalInvocationID.x;
36
37   uint32_t idx = global_id;
38
39   uint32_t previous_id = RADV_BVH_INVALID_NODE;
40   radv_aabb previous_bounds;
41   previous_bounds.min = vec3(INFINITY);
42   previous_bounds.max = vec3(-INFINITY);
43
44   for (;;) {
45      uint32_t count = 0;
46
47      /* Check if all children have been processed. As this is an atomic the last path coming from
48       * a child will pass here, while earlier paths break.
49       */
50      count = atomicAdd(
51         DEREF(INDEX(lbvh_node_info, args.node_info, idx)).path_count, 1, gl_ScopeDevice,
52         gl_StorageSemanticsBuffer,
53         gl_SemanticsAcquireRelease | gl_SemanticsMakeAvailable | gl_SemanticsMakeVisible);
54      if (count != 2)
55         break;
56
57      /* We allocate nodes on demand with the atomic here to ensure children come before their
58       * parents, which is a requirement of the encoder.
59       */
60      uint32_t dst_idx =
61         atomicAdd(DEREF(REF(radv_ir_header)(args.header)).ir_internal_node_count, 1);
62
63      uint32_t current_offset = args.internal_node_base + dst_idx * SIZEOF(radv_ir_box_node);
64      uint32_t current_id = pack_ir_node_id(current_offset, radv_ir_node_internal);
65
66      REF(radv_ir_box_node) node = REF(radv_ir_box_node)(OFFSET(args.bvh, current_offset));
67      radv_aabb bounds = previous_bounds;
68
69      lbvh_node_info info = DEREF(INDEX(lbvh_node_info, args.node_info, idx));
70
71      uint32_t children[2] = info.children;
72
73      /* Try using the cached previous_bounds instead of fetching the bounds twice. */
74      int32_t previous_child_index = -1;
75      if (previous_id == children[0])
76         previous_child_index = 0;
77      else if (previous_id == children[1])
78         previous_child_index = 1;
79
80      if (previous_child_index == -1) {
81         if (children[0] != RADV_BVH_INVALID_NODE) {
82            uint32_t child_offset = ir_id_to_offset(children[0]);
83            REF(radv_ir_node) child = REF(radv_ir_node)(OFFSET(args.bvh, child_offset));
84            radv_aabb child_bounds = DEREF(child).aabb;
85            bounds.min = min(bounds.min, child_bounds.min);
86            bounds.max = max(bounds.max, child_bounds.max);
87         }
88         previous_child_index = 0;
89      }
90
91      /* Fetch the non-cached child */
92      if (children[1 - previous_child_index] != RADV_BVH_INVALID_NODE) {
93         uint32_t child_offset = ir_id_to_offset(children[1 - previous_child_index]);
94         REF(radv_ir_node) child = REF(radv_ir_node)(OFFSET(args.bvh, child_offset));
95         radv_aabb child_bounds = DEREF(child).aabb;
96         bounds.min = min(bounds.min, child_bounds.min);
97         bounds.max = max(bounds.max, child_bounds.max);
98      }
99
100      radv_ir_box_node node_value;
101
102      node_value.base.aabb = bounds;
103      node_value.bvh_offset = RADV_UNKNOWN_BVH_OFFSET;
104      node_value.children = children;
105
106      DEREF(node) = node_value;
107
108      if (info.parent == RADV_BVH_INVALID_NODE)
109         break;
110
111      idx = info.parent & ~LBVH_RIGHT_CHILD_BIT;
112
113      DEREF(INDEX(lbvh_node_info, args.node_info, idx))
114         .children[(info.parent >> LBVH_RIGHT_CHILD_BIT_SHIFT) & 1] = current_id;
115
116      previous_id = current_id;
117      previous_bounds = bounds;
118
119      memoryBarrierBuffer();
120   }
121}
122