1 /*
2 * Copyright © 2024 Valve Corporation
3 *
4 * SPDX-License-Identifier: MIT
5 */
6
7 #include "lvp_private.h"
8 #include "lvp_acceleration_structure.h"
9 #include "lvp_nir_ray_tracing.h"
10
11 #include "vk_pipeline.h"
12
13 #include "nir.h"
14 #include "nir_builder.h"
15
16 #include "spirv/spirv.h"
17
18 #include "util/mesa-sha1.h"
19 #include "util/simple_mtx.h"
20
21 static void
lvp_init_ray_tracing_groups(struct lvp_pipeline * pipeline,const VkRayTracingPipelineCreateInfoKHR * create_info)22 lvp_init_ray_tracing_groups(struct lvp_pipeline *pipeline,
23 const VkRayTracingPipelineCreateInfoKHR *create_info)
24 {
25 uint32_t i = 0;
26 for (; i < create_info->groupCount; i++) {
27 const VkRayTracingShaderGroupCreateInfoKHR *group_info = create_info->pGroups + i;
28 struct lvp_ray_tracing_group *dst = pipeline->rt.groups + i;
29
30 dst->recursive_index = VK_SHADER_UNUSED_KHR;
31 dst->ahit_index = VK_SHADER_UNUSED_KHR;
32 dst->isec_index = VK_SHADER_UNUSED_KHR;
33
34 switch (group_info->type) {
35 case VK_RAY_TRACING_SHADER_GROUP_TYPE_GENERAL_KHR:
36 if (group_info->generalShader != VK_SHADER_UNUSED_KHR) {
37 dst->recursive_index = group_info->generalShader;
38 }
39 break;
40 case VK_RAY_TRACING_SHADER_GROUP_TYPE_TRIANGLES_HIT_GROUP_KHR:
41 if (group_info->closestHitShader != VK_SHADER_UNUSED_KHR) {
42 dst->recursive_index = group_info->closestHitShader;
43 }
44 if (group_info->anyHitShader != VK_SHADER_UNUSED_KHR) {
45 dst->ahit_index = group_info->anyHitShader;
46 }
47 break;
48 case VK_RAY_TRACING_SHADER_GROUP_TYPE_PROCEDURAL_HIT_GROUP_KHR:
49 if (group_info->closestHitShader != VK_SHADER_UNUSED_KHR) {
50 dst->recursive_index = group_info->closestHitShader;
51 }
52 if (group_info->intersectionShader != VK_SHADER_UNUSED_KHR) {
53 dst->isec_index = group_info->intersectionShader;
54
55 if (group_info->anyHitShader != VK_SHADER_UNUSED_KHR)
56 dst->ahit_index = group_info->anyHitShader;
57 }
58 break;
59 default:
60 unreachable("Unimplemented VkRayTracingShaderGroupTypeKHR");
61 }
62
63 dst->handle.index = p_atomic_inc_return(&pipeline->device->group_handle_alloc);
64 }
65
66 if (!create_info->pLibraryInfo)
67 return;
68
69 uint32_t stage_base_index = create_info->stageCount;
70 for (uint32_t library_index = 0; library_index < create_info->pLibraryInfo->libraryCount; library_index++) {
71 VK_FROM_HANDLE(lvp_pipeline, library, create_info->pLibraryInfo->pLibraries[library_index]);
72 for (uint32_t group_index = 0; group_index < library->rt.group_count; group_index++) {
73 const struct lvp_ray_tracing_group *src = library->rt.groups + group_index;
74 struct lvp_ray_tracing_group *dst = pipeline->rt.groups + i;
75
76 dst->handle = src->handle;
77
78 if (src->recursive_index != VK_SHADER_UNUSED_KHR)
79 dst->recursive_index = stage_base_index + src->recursive_index;
80 else
81 dst->recursive_index = VK_SHADER_UNUSED_KHR;
82
83 if (src->ahit_index != VK_SHADER_UNUSED_KHR)
84 dst->ahit_index = stage_base_index + src->ahit_index;
85 else
86 dst->ahit_index = VK_SHADER_UNUSED_KHR;
87
88 if (src->isec_index != VK_SHADER_UNUSED_KHR)
89 dst->isec_index = stage_base_index + src->isec_index;
90 else
91 dst->isec_index = VK_SHADER_UNUSED_KHR;
92
93 i++;
94 }
95 stage_base_index += library->rt.stage_count;
96 }
97 }
98
99 static bool
lvp_lower_ray_tracing_derefs(nir_shader * shader)100 lvp_lower_ray_tracing_derefs(nir_shader *shader)
101 {
102 nir_function_impl *impl = nir_shader_get_entrypoint(shader);
103
104 bool progress = false;
105
106 nir_builder _b = nir_builder_at(nir_before_impl(impl));
107 nir_builder *b = &_b;
108
109 nir_def *arg_offset = nir_load_shader_call_data_offset_lvp(b);
110
111 nir_foreach_block (block, impl) {
112 nir_foreach_instr_safe (instr, block) {
113 if (instr->type != nir_instr_type_deref)
114 continue;
115
116 nir_deref_instr *deref = nir_instr_as_deref(instr);
117 if (!nir_deref_mode_is_one_of(deref, nir_var_shader_call_data |
118 nir_var_ray_hit_attrib))
119 continue;
120
121 bool is_shader_call_data = nir_deref_mode_is(deref, nir_var_shader_call_data);
122
123 deref->modes = nir_var_function_temp;
124 progress = true;
125
126 if (deref->deref_type == nir_deref_type_var) {
127 b->cursor = nir_before_instr(&deref->instr);
128 nir_def *offset = is_shader_call_data ? arg_offset : nir_imm_int(b, 0);
129 nir_deref_instr *replacement =
130 nir_build_deref_cast(b, offset, nir_var_function_temp, deref->var->type, 0);
131 nir_def_replace(&deref->def, &replacement->def);
132 }
133 }
134 }
135
136 if (progress)
137 nir_metadata_preserve(impl, nir_metadata_control_flow);
138 else
139 nir_metadata_preserve(impl, nir_metadata_all);
140
141 return progress;
142 }
143
144 static bool
lvp_move_ray_tracing_intrinsic(nir_builder * b,nir_intrinsic_instr * instr,void * data)145 lvp_move_ray_tracing_intrinsic(nir_builder *b, nir_intrinsic_instr *instr, void *data)
146 {
147 switch (instr->intrinsic) {
148 case nir_intrinsic_load_shader_record_ptr:
149 case nir_intrinsic_load_ray_flags:
150 case nir_intrinsic_load_ray_object_origin:
151 case nir_intrinsic_load_ray_world_origin:
152 case nir_intrinsic_load_ray_t_min:
153 case nir_intrinsic_load_ray_object_direction:
154 case nir_intrinsic_load_ray_world_direction:
155 case nir_intrinsic_load_ray_t_max:
156 nir_instr_move(nir_before_impl(b->impl), &instr->instr);
157 return true;
158 default:
159 return false;
160 }
161 }
162
163 static VkResult
lvp_compile_ray_tracing_stages(struct lvp_pipeline * pipeline,const VkRayTracingPipelineCreateInfoKHR * create_info)164 lvp_compile_ray_tracing_stages(struct lvp_pipeline *pipeline,
165 const VkRayTracingPipelineCreateInfoKHR *create_info)
166 {
167 VkResult result = VK_SUCCESS;
168
169 uint32_t i = 0;
170 for (; i < create_info->stageCount; i++) {
171 nir_shader *nir;
172 result = lvp_spirv_to_nir(pipeline, create_info->pStages + i, &nir);
173 if (result != VK_SUCCESS)
174 return result;
175
176 assert(!nir->scratch_size);
177 if (nir->info.stage == MESA_SHADER_ANY_HIT ||
178 nir->info.stage == MESA_SHADER_CLOSEST_HIT ||
179 nir->info.stage == MESA_SHADER_INTERSECTION)
180 nir->scratch_size = LVP_RAY_HIT_ATTRIBS_SIZE;
181
182 NIR_PASS(_, nir, nir_lower_vars_to_explicit_types,
183 nir_var_function_temp | nir_var_shader_call_data | nir_var_ray_hit_attrib,
184 glsl_get_natural_size_align_bytes);
185
186 NIR_PASS(_, nir, lvp_lower_ray_tracing_derefs);
187
188 NIR_PASS(_, nir, nir_lower_explicit_io, nir_var_function_temp, nir_address_format_32bit_offset);
189
190 NIR_PASS(_, nir, nir_shader_intrinsics_pass, lvp_move_ray_tracing_intrinsic,
191 nir_metadata_control_flow, NULL);
192
193 pipeline->rt.stages[i] = lvp_create_pipeline_nir(nir);
194 if (!pipeline->rt.stages[i]) {
195 result = VK_ERROR_OUT_OF_HOST_MEMORY;
196 ralloc_free(nir);
197 return result;
198 }
199 }
200
201 if (!create_info->pLibraryInfo)
202 return result;
203
204 for (uint32_t library_index = 0; library_index < create_info->pLibraryInfo->libraryCount; library_index++) {
205 VK_FROM_HANDLE(lvp_pipeline, library, create_info->pLibraryInfo->pLibraries[library_index]);
206 for (uint32_t stage_index = 0; stage_index < library->rt.stage_count; stage_index++) {
207 lvp_pipeline_nir_ref(pipeline->rt.stages + i, library->rt.stages[stage_index]);
208 i++;
209 }
210 }
211
212 return result;
213 }
214
215 static nir_def *
lvp_load_trace_ray_command_field(nir_builder * b,uint32_t command_offset,uint32_t num_components,uint32_t bit_size)216 lvp_load_trace_ray_command_field(nir_builder *b, uint32_t command_offset,
217 uint32_t num_components, uint32_t bit_size)
218 {
219 return nir_load_ssbo(b, num_components, bit_size, nir_imm_int(b, 0),
220 nir_imm_int(b, command_offset));
221 }
222
223 struct lvp_sbt_entry {
224 nir_def *value;
225 nir_def *shader_record_ptr;
226 };
227
228 static struct lvp_sbt_entry
lvp_load_sbt_entry(nir_builder * b,nir_def * index,uint32_t command_offset,uint32_t index_offset)229 lvp_load_sbt_entry(nir_builder *b, nir_def *index,
230 uint32_t command_offset, uint32_t index_offset)
231 {
232 nir_def *addr = lvp_load_trace_ray_command_field(b, command_offset, 1, 64);
233
234 if (index) {
235 /* The 32 high bits of stride can be ignored. */
236 nir_def *stride = lvp_load_trace_ray_command_field(
237 b, command_offset + sizeof(VkDeviceSize) * 2, 1, 32);
238 addr = nir_iadd(b, addr, nir_u2u64(b, nir_imul(b, index, stride)));
239 }
240
241 return (struct lvp_sbt_entry) {
242 .value = nir_build_load_global(b, 1, 32, nir_iadd_imm(b, addr, index_offset)),
243 .shader_record_ptr = nir_iadd_imm(b, addr, LVP_RAY_TRACING_GROUP_HANDLE_SIZE),
244 };
245 }
246
247 struct lvp_ray_traversal_state {
248 nir_variable *origin;
249 nir_variable *dir;
250 nir_variable *inv_dir;
251 nir_variable *bvh_base;
252 nir_variable *current_node;
253 nir_variable *stack_base;
254 nir_variable *stack_ptr;
255 nir_variable *stack;
256 nir_variable *hit;
257
258 nir_variable *instance_addr;
259 nir_variable *sbt_offset_and_flags;
260 };
261
262 struct lvp_ray_tracing_state {
263 nir_variable *bvh_base;
264 nir_variable *flags;
265 nir_variable *cull_mask;
266 nir_variable *sbt_offset;
267 nir_variable *sbt_stride;
268 nir_variable *miss_index;
269 nir_variable *origin;
270 nir_variable *tmin;
271 nir_variable *dir;
272 nir_variable *tmax;
273
274 nir_variable *instance_addr;
275 nir_variable *primitive_id;
276 nir_variable *geometry_id_and_flags;
277 nir_variable *hit_kind;
278 nir_variable *sbt_index;
279
280 nir_variable *shader_record_ptr;
281 nir_variable *stack_ptr;
282 nir_variable *shader_call_data_offset;
283
284 nir_variable *accept;
285 nir_variable *terminate;
286 nir_variable *opaque;
287
288 struct lvp_ray_traversal_state traversal;
289 };
290
291 struct lvp_ray_tracing_pipeline_compiler {
292 struct lvp_pipeline *pipeline;
293 VkPipelineCreateFlags2KHR flags;
294
295 struct lvp_ray_tracing_state state;
296
297 struct hash_table *functions;
298
299 uint32_t raygen_size;
300 uint32_t ahit_size;
301 uint32_t chit_size;
302 uint32_t miss_size;
303 uint32_t isec_size;
304 uint32_t callable_size;
305 };
306
307 static uint32_t
lvp_ray_tracing_pipeline_compiler_get_stack_size(struct lvp_ray_tracing_pipeline_compiler * compiler,nir_function * function)308 lvp_ray_tracing_pipeline_compiler_get_stack_size(
309 struct lvp_ray_tracing_pipeline_compiler *compiler, nir_function *function)
310 {
311 hash_table_foreach(compiler->functions, entry) {
312 if (entry->data == function) {
313 const nir_shader *shader = entry->key;
314 return shader->scratch_size;
315 }
316 }
317 return 0;
318 }
319
320 static void
lvp_ray_tracing_state_init(nir_shader * nir,struct lvp_ray_tracing_state * state)321 lvp_ray_tracing_state_init(nir_shader *nir, struct lvp_ray_tracing_state *state)
322 {
323 state->bvh_base = nir_variable_create(nir, nir_var_shader_temp, glsl_uint64_t_type(), "bvh_base");
324 state->flags = nir_variable_create(nir, nir_var_shader_temp, glsl_uint_type(), "flags");
325 state->cull_mask = nir_variable_create(nir, nir_var_shader_temp, glsl_uint_type(), "cull_mask");
326 state->sbt_offset = nir_variable_create(nir, nir_var_shader_temp, glsl_uint_type(), "sbt_offset");
327 state->sbt_stride = nir_variable_create(nir, nir_var_shader_temp, glsl_uint_type(), "sbt_stride");
328 state->miss_index = nir_variable_create(nir, nir_var_shader_temp, glsl_uint_type(), "miss_index");
329 state->origin = nir_variable_create(nir, nir_var_shader_temp, glsl_vec_type(3), "origin");
330 state->tmin = nir_variable_create(nir, nir_var_shader_temp, glsl_float_type(), "tmin");
331 state->dir = nir_variable_create(nir, nir_var_shader_temp, glsl_vec_type(3), "dir");
332 state->tmax = nir_variable_create(nir, nir_var_shader_temp, glsl_float_type(), "tmax");
333
334 state->instance_addr = nir_variable_create(nir, nir_var_shader_temp, glsl_uint64_t_type(), "instance_addr");
335 state->primitive_id = nir_variable_create(nir, nir_var_shader_temp, glsl_uint_type(), "primitive_id");
336 state->geometry_id_and_flags = nir_variable_create(nir, nir_var_shader_temp, glsl_uint_type(), "geometry_id_and_flags");
337 state->hit_kind = nir_variable_create(nir, nir_var_shader_temp, glsl_uint_type(), "hit_kind");
338 state->sbt_index = nir_variable_create(nir, nir_var_shader_temp, glsl_uint_type(), "sbt_index");
339
340 state->shader_record_ptr = nir_variable_create(nir, nir_var_shader_temp, glsl_uint64_t_type(), "shader_record_ptr");
341 state->stack_ptr = nir_variable_create(nir, nir_var_shader_temp, glsl_uint_type(), "stack_ptr");
342 state->shader_call_data_offset = nir_variable_create(nir, nir_var_shader_temp, glsl_uint_type(), "shader_call_data_offset");
343
344 state->accept = nir_variable_create(nir, nir_var_shader_temp, glsl_bool_type(), "accept");
345 state->terminate = nir_variable_create(nir, nir_var_shader_temp, glsl_bool_type(), "terminate");
346 state->opaque = nir_variable_create(nir, nir_var_shader_temp, glsl_bool_type(), "opaque");
347 }
348
349 static void
lvp_ray_traversal_state_init(nir_function_impl * impl,struct lvp_ray_traversal_state * state)350 lvp_ray_traversal_state_init(nir_function_impl *impl, struct lvp_ray_traversal_state *state)
351 {
352 state->origin = nir_local_variable_create(impl, glsl_vec_type(3), "traversal.origin");
353 state->dir = nir_local_variable_create(impl, glsl_vec_type(3), "traversal.dir");
354 state->inv_dir = nir_local_variable_create(impl, glsl_vec_type(3), "traversal.inv_dir");
355 state->bvh_base = nir_local_variable_create(impl, glsl_uint64_t_type(), "traversal.bvh_base");
356 state->current_node = nir_local_variable_create(impl, glsl_uint_type(), "traversal.current_node");
357 state->stack_base = nir_local_variable_create(impl, glsl_uint_type(), "traversal.stack_base");
358 state->stack_ptr = nir_local_variable_create(impl, glsl_uint_type(), "traversal.stack_ptr");
359 state->stack = nir_local_variable_create(impl, glsl_array_type(glsl_uint_type(), 24 * 2, 0), "traversal.stack");
360 state->hit = nir_local_variable_create(impl, glsl_bool_type(), "traversal.hit");
361
362 state->instance_addr = nir_local_variable_create(impl, glsl_uint64_t_type(), "traversal.instance_addr");
363 state->sbt_offset_and_flags = nir_local_variable_create(impl, glsl_uint_type(), "traversal.sbt_offset_and_flags");
364 }
365
366 static void
lvp_call_ray_tracing_stage(nir_builder * b,struct lvp_ray_tracing_pipeline_compiler * compiler,nir_shader * stage)367 lvp_call_ray_tracing_stage(nir_builder *b, struct lvp_ray_tracing_pipeline_compiler *compiler, nir_shader *stage)
368 {
369 nir_function *function;
370
371 struct hash_entry *entry = _mesa_hash_table_search(compiler->functions, stage);
372 if (entry) {
373 function = entry->data;
374 } else {
375 nir_function_impl *stage_entrypoint = nir_shader_get_entrypoint(stage);
376 nir_function_impl *copy = nir_function_impl_clone(b->shader, stage_entrypoint);
377
378 struct hash_table *var_remap = _mesa_pointer_hash_table_create(NULL);
379
380 nir_foreach_block(block, copy) {
381 nir_foreach_instr_safe(instr, block) {
382 if (instr->type != nir_instr_type_deref)
383 continue;
384
385 nir_deref_instr *deref = nir_instr_as_deref(instr);
386 if (deref->deref_type != nir_deref_type_var ||
387 deref->var->data.mode == nir_var_function_temp)
388 continue;
389
390 struct hash_entry *entry =
391 _mesa_hash_table_search(var_remap, deref->var);
392 if (!entry) {
393 nir_variable *new_var = nir_variable_clone(deref->var, b->shader);
394 nir_shader_add_variable(b->shader, new_var);
395 entry = _mesa_hash_table_insert(var_remap,
396 deref->var, new_var);
397 }
398 deref->var = entry->data;
399 }
400 }
401
402 function = nir_function_create(
403 b->shader, _mesa_shader_stage_to_string(stage->info.stage));
404 nir_function_set_impl(function, copy);
405
406 ralloc_free(var_remap);
407
408 _mesa_hash_table_insert(compiler->functions, stage, function);
409 }
410
411 nir_build_call(b, function, 0, NULL);
412
413 switch(stage->info.stage) {
414 case MESA_SHADER_RAYGEN:
415 compiler->raygen_size = MAX2(compiler->raygen_size, stage->scratch_size);
416 break;
417 case MESA_SHADER_ANY_HIT:
418 compiler->ahit_size = MAX2(compiler->ahit_size, stage->scratch_size);
419 break;
420 case MESA_SHADER_CLOSEST_HIT:
421 compiler->chit_size = MAX2(compiler->chit_size, stage->scratch_size);
422 break;
423 case MESA_SHADER_MISS:
424 compiler->miss_size = MAX2(compiler->miss_size, stage->scratch_size);
425 break;
426 case MESA_SHADER_INTERSECTION:
427 compiler->isec_size = MAX2(compiler->isec_size, stage->scratch_size);
428 break;
429 case MESA_SHADER_CALLABLE:
430 compiler->callable_size = MAX2(compiler->callable_size, stage->scratch_size);
431 break;
432 default:
433 unreachable("Invalid ray tracing stage");
434 break;
435 }
436 }
437
438 static void
lvp_execute_callable(nir_builder * b,struct lvp_ray_tracing_pipeline_compiler * compiler,nir_intrinsic_instr * instr)439 lvp_execute_callable(nir_builder *b, struct lvp_ray_tracing_pipeline_compiler *compiler,
440 nir_intrinsic_instr *instr)
441 {
442 struct lvp_ray_tracing_state *state = &compiler->state;
443
444 nir_def *sbt_index = instr->src[0].ssa;
445 nir_def *payload = instr->src[1].ssa;
446
447 struct lvp_sbt_entry callable_entry = lvp_load_sbt_entry(
448 b,
449 sbt_index,
450 offsetof(VkTraceRaysIndirectCommand2KHR, callableShaderBindingTableAddress),
451 offsetof(struct lvp_ray_tracing_group_handle, index));
452 nir_store_var(b, compiler->state.shader_record_ptr, callable_entry.shader_record_ptr, 0x1);
453
454 uint32_t stack_size =
455 lvp_ray_tracing_pipeline_compiler_get_stack_size(compiler, b->impl->function);
456 nir_def *stack_ptr = nir_load_var(b, state->stack_ptr);
457 nir_store_var(b, state->stack_ptr, nir_iadd_imm(b, stack_ptr, stack_size), 0x1);
458
459 nir_store_var(b, state->shader_call_data_offset, nir_iadd_imm(b, payload, -stack_size), 0x1);
460
461 for (uint32_t i = 0; i < compiler->pipeline->rt.group_count; i++) {
462 struct lvp_ray_tracing_group *group = compiler->pipeline->rt.groups + i;
463 if (group->recursive_index == VK_SHADER_UNUSED_KHR)
464 continue;
465
466 nir_shader *stage = compiler->pipeline->rt.stages[group->recursive_index]->nir;
467 if (stage->info.stage != MESA_SHADER_CALLABLE)
468 continue;
469
470 nir_push_if(b, nir_ieq_imm(b, callable_entry.value, group->handle.index));
471 lvp_call_ray_tracing_stage(b, compiler, stage);
472 nir_pop_if(b, NULL);
473 }
474
475 nir_store_var(b, state->stack_ptr, stack_ptr, 0x1);
476 }
477
478 struct lvp_lower_isec_intrinsic_state {
479 struct lvp_ray_tracing_pipeline_compiler *compiler;
480 nir_shader *ahit;
481 };
482
483 static bool
lvp_lower_isec_intrinsic(nir_builder * b,nir_intrinsic_instr * instr,void * data)484 lvp_lower_isec_intrinsic(nir_builder *b, nir_intrinsic_instr *instr, void *data)
485 {
486 if (instr->intrinsic != nir_intrinsic_report_ray_intersection)
487 return false;
488
489 struct lvp_lower_isec_intrinsic_state *isec_state = data;
490 struct lvp_ray_tracing_pipeline_compiler *compiler = isec_state->compiler;
491 struct lvp_ray_tracing_state *state = &compiler->state;
492
493 b->cursor = nir_after_instr(&instr->instr);
494
495 nir_def *t = instr->src[0].ssa;
496 nir_def *hit_kind = instr->src[1].ssa;
497
498 nir_def *prev_accept = nir_load_var(b, state->accept);
499 nir_def *prev_tmax = nir_load_var(b, state->tmax);
500 nir_def *prev_hit_kind = nir_load_var(b, state->hit_kind);
501
502 nir_variable *commit = nir_local_variable_create(b->impl, glsl_bool_type(), "commit");
503 nir_store_var(b, commit, nir_imm_false(b), 0x1);
504
505 nir_push_if(b, nir_iand(b, nir_fge(b, t, nir_load_var(b, state->tmin)), nir_fge(b, nir_load_var(b, state->tmax), t)));
506 {
507 nir_store_var(b, state->accept, nir_imm_true(b), 0x1);
508
509 nir_store_var(b, state->tmax, t, 1);
510 nir_store_var(b, state->hit_kind, hit_kind, 1);
511
512 if (isec_state->ahit) {
513 nir_def *prev_terminate = nir_load_var(b, state->terminate);
514 nir_store_var(b, state->terminate, nir_imm_false(b), 0x1);
515
516 nir_push_if(b, nir_inot(b, nir_load_var(b, state->opaque)));
517 {
518 lvp_call_ray_tracing_stage(b, compiler, isec_state->ahit);
519 }
520 nir_pop_if(b, NULL);
521
522 nir_def *terminate = nir_load_var(b, state->terminate);
523 nir_store_var(b, state->terminate, nir_ior(b, terminate, prev_terminate), 0x1);
524
525 nir_push_if(b, terminate);
526 nir_jump(b, nir_jump_return);
527 nir_pop_if(b, NULL);
528 }
529
530 nir_push_if(b, nir_load_var(b, state->accept));
531 {
532 nir_store_var(b, commit, nir_imm_true(b), 0x1);
533 }
534 nir_push_else(b, NULL);
535 {
536 nir_store_var(b, state->accept, prev_accept, 0x1);
537 nir_store_var(b, state->tmax, prev_tmax, 1);
538 nir_store_var(b, state->hit_kind, prev_hit_kind, 1);
539 }
540 nir_pop_if(b, NULL);
541 }
542 nir_pop_if(b, NULL);
543
544 nir_def_replace(&instr->def, nir_load_var(b, commit));
545
546 return true;
547 }
548
549 static void
lvp_handle_aabb_intersection(nir_builder * b,struct lvp_leaf_intersection * intersection,const struct lvp_ray_traversal_args * args,const struct lvp_ray_flags * ray_flags)550 lvp_handle_aabb_intersection(nir_builder *b, struct lvp_leaf_intersection *intersection,
551 const struct lvp_ray_traversal_args *args,
552 const struct lvp_ray_flags *ray_flags)
553 {
554 struct lvp_ray_tracing_pipeline_compiler *compiler = args->data;
555 struct lvp_ray_tracing_state *state = &compiler->state;
556
557 nir_store_var(b, state->accept, nir_imm_false(b), 0x1);
558 nir_store_var(b, state->terminate, ray_flags->terminate_on_first_hit, 0x1);
559 nir_store_var(b, state->opaque, intersection->opaque, 0x1);
560
561 nir_def *prev_instance_addr = nir_load_var(b, state->instance_addr);
562 nir_def *prev_primitive_id = nir_load_var(b, state->primitive_id);
563 nir_def *prev_geometry_id_and_flags = nir_load_var(b, state->geometry_id_and_flags);
564
565 nir_store_var(b, state->instance_addr, nir_load_var(b, state->traversal.instance_addr), 0x1);
566 nir_store_var(b, state->primitive_id, intersection->primitive_id, 0x1);
567 nir_store_var(b, state->geometry_id_and_flags, intersection->geometry_id_and_flags, 0x1);
568
569 nir_def *geometry_id = nir_iand_imm(b, intersection->geometry_id_and_flags, 0xfffffff);
570 nir_def *sbt_index =
571 nir_iadd(b,
572 nir_iadd(b, nir_load_var(b, state->sbt_offset),
573 nir_iand_imm(b, nir_load_var(b, state->traversal.sbt_offset_and_flags), 0xffffff)),
574 nir_imul(b, nir_load_var(b, state->sbt_stride), geometry_id));
575
576 struct lvp_sbt_entry isec_entry = lvp_load_sbt_entry(
577 b,
578 sbt_index,
579 offsetof(VkTraceRaysIndirectCommand2KHR, hitShaderBindingTableAddress),
580 offsetof(struct lvp_ray_tracing_group_handle, index));
581 nir_store_var(b, compiler->state.shader_record_ptr, isec_entry.shader_record_ptr, 0x1);
582
583 for (uint32_t i = 0; i < compiler->pipeline->rt.group_count; i++) {
584 struct lvp_ray_tracing_group *group = compiler->pipeline->rt.groups + i;
585 if (group->isec_index == VK_SHADER_UNUSED_KHR)
586 continue;
587
588 nir_shader *stage = compiler->pipeline->rt.stages[group->isec_index]->nir;
589
590 nir_push_if(b, nir_ieq_imm(b, isec_entry.value, group->handle.index));
591 lvp_call_ray_tracing_stage(b, compiler, stage);
592 nir_pop_if(b, NULL);
593
594 nir_shader *ahit_stage = NULL;
595 if (group->ahit_index != VK_SHADER_UNUSED_KHR)
596 ahit_stage = compiler->pipeline->rt.stages[group->ahit_index]->nir;
597
598 struct lvp_lower_isec_intrinsic_state isec_state = {
599 .compiler = compiler,
600 .ahit = ahit_stage,
601 };
602 nir_shader_intrinsics_pass(b->shader, lvp_lower_isec_intrinsic,
603 nir_metadata_none, &isec_state);
604 }
605
606 nir_push_if(b, nir_load_var(b, state->accept));
607 {
608 nir_store_var(b, state->sbt_index, sbt_index, 0x1);
609 nir_store_var(b, state->traversal.hit, nir_imm_true(b), 0x1);
610
611 nir_break_if(b, nir_load_var(b, state->terminate));
612 }
613 nir_push_else(b, NULL);
614 {
615 nir_store_var(b, state->instance_addr, prev_instance_addr, 0x1);
616 nir_store_var(b, state->primitive_id, prev_primitive_id, 0x1);
617 nir_store_var(b, state->geometry_id_and_flags, prev_geometry_id_and_flags, 0x1);
618 }
619 nir_pop_if(b, NULL);
620 }
621
622 static void
lvp_handle_triangle_intersection(nir_builder * b,struct lvp_triangle_intersection * intersection,const struct lvp_ray_traversal_args * args,const struct lvp_ray_flags * ray_flags)623 lvp_handle_triangle_intersection(nir_builder *b,
624 struct lvp_triangle_intersection *intersection,
625 const struct lvp_ray_traversal_args *args,
626 const struct lvp_ray_flags *ray_flags)
627 {
628 struct lvp_ray_tracing_pipeline_compiler *compiler = args->data;
629 struct lvp_ray_tracing_state *state = &compiler->state;
630
631 nir_store_var(b, state->accept, nir_imm_true(b), 0x1);
632 nir_store_var(b, state->terminate, ray_flags->terminate_on_first_hit, 0x1);
633
634 nir_def *barycentrics_offset = nir_load_var(b, state->stack_ptr);
635
636 nir_def *prev_tmax = nir_load_var(b, state->tmax);
637 nir_def *prev_instance_addr = nir_load_var(b, state->instance_addr);
638 nir_def *prev_primitive_id = nir_load_var(b, state->primitive_id);
639 nir_def *prev_geometry_id_and_flags = nir_load_var(b, state->geometry_id_and_flags);
640 nir_def *prev_hit_kind = nir_load_var(b, state->hit_kind);
641 nir_def *prev_barycentrics = nir_load_scratch(b, 2, 32, barycentrics_offset);
642
643 nir_store_var(b, state->tmax, intersection->t, 0x1);
644 nir_store_var(b, state->instance_addr, nir_load_var(b, state->traversal.instance_addr), 0x1);
645 nir_store_var(b, state->primitive_id, intersection->base.primitive_id, 0x1);
646 nir_store_var(b, state->geometry_id_and_flags, intersection->base.geometry_id_and_flags, 0x1);
647 nir_store_var(b, state->hit_kind,
648 nir_bcsel(b, intersection->frontface, nir_imm_int(b, 0xFE), nir_imm_int(b, 0xFF)), 0x1);
649
650 nir_store_scratch(b, intersection->barycentrics, barycentrics_offset);
651
652 nir_def *geometry_id = nir_iand_imm(b, intersection->base.geometry_id_and_flags, 0xfffffff);
653 nir_def *sbt_index =
654 nir_iadd(b,
655 nir_iadd(b, nir_load_var(b, state->sbt_offset),
656 nir_iand_imm(b, nir_load_var(b, state->traversal.sbt_offset_and_flags), 0xffffff)),
657 nir_imul(b, nir_load_var(b, state->sbt_stride), geometry_id));
658
659 nir_push_if(b, nir_inot(b, intersection->base.opaque));
660 {
661 struct lvp_sbt_entry ahit_entry = lvp_load_sbt_entry(
662 b,
663 sbt_index,
664 offsetof(VkTraceRaysIndirectCommand2KHR, hitShaderBindingTableAddress),
665 offsetof(struct lvp_ray_tracing_group_handle, index));
666 nir_store_var(b, compiler->state.shader_record_ptr, ahit_entry.shader_record_ptr, 0x1);
667
668 for (uint32_t i = 0; i < compiler->pipeline->rt.group_count; i++) {
669 struct lvp_ray_tracing_group *group = compiler->pipeline->rt.groups + i;
670 if (group->ahit_index == VK_SHADER_UNUSED_KHR)
671 continue;
672
673 nir_shader *stage = compiler->pipeline->rt.stages[group->ahit_index]->nir;
674
675 nir_push_if(b, nir_ieq_imm(b, ahit_entry.value, group->handle.index));
676 lvp_call_ray_tracing_stage(b, compiler, stage);
677 nir_pop_if(b, NULL);
678 }
679 }
680 nir_pop_if(b, NULL);
681
682 nir_push_if(b, nir_load_var(b, state->accept));
683 {
684 nir_store_var(b, state->sbt_index, sbt_index, 0x1);
685 nir_store_var(b, state->traversal.hit, nir_imm_true(b), 0x1);
686
687 nir_break_if(b, nir_load_var(b, state->terminate));
688 }
689 nir_push_else(b, NULL);
690 {
691 nir_store_var(b, state->tmax, prev_tmax, 0x1);
692 nir_store_var(b, state->instance_addr, prev_instance_addr, 0x1);
693 nir_store_var(b, state->primitive_id, prev_primitive_id, 0x1);
694 nir_store_var(b, state->geometry_id_and_flags, prev_geometry_id_and_flags, 0x1);
695 nir_store_var(b, state->hit_kind, prev_hit_kind, 0x1);
696 nir_store_scratch(b, prev_barycentrics, barycentrics_offset);
697 }
698 nir_pop_if(b, NULL);
699 }
700
701 static void
lvp_trace_ray(nir_builder * b,struct lvp_ray_tracing_pipeline_compiler * compiler,nir_intrinsic_instr * instr)702 lvp_trace_ray(nir_builder *b, struct lvp_ray_tracing_pipeline_compiler *compiler,
703 nir_intrinsic_instr *instr)
704 {
705 struct lvp_ray_tracing_state *state = &compiler->state;
706
707 nir_def *accel_struct = instr->src[0].ssa;
708 nir_def *flags = instr->src[1].ssa;
709 nir_def *cull_mask = instr->src[2].ssa;
710 nir_def *sbt_offset = nir_iand_imm(b, instr->src[3].ssa, 0xF);
711 nir_def *sbt_stride = nir_iand_imm(b, instr->src[4].ssa, 0xF);
712 nir_def *miss_index = nir_iand_imm(b, instr->src[5].ssa, 0xFFFF);
713 nir_def *origin = instr->src[6].ssa;
714 nir_def *tmin = instr->src[7].ssa;
715 nir_def *dir = instr->src[8].ssa;
716 nir_def *tmax = instr->src[9].ssa;
717 nir_def *payload = instr->src[10].ssa;
718
719 uint32_t stack_size =
720 lvp_ray_tracing_pipeline_compiler_get_stack_size(compiler, b->impl->function);
721 nir_def *stack_ptr = nir_load_var(b, state->stack_ptr);
722 nir_store_var(b, state->stack_ptr, nir_iadd_imm(b, stack_ptr, stack_size), 0x1);
723
724 nir_store_var(b, state->shader_call_data_offset, nir_iadd_imm(b, payload, -stack_size), 0x1);
725
726 nir_def *bvh_base = accel_struct;
727 if (bvh_base->bit_size != 64) {
728 assert(bvh_base->num_components >= 2);
729 bvh_base = nir_load_ubo(
730 b, 1, 64, nir_channel(b, accel_struct, 0),
731 nir_imul_imm(b, nir_channel(b, accel_struct, 1), sizeof(struct lp_descriptor)), .range = ~0);
732 }
733
734 lvp_ray_traversal_state_init(b->impl, &state->traversal);
735
736 nir_store_var(b, state->bvh_base, bvh_base, 0x1);
737 nir_store_var(b, state->flags, flags, 0x1);
738 nir_store_var(b, state->cull_mask, cull_mask, 0x1);
739 nir_store_var(b, state->sbt_offset, sbt_offset, 0x1);
740 nir_store_var(b, state->sbt_stride, sbt_stride, 0x1);
741 nir_store_var(b, state->miss_index, miss_index, 0x1);
742 nir_store_var(b, state->origin, origin, 0x7);
743 nir_store_var(b, state->tmin, tmin, 0x1);
744 nir_store_var(b, state->dir, dir, 0x7);
745 nir_store_var(b, state->tmax, tmax, 0x1);
746
747 nir_store_var(b, state->traversal.bvh_base, bvh_base, 0x1);
748 nir_store_var(b, state->traversal.origin, origin, 0x7);
749 nir_store_var(b, state->traversal.dir, dir, 0x7);
750 nir_store_var(b, state->traversal.inv_dir, nir_frcp(b, dir), 0x7);
751 nir_store_var(b, state->traversal.current_node, nir_imm_int(b, LVP_BVH_ROOT_NODE), 0x1);
752 nir_store_var(b, state->traversal.stack_base, nir_imm_int(b, -1), 0x1);
753 nir_store_var(b, state->traversal.stack_ptr, nir_imm_int(b, 0), 0x1);
754
755 nir_store_var(b, state->traversal.hit, nir_imm_false(b), 0x1);
756
757 struct lvp_ray_traversal_vars vars = {
758 .tmax = nir_build_deref_var(b, state->tmax),
759 .origin = nir_build_deref_var(b, state->traversal.origin),
760 .dir = nir_build_deref_var(b, state->traversal.dir),
761 .inv_dir = nir_build_deref_var(b, state->traversal.inv_dir),
762 .bvh_base = nir_build_deref_var(b, state->traversal.bvh_base),
763 .current_node = nir_build_deref_var(b, state->traversal.current_node),
764 .stack_base = nir_build_deref_var(b, state->traversal.stack_base),
765 .stack_ptr = nir_build_deref_var(b, state->traversal.stack_ptr),
766 .stack = nir_build_deref_var(b, state->traversal.stack),
767 .instance_addr = nir_build_deref_var(b, state->traversal.instance_addr),
768 .sbt_offset_and_flags = nir_build_deref_var(b, state->traversal.sbt_offset_and_flags),
769 };
770
771 struct lvp_ray_traversal_args args = {
772 .root_bvh_base = bvh_base,
773 .flags = flags,
774 .cull_mask = nir_ishl_imm(b, cull_mask, 24),
775 .origin = origin,
776 .tmin = tmin,
777 .dir = dir,
778 .vars = vars,
779 .aabb_cb = (compiler->flags & VK_PIPELINE_CREATE_2_RAY_TRACING_SKIP_AABBS_BIT_KHR) ?
780 NULL : lvp_handle_aabb_intersection,
781 .triangle_cb = (compiler->flags & VK_PIPELINE_CREATE_2_RAY_TRACING_SKIP_TRIANGLES_BIT_KHR) ?
782 NULL : lvp_handle_triangle_intersection,
783 .data = compiler,
784 };
785
786 nir_push_if(b, nir_ine_imm(b, bvh_base, 0));
787 lvp_build_ray_traversal(b, &args);
788 nir_pop_if(b, NULL);
789
790 nir_push_if(b, nir_load_var(b, state->traversal.hit));
791 {
792 nir_def *skip_chit = nir_test_mask(b, flags, SpvRayFlagsSkipClosestHitShaderKHRMask);
793 nir_push_if(b, nir_inot(b, skip_chit));
794
795 struct lvp_sbt_entry chit_entry = lvp_load_sbt_entry(
796 b,
797 nir_load_var(b, state->sbt_index),
798 offsetof(VkTraceRaysIndirectCommand2KHR, hitShaderBindingTableAddress),
799 offsetof(struct lvp_ray_tracing_group_handle, index));
800 nir_store_var(b, compiler->state.shader_record_ptr, chit_entry.shader_record_ptr, 0x1);
801
802 for (uint32_t i = 0; i < compiler->pipeline->rt.group_count; i++) {
803 struct lvp_ray_tracing_group *group = compiler->pipeline->rt.groups + i;
804 if (group->recursive_index == VK_SHADER_UNUSED_KHR)
805 continue;
806
807 nir_shader *stage = compiler->pipeline->rt.stages[group->recursive_index]->nir;
808 if (stage->info.stage != MESA_SHADER_CLOSEST_HIT)
809 continue;
810
811 nir_push_if(b, nir_ieq_imm(b, chit_entry.value, group->handle.index));
812 lvp_call_ray_tracing_stage(b, compiler, stage);
813 nir_pop_if(b, NULL);
814 }
815
816 nir_pop_if(b, NULL);
817 }
818 nir_push_else(b, NULL);
819 {
820 struct lvp_sbt_entry miss_entry = lvp_load_sbt_entry(
821 b,
822 miss_index,
823 offsetof(VkTraceRaysIndirectCommand2KHR, missShaderBindingTableAddress),
824 offsetof(struct lvp_ray_tracing_group_handle, index));
825 nir_store_var(b, compiler->state.shader_record_ptr, miss_entry.shader_record_ptr, 0x1);
826
827 for (uint32_t i = 0; i < compiler->pipeline->rt.group_count; i++) {
828 struct lvp_ray_tracing_group *group = compiler->pipeline->rt.groups + i;
829 if (group->recursive_index == VK_SHADER_UNUSED_KHR)
830 continue;
831
832 nir_shader *stage = compiler->pipeline->rt.stages[group->recursive_index]->nir;
833 if (stage->info.stage != MESA_SHADER_MISS)
834 continue;
835
836 nir_push_if(b, nir_ieq_imm(b, miss_entry.value, group->handle.index));
837 lvp_call_ray_tracing_stage(b, compiler, stage);
838 nir_pop_if(b, NULL);
839 }
840 }
841 nir_pop_if(b, NULL);
842
843 nir_store_var(b, state->stack_ptr, stack_ptr, 0x1);
844 }
845
846 static bool
lvp_lower_ray_tracing_instr(nir_builder * b,nir_instr * instr,void * data)847 lvp_lower_ray_tracing_instr(nir_builder *b, nir_instr *instr, void *data)
848 {
849 struct lvp_ray_tracing_pipeline_compiler *compiler = data;
850 struct lvp_ray_tracing_state *state = &compiler->state;
851
852 if (instr->type == nir_instr_type_jump) {
853 nir_jump_instr *jump = nir_instr_as_jump(instr);
854 if (jump->type == nir_jump_halt) {
855 jump->type = nir_jump_return;
856 return true;
857 }
858 return false;
859 } else if (instr->type != nir_instr_type_intrinsic) {
860 return false;
861 }
862
863 nir_intrinsic_instr *intr = nir_instr_as_intrinsic(instr);
864
865 nir_def *def = NULL;
866
867 b->cursor = nir_before_instr(instr);
868
869 switch (intr->intrinsic) {
870 /* Ray tracing instructions */
871 case nir_intrinsic_execute_callable:
872 lvp_execute_callable(b, compiler, intr);
873 break;
874 case nir_intrinsic_trace_ray:
875 lvp_trace_ray(b, compiler, intr);
876 break;
877 case nir_intrinsic_ignore_ray_intersection: {
878 nir_store_var(b, state->accept, nir_imm_false(b), 0x1);
879
880 nir_push_if(b, nir_imm_true(b));
881 nir_jump(b, nir_jump_return);
882 nir_pop_if(b, NULL);
883 break;
884 }
885 case nir_intrinsic_terminate_ray: {
886 nir_store_var(b, state->accept, nir_imm_true(b), 0x1);
887 nir_store_var(b, state->terminate, nir_imm_true(b), 0x1);
888
889 nir_push_if(b, nir_imm_true(b));
890 nir_jump(b, nir_jump_return);
891 nir_pop_if(b, NULL);
892 break;
893 }
894 /* Ray tracing system values */
895 case nir_intrinsic_load_ray_launch_id:
896 def = nir_load_global_invocation_id(b, 32);
897 break;
898 case nir_intrinsic_load_ray_launch_size:
899 def = lvp_load_trace_ray_command_field(
900 b, offsetof(VkTraceRaysIndirectCommand2KHR, width), 3, 32);
901 break;
902 case nir_intrinsic_load_shader_record_ptr:
903 def = nir_load_var(b, state->shader_record_ptr);
904 break;
905 case nir_intrinsic_load_ray_t_min:
906 def = nir_load_var(b, state->tmin);
907 break;
908 case nir_intrinsic_load_ray_t_max:
909 def = nir_load_var(b, state->tmax);
910 break;
911 case nir_intrinsic_load_ray_world_origin:
912 def = nir_load_var(b, state->origin);
913 break;
914 case nir_intrinsic_load_ray_world_direction:
915 def = nir_load_var(b, state->dir);
916 break;
917 case nir_intrinsic_load_ray_instance_custom_index: {
918 nir_def *instance_node_addr = nir_load_var(b, state->instance_addr);
919 nir_def *custom_instance_and_mask = nir_build_load_global(
920 b, 1, 32,
921 nir_iadd_imm(b, instance_node_addr, offsetof(struct lvp_bvh_instance_node, custom_instance_and_mask)));
922 def = nir_iand_imm(b, custom_instance_and_mask, 0xFFFFFF);
923 break;
924 }
925 case nir_intrinsic_load_primitive_id:
926 def = nir_load_var(b, state->primitive_id);
927 break;
928 case nir_intrinsic_load_ray_geometry_index:
929 def = nir_load_var(b, state->geometry_id_and_flags);
930 def = nir_iand_imm(b, def, 0xFFFFFFF);
931 break;
932 case nir_intrinsic_load_instance_id: {
933 nir_def *instance_node_addr = nir_load_var(b, state->instance_addr);
934 def = nir_build_load_global(
935 b, 1, 32, nir_iadd_imm(b, instance_node_addr, offsetof(struct lvp_bvh_instance_node, instance_id)));
936 break;
937 }
938 case nir_intrinsic_load_ray_flags:
939 def = nir_load_var(b, state->flags);
940 break;
941 case nir_intrinsic_load_ray_hit_kind:
942 def = nir_load_var(b, state->hit_kind);
943 break;
944 case nir_intrinsic_load_ray_world_to_object: {
945 unsigned c = nir_intrinsic_column(intr);
946 nir_def *instance_node_addr = nir_load_var(b, state->instance_addr);
947 nir_def *wto_matrix[3];
948 lvp_load_wto_matrix(b, instance_node_addr, wto_matrix);
949
950 nir_def *vals[3];
951 for (unsigned i = 0; i < 3; ++i)
952 vals[i] = nir_channel(b, wto_matrix[i], c);
953
954 def = nir_vec(b, vals, 3);
955 break;
956 }
957 case nir_intrinsic_load_ray_object_to_world: {
958 unsigned c = nir_intrinsic_column(intr);
959 nir_def *instance_node_addr = nir_load_var(b, state->instance_addr);
960 nir_def *rows[3];
961 for (unsigned r = 0; r < 3; ++r)
962 rows[r] = nir_build_load_global(
963 b, 4, 32,
964 nir_iadd_imm(b, instance_node_addr, offsetof(struct lvp_bvh_instance_node, otw_matrix) + r * 16));
965 def = nir_vec3(b, nir_channel(b, rows[0], c), nir_channel(b, rows[1], c), nir_channel(b, rows[2], c));
966 break;
967 }
968 case nir_intrinsic_load_ray_object_origin: {
969 nir_def *instance_node_addr = nir_load_var(b, state->instance_addr);
970 nir_def *wto_matrix[3];
971 lvp_load_wto_matrix(b, instance_node_addr, wto_matrix);
972 def = lvp_mul_vec3_mat(b, nir_load_var(b, state->origin), wto_matrix, true);
973 break;
974 }
975 case nir_intrinsic_load_ray_object_direction: {
976 nir_def *instance_node_addr = nir_load_var(b, state->instance_addr);
977 nir_def *wto_matrix[3];
978 lvp_load_wto_matrix(b, instance_node_addr, wto_matrix);
979 def = lvp_mul_vec3_mat(b, nir_load_var(b, state->dir), wto_matrix, false);
980 break;
981 }
982 case nir_intrinsic_load_cull_mask:
983 def = nir_iand_imm(b, nir_load_var(b, state->cull_mask), 0xFF);
984 break;
985 /* Ray tracing stack lowering */
986 case nir_intrinsic_load_scratch: {
987 nir_src_rewrite(&intr->src[0], nir_iadd(b, nir_load_var(b, state->stack_ptr), intr->src[0].ssa));
988 return true;
989 }
990 case nir_intrinsic_store_scratch: {
991 nir_src_rewrite(&intr->src[1], nir_iadd(b, nir_load_var(b, state->stack_ptr), intr->src[1].ssa));
992 return true;
993 }
994 case nir_intrinsic_load_ray_triangle_vertex_positions: {
995 def = lvp_load_vertex_position(
996 b, nir_load_var(b, state->instance_addr), nir_load_var(b, state->primitive_id),
997 nir_intrinsic_column(intr));
998 break;
999 }
1000 /* Internal system values */
1001 case nir_intrinsic_load_shader_call_data_offset_lvp:
1002 def = nir_load_var(b, state->shader_call_data_offset);
1003 break;
1004 default:
1005 return false;
1006 }
1007
1008 if (def)
1009 nir_def_rewrite_uses(&intr->def, def);
1010 nir_instr_remove(instr);
1011
1012 return true;
1013 }
1014
1015 static bool
lvp_lower_ray_tracing_stack_base(nir_builder * b,nir_intrinsic_instr * instr,void * data)1016 lvp_lower_ray_tracing_stack_base(nir_builder *b, nir_intrinsic_instr *instr, void *data)
1017 {
1018 if (instr->intrinsic != nir_intrinsic_load_ray_tracing_stack_base_lvp)
1019 return false;
1020
1021 b->cursor = nir_after_instr(&instr->instr);
1022
1023 nir_def_replace(&instr->def, nir_imm_int(b, b->shader->scratch_size));
1024
1025 return true;
1026 }
1027
1028 static void
lvp_compile_ray_tracing_pipeline(struct lvp_pipeline * pipeline,const VkRayTracingPipelineCreateInfoKHR * create_info)1029 lvp_compile_ray_tracing_pipeline(struct lvp_pipeline *pipeline,
1030 const VkRayTracingPipelineCreateInfoKHR *create_info)
1031 {
1032 nir_builder _b = nir_builder_init_simple_shader(
1033 MESA_SHADER_COMPUTE,
1034 pipeline->device->pscreen->get_compiler_options(pipeline->device->pscreen, PIPE_SHADER_IR_NIR, MESA_SHADER_COMPUTE),
1035 "ray tracing pipeline");
1036 nir_builder *b = &_b;
1037
1038 b->shader->info.workgroup_size[0] = 8;
1039
1040 struct lvp_ray_tracing_pipeline_compiler compiler = {
1041 .pipeline = pipeline,
1042 .flags = vk_rt_pipeline_create_flags(create_info),
1043 };
1044 lvp_ray_tracing_state_init(b->shader, &compiler.state);
1045 compiler.functions = _mesa_pointer_hash_table_create(NULL);
1046
1047 nir_def *launch_id = nir_load_ray_launch_id(b);
1048 nir_def *launch_size = nir_load_ray_launch_size(b);
1049 nir_def *oob = nir_ige(b, nir_channel(b, launch_id, 0), nir_channel(b, launch_size, 0));
1050 oob = nir_ior(b, oob, nir_ige(b, nir_channel(b, launch_id, 1), nir_channel(b, launch_size, 1)));
1051 oob = nir_ior(b, oob, nir_ige(b, nir_channel(b, launch_id, 2), nir_channel(b, launch_size, 2)));
1052
1053 nir_push_if(b, oob);
1054 nir_jump(b, nir_jump_return);
1055 nir_pop_if(b, NULL);
1056
1057 nir_store_var(b, compiler.state.stack_ptr, nir_load_ray_tracing_stack_base_lvp(b), 0x1);
1058
1059 struct lvp_sbt_entry raygen_entry = lvp_load_sbt_entry(
1060 b,
1061 NULL,
1062 offsetof(VkTraceRaysIndirectCommand2KHR, raygenShaderRecordAddress),
1063 offsetof(struct lvp_ray_tracing_group_handle, index));
1064 nir_store_var(b, compiler.state.shader_record_ptr, raygen_entry.shader_record_ptr, 0x1);
1065
1066 for (uint32_t i = 0; i < pipeline->rt.group_count; i++) {
1067 struct lvp_ray_tracing_group *group = pipeline->rt.groups + i;
1068 if (group->recursive_index == VK_SHADER_UNUSED_KHR)
1069 continue;
1070
1071 nir_shader *stage = pipeline->rt.stages[group->recursive_index]->nir;
1072
1073 if (stage->info.stage != MESA_SHADER_RAYGEN)
1074 continue;
1075
1076 nir_push_if(b, nir_ieq_imm(b, raygen_entry.value, group->handle.index));
1077 lvp_call_ray_tracing_stage(b, &compiler, stage);
1078 nir_pop_if(b, NULL);
1079 }
1080
1081 nir_shader_instructions_pass(b->shader, lvp_lower_ray_tracing_instr, nir_metadata_none, &compiler);
1082
1083 NIR_PASS(_, b->shader, nir_lower_returns);
1084
1085 const struct nir_lower_compute_system_values_options compute_system_values = {0};
1086 NIR_PASS(_, b->shader, nir_lower_compute_system_values, &compute_system_values);
1087 NIR_PASS(_, b->shader, nir_lower_global_vars_to_local);
1088 NIR_PASS(_, b->shader, nir_lower_vars_to_ssa);
1089
1090 NIR_PASS(_, b->shader, nir_lower_vars_to_explicit_types,
1091 nir_var_shader_temp,
1092 glsl_get_natural_size_align_bytes);
1093
1094 NIR_PASS(_, b->shader, nir_lower_explicit_io, nir_var_shader_temp,
1095 nir_address_format_32bit_offset);
1096
1097 NIR_PASS(_, b->shader, nir_shader_intrinsics_pass, lvp_lower_ray_tracing_stack_base,
1098 nir_metadata_control_flow, NULL);
1099
1100 /* We can not support dynamic stack sizes, assume the worst. */
1101 b->shader->scratch_size +=
1102 compiler.raygen_size +
1103 MIN2(create_info->maxPipelineRayRecursionDepth, 1) * MAX3(compiler.chit_size, compiler.miss_size, compiler.isec_size + compiler.ahit_size) +
1104 MAX2(0, (int)create_info->maxPipelineRayRecursionDepth - 1) * MAX2(compiler.chit_size, compiler.miss_size) + 31 * compiler.callable_size;
1105
1106 struct lvp_shader *shader = &pipeline->shaders[MESA_SHADER_RAYGEN];
1107 lvp_shader_init(shader, b->shader);
1108 shader->shader_cso = lvp_shader_compile(pipeline->device, shader, nir_shader_clone(NULL, shader->pipeline_nir->nir), false);
1109
1110 _mesa_hash_table_destroy(compiler.functions, NULL);
1111 }
1112
1113 static VkResult
lvp_create_ray_tracing_pipeline(VkDevice _device,const VkAllocationCallbacks * allocator,const VkRayTracingPipelineCreateInfoKHR * create_info,VkPipeline * out_pipeline)1114 lvp_create_ray_tracing_pipeline(VkDevice _device, const VkAllocationCallbacks *allocator,
1115 const VkRayTracingPipelineCreateInfoKHR *create_info,
1116 VkPipeline *out_pipeline)
1117 {
1118 VK_FROM_HANDLE(lvp_device, device, _device);
1119 VK_FROM_HANDLE(lvp_pipeline_layout, layout, create_info->layout);
1120
1121 VkResult result = VK_SUCCESS;
1122
1123 struct lvp_pipeline *pipeline = vk_zalloc2(&device->vk.alloc, allocator, sizeof(struct lvp_pipeline), 8,
1124 VK_SYSTEM_ALLOCATION_SCOPE_OBJECT);
1125 if (!pipeline)
1126 return VK_ERROR_OUT_OF_HOST_MEMORY;
1127
1128 vk_object_base_init(&device->vk, &pipeline->base,
1129 VK_OBJECT_TYPE_PIPELINE);
1130
1131 vk_pipeline_layout_ref(&layout->vk);
1132
1133 pipeline->device = device;
1134 pipeline->layout = layout;
1135 pipeline->type = LVP_PIPELINE_RAY_TRACING;
1136 pipeline->flags = vk_rt_pipeline_create_flags(create_info);
1137
1138 pipeline->rt.stage_count = create_info->stageCount;
1139 pipeline->rt.group_count = create_info->groupCount;
1140 if (create_info->pLibraryInfo) {
1141 for (uint32_t i = 0; i < create_info->pLibraryInfo->libraryCount; i++) {
1142 VK_FROM_HANDLE(lvp_pipeline, library, create_info->pLibraryInfo->pLibraries[i]);
1143 pipeline->rt.stage_count += library->rt.stage_count;
1144 pipeline->rt.group_count += library->rt.group_count;
1145 }
1146 }
1147
1148 pipeline->rt.stages = calloc(pipeline->rt.stage_count, sizeof(struct lvp_pipeline_nir *));
1149 pipeline->rt.groups = calloc(pipeline->rt.group_count, sizeof(struct lvp_ray_tracing_group));
1150 if (!pipeline->rt.stages || !pipeline->rt.groups) {
1151 result = VK_ERROR_OUT_OF_HOST_MEMORY;
1152 goto fail;
1153 }
1154
1155 result = lvp_compile_ray_tracing_stages(pipeline, create_info);
1156 if (result != VK_SUCCESS)
1157 goto fail;
1158
1159 lvp_init_ray_tracing_groups(pipeline, create_info);
1160
1161 if (!(pipeline->flags & VK_PIPELINE_CREATE_2_LIBRARY_BIT_KHR)) {
1162 lvp_compile_ray_tracing_pipeline(pipeline, create_info);
1163 }
1164
1165 *out_pipeline = lvp_pipeline_to_handle(pipeline);
1166
1167 return VK_SUCCESS;
1168
1169 fail:
1170 lvp_pipeline_destroy(device, pipeline, false);
1171 return result;
1172 }
1173
1174 VKAPI_ATTR VkResult VKAPI_CALL
lvp_CreateRayTracingPipelinesKHR(VkDevice device,VkDeferredOperationKHR deferredOperation,VkPipelineCache pipelineCache,uint32_t createInfoCount,const VkRayTracingPipelineCreateInfoKHR * pCreateInfos,const VkAllocationCallbacks * pAllocator,VkPipeline * pPipelines)1175 lvp_CreateRayTracingPipelinesKHR(
1176 VkDevice device,
1177 VkDeferredOperationKHR deferredOperation,
1178 VkPipelineCache pipelineCache,
1179 uint32_t createInfoCount,
1180 const VkRayTracingPipelineCreateInfoKHR *pCreateInfos,
1181 const VkAllocationCallbacks *pAllocator,
1182 VkPipeline *pPipelines)
1183 {
1184 VkResult result = VK_SUCCESS;
1185
1186 uint32_t i = 0;
1187 for (; i < createInfoCount; i++) {
1188 VkResult tmp_result = lvp_create_ray_tracing_pipeline(
1189 device, pAllocator, pCreateInfos + i, pPipelines + i);
1190
1191 if (tmp_result != VK_SUCCESS) {
1192 result = tmp_result;
1193 pPipelines[i] = VK_NULL_HANDLE;
1194
1195 if (vk_rt_pipeline_create_flags(&pCreateInfos[i]) &
1196 VK_PIPELINE_CREATE_2_EARLY_RETURN_ON_FAILURE_BIT_KHR)
1197 break;
1198 }
1199 }
1200
1201 for (; i < createInfoCount; i++)
1202 pPipelines[i] = VK_NULL_HANDLE;
1203
1204 return result;
1205 }
1206
1207
1208 VKAPI_ATTR VkResult VKAPI_CALL
lvp_GetRayTracingShaderGroupHandlesKHR(VkDevice _device,VkPipeline _pipeline,uint32_t firstGroup,uint32_t groupCount,size_t dataSize,void * pData)1209 lvp_GetRayTracingShaderGroupHandlesKHR(
1210 VkDevice _device,
1211 VkPipeline _pipeline,
1212 uint32_t firstGroup,
1213 uint32_t groupCount,
1214 size_t dataSize,
1215 void *pData)
1216 {
1217 VK_FROM_HANDLE(lvp_pipeline, pipeline, _pipeline);
1218
1219 uint8_t *data = pData;
1220 memset(data, 0, dataSize);
1221
1222 for (uint32_t i = 0; i < groupCount; i++) {
1223 memcpy(data + i * LVP_RAY_TRACING_GROUP_HANDLE_SIZE,
1224 pipeline->rt.groups + firstGroup + i,
1225 sizeof(struct lvp_ray_tracing_group_handle));
1226 }
1227
1228 return VK_SUCCESS;
1229 }
1230
1231 VKAPI_ATTR VkResult VKAPI_CALL
lvp_GetRayTracingCaptureReplayShaderGroupHandlesKHR(VkDevice device,VkPipeline pipeline,uint32_t firstGroup,uint32_t groupCount,size_t dataSize,void * pData)1232 lvp_GetRayTracingCaptureReplayShaderGroupHandlesKHR(
1233 VkDevice device,
1234 VkPipeline pipeline,
1235 uint32_t firstGroup,
1236 uint32_t groupCount,
1237 size_t dataSize,
1238 void *pData)
1239 {
1240 return VK_SUCCESS;
1241 }
1242
1243 VKAPI_ATTR VkDeviceSize VKAPI_CALL
lvp_GetRayTracingShaderGroupStackSizeKHR(VkDevice device,VkPipeline pipeline,uint32_t group,VkShaderGroupShaderKHR groupShader)1244 lvp_GetRayTracingShaderGroupStackSizeKHR(
1245 VkDevice device,
1246 VkPipeline pipeline,
1247 uint32_t group,
1248 VkShaderGroupShaderKHR groupShader)
1249 {
1250 return 4;
1251 }
1252