xref: /aosp_15_r20/external/mesa3d/src/amd/vulkan/bvh/lbvh_main.comp (revision 6104692788411f58d303aa86923a9ff6ecaded22)
1/*
2 * Copyright © 2022 Bas Nieuwenhuizen
3 *
4 * SPDX-License-Identifier: MIT
5 */
6
7#version 460
8
9#extension GL_GOOGLE_include_directive : require
10
11#extension GL_EXT_shader_explicit_arithmetic_types_int8 : require
12#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require
13#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
14#extension GL_EXT_shader_explicit_arithmetic_types_int64 : require
15#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
16#extension GL_EXT_scalar_block_layout : require
17#extension GL_EXT_buffer_reference : require
18#extension GL_EXT_buffer_reference2 : require
19
20layout(local_size_x = 64, local_size_y = 1, local_size_z = 1) in;
21
22#include "build_interface.h"
23
24TYPE(lbvh_node_info, 4);
25
26layout(push_constant) uniform CONSTS
27{
28   lbvh_main_args args;
29};
30
31int32_t
32longest_common_prefix(int32_t i, uint32_t key_i, int32_t j)
33{
34   if (j < 0 || j >= args.id_count)
35      return -1;
36
37   uint32_t key_j = DEREF(INDEX(key_id_pair, args.src_ids, j)).key;
38
39   uint32_t diff = key_i ^ key_j;
40   int32_t ret = 0;
41   if (key_i == key_j) {
42      ret += 32;
43      diff = i ^ j;
44   }
45
46   return ret + 31 - findMSB(diff);
47}
48
49/*
50 * The LBVH algorithm constructs a radix tree of the sorted nodes according to their key.
51 *
52 * We do this by making the following decision:
53 *
54 *    Node N always either starts or ends at leaf N.
55 *
56 * From there it follows that we always have to extend it into the direction which has
57 * a longer common prefix with the direct neighbour. Then we try to make the node cover
58 * as many leaves as possible without including the other neighbour.
59 *
60 * For finding the split point we compute the longest common prefix of all the leaves within the
61 * node, and look for the first leaf the same length common prefix with leaf N as that.
62 *
63 * To give an example: leaves=[000,001,010,011,100,101,110,111], node=5 (with value 101)
64 *
65 * lcp(101, 100) = 2 and lcp(101, 110) = 1, so we extend down.
66 * lcp(101, 011) = 0, so the range of the node is [4,5] with values [100, 101]
67 *
68 * the lcp of all the leaves in the range is the same as the lcp of the first and last leaf, in this
69 * case that is lcp(101, 100) = 2. Then we have lcp(101, 101) = 3 and lcp(101, 100) = 2, so the first
70 * leaf that has a longer lcp is 4. Hence the two children of this node have range [4,4] and [5,5]
71 */
72void
73main()
74{
75   if (args.id_count <= 1) {
76      REF(lbvh_node_info) dst = REF(lbvh_node_info)(args.node_info);
77      DEREF(dst).parent = RADV_BVH_INVALID_NODE;
78      DEREF(dst).path_count = 2;
79      DEREF(dst).children[0] =
80         args.id_count == 1 ? DEREF(INDEX(key_id_pair, args.src_ids, 0)).id : RADV_BVH_INVALID_NODE;
81      DEREF(dst).children[1] = RADV_BVH_INVALID_NODE;
82      return;
83   }
84
85   int32_t id = int32_t(gl_GlobalInvocationID.x);
86   uint32_t id_key = DEREF(INDEX(key_id_pair, args.src_ids, id)).key;
87
88   int32_t left_lcp = longest_common_prefix(id, id_key, id - 1);
89   int32_t right_lcp = longest_common_prefix(id, id_key, id + 1);
90   int32_t dir = right_lcp > left_lcp ? 1 : -1;
91   int32_t lcp_min = min(left_lcp, right_lcp);
92
93   /* Determine the bounds for the binary search for the length of the range that
94    * this subtree is going to own.
95    */
96   int32_t lmax = 128;
97   while (longest_common_prefix(id, id_key, id + dir * lmax) > lcp_min) {
98      lmax *= 2;
99   }
100
101   int32_t length = 0;
102   for (int32_t t = lmax / 2; t >= 1; t /= 2) {
103      if (longest_common_prefix(id, id_key, id + (length + t) * dir) > lcp_min)
104         length += t;
105   }
106   int32_t other_end = id + length * dir;
107
108   /* The number of bits in the prefix that is the same for all elements in the
109    * range.
110    */
111   int32_t lcp_node = longest_common_prefix(id, id_key, other_end);
112   int32_t child_range = 0;
113   for (int32_t diff = 2; diff < 2 * length; diff *= 2) {
114      int32_t t = DIV_ROUND_UP(length, diff);
115      if (longest_common_prefix(id, id_key, id + (child_range + t) * dir) > lcp_node)
116         child_range += t;
117   }
118
119   int32_t child_split = id + child_range * dir;
120
121   /* If dir = -1, right = child_split */
122   int32_t left = child_split + min(dir, 0);
123   int32_t right = left + 1;
124
125   /* if the number of leaves covered by a child is 1, we can use the leaf directly */
126   bool left_leaf = min(id, other_end) == left;
127   bool right_leaf = max(id, other_end) == right;
128
129   if (!left_leaf)
130      DEREF(INDEX(lbvh_node_info, args.node_info, left)).parent = id;
131   if (!right_leaf)
132      DEREF(INDEX(lbvh_node_info, args.node_info, right)).parent = LBVH_RIGHT_CHILD_BIT | id;
133
134   REF(lbvh_node_info) dst = INDEX(lbvh_node_info, args.node_info, id);
135   DEREF(dst).path_count = (left_leaf ? 1 : 0) + (right_leaf ? 1 : 0);
136   DEREF(dst).children[0] = DEREF(INDEX(key_id_pair, args.src_ids, left)).id;
137   DEREF(dst).children[1] = DEREF(INDEX(key_id_pair, args.src_ids, right)).id;
138   if (id == 0)
139      DEREF(dst).parent = RADV_BVH_INVALID_NODE;
140}
141