xref: /aosp_15_r20/external/mesa3d/src/amd/vulkan/nir/radv_nir_lower_ray_queries.c (revision 6104692788411f58d303aa86923a9ff6ecaded22)
1 /*
2  * Copyright © 2022 Konstantin Seurer
3  *
4  * SPDX-License-Identifier: MIT
5  */
6 
7 #include "nir/nir.h"
8 #include "nir/nir_builder.h"
9 
10 #include "util/hash_table.h"
11 
12 #include "bvh/bvh.h"
13 #include "nir/radv_nir_rt_common.h"
14 #include "radv_debug.h"
15 #include "radv_nir.h"
16 #include "radv_shader.h"
17 
18 /* Traversal stack size. Traversal supports backtracking so we can go deeper than this size if
19  * needed. However, we keep a large stack size to avoid it being put into registers, which hurts
20  * occupancy. */
21 #define MAX_SCRATCH_STACK_ENTRY_COUNT 76
22 
23 typedef struct {
24    nir_variable *variable;
25    unsigned array_length;
26 } rq_variable;
27 
28 static rq_variable *
rq_variable_create(void * ctx,nir_shader * shader,unsigned array_length,const struct glsl_type * type,const char * name)29 rq_variable_create(void *ctx, nir_shader *shader, unsigned array_length, const struct glsl_type *type, const char *name)
30 {
31    rq_variable *result = ralloc(ctx, rq_variable);
32    result->array_length = array_length;
33 
34    const struct glsl_type *variable_type = type;
35    if (array_length != 1)
36       variable_type = glsl_array_type(type, array_length, glsl_get_explicit_stride(type));
37 
38    result->variable = nir_variable_create(shader, nir_var_shader_temp, variable_type, name);
39 
40    return result;
41 }
42 
43 static nir_def *
nir_load_array(nir_builder * b,nir_variable * array,nir_def * index)44 nir_load_array(nir_builder *b, nir_variable *array, nir_def *index)
45 {
46    return nir_load_deref(b, nir_build_deref_array(b, nir_build_deref_var(b, array), index));
47 }
48 
49 static void
nir_store_array(nir_builder * b,nir_variable * array,nir_def * index,nir_def * value,unsigned writemask)50 nir_store_array(nir_builder *b, nir_variable *array, nir_def *index, nir_def *value, unsigned writemask)
51 {
52    nir_store_deref(b, nir_build_deref_array(b, nir_build_deref_var(b, array), index), value, writemask);
53 }
54 
55 static nir_deref_instr *
rq_deref_var(nir_builder * b,nir_def * index,rq_variable * var)56 rq_deref_var(nir_builder *b, nir_def *index, rq_variable *var)
57 {
58    if (var->array_length == 1)
59       return nir_build_deref_var(b, var->variable);
60 
61    return nir_build_deref_array(b, nir_build_deref_var(b, var->variable), index);
62 }
63 
64 static nir_def *
rq_load_var(nir_builder * b,nir_def * index,rq_variable * var)65 rq_load_var(nir_builder *b, nir_def *index, rq_variable *var)
66 {
67    if (var->array_length == 1)
68       return nir_load_var(b, var->variable);
69 
70    return nir_load_array(b, var->variable, index);
71 }
72 
73 static void
rq_store_var(nir_builder * b,nir_def * index,rq_variable * var,nir_def * value,unsigned writemask)74 rq_store_var(nir_builder *b, nir_def *index, rq_variable *var, nir_def *value, unsigned writemask)
75 {
76    if (var->array_length == 1) {
77       nir_store_var(b, var->variable, value, writemask);
78    } else {
79       nir_store_array(b, var->variable, index, value, writemask);
80    }
81 }
82 
83 static void
rq_copy_var(nir_builder * b,nir_def * index,rq_variable * dst,rq_variable * src,unsigned mask)84 rq_copy_var(nir_builder *b, nir_def *index, rq_variable *dst, rq_variable *src, unsigned mask)
85 {
86    rq_store_var(b, index, dst, rq_load_var(b, index, src), mask);
87 }
88 
89 static nir_def *
rq_load_array(nir_builder * b,nir_def * index,rq_variable * var,nir_def * array_index)90 rq_load_array(nir_builder *b, nir_def *index, rq_variable *var, nir_def *array_index)
91 {
92    if (var->array_length == 1)
93       return nir_load_array(b, var->variable, array_index);
94 
95    return nir_load_deref(
96       b, nir_build_deref_array(b, nir_build_deref_array(b, nir_build_deref_var(b, var->variable), index), array_index));
97 }
98 
99 static void
rq_store_array(nir_builder * b,nir_def * index,rq_variable * var,nir_def * array_index,nir_def * value,unsigned writemask)100 rq_store_array(nir_builder *b, nir_def *index, rq_variable *var, nir_def *array_index, nir_def *value,
101                unsigned writemask)
102 {
103    if (var->array_length == 1) {
104       nir_store_array(b, var->variable, array_index, value, writemask);
105    } else {
106       nir_store_deref(
107          b,
108          nir_build_deref_array(b, nir_build_deref_array(b, nir_build_deref_var(b, var->variable), index), array_index),
109          value, writemask);
110    }
111 }
112 
113 struct ray_query_traversal_vars {
114    rq_variable *origin;
115    rq_variable *direction;
116 
117    rq_variable *bvh_base;
118    rq_variable *stack;
119    rq_variable *top_stack;
120    rq_variable *stack_low_watermark;
121    rq_variable *current_node;
122    rq_variable *previous_node;
123    rq_variable *instance_top_node;
124    rq_variable *instance_bottom_node;
125 };
126 
127 struct ray_query_intersection_vars {
128    rq_variable *primitive_id;
129    rq_variable *geometry_id_and_flags;
130    rq_variable *instance_addr;
131    rq_variable *intersection_type;
132    rq_variable *opaque;
133    rq_variable *frontface;
134    rq_variable *sbt_offset_and_flags;
135    rq_variable *barycentrics;
136    rq_variable *t;
137 };
138 
139 struct ray_query_vars {
140    rq_variable *root_bvh_base;
141    rq_variable *flags;
142    rq_variable *cull_mask;
143    rq_variable *origin;
144    rq_variable *tmin;
145    rq_variable *direction;
146 
147    rq_variable *incomplete;
148 
149    struct ray_query_intersection_vars closest;
150    struct ray_query_intersection_vars candidate;
151 
152    struct ray_query_traversal_vars trav;
153 
154    rq_variable *stack;
155    uint32_t shared_base;
156    uint32_t stack_entries;
157 
158    nir_intrinsic_instr *initialize;
159 };
160 
161 #define VAR_NAME(name) strcat(strcpy(ralloc_size(ctx, strlen(base_name) + strlen(name) + 1), base_name), name)
162 
163 static struct ray_query_traversal_vars
init_ray_query_traversal_vars(void * ctx,nir_shader * shader,unsigned array_length,const char * base_name)164 init_ray_query_traversal_vars(void *ctx, nir_shader *shader, unsigned array_length, const char *base_name)
165 {
166    struct ray_query_traversal_vars result;
167 
168    const struct glsl_type *vec3_type = glsl_vector_type(GLSL_TYPE_FLOAT, 3);
169 
170    result.origin = rq_variable_create(ctx, shader, array_length, vec3_type, VAR_NAME("_origin"));
171    result.direction = rq_variable_create(ctx, shader, array_length, vec3_type, VAR_NAME("_direction"));
172 
173    result.bvh_base = rq_variable_create(ctx, shader, array_length, glsl_uint64_t_type(), VAR_NAME("_bvh_base"));
174    result.stack = rq_variable_create(ctx, shader, array_length, glsl_uint_type(), VAR_NAME("_stack"));
175    result.top_stack = rq_variable_create(ctx, shader, array_length, glsl_uint_type(), VAR_NAME("_top_stack"));
176    result.stack_low_watermark =
177       rq_variable_create(ctx, shader, array_length, glsl_uint_type(), VAR_NAME("_stack_low_watermark"));
178    result.current_node = rq_variable_create(ctx, shader, array_length, glsl_uint_type(), VAR_NAME("_current_node"));
179    result.previous_node = rq_variable_create(ctx, shader, array_length, glsl_uint_type(), VAR_NAME("_previous_node"));
180    result.instance_top_node =
181       rq_variable_create(ctx, shader, array_length, glsl_uint_type(), VAR_NAME("_instance_top_node"));
182    result.instance_bottom_node =
183       rq_variable_create(ctx, shader, array_length, glsl_uint_type(), VAR_NAME("_instance_bottom_node"));
184    return result;
185 }
186 
187 static struct ray_query_intersection_vars
init_ray_query_intersection_vars(void * ctx,nir_shader * shader,unsigned array_length,const char * base_name)188 init_ray_query_intersection_vars(void *ctx, nir_shader *shader, unsigned array_length, const char *base_name)
189 {
190    struct ray_query_intersection_vars result;
191 
192    const struct glsl_type *vec2_type = glsl_vector_type(GLSL_TYPE_FLOAT, 2);
193 
194    result.primitive_id = rq_variable_create(ctx, shader, array_length, glsl_uint_type(), VAR_NAME("_primitive_id"));
195    result.geometry_id_and_flags =
196       rq_variable_create(ctx, shader, array_length, glsl_uint_type(), VAR_NAME("_geometry_id_and_flags"));
197    result.instance_addr =
198       rq_variable_create(ctx, shader, array_length, glsl_uint64_t_type(), VAR_NAME("_instance_addr"));
199    result.intersection_type =
200       rq_variable_create(ctx, shader, array_length, glsl_uint_type(), VAR_NAME("_intersection_type"));
201    result.opaque = rq_variable_create(ctx, shader, array_length, glsl_bool_type(), VAR_NAME("_opaque"));
202    result.frontface = rq_variable_create(ctx, shader, array_length, glsl_bool_type(), VAR_NAME("_frontface"));
203    result.sbt_offset_and_flags =
204       rq_variable_create(ctx, shader, array_length, glsl_uint_type(), VAR_NAME("_sbt_offset_and_flags"));
205    result.barycentrics = rq_variable_create(ctx, shader, array_length, vec2_type, VAR_NAME("_barycentrics"));
206    result.t = rq_variable_create(ctx, shader, array_length, glsl_float_type(), VAR_NAME("_t"));
207 
208    return result;
209 }
210 
211 static void
init_ray_query_vars(nir_shader * shader,unsigned array_length,struct ray_query_vars * dst,const char * base_name,uint32_t max_shared_size)212 init_ray_query_vars(nir_shader *shader, unsigned array_length, struct ray_query_vars *dst, const char *base_name,
213                     uint32_t max_shared_size)
214 {
215    void *ctx = dst;
216    const struct glsl_type *vec3_type = glsl_vector_type(GLSL_TYPE_FLOAT, 3);
217 
218    dst->root_bvh_base = rq_variable_create(dst, shader, array_length, glsl_uint64_t_type(), VAR_NAME("_root_bvh_base"));
219    dst->flags = rq_variable_create(dst, shader, array_length, glsl_uint_type(), VAR_NAME("_flags"));
220    dst->cull_mask = rq_variable_create(dst, shader, array_length, glsl_uint_type(), VAR_NAME("_cull_mask"));
221    dst->origin = rq_variable_create(dst, shader, array_length, vec3_type, VAR_NAME("_origin"));
222    dst->tmin = rq_variable_create(dst, shader, array_length, glsl_float_type(), VAR_NAME("_tmin"));
223    dst->direction = rq_variable_create(dst, shader, array_length, vec3_type, VAR_NAME("_direction"));
224 
225    dst->incomplete = rq_variable_create(dst, shader, array_length, glsl_bool_type(), VAR_NAME("_incomplete"));
226 
227    dst->closest = init_ray_query_intersection_vars(dst, shader, array_length, VAR_NAME("_closest"));
228    dst->candidate = init_ray_query_intersection_vars(dst, shader, array_length, VAR_NAME("_candidate"));
229 
230    dst->trav = init_ray_query_traversal_vars(dst, shader, array_length, VAR_NAME("_top"));
231 
232    uint32_t workgroup_size =
233       shader->info.workgroup_size[0] * shader->info.workgroup_size[1] * shader->info.workgroup_size[2];
234    uint32_t shared_stack_entries = shader->info.ray_queries == 1 ? 16 : 8;
235    uint32_t shared_stack_size = workgroup_size * shared_stack_entries * 4;
236    uint32_t shared_offset = align(shader->info.shared_size, 4);
237    if (shader->info.stage != MESA_SHADER_COMPUTE || array_length > 1 ||
238        shared_offset + shared_stack_size > max_shared_size) {
239       dst->stack =
240          rq_variable_create(dst, shader, array_length,
241                             glsl_array_type(glsl_uint_type(), MAX_SCRATCH_STACK_ENTRY_COUNT, 0), VAR_NAME("_stack"));
242       dst->stack_entries = MAX_SCRATCH_STACK_ENTRY_COUNT;
243    } else {
244       dst->stack = NULL;
245       dst->shared_base = shared_offset;
246       dst->stack_entries = shared_stack_entries;
247 
248       shader->info.shared_size = shared_offset + shared_stack_size;
249    }
250 }
251 
252 #undef VAR_NAME
253 
254 static void
lower_ray_query(nir_shader * shader,nir_variable * ray_query,struct hash_table * ht,uint32_t max_shared_size)255 lower_ray_query(nir_shader *shader, nir_variable *ray_query, struct hash_table *ht, uint32_t max_shared_size)
256 {
257    struct ray_query_vars *vars = ralloc(ht, struct ray_query_vars);
258 
259    unsigned array_length = 1;
260    if (glsl_type_is_array(ray_query->type))
261       array_length = glsl_get_length(ray_query->type);
262 
263    init_ray_query_vars(shader, array_length, vars, ray_query->name == NULL ? "" : ray_query->name, max_shared_size);
264 
265    _mesa_hash_table_insert(ht, ray_query, vars);
266 }
267 
268 static void
copy_candidate_to_closest(nir_builder * b,nir_def * index,struct ray_query_vars * vars)269 copy_candidate_to_closest(nir_builder *b, nir_def *index, struct ray_query_vars *vars)
270 {
271    rq_copy_var(b, index, vars->closest.barycentrics, vars->candidate.barycentrics, 0x3);
272    rq_copy_var(b, index, vars->closest.geometry_id_and_flags, vars->candidate.geometry_id_and_flags, 0x1);
273    rq_copy_var(b, index, vars->closest.instance_addr, vars->candidate.instance_addr, 0x1);
274    rq_copy_var(b, index, vars->closest.intersection_type, vars->candidate.intersection_type, 0x1);
275    rq_copy_var(b, index, vars->closest.opaque, vars->candidate.opaque, 0x1);
276    rq_copy_var(b, index, vars->closest.frontface, vars->candidate.frontface, 0x1);
277    rq_copy_var(b, index, vars->closest.sbt_offset_and_flags, vars->candidate.sbt_offset_and_flags, 0x1);
278    rq_copy_var(b, index, vars->closest.primitive_id, vars->candidate.primitive_id, 0x1);
279    rq_copy_var(b, index, vars->closest.t, vars->candidate.t, 0x1);
280 }
281 
282 static void
insert_terminate_on_first_hit(nir_builder * b,nir_def * index,struct ray_query_vars * vars,const struct radv_ray_flags * ray_flags,bool break_on_terminate)283 insert_terminate_on_first_hit(nir_builder *b, nir_def *index, struct ray_query_vars *vars,
284                               const struct radv_ray_flags *ray_flags, bool break_on_terminate)
285 {
286    nir_def *terminate_on_first_hit;
287    if (ray_flags)
288       terminate_on_first_hit = ray_flags->terminate_on_first_hit;
289    else
290       terminate_on_first_hit =
291          nir_test_mask(b, rq_load_var(b, index, vars->flags), SpvRayFlagsTerminateOnFirstHitKHRMask);
292    nir_push_if(b, terminate_on_first_hit);
293    {
294       rq_store_var(b, index, vars->incomplete, nir_imm_false(b), 0x1);
295       if (break_on_terminate)
296          nir_jump(b, nir_jump_break);
297    }
298    nir_pop_if(b, NULL);
299 }
300 
301 static void
lower_rq_confirm_intersection(nir_builder * b,nir_def * index,nir_intrinsic_instr * instr,struct ray_query_vars * vars)302 lower_rq_confirm_intersection(nir_builder *b, nir_def *index, nir_intrinsic_instr *instr, struct ray_query_vars *vars)
303 {
304    copy_candidate_to_closest(b, index, vars);
305    insert_terminate_on_first_hit(b, index, vars, NULL, false);
306 }
307 
308 static void
lower_rq_generate_intersection(nir_builder * b,nir_def * index,nir_intrinsic_instr * instr,struct ray_query_vars * vars)309 lower_rq_generate_intersection(nir_builder *b, nir_def *index, nir_intrinsic_instr *instr, struct ray_query_vars *vars)
310 {
311    nir_push_if(b, nir_iand(b, nir_fge(b, rq_load_var(b, index, vars->closest.t), instr->src[1].ssa),
312                            nir_fge(b, instr->src[1].ssa, rq_load_var(b, index, vars->tmin))));
313    {
314       copy_candidate_to_closest(b, index, vars);
315       insert_terminate_on_first_hit(b, index, vars, NULL, false);
316       rq_store_var(b, index, vars->closest.t, instr->src[1].ssa, 0x1);
317    }
318    nir_pop_if(b, NULL);
319 }
320 
321 enum rq_intersection_type { intersection_type_none, intersection_type_triangle, intersection_type_aabb };
322 
323 static void
lower_rq_initialize(nir_builder * b,nir_def * index,nir_intrinsic_instr * instr,struct ray_query_vars * vars,struct radv_instance * instance)324 lower_rq_initialize(nir_builder *b, nir_def *index, nir_intrinsic_instr *instr, struct ray_query_vars *vars,
325                     struct radv_instance *instance)
326 {
327    rq_store_var(b, index, vars->flags, instr->src[2].ssa, 0x1);
328    rq_store_var(b, index, vars->cull_mask, nir_ishl_imm(b, instr->src[3].ssa, 24), 0x1);
329 
330    rq_store_var(b, index, vars->origin, instr->src[4].ssa, 0x7);
331    rq_store_var(b, index, vars->trav.origin, instr->src[4].ssa, 0x7);
332 
333    rq_store_var(b, index, vars->tmin, instr->src[5].ssa, 0x1);
334 
335    rq_store_var(b, index, vars->direction, instr->src[6].ssa, 0x7);
336    rq_store_var(b, index, vars->trav.direction, instr->src[6].ssa, 0x7);
337 
338    rq_store_var(b, index, vars->closest.t, instr->src[7].ssa, 0x1);
339    rq_store_var(b, index, vars->closest.intersection_type, nir_imm_int(b, intersection_type_none), 0x1);
340 
341    nir_def *accel_struct = instr->src[1].ssa;
342 
343    /* Make sure that instance data loads don't hang in case of a miss by setting a valid initial address. */
344    rq_store_var(b, index, vars->closest.instance_addr, accel_struct, 1);
345    rq_store_var(b, index, vars->candidate.instance_addr, accel_struct, 1);
346 
347    nir_def *bvh_offset = nir_build_load_global(
348       b, 1, 32, nir_iadd_imm(b, accel_struct, offsetof(struct radv_accel_struct_header, bvh_offset)),
349       .access = ACCESS_NON_WRITEABLE);
350    nir_def *bvh_base = nir_iadd(b, accel_struct, nir_u2u64(b, bvh_offset));
351    bvh_base = build_addr_to_node(b, bvh_base);
352 
353    rq_store_var(b, index, vars->root_bvh_base, bvh_base, 0x1);
354    rq_store_var(b, index, vars->trav.bvh_base, bvh_base, 1);
355 
356    if (vars->stack) {
357       rq_store_var(b, index, vars->trav.stack, nir_imm_int(b, 0), 0x1);
358       rq_store_var(b, index, vars->trav.stack_low_watermark, nir_imm_int(b, 0), 0x1);
359    } else {
360       nir_def *base_offset = nir_imul_imm(b, nir_load_local_invocation_index(b), sizeof(uint32_t));
361       base_offset = nir_iadd_imm(b, base_offset, vars->shared_base);
362       rq_store_var(b, index, vars->trav.stack, base_offset, 0x1);
363       rq_store_var(b, index, vars->trav.stack_low_watermark, base_offset, 0x1);
364    }
365 
366    rq_store_var(b, index, vars->trav.current_node, nir_imm_int(b, RADV_BVH_ROOT_NODE), 0x1);
367    rq_store_var(b, index, vars->trav.previous_node, nir_imm_int(b, RADV_BVH_INVALID_NODE), 0x1);
368    rq_store_var(b, index, vars->trav.instance_top_node, nir_imm_int(b, RADV_BVH_INVALID_NODE), 0x1);
369    rq_store_var(b, index, vars->trav.instance_bottom_node, nir_imm_int(b, RADV_BVH_NO_INSTANCE_ROOT), 0x1);
370 
371    rq_store_var(b, index, vars->trav.top_stack, nir_imm_int(b, -1), 1);
372 
373    rq_store_var(b, index, vars->incomplete, nir_imm_bool(b, !(instance->debug_flags & RADV_DEBUG_NO_RT)), 0x1);
374 
375    vars->initialize = instr;
376 }
377 
378 static nir_def *
lower_rq_load(struct radv_device * device,nir_builder * b,nir_def * index,nir_intrinsic_instr * instr,struct ray_query_vars * vars)379 lower_rq_load(struct radv_device *device, nir_builder *b, nir_def *index, nir_intrinsic_instr *instr,
380               struct ray_query_vars *vars)
381 {
382    bool committed = nir_intrinsic_committed(instr);
383    struct ray_query_intersection_vars *intersection = committed ? &vars->closest : &vars->candidate;
384 
385    uint32_t column = nir_intrinsic_column(instr);
386 
387    nir_ray_query_value value = nir_intrinsic_ray_query_value(instr);
388    switch (value) {
389    case nir_ray_query_value_flags:
390       return rq_load_var(b, index, vars->flags);
391    case nir_ray_query_value_intersection_barycentrics:
392       return rq_load_var(b, index, intersection->barycentrics);
393    case nir_ray_query_value_intersection_candidate_aabb_opaque:
394       return nir_iand(b, rq_load_var(b, index, vars->candidate.opaque),
395                       nir_ieq_imm(b, rq_load_var(b, index, vars->candidate.intersection_type), intersection_type_aabb));
396    case nir_ray_query_value_intersection_front_face:
397       return rq_load_var(b, index, intersection->frontface);
398    case nir_ray_query_value_intersection_geometry_index:
399       return nir_iand_imm(b, rq_load_var(b, index, intersection->geometry_id_and_flags), 0xFFFFFF);
400    case nir_ray_query_value_intersection_instance_custom_index: {
401       nir_def *instance_node_addr = rq_load_var(b, index, intersection->instance_addr);
402       return nir_iand_imm(
403          b,
404          nir_build_load_global(
405             b, 1, 32,
406             nir_iadd_imm(b, instance_node_addr, offsetof(struct radv_bvh_instance_node, custom_instance_and_mask))),
407          0xFFFFFF);
408    }
409    case nir_ray_query_value_intersection_instance_id: {
410       nir_def *instance_node_addr = rq_load_var(b, index, intersection->instance_addr);
411       return nir_build_load_global(
412          b, 1, 32, nir_iadd_imm(b, instance_node_addr, offsetof(struct radv_bvh_instance_node, instance_id)));
413    }
414    case nir_ray_query_value_intersection_instance_sbt_index:
415       return nir_iand_imm(b, rq_load_var(b, index, intersection->sbt_offset_and_flags), 0xFFFFFF);
416    case nir_ray_query_value_intersection_object_ray_direction: {
417       nir_def *instance_node_addr = rq_load_var(b, index, intersection->instance_addr);
418       nir_def *wto_matrix[3];
419       nir_build_wto_matrix_load(b, instance_node_addr, wto_matrix);
420       return nir_build_vec3_mat_mult(b, rq_load_var(b, index, vars->direction), wto_matrix, false);
421    }
422    case nir_ray_query_value_intersection_object_ray_origin: {
423       nir_def *instance_node_addr = rq_load_var(b, index, intersection->instance_addr);
424       nir_def *wto_matrix[3];
425       nir_build_wto_matrix_load(b, instance_node_addr, wto_matrix);
426       return nir_build_vec3_mat_mult(b, rq_load_var(b, index, vars->origin), wto_matrix, true);
427    }
428    case nir_ray_query_value_intersection_object_to_world: {
429       nir_def *instance_node_addr = rq_load_var(b, index, intersection->instance_addr);
430       nir_def *rows[3];
431       for (unsigned r = 0; r < 3; ++r)
432          rows[r] = nir_build_load_global(
433             b, 4, 32,
434             nir_iadd_imm(b, instance_node_addr, offsetof(struct radv_bvh_instance_node, otw_matrix) + r * 16));
435 
436       return nir_vec3(b, nir_channel(b, rows[0], column), nir_channel(b, rows[1], column),
437                       nir_channel(b, rows[2], column));
438    }
439    case nir_ray_query_value_intersection_primitive_index:
440       return rq_load_var(b, index, intersection->primitive_id);
441    case nir_ray_query_value_intersection_t:
442       return rq_load_var(b, index, intersection->t);
443    case nir_ray_query_value_intersection_type: {
444       nir_def *intersection_type = rq_load_var(b, index, intersection->intersection_type);
445       if (!committed)
446          intersection_type = nir_iadd_imm(b, intersection_type, -1);
447 
448       return intersection_type;
449    }
450    case nir_ray_query_value_intersection_world_to_object: {
451       nir_def *instance_node_addr = rq_load_var(b, index, intersection->instance_addr);
452 
453       nir_def *wto_matrix[3];
454       nir_build_wto_matrix_load(b, instance_node_addr, wto_matrix);
455 
456       nir_def *vals[3];
457       for (unsigned i = 0; i < 3; ++i)
458          vals[i] = nir_channel(b, wto_matrix[i], column);
459 
460       return nir_vec(b, vals, 3);
461    }
462    case nir_ray_query_value_tmin:
463       return rq_load_var(b, index, vars->tmin);
464    case nir_ray_query_value_world_ray_direction:
465       return rq_load_var(b, index, vars->direction);
466    case nir_ray_query_value_world_ray_origin:
467       return rq_load_var(b, index, vars->origin);
468    case nir_ray_query_value_intersection_triangle_vertex_positions: {
469       nir_def *instance_node_addr = rq_load_var(b, index, intersection->instance_addr);
470       nir_def *primitive_id = rq_load_var(b, index, intersection->primitive_id);
471       return radv_load_vertex_position(device, b, instance_node_addr, primitive_id, nir_intrinsic_column(instr));
472    }
473    default:
474       unreachable("Invalid nir_ray_query_value!");
475    }
476 
477    return NULL;
478 }
479 
480 struct traversal_data {
481    struct ray_query_vars *vars;
482    nir_def *index;
483 };
484 
485 static void
handle_candidate_aabb(nir_builder * b,struct radv_leaf_intersection * intersection,const struct radv_ray_traversal_args * args)486 handle_candidate_aabb(nir_builder *b, struct radv_leaf_intersection *intersection,
487                       const struct radv_ray_traversal_args *args)
488 {
489    struct traversal_data *data = args->data;
490    struct ray_query_vars *vars = data->vars;
491    nir_def *index = data->index;
492 
493    rq_store_var(b, index, vars->candidate.primitive_id, intersection->primitive_id, 1);
494    rq_store_var(b, index, vars->candidate.geometry_id_and_flags, intersection->geometry_id_and_flags, 1);
495    rq_store_var(b, index, vars->candidate.opaque, intersection->opaque, 0x1);
496    rq_store_var(b, index, vars->candidate.intersection_type, nir_imm_int(b, intersection_type_aabb), 0x1);
497 
498    nir_jump(b, nir_jump_break);
499 }
500 
501 static void
handle_candidate_triangle(nir_builder * b,struct radv_triangle_intersection * intersection,const struct radv_ray_traversal_args * args,const struct radv_ray_flags * ray_flags)502 handle_candidate_triangle(nir_builder *b, struct radv_triangle_intersection *intersection,
503                           const struct radv_ray_traversal_args *args, const struct radv_ray_flags *ray_flags)
504 {
505    struct traversal_data *data = args->data;
506    struct ray_query_vars *vars = data->vars;
507    nir_def *index = data->index;
508 
509    rq_store_var(b, index, vars->candidate.barycentrics, intersection->barycentrics, 3);
510    rq_store_var(b, index, vars->candidate.primitive_id, intersection->base.primitive_id, 1);
511    rq_store_var(b, index, vars->candidate.geometry_id_and_flags, intersection->base.geometry_id_and_flags, 1);
512    rq_store_var(b, index, vars->candidate.t, intersection->t, 0x1);
513    rq_store_var(b, index, vars->candidate.opaque, intersection->base.opaque, 0x1);
514    rq_store_var(b, index, vars->candidate.frontface, intersection->frontface, 0x1);
515    rq_store_var(b, index, vars->candidate.intersection_type, nir_imm_int(b, intersection_type_triangle), 0x1);
516 
517    nir_push_if(b, intersection->base.opaque);
518    {
519       copy_candidate_to_closest(b, index, vars);
520       insert_terminate_on_first_hit(b, index, vars, ray_flags, true);
521    }
522    nir_push_else(b, NULL);
523    {
524       nir_jump(b, nir_jump_break);
525    }
526    nir_pop_if(b, NULL);
527 }
528 
529 static void
store_stack_entry(nir_builder * b,nir_def * index,nir_def * value,const struct radv_ray_traversal_args * args)530 store_stack_entry(nir_builder *b, nir_def *index, nir_def *value, const struct radv_ray_traversal_args *args)
531 {
532    struct traversal_data *data = args->data;
533    if (data->vars->stack)
534       rq_store_array(b, data->index, data->vars->stack, index, value, 1);
535    else
536       nir_store_shared(b, value, index, .base = 0, .align_mul = 4);
537 }
538 
539 static nir_def *
load_stack_entry(nir_builder * b,nir_def * index,const struct radv_ray_traversal_args * args)540 load_stack_entry(nir_builder *b, nir_def *index, const struct radv_ray_traversal_args *args)
541 {
542    struct traversal_data *data = args->data;
543    if (data->vars->stack)
544       return rq_load_array(b, data->index, data->vars->stack, index);
545    else
546       return nir_load_shared(b, 1, 32, index, .base = 0, .align_mul = 4);
547 }
548 
549 static nir_def *
lower_rq_proceed(nir_builder * b,nir_def * index,nir_intrinsic_instr * instr,struct ray_query_vars * vars,struct radv_device * device)550 lower_rq_proceed(nir_builder *b, nir_def *index, nir_intrinsic_instr *instr, struct ray_query_vars *vars,
551                  struct radv_device *device)
552 {
553    nir_metadata_require(nir_cf_node_get_function(&instr->instr.block->cf_node), nir_metadata_dominance);
554 
555    bool ignore_cull_mask = false;
556    if (nir_block_dominates(vars->initialize->instr.block, instr->instr.block)) {
557       nir_src cull_mask = vars->initialize->src[3];
558       if (nir_src_is_const(cull_mask) && nir_src_as_uint(cull_mask) == 0xFF)
559          ignore_cull_mask = true;
560    }
561 
562    nir_variable *inv_dir = nir_local_variable_create(b->impl, glsl_vector_type(GLSL_TYPE_FLOAT, 3), "inv_dir");
563    nir_store_var(b, inv_dir, nir_frcp(b, rq_load_var(b, index, vars->trav.direction)), 0x7);
564 
565    struct radv_ray_traversal_vars trav_vars = {
566       .tmax = rq_deref_var(b, index, vars->closest.t),
567       .origin = rq_deref_var(b, index, vars->trav.origin),
568       .dir = rq_deref_var(b, index, vars->trav.direction),
569       .inv_dir = nir_build_deref_var(b, inv_dir),
570       .bvh_base = rq_deref_var(b, index, vars->trav.bvh_base),
571       .stack = rq_deref_var(b, index, vars->trav.stack),
572       .top_stack = rq_deref_var(b, index, vars->trav.top_stack),
573       .stack_low_watermark = rq_deref_var(b, index, vars->trav.stack_low_watermark),
574       .current_node = rq_deref_var(b, index, vars->trav.current_node),
575       .previous_node = rq_deref_var(b, index, vars->trav.previous_node),
576       .instance_top_node = rq_deref_var(b, index, vars->trav.instance_top_node),
577       .instance_bottom_node = rq_deref_var(b, index, vars->trav.instance_bottom_node),
578       .instance_addr = rq_deref_var(b, index, vars->candidate.instance_addr),
579       .sbt_offset_and_flags = rq_deref_var(b, index, vars->candidate.sbt_offset_and_flags),
580    };
581 
582    struct traversal_data data = {
583       .vars = vars,
584       .index = index,
585    };
586 
587    struct radv_ray_traversal_args args = {
588       .root_bvh_base = rq_load_var(b, index, vars->root_bvh_base),
589       .flags = rq_load_var(b, index, vars->flags),
590       .cull_mask = rq_load_var(b, index, vars->cull_mask),
591       .origin = rq_load_var(b, index, vars->origin),
592       .tmin = rq_load_var(b, index, vars->tmin),
593       .dir = rq_load_var(b, index, vars->direction),
594       .vars = trav_vars,
595       .stack_entries = vars->stack_entries,
596       .ignore_cull_mask = ignore_cull_mask,
597       .stack_store_cb = store_stack_entry,
598       .stack_load_cb = load_stack_entry,
599       .aabb_cb = handle_candidate_aabb,
600       .triangle_cb = handle_candidate_triangle,
601       .data = &data,
602    };
603 
604    if (vars->stack) {
605       args.stack_stride = 1;
606       args.stack_base = 0;
607    } else {
608       uint32_t workgroup_size =
609          b->shader->info.workgroup_size[0] * b->shader->info.workgroup_size[1] * b->shader->info.workgroup_size[2];
610       args.stack_stride = workgroup_size * 4;
611       args.stack_base = vars->shared_base;
612    }
613 
614    nir_push_if(b, rq_load_var(b, index, vars->incomplete));
615    {
616       nir_def *incomplete = radv_build_ray_traversal(device, b, &args);
617       rq_store_var(b, index, vars->incomplete, nir_iand(b, rq_load_var(b, index, vars->incomplete), incomplete), 1);
618    }
619    nir_pop_if(b, NULL);
620 
621    return rq_load_var(b, index, vars->incomplete);
622 }
623 
624 static void
lower_rq_terminate(nir_builder * b,nir_def * index,nir_intrinsic_instr * instr,struct ray_query_vars * vars)625 lower_rq_terminate(nir_builder *b, nir_def *index, nir_intrinsic_instr *instr, struct ray_query_vars *vars)
626 {
627    rq_store_var(b, index, vars->incomplete, nir_imm_false(b), 0x1);
628 }
629 
630 bool
radv_nir_lower_ray_queries(struct nir_shader * shader,struct radv_device * device)631 radv_nir_lower_ray_queries(struct nir_shader *shader, struct radv_device *device)
632 {
633    const struct radv_physical_device *pdev = radv_device_physical(device);
634    struct radv_instance *instance = radv_physical_device_instance(pdev);
635    bool progress = false;
636    struct hash_table *query_ht = _mesa_pointer_hash_table_create(NULL);
637 
638    nir_foreach_variable_in_list (var, &shader->variables) {
639       if (!var->data.ray_query)
640          continue;
641 
642       lower_ray_query(shader, var, query_ht, pdev->max_shared_size);
643 
644       progress = true;
645    }
646 
647    nir_foreach_function (function, shader) {
648       if (!function->impl)
649          continue;
650 
651       nir_builder builder = nir_builder_create(function->impl);
652 
653       nir_foreach_variable_in_list (var, &function->impl->locals) {
654          if (!var->data.ray_query)
655             continue;
656 
657          lower_ray_query(shader, var, query_ht, pdev->max_shared_size);
658 
659          progress = true;
660       }
661 
662       nir_foreach_block (block, function->impl) {
663          nir_foreach_instr_safe (instr, block) {
664             if (instr->type != nir_instr_type_intrinsic)
665                continue;
666 
667             nir_intrinsic_instr *intrinsic = nir_instr_as_intrinsic(instr);
668 
669             if (!nir_intrinsic_is_ray_query(intrinsic->intrinsic))
670                continue;
671 
672             nir_deref_instr *ray_query_deref = nir_instr_as_deref(intrinsic->src[0].ssa->parent_instr);
673             nir_def *index = NULL;
674 
675             if (ray_query_deref->deref_type == nir_deref_type_array) {
676                index = ray_query_deref->arr.index.ssa;
677                ray_query_deref = nir_instr_as_deref(ray_query_deref->parent.ssa->parent_instr);
678             }
679 
680             assert(ray_query_deref->deref_type == nir_deref_type_var);
681 
682             struct ray_query_vars *vars =
683                (struct ray_query_vars *)_mesa_hash_table_search(query_ht, ray_query_deref->var)->data;
684 
685             builder.cursor = nir_before_instr(instr);
686 
687             nir_def *new_dest = NULL;
688 
689             switch (intrinsic->intrinsic) {
690             case nir_intrinsic_rq_confirm_intersection:
691                lower_rq_confirm_intersection(&builder, index, intrinsic, vars);
692                break;
693             case nir_intrinsic_rq_generate_intersection:
694                lower_rq_generate_intersection(&builder, index, intrinsic, vars);
695                break;
696             case nir_intrinsic_rq_initialize:
697                lower_rq_initialize(&builder, index, intrinsic, vars, instance);
698                break;
699             case nir_intrinsic_rq_load:
700                new_dest = lower_rq_load(device, &builder, index, intrinsic, vars);
701                break;
702             case nir_intrinsic_rq_proceed:
703                new_dest = lower_rq_proceed(&builder, index, intrinsic, vars, device);
704                break;
705             case nir_intrinsic_rq_terminate:
706                lower_rq_terminate(&builder, index, intrinsic, vars);
707                break;
708             default:
709                unreachable("Unsupported ray query intrinsic!");
710             }
711 
712             if (new_dest)
713                nir_def_rewrite_uses(&intrinsic->def, new_dest);
714 
715             nir_instr_remove(instr);
716             nir_instr_free(instr);
717 
718             progress = true;
719          }
720       }
721 
722       nir_metadata_preserve(function->impl, nir_metadata_none);
723    }
724 
725    ralloc_free(query_ht);
726 
727    return progress;
728 }
729