xref: /aosp_15_r20/external/mesa3d/src/amd/vulkan/nir/radv_nir_rt_shader.c (revision 6104692788411f58d303aa86923a9ff6ecaded22)
1 /*
2  * Copyright © 2021 Google
3  *
4  * SPDX-License-Identifier: MIT
5  */
6 
7 #include "nir/nir.h"
8 #include "nir/nir_builder.h"
9 
10 #include "bvh/bvh.h"
11 #include "meta/radv_meta.h"
12 #include "nir/radv_nir.h"
13 #include "nir/radv_nir_rt_common.h"
14 #include "ac_nir.h"
15 #include "radv_pipeline_cache.h"
16 #include "radv_pipeline_rt.h"
17 #include "radv_shader.h"
18 
19 #include "vk_pipeline.h"
20 
21 /* Traversal stack size. This stack is put in LDS and experimentally 16 entries results in best
22  * performance. */
23 #define MAX_STACK_ENTRY_COUNT 16
24 
25 #define RADV_RT_SWITCH_NULL_CHECK_THRESHOLD 3
26 
27 /* Minimum number of inlined shaders to use binary search to select which shader to run. */
28 #define INLINED_SHADER_BSEARCH_THRESHOLD 16
29 
30 struct radv_rt_case_data {
31    struct radv_device *device;
32    struct radv_ray_tracing_pipeline *pipeline;
33    struct rt_variables *vars;
34 };
35 
36 typedef void (*radv_get_group_info)(struct radv_ray_tracing_group *, uint32_t *, uint32_t *,
37                                     struct radv_rt_case_data *);
38 typedef void (*radv_insert_shader_case)(nir_builder *, nir_def *, struct radv_ray_tracing_group *,
39                                         struct radv_rt_case_data *);
40 
41 struct inlined_shader_case {
42    struct radv_ray_tracing_group *group;
43    uint32_t call_idx;
44 };
45 
46 static int
compare_inlined_shader_case(const void * a,const void * b)47 compare_inlined_shader_case(const void *a, const void *b)
48 {
49    const struct inlined_shader_case *visit_a = a;
50    const struct inlined_shader_case *visit_b = b;
51    return visit_a->call_idx > visit_b->call_idx ? 1 : visit_a->call_idx < visit_b->call_idx ? -1 : 0;
52 }
53 
54 static void
insert_inlined_range(nir_builder * b,nir_def * sbt_idx,radv_insert_shader_case shader_case,struct radv_rt_case_data * data,struct inlined_shader_case * cases,uint32_t length)55 insert_inlined_range(nir_builder *b, nir_def *sbt_idx, radv_insert_shader_case shader_case,
56                      struct radv_rt_case_data *data, struct inlined_shader_case *cases, uint32_t length)
57 {
58    if (length >= INLINED_SHADER_BSEARCH_THRESHOLD) {
59       nir_push_if(b, nir_ige_imm(b, sbt_idx, cases[length / 2].call_idx));
60       {
61          insert_inlined_range(b, sbt_idx, shader_case, data, cases + (length / 2), length - (length / 2));
62       }
63       nir_push_else(b, NULL);
64       {
65          insert_inlined_range(b, sbt_idx, shader_case, data, cases, length / 2);
66       }
67       nir_pop_if(b, NULL);
68    } else {
69       for (uint32_t i = 0; i < length; ++i)
70          shader_case(b, sbt_idx, cases[i].group, data);
71    }
72 }
73 
74 static void
radv_visit_inlined_shaders(nir_builder * b,nir_def * sbt_idx,bool can_have_null_shaders,struct radv_rt_case_data * data,radv_get_group_info group_info,radv_insert_shader_case shader_case)75 radv_visit_inlined_shaders(nir_builder *b, nir_def *sbt_idx, bool can_have_null_shaders, struct radv_rt_case_data *data,
76                            radv_get_group_info group_info, radv_insert_shader_case shader_case)
77 {
78    struct inlined_shader_case *cases = calloc(data->pipeline->group_count, sizeof(struct inlined_shader_case));
79    uint32_t case_count = 0;
80 
81    for (unsigned i = 0; i < data->pipeline->group_count; i++) {
82       struct radv_ray_tracing_group *group = &data->pipeline->groups[i];
83 
84       uint32_t shader_index = VK_SHADER_UNUSED_KHR;
85       uint32_t handle_index = VK_SHADER_UNUSED_KHR;
86       group_info(group, &shader_index, &handle_index, data);
87       if (shader_index == VK_SHADER_UNUSED_KHR)
88          continue;
89 
90       /* Avoid emitting stages with the same shaders/handles multiple times. */
91       bool duplicate = false;
92       for (unsigned j = 0; j < i; j++) {
93          uint32_t other_shader_index = VK_SHADER_UNUSED_KHR;
94          uint32_t other_handle_index = VK_SHADER_UNUSED_KHR;
95          group_info(&data->pipeline->groups[j], &other_shader_index, &other_handle_index, data);
96 
97          if (handle_index == other_handle_index) {
98             duplicate = true;
99             break;
100          }
101       }
102 
103       if (!duplicate) {
104          cases[case_count++] = (struct inlined_shader_case){
105             .group = group,
106             .call_idx = handle_index,
107          };
108       }
109    }
110 
111    qsort(cases, case_count, sizeof(struct inlined_shader_case), compare_inlined_shader_case);
112 
113    /* Do not emit 'if (sbt_idx != 0) { ... }' is there are only a few cases. */
114    can_have_null_shaders &= case_count >= RADV_RT_SWITCH_NULL_CHECK_THRESHOLD;
115 
116    if (can_have_null_shaders)
117       nir_push_if(b, nir_ine_imm(b, sbt_idx, 0));
118 
119    insert_inlined_range(b, sbt_idx, shader_case, data, cases, case_count);
120 
121    if (can_have_null_shaders)
122       nir_pop_if(b, NULL);
123 
124    free(cases);
125 }
126 
127 static bool
lower_rt_derefs(nir_shader * shader)128 lower_rt_derefs(nir_shader *shader)
129 {
130    nir_function_impl *impl = nir_shader_get_entrypoint(shader);
131 
132    bool progress = false;
133 
134    nir_builder b = nir_builder_at(nir_before_impl(impl));
135 
136    nir_def *arg_offset = nir_load_rt_arg_scratch_offset_amd(&b);
137 
138    nir_foreach_block (block, impl) {
139       nir_foreach_instr_safe (instr, block) {
140          if (instr->type != nir_instr_type_deref)
141             continue;
142 
143          nir_deref_instr *deref = nir_instr_as_deref(instr);
144          if (!nir_deref_mode_is(deref, nir_var_shader_call_data))
145             continue;
146 
147          deref->modes = nir_var_function_temp;
148          progress = true;
149 
150          if (deref->deref_type == nir_deref_type_var) {
151             b.cursor = nir_before_instr(&deref->instr);
152             nir_deref_instr *replacement =
153                nir_build_deref_cast(&b, arg_offset, nir_var_function_temp, deref->var->type, 0);
154             nir_def_replace(&deref->def, &replacement->def);
155          }
156       }
157    }
158 
159    if (progress)
160       nir_metadata_preserve(impl, nir_metadata_control_flow);
161    else
162       nir_metadata_preserve(impl, nir_metadata_all);
163 
164    return progress;
165 }
166 
167 /*
168  * Global variables for an RT pipeline
169  */
170 struct rt_variables {
171    struct radv_device *device;
172    const VkPipelineCreateFlags2KHR flags;
173    bool monolithic;
174 
175    /* idx of the next shader to run in the next iteration of the main loop.
176     * During traversal, idx is used to store the SBT index and will contain
177     * the correct resume index upon returning.
178     */
179    nir_variable *idx;
180    nir_variable *shader_addr;
181    nir_variable *traversal_addr;
182 
183    /* scratch offset of the argument area relative to stack_ptr */
184    nir_variable *arg;
185    uint32_t payload_offset;
186 
187    nir_variable *stack_ptr;
188 
189    nir_variable *ahit_isec_count;
190 
191    nir_variable *launch_sizes[3];
192    nir_variable *launch_ids[3];
193 
194    /* global address of the SBT entry used for the shader */
195    nir_variable *shader_record_ptr;
196 
197    /* trace_ray arguments */
198    nir_variable *accel_struct;
199    nir_variable *cull_mask_and_flags;
200    nir_variable *sbt_offset;
201    nir_variable *sbt_stride;
202    nir_variable *miss_index;
203    nir_variable *origin;
204    nir_variable *tmin;
205    nir_variable *direction;
206    nir_variable *tmax;
207 
208    /* Properties of the primitive currently being visited. */
209    nir_variable *primitive_id;
210    nir_variable *geometry_id_and_flags;
211    nir_variable *instance_addr;
212    nir_variable *hit_kind;
213    nir_variable *opaque;
214 
215    /* Output variables for intersection & anyhit shaders. */
216    nir_variable *ahit_accept;
217    nir_variable *ahit_terminate;
218 
219    unsigned stack_size;
220 };
221 
222 static struct rt_variables
create_rt_variables(nir_shader * shader,struct radv_device * device,const VkPipelineCreateFlags2KHR flags,bool monolithic)223 create_rt_variables(nir_shader *shader, struct radv_device *device, const VkPipelineCreateFlags2KHR flags,
224                     bool monolithic)
225 {
226    struct rt_variables vars = {
227       .device = device,
228       .flags = flags,
229       .monolithic = monolithic,
230    };
231    vars.idx = nir_variable_create(shader, nir_var_shader_temp, glsl_uint_type(), "idx");
232    vars.shader_addr = nir_variable_create(shader, nir_var_shader_temp, glsl_uint64_t_type(), "shader_addr");
233    vars.traversal_addr = nir_variable_create(shader, nir_var_shader_temp, glsl_uint64_t_type(), "traversal_addr");
234    vars.arg = nir_variable_create(shader, nir_var_shader_temp, glsl_uint_type(), "arg");
235    vars.stack_ptr = nir_variable_create(shader, nir_var_shader_temp, glsl_uint_type(), "stack_ptr");
236    vars.shader_record_ptr = nir_variable_create(shader, nir_var_shader_temp, glsl_uint64_t_type(), "shader_record_ptr");
237 
238    vars.launch_sizes[0] = nir_variable_create(shader, nir_var_shader_temp, glsl_uint_type(), "launch_size_x");
239    vars.launch_sizes[1] = nir_variable_create(shader, nir_var_shader_temp, glsl_uint_type(), "launch_size_y");
240    vars.launch_sizes[2] = nir_variable_create(shader, nir_var_shader_temp, glsl_uint_type(), "launch_size_z");
241 
242    vars.launch_ids[0] = nir_variable_create(shader, nir_var_shader_temp, glsl_uint_type(), "launch_id_x");
243    vars.launch_ids[1] = nir_variable_create(shader, nir_var_shader_temp, glsl_uint_type(), "launch_id_y");
244    vars.launch_ids[2] = nir_variable_create(shader, nir_var_shader_temp, glsl_uint_type(), "launch_id_z");
245 
246    if (device->rra_trace.ray_history_addr)
247       vars.ahit_isec_count = nir_variable_create(shader, nir_var_shader_temp, glsl_uint_type(), "ahit_isec_count");
248 
249    const struct glsl_type *vec3_type = glsl_vector_type(GLSL_TYPE_FLOAT, 3);
250    vars.accel_struct = nir_variable_create(shader, nir_var_shader_temp, glsl_uint64_t_type(), "accel_struct");
251    vars.cull_mask_and_flags = nir_variable_create(shader, nir_var_shader_temp, glsl_uint_type(), "cull_mask_and_flags");
252    vars.sbt_offset = nir_variable_create(shader, nir_var_shader_temp, glsl_uint_type(), "sbt_offset");
253    vars.sbt_stride = nir_variable_create(shader, nir_var_shader_temp, glsl_uint_type(), "sbt_stride");
254    vars.miss_index = nir_variable_create(shader, nir_var_shader_temp, glsl_uint_type(), "miss_index");
255    vars.origin = nir_variable_create(shader, nir_var_shader_temp, vec3_type, "ray_origin");
256    vars.tmin = nir_variable_create(shader, nir_var_shader_temp, glsl_float_type(), "ray_tmin");
257    vars.direction = nir_variable_create(shader, nir_var_shader_temp, vec3_type, "ray_direction");
258    vars.tmax = nir_variable_create(shader, nir_var_shader_temp, glsl_float_type(), "ray_tmax");
259 
260    vars.primitive_id = nir_variable_create(shader, nir_var_shader_temp, glsl_uint_type(), "primitive_id");
261    vars.geometry_id_and_flags =
262       nir_variable_create(shader, nir_var_shader_temp, glsl_uint_type(), "geometry_id_and_flags");
263    vars.instance_addr = nir_variable_create(shader, nir_var_shader_temp, glsl_uint64_t_type(), "instance_addr");
264    vars.hit_kind = nir_variable_create(shader, nir_var_shader_temp, glsl_uint_type(), "hit_kind");
265    vars.opaque = nir_variable_create(shader, nir_var_shader_temp, glsl_bool_type(), "opaque");
266 
267    vars.ahit_accept = nir_variable_create(shader, nir_var_shader_temp, glsl_bool_type(), "ahit_accept");
268    vars.ahit_terminate = nir_variable_create(shader, nir_var_shader_temp, glsl_bool_type(), "ahit_terminate");
269 
270    return vars;
271 }
272 
273 /*
274  * Remap all the variables between the two rt_variables struct for inlining.
275  */
276 static void
map_rt_variables(struct hash_table * var_remap,struct rt_variables * src,const struct rt_variables * dst)277 map_rt_variables(struct hash_table *var_remap, struct rt_variables *src, const struct rt_variables *dst)
278 {
279    _mesa_hash_table_insert(var_remap, src->idx, dst->idx);
280    _mesa_hash_table_insert(var_remap, src->shader_addr, dst->shader_addr);
281    _mesa_hash_table_insert(var_remap, src->traversal_addr, dst->traversal_addr);
282    _mesa_hash_table_insert(var_remap, src->arg, dst->arg);
283    _mesa_hash_table_insert(var_remap, src->stack_ptr, dst->stack_ptr);
284    _mesa_hash_table_insert(var_remap, src->shader_record_ptr, dst->shader_record_ptr);
285 
286    for (uint32_t i = 0; i < ARRAY_SIZE(src->launch_sizes); i++)
287       _mesa_hash_table_insert(var_remap, src->launch_sizes[i], dst->launch_sizes[i]);
288 
289    for (uint32_t i = 0; i < ARRAY_SIZE(src->launch_ids); i++)
290       _mesa_hash_table_insert(var_remap, src->launch_ids[i], dst->launch_ids[i]);
291 
292    if (dst->ahit_isec_count)
293       _mesa_hash_table_insert(var_remap, src->ahit_isec_count, dst->ahit_isec_count);
294 
295    _mesa_hash_table_insert(var_remap, src->accel_struct, dst->accel_struct);
296    _mesa_hash_table_insert(var_remap, src->cull_mask_and_flags, dst->cull_mask_and_flags);
297    _mesa_hash_table_insert(var_remap, src->sbt_offset, dst->sbt_offset);
298    _mesa_hash_table_insert(var_remap, src->sbt_stride, dst->sbt_stride);
299    _mesa_hash_table_insert(var_remap, src->miss_index, dst->miss_index);
300    _mesa_hash_table_insert(var_remap, src->origin, dst->origin);
301    _mesa_hash_table_insert(var_remap, src->tmin, dst->tmin);
302    _mesa_hash_table_insert(var_remap, src->direction, dst->direction);
303    _mesa_hash_table_insert(var_remap, src->tmax, dst->tmax);
304 
305    _mesa_hash_table_insert(var_remap, src->primitive_id, dst->primitive_id);
306    _mesa_hash_table_insert(var_remap, src->geometry_id_and_flags, dst->geometry_id_and_flags);
307    _mesa_hash_table_insert(var_remap, src->instance_addr, dst->instance_addr);
308    _mesa_hash_table_insert(var_remap, src->hit_kind, dst->hit_kind);
309    _mesa_hash_table_insert(var_remap, src->opaque, dst->opaque);
310    _mesa_hash_table_insert(var_remap, src->ahit_accept, dst->ahit_accept);
311    _mesa_hash_table_insert(var_remap, src->ahit_terminate, dst->ahit_terminate);
312 }
313 
314 /*
315  * Create a copy of the global rt variables where the primitive/instance related variables are
316  * independent.This is needed as we need to keep the old values of the global variables around
317  * in case e.g. an anyhit shader reject the collision. So there are inner variables that get copied
318  * to the outer variables once we commit to a better hit.
319  */
320 static struct rt_variables
create_inner_vars(nir_builder * b,const struct rt_variables * vars)321 create_inner_vars(nir_builder *b, const struct rt_variables *vars)
322 {
323    struct rt_variables inner_vars = *vars;
324    inner_vars.idx = nir_variable_create(b->shader, nir_var_shader_temp, glsl_uint_type(), "inner_idx");
325    inner_vars.shader_record_ptr =
326       nir_variable_create(b->shader, nir_var_shader_temp, glsl_uint64_t_type(), "inner_shader_record_ptr");
327    inner_vars.primitive_id =
328       nir_variable_create(b->shader, nir_var_shader_temp, glsl_uint_type(), "inner_primitive_id");
329    inner_vars.geometry_id_and_flags =
330       nir_variable_create(b->shader, nir_var_shader_temp, glsl_uint_type(), "inner_geometry_id_and_flags");
331    inner_vars.tmax = nir_variable_create(b->shader, nir_var_shader_temp, glsl_float_type(), "inner_tmax");
332    inner_vars.instance_addr =
333       nir_variable_create(b->shader, nir_var_shader_temp, glsl_uint64_t_type(), "inner_instance_addr");
334    inner_vars.hit_kind = nir_variable_create(b->shader, nir_var_shader_temp, glsl_uint_type(), "inner_hit_kind");
335 
336    return inner_vars;
337 }
338 
339 static void
insert_rt_return(nir_builder * b,const struct rt_variables * vars)340 insert_rt_return(nir_builder *b, const struct rt_variables *vars)
341 {
342    nir_store_var(b, vars->stack_ptr, nir_iadd_imm(b, nir_load_var(b, vars->stack_ptr), -16), 1);
343    nir_store_var(b, vars->shader_addr, nir_load_scratch(b, 1, 64, nir_load_var(b, vars->stack_ptr), .align_mul = 16),
344                  1);
345 }
346 
347 enum sbt_type {
348    SBT_RAYGEN = offsetof(VkTraceRaysIndirectCommand2KHR, raygenShaderRecordAddress),
349    SBT_MISS = offsetof(VkTraceRaysIndirectCommand2KHR, missShaderBindingTableAddress),
350    SBT_HIT = offsetof(VkTraceRaysIndirectCommand2KHR, hitShaderBindingTableAddress),
351    SBT_CALLABLE = offsetof(VkTraceRaysIndirectCommand2KHR, callableShaderBindingTableAddress),
352 };
353 
354 enum sbt_entry {
355    SBT_RECURSIVE_PTR = offsetof(struct radv_pipeline_group_handle, recursive_shader_ptr),
356    SBT_GENERAL_IDX = offsetof(struct radv_pipeline_group_handle, general_index),
357    SBT_CLOSEST_HIT_IDX = offsetof(struct radv_pipeline_group_handle, closest_hit_index),
358    SBT_INTERSECTION_IDX = offsetof(struct radv_pipeline_group_handle, intersection_index),
359    SBT_ANY_HIT_IDX = offsetof(struct radv_pipeline_group_handle, any_hit_index),
360 };
361 
362 static void
load_sbt_entry(nir_builder * b,const struct rt_variables * vars,nir_def * idx,enum sbt_type binding,enum sbt_entry offset)363 load_sbt_entry(nir_builder *b, const struct rt_variables *vars, nir_def *idx, enum sbt_type binding,
364                enum sbt_entry offset)
365 {
366    nir_def *desc_base_addr = nir_load_sbt_base_amd(b);
367 
368    nir_def *desc = nir_pack_64_2x32(b, nir_load_smem_amd(b, 2, desc_base_addr, nir_imm_int(b, binding)));
369 
370    nir_def *stride_offset = nir_imm_int(b, binding + (binding == SBT_RAYGEN ? 8 : 16));
371    nir_def *stride = nir_load_smem_amd(b, 1, desc_base_addr, stride_offset);
372 
373    nir_def *addr = nir_iadd(b, desc, nir_u2u64(b, nir_iadd_imm(b, nir_imul(b, idx, stride), offset)));
374 
375    if (offset == SBT_RECURSIVE_PTR) {
376       nir_store_var(b, vars->shader_addr, nir_build_load_global(b, 1, 64, addr), 1);
377    } else {
378       nir_store_var(b, vars->idx, nir_build_load_global(b, 1, 32, addr), 1);
379    }
380 
381    nir_def *record_addr = nir_iadd_imm(b, addr, RADV_RT_HANDLE_SIZE - offset);
382    nir_store_var(b, vars->shader_record_ptr, record_addr, 1);
383 }
384 
385 struct radv_rt_shader_info {
386    bool uses_launch_id;
387    bool uses_launch_size;
388 };
389 
390 struct radv_lower_rt_instruction_data {
391    struct rt_variables *vars;
392    bool late_lowering;
393 
394    struct radv_rt_shader_info *out_info;
395 };
396 
397 static bool
radv_lower_rt_instruction(nir_builder * b,nir_instr * instr,void * _data)398 radv_lower_rt_instruction(nir_builder *b, nir_instr *instr, void *_data)
399 {
400    if (instr->type == nir_instr_type_jump) {
401       nir_jump_instr *jump = nir_instr_as_jump(instr);
402       if (jump->type == nir_jump_halt) {
403          jump->type = nir_jump_return;
404          return true;
405       }
406       return false;
407    } else if (instr->type != nir_instr_type_intrinsic) {
408       return false;
409    }
410 
411    nir_intrinsic_instr *intr = nir_instr_as_intrinsic(instr);
412 
413    struct radv_lower_rt_instruction_data *data = _data;
414    struct rt_variables *vars = data->vars;
415 
416    b->cursor = nir_before_instr(&intr->instr);
417 
418    nir_def *ret = NULL;
419    switch (intr->intrinsic) {
420    case nir_intrinsic_rt_execute_callable: {
421       uint32_t size = align(nir_intrinsic_stack_size(intr), 16);
422       nir_def *ret_ptr = nir_load_resume_shader_address_amd(b, nir_intrinsic_call_idx(intr));
423       ret_ptr = nir_ior_imm(b, ret_ptr, radv_get_rt_priority(b->shader->info.stage));
424 
425       nir_store_var(b, vars->stack_ptr, nir_iadd_imm_nuw(b, nir_load_var(b, vars->stack_ptr), size), 1);
426       nir_store_scratch(b, ret_ptr, nir_load_var(b, vars->stack_ptr), .align_mul = 16);
427 
428       nir_store_var(b, vars->stack_ptr, nir_iadd_imm_nuw(b, nir_load_var(b, vars->stack_ptr), 16), 1);
429       load_sbt_entry(b, vars, intr->src[0].ssa, SBT_CALLABLE, SBT_RECURSIVE_PTR);
430 
431       nir_store_var(b, vars->arg, nir_iadd_imm(b, intr->src[1].ssa, -size - 16), 1);
432 
433       vars->stack_size = MAX2(vars->stack_size, size + 16);
434       break;
435    }
436    case nir_intrinsic_rt_trace_ray: {
437       uint32_t size = align(nir_intrinsic_stack_size(intr), 16);
438       nir_def *ret_ptr = nir_load_resume_shader_address_amd(b, nir_intrinsic_call_idx(intr));
439       ret_ptr = nir_ior_imm(b, ret_ptr, radv_get_rt_priority(b->shader->info.stage));
440 
441       nir_store_var(b, vars->stack_ptr, nir_iadd_imm_nuw(b, nir_load_var(b, vars->stack_ptr), size), 1);
442       nir_store_scratch(b, ret_ptr, nir_load_var(b, vars->stack_ptr), .align_mul = 16);
443 
444       nir_store_var(b, vars->stack_ptr, nir_iadd_imm_nuw(b, nir_load_var(b, vars->stack_ptr), 16), 1);
445 
446       nir_store_var(b, vars->shader_addr, nir_load_var(b, vars->traversal_addr), 1);
447       nir_store_var(b, vars->arg, nir_iadd_imm(b, intr->src[10].ssa, -size - 16), 1);
448 
449       vars->stack_size = MAX2(vars->stack_size, size + 16);
450 
451       /* Per the SPIR-V extension spec we have to ignore some bits for some arguments. */
452       nir_store_var(b, vars->accel_struct, intr->src[0].ssa, 0x1);
453       nir_store_var(b, vars->cull_mask_and_flags, nir_ior(b, nir_ishl_imm(b, intr->src[2].ssa, 24), intr->src[1].ssa),
454                     0x1);
455       nir_store_var(b, vars->sbt_offset, nir_iand_imm(b, intr->src[3].ssa, 0xf), 0x1);
456       nir_store_var(b, vars->sbt_stride, nir_iand_imm(b, intr->src[4].ssa, 0xf), 0x1);
457       nir_store_var(b, vars->miss_index, nir_iand_imm(b, intr->src[5].ssa, 0xffff), 0x1);
458       nir_store_var(b, vars->origin, intr->src[6].ssa, 0x7);
459       nir_store_var(b, vars->tmin, intr->src[7].ssa, 0x1);
460       nir_store_var(b, vars->direction, intr->src[8].ssa, 0x7);
461       nir_store_var(b, vars->tmax, intr->src[9].ssa, 0x1);
462       break;
463    }
464    case nir_intrinsic_rt_resume: {
465       uint32_t size = align(nir_intrinsic_stack_size(intr), 16);
466 
467       nir_store_var(b, vars->stack_ptr, nir_iadd_imm(b, nir_load_var(b, vars->stack_ptr), -size), 1);
468       break;
469    }
470    case nir_intrinsic_rt_return_amd: {
471       if (b->shader->info.stage == MESA_SHADER_RAYGEN) {
472          nir_terminate(b);
473          break;
474       }
475       insert_rt_return(b, vars);
476       break;
477    }
478    case nir_intrinsic_load_scratch: {
479       if (data->late_lowering)
480          nir_src_rewrite(&intr->src[0], nir_iadd_nuw(b, nir_load_var(b, vars->stack_ptr), intr->src[0].ssa));
481       return true;
482    }
483    case nir_intrinsic_store_scratch: {
484       if (data->late_lowering)
485          nir_src_rewrite(&intr->src[1], nir_iadd_nuw(b, nir_load_var(b, vars->stack_ptr), intr->src[1].ssa));
486       return true;
487    }
488    case nir_intrinsic_load_rt_arg_scratch_offset_amd: {
489       ret = nir_load_var(b, vars->arg);
490       break;
491    }
492    case nir_intrinsic_load_shader_record_ptr: {
493       ret = nir_load_var(b, vars->shader_record_ptr);
494       break;
495    }
496    case nir_intrinsic_load_ray_launch_size: {
497       if (data->out_info)
498          data->out_info->uses_launch_size = true;
499 
500       if (!data->late_lowering)
501          return false;
502 
503       ret = nir_vec3(b, nir_load_var(b, vars->launch_sizes[0]), nir_load_var(b, vars->launch_sizes[1]),
504                      nir_load_var(b, vars->launch_sizes[2]));
505       break;
506    };
507    case nir_intrinsic_load_ray_launch_id: {
508       if (data->out_info)
509          data->out_info->uses_launch_id = true;
510 
511       if (!data->late_lowering)
512          return false;
513 
514       ret = nir_vec3(b, nir_load_var(b, vars->launch_ids[0]), nir_load_var(b, vars->launch_ids[1]),
515                      nir_load_var(b, vars->launch_ids[2]));
516       break;
517    }
518    case nir_intrinsic_load_ray_t_min: {
519       ret = nir_load_var(b, vars->tmin);
520       break;
521    }
522    case nir_intrinsic_load_ray_t_max: {
523       ret = nir_load_var(b, vars->tmax);
524       break;
525    }
526    case nir_intrinsic_load_ray_world_origin: {
527       ret = nir_load_var(b, vars->origin);
528       break;
529    }
530    case nir_intrinsic_load_ray_world_direction: {
531       ret = nir_load_var(b, vars->direction);
532       break;
533    }
534    case nir_intrinsic_load_ray_instance_custom_index: {
535       nir_def *instance_node_addr = nir_load_var(b, vars->instance_addr);
536       nir_def *custom_instance_and_mask = nir_build_load_global(
537          b, 1, 32,
538          nir_iadd_imm(b, instance_node_addr, offsetof(struct radv_bvh_instance_node, custom_instance_and_mask)));
539       ret = nir_iand_imm(b, custom_instance_and_mask, 0xFFFFFF);
540       break;
541    }
542    case nir_intrinsic_load_primitive_id: {
543       ret = nir_load_var(b, vars->primitive_id);
544       break;
545    }
546    case nir_intrinsic_load_ray_geometry_index: {
547       ret = nir_load_var(b, vars->geometry_id_and_flags);
548       ret = nir_iand_imm(b, ret, 0xFFFFFFF);
549       break;
550    }
551    case nir_intrinsic_load_instance_id: {
552       nir_def *instance_node_addr = nir_load_var(b, vars->instance_addr);
553       ret = nir_build_load_global(
554          b, 1, 32, nir_iadd_imm(b, instance_node_addr, offsetof(struct radv_bvh_instance_node, instance_id)));
555       break;
556    }
557    case nir_intrinsic_load_ray_flags: {
558       ret = nir_iand_imm(b, nir_load_var(b, vars->cull_mask_and_flags), 0xFFFFFF);
559       break;
560    }
561    case nir_intrinsic_load_ray_hit_kind: {
562       ret = nir_load_var(b, vars->hit_kind);
563       break;
564    }
565    case nir_intrinsic_load_ray_world_to_object: {
566       unsigned c = nir_intrinsic_column(intr);
567       nir_def *instance_node_addr = nir_load_var(b, vars->instance_addr);
568       nir_def *wto_matrix[3];
569       nir_build_wto_matrix_load(b, instance_node_addr, wto_matrix);
570 
571       nir_def *vals[3];
572       for (unsigned i = 0; i < 3; ++i)
573          vals[i] = nir_channel(b, wto_matrix[i], c);
574 
575       ret = nir_vec(b, vals, 3);
576       break;
577    }
578    case nir_intrinsic_load_ray_object_to_world: {
579       unsigned c = nir_intrinsic_column(intr);
580       nir_def *instance_node_addr = nir_load_var(b, vars->instance_addr);
581       nir_def *rows[3];
582       for (unsigned r = 0; r < 3; ++r)
583          rows[r] = nir_build_load_global(
584             b, 4, 32,
585             nir_iadd_imm(b, instance_node_addr, offsetof(struct radv_bvh_instance_node, otw_matrix) + r * 16));
586       ret = nir_vec3(b, nir_channel(b, rows[0], c), nir_channel(b, rows[1], c), nir_channel(b, rows[2], c));
587       break;
588    }
589    case nir_intrinsic_load_ray_object_origin: {
590       nir_def *instance_node_addr = nir_load_var(b, vars->instance_addr);
591       nir_def *wto_matrix[3];
592       nir_build_wto_matrix_load(b, instance_node_addr, wto_matrix);
593       ret = nir_build_vec3_mat_mult(b, nir_load_var(b, vars->origin), wto_matrix, true);
594       break;
595    }
596    case nir_intrinsic_load_ray_object_direction: {
597       nir_def *instance_node_addr = nir_load_var(b, vars->instance_addr);
598       nir_def *wto_matrix[3];
599       nir_build_wto_matrix_load(b, instance_node_addr, wto_matrix);
600       ret = nir_build_vec3_mat_mult(b, nir_load_var(b, vars->direction), wto_matrix, false);
601       break;
602    }
603    case nir_intrinsic_load_intersection_opaque_amd: {
604       ret = nir_load_var(b, vars->opaque);
605       break;
606    }
607    case nir_intrinsic_load_cull_mask: {
608       ret = nir_ushr_imm(b, nir_load_var(b, vars->cull_mask_and_flags), 24);
609       break;
610    }
611    case nir_intrinsic_ignore_ray_intersection: {
612       nir_store_var(b, vars->ahit_accept, nir_imm_false(b), 0x1);
613 
614       /* The if is a workaround to avoid having to fix up control flow manually */
615       nir_push_if(b, nir_imm_true(b));
616       nir_jump(b, nir_jump_return);
617       nir_pop_if(b, NULL);
618       break;
619    }
620    case nir_intrinsic_terminate_ray: {
621       nir_store_var(b, vars->ahit_accept, nir_imm_true(b), 0x1);
622       nir_store_var(b, vars->ahit_terminate, nir_imm_true(b), 0x1);
623 
624       /* The if is a workaround to avoid having to fix up control flow manually */
625       nir_push_if(b, nir_imm_true(b));
626       nir_jump(b, nir_jump_return);
627       nir_pop_if(b, NULL);
628       break;
629    }
630    case nir_intrinsic_report_ray_intersection: {
631       nir_push_if(b, nir_iand(b, nir_fge(b, nir_load_var(b, vars->tmax), intr->src[0].ssa),
632                               nir_fge(b, intr->src[0].ssa, nir_load_var(b, vars->tmin))));
633       {
634          nir_store_var(b, vars->ahit_accept, nir_imm_true(b), 0x1);
635          nir_store_var(b, vars->tmax, intr->src[0].ssa, 1);
636          nir_store_var(b, vars->hit_kind, intr->src[1].ssa, 1);
637       }
638       nir_pop_if(b, NULL);
639       break;
640    }
641    case nir_intrinsic_load_sbt_offset_amd: {
642       ret = nir_load_var(b, vars->sbt_offset);
643       break;
644    }
645    case nir_intrinsic_load_sbt_stride_amd: {
646       ret = nir_load_var(b, vars->sbt_stride);
647       break;
648    }
649    case nir_intrinsic_load_accel_struct_amd: {
650       ret = nir_load_var(b, vars->accel_struct);
651       break;
652    }
653    case nir_intrinsic_load_cull_mask_and_flags_amd: {
654       ret = nir_load_var(b, vars->cull_mask_and_flags);
655       break;
656    }
657    case nir_intrinsic_execute_closest_hit_amd: {
658       nir_store_var(b, vars->tmax, intr->src[1].ssa, 0x1);
659       nir_store_var(b, vars->primitive_id, intr->src[2].ssa, 0x1);
660       nir_store_var(b, vars->instance_addr, intr->src[3].ssa, 0x1);
661       nir_store_var(b, vars->geometry_id_and_flags, intr->src[4].ssa, 0x1);
662       nir_store_var(b, vars->hit_kind, intr->src[5].ssa, 0x1);
663       load_sbt_entry(b, vars, intr->src[0].ssa, SBT_HIT, SBT_RECURSIVE_PTR);
664 
665       nir_def *should_return =
666          nir_test_mask(b, nir_load_var(b, vars->cull_mask_and_flags), SpvRayFlagsSkipClosestHitShaderKHRMask);
667 
668       if (!(vars->flags & VK_PIPELINE_CREATE_2_RAY_TRACING_NO_NULL_CLOSEST_HIT_SHADERS_BIT_KHR)) {
669          should_return = nir_ior(b, should_return, nir_ieq_imm(b, nir_load_var(b, vars->shader_addr), 0));
670       }
671 
672       /* should_return is set if we had a hit but we won't be calling the closest hit
673        * shader and hence need to return immediately to the calling shader. */
674       nir_push_if(b, should_return);
675       insert_rt_return(b, vars);
676       nir_pop_if(b, NULL);
677       break;
678    }
679    case nir_intrinsic_execute_miss_amd: {
680       nir_store_var(b, vars->tmax, intr->src[0].ssa, 0x1);
681       nir_def *undef = nir_undef(b, 1, 32);
682       nir_store_var(b, vars->primitive_id, undef, 0x1);
683       nir_store_var(b, vars->instance_addr, nir_undef(b, 1, 64), 0x1);
684       nir_store_var(b, vars->geometry_id_and_flags, undef, 0x1);
685       nir_store_var(b, vars->hit_kind, undef, 0x1);
686       nir_def *miss_index = nir_load_var(b, vars->miss_index);
687       load_sbt_entry(b, vars, miss_index, SBT_MISS, SBT_RECURSIVE_PTR);
688 
689       if (!(vars->flags & VK_PIPELINE_CREATE_2_RAY_TRACING_NO_NULL_MISS_SHADERS_BIT_KHR)) {
690          /* In case of a NULL miss shader, do nothing and just return. */
691          nir_push_if(b, nir_ieq_imm(b, nir_load_var(b, vars->shader_addr), 0));
692          insert_rt_return(b, vars);
693          nir_pop_if(b, NULL);
694       }
695 
696       break;
697    }
698    case nir_intrinsic_load_ray_triangle_vertex_positions: {
699       nir_def *instance_node_addr = nir_load_var(b, vars->instance_addr);
700       nir_def *primitive_id = nir_load_var(b, vars->primitive_id);
701       ret = radv_load_vertex_position(vars->device, b, instance_node_addr, primitive_id, nir_intrinsic_column(intr));
702       break;
703    }
704    default:
705       return false;
706    }
707 
708    if (ret)
709       nir_def_rewrite_uses(&intr->def, ret);
710    nir_instr_remove(&intr->instr);
711 
712    return true;
713 }
714 
715 /* This lowers all the RT instructions that we do not want to pass on to the combined shader and
716  * that we can implement using the variables from the shader we are going to inline into. */
717 static void
lower_rt_instructions(nir_shader * shader,struct rt_variables * vars,bool late_lowering,struct radv_rt_shader_info * out_info)718 lower_rt_instructions(nir_shader *shader, struct rt_variables *vars, bool late_lowering,
719                       struct radv_rt_shader_info *out_info)
720 {
721    struct radv_lower_rt_instruction_data data = {
722       .vars = vars,
723       .late_lowering = late_lowering,
724       .out_info = out_info,
725    };
726    nir_shader_instructions_pass(shader, radv_lower_rt_instruction, nir_metadata_none, &data);
727 }
728 
729 /* Lowers hit attributes to registers or shared memory. If hit_attribs is NULL, attributes are
730  * lowered to shared memory. */
731 static void
lower_hit_attribs(nir_shader * shader,nir_variable ** hit_attribs,uint32_t workgroup_size)732 lower_hit_attribs(nir_shader *shader, nir_variable **hit_attribs, uint32_t workgroup_size)
733 {
734    nir_function_impl *impl = nir_shader_get_entrypoint(shader);
735 
736    nir_foreach_variable_with_modes (attrib, shader, nir_var_ray_hit_attrib)
737       attrib->data.mode = nir_var_shader_temp;
738 
739    nir_builder b = nir_builder_create(impl);
740 
741    nir_foreach_block (block, impl) {
742       nir_foreach_instr_safe (instr, block) {
743          if (instr->type != nir_instr_type_intrinsic)
744             continue;
745 
746          nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
747          if (intrin->intrinsic != nir_intrinsic_load_hit_attrib_amd &&
748              intrin->intrinsic != nir_intrinsic_store_hit_attrib_amd)
749             continue;
750 
751          b.cursor = nir_after_instr(instr);
752 
753          nir_def *offset;
754          if (!hit_attribs)
755             offset = nir_imul_imm(
756                &b, nir_iadd_imm(&b, nir_load_local_invocation_index(&b), nir_intrinsic_base(intrin) * workgroup_size),
757                sizeof(uint32_t));
758 
759          if (intrin->intrinsic == nir_intrinsic_load_hit_attrib_amd) {
760             nir_def *ret;
761             if (hit_attribs)
762                ret = nir_load_var(&b, hit_attribs[nir_intrinsic_base(intrin)]);
763             else
764                ret = nir_load_shared(&b, 1, 32, offset, .base = 0, .align_mul = 4);
765             nir_def_rewrite_uses(nir_instr_def(instr), ret);
766          } else {
767             if (hit_attribs)
768                nir_store_var(&b, hit_attribs[nir_intrinsic_base(intrin)], intrin->src->ssa, 0x1);
769             else
770                nir_store_shared(&b, intrin->src->ssa, offset, .base = 0, .align_mul = 4);
771          }
772          nir_instr_remove(instr);
773       }
774    }
775 
776    if (!hit_attribs)
777       shader->info.shared_size = MAX2(shader->info.shared_size, workgroup_size * RADV_MAX_HIT_ATTRIB_SIZE);
778 }
779 
780 static void
inline_constants(nir_shader * dst,nir_shader * src)781 inline_constants(nir_shader *dst, nir_shader *src)
782 {
783    if (!src->constant_data_size)
784       return;
785 
786    uint32_t align_mul = 1;
787    if (dst->constant_data_size) {
788       nir_foreach_block (block, nir_shader_get_entrypoint(src)) {
789          nir_foreach_instr (instr, block) {
790             if (instr->type != nir_instr_type_intrinsic)
791                continue;
792 
793             nir_intrinsic_instr *intrinsic = nir_instr_as_intrinsic(instr);
794             if (intrinsic->intrinsic == nir_intrinsic_load_constant)
795                align_mul = MAX2(align_mul, nir_intrinsic_align_mul(intrinsic));
796          }
797       }
798    }
799 
800    uint32_t old_constant_data_size = dst->constant_data_size;
801    uint32_t base_offset = align(dst->constant_data_size, align_mul);
802    dst->constant_data_size = base_offset + src->constant_data_size;
803    dst->constant_data = rerzalloc_size(dst, dst->constant_data, old_constant_data_size, dst->constant_data_size);
804    memcpy((char *)dst->constant_data + base_offset, src->constant_data, src->constant_data_size);
805 
806    if (!base_offset)
807       return;
808 
809    nir_foreach_block (block, nir_shader_get_entrypoint(src)) {
810       nir_foreach_instr (instr, block) {
811          if (instr->type != nir_instr_type_intrinsic)
812             continue;
813 
814          nir_intrinsic_instr *intrinsic = nir_instr_as_intrinsic(instr);
815          if (intrinsic->intrinsic == nir_intrinsic_load_constant)
816             nir_intrinsic_set_base(intrinsic, base_offset + nir_intrinsic_base(intrinsic));
817       }
818    }
819 }
820 
821 static void
insert_rt_case(nir_builder * b,nir_shader * shader,struct rt_variables * vars,nir_def * idx,uint32_t call_idx)822 insert_rt_case(nir_builder *b, nir_shader *shader, struct rt_variables *vars, nir_def *idx, uint32_t call_idx)
823 {
824    struct hash_table *var_remap = _mesa_pointer_hash_table_create(NULL);
825 
826    nir_opt_dead_cf(shader);
827 
828    struct rt_variables src_vars = create_rt_variables(shader, vars->device, vars->flags, vars->monolithic);
829    map_rt_variables(var_remap, &src_vars, vars);
830 
831    NIR_PASS_V(shader, lower_rt_instructions, &src_vars, false, NULL);
832 
833    NIR_PASS(_, shader, nir_lower_returns);
834    NIR_PASS(_, shader, nir_opt_dce);
835 
836    inline_constants(b->shader, shader);
837 
838    nir_push_if(b, nir_ieq_imm(b, idx, call_idx));
839    nir_inline_function_impl(b, nir_shader_get_entrypoint(shader), NULL, var_remap);
840    nir_pop_if(b, NULL);
841 
842    ralloc_free(var_remap);
843 }
844 
845 static bool
radv_lower_payload_arg_to_offset(nir_builder * b,nir_intrinsic_instr * instr,void * data)846 radv_lower_payload_arg_to_offset(nir_builder *b, nir_intrinsic_instr *instr, void *data)
847 {
848    if (instr->intrinsic != nir_intrinsic_trace_ray)
849       return false;
850 
851    nir_deref_instr *payload = nir_src_as_deref(instr->src[10]);
852    assert(payload->deref_type == nir_deref_type_var);
853 
854    b->cursor = nir_before_instr(&instr->instr);
855    nir_def *offset = nir_imm_int(b, payload->var->data.driver_location);
856 
857    nir_src_rewrite(&instr->src[10], offset);
858 
859    return true;
860 }
861 
862 void
radv_nir_lower_rt_io(nir_shader * nir,bool monolithic,uint32_t payload_offset)863 radv_nir_lower_rt_io(nir_shader *nir, bool monolithic, uint32_t payload_offset)
864 {
865    if (!monolithic) {
866       NIR_PASS(_, nir, nir_lower_vars_to_explicit_types, nir_var_function_temp | nir_var_shader_call_data,
867                glsl_get_natural_size_align_bytes);
868 
869       NIR_PASS(_, nir, lower_rt_derefs);
870 
871       NIR_PASS(_, nir, nir_lower_explicit_io, nir_var_function_temp, nir_address_format_32bit_offset);
872    } else {
873       if (nir->info.stage == MESA_SHADER_RAYGEN) {
874          /* Use nir_lower_vars_to_explicit_types to assign the payload locations. We call
875           * nir_lower_vars_to_explicit_types later after splitting the payloads.
876           */
877          uint32_t scratch_size = nir->scratch_size;
878          nir_lower_vars_to_explicit_types(nir, nir_var_function_temp, glsl_get_natural_size_align_bytes);
879          nir->scratch_size = scratch_size;
880 
881          nir_shader_intrinsics_pass(nir, radv_lower_payload_arg_to_offset, nir_metadata_control_flow, NULL);
882       }
883 
884       NIR_PASS(_, nir, radv_nir_lower_ray_payload_derefs, payload_offset);
885    }
886 }
887 
888 static nir_def *
radv_build_token_begin(nir_builder * b,struct rt_variables * vars,nir_def * hit,enum radv_packed_token_type token_type,nir_def * token_size,uint32_t max_token_size)889 radv_build_token_begin(nir_builder *b, struct rt_variables *vars, nir_def *hit, enum radv_packed_token_type token_type,
890                        nir_def *token_size, uint32_t max_token_size)
891 {
892    struct radv_rra_trace_data *rra_trace = &vars->device->rra_trace;
893    assert(rra_trace->ray_history_addr);
894    assert(rra_trace->ray_history_buffer_size >= max_token_size);
895 
896    nir_def *ray_history_addr = nir_imm_int64(b, rra_trace->ray_history_addr);
897 
898    nir_def *launch_id = nir_load_ray_launch_id(b);
899 
900    nir_def *trace = nir_imm_true(b);
901    for (uint32_t i = 0; i < 3; i++) {
902       nir_def *remainder = nir_umod_imm(b, nir_channel(b, launch_id, i), rra_trace->ray_history_resolution_scale);
903       trace = nir_iand(b, trace, nir_ieq_imm(b, remainder, 0));
904    }
905    nir_push_if(b, trace);
906 
907    static_assert(offsetof(struct radv_ray_history_header, offset) == 0, "Unexpected offset");
908    nir_def *base_offset = nir_global_atomic(b, 32, ray_history_addr, token_size, .atomic_op = nir_atomic_op_iadd);
909 
910    /* Abuse the dword alignment of token_size to add an invalid bit to offset. */
911    trace = nir_ieq_imm(b, nir_iand_imm(b, base_offset, 1), 0);
912 
913    nir_def *in_bounds = nir_ule_imm(b, base_offset, rra_trace->ray_history_buffer_size - max_token_size);
914    /* Make sure we don't overwrite the header in case of an overflow. */
915    in_bounds = nir_iand(b, in_bounds, nir_uge_imm(b, base_offset, sizeof(struct radv_ray_history_header)));
916 
917    nir_push_if(b, nir_iand(b, trace, in_bounds));
918 
919    nir_def *dst_addr = nir_iadd(b, ray_history_addr, nir_u2u64(b, base_offset));
920 
921    nir_def *launch_size = nir_load_ray_launch_size(b);
922 
923    nir_def *launch_id_comps[3];
924    nir_def *launch_size_comps[3];
925    for (uint32_t i = 0; i < 3; i++) {
926       launch_id_comps[i] = nir_udiv_imm(b, nir_channel(b, launch_id, i), rra_trace->ray_history_resolution_scale);
927       launch_size_comps[i] = nir_udiv_imm(b, nir_channel(b, launch_size, i), rra_trace->ray_history_resolution_scale);
928    }
929 
930    nir_def *global_index =
931       nir_iadd(b, launch_id_comps[0],
932                nir_iadd(b, nir_imul(b, launch_id_comps[1], launch_size_comps[0]),
933                         nir_imul(b, launch_id_comps[2], nir_imul(b, launch_size_comps[0], launch_size_comps[1]))));
934    nir_def *launch_index_and_hit = nir_bcsel(b, hit, nir_ior_imm(b, global_index, 1u << 29u), global_index);
935    nir_build_store_global(b, nir_ior_imm(b, launch_index_and_hit, token_type << 30), dst_addr, .align_mul = 4);
936 
937    return nir_iadd_imm(b, dst_addr, 4);
938 }
939 
940 static void
radv_build_token_end(nir_builder * b)941 radv_build_token_end(nir_builder *b)
942 {
943    nir_pop_if(b, NULL);
944    nir_pop_if(b, NULL);
945 }
946 
947 static void
radv_build_end_trace_token(nir_builder * b,struct rt_variables * vars,nir_def * tmax,nir_def * hit,nir_def * iteration_instance_count)948 radv_build_end_trace_token(nir_builder *b, struct rt_variables *vars, nir_def *tmax, nir_def *hit,
949                            nir_def *iteration_instance_count)
950 {
951    nir_def *token_size = nir_bcsel(b, hit, nir_imm_int(b, sizeof(struct radv_packed_end_trace_token)),
952                                    nir_imm_int(b, offsetof(struct radv_packed_end_trace_token, primitive_id)));
953 
954    nir_def *dst_addr = radv_build_token_begin(b, vars, hit, radv_packed_token_end_trace, token_size,
955                                               sizeof(struct radv_packed_end_trace_token));
956    {
957       nir_build_store_global(b, nir_load_var(b, vars->accel_struct), dst_addr, .align_mul = 4);
958       dst_addr = nir_iadd_imm(b, dst_addr, 8);
959 
960       nir_def *dispatch_indices =
961          nir_load_smem_amd(b, 2, nir_imm_int64(b, vars->device->rra_trace.ray_history_addr),
962                            nir_imm_int(b, offsetof(struct radv_ray_history_header, dispatch_index)), .align_mul = 4);
963       nir_def *dispatch_index = nir_iadd(b, nir_channel(b, dispatch_indices, 0), nir_channel(b, dispatch_indices, 1));
964       nir_def *dispatch_and_flags = nir_iand_imm(b, nir_load_var(b, vars->cull_mask_and_flags), 0xFFFF);
965       dispatch_and_flags = nir_ior(b, dispatch_and_flags, dispatch_index);
966       nir_build_store_global(b, dispatch_and_flags, dst_addr, .align_mul = 4);
967       dst_addr = nir_iadd_imm(b, dst_addr, 4);
968 
969       nir_def *shifted_cull_mask = nir_iand_imm(b, nir_load_var(b, vars->cull_mask_and_flags), 0xFF000000);
970 
971       nir_def *packed_args = nir_load_var(b, vars->sbt_offset);
972       packed_args = nir_ior(b, packed_args, nir_ishl_imm(b, nir_load_var(b, vars->sbt_stride), 4));
973       packed_args = nir_ior(b, packed_args, nir_ishl_imm(b, nir_load_var(b, vars->miss_index), 8));
974       packed_args = nir_ior(b, packed_args, shifted_cull_mask);
975       nir_build_store_global(b, packed_args, dst_addr, .align_mul = 4);
976       dst_addr = nir_iadd_imm(b, dst_addr, 4);
977 
978       nir_build_store_global(b, nir_load_var(b, vars->origin), dst_addr, .align_mul = 4);
979       dst_addr = nir_iadd_imm(b, dst_addr, 12);
980 
981       nir_build_store_global(b, nir_load_var(b, vars->tmin), dst_addr, .align_mul = 4);
982       dst_addr = nir_iadd_imm(b, dst_addr, 4);
983 
984       nir_build_store_global(b, nir_load_var(b, vars->direction), dst_addr, .align_mul = 4);
985       dst_addr = nir_iadd_imm(b, dst_addr, 12);
986 
987       nir_build_store_global(b, tmax, dst_addr, .align_mul = 4);
988       dst_addr = nir_iadd_imm(b, dst_addr, 4);
989 
990       nir_build_store_global(b, iteration_instance_count, dst_addr, .align_mul = 4);
991       dst_addr = nir_iadd_imm(b, dst_addr, 4);
992 
993       nir_build_store_global(b, nir_load_var(b, vars->ahit_isec_count), dst_addr, .align_mul = 4);
994       dst_addr = nir_iadd_imm(b, dst_addr, 4);
995 
996       nir_push_if(b, hit);
997       {
998          nir_build_store_global(b, nir_load_var(b, vars->primitive_id), dst_addr, .align_mul = 4);
999          dst_addr = nir_iadd_imm(b, dst_addr, 4);
1000 
1001          nir_def *geometry_id = nir_iand_imm(b, nir_load_var(b, vars->geometry_id_and_flags), 0xFFFFFFF);
1002          nir_build_store_global(b, geometry_id, dst_addr, .align_mul = 4);
1003          dst_addr = nir_iadd_imm(b, dst_addr, 4);
1004 
1005          nir_def *instance_id_and_hit_kind =
1006             nir_build_load_global(b, 1, 32,
1007                                   nir_iadd_imm(b, nir_load_var(b, vars->instance_addr),
1008                                                offsetof(struct radv_bvh_instance_node, instance_id)));
1009          instance_id_and_hit_kind =
1010             nir_ior(b, instance_id_and_hit_kind, nir_ishl_imm(b, nir_load_var(b, vars->hit_kind), 24));
1011          nir_build_store_global(b, instance_id_and_hit_kind, dst_addr, .align_mul = 4);
1012          dst_addr = nir_iadd_imm(b, dst_addr, 4);
1013 
1014          nir_build_store_global(b, nir_load_var(b, vars->tmax), dst_addr, .align_mul = 4);
1015          dst_addr = nir_iadd_imm(b, dst_addr, 4);
1016       }
1017       nir_pop_if(b, NULL);
1018    }
1019    radv_build_token_end(b);
1020 }
1021 
1022 static nir_function_impl *
lower_any_hit_for_intersection(nir_shader * any_hit)1023 lower_any_hit_for_intersection(nir_shader *any_hit)
1024 {
1025    nir_function_impl *impl = nir_shader_get_entrypoint(any_hit);
1026 
1027    /* Any-hit shaders need three parameters */
1028    assert(impl->function->num_params == 0);
1029    nir_parameter params[] = {
1030       {
1031          /* A pointer to a boolean value for whether or not the hit was
1032           * accepted.
1033           */
1034          .num_components = 1,
1035          .bit_size = 32,
1036       },
1037       {
1038          /* The hit T value */
1039          .num_components = 1,
1040          .bit_size = 32,
1041       },
1042       {
1043          /* The hit kind */
1044          .num_components = 1,
1045          .bit_size = 32,
1046       },
1047       {
1048          /* Scratch offset */
1049          .num_components = 1,
1050          .bit_size = 32,
1051       },
1052    };
1053    impl->function->num_params = ARRAY_SIZE(params);
1054    impl->function->params = ralloc_array(any_hit, nir_parameter, ARRAY_SIZE(params));
1055    memcpy(impl->function->params, params, sizeof(params));
1056 
1057    nir_builder build = nir_builder_at(nir_before_impl(impl));
1058    nir_builder *b = &build;
1059 
1060    nir_def *commit_ptr = nir_load_param(b, 0);
1061    nir_def *hit_t = nir_load_param(b, 1);
1062    nir_def *hit_kind = nir_load_param(b, 2);
1063    nir_def *scratch_offset = nir_load_param(b, 3);
1064 
1065    nir_deref_instr *commit = nir_build_deref_cast(b, commit_ptr, nir_var_function_temp, glsl_bool_type(), 0);
1066 
1067    nir_foreach_block_safe (block, impl) {
1068       nir_foreach_instr_safe (instr, block) {
1069          switch (instr->type) {
1070          case nir_instr_type_intrinsic: {
1071             nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
1072             switch (intrin->intrinsic) {
1073             case nir_intrinsic_ignore_ray_intersection:
1074                b->cursor = nir_instr_remove(&intrin->instr);
1075                /* We put the newly emitted code inside a dummy if because it's
1076                 * going to contain a jump instruction and we don't want to
1077                 * deal with that mess here.  It'll get dealt with by our
1078                 * control-flow optimization passes.
1079                 */
1080                nir_store_deref(b, commit, nir_imm_false(b), 0x1);
1081                nir_push_if(b, nir_imm_true(b));
1082                nir_jump(b, nir_jump_return);
1083                nir_pop_if(b, NULL);
1084                break;
1085 
1086             case nir_intrinsic_terminate_ray:
1087                /* The "normal" handling of terminateRay works fine in
1088                 * intersection shaders.
1089                 */
1090                break;
1091 
1092             case nir_intrinsic_load_ray_t_max:
1093                nir_def_replace(&intrin->def, hit_t);
1094                break;
1095 
1096             case nir_intrinsic_load_ray_hit_kind:
1097                nir_def_replace(&intrin->def, hit_kind);
1098                break;
1099 
1100             /* We place all any_hit scratch variables after intersection scratch variables.
1101              * For that reason, we increment the scratch offset by the intersection scratch
1102              * size. For call_data, we have to subtract the offset again.
1103              *
1104              * Note that we don't increase the scratch size as it is already reflected via
1105              * the any_hit stack_size.
1106              */
1107             case nir_intrinsic_load_scratch:
1108                b->cursor = nir_before_instr(instr);
1109                nir_src_rewrite(&intrin->src[0], nir_iadd_nuw(b, scratch_offset, intrin->src[0].ssa));
1110                break;
1111             case nir_intrinsic_store_scratch:
1112                b->cursor = nir_before_instr(instr);
1113                nir_src_rewrite(&intrin->src[1], nir_iadd_nuw(b, scratch_offset, intrin->src[1].ssa));
1114                break;
1115             case nir_intrinsic_load_rt_arg_scratch_offset_amd:
1116                b->cursor = nir_after_instr(instr);
1117                nir_def *arg_offset = nir_isub(b, &intrin->def, scratch_offset);
1118                nir_def_rewrite_uses_after(&intrin->def, arg_offset, arg_offset->parent_instr);
1119                break;
1120 
1121             default:
1122                break;
1123             }
1124             break;
1125          }
1126          case nir_instr_type_jump: {
1127             nir_jump_instr *jump = nir_instr_as_jump(instr);
1128             if (jump->type == nir_jump_halt) {
1129                b->cursor = nir_instr_remove(instr);
1130                nir_jump(b, nir_jump_return);
1131             }
1132             break;
1133          }
1134 
1135          default:
1136             break;
1137          }
1138       }
1139    }
1140 
1141    nir_validate_shader(any_hit, "after initial any-hit lowering");
1142 
1143    nir_lower_returns_impl(impl);
1144 
1145    nir_validate_shader(any_hit, "after lowering returns");
1146 
1147    return impl;
1148 }
1149 
1150 /* Inline the any_hit shader into the intersection shader so we don't have
1151  * to implement yet another shader call interface here. Neither do any recursion.
1152  */
1153 static void
nir_lower_intersection_shader(nir_shader * intersection,nir_shader * any_hit)1154 nir_lower_intersection_shader(nir_shader *intersection, nir_shader *any_hit)
1155 {
1156    void *dead_ctx = ralloc_context(intersection);
1157 
1158    nir_function_impl *any_hit_impl = NULL;
1159    struct hash_table *any_hit_var_remap = NULL;
1160    if (any_hit) {
1161       any_hit = nir_shader_clone(dead_ctx, any_hit);
1162       NIR_PASS(_, any_hit, nir_opt_dce);
1163 
1164       inline_constants(intersection, any_hit);
1165 
1166       any_hit_impl = lower_any_hit_for_intersection(any_hit);
1167       any_hit_var_remap = _mesa_pointer_hash_table_create(dead_ctx);
1168    }
1169 
1170    nir_function_impl *impl = nir_shader_get_entrypoint(intersection);
1171 
1172    nir_builder build = nir_builder_create(impl);
1173    nir_builder *b = &build;
1174 
1175    b->cursor = nir_before_impl(impl);
1176 
1177    nir_variable *commit = nir_local_variable_create(impl, glsl_bool_type(), "ray_commit");
1178    nir_store_var(b, commit, nir_imm_false(b), 0x1);
1179 
1180    nir_foreach_block_safe (block, impl) {
1181       nir_foreach_instr_safe (instr, block) {
1182          if (instr->type != nir_instr_type_intrinsic)
1183             continue;
1184 
1185          nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
1186          if (intrin->intrinsic != nir_intrinsic_report_ray_intersection)
1187             continue;
1188 
1189          b->cursor = nir_instr_remove(&intrin->instr);
1190          nir_def *hit_t = intrin->src[0].ssa;
1191          nir_def *hit_kind = intrin->src[1].ssa;
1192          nir_def *min_t = nir_load_ray_t_min(b);
1193          nir_def *max_t = nir_load_ray_t_max(b);
1194 
1195          /* bool commit_tmp = false; */
1196          nir_variable *commit_tmp = nir_local_variable_create(impl, glsl_bool_type(), "commit_tmp");
1197          nir_store_var(b, commit_tmp, nir_imm_false(b), 0x1);
1198 
1199          nir_push_if(b, nir_iand(b, nir_fge(b, hit_t, min_t), nir_fge(b, max_t, hit_t)));
1200          {
1201             /* Any-hit defaults to commit */
1202             nir_store_var(b, commit_tmp, nir_imm_true(b), 0x1);
1203 
1204             if (any_hit_impl != NULL) {
1205                nir_push_if(b, nir_inot(b, nir_load_intersection_opaque_amd(b)));
1206                {
1207                   nir_def *params[] = {
1208                      &nir_build_deref_var(b, commit_tmp)->def,
1209                      hit_t,
1210                      hit_kind,
1211                      nir_imm_int(b, intersection->scratch_size),
1212                   };
1213                   nir_inline_function_impl(b, any_hit_impl, params, any_hit_var_remap);
1214                }
1215                nir_pop_if(b, NULL);
1216             }
1217 
1218             nir_push_if(b, nir_load_var(b, commit_tmp));
1219             {
1220                nir_report_ray_intersection(b, 1, hit_t, hit_kind);
1221             }
1222             nir_pop_if(b, NULL);
1223          }
1224          nir_pop_if(b, NULL);
1225 
1226          nir_def *accepted = nir_load_var(b, commit_tmp);
1227          nir_def_rewrite_uses(&intrin->def, accepted);
1228       }
1229    }
1230    nir_metadata_preserve(impl, nir_metadata_none);
1231 
1232    /* We did some inlining; have to re-index SSA defs */
1233    nir_index_ssa_defs(impl);
1234 
1235    /* Eliminate the casts introduced for the commit return of the any-hit shader. */
1236    NIR_PASS(_, intersection, nir_opt_deref);
1237 
1238    ralloc_free(dead_ctx);
1239 }
1240 
1241 /* Variables only used internally to ray traversal. This is data that describes
1242  * the current state of the traversal vs. what we'd give to a shader.  e.g. what
1243  * is the instance we're currently visiting vs. what is the instance of the
1244  * closest hit. */
1245 struct rt_traversal_vars {
1246    nir_variable *origin;
1247    nir_variable *dir;
1248    nir_variable *inv_dir;
1249    nir_variable *sbt_offset_and_flags;
1250    nir_variable *instance_addr;
1251    nir_variable *hit;
1252    nir_variable *bvh_base;
1253    nir_variable *stack;
1254    nir_variable *top_stack;
1255    nir_variable *stack_low_watermark;
1256    nir_variable *current_node;
1257    nir_variable *previous_node;
1258    nir_variable *instance_top_node;
1259    nir_variable *instance_bottom_node;
1260 };
1261 
1262 static struct rt_traversal_vars
init_traversal_vars(nir_builder * b)1263 init_traversal_vars(nir_builder *b)
1264 {
1265    const struct glsl_type *vec3_type = glsl_vector_type(GLSL_TYPE_FLOAT, 3);
1266    struct rt_traversal_vars ret;
1267 
1268    ret.origin = nir_variable_create(b->shader, nir_var_shader_temp, vec3_type, "traversal_origin");
1269    ret.dir = nir_variable_create(b->shader, nir_var_shader_temp, vec3_type, "traversal_dir");
1270    ret.inv_dir = nir_variable_create(b->shader, nir_var_shader_temp, vec3_type, "traversal_inv_dir");
1271    ret.sbt_offset_and_flags =
1272       nir_variable_create(b->shader, nir_var_shader_temp, glsl_uint_type(), "traversal_sbt_offset_and_flags");
1273    ret.instance_addr = nir_variable_create(b->shader, nir_var_shader_temp, glsl_uint64_t_type(), "instance_addr");
1274    ret.hit = nir_variable_create(b->shader, nir_var_shader_temp, glsl_bool_type(), "traversal_hit");
1275    ret.bvh_base = nir_variable_create(b->shader, nir_var_shader_temp, glsl_uint64_t_type(), "traversal_bvh_base");
1276    ret.stack = nir_variable_create(b->shader, nir_var_shader_temp, glsl_uint_type(), "traversal_stack_ptr");
1277    ret.top_stack = nir_variable_create(b->shader, nir_var_shader_temp, glsl_uint_type(), "traversal_top_stack_ptr");
1278    ret.stack_low_watermark =
1279       nir_variable_create(b->shader, nir_var_shader_temp, glsl_uint_type(), "traversal_stack_low_watermark");
1280    ret.current_node = nir_variable_create(b->shader, nir_var_shader_temp, glsl_uint_type(), "current_node;");
1281    ret.previous_node = nir_variable_create(b->shader, nir_var_shader_temp, glsl_uint_type(), "previous_node");
1282    ret.instance_top_node = nir_variable_create(b->shader, nir_var_shader_temp, glsl_uint_type(), "instance_top_node");
1283    ret.instance_bottom_node =
1284       nir_variable_create(b->shader, nir_var_shader_temp, glsl_uint_type(), "instance_bottom_node");
1285    return ret;
1286 }
1287 
1288 struct traversal_data {
1289    struct radv_device *device;
1290    struct rt_variables *vars;
1291    struct rt_traversal_vars *trav_vars;
1292    nir_variable *barycentrics;
1293 
1294    struct radv_ray_tracing_pipeline *pipeline;
1295 };
1296 
1297 static void
radv_ray_tracing_group_ahit_info(struct radv_ray_tracing_group * group,uint32_t * shader_index,uint32_t * handle_index,struct radv_rt_case_data * data)1298 radv_ray_tracing_group_ahit_info(struct radv_ray_tracing_group *group, uint32_t *shader_index, uint32_t *handle_index,
1299                                  struct radv_rt_case_data *data)
1300 {
1301    if (group->type == VK_RAY_TRACING_SHADER_GROUP_TYPE_TRIANGLES_HIT_GROUP_KHR) {
1302       *shader_index = group->any_hit_shader;
1303       *handle_index = group->handle.any_hit_index;
1304    }
1305 }
1306 
1307 static void
radv_build_ahit_case(nir_builder * b,nir_def * sbt_idx,struct radv_ray_tracing_group * group,struct radv_rt_case_data * data)1308 radv_build_ahit_case(nir_builder *b, nir_def *sbt_idx, struct radv_ray_tracing_group *group,
1309                      struct radv_rt_case_data *data)
1310 {
1311    nir_shader *nir_stage =
1312       radv_pipeline_cache_handle_to_nir(data->device, data->pipeline->stages[group->any_hit_shader].nir);
1313    assert(nir_stage);
1314 
1315    radv_nir_lower_rt_io(nir_stage, data->vars->monolithic, data->vars->payload_offset);
1316 
1317    insert_rt_case(b, nir_stage, data->vars, sbt_idx, group->handle.any_hit_index);
1318    ralloc_free(nir_stage);
1319 }
1320 
1321 static void
radv_ray_tracing_group_isec_info(struct radv_ray_tracing_group * group,uint32_t * shader_index,uint32_t * handle_index,struct radv_rt_case_data * data)1322 radv_ray_tracing_group_isec_info(struct radv_ray_tracing_group *group, uint32_t *shader_index, uint32_t *handle_index,
1323                                  struct radv_rt_case_data *data)
1324 {
1325    if (group->type == VK_RAY_TRACING_SHADER_GROUP_TYPE_PROCEDURAL_HIT_GROUP_KHR) {
1326       *shader_index = group->intersection_shader;
1327       *handle_index = group->handle.intersection_index;
1328    }
1329 }
1330 
1331 static void
radv_build_isec_case(nir_builder * b,nir_def * sbt_idx,struct radv_ray_tracing_group * group,struct radv_rt_case_data * data)1332 radv_build_isec_case(nir_builder *b, nir_def *sbt_idx, struct radv_ray_tracing_group *group,
1333                      struct radv_rt_case_data *data)
1334 {
1335    nir_shader *nir_stage =
1336       radv_pipeline_cache_handle_to_nir(data->device, data->pipeline->stages[group->intersection_shader].nir);
1337    assert(nir_stage);
1338 
1339    radv_nir_lower_rt_io(nir_stage, data->vars->monolithic, data->vars->payload_offset);
1340 
1341    nir_shader *any_hit_stage = NULL;
1342    if (group->any_hit_shader != VK_SHADER_UNUSED_KHR) {
1343       any_hit_stage =
1344          radv_pipeline_cache_handle_to_nir(data->device, data->pipeline->stages[group->any_hit_shader].nir);
1345       assert(any_hit_stage);
1346 
1347       radv_nir_lower_rt_io(any_hit_stage, data->vars->monolithic, data->vars->payload_offset);
1348 
1349       /* reserve stack size for any_hit before it is inlined */
1350       data->pipeline->stages[group->any_hit_shader].stack_size = any_hit_stage->scratch_size;
1351 
1352       nir_lower_intersection_shader(nir_stage, any_hit_stage);
1353       ralloc_free(any_hit_stage);
1354    }
1355 
1356    insert_rt_case(b, nir_stage, data->vars, sbt_idx, group->handle.intersection_index);
1357    ralloc_free(nir_stage);
1358 }
1359 
1360 static void
radv_ray_tracing_group_chit_info(struct radv_ray_tracing_group * group,uint32_t * shader_index,uint32_t * handle_index,struct radv_rt_case_data * data)1361 radv_ray_tracing_group_chit_info(struct radv_ray_tracing_group *group, uint32_t *shader_index, uint32_t *handle_index,
1362                                  struct radv_rt_case_data *data)
1363 {
1364    if (group->type != VK_RAY_TRACING_SHADER_GROUP_TYPE_GENERAL_KHR) {
1365       *shader_index = group->recursive_shader;
1366       *handle_index = group->handle.closest_hit_index;
1367    }
1368 }
1369 
1370 static void
radv_ray_tracing_group_miss_info(struct radv_ray_tracing_group * group,uint32_t * shader_index,uint32_t * handle_index,struct radv_rt_case_data * data)1371 radv_ray_tracing_group_miss_info(struct radv_ray_tracing_group *group, uint32_t *shader_index, uint32_t *handle_index,
1372                                  struct radv_rt_case_data *data)
1373 {
1374    if (group->type == VK_RAY_TRACING_SHADER_GROUP_TYPE_GENERAL_KHR) {
1375       if (data->pipeline->stages[group->recursive_shader].stage != MESA_SHADER_MISS)
1376          return;
1377 
1378       *shader_index = group->recursive_shader;
1379       *handle_index = group->handle.general_index;
1380    }
1381 }
1382 
1383 static void
radv_build_recursive_case(nir_builder * b,nir_def * sbt_idx,struct radv_ray_tracing_group * group,struct radv_rt_case_data * data)1384 radv_build_recursive_case(nir_builder *b, nir_def *sbt_idx, struct radv_ray_tracing_group *group,
1385                           struct radv_rt_case_data *data)
1386 {
1387    nir_shader *nir_stage =
1388       radv_pipeline_cache_handle_to_nir(data->device, data->pipeline->stages[group->recursive_shader].nir);
1389    assert(nir_stage);
1390 
1391    radv_nir_lower_rt_io(nir_stage, data->vars->monolithic, data->vars->payload_offset);
1392 
1393    insert_rt_case(b, nir_stage, data->vars, sbt_idx, group->handle.general_index);
1394    ralloc_free(nir_stage);
1395 }
1396 
1397 static void
handle_candidate_triangle(nir_builder * b,struct radv_triangle_intersection * intersection,const struct radv_ray_traversal_args * args,const struct radv_ray_flags * ray_flags)1398 handle_candidate_triangle(nir_builder *b, struct radv_triangle_intersection *intersection,
1399                           const struct radv_ray_traversal_args *args, const struct radv_ray_flags *ray_flags)
1400 {
1401    struct traversal_data *data = args->data;
1402 
1403    nir_def *geometry_id = nir_iand_imm(b, intersection->base.geometry_id_and_flags, 0xfffffff);
1404    nir_def *sbt_idx =
1405       nir_iadd(b,
1406                nir_iadd(b, nir_load_var(b, data->vars->sbt_offset),
1407                         nir_iand_imm(b, nir_load_var(b, data->trav_vars->sbt_offset_and_flags), 0xffffff)),
1408                nir_imul(b, nir_load_var(b, data->vars->sbt_stride), geometry_id));
1409 
1410    nir_def *hit_kind = nir_bcsel(b, intersection->frontface, nir_imm_int(b, 0xFE), nir_imm_int(b, 0xFF));
1411 
1412    nir_def *prev_barycentrics = nir_load_var(b, data->barycentrics);
1413    nir_store_var(b, data->barycentrics, intersection->barycentrics, 0x3);
1414 
1415    nir_store_var(b, data->vars->ahit_accept, nir_imm_true(b), 0x1);
1416    nir_store_var(b, data->vars->ahit_terminate, nir_imm_false(b), 0x1);
1417 
1418    nir_push_if(b, nir_inot(b, intersection->base.opaque));
1419    {
1420       struct rt_variables inner_vars = create_inner_vars(b, data->vars);
1421 
1422       nir_store_var(b, inner_vars.primitive_id, intersection->base.primitive_id, 1);
1423       nir_store_var(b, inner_vars.geometry_id_and_flags, intersection->base.geometry_id_and_flags, 1);
1424       nir_store_var(b, inner_vars.tmax, intersection->t, 0x1);
1425       nir_store_var(b, inner_vars.instance_addr, nir_load_var(b, data->trav_vars->instance_addr), 0x1);
1426       nir_store_var(b, inner_vars.hit_kind, hit_kind, 0x1);
1427 
1428       load_sbt_entry(b, &inner_vars, sbt_idx, SBT_HIT, SBT_ANY_HIT_IDX);
1429 
1430       struct radv_rt_case_data case_data = {
1431          .device = data->device,
1432          .pipeline = data->pipeline,
1433          .vars = &inner_vars,
1434       };
1435 
1436       if (data->vars->ahit_isec_count)
1437          nir_store_var(b, data->vars->ahit_isec_count, nir_iadd_imm(b, nir_load_var(b, data->vars->ahit_isec_count), 1),
1438                        0x1);
1439 
1440       radv_visit_inlined_shaders(
1441          b, nir_load_var(b, inner_vars.idx),
1442          !(data->vars->flags & VK_PIPELINE_CREATE_2_RAY_TRACING_NO_NULL_ANY_HIT_SHADERS_BIT_KHR), &case_data,
1443          radv_ray_tracing_group_ahit_info, radv_build_ahit_case);
1444 
1445       nir_push_if(b, nir_inot(b, nir_load_var(b, data->vars->ahit_accept)));
1446       {
1447          nir_store_var(b, data->barycentrics, prev_barycentrics, 0x3);
1448          nir_jump(b, nir_jump_continue);
1449       }
1450       nir_pop_if(b, NULL);
1451    }
1452    nir_pop_if(b, NULL);
1453 
1454    nir_store_var(b, data->vars->primitive_id, intersection->base.primitive_id, 1);
1455    nir_store_var(b, data->vars->geometry_id_and_flags, intersection->base.geometry_id_and_flags, 1);
1456    nir_store_var(b, data->vars->tmax, intersection->t, 0x1);
1457    nir_store_var(b, data->vars->instance_addr, nir_load_var(b, data->trav_vars->instance_addr), 0x1);
1458    nir_store_var(b, data->vars->hit_kind, hit_kind, 0x1);
1459 
1460    nir_store_var(b, data->vars->idx, sbt_idx, 1);
1461    nir_store_var(b, data->trav_vars->hit, nir_imm_true(b), 1);
1462 
1463    nir_def *ray_terminated = nir_load_var(b, data->vars->ahit_terminate);
1464    nir_break_if(b, nir_ior(b, ray_flags->terminate_on_first_hit, ray_terminated));
1465 }
1466 
1467 static void
handle_candidate_aabb(nir_builder * b,struct radv_leaf_intersection * intersection,const struct radv_ray_traversal_args * args)1468 handle_candidate_aabb(nir_builder *b, struct radv_leaf_intersection *intersection,
1469                       const struct radv_ray_traversal_args *args)
1470 {
1471    struct traversal_data *data = args->data;
1472 
1473    nir_def *geometry_id = nir_iand_imm(b, intersection->geometry_id_and_flags, 0xfffffff);
1474    nir_def *sbt_idx =
1475       nir_iadd(b,
1476                nir_iadd(b, nir_load_var(b, data->vars->sbt_offset),
1477                         nir_iand_imm(b, nir_load_var(b, data->trav_vars->sbt_offset_and_flags), 0xffffff)),
1478                nir_imul(b, nir_load_var(b, data->vars->sbt_stride), geometry_id));
1479 
1480    struct rt_variables inner_vars = create_inner_vars(b, data->vars);
1481 
1482    /* For AABBs the intersection shader writes the hit kind, and only does it if it is the
1483     * next closest hit candidate. */
1484    inner_vars.hit_kind = data->vars->hit_kind;
1485 
1486    nir_store_var(b, inner_vars.primitive_id, intersection->primitive_id, 1);
1487    nir_store_var(b, inner_vars.geometry_id_and_flags, intersection->geometry_id_and_flags, 1);
1488    nir_store_var(b, inner_vars.tmax, nir_load_var(b, data->vars->tmax), 0x1);
1489    nir_store_var(b, inner_vars.instance_addr, nir_load_var(b, data->trav_vars->instance_addr), 0x1);
1490    nir_store_var(b, inner_vars.opaque, intersection->opaque, 1);
1491 
1492    load_sbt_entry(b, &inner_vars, sbt_idx, SBT_HIT, SBT_INTERSECTION_IDX);
1493 
1494    nir_store_var(b, data->vars->ahit_accept, nir_imm_false(b), 0x1);
1495    nir_store_var(b, data->vars->ahit_terminate, nir_imm_false(b), 0x1);
1496 
1497    if (data->vars->ahit_isec_count)
1498       nir_store_var(b, data->vars->ahit_isec_count,
1499                     nir_iadd_imm(b, nir_load_var(b, data->vars->ahit_isec_count), 1 << 16), 0x1);
1500 
1501    struct radv_rt_case_data case_data = {
1502       .device = data->device,
1503       .pipeline = data->pipeline,
1504       .vars = &inner_vars,
1505    };
1506 
1507    radv_visit_inlined_shaders(
1508       b, nir_load_var(b, inner_vars.idx),
1509       !(data->vars->flags & VK_PIPELINE_CREATE_2_RAY_TRACING_NO_NULL_INTERSECTION_SHADERS_BIT_KHR), &case_data,
1510       radv_ray_tracing_group_isec_info, radv_build_isec_case);
1511 
1512    nir_push_if(b, nir_load_var(b, data->vars->ahit_accept));
1513    {
1514       nir_store_var(b, data->vars->primitive_id, intersection->primitive_id, 1);
1515       nir_store_var(b, data->vars->geometry_id_and_flags, intersection->geometry_id_and_flags, 1);
1516       nir_store_var(b, data->vars->tmax, nir_load_var(b, inner_vars.tmax), 0x1);
1517       nir_store_var(b, data->vars->instance_addr, nir_load_var(b, data->trav_vars->instance_addr), 0x1);
1518 
1519       nir_store_var(b, data->vars->idx, sbt_idx, 1);
1520       nir_store_var(b, data->trav_vars->hit, nir_imm_true(b), 1);
1521 
1522       nir_def *terminate_on_first_hit = nir_test_mask(b, args->flags, SpvRayFlagsTerminateOnFirstHitKHRMask);
1523       nir_def *ray_terminated = nir_load_var(b, data->vars->ahit_terminate);
1524       nir_break_if(b, nir_ior(b, terminate_on_first_hit, ray_terminated));
1525    }
1526    nir_pop_if(b, NULL);
1527 }
1528 
1529 static void
store_stack_entry(nir_builder * b,nir_def * index,nir_def * value,const struct radv_ray_traversal_args * args)1530 store_stack_entry(nir_builder *b, nir_def *index, nir_def *value, const struct radv_ray_traversal_args *args)
1531 {
1532    nir_store_shared(b, value, index, .base = 0, .align_mul = 4);
1533 }
1534 
1535 static nir_def *
load_stack_entry(nir_builder * b,nir_def * index,const struct radv_ray_traversal_args * args)1536 load_stack_entry(nir_builder *b, nir_def *index, const struct radv_ray_traversal_args *args)
1537 {
1538    return nir_load_shared(b, 1, 32, index, .base = 0, .align_mul = 4);
1539 }
1540 
1541 static void
radv_build_traversal(struct radv_device * device,struct radv_ray_tracing_pipeline * pipeline,const VkRayTracingPipelineCreateInfoKHR * pCreateInfo,bool monolithic,nir_builder * b,struct rt_variables * vars,bool ignore_cull_mask,struct radv_ray_tracing_stage_info * info)1542 radv_build_traversal(struct radv_device *device, struct radv_ray_tracing_pipeline *pipeline,
1543                      const VkRayTracingPipelineCreateInfoKHR *pCreateInfo, bool monolithic, nir_builder *b,
1544                      struct rt_variables *vars, bool ignore_cull_mask, struct radv_ray_tracing_stage_info *info)
1545 {
1546    const struct radv_physical_device *pdev = radv_device_physical(device);
1547    nir_variable *barycentrics =
1548       nir_variable_create(b->shader, nir_var_ray_hit_attrib, glsl_vector_type(GLSL_TYPE_FLOAT, 2), "barycentrics");
1549    barycentrics->data.driver_location = 0;
1550 
1551    struct rt_traversal_vars trav_vars = init_traversal_vars(b);
1552 
1553    nir_store_var(b, trav_vars.hit, nir_imm_false(b), 1);
1554 
1555    nir_def *accel_struct = nir_load_var(b, vars->accel_struct);
1556    nir_def *bvh_offset = nir_build_load_global(
1557       b, 1, 32, nir_iadd_imm(b, accel_struct, offsetof(struct radv_accel_struct_header, bvh_offset)),
1558       .access = ACCESS_NON_WRITEABLE);
1559    nir_def *root_bvh_base = nir_iadd(b, accel_struct, nir_u2u64(b, bvh_offset));
1560    root_bvh_base = build_addr_to_node(b, root_bvh_base);
1561 
1562    nir_store_var(b, trav_vars.bvh_base, root_bvh_base, 1);
1563 
1564    nir_def *vec3ones = nir_imm_vec3(b, 1.0, 1.0, 1.0);
1565 
1566    nir_store_var(b, trav_vars.origin, nir_load_var(b, vars->origin), 7);
1567    nir_store_var(b, trav_vars.dir, nir_load_var(b, vars->direction), 7);
1568    nir_store_var(b, trav_vars.inv_dir, nir_fdiv(b, vec3ones, nir_load_var(b, trav_vars.dir)), 7);
1569    nir_store_var(b, trav_vars.sbt_offset_and_flags, nir_imm_int(b, 0), 1);
1570    nir_store_var(b, trav_vars.instance_addr, nir_imm_int64(b, 0), 1);
1571 
1572    nir_store_var(b, trav_vars.stack, nir_imul_imm(b, nir_load_local_invocation_index(b), sizeof(uint32_t)), 1);
1573    nir_store_var(b, trav_vars.stack_low_watermark, nir_load_var(b, trav_vars.stack), 1);
1574    nir_store_var(b, trav_vars.current_node, nir_imm_int(b, RADV_BVH_ROOT_NODE), 0x1);
1575    nir_store_var(b, trav_vars.previous_node, nir_imm_int(b, RADV_BVH_INVALID_NODE), 0x1);
1576    nir_store_var(b, trav_vars.instance_top_node, nir_imm_int(b, RADV_BVH_INVALID_NODE), 0x1);
1577    nir_store_var(b, trav_vars.instance_bottom_node, nir_imm_int(b, RADV_BVH_NO_INSTANCE_ROOT), 0x1);
1578 
1579    nir_store_var(b, trav_vars.top_stack, nir_imm_int(b, -1), 1);
1580 
1581    struct radv_ray_traversal_vars trav_vars_args = {
1582       .tmax = nir_build_deref_var(b, vars->tmax),
1583       .origin = nir_build_deref_var(b, trav_vars.origin),
1584       .dir = nir_build_deref_var(b, trav_vars.dir),
1585       .inv_dir = nir_build_deref_var(b, trav_vars.inv_dir),
1586       .bvh_base = nir_build_deref_var(b, trav_vars.bvh_base),
1587       .stack = nir_build_deref_var(b, trav_vars.stack),
1588       .top_stack = nir_build_deref_var(b, trav_vars.top_stack),
1589       .stack_low_watermark = nir_build_deref_var(b, trav_vars.stack_low_watermark),
1590       .current_node = nir_build_deref_var(b, trav_vars.current_node),
1591       .previous_node = nir_build_deref_var(b, trav_vars.previous_node),
1592       .instance_top_node = nir_build_deref_var(b, trav_vars.instance_top_node),
1593       .instance_bottom_node = nir_build_deref_var(b, trav_vars.instance_bottom_node),
1594       .instance_addr = nir_build_deref_var(b, trav_vars.instance_addr),
1595       .sbt_offset_and_flags = nir_build_deref_var(b, trav_vars.sbt_offset_and_flags),
1596    };
1597 
1598    nir_variable *iteration_instance_count = NULL;
1599    if (vars->device->rra_trace.ray_history_addr) {
1600       iteration_instance_count =
1601          nir_variable_create(b->shader, nir_var_shader_temp, glsl_uint_type(), "iteration_instance_count");
1602       nir_store_var(b, iteration_instance_count, nir_imm_int(b, 0), 0x1);
1603       trav_vars_args.iteration_instance_count = nir_build_deref_var(b, iteration_instance_count);
1604 
1605       nir_store_var(b, vars->ahit_isec_count, nir_imm_int(b, 0), 0x1);
1606    }
1607 
1608    struct traversal_data data = {
1609       .device = device,
1610       .vars = vars,
1611       .trav_vars = &trav_vars,
1612       .barycentrics = barycentrics,
1613       .pipeline = pipeline,
1614    };
1615 
1616    nir_def *cull_mask_and_flags = nir_load_var(b, vars->cull_mask_and_flags);
1617    struct radv_ray_traversal_args args = {
1618       .root_bvh_base = root_bvh_base,
1619       .flags = cull_mask_and_flags,
1620       .cull_mask = cull_mask_and_flags,
1621       .origin = nir_load_var(b, vars->origin),
1622       .tmin = nir_load_var(b, vars->tmin),
1623       .dir = nir_load_var(b, vars->direction),
1624       .vars = trav_vars_args,
1625       .stack_stride = pdev->rt_wave_size * sizeof(uint32_t),
1626       .stack_entries = MAX_STACK_ENTRY_COUNT,
1627       .stack_base = 0,
1628       .ignore_cull_mask = ignore_cull_mask,
1629       .set_flags = info ? info->set_flags : 0,
1630       .unset_flags = info ? info->unset_flags : 0,
1631       .stack_store_cb = store_stack_entry,
1632       .stack_load_cb = load_stack_entry,
1633       .aabb_cb = (pipeline->base.base.create_flags & VK_PIPELINE_CREATE_2_RAY_TRACING_SKIP_AABBS_BIT_KHR)
1634                     ? NULL
1635                     : handle_candidate_aabb,
1636       .triangle_cb = (pipeline->base.base.create_flags & VK_PIPELINE_CREATE_2_RAY_TRACING_SKIP_TRIANGLES_BIT_KHR)
1637                         ? NULL
1638                         : handle_candidate_triangle,
1639       .data = &data,
1640    };
1641 
1642    nir_def *original_tmax = nir_load_var(b, vars->tmax);
1643 
1644    radv_build_ray_traversal(device, b, &args);
1645 
1646    if (vars->device->rra_trace.ray_history_addr)
1647       radv_build_end_trace_token(b, vars, original_tmax, nir_load_var(b, trav_vars.hit),
1648                                  nir_load_var(b, iteration_instance_count));
1649 
1650    nir_metadata_preserve(nir_shader_get_entrypoint(b->shader), nir_metadata_none);
1651    radv_nir_lower_hit_attrib_derefs(b->shader);
1652 
1653    /* Register storage for hit attributes */
1654    nir_variable *hit_attribs[RADV_MAX_HIT_ATTRIB_DWORDS];
1655 
1656    if (!monolithic) {
1657       for (uint32_t i = 0; i < ARRAY_SIZE(hit_attribs); i++)
1658          hit_attribs[i] =
1659             nir_local_variable_create(nir_shader_get_entrypoint(b->shader), glsl_uint_type(), "ahit_attrib");
1660 
1661       lower_hit_attribs(b->shader, hit_attribs, pdev->rt_wave_size);
1662    }
1663 
1664    /* Initialize follow-up shader. */
1665    nir_push_if(b, nir_load_var(b, trav_vars.hit));
1666    {
1667       if (monolithic) {
1668          load_sbt_entry(b, vars, nir_load_var(b, vars->idx), SBT_HIT, SBT_CLOSEST_HIT_IDX);
1669 
1670          nir_def *should_return =
1671             nir_test_mask(b, nir_load_var(b, vars->cull_mask_and_flags), SpvRayFlagsSkipClosestHitShaderKHRMask);
1672 
1673          /* should_return is set if we had a hit but we won't be calling the closest hit
1674           * shader and hence need to return immediately to the calling shader. */
1675          nir_push_if(b, nir_inot(b, should_return));
1676 
1677          struct radv_rt_case_data case_data = {
1678             .device = device,
1679             .pipeline = pipeline,
1680             .vars = vars,
1681          };
1682 
1683          radv_visit_inlined_shaders(
1684             b, nir_load_var(b, vars->idx),
1685             !(vars->flags & VK_PIPELINE_CREATE_2_RAY_TRACING_NO_NULL_CLOSEST_HIT_SHADERS_BIT_KHR), &case_data,
1686             radv_ray_tracing_group_chit_info, radv_build_recursive_case);
1687 
1688          nir_pop_if(b, NULL);
1689       } else {
1690          for (int i = 0; i < ARRAY_SIZE(hit_attribs); ++i)
1691             nir_store_hit_attrib_amd(b, nir_load_var(b, hit_attribs[i]), .base = i);
1692          nir_execute_closest_hit_amd(b, nir_load_var(b, vars->idx), nir_load_var(b, vars->tmax),
1693                                      nir_load_var(b, vars->primitive_id), nir_load_var(b, vars->instance_addr),
1694                                      nir_load_var(b, vars->geometry_id_and_flags), nir_load_var(b, vars->hit_kind));
1695       }
1696    }
1697    nir_push_else(b, NULL);
1698    {
1699       if (monolithic) {
1700          load_sbt_entry(b, vars, nir_load_var(b, vars->miss_index), SBT_MISS, SBT_GENERAL_IDX);
1701 
1702          struct radv_rt_case_data case_data = {
1703             .device = device,
1704             .pipeline = pipeline,
1705             .vars = vars,
1706          };
1707 
1708          radv_visit_inlined_shaders(b, nir_load_var(b, vars->idx),
1709                                     !(vars->flags & VK_PIPELINE_CREATE_2_RAY_TRACING_NO_NULL_MISS_SHADERS_BIT_KHR),
1710                                     &case_data, radv_ray_tracing_group_miss_info, radv_build_recursive_case);
1711       } else {
1712          /* Only load the miss shader if we actually miss. It is valid to not specify an SBT pointer
1713           * for miss shaders if none of the rays miss. */
1714          nir_execute_miss_amd(b, nir_load_var(b, vars->tmax));
1715       }
1716    }
1717    nir_pop_if(b, NULL);
1718 }
1719 
1720 nir_shader *
radv_build_traversal_shader(struct radv_device * device,struct radv_ray_tracing_pipeline * pipeline,const VkRayTracingPipelineCreateInfoKHR * pCreateInfo,struct radv_ray_tracing_stage_info * info)1721 radv_build_traversal_shader(struct radv_device *device, struct radv_ray_tracing_pipeline *pipeline,
1722                             const VkRayTracingPipelineCreateInfoKHR *pCreateInfo,
1723                             struct radv_ray_tracing_stage_info *info)
1724 {
1725    const struct radv_physical_device *pdev = radv_device_physical(device);
1726    const VkPipelineCreateFlagBits2KHR create_flags = vk_rt_pipeline_create_flags(pCreateInfo);
1727 
1728    /* Create the traversal shader as an intersection shader to prevent validation failures due to
1729     * invalid variable modes.*/
1730    nir_builder b = radv_meta_init_shader(device, MESA_SHADER_INTERSECTION, "rt_traversal");
1731    b.shader->info.internal = false;
1732    b.shader->info.workgroup_size[0] = 8;
1733    b.shader->info.workgroup_size[1] = pdev->rt_wave_size == 64 ? 8 : 4;
1734    b.shader->info.shared_size = pdev->rt_wave_size * MAX_STACK_ENTRY_COUNT * sizeof(uint32_t);
1735    struct rt_variables vars = create_rt_variables(b.shader, device, create_flags, false);
1736 
1737    if (info->tmin.state == RADV_RT_CONST_ARG_STATE_VALID)
1738       nir_store_var(&b, vars.tmin, nir_imm_int(&b, info->tmin.value), 0x1);
1739    else
1740       nir_store_var(&b, vars.tmin, nir_load_ray_t_min(&b), 0x1);
1741 
1742    if (info->tmax.state == RADV_RT_CONST_ARG_STATE_VALID)
1743       nir_store_var(&b, vars.tmax, nir_imm_int(&b, info->tmax.value), 0x1);
1744    else
1745       nir_store_var(&b, vars.tmax, nir_load_ray_t_max(&b), 0x1);
1746 
1747    if (info->sbt_offset.state == RADV_RT_CONST_ARG_STATE_VALID)
1748       nir_store_var(&b, vars.sbt_offset, nir_imm_int(&b, info->sbt_offset.value), 0x1);
1749    else
1750       nir_store_var(&b, vars.sbt_offset, nir_load_sbt_offset_amd(&b), 0x1);
1751 
1752    if (info->sbt_stride.state == RADV_RT_CONST_ARG_STATE_VALID)
1753       nir_store_var(&b, vars.sbt_stride, nir_imm_int(&b, info->sbt_stride.value), 0x1);
1754    else
1755       nir_store_var(&b, vars.sbt_stride, nir_load_sbt_stride_amd(&b), 0x1);
1756 
1757    /* initialize trace_ray arguments */
1758    nir_store_var(&b, vars.accel_struct, nir_load_accel_struct_amd(&b), 1);
1759    nir_store_var(&b, vars.cull_mask_and_flags, nir_load_cull_mask_and_flags_amd(&b), 0x1);
1760    nir_store_var(&b, vars.origin, nir_load_ray_world_origin(&b), 0x7);
1761    nir_store_var(&b, vars.direction, nir_load_ray_world_direction(&b), 0x7);
1762    nir_store_var(&b, vars.arg, nir_load_rt_arg_scratch_offset_amd(&b), 0x1);
1763    nir_store_var(&b, vars.stack_ptr, nir_imm_int(&b, 0), 0x1);
1764 
1765    radv_build_traversal(device, pipeline, pCreateInfo, false, &b, &vars, false, info);
1766 
1767    /* Deal with all the inline functions. */
1768    nir_index_ssa_defs(nir_shader_get_entrypoint(b.shader));
1769    nir_metadata_preserve(nir_shader_get_entrypoint(b.shader), nir_metadata_none);
1770 
1771    /* Lower and cleanup variables */
1772    NIR_PASS_V(b.shader, nir_lower_global_vars_to_local);
1773    NIR_PASS_V(b.shader, nir_lower_vars_to_ssa);
1774 
1775    return b.shader;
1776 }
1777 
1778 struct lower_rt_instruction_monolithic_state {
1779    struct radv_device *device;
1780    struct radv_ray_tracing_pipeline *pipeline;
1781    const VkRayTracingPipelineCreateInfoKHR *pCreateInfo;
1782 
1783    struct rt_variables *vars;
1784 };
1785 
1786 static bool
lower_rt_instruction_monolithic(nir_builder * b,nir_instr * instr,void * data)1787 lower_rt_instruction_monolithic(nir_builder *b, nir_instr *instr, void *data)
1788 {
1789    if (instr->type != nir_instr_type_intrinsic)
1790       return false;
1791 
1792    b->cursor = nir_after_instr(instr);
1793 
1794    nir_intrinsic_instr *intr = nir_instr_as_intrinsic(instr);
1795 
1796    struct lower_rt_instruction_monolithic_state *state = data;
1797    const struct radv_physical_device *pdev = radv_device_physical(state->device);
1798    struct rt_variables *vars = state->vars;
1799 
1800    switch (intr->intrinsic) {
1801    case nir_intrinsic_execute_callable:
1802       unreachable("nir_intrinsic_execute_callable");
1803    case nir_intrinsic_trace_ray: {
1804       vars->payload_offset = nir_src_as_uint(intr->src[10]);
1805 
1806       nir_src cull_mask = intr->src[2];
1807       bool ignore_cull_mask = nir_src_is_const(cull_mask) && (nir_src_as_uint(cull_mask) & 0xFF) == 0xFF;
1808 
1809       /* Per the SPIR-V extension spec we have to ignore some bits for some arguments. */
1810       nir_store_var(b, vars->accel_struct, intr->src[0].ssa, 0x1);
1811       nir_store_var(b, vars->cull_mask_and_flags, nir_ior(b, nir_ishl_imm(b, cull_mask.ssa, 24), intr->src[1].ssa),
1812                     0x1);
1813       nir_store_var(b, vars->sbt_offset, nir_iand_imm(b, intr->src[3].ssa, 0xf), 0x1);
1814       nir_store_var(b, vars->sbt_stride, nir_iand_imm(b, intr->src[4].ssa, 0xf), 0x1);
1815       nir_store_var(b, vars->miss_index, nir_iand_imm(b, intr->src[5].ssa, 0xffff), 0x1);
1816       nir_store_var(b, vars->origin, intr->src[6].ssa, 0x7);
1817       nir_store_var(b, vars->tmin, intr->src[7].ssa, 0x1);
1818       nir_store_var(b, vars->direction, intr->src[8].ssa, 0x7);
1819       nir_store_var(b, vars->tmax, intr->src[9].ssa, 0x1);
1820 
1821       nir_def *stack_ptr = nir_load_var(b, vars->stack_ptr);
1822       nir_store_var(b, vars->stack_ptr, nir_iadd_imm(b, stack_ptr, b->shader->scratch_size), 0x1);
1823 
1824       radv_build_traversal(state->device, state->pipeline, state->pCreateInfo, true, b, vars, ignore_cull_mask, NULL);
1825       b->shader->info.shared_size =
1826          MAX2(b->shader->info.shared_size, pdev->rt_wave_size * MAX_STACK_ENTRY_COUNT * sizeof(uint32_t));
1827 
1828       nir_store_var(b, vars->stack_ptr, stack_ptr, 0x1);
1829 
1830       nir_instr_remove(instr);
1831       return true;
1832    }
1833    case nir_intrinsic_rt_resume:
1834       unreachable("nir_intrinsic_rt_resume");
1835    case nir_intrinsic_rt_return_amd:
1836       unreachable("nir_intrinsic_rt_return_amd");
1837    case nir_intrinsic_execute_closest_hit_amd:
1838       unreachable("nir_intrinsic_execute_closest_hit_amd");
1839    case nir_intrinsic_execute_miss_amd:
1840       unreachable("nir_intrinsic_execute_miss_amd");
1841    default:
1842       return false;
1843    }
1844 }
1845 
1846 static bool
radv_count_hit_attrib_slots(nir_builder * b,nir_intrinsic_instr * instr,void * data)1847 radv_count_hit_attrib_slots(nir_builder *b, nir_intrinsic_instr *instr, void *data)
1848 {
1849    uint32_t *count = data;
1850    if (instr->intrinsic == nir_intrinsic_load_hit_attrib_amd || instr->intrinsic == nir_intrinsic_store_hit_attrib_amd)
1851       *count = MAX2(*count, nir_intrinsic_base(instr) + 1);
1852 
1853    return false;
1854 }
1855 
1856 static void
lower_rt_instructions_monolithic(nir_shader * shader,struct radv_device * device,struct radv_ray_tracing_pipeline * pipeline,const VkRayTracingPipelineCreateInfoKHR * pCreateInfo,struct rt_variables * vars)1857 lower_rt_instructions_monolithic(nir_shader *shader, struct radv_device *device,
1858                                  struct radv_ray_tracing_pipeline *pipeline,
1859                                  const VkRayTracingPipelineCreateInfoKHR *pCreateInfo, struct rt_variables *vars)
1860 {
1861    nir_function_impl *impl = nir_shader_get_entrypoint(shader);
1862 
1863    struct lower_rt_instruction_monolithic_state state = {
1864       .device = device,
1865       .pipeline = pipeline,
1866       .pCreateInfo = pCreateInfo,
1867       .vars = vars,
1868    };
1869 
1870    nir_shader_instructions_pass(shader, lower_rt_instruction_monolithic, nir_metadata_none, &state);
1871    nir_index_ssa_defs(impl);
1872 
1873    uint32_t hit_attrib_count = 0;
1874    nir_shader_intrinsics_pass(shader, radv_count_hit_attrib_slots, nir_metadata_all, &hit_attrib_count);
1875 
1876    /* Register storage for hit attributes */
1877    STACK_ARRAY(nir_variable *, hit_attribs, hit_attrib_count);
1878    for (uint32_t i = 0; i < hit_attrib_count; i++)
1879       hit_attribs[i] = nir_local_variable_create(impl, glsl_uint_type(), "ahit_attrib");
1880 
1881    lower_hit_attribs(shader, hit_attribs, 0);
1882 }
1883 
1884 /** Select the next shader based on priorities:
1885  *
1886  * Detect the priority of the shader stage by the lowest bits in the address (low to high):
1887  *  - Raygen              - idx 0
1888  *  - Traversal           - idx 1
1889  *  - Closest Hit / Miss  - idx 2
1890  *  - Callable            - idx 3
1891  *
1892  *
1893  * This gives us the following priorities:
1894  * Raygen       :  Callable  >               >  Traversal  >  Raygen
1895  * Traversal    :            >  Chit / Miss  >             >  Raygen
1896  * CHit / Miss  :  Callable  >  Chit / Miss  >  Traversal  >  Raygen
1897  * Callable     :  Callable  >  Chit / Miss  >             >  Raygen
1898  */
1899 static nir_def *
select_next_shader(nir_builder * b,nir_def * shader_addr,unsigned wave_size)1900 select_next_shader(nir_builder *b, nir_def *shader_addr, unsigned wave_size)
1901 {
1902    gl_shader_stage stage = b->shader->info.stage;
1903    nir_def *prio = nir_iand_imm(b, shader_addr, radv_rt_priority_mask);
1904    nir_def *ballot = nir_ballot(b, 1, wave_size, nir_imm_bool(b, true));
1905    nir_def *ballot_traversal = nir_ballot(b, 1, wave_size, nir_ieq_imm(b, prio, radv_rt_priority_traversal));
1906    nir_def *ballot_hit_miss = nir_ballot(b, 1, wave_size, nir_ieq_imm(b, prio, radv_rt_priority_hit_miss));
1907    nir_def *ballot_callable = nir_ballot(b, 1, wave_size, nir_ieq_imm(b, prio, radv_rt_priority_callable));
1908 
1909    if (stage != MESA_SHADER_CALLABLE && stage != MESA_SHADER_INTERSECTION)
1910       ballot = nir_bcsel(b, nir_ine_imm(b, ballot_traversal, 0), ballot_traversal, ballot);
1911    if (stage != MESA_SHADER_RAYGEN)
1912       ballot = nir_bcsel(b, nir_ine_imm(b, ballot_hit_miss, 0), ballot_hit_miss, ballot);
1913    if (stage != MESA_SHADER_INTERSECTION)
1914       ballot = nir_bcsel(b, nir_ine_imm(b, ballot_callable, 0), ballot_callable, ballot);
1915 
1916    nir_def *lsb = nir_find_lsb(b, ballot);
1917    nir_def *next = nir_read_invocation(b, shader_addr, lsb);
1918    return nir_iand_imm(b, next, ~radv_rt_priority_mask);
1919 }
1920 
1921 static void
radv_store_arg(nir_builder * b,const struct radv_shader_args * args,const struct radv_ray_tracing_stage_info * info,struct ac_arg arg,nir_def * value)1922 radv_store_arg(nir_builder *b, const struct radv_shader_args *args, const struct radv_ray_tracing_stage_info *info,
1923                struct ac_arg arg, nir_def *value)
1924 {
1925    /* Do not pass unused data to the next stage. */
1926    if (!info || !BITSET_TEST(info->unused_args, arg.arg_index))
1927       ac_nir_store_arg(b, &args->ac, arg, value);
1928 }
1929 
1930 void
radv_nir_lower_rt_abi(nir_shader * shader,const VkRayTracingPipelineCreateInfoKHR * pCreateInfo,const struct radv_shader_args * args,const struct radv_shader_info * info,uint32_t * stack_size,bool resume_shader,struct radv_device * device,struct radv_ray_tracing_pipeline * pipeline,bool monolithic,const struct radv_ray_tracing_stage_info * traversal_info)1931 radv_nir_lower_rt_abi(nir_shader *shader, const VkRayTracingPipelineCreateInfoKHR *pCreateInfo,
1932                       const struct radv_shader_args *args, const struct radv_shader_info *info, uint32_t *stack_size,
1933                       bool resume_shader, struct radv_device *device, struct radv_ray_tracing_pipeline *pipeline,
1934                       bool monolithic, const struct radv_ray_tracing_stage_info *traversal_info)
1935 {
1936    nir_function_impl *impl = nir_shader_get_entrypoint(shader);
1937 
1938    const VkPipelineCreateFlagBits2KHR create_flags = vk_rt_pipeline_create_flags(pCreateInfo);
1939 
1940    struct rt_variables vars = create_rt_variables(shader, device, create_flags, monolithic);
1941 
1942    if (monolithic)
1943       lower_rt_instructions_monolithic(shader, device, pipeline, pCreateInfo, &vars);
1944 
1945    struct radv_rt_shader_info rt_info = {0};
1946 
1947    lower_rt_instructions(shader, &vars, true, &rt_info);
1948 
1949    if (stack_size) {
1950       vars.stack_size = MAX2(vars.stack_size, shader->scratch_size);
1951       *stack_size = MAX2(*stack_size, vars.stack_size);
1952    }
1953    shader->scratch_size = 0;
1954 
1955    NIR_PASS(_, shader, nir_lower_returns);
1956 
1957    nir_cf_list list;
1958    nir_cf_extract(&list, nir_before_impl(impl), nir_after_impl(impl));
1959 
1960    /* initialize variables */
1961    nir_builder b = nir_builder_at(nir_before_impl(impl));
1962 
1963    nir_def *descriptor_sets = ac_nir_load_arg(&b, &args->ac, args->descriptor_sets[0]);
1964    nir_def *push_constants = ac_nir_load_arg(&b, &args->ac, args->ac.push_constants);
1965    nir_def *sbt_descriptors = ac_nir_load_arg(&b, &args->ac, args->ac.rt.sbt_descriptors);
1966 
1967    nir_def *launch_sizes[3];
1968    for (uint32_t i = 0; i < ARRAY_SIZE(launch_sizes); i++) {
1969       launch_sizes[i] = ac_nir_load_arg(&b, &args->ac, args->ac.rt.launch_sizes[i]);
1970       nir_store_var(&b, vars.launch_sizes[i], launch_sizes[i], 1);
1971    }
1972 
1973    nir_def *scratch_offset = NULL;
1974    if (args->ac.scratch_offset.used)
1975       scratch_offset = ac_nir_load_arg(&b, &args->ac, args->ac.scratch_offset);
1976    nir_def *ring_offsets = NULL;
1977    if (args->ac.ring_offsets.used)
1978       ring_offsets = ac_nir_load_arg(&b, &args->ac, args->ac.ring_offsets);
1979 
1980    nir_def *launch_ids[3];
1981    for (uint32_t i = 0; i < ARRAY_SIZE(launch_ids); i++) {
1982       launch_ids[i] = ac_nir_load_arg(&b, &args->ac, args->ac.rt.launch_ids[i]);
1983       nir_store_var(&b, vars.launch_ids[i], launch_ids[i], 1);
1984    }
1985 
1986    nir_def *traversal_addr = ac_nir_load_arg(&b, &args->ac, args->ac.rt.traversal_shader_addr);
1987    nir_store_var(&b, vars.traversal_addr, nir_pack_64_2x32(&b, traversal_addr), 1);
1988 
1989    nir_def *shader_addr = ac_nir_load_arg(&b, &args->ac, args->ac.rt.shader_addr);
1990    shader_addr = nir_pack_64_2x32(&b, shader_addr);
1991    nir_store_var(&b, vars.shader_addr, shader_addr, 1);
1992 
1993    nir_store_var(&b, vars.stack_ptr, ac_nir_load_arg(&b, &args->ac, args->ac.rt.dynamic_callable_stack_base), 1);
1994    nir_def *record_ptr = ac_nir_load_arg(&b, &args->ac, args->ac.rt.shader_record);
1995    nir_store_var(&b, vars.shader_record_ptr, nir_pack_64_2x32(&b, record_ptr), 1);
1996    nir_store_var(&b, vars.arg, ac_nir_load_arg(&b, &args->ac, args->ac.rt.payload_offset), 1);
1997 
1998    nir_def *accel_struct = ac_nir_load_arg(&b, &args->ac, args->ac.rt.accel_struct);
1999    nir_store_var(&b, vars.accel_struct, nir_pack_64_2x32(&b, accel_struct), 1);
2000    nir_store_var(&b, vars.cull_mask_and_flags, ac_nir_load_arg(&b, &args->ac, args->ac.rt.cull_mask_and_flags), 1);
2001    nir_store_var(&b, vars.sbt_offset, ac_nir_load_arg(&b, &args->ac, args->ac.rt.sbt_offset), 1);
2002    nir_store_var(&b, vars.sbt_stride, ac_nir_load_arg(&b, &args->ac, args->ac.rt.sbt_stride), 1);
2003    nir_store_var(&b, vars.origin, ac_nir_load_arg(&b, &args->ac, args->ac.rt.ray_origin), 0x7);
2004    nir_store_var(&b, vars.tmin, ac_nir_load_arg(&b, &args->ac, args->ac.rt.ray_tmin), 1);
2005    nir_store_var(&b, vars.direction, ac_nir_load_arg(&b, &args->ac, args->ac.rt.ray_direction), 0x7);
2006    nir_store_var(&b, vars.tmax, ac_nir_load_arg(&b, &args->ac, args->ac.rt.ray_tmax), 1);
2007 
2008    if (traversal_info && traversal_info->miss_index.state == RADV_RT_CONST_ARG_STATE_VALID)
2009       nir_store_var(&b, vars.miss_index, nir_imm_int(&b, traversal_info->miss_index.value), 0x1);
2010    else
2011       nir_store_var(&b, vars.miss_index, ac_nir_load_arg(&b, &args->ac, args->ac.rt.miss_index), 0x1);
2012 
2013    nir_store_var(&b, vars.primitive_id, ac_nir_load_arg(&b, &args->ac, args->ac.rt.primitive_id), 1);
2014    nir_def *instance_addr = ac_nir_load_arg(&b, &args->ac, args->ac.rt.instance_addr);
2015    nir_store_var(&b, vars.instance_addr, nir_pack_64_2x32(&b, instance_addr), 1);
2016    nir_store_var(&b, vars.geometry_id_and_flags, ac_nir_load_arg(&b, &args->ac, args->ac.rt.geometry_id_and_flags), 1);
2017    nir_store_var(&b, vars.hit_kind, ac_nir_load_arg(&b, &args->ac, args->ac.rt.hit_kind), 1);
2018 
2019    /* guard the shader, so that only the correct invocations execute it */
2020    nir_if *shader_guard = NULL;
2021    if (shader->info.stage != MESA_SHADER_RAYGEN || resume_shader) {
2022       nir_def *uniform_shader_addr = ac_nir_load_arg(&b, &args->ac, args->ac.rt.uniform_shader_addr);
2023       uniform_shader_addr = nir_pack_64_2x32(&b, uniform_shader_addr);
2024       uniform_shader_addr = nir_ior_imm(&b, uniform_shader_addr, radv_get_rt_priority(shader->info.stage));
2025 
2026       shader_guard = nir_push_if(&b, nir_ieq(&b, uniform_shader_addr, shader_addr));
2027       shader_guard->control = nir_selection_control_divergent_always_taken;
2028    }
2029 
2030    nir_cf_reinsert(&list, b.cursor);
2031 
2032    if (shader_guard)
2033       nir_pop_if(&b, shader_guard);
2034 
2035    b.cursor = nir_after_impl(impl);
2036 
2037    if (monolithic) {
2038       nir_terminate(&b);
2039    } else {
2040       /* select next shader */
2041       shader_addr = nir_load_var(&b, vars.shader_addr);
2042       nir_def *next = select_next_shader(&b, shader_addr, info->wave_size);
2043       ac_nir_store_arg(&b, &args->ac, args->ac.rt.uniform_shader_addr, next);
2044 
2045       ac_nir_store_arg(&b, &args->ac, args->descriptor_sets[0], descriptor_sets);
2046       ac_nir_store_arg(&b, &args->ac, args->ac.push_constants, push_constants);
2047       ac_nir_store_arg(&b, &args->ac, args->ac.rt.sbt_descriptors, sbt_descriptors);
2048       ac_nir_store_arg(&b, &args->ac, args->ac.rt.traversal_shader_addr, traversal_addr);
2049 
2050       for (uint32_t i = 0; i < ARRAY_SIZE(launch_sizes); i++) {
2051          if (rt_info.uses_launch_size)
2052             ac_nir_store_arg(&b, &args->ac, args->ac.rt.launch_sizes[i], launch_sizes[i]);
2053          else
2054             radv_store_arg(&b, args, traversal_info, args->ac.rt.launch_sizes[i], launch_sizes[i]);
2055       }
2056 
2057       if (scratch_offset)
2058          ac_nir_store_arg(&b, &args->ac, args->ac.scratch_offset, scratch_offset);
2059       if (ring_offsets)
2060          ac_nir_store_arg(&b, &args->ac, args->ac.ring_offsets, ring_offsets);
2061 
2062       for (uint32_t i = 0; i < ARRAY_SIZE(launch_ids); i++) {
2063          if (rt_info.uses_launch_id)
2064             ac_nir_store_arg(&b, &args->ac, args->ac.rt.launch_ids[i], launch_ids[i]);
2065          else
2066             radv_store_arg(&b, args, traversal_info, args->ac.rt.launch_ids[i], launch_ids[i]);
2067       }
2068 
2069       /* store back all variables to registers */
2070       ac_nir_store_arg(&b, &args->ac, args->ac.rt.dynamic_callable_stack_base, nir_load_var(&b, vars.stack_ptr));
2071       ac_nir_store_arg(&b, &args->ac, args->ac.rt.shader_addr, shader_addr);
2072       radv_store_arg(&b, args, traversal_info, args->ac.rt.shader_record, nir_load_var(&b, vars.shader_record_ptr));
2073       radv_store_arg(&b, args, traversal_info, args->ac.rt.payload_offset, nir_load_var(&b, vars.arg));
2074       radv_store_arg(&b, args, traversal_info, args->ac.rt.accel_struct, nir_load_var(&b, vars.accel_struct));
2075       radv_store_arg(&b, args, traversal_info, args->ac.rt.cull_mask_and_flags,
2076                      nir_load_var(&b, vars.cull_mask_and_flags));
2077       radv_store_arg(&b, args, traversal_info, args->ac.rt.sbt_offset, nir_load_var(&b, vars.sbt_offset));
2078       radv_store_arg(&b, args, traversal_info, args->ac.rt.sbt_stride, nir_load_var(&b, vars.sbt_stride));
2079       radv_store_arg(&b, args, traversal_info, args->ac.rt.miss_index, nir_load_var(&b, vars.miss_index));
2080       radv_store_arg(&b, args, traversal_info, args->ac.rt.ray_origin, nir_load_var(&b, vars.origin));
2081       radv_store_arg(&b, args, traversal_info, args->ac.rt.ray_tmin, nir_load_var(&b, vars.tmin));
2082       radv_store_arg(&b, args, traversal_info, args->ac.rt.ray_direction, nir_load_var(&b, vars.direction));
2083       radv_store_arg(&b, args, traversal_info, args->ac.rt.ray_tmax, nir_load_var(&b, vars.tmax));
2084 
2085       radv_store_arg(&b, args, traversal_info, args->ac.rt.primitive_id, nir_load_var(&b, vars.primitive_id));
2086       radv_store_arg(&b, args, traversal_info, args->ac.rt.instance_addr, nir_load_var(&b, vars.instance_addr));
2087       radv_store_arg(&b, args, traversal_info, args->ac.rt.geometry_id_and_flags,
2088                      nir_load_var(&b, vars.geometry_id_and_flags));
2089       radv_store_arg(&b, args, traversal_info, args->ac.rt.hit_kind, nir_load_var(&b, vars.hit_kind));
2090    }
2091 
2092    nir_metadata_preserve(impl, nir_metadata_none);
2093 
2094    /* cleanup passes */
2095    NIR_PASS_V(shader, nir_lower_global_vars_to_local);
2096    NIR_PASS_V(shader, nir_lower_vars_to_ssa);
2097    if (shader->info.stage == MESA_SHADER_CLOSEST_HIT || shader->info.stage == MESA_SHADER_INTERSECTION)
2098       NIR_PASS_V(shader, lower_hit_attribs, NULL, info->wave_size);
2099 }
2100 
2101 static bool
radv_arg_def_is_unused(nir_def * def)2102 radv_arg_def_is_unused(nir_def *def)
2103 {
2104    nir_foreach_use (use, def) {
2105       nir_instr *use_instr = nir_src_parent_instr(use);
2106       if (use_instr->type == nir_instr_type_intrinsic) {
2107          nir_intrinsic_instr *use_intr = nir_instr_as_intrinsic(use_instr);
2108          if (use_intr->intrinsic == nir_intrinsic_store_scalar_arg_amd ||
2109              use_intr->intrinsic == nir_intrinsic_store_vector_arg_amd)
2110             continue;
2111       } else if (use_instr->type == nir_instr_type_phi) {
2112          nir_cf_node *prev_node = nir_cf_node_prev(&use_instr->block->cf_node);
2113          if (!prev_node)
2114             return false;
2115 
2116          nir_phi_instr *phi = nir_instr_as_phi(use_instr);
2117          if (radv_arg_def_is_unused(&phi->def))
2118             continue;
2119       }
2120 
2121       return false;
2122    }
2123 
2124    return true;
2125 }
2126 
2127 static bool
radv_gather_unused_args_instr(nir_builder * b,nir_intrinsic_instr * instr,void * data)2128 radv_gather_unused_args_instr(nir_builder *b, nir_intrinsic_instr *instr, void *data)
2129 {
2130    if (instr->intrinsic != nir_intrinsic_load_scalar_arg_amd && instr->intrinsic != nir_intrinsic_load_vector_arg_amd)
2131       return false;
2132 
2133    if (!radv_arg_def_is_unused(&instr->def)) {
2134       /* This arg is used for more than passing data to the next stage. */
2135       struct radv_ray_tracing_stage_info *info = data;
2136       BITSET_CLEAR(info->unused_args, nir_intrinsic_base(instr));
2137    }
2138 
2139    return false;
2140 }
2141 
2142 void
radv_gather_unused_args(struct radv_ray_tracing_stage_info * info,nir_shader * nir)2143 radv_gather_unused_args(struct radv_ray_tracing_stage_info *info, nir_shader *nir)
2144 {
2145    nir_shader_intrinsics_pass(nir, radv_gather_unused_args_instr, nir_metadata_all, info);
2146 }
2147