xref: /aosp_15_r20/external/mesa3d/src/amd/vulkan/bvh/ploc_internal.comp (revision 6104692788411f58d303aa86923a9ff6ecaded22)
1/*
2 * Copyright © 2022 Bas Nieuwenhuizen
3 *
4 * SPDX-License-Identifier: MIT
5 */
6
7#version 460
8
9#extension GL_GOOGLE_include_directive : require
10
11#extension GL_EXT_shader_explicit_arithmetic_types_int8 : require
12#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require
13#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
14#extension GL_EXT_shader_explicit_arithmetic_types_int64 : require
15#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
16#extension GL_EXT_scalar_block_layout : require
17#extension GL_EXT_buffer_reference : require
18#extension GL_EXT_buffer_reference2 : require
19#extension GL_KHR_memory_scope_semantics : require
20#extension GL_KHR_shader_subgroup_vote : require
21#extension GL_KHR_shader_subgroup_arithmetic : require
22#extension GL_KHR_shader_subgroup_ballot : require
23
24layout(local_size_x = 1024, local_size_y = 1, local_size_z = 1) in;
25
26#define USE_GLOBAL_SYNC
27#include "build_interface.h"
28
29TYPE(ploc_prefix_scan_partition, 4);
30
31layout(push_constant) uniform CONSTS
32{
33   ploc_args args;
34};
35
36shared uint32_t exclusive_prefix_sum;
37shared uint32_t aggregate_sums[PLOC_WORKGROUP_SIZE / 64];
38
39/*
40 * Global prefix scan over all workgroups to find out the index of the collapsed node to write.
41 * See https://research.nvidia.com/sites/default/files/publications/nvr-2016-002.pdf
42 * One partition = one workgroup in this case.
43 */
44uint32_t
45prefix_scan(uvec4 ballot, REF(ploc_prefix_scan_partition) partitions, uint32_t task_index)
46{
47   if (gl_LocalInvocationIndex == 0) {
48      /* Temporary copy of exclusive_prefix_sum to avoid reading+writing LDS each addition */
49      uint32_t local_exclusive_prefix_sum = 0;
50      if (task_index >= gl_WorkGroupSize.x) {
51         REF(ploc_prefix_scan_partition) current_partition =
52            REF(ploc_prefix_scan_partition)(INDEX(ploc_prefix_scan_partition, partitions, task_index / gl_WorkGroupSize.x));
53
54         REF(ploc_prefix_scan_partition) previous_partition = current_partition - 1;
55
56         while (true) {
57            /* See if this previous workgroup already set their inclusive sum */
58            if (atomicLoad(DEREF(previous_partition).inclusive_sum, gl_ScopeDevice,
59                           gl_StorageSemanticsBuffer,
60                           gl_SemanticsAcquire | gl_SemanticsMakeVisible) != 0xFFFFFFFF) {
61               local_exclusive_prefix_sum += DEREF(previous_partition).inclusive_sum;
62               break;
63            } else {
64               local_exclusive_prefix_sum += DEREF(previous_partition).aggregate;
65               previous_partition -= 1;
66            }
67         }
68         /* Set the inclusive sum for the next workgroups */
69         atomicStore(DEREF(current_partition).inclusive_sum,
70                     DEREF(current_partition).aggregate + local_exclusive_prefix_sum, gl_ScopeDevice,
71                     gl_StorageSemanticsBuffer, gl_SemanticsRelease | gl_SemanticsMakeAvailable);
72      }
73      exclusive_prefix_sum = local_exclusive_prefix_sum;
74   }
75
76   if (subgroupElect())
77      aggregate_sums[gl_SubgroupID] = subgroupBallotBitCount(ballot);
78   barrier();
79
80   if (gl_LocalInvocationID.x < PLOC_WORKGROUP_SIZE / 64) {
81      aggregate_sums[gl_LocalInvocationID.x] =
82         exclusive_prefix_sum + subgroupExclusiveAdd(aggregate_sums[gl_LocalInvocationID.x]);
83   }
84   barrier();
85
86   return aggregate_sums[gl_SubgroupID] + subgroupBallotExclusiveBitCount(ballot);
87}
88
89/* Relative cost of increasing the BVH depth. Deep BVHs will require more backtracking. */
90#define BVH_LEVEL_COST 0.2
91
92uint32_t
93push_node(uint32_t children[2], radv_aabb bounds[2])
94{
95   uint32_t internal_node_index = atomicAdd(DEREF(args.header).ir_internal_node_count, 1);
96   uint32_t dst_offset = args.internal_node_offset + internal_node_index * SIZEOF(radv_ir_box_node);
97   uint32_t dst_id = pack_ir_node_id(dst_offset, radv_ir_node_internal);
98   REF(radv_ir_box_node) dst_node = REF(radv_ir_box_node)(OFFSET(args.bvh, dst_offset));
99
100   radv_aabb total_bounds;
101   total_bounds.min = vec3(INFINITY);
102   total_bounds.max = vec3(-INFINITY);
103
104   for (uint i = 0; i < 2; ++i) {
105      VOID_REF node = OFFSET(args.bvh, ir_id_to_offset(children[i]));
106      REF(radv_ir_node) child = REF(radv_ir_node)(node);
107
108      total_bounds.min = min(total_bounds.min, bounds[i].min);
109      total_bounds.max = max(total_bounds.max, bounds[i].max);
110
111      DEREF(dst_node).children[i] = children[i];
112   }
113
114   DEREF(dst_node).base.aabb = total_bounds;
115   DEREF(dst_node).bvh_offset = RADV_UNKNOWN_BVH_OFFSET;
116   return dst_id;
117}
118
119#define PLOC_NEIGHBOURHOOD 16
120#define PLOC_OFFSET_MASK   ((1 << 5) - 1)
121
122uint32_t
123encode_neighbour_offset(float sah, uint32_t i, uint32_t j)
124{
125   int32_t offset = int32_t(j) - int32_t(i);
126   uint32_t encoded_offset = offset + PLOC_NEIGHBOURHOOD - (offset > 0 ? 1 : 0);
127   return (floatBitsToUint(sah) & (~PLOC_OFFSET_MASK)) | encoded_offset;
128}
129
130int32_t
131decode_neighbour_offset(uint32_t encoded_offset)
132{
133   int32_t offset = int32_t(encoded_offset & PLOC_OFFSET_MASK) - PLOC_NEIGHBOURHOOD;
134   return offset + (offset >= 0 ? 1 : 0);
135}
136
137#define NUM_PLOC_LDS_ITEMS PLOC_WORKGROUP_SIZE + 4 * PLOC_NEIGHBOURHOOD
138
139shared radv_aabb shared_bounds[NUM_PLOC_LDS_ITEMS];
140shared uint32_t nearest_neighbour_indices[NUM_PLOC_LDS_ITEMS];
141
142uint32_t
143load_id(VOID_REF ids, uint32_t iter, uint32_t index)
144{
145   if (iter == 0)
146      return DEREF(REF(key_id_pair)(INDEX(key_id_pair, ids, index))).id;
147   else
148      return DEREF(REF(uint32_t)(INDEX(uint32_t, ids, index)));
149}
150
151void
152load_bounds(VOID_REF ids, uint32_t iter, uint32_t task_index, uint32_t lds_base,
153            uint32_t neighbourhood_overlap, uint32_t search_bound)
154{
155   for (uint32_t i = task_index - 2 * neighbourhood_overlap; i < search_bound;
156        i += gl_WorkGroupSize.x) {
157      uint32_t id = load_id(ids, iter, i);
158      if (id == RADV_BVH_INVALID_NODE)
159         continue;
160
161      VOID_REF addr = OFFSET(args.bvh, ir_id_to_offset(id));
162      REF(radv_ir_node) node = REF(radv_ir_node)(addr);
163
164      shared_bounds[i - lds_base] = DEREF(node).aabb;
165   }
166}
167
168float
169combined_node_cost(uint32_t lds_base, uint32_t i, uint32_t j)
170{
171   radv_aabb combined_bounds;
172   combined_bounds.min = min(shared_bounds[i - lds_base].min, shared_bounds[j - lds_base].min);
173   combined_bounds.max = max(shared_bounds[i - lds_base].max, shared_bounds[j - lds_base].max);
174   return aabb_surface_area(combined_bounds);
175}
176
177shared uint32_t shared_aggregate_sum;
178
179void
180main(void)
181{
182   VOID_REF src_ids = args.ids_0;
183   VOID_REF dst_ids = args.ids_1;
184
185   /* We try to use LBVH for BVHs where we know there will be less than 5 leaves,
186    * but sometimes all leaves might be inactive */
187   if (DEREF(args.header).active_leaf_count <= 2) {
188      if (gl_GlobalInvocationID.x == 0) {
189         uint32_t internal_node_index = atomicAdd(DEREF(args.header).ir_internal_node_count, 1);
190         uint32_t dst_offset = args.internal_node_offset + internal_node_index * SIZEOF(radv_ir_box_node);
191         REF(radv_ir_box_node) dst_node = REF(radv_ir_box_node)(OFFSET(args.bvh, dst_offset));
192
193         radv_aabb total_bounds;
194         total_bounds.min = vec3(INFINITY);
195         total_bounds.max = vec3(-INFINITY);
196
197         uint32_t i = 0;
198         for (; i < DEREF(args.header).active_leaf_count; i++) {
199            uint32_t child_id = DEREF(INDEX(key_id_pair, src_ids, i)).id;
200
201            if (child_id != RADV_BVH_INVALID_NODE) {
202               VOID_REF node = OFFSET(args.bvh, ir_id_to_offset(child_id));
203               REF(radv_ir_node) child = REF(radv_ir_node)(node);
204               radv_aabb bounds = DEREF(child).aabb;
205
206               total_bounds.min = min(total_bounds.min, bounds.min);
207               total_bounds.max = max(total_bounds.max, bounds.max);
208            }
209
210            DEREF(dst_node).children[i] = child_id;
211         }
212         for (; i < 2; i++)
213            DEREF(dst_node).children[i] = RADV_BVH_INVALID_NODE;
214
215         DEREF(dst_node).base.aabb = total_bounds;
216         DEREF(dst_node).bvh_offset = RADV_UNKNOWN_BVH_OFFSET;
217      }
218      return;
219   }
220
221   /* Only initialize sync_data once per workgroup. For intra-workgroup synchronization,
222    * fetch_task contains a workgroup-scoped control+memory barrier.
223    */
224   if (gl_LocalInvocationIndex == 0) {
225      atomicCompSwap(DEREF(args.header).sync_data.task_counts[0], 0xFFFFFFFF,
226                     DEREF(args.header).active_leaf_count);
227      atomicCompSwap(DEREF(args.header).sync_data.current_phase_end_counter, 0xFFFFFFFF,
228                     DIV_ROUND_UP(DEREF(args.header).active_leaf_count, gl_WorkGroupSize.x));
229   }
230
231   REF(ploc_prefix_scan_partition)
232   partitions = REF(ploc_prefix_scan_partition)(args.prefix_scan_partitions);
233
234   uint32_t task_index = fetch_task(args.header, false);
235
236   for (uint iter = 0;; ++iter) {
237      uint32_t current_task_count = task_count(args.header);
238      if (task_index == TASK_INDEX_INVALID)
239         break;
240
241      /* Find preferred partners and merge them */
242      PHASE (args.header) {
243         uint32_t base_index = task_index - gl_LocalInvocationID.x;
244         uint32_t neighbourhood_overlap = min(PLOC_NEIGHBOURHOOD, base_index);
245         uint32_t double_neighbourhood_overlap = min(2 * PLOC_NEIGHBOURHOOD, base_index);
246         /* Upper bound to where valid nearest node indices are written. */
247         uint32_t write_bound =
248            min(current_task_count, base_index + gl_WorkGroupSize.x + PLOC_NEIGHBOURHOOD);
249         /* Upper bound to where valid nearest node indices are searched. */
250         uint32_t search_bound =
251            min(current_task_count, base_index + gl_WorkGroupSize.x + 2 * PLOC_NEIGHBOURHOOD);
252         uint32_t lds_base = base_index - double_neighbourhood_overlap;
253
254         load_bounds(src_ids, iter, task_index, lds_base, neighbourhood_overlap, search_bound);
255
256         for (uint32_t i = gl_LocalInvocationID.x; i < NUM_PLOC_LDS_ITEMS; i += gl_WorkGroupSize.x)
257            nearest_neighbour_indices[i] = 0xFFFFFFFF;
258         barrier();
259
260         for (uint32_t i = task_index - double_neighbourhood_overlap; i < write_bound;
261              i += gl_WorkGroupSize.x) {
262            uint32_t right_bound = min(search_bound - 1 - i, PLOC_NEIGHBOURHOOD);
263
264            uint32_t fallback_pair = i == 0 ? (i + 1) : (i - 1);
265            uint32_t min_offset = encode_neighbour_offset(INFINITY, i, fallback_pair);
266
267            for (uint32_t j = max(i + 1, base_index - neighbourhood_overlap); j <= i + right_bound;
268                 ++j) {
269
270               float sah = combined_node_cost(lds_base, i, j);
271               uint32_t i_encoded_offset = encode_neighbour_offset(sah, i, j);
272               uint32_t j_encoded_offset = encode_neighbour_offset(sah, j, i);
273               min_offset = min(min_offset, i_encoded_offset);
274               atomicMin(nearest_neighbour_indices[j - lds_base], j_encoded_offset);
275            }
276            if (i >= base_index - neighbourhood_overlap)
277               atomicMin(nearest_neighbour_indices[i - lds_base], min_offset);
278         }
279
280         if (gl_LocalInvocationID.x == 0)
281            shared_aggregate_sum = 0;
282         barrier();
283
284         for (uint32_t i = task_index - neighbourhood_overlap; i < write_bound;
285              i += gl_WorkGroupSize.x) {
286            uint32_t left_bound = min(i, PLOC_NEIGHBOURHOOD);
287            uint32_t right_bound = min(search_bound - 1 - i, PLOC_NEIGHBOURHOOD);
288            /*
289             * Workaround for a worst-case scenario in PLOC: If the combined area of
290             * all nodes (in the neighbourhood) is the same, then the chosen nearest
291             * neighbour is the first neighbour. However, this means that no nodes
292             * except the first two will find each other as nearest neighbour. Therefore,
293             * only one node is contained in each BVH level. By first testing if the immediate
294             * neighbour on one side is the nearest, all immediate neighbours will be merged
295             * on every step.
296             */
297            uint32_t preferred_pair;
298            if ((i & 1) != 0)
299               preferred_pair = i - min(left_bound, 1);
300            else
301               preferred_pair = i + min(right_bound, 1);
302
303            if (preferred_pair != i) {
304               uint32_t encoded_min_sah =
305                  nearest_neighbour_indices[i - lds_base] & (~PLOC_OFFSET_MASK);
306               float sah = combined_node_cost(lds_base, i, preferred_pair);
307               uint32_t encoded_sah = floatBitsToUint(sah) & (~PLOC_OFFSET_MASK);
308               uint32_t encoded_offset = encode_neighbour_offset(sah, i, preferred_pair);
309               if (encoded_sah <= encoded_min_sah) {
310                  nearest_neighbour_indices[i - lds_base] = encoded_offset;
311               }
312            }
313         }
314         barrier();
315
316         bool has_valid_node = true;
317
318         if (task_index < current_task_count) {
319            uint32_t base_index = task_index - gl_LocalInvocationID.x;
320
321            uint32_t neighbour_index =
322               task_index +
323               decode_neighbour_offset(nearest_neighbour_indices[task_index - lds_base]);
324            uint32_t other_neighbour_index =
325               neighbour_index +
326               decode_neighbour_offset(nearest_neighbour_indices[neighbour_index - lds_base]);
327            uint32_t id = load_id(src_ids, iter, task_index);
328            if (other_neighbour_index == task_index) {
329               if (task_index < neighbour_index) {
330                  uint32_t neighbour_id = load_id(src_ids, iter, neighbour_index);
331                  uint32_t children[2] = {id, neighbour_id};
332                  radv_aabb bounds[2] = {shared_bounds[task_index - lds_base], shared_bounds[neighbour_index - lds_base]};
333
334                  DEREF(REF(uint32_t)(INDEX(uint32_t, dst_ids, task_index))) = push_node(children, bounds);
335                  DEREF(REF(uint32_t)(INDEX(uint32_t, dst_ids, neighbour_index))) =
336                     RADV_BVH_INVALID_NODE;
337               } else {
338                  /* We still store in the other case so we don't destroy the node id needed to
339                   * create the internal node */
340                  has_valid_node = false;
341               }
342            } else {
343               DEREF(REF(uint32_t)(INDEX(uint32_t, dst_ids, task_index))) = id;
344            }
345
346            /* Compact - prepare prefix scan */
347            uvec4 ballot = subgroupBallot(has_valid_node);
348
349            uint32_t aggregate_sum = subgroupBallotBitCount(ballot);
350            if (subgroupElect())
351               atomicAdd(shared_aggregate_sum, aggregate_sum);
352         }
353
354         barrier();
355         /*
356          * The paper proposes initializing all partitions to an invalid state
357          * and only computing aggregates afterwards. We skip that step and
358          * initialize the partitions to a valid state. This also simplifies
359          * the look-back, as there will never be any blocking due to invalid
360          * partitions.
361          */
362         if (gl_LocalInvocationIndex == 0) {
363            REF(ploc_prefix_scan_partition)
364            current_partition = REF(ploc_prefix_scan_partition)(
365               INDEX(ploc_prefix_scan_partition, partitions, task_index / gl_WorkGroupSize.x));
366            DEREF(current_partition).aggregate = shared_aggregate_sum;
367            if (task_index < gl_WorkGroupSize.x) {
368               DEREF(current_partition).inclusive_sum = shared_aggregate_sum;
369            } else {
370               DEREF(current_partition).inclusive_sum = 0xFFFFFFFF;
371            }
372         }
373
374         if (task_index == 0)
375            set_next_task_count(args.header, task_count(args.header));
376      }
377
378      /* Compact - prefix scan and update */
379      PHASE (args.header) {
380         uint32_t current_task_count = task_count(args.header);
381
382         uint32_t id = task_index < current_task_count
383                          ? DEREF(REF(uint32_t)(INDEX(uint32_t, dst_ids, task_index)))
384                          : RADV_BVH_INVALID_NODE;
385         uvec4 ballot = subgroupBallot(id != RADV_BVH_INVALID_NODE);
386
387         uint32_t new_offset = prefix_scan(ballot, partitions, task_index);
388         if (task_index >= current_task_count)
389            continue;
390
391         if (id != RADV_BVH_INVALID_NODE) {
392            DEREF(REF(uint32_t)(INDEX(uint32_t, src_ids, new_offset))) = id;
393            ++new_offset;
394         }
395
396         if (task_index == current_task_count - 1) {
397            set_next_task_count(args.header, new_offset);
398            if (new_offset == 1)
399               DEREF(args.header).sync_data.next_phase_exit_flag = 1;
400         }
401      }
402   }
403}
404