xref: /aosp_15_r20/external/mesa3d/src/gallium/frontends/lavapipe/lvp_ray_tracing_pipeline.c (revision 6104692788411f58d303aa86923a9ff6ecaded22)
1 /*
2  * Copyright © 2024 Valve Corporation
3  *
4  * SPDX-License-Identifier: MIT
5  */
6 
7 #include "lvp_private.h"
8 #include "lvp_acceleration_structure.h"
9 #include "lvp_nir_ray_tracing.h"
10 
11 #include "vk_pipeline.h"
12 
13 #include "nir.h"
14 #include "nir_builder.h"
15 
16 #include "spirv/spirv.h"
17 
18 #include "util/mesa-sha1.h"
19 #include "util/simple_mtx.h"
20 
21 static void
lvp_init_ray_tracing_groups(struct lvp_pipeline * pipeline,const VkRayTracingPipelineCreateInfoKHR * create_info)22 lvp_init_ray_tracing_groups(struct lvp_pipeline *pipeline,
23                             const VkRayTracingPipelineCreateInfoKHR *create_info)
24 {
25    uint32_t i = 0;
26    for (; i < create_info->groupCount; i++) {
27       const VkRayTracingShaderGroupCreateInfoKHR *group_info = create_info->pGroups + i;
28       struct lvp_ray_tracing_group *dst = pipeline->rt.groups + i;
29 
30       dst->recursive_index = VK_SHADER_UNUSED_KHR;
31       dst->ahit_index = VK_SHADER_UNUSED_KHR;
32       dst->isec_index = VK_SHADER_UNUSED_KHR;
33 
34       switch (group_info->type) {
35       case VK_RAY_TRACING_SHADER_GROUP_TYPE_GENERAL_KHR:
36          if (group_info->generalShader != VK_SHADER_UNUSED_KHR) {
37             dst->recursive_index = group_info->generalShader;
38          }
39          break;
40       case VK_RAY_TRACING_SHADER_GROUP_TYPE_TRIANGLES_HIT_GROUP_KHR:
41          if (group_info->closestHitShader != VK_SHADER_UNUSED_KHR) {
42             dst->recursive_index = group_info->closestHitShader;
43          }
44          if (group_info->anyHitShader != VK_SHADER_UNUSED_KHR) {
45             dst->ahit_index = group_info->anyHitShader;
46          }
47          break;
48       case VK_RAY_TRACING_SHADER_GROUP_TYPE_PROCEDURAL_HIT_GROUP_KHR:
49          if (group_info->closestHitShader != VK_SHADER_UNUSED_KHR) {
50             dst->recursive_index = group_info->closestHitShader;
51          }
52          if (group_info->intersectionShader != VK_SHADER_UNUSED_KHR) {
53             dst->isec_index = group_info->intersectionShader;
54 
55             if (group_info->anyHitShader != VK_SHADER_UNUSED_KHR)
56                dst->ahit_index = group_info->anyHitShader;
57          }
58          break;
59       default:
60          unreachable("Unimplemented VkRayTracingShaderGroupTypeKHR");
61       }
62 
63       dst->handle.index = p_atomic_inc_return(&pipeline->device->group_handle_alloc);
64    }
65 
66    if (!create_info->pLibraryInfo)
67       return;
68 
69    uint32_t stage_base_index = create_info->stageCount;
70    for (uint32_t library_index = 0; library_index < create_info->pLibraryInfo->libraryCount; library_index++) {
71       VK_FROM_HANDLE(lvp_pipeline, library, create_info->pLibraryInfo->pLibraries[library_index]);
72       for (uint32_t group_index = 0; group_index < library->rt.group_count; group_index++) {
73          const struct lvp_ray_tracing_group *src = library->rt.groups + group_index;
74          struct lvp_ray_tracing_group *dst = pipeline->rt.groups + i;
75 
76          dst->handle = src->handle;
77 
78          if (src->recursive_index != VK_SHADER_UNUSED_KHR)
79             dst->recursive_index = stage_base_index + src->recursive_index;
80          else
81             dst->recursive_index = VK_SHADER_UNUSED_KHR;
82 
83          if (src->ahit_index != VK_SHADER_UNUSED_KHR)
84             dst->ahit_index = stage_base_index + src->ahit_index;
85          else
86             dst->ahit_index = VK_SHADER_UNUSED_KHR;
87 
88          if (src->isec_index != VK_SHADER_UNUSED_KHR)
89             dst->isec_index = stage_base_index + src->isec_index;
90          else
91             dst->isec_index = VK_SHADER_UNUSED_KHR;
92 
93          i++;
94       }
95       stage_base_index += library->rt.stage_count;
96    }
97 }
98 
99 static bool
lvp_lower_ray_tracing_derefs(nir_shader * shader)100 lvp_lower_ray_tracing_derefs(nir_shader *shader)
101 {
102    nir_function_impl *impl = nir_shader_get_entrypoint(shader);
103 
104    bool progress = false;
105 
106    nir_builder _b = nir_builder_at(nir_before_impl(impl));
107    nir_builder *b = &_b;
108 
109    nir_def *arg_offset = nir_load_shader_call_data_offset_lvp(b);
110 
111    nir_foreach_block (block, impl) {
112       nir_foreach_instr_safe (instr, block) {
113          if (instr->type != nir_instr_type_deref)
114             continue;
115 
116          nir_deref_instr *deref = nir_instr_as_deref(instr);
117          if (!nir_deref_mode_is_one_of(deref, nir_var_shader_call_data |
118                                        nir_var_ray_hit_attrib))
119             continue;
120 
121          bool is_shader_call_data = nir_deref_mode_is(deref, nir_var_shader_call_data);
122 
123          deref->modes = nir_var_function_temp;
124          progress = true;
125 
126          if (deref->deref_type == nir_deref_type_var) {
127             b->cursor = nir_before_instr(&deref->instr);
128             nir_def *offset = is_shader_call_data ? arg_offset : nir_imm_int(b, 0);
129             nir_deref_instr *replacement =
130                nir_build_deref_cast(b, offset, nir_var_function_temp, deref->var->type, 0);
131             nir_def_replace(&deref->def, &replacement->def);
132          }
133       }
134    }
135 
136    if (progress)
137       nir_metadata_preserve(impl, nir_metadata_control_flow);
138    else
139       nir_metadata_preserve(impl, nir_metadata_all);
140 
141    return progress;
142 }
143 
144 static bool
lvp_move_ray_tracing_intrinsic(nir_builder * b,nir_intrinsic_instr * instr,void * data)145 lvp_move_ray_tracing_intrinsic(nir_builder *b, nir_intrinsic_instr *instr, void *data)
146 {
147    switch (instr->intrinsic) {
148    case nir_intrinsic_load_shader_record_ptr:
149    case nir_intrinsic_load_ray_flags:
150    case nir_intrinsic_load_ray_object_origin:
151    case nir_intrinsic_load_ray_world_origin:
152    case nir_intrinsic_load_ray_t_min:
153    case nir_intrinsic_load_ray_object_direction:
154    case nir_intrinsic_load_ray_world_direction:
155    case nir_intrinsic_load_ray_t_max:
156       nir_instr_move(nir_before_impl(b->impl), &instr->instr);
157       return true;
158    default:
159       return false;
160    }
161 }
162 
163 static VkResult
lvp_compile_ray_tracing_stages(struct lvp_pipeline * pipeline,const VkRayTracingPipelineCreateInfoKHR * create_info)164 lvp_compile_ray_tracing_stages(struct lvp_pipeline *pipeline,
165                                const VkRayTracingPipelineCreateInfoKHR *create_info)
166 {
167    VkResult result = VK_SUCCESS;
168 
169    uint32_t i = 0;
170    for (; i < create_info->stageCount; i++) {
171       nir_shader *nir;
172       result = lvp_spirv_to_nir(pipeline, create_info->pStages + i, &nir);
173       if (result != VK_SUCCESS)
174          return result;
175 
176       assert(!nir->scratch_size);
177       if (nir->info.stage == MESA_SHADER_ANY_HIT ||
178           nir->info.stage == MESA_SHADER_CLOSEST_HIT ||
179           nir->info.stage == MESA_SHADER_INTERSECTION)
180          nir->scratch_size = LVP_RAY_HIT_ATTRIBS_SIZE;
181 
182       NIR_PASS(_, nir, nir_lower_vars_to_explicit_types,
183                nir_var_function_temp | nir_var_shader_call_data | nir_var_ray_hit_attrib,
184                glsl_get_natural_size_align_bytes);
185 
186       NIR_PASS(_, nir, lvp_lower_ray_tracing_derefs);
187 
188       NIR_PASS(_, nir, nir_lower_explicit_io, nir_var_function_temp, nir_address_format_32bit_offset);
189 
190       NIR_PASS(_, nir, nir_shader_intrinsics_pass, lvp_move_ray_tracing_intrinsic,
191                nir_metadata_control_flow, NULL);
192 
193       pipeline->rt.stages[i] = lvp_create_pipeline_nir(nir);
194       if (!pipeline->rt.stages[i]) {
195          result = VK_ERROR_OUT_OF_HOST_MEMORY;
196          ralloc_free(nir);
197          return result;
198       }
199    }
200 
201    if (!create_info->pLibraryInfo)
202       return result;
203 
204    for (uint32_t library_index = 0; library_index < create_info->pLibraryInfo->libraryCount; library_index++) {
205       VK_FROM_HANDLE(lvp_pipeline, library, create_info->pLibraryInfo->pLibraries[library_index]);
206       for (uint32_t stage_index = 0; stage_index < library->rt.stage_count; stage_index++) {
207          lvp_pipeline_nir_ref(pipeline->rt.stages + i, library->rt.stages[stage_index]);
208          i++;
209       }
210    }
211 
212    return result;
213 }
214 
215 static nir_def *
lvp_load_trace_ray_command_field(nir_builder * b,uint32_t command_offset,uint32_t num_components,uint32_t bit_size)216 lvp_load_trace_ray_command_field(nir_builder *b, uint32_t command_offset,
217                                  uint32_t num_components, uint32_t bit_size)
218 {
219    return nir_load_ssbo(b, num_components, bit_size, nir_imm_int(b, 0),
220                         nir_imm_int(b, command_offset));
221 }
222 
223 struct lvp_sbt_entry {
224    nir_def *value;
225    nir_def *shader_record_ptr;
226 };
227 
228 static struct lvp_sbt_entry
lvp_load_sbt_entry(nir_builder * b,nir_def * index,uint32_t command_offset,uint32_t index_offset)229 lvp_load_sbt_entry(nir_builder *b, nir_def *index,
230                    uint32_t command_offset, uint32_t index_offset)
231 {
232    nir_def *addr = lvp_load_trace_ray_command_field(b, command_offset, 1, 64);
233 
234    if (index) {
235       /* The 32 high bits of stride can be ignored. */
236       nir_def *stride = lvp_load_trace_ray_command_field(
237          b, command_offset + sizeof(VkDeviceSize) * 2, 1, 32);
238       addr = nir_iadd(b, addr, nir_u2u64(b, nir_imul(b, index, stride)));
239    }
240 
241    return (struct lvp_sbt_entry) {
242       .value = nir_build_load_global(b, 1, 32, nir_iadd_imm(b, addr, index_offset)),
243       .shader_record_ptr = nir_iadd_imm(b, addr, LVP_RAY_TRACING_GROUP_HANDLE_SIZE),
244    };
245 }
246 
247 struct lvp_ray_traversal_state {
248    nir_variable *origin;
249    nir_variable *dir;
250    nir_variable *inv_dir;
251    nir_variable *bvh_base;
252    nir_variable *current_node;
253    nir_variable *stack_base;
254    nir_variable *stack_ptr;
255    nir_variable *stack;
256    nir_variable *hit;
257 
258    nir_variable *instance_addr;
259    nir_variable *sbt_offset_and_flags;
260 };
261 
262 struct lvp_ray_tracing_state {
263    nir_variable *bvh_base;
264    nir_variable *flags;
265    nir_variable *cull_mask;
266    nir_variable *sbt_offset;
267    nir_variable *sbt_stride;
268    nir_variable *miss_index;
269    nir_variable *origin;
270    nir_variable *tmin;
271    nir_variable *dir;
272    nir_variable *tmax;
273 
274    nir_variable *instance_addr;
275    nir_variable *primitive_id;
276    nir_variable *geometry_id_and_flags;
277    nir_variable *hit_kind;
278    nir_variable *sbt_index;
279 
280    nir_variable *shader_record_ptr;
281    nir_variable *stack_ptr;
282    nir_variable *shader_call_data_offset;
283 
284    nir_variable *accept;
285    nir_variable *terminate;
286    nir_variable *opaque;
287 
288    struct lvp_ray_traversal_state traversal;
289 };
290 
291 struct lvp_ray_tracing_pipeline_compiler {
292    struct lvp_pipeline *pipeline;
293    VkPipelineCreateFlags2KHR flags;
294 
295    struct lvp_ray_tracing_state state;
296 
297    struct hash_table *functions;
298 
299    uint32_t raygen_size;
300    uint32_t ahit_size;
301    uint32_t chit_size;
302    uint32_t miss_size;
303    uint32_t isec_size;
304    uint32_t callable_size;
305 };
306 
307 static uint32_t
lvp_ray_tracing_pipeline_compiler_get_stack_size(struct lvp_ray_tracing_pipeline_compiler * compiler,nir_function * function)308 lvp_ray_tracing_pipeline_compiler_get_stack_size(
309    struct lvp_ray_tracing_pipeline_compiler *compiler, nir_function *function)
310 {
311    hash_table_foreach(compiler->functions, entry) {
312       if (entry->data == function) {
313          const nir_shader *shader = entry->key;
314          return shader->scratch_size;
315       }
316    }
317    return 0;
318 }
319 
320 static void
lvp_ray_tracing_state_init(nir_shader * nir,struct lvp_ray_tracing_state * state)321 lvp_ray_tracing_state_init(nir_shader *nir, struct lvp_ray_tracing_state *state)
322 {
323    state->bvh_base = nir_variable_create(nir, nir_var_shader_temp, glsl_uint64_t_type(), "bvh_base");
324    state->flags = nir_variable_create(nir, nir_var_shader_temp, glsl_uint_type(), "flags");
325    state->cull_mask = nir_variable_create(nir, nir_var_shader_temp, glsl_uint_type(), "cull_mask");
326    state->sbt_offset = nir_variable_create(nir, nir_var_shader_temp, glsl_uint_type(), "sbt_offset");
327    state->sbt_stride = nir_variable_create(nir, nir_var_shader_temp, glsl_uint_type(), "sbt_stride");
328    state->miss_index = nir_variable_create(nir, nir_var_shader_temp, glsl_uint_type(), "miss_index");
329    state->origin = nir_variable_create(nir, nir_var_shader_temp, glsl_vec_type(3), "origin");
330    state->tmin = nir_variable_create(nir, nir_var_shader_temp, glsl_float_type(), "tmin");
331    state->dir = nir_variable_create(nir, nir_var_shader_temp, glsl_vec_type(3), "dir");
332    state->tmax = nir_variable_create(nir, nir_var_shader_temp, glsl_float_type(), "tmax");
333 
334    state->instance_addr = nir_variable_create(nir, nir_var_shader_temp, glsl_uint64_t_type(), "instance_addr");
335    state->primitive_id = nir_variable_create(nir, nir_var_shader_temp, glsl_uint_type(), "primitive_id");
336    state->geometry_id_and_flags = nir_variable_create(nir, nir_var_shader_temp, glsl_uint_type(), "geometry_id_and_flags");
337    state->hit_kind = nir_variable_create(nir, nir_var_shader_temp, glsl_uint_type(), "hit_kind");
338    state->sbt_index = nir_variable_create(nir, nir_var_shader_temp, glsl_uint_type(), "sbt_index");
339 
340    state->shader_record_ptr = nir_variable_create(nir, nir_var_shader_temp, glsl_uint64_t_type(), "shader_record_ptr");
341    state->stack_ptr = nir_variable_create(nir, nir_var_shader_temp, glsl_uint_type(), "stack_ptr");
342    state->shader_call_data_offset = nir_variable_create(nir, nir_var_shader_temp, glsl_uint_type(), "shader_call_data_offset");
343 
344    state->accept = nir_variable_create(nir, nir_var_shader_temp, glsl_bool_type(), "accept");
345    state->terminate = nir_variable_create(nir, nir_var_shader_temp, glsl_bool_type(), "terminate");
346    state->opaque = nir_variable_create(nir, nir_var_shader_temp, glsl_bool_type(), "opaque");
347 }
348 
349 static void
lvp_ray_traversal_state_init(nir_function_impl * impl,struct lvp_ray_traversal_state * state)350 lvp_ray_traversal_state_init(nir_function_impl *impl, struct lvp_ray_traversal_state *state)
351 {
352    state->origin = nir_local_variable_create(impl, glsl_vec_type(3), "traversal.origin");
353    state->dir = nir_local_variable_create(impl, glsl_vec_type(3), "traversal.dir");
354    state->inv_dir = nir_local_variable_create(impl, glsl_vec_type(3), "traversal.inv_dir");
355    state->bvh_base = nir_local_variable_create(impl, glsl_uint64_t_type(), "traversal.bvh_base");
356    state->current_node = nir_local_variable_create(impl, glsl_uint_type(), "traversal.current_node");
357    state->stack_base = nir_local_variable_create(impl, glsl_uint_type(), "traversal.stack_base");
358    state->stack_ptr = nir_local_variable_create(impl, glsl_uint_type(), "traversal.stack_ptr");
359    state->stack = nir_local_variable_create(impl, glsl_array_type(glsl_uint_type(), 24 * 2, 0), "traversal.stack");
360    state->hit = nir_local_variable_create(impl, glsl_bool_type(), "traversal.hit");
361 
362    state->instance_addr = nir_local_variable_create(impl, glsl_uint64_t_type(), "traversal.instance_addr");
363    state->sbt_offset_and_flags = nir_local_variable_create(impl, glsl_uint_type(), "traversal.sbt_offset_and_flags");
364 }
365 
366 static void
lvp_call_ray_tracing_stage(nir_builder * b,struct lvp_ray_tracing_pipeline_compiler * compiler,nir_shader * stage)367 lvp_call_ray_tracing_stage(nir_builder *b, struct lvp_ray_tracing_pipeline_compiler *compiler, nir_shader *stage)
368 {
369    nir_function *function;
370 
371    struct hash_entry *entry = _mesa_hash_table_search(compiler->functions, stage);
372    if (entry) {
373       function = entry->data;
374    } else {
375       nir_function_impl *stage_entrypoint = nir_shader_get_entrypoint(stage);
376       nir_function_impl *copy = nir_function_impl_clone(b->shader, stage_entrypoint);
377 
378       struct hash_table *var_remap = _mesa_pointer_hash_table_create(NULL);
379 
380       nir_foreach_block(block, copy) {
381          nir_foreach_instr_safe(instr, block) {
382             if (instr->type != nir_instr_type_deref)
383                continue;
384 
385             nir_deref_instr *deref = nir_instr_as_deref(instr);
386             if (deref->deref_type != nir_deref_type_var ||
387                 deref->var->data.mode == nir_var_function_temp)
388                continue;
389 
390             struct hash_entry *entry =
391                _mesa_hash_table_search(var_remap, deref->var);
392             if (!entry) {
393                nir_variable *new_var = nir_variable_clone(deref->var, b->shader);
394                nir_shader_add_variable(b->shader, new_var);
395                entry = _mesa_hash_table_insert(var_remap,
396                                                deref->var, new_var);
397             }
398             deref->var = entry->data;
399          }
400       }
401 
402       function = nir_function_create(
403          b->shader, _mesa_shader_stage_to_string(stage->info.stage));
404       nir_function_set_impl(function, copy);
405 
406       ralloc_free(var_remap);
407 
408       _mesa_hash_table_insert(compiler->functions, stage, function);
409    }
410 
411    nir_build_call(b, function, 0, NULL);
412 
413    switch(stage->info.stage) {
414    case MESA_SHADER_RAYGEN:
415       compiler->raygen_size = MAX2(compiler->raygen_size, stage->scratch_size);
416       break;
417    case MESA_SHADER_ANY_HIT:
418       compiler->ahit_size = MAX2(compiler->ahit_size, stage->scratch_size);
419       break;
420    case MESA_SHADER_CLOSEST_HIT:
421       compiler->chit_size = MAX2(compiler->chit_size, stage->scratch_size);
422       break;
423    case MESA_SHADER_MISS:
424       compiler->miss_size = MAX2(compiler->miss_size, stage->scratch_size);
425       break;
426    case MESA_SHADER_INTERSECTION:
427       compiler->isec_size = MAX2(compiler->isec_size, stage->scratch_size);
428       break;
429    case MESA_SHADER_CALLABLE:
430       compiler->callable_size = MAX2(compiler->callable_size, stage->scratch_size);
431       break;
432    default:
433       unreachable("Invalid ray tracing stage");
434       break;
435    }
436 }
437 
438 static void
lvp_execute_callable(nir_builder * b,struct lvp_ray_tracing_pipeline_compiler * compiler,nir_intrinsic_instr * instr)439 lvp_execute_callable(nir_builder *b, struct lvp_ray_tracing_pipeline_compiler *compiler,
440                      nir_intrinsic_instr *instr)
441 {
442    struct lvp_ray_tracing_state *state = &compiler->state;
443 
444    nir_def *sbt_index = instr->src[0].ssa;
445    nir_def *payload = instr->src[1].ssa;
446 
447    struct lvp_sbt_entry callable_entry = lvp_load_sbt_entry(
448       b,
449       sbt_index,
450       offsetof(VkTraceRaysIndirectCommand2KHR, callableShaderBindingTableAddress),
451       offsetof(struct lvp_ray_tracing_group_handle, index));
452    nir_store_var(b, compiler->state.shader_record_ptr, callable_entry.shader_record_ptr, 0x1);
453 
454    uint32_t stack_size =
455       lvp_ray_tracing_pipeline_compiler_get_stack_size(compiler, b->impl->function);
456    nir_def *stack_ptr = nir_load_var(b, state->stack_ptr);
457    nir_store_var(b, state->stack_ptr, nir_iadd_imm(b, stack_ptr, stack_size), 0x1);
458 
459    nir_store_var(b, state->shader_call_data_offset, nir_iadd_imm(b, payload, -stack_size), 0x1);
460 
461    for (uint32_t i = 0; i < compiler->pipeline->rt.group_count; i++) {
462       struct lvp_ray_tracing_group *group = compiler->pipeline->rt.groups + i;
463       if (group->recursive_index == VK_SHADER_UNUSED_KHR)
464          continue;
465 
466       nir_shader *stage = compiler->pipeline->rt.stages[group->recursive_index]->nir;
467       if (stage->info.stage != MESA_SHADER_CALLABLE)
468          continue;
469 
470       nir_push_if(b, nir_ieq_imm(b, callable_entry.value, group->handle.index));
471       lvp_call_ray_tracing_stage(b, compiler, stage);
472       nir_pop_if(b, NULL);
473    }
474 
475    nir_store_var(b, state->stack_ptr, stack_ptr, 0x1);
476 }
477 
478 struct lvp_lower_isec_intrinsic_state {
479    struct lvp_ray_tracing_pipeline_compiler *compiler;
480    nir_shader *ahit;
481 };
482 
483 static bool
lvp_lower_isec_intrinsic(nir_builder * b,nir_intrinsic_instr * instr,void * data)484 lvp_lower_isec_intrinsic(nir_builder *b, nir_intrinsic_instr *instr, void *data)
485 {
486    if (instr->intrinsic != nir_intrinsic_report_ray_intersection)
487       return false;
488 
489    struct lvp_lower_isec_intrinsic_state *isec_state = data;
490    struct lvp_ray_tracing_pipeline_compiler *compiler = isec_state->compiler;
491    struct lvp_ray_tracing_state *state = &compiler->state;
492 
493    b->cursor = nir_after_instr(&instr->instr);
494 
495    nir_def *t = instr->src[0].ssa;
496    nir_def *hit_kind = instr->src[1].ssa;
497 
498    nir_def *prev_accept = nir_load_var(b, state->accept);
499    nir_def *prev_tmax = nir_load_var(b, state->tmax);
500    nir_def *prev_hit_kind = nir_load_var(b, state->hit_kind);
501 
502    nir_variable *commit = nir_local_variable_create(b->impl, glsl_bool_type(), "commit");
503    nir_store_var(b, commit, nir_imm_false(b), 0x1);
504 
505    nir_push_if(b, nir_iand(b, nir_fge(b, t, nir_load_var(b, state->tmin)), nir_fge(b, nir_load_var(b, state->tmax), t)));
506    {
507       nir_store_var(b, state->accept, nir_imm_true(b), 0x1);
508 
509       nir_store_var(b, state->tmax, t, 1);
510       nir_store_var(b, state->hit_kind, hit_kind, 1);
511 
512       if (isec_state->ahit) {
513          nir_def *prev_terminate = nir_load_var(b, state->terminate);
514          nir_store_var(b, state->terminate, nir_imm_false(b), 0x1);
515 
516          nir_push_if(b, nir_inot(b, nir_load_var(b, state->opaque)));
517          {
518             lvp_call_ray_tracing_stage(b, compiler, isec_state->ahit);
519          }
520          nir_pop_if(b, NULL);
521 
522          nir_def *terminate = nir_load_var(b, state->terminate);
523          nir_store_var(b, state->terminate, nir_ior(b, terminate, prev_terminate), 0x1);
524 
525          nir_push_if(b, terminate);
526          nir_jump(b, nir_jump_return);
527          nir_pop_if(b, NULL);
528       }
529 
530       nir_push_if(b, nir_load_var(b, state->accept));
531       {
532          nir_store_var(b, commit, nir_imm_true(b), 0x1);
533       }
534       nir_push_else(b, NULL);
535       {
536          nir_store_var(b, state->accept, prev_accept, 0x1);
537          nir_store_var(b, state->tmax, prev_tmax, 1);
538          nir_store_var(b, state->hit_kind, prev_hit_kind, 1);
539       }
540       nir_pop_if(b, NULL);
541    }
542    nir_pop_if(b, NULL);
543 
544    nir_def_replace(&instr->def, nir_load_var(b, commit));
545 
546    return true;
547 }
548 
549 static void
lvp_handle_aabb_intersection(nir_builder * b,struct lvp_leaf_intersection * intersection,const struct lvp_ray_traversal_args * args,const struct lvp_ray_flags * ray_flags)550 lvp_handle_aabb_intersection(nir_builder *b, struct lvp_leaf_intersection *intersection,
551                              const struct lvp_ray_traversal_args *args,
552                              const struct lvp_ray_flags *ray_flags)
553 {
554    struct lvp_ray_tracing_pipeline_compiler *compiler = args->data;
555    struct lvp_ray_tracing_state *state = &compiler->state;
556 
557    nir_store_var(b, state->accept, nir_imm_false(b), 0x1);
558    nir_store_var(b, state->terminate, ray_flags->terminate_on_first_hit, 0x1);
559    nir_store_var(b, state->opaque, intersection->opaque, 0x1);
560 
561    nir_def *prev_instance_addr = nir_load_var(b, state->instance_addr);
562    nir_def *prev_primitive_id = nir_load_var(b, state->primitive_id);
563    nir_def *prev_geometry_id_and_flags = nir_load_var(b, state->geometry_id_and_flags);
564 
565    nir_store_var(b, state->instance_addr, nir_load_var(b, state->traversal.instance_addr), 0x1);
566    nir_store_var(b, state->primitive_id, intersection->primitive_id, 0x1);
567    nir_store_var(b, state->geometry_id_and_flags, intersection->geometry_id_and_flags, 0x1);
568 
569    nir_def *geometry_id = nir_iand_imm(b, intersection->geometry_id_and_flags, 0xfffffff);
570    nir_def *sbt_index =
571       nir_iadd(b,
572                nir_iadd(b, nir_load_var(b, state->sbt_offset),
573                         nir_iand_imm(b, nir_load_var(b, state->traversal.sbt_offset_and_flags), 0xffffff)),
574                nir_imul(b, nir_load_var(b, state->sbt_stride), geometry_id));
575 
576    struct lvp_sbt_entry isec_entry = lvp_load_sbt_entry(
577       b,
578       sbt_index,
579       offsetof(VkTraceRaysIndirectCommand2KHR, hitShaderBindingTableAddress),
580       offsetof(struct lvp_ray_tracing_group_handle, index));
581    nir_store_var(b, compiler->state.shader_record_ptr, isec_entry.shader_record_ptr, 0x1);
582 
583    for (uint32_t i = 0; i < compiler->pipeline->rt.group_count; i++) {
584       struct lvp_ray_tracing_group *group = compiler->pipeline->rt.groups + i;
585       if (group->isec_index == VK_SHADER_UNUSED_KHR)
586          continue;
587 
588       nir_shader *stage = compiler->pipeline->rt.stages[group->isec_index]->nir;
589 
590       nir_push_if(b, nir_ieq_imm(b, isec_entry.value, group->handle.index));
591       lvp_call_ray_tracing_stage(b, compiler, stage);
592       nir_pop_if(b, NULL);
593 
594       nir_shader *ahit_stage = NULL;
595       if (group->ahit_index != VK_SHADER_UNUSED_KHR)
596          ahit_stage = compiler->pipeline->rt.stages[group->ahit_index]->nir;
597 
598       struct lvp_lower_isec_intrinsic_state isec_state = {
599          .compiler = compiler,
600          .ahit = ahit_stage,
601       };
602       nir_shader_intrinsics_pass(b->shader, lvp_lower_isec_intrinsic,
603                                  nir_metadata_none, &isec_state);
604    }
605 
606    nir_push_if(b, nir_load_var(b, state->accept));
607    {
608       nir_store_var(b, state->sbt_index, sbt_index, 0x1);
609       nir_store_var(b, state->traversal.hit, nir_imm_true(b), 0x1);
610 
611       nir_break_if(b, nir_load_var(b, state->terminate));
612    }
613    nir_push_else(b, NULL);
614    {
615       nir_store_var(b, state->instance_addr, prev_instance_addr, 0x1);
616       nir_store_var(b, state->primitive_id, prev_primitive_id, 0x1);
617       nir_store_var(b, state->geometry_id_and_flags, prev_geometry_id_and_flags, 0x1);
618    }
619    nir_pop_if(b, NULL);
620 }
621 
622 static void
lvp_handle_triangle_intersection(nir_builder * b,struct lvp_triangle_intersection * intersection,const struct lvp_ray_traversal_args * args,const struct lvp_ray_flags * ray_flags)623 lvp_handle_triangle_intersection(nir_builder *b,
624                                  struct lvp_triangle_intersection *intersection,
625                                  const struct lvp_ray_traversal_args *args,
626                                  const struct lvp_ray_flags *ray_flags)
627 {
628    struct lvp_ray_tracing_pipeline_compiler *compiler = args->data;
629    struct lvp_ray_tracing_state *state = &compiler->state;
630 
631    nir_store_var(b, state->accept, nir_imm_true(b), 0x1);
632    nir_store_var(b, state->terminate, ray_flags->terminate_on_first_hit, 0x1);
633 
634    nir_def *barycentrics_offset = nir_load_var(b, state->stack_ptr);
635 
636    nir_def *prev_tmax = nir_load_var(b, state->tmax);
637    nir_def *prev_instance_addr = nir_load_var(b, state->instance_addr);
638    nir_def *prev_primitive_id = nir_load_var(b, state->primitive_id);
639    nir_def *prev_geometry_id_and_flags = nir_load_var(b, state->geometry_id_and_flags);
640    nir_def *prev_hit_kind = nir_load_var(b, state->hit_kind);
641    nir_def *prev_barycentrics = nir_load_scratch(b, 2, 32, barycentrics_offset);
642 
643    nir_store_var(b, state->tmax, intersection->t, 0x1);
644    nir_store_var(b, state->instance_addr, nir_load_var(b, state->traversal.instance_addr), 0x1);
645    nir_store_var(b, state->primitive_id, intersection->base.primitive_id, 0x1);
646    nir_store_var(b, state->geometry_id_and_flags, intersection->base.geometry_id_and_flags, 0x1);
647    nir_store_var(b, state->hit_kind,
648                  nir_bcsel(b, intersection->frontface, nir_imm_int(b, 0xFE), nir_imm_int(b, 0xFF)), 0x1);
649 
650    nir_store_scratch(b, intersection->barycentrics, barycentrics_offset);
651 
652    nir_def *geometry_id = nir_iand_imm(b, intersection->base.geometry_id_and_flags, 0xfffffff);
653    nir_def *sbt_index =
654       nir_iadd(b,
655                nir_iadd(b, nir_load_var(b, state->sbt_offset),
656                         nir_iand_imm(b, nir_load_var(b, state->traversal.sbt_offset_and_flags), 0xffffff)),
657                nir_imul(b, nir_load_var(b, state->sbt_stride), geometry_id));
658 
659    nir_push_if(b, nir_inot(b, intersection->base.opaque));
660    {
661       struct lvp_sbt_entry ahit_entry = lvp_load_sbt_entry(
662          b,
663          sbt_index,
664          offsetof(VkTraceRaysIndirectCommand2KHR, hitShaderBindingTableAddress),
665          offsetof(struct lvp_ray_tracing_group_handle, index));
666       nir_store_var(b, compiler->state.shader_record_ptr, ahit_entry.shader_record_ptr, 0x1);
667 
668       for (uint32_t i = 0; i < compiler->pipeline->rt.group_count; i++) {
669          struct lvp_ray_tracing_group *group = compiler->pipeline->rt.groups + i;
670          if (group->ahit_index == VK_SHADER_UNUSED_KHR)
671             continue;
672 
673          nir_shader *stage = compiler->pipeline->rt.stages[group->ahit_index]->nir;
674 
675          nir_push_if(b, nir_ieq_imm(b, ahit_entry.value, group->handle.index));
676          lvp_call_ray_tracing_stage(b, compiler, stage);
677          nir_pop_if(b, NULL);
678       }
679    }
680    nir_pop_if(b, NULL);
681 
682    nir_push_if(b, nir_load_var(b, state->accept));
683    {
684       nir_store_var(b, state->sbt_index, sbt_index, 0x1);
685       nir_store_var(b, state->traversal.hit, nir_imm_true(b), 0x1);
686 
687       nir_break_if(b, nir_load_var(b, state->terminate));
688    }
689    nir_push_else(b, NULL);
690    {
691       nir_store_var(b, state->tmax, prev_tmax, 0x1);
692       nir_store_var(b, state->instance_addr, prev_instance_addr, 0x1);
693       nir_store_var(b, state->primitive_id, prev_primitive_id, 0x1);
694       nir_store_var(b, state->geometry_id_and_flags, prev_geometry_id_and_flags, 0x1);
695       nir_store_var(b, state->hit_kind, prev_hit_kind, 0x1);
696       nir_store_scratch(b, prev_barycentrics, barycentrics_offset);
697    }
698    nir_pop_if(b, NULL);
699 }
700 
701 static void
lvp_trace_ray(nir_builder * b,struct lvp_ray_tracing_pipeline_compiler * compiler,nir_intrinsic_instr * instr)702 lvp_trace_ray(nir_builder *b, struct lvp_ray_tracing_pipeline_compiler *compiler,
703               nir_intrinsic_instr *instr)
704 {
705    struct lvp_ray_tracing_state *state = &compiler->state;
706 
707    nir_def *accel_struct = instr->src[0].ssa;
708    nir_def *flags = instr->src[1].ssa;
709    nir_def *cull_mask = instr->src[2].ssa;
710    nir_def *sbt_offset = nir_iand_imm(b, instr->src[3].ssa, 0xF);
711    nir_def *sbt_stride = nir_iand_imm(b, instr->src[4].ssa, 0xF);
712    nir_def *miss_index = nir_iand_imm(b, instr->src[5].ssa, 0xFFFF);
713    nir_def *origin = instr->src[6].ssa;
714    nir_def *tmin = instr->src[7].ssa;
715    nir_def *dir = instr->src[8].ssa;
716    nir_def *tmax = instr->src[9].ssa;
717    nir_def *payload = instr->src[10].ssa;
718 
719    uint32_t stack_size =
720       lvp_ray_tracing_pipeline_compiler_get_stack_size(compiler, b->impl->function);
721    nir_def *stack_ptr = nir_load_var(b, state->stack_ptr);
722    nir_store_var(b, state->stack_ptr, nir_iadd_imm(b, stack_ptr, stack_size), 0x1);
723 
724    nir_store_var(b, state->shader_call_data_offset, nir_iadd_imm(b, payload, -stack_size), 0x1);
725 
726    nir_def *bvh_base = accel_struct;
727    if (bvh_base->bit_size != 64) {
728       assert(bvh_base->num_components >= 2);
729       bvh_base = nir_load_ubo(
730          b, 1, 64, nir_channel(b, accel_struct, 0),
731          nir_imul_imm(b, nir_channel(b, accel_struct, 1), sizeof(struct lp_descriptor)), .range = ~0);
732    }
733 
734    lvp_ray_traversal_state_init(b->impl, &state->traversal);
735 
736    nir_store_var(b, state->bvh_base, bvh_base, 0x1);
737    nir_store_var(b, state->flags, flags, 0x1);
738    nir_store_var(b, state->cull_mask, cull_mask, 0x1);
739    nir_store_var(b, state->sbt_offset, sbt_offset, 0x1);
740    nir_store_var(b, state->sbt_stride, sbt_stride, 0x1);
741    nir_store_var(b, state->miss_index, miss_index, 0x1);
742    nir_store_var(b, state->origin, origin, 0x7);
743    nir_store_var(b, state->tmin, tmin, 0x1);
744    nir_store_var(b, state->dir, dir, 0x7);
745    nir_store_var(b, state->tmax, tmax, 0x1);
746 
747    nir_store_var(b, state->traversal.bvh_base, bvh_base, 0x1);
748    nir_store_var(b, state->traversal.origin, origin, 0x7);
749    nir_store_var(b, state->traversal.dir, dir, 0x7);
750    nir_store_var(b, state->traversal.inv_dir, nir_frcp(b, dir), 0x7);
751    nir_store_var(b, state->traversal.current_node, nir_imm_int(b, LVP_BVH_ROOT_NODE), 0x1);
752    nir_store_var(b, state->traversal.stack_base, nir_imm_int(b, -1), 0x1);
753    nir_store_var(b, state->traversal.stack_ptr, nir_imm_int(b, 0), 0x1);
754 
755    nir_store_var(b, state->traversal.hit, nir_imm_false(b), 0x1);
756 
757    struct lvp_ray_traversal_vars vars = {
758       .tmax = nir_build_deref_var(b, state->tmax),
759       .origin = nir_build_deref_var(b, state->traversal.origin),
760       .dir = nir_build_deref_var(b, state->traversal.dir),
761       .inv_dir = nir_build_deref_var(b, state->traversal.inv_dir),
762       .bvh_base = nir_build_deref_var(b, state->traversal.bvh_base),
763       .current_node = nir_build_deref_var(b, state->traversal.current_node),
764       .stack_base = nir_build_deref_var(b, state->traversal.stack_base),
765       .stack_ptr = nir_build_deref_var(b, state->traversal.stack_ptr),
766       .stack = nir_build_deref_var(b, state->traversal.stack),
767       .instance_addr = nir_build_deref_var(b, state->traversal.instance_addr),
768       .sbt_offset_and_flags = nir_build_deref_var(b, state->traversal.sbt_offset_and_flags),
769    };
770 
771    struct lvp_ray_traversal_args args = {
772       .root_bvh_base = bvh_base,
773       .flags = flags,
774       .cull_mask = nir_ishl_imm(b, cull_mask, 24),
775       .origin = origin,
776       .tmin = tmin,
777       .dir = dir,
778       .vars = vars,
779       .aabb_cb = (compiler->flags & VK_PIPELINE_CREATE_2_RAY_TRACING_SKIP_AABBS_BIT_KHR) ?
780                  NULL : lvp_handle_aabb_intersection,
781       .triangle_cb = (compiler->flags & VK_PIPELINE_CREATE_2_RAY_TRACING_SKIP_TRIANGLES_BIT_KHR) ?
782                      NULL : lvp_handle_triangle_intersection,
783       .data = compiler,
784    };
785 
786    nir_push_if(b, nir_ine_imm(b, bvh_base, 0));
787    lvp_build_ray_traversal(b, &args);
788    nir_pop_if(b, NULL);
789 
790    nir_push_if(b, nir_load_var(b, state->traversal.hit));
791    {
792       nir_def *skip_chit = nir_test_mask(b, flags, SpvRayFlagsSkipClosestHitShaderKHRMask);
793       nir_push_if(b, nir_inot(b, skip_chit));
794 
795       struct lvp_sbt_entry chit_entry = lvp_load_sbt_entry(
796          b,
797          nir_load_var(b, state->sbt_index),
798          offsetof(VkTraceRaysIndirectCommand2KHR, hitShaderBindingTableAddress),
799          offsetof(struct lvp_ray_tracing_group_handle, index));
800       nir_store_var(b, compiler->state.shader_record_ptr, chit_entry.shader_record_ptr, 0x1);
801 
802       for (uint32_t i = 0; i < compiler->pipeline->rt.group_count; i++) {
803          struct lvp_ray_tracing_group *group = compiler->pipeline->rt.groups + i;
804          if (group->recursive_index == VK_SHADER_UNUSED_KHR)
805             continue;
806 
807          nir_shader *stage = compiler->pipeline->rt.stages[group->recursive_index]->nir;
808          if (stage->info.stage != MESA_SHADER_CLOSEST_HIT)
809             continue;
810 
811          nir_push_if(b, nir_ieq_imm(b, chit_entry.value, group->handle.index));
812          lvp_call_ray_tracing_stage(b, compiler, stage);
813          nir_pop_if(b, NULL);
814       }
815 
816       nir_pop_if(b, NULL);
817    }
818    nir_push_else(b, NULL);
819    {
820       struct lvp_sbt_entry miss_entry = lvp_load_sbt_entry(
821          b,
822          miss_index,
823          offsetof(VkTraceRaysIndirectCommand2KHR, missShaderBindingTableAddress),
824          offsetof(struct lvp_ray_tracing_group_handle, index));
825       nir_store_var(b, compiler->state.shader_record_ptr, miss_entry.shader_record_ptr, 0x1);
826 
827       for (uint32_t i = 0; i < compiler->pipeline->rt.group_count; i++) {
828          struct lvp_ray_tracing_group *group = compiler->pipeline->rt.groups + i;
829          if (group->recursive_index == VK_SHADER_UNUSED_KHR)
830             continue;
831 
832          nir_shader *stage = compiler->pipeline->rt.stages[group->recursive_index]->nir;
833          if (stage->info.stage != MESA_SHADER_MISS)
834             continue;
835 
836          nir_push_if(b, nir_ieq_imm(b, miss_entry.value, group->handle.index));
837          lvp_call_ray_tracing_stage(b, compiler, stage);
838          nir_pop_if(b, NULL);
839       }
840    }
841    nir_pop_if(b, NULL);
842 
843    nir_store_var(b, state->stack_ptr, stack_ptr, 0x1);
844 }
845 
846 static bool
lvp_lower_ray_tracing_instr(nir_builder * b,nir_instr * instr,void * data)847 lvp_lower_ray_tracing_instr(nir_builder *b, nir_instr *instr, void *data)
848 {
849    struct lvp_ray_tracing_pipeline_compiler *compiler = data;
850    struct lvp_ray_tracing_state *state = &compiler->state;
851 
852    if (instr->type == nir_instr_type_jump) {
853       nir_jump_instr *jump = nir_instr_as_jump(instr);
854       if (jump->type == nir_jump_halt) {
855          jump->type = nir_jump_return;
856          return true;
857       }
858       return false;
859    } else if (instr->type != nir_instr_type_intrinsic) {
860       return false;
861    }
862 
863    nir_intrinsic_instr *intr = nir_instr_as_intrinsic(instr);
864 
865    nir_def *def = NULL;
866 
867    b->cursor = nir_before_instr(instr);
868 
869    switch (intr->intrinsic) {
870    /* Ray tracing instructions */
871    case nir_intrinsic_execute_callable:
872       lvp_execute_callable(b, compiler, intr);
873       break;
874    case nir_intrinsic_trace_ray:
875       lvp_trace_ray(b, compiler, intr);
876       break;
877    case nir_intrinsic_ignore_ray_intersection: {
878       nir_store_var(b, state->accept, nir_imm_false(b), 0x1);
879 
880       nir_push_if(b, nir_imm_true(b));
881       nir_jump(b, nir_jump_return);
882       nir_pop_if(b, NULL);
883       break;
884    }
885    case nir_intrinsic_terminate_ray: {
886       nir_store_var(b, state->accept, nir_imm_true(b), 0x1);
887       nir_store_var(b, state->terminate, nir_imm_true(b), 0x1);
888 
889       nir_push_if(b, nir_imm_true(b));
890       nir_jump(b, nir_jump_return);
891       nir_pop_if(b, NULL);
892       break;
893    }
894    /* Ray tracing system values */
895    case nir_intrinsic_load_ray_launch_id:
896       def = nir_load_global_invocation_id(b, 32);
897       break;
898    case nir_intrinsic_load_ray_launch_size:
899       def = lvp_load_trace_ray_command_field(
900          b, offsetof(VkTraceRaysIndirectCommand2KHR, width), 3, 32);
901       break;
902    case nir_intrinsic_load_shader_record_ptr:
903       def = nir_load_var(b, state->shader_record_ptr);
904       break;
905    case nir_intrinsic_load_ray_t_min:
906       def = nir_load_var(b, state->tmin);
907       break;
908    case nir_intrinsic_load_ray_t_max:
909       def = nir_load_var(b, state->tmax);
910       break;
911    case nir_intrinsic_load_ray_world_origin:
912       def = nir_load_var(b, state->origin);
913       break;
914    case nir_intrinsic_load_ray_world_direction:
915       def = nir_load_var(b, state->dir);
916       break;
917    case nir_intrinsic_load_ray_instance_custom_index: {
918       nir_def *instance_node_addr = nir_load_var(b, state->instance_addr);
919       nir_def *custom_instance_and_mask = nir_build_load_global(
920          b, 1, 32,
921          nir_iadd_imm(b, instance_node_addr, offsetof(struct lvp_bvh_instance_node, custom_instance_and_mask)));
922       def = nir_iand_imm(b, custom_instance_and_mask, 0xFFFFFF);
923       break;
924    }
925    case nir_intrinsic_load_primitive_id:
926       def = nir_load_var(b, state->primitive_id);
927       break;
928    case nir_intrinsic_load_ray_geometry_index:
929       def = nir_load_var(b, state->geometry_id_and_flags);
930       def = nir_iand_imm(b, def, 0xFFFFFFF);
931       break;
932    case nir_intrinsic_load_instance_id: {
933       nir_def *instance_node_addr = nir_load_var(b, state->instance_addr);
934       def = nir_build_load_global(
935          b, 1, 32, nir_iadd_imm(b, instance_node_addr, offsetof(struct lvp_bvh_instance_node, instance_id)));
936       break;
937    }
938    case nir_intrinsic_load_ray_flags:
939       def = nir_load_var(b, state->flags);
940       break;
941    case nir_intrinsic_load_ray_hit_kind:
942       def = nir_load_var(b, state->hit_kind);
943       break;
944    case nir_intrinsic_load_ray_world_to_object: {
945       unsigned c = nir_intrinsic_column(intr);
946       nir_def *instance_node_addr = nir_load_var(b, state->instance_addr);
947       nir_def *wto_matrix[3];
948       lvp_load_wto_matrix(b, instance_node_addr, wto_matrix);
949 
950       nir_def *vals[3];
951       for (unsigned i = 0; i < 3; ++i)
952          vals[i] = nir_channel(b, wto_matrix[i], c);
953 
954       def = nir_vec(b, vals, 3);
955       break;
956    }
957    case nir_intrinsic_load_ray_object_to_world: {
958       unsigned c = nir_intrinsic_column(intr);
959       nir_def *instance_node_addr = nir_load_var(b, state->instance_addr);
960       nir_def *rows[3];
961       for (unsigned r = 0; r < 3; ++r)
962          rows[r] = nir_build_load_global(
963             b, 4, 32,
964             nir_iadd_imm(b, instance_node_addr, offsetof(struct lvp_bvh_instance_node, otw_matrix) + r * 16));
965       def = nir_vec3(b, nir_channel(b, rows[0], c), nir_channel(b, rows[1], c), nir_channel(b, rows[2], c));
966       break;
967    }
968    case nir_intrinsic_load_ray_object_origin: {
969       nir_def *instance_node_addr = nir_load_var(b, state->instance_addr);
970       nir_def *wto_matrix[3];
971       lvp_load_wto_matrix(b, instance_node_addr, wto_matrix);
972       def = lvp_mul_vec3_mat(b, nir_load_var(b, state->origin), wto_matrix, true);
973       break;
974    }
975    case nir_intrinsic_load_ray_object_direction: {
976       nir_def *instance_node_addr = nir_load_var(b, state->instance_addr);
977       nir_def *wto_matrix[3];
978       lvp_load_wto_matrix(b, instance_node_addr, wto_matrix);
979       def = lvp_mul_vec3_mat(b, nir_load_var(b, state->dir), wto_matrix, false);
980       break;
981    }
982    case nir_intrinsic_load_cull_mask:
983       def = nir_iand_imm(b, nir_load_var(b, state->cull_mask), 0xFF);
984       break;
985    /* Ray tracing stack lowering */
986    case nir_intrinsic_load_scratch: {
987       nir_src_rewrite(&intr->src[0], nir_iadd(b, nir_load_var(b, state->stack_ptr), intr->src[0].ssa));
988       return true;
989    }
990    case nir_intrinsic_store_scratch: {
991       nir_src_rewrite(&intr->src[1], nir_iadd(b, nir_load_var(b, state->stack_ptr), intr->src[1].ssa));
992       return true;
993    }
994    case nir_intrinsic_load_ray_triangle_vertex_positions: {
995       def = lvp_load_vertex_position(
996          b, nir_load_var(b, state->instance_addr), nir_load_var(b, state->primitive_id),
997          nir_intrinsic_column(intr));
998       break;
999    }
1000    /* Internal system values */
1001    case nir_intrinsic_load_shader_call_data_offset_lvp:
1002       def = nir_load_var(b, state->shader_call_data_offset);
1003       break;
1004    default:
1005       return false;
1006    }
1007 
1008    if (def)
1009       nir_def_rewrite_uses(&intr->def, def);
1010    nir_instr_remove(instr);
1011 
1012    return true;
1013 }
1014 
1015 static bool
lvp_lower_ray_tracing_stack_base(nir_builder * b,nir_intrinsic_instr * instr,void * data)1016 lvp_lower_ray_tracing_stack_base(nir_builder *b, nir_intrinsic_instr *instr, void *data)
1017 {
1018    if (instr->intrinsic != nir_intrinsic_load_ray_tracing_stack_base_lvp)
1019       return false;
1020 
1021    b->cursor = nir_after_instr(&instr->instr);
1022 
1023    nir_def_replace(&instr->def, nir_imm_int(b, b->shader->scratch_size));
1024 
1025    return true;
1026 }
1027 
1028 static void
lvp_compile_ray_tracing_pipeline(struct lvp_pipeline * pipeline,const VkRayTracingPipelineCreateInfoKHR * create_info)1029 lvp_compile_ray_tracing_pipeline(struct lvp_pipeline *pipeline,
1030                                  const VkRayTracingPipelineCreateInfoKHR *create_info)
1031 {
1032    nir_builder _b = nir_builder_init_simple_shader(
1033       MESA_SHADER_COMPUTE,
1034       pipeline->device->pscreen->get_compiler_options(pipeline->device->pscreen, PIPE_SHADER_IR_NIR, MESA_SHADER_COMPUTE),
1035       "ray tracing pipeline");
1036    nir_builder *b = &_b;
1037 
1038    b->shader->info.workgroup_size[0] = 8;
1039 
1040    struct lvp_ray_tracing_pipeline_compiler compiler = {
1041       .pipeline = pipeline,
1042       .flags = vk_rt_pipeline_create_flags(create_info),
1043    };
1044    lvp_ray_tracing_state_init(b->shader, &compiler.state);
1045    compiler.functions = _mesa_pointer_hash_table_create(NULL);
1046 
1047    nir_def *launch_id = nir_load_ray_launch_id(b);
1048    nir_def *launch_size = nir_load_ray_launch_size(b);
1049    nir_def *oob = nir_ige(b, nir_channel(b, launch_id, 0), nir_channel(b, launch_size, 0));
1050    oob = nir_ior(b, oob, nir_ige(b, nir_channel(b, launch_id, 1), nir_channel(b, launch_size, 1)));
1051    oob = nir_ior(b, oob, nir_ige(b, nir_channel(b, launch_id, 2), nir_channel(b, launch_size, 2)));
1052 
1053    nir_push_if(b, oob);
1054    nir_jump(b, nir_jump_return);
1055    nir_pop_if(b, NULL);
1056 
1057    nir_store_var(b, compiler.state.stack_ptr, nir_load_ray_tracing_stack_base_lvp(b), 0x1);
1058 
1059    struct lvp_sbt_entry raygen_entry = lvp_load_sbt_entry(
1060       b,
1061       NULL,
1062       offsetof(VkTraceRaysIndirectCommand2KHR, raygenShaderRecordAddress),
1063       offsetof(struct lvp_ray_tracing_group_handle, index));
1064    nir_store_var(b, compiler.state.shader_record_ptr, raygen_entry.shader_record_ptr, 0x1);
1065 
1066    for (uint32_t i = 0; i < pipeline->rt.group_count; i++) {
1067       struct lvp_ray_tracing_group *group = pipeline->rt.groups + i;
1068       if (group->recursive_index == VK_SHADER_UNUSED_KHR)
1069          continue;
1070 
1071       nir_shader *stage = pipeline->rt.stages[group->recursive_index]->nir;
1072 
1073       if (stage->info.stage != MESA_SHADER_RAYGEN)
1074          continue;
1075 
1076       nir_push_if(b, nir_ieq_imm(b, raygen_entry.value, group->handle.index));
1077       lvp_call_ray_tracing_stage(b, &compiler, stage);
1078       nir_pop_if(b, NULL);
1079    }
1080 
1081    nir_shader_instructions_pass(b->shader, lvp_lower_ray_tracing_instr, nir_metadata_none, &compiler);
1082 
1083    NIR_PASS(_, b->shader, nir_lower_returns);
1084 
1085    const struct nir_lower_compute_system_values_options compute_system_values = {0};
1086    NIR_PASS(_, b->shader, nir_lower_compute_system_values, &compute_system_values);
1087    NIR_PASS(_, b->shader, nir_lower_global_vars_to_local);
1088    NIR_PASS(_, b->shader, nir_lower_vars_to_ssa);
1089 
1090    NIR_PASS(_, b->shader, nir_lower_vars_to_explicit_types,
1091             nir_var_shader_temp,
1092             glsl_get_natural_size_align_bytes);
1093 
1094    NIR_PASS(_, b->shader, nir_lower_explicit_io, nir_var_shader_temp,
1095             nir_address_format_32bit_offset);
1096 
1097    NIR_PASS(_, b->shader, nir_shader_intrinsics_pass, lvp_lower_ray_tracing_stack_base,
1098             nir_metadata_control_flow, NULL);
1099 
1100    /* We can not support dynamic stack sizes, assume the worst. */
1101    b->shader->scratch_size +=
1102       compiler.raygen_size +
1103       MIN2(create_info->maxPipelineRayRecursionDepth, 1) * MAX3(compiler.chit_size, compiler.miss_size, compiler.isec_size + compiler.ahit_size) +
1104       MAX2(0, (int)create_info->maxPipelineRayRecursionDepth - 1) * MAX2(compiler.chit_size, compiler.miss_size) + 31 * compiler.callable_size;
1105 
1106    struct lvp_shader *shader = &pipeline->shaders[MESA_SHADER_RAYGEN];
1107    lvp_shader_init(shader, b->shader);
1108    shader->shader_cso = lvp_shader_compile(pipeline->device, shader, nir_shader_clone(NULL, shader->pipeline_nir->nir), false);
1109 
1110    _mesa_hash_table_destroy(compiler.functions, NULL);
1111 }
1112 
1113 static VkResult
lvp_create_ray_tracing_pipeline(VkDevice _device,const VkAllocationCallbacks * allocator,const VkRayTracingPipelineCreateInfoKHR * create_info,VkPipeline * out_pipeline)1114 lvp_create_ray_tracing_pipeline(VkDevice _device, const VkAllocationCallbacks *allocator,
1115                                 const VkRayTracingPipelineCreateInfoKHR *create_info,
1116                                 VkPipeline *out_pipeline)
1117 {
1118    VK_FROM_HANDLE(lvp_device, device, _device);
1119    VK_FROM_HANDLE(lvp_pipeline_layout, layout, create_info->layout);
1120 
1121    VkResult result = VK_SUCCESS;
1122 
1123    struct lvp_pipeline *pipeline = vk_zalloc2(&device->vk.alloc, allocator, sizeof(struct lvp_pipeline), 8,
1124                                               VK_SYSTEM_ALLOCATION_SCOPE_OBJECT);
1125    if (!pipeline)
1126       return VK_ERROR_OUT_OF_HOST_MEMORY;
1127 
1128    vk_object_base_init(&device->vk, &pipeline->base,
1129                        VK_OBJECT_TYPE_PIPELINE);
1130 
1131    vk_pipeline_layout_ref(&layout->vk);
1132 
1133    pipeline->device = device;
1134    pipeline->layout = layout;
1135    pipeline->type = LVP_PIPELINE_RAY_TRACING;
1136    pipeline->flags = vk_rt_pipeline_create_flags(create_info);
1137 
1138    pipeline->rt.stage_count = create_info->stageCount;
1139    pipeline->rt.group_count = create_info->groupCount;
1140    if (create_info->pLibraryInfo) {
1141       for (uint32_t i = 0; i < create_info->pLibraryInfo->libraryCount; i++) {
1142          VK_FROM_HANDLE(lvp_pipeline, library, create_info->pLibraryInfo->pLibraries[i]);
1143          pipeline->rt.stage_count += library->rt.stage_count;
1144          pipeline->rt.group_count += library->rt.group_count;
1145       }
1146    }
1147 
1148    pipeline->rt.stages = calloc(pipeline->rt.stage_count, sizeof(struct lvp_pipeline_nir *));
1149    pipeline->rt.groups = calloc(pipeline->rt.group_count, sizeof(struct lvp_ray_tracing_group));
1150    if (!pipeline->rt.stages || !pipeline->rt.groups) {
1151       result = VK_ERROR_OUT_OF_HOST_MEMORY;
1152       goto fail;
1153    }
1154 
1155    result = lvp_compile_ray_tracing_stages(pipeline, create_info);
1156    if (result != VK_SUCCESS)
1157       goto fail;
1158 
1159    lvp_init_ray_tracing_groups(pipeline, create_info);
1160 
1161    if (!(pipeline->flags & VK_PIPELINE_CREATE_2_LIBRARY_BIT_KHR)) {
1162       lvp_compile_ray_tracing_pipeline(pipeline, create_info);
1163    }
1164 
1165    *out_pipeline = lvp_pipeline_to_handle(pipeline);
1166 
1167    return VK_SUCCESS;
1168 
1169 fail:
1170    lvp_pipeline_destroy(device, pipeline, false);
1171    return result;
1172 }
1173 
1174 VKAPI_ATTR VkResult VKAPI_CALL
lvp_CreateRayTracingPipelinesKHR(VkDevice device,VkDeferredOperationKHR deferredOperation,VkPipelineCache pipelineCache,uint32_t createInfoCount,const VkRayTracingPipelineCreateInfoKHR * pCreateInfos,const VkAllocationCallbacks * pAllocator,VkPipeline * pPipelines)1175 lvp_CreateRayTracingPipelinesKHR(
1176    VkDevice device,
1177    VkDeferredOperationKHR deferredOperation,
1178    VkPipelineCache pipelineCache,
1179    uint32_t createInfoCount,
1180    const VkRayTracingPipelineCreateInfoKHR *pCreateInfos,
1181    const VkAllocationCallbacks *pAllocator,
1182    VkPipeline *pPipelines)
1183 {
1184    VkResult result = VK_SUCCESS;
1185 
1186    uint32_t i = 0;
1187    for (; i < createInfoCount; i++) {
1188       VkResult tmp_result = lvp_create_ray_tracing_pipeline(
1189          device, pAllocator, pCreateInfos + i, pPipelines + i);
1190 
1191       if (tmp_result != VK_SUCCESS) {
1192          result = tmp_result;
1193          pPipelines[i] = VK_NULL_HANDLE;
1194 
1195          if (vk_rt_pipeline_create_flags(&pCreateInfos[i]) &
1196              VK_PIPELINE_CREATE_2_EARLY_RETURN_ON_FAILURE_BIT_KHR)
1197             break;
1198       }
1199    }
1200 
1201    for (; i < createInfoCount; i++)
1202       pPipelines[i] = VK_NULL_HANDLE;
1203 
1204    return result;
1205 }
1206 
1207 
1208 VKAPI_ATTR VkResult VKAPI_CALL
lvp_GetRayTracingShaderGroupHandlesKHR(VkDevice _device,VkPipeline _pipeline,uint32_t firstGroup,uint32_t groupCount,size_t dataSize,void * pData)1209 lvp_GetRayTracingShaderGroupHandlesKHR(
1210     VkDevice _device,
1211     VkPipeline _pipeline,
1212     uint32_t firstGroup,
1213     uint32_t groupCount,
1214     size_t dataSize,
1215     void *pData)
1216 {
1217    VK_FROM_HANDLE(lvp_pipeline, pipeline, _pipeline);
1218 
1219    uint8_t *data = pData;
1220    memset(data, 0, dataSize);
1221 
1222    for (uint32_t i = 0; i < groupCount; i++) {
1223       memcpy(data + i * LVP_RAY_TRACING_GROUP_HANDLE_SIZE,
1224              pipeline->rt.groups + firstGroup + i,
1225              sizeof(struct lvp_ray_tracing_group_handle));
1226    }
1227 
1228    return VK_SUCCESS;
1229 }
1230 
1231 VKAPI_ATTR VkResult VKAPI_CALL
lvp_GetRayTracingCaptureReplayShaderGroupHandlesKHR(VkDevice device,VkPipeline pipeline,uint32_t firstGroup,uint32_t groupCount,size_t dataSize,void * pData)1232 lvp_GetRayTracingCaptureReplayShaderGroupHandlesKHR(
1233    VkDevice device,
1234    VkPipeline pipeline,
1235    uint32_t firstGroup,
1236    uint32_t groupCount,
1237    size_t dataSize,
1238    void *pData)
1239 {
1240    return VK_SUCCESS;
1241 }
1242 
1243 VKAPI_ATTR VkDeviceSize VKAPI_CALL
lvp_GetRayTracingShaderGroupStackSizeKHR(VkDevice device,VkPipeline pipeline,uint32_t group,VkShaderGroupShaderKHR groupShader)1244 lvp_GetRayTracingShaderGroupStackSizeKHR(
1245    VkDevice device,
1246    VkPipeline pipeline,
1247    uint32_t group,
1248    VkShaderGroupShaderKHR groupShader)
1249 {
1250    return 4;
1251 }
1252