xref: /aosp_15_r20/external/mesa3d/src/intel/compiler/brw_nir_rt_builder.h (revision 6104692788411f58d303aa86923a9ff6ecaded22)
1 /*
2  * Copyright © 2020 Intel Corporation
3  *
4  * Permission is hereby granted, free of charge, to any person obtaining a
5  * copy of this software and associated documentation files (the "Software"),
6  * to deal in the Software without restriction, including without limitation
7  * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8  * and/or sell copies of the Software, and to permit persons to whom the
9  * Software is furnished to do so, subject to the following conditions:
10  *
11  * The above copyright notice and this permission notice (including the next
12  * paragraph) shall be included in all copies or substantial portions of the
13  * Software.
14  *
15  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
18  * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20  * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21  * IN THE SOFTWARE.
22  */
23 
24 #ifndef BRW_NIR_RT_BUILDER_H
25 #define BRW_NIR_RT_BUILDER_H
26 
27 /* This file provides helpers to access memory based data structures that the
28  * RT hardware reads/writes and their locations.
29  *
30  * See also "Memory Based Data Structures for Ray Tracing" (BSpec 47547) and
31  * "Ray Tracing Address Computation for Memory Resident Structures" (BSpec
32  * 47550).
33  */
34 
35 #include "brw_rt.h"
36 #include "nir_builder.h"
37 
38 #define is_access_for_builder(b) \
39    ((b)->shader->info.stage == MESA_SHADER_FRAGMENT ? \
40     ACCESS_INCLUDE_HELPERS : 0)
41 
42 static inline nir_def *
brw_nir_rt_load(nir_builder * b,nir_def * addr,unsigned align,unsigned components,unsigned bit_size)43 brw_nir_rt_load(nir_builder *b, nir_def *addr, unsigned align,
44                 unsigned components, unsigned bit_size)
45 {
46    return nir_build_load_global(b, components, bit_size, addr,
47                                 .align_mul = align,
48                                 .access = is_access_for_builder(b));
49 }
50 
51 static inline void
brw_nir_rt_store(nir_builder * b,nir_def * addr,unsigned align,nir_def * value,unsigned write_mask)52 brw_nir_rt_store(nir_builder *b, nir_def *addr, unsigned align,
53                  nir_def *value, unsigned write_mask)
54 {
55    nir_build_store_global(b, value, addr,
56                           .align_mul = align,
57                           .write_mask = (write_mask) &
58                                         BITFIELD_MASK(value->num_components),
59                           .access = is_access_for_builder(b));
60 }
61 
62 static inline nir_def *
brw_nir_rt_load_const(nir_builder * b,unsigned components,nir_def * addr)63 brw_nir_rt_load_const(nir_builder *b, unsigned components, nir_def *addr)
64 {
65    return nir_load_global_constant_uniform_block_intel(
66       b, components, 32, addr,
67       .access = ACCESS_CAN_REORDER | ACCESS_NON_WRITEABLE,
68       .align_mul = 64);
69 }
70 
71 static inline nir_def *
brw_load_btd_dss_id(nir_builder * b)72 brw_load_btd_dss_id(nir_builder *b)
73 {
74    return nir_load_topology_id_intel(b, .base = BRW_TOPOLOGY_ID_DSS);
75 }
76 
77 static inline nir_def *
brw_nir_rt_load_num_simd_lanes_per_dss(nir_builder * b,const struct intel_device_info * devinfo)78 brw_nir_rt_load_num_simd_lanes_per_dss(nir_builder *b,
79                                        const struct intel_device_info *devinfo)
80 {
81    return nir_imm_int(b, devinfo->num_thread_per_eu *
82                          devinfo->max_eus_per_subslice *
83                          16 /* The RT computation is based off SIMD16 */);
84 }
85 
86 static inline nir_def *
brw_load_eu_thread_simd(nir_builder * b)87 brw_load_eu_thread_simd(nir_builder *b)
88 {
89    return nir_load_topology_id_intel(b, .base = BRW_TOPOLOGY_ID_EU_THREAD_SIMD);
90 }
91 
92 static inline nir_def *
brw_nir_rt_async_stack_id(nir_builder * b)93 brw_nir_rt_async_stack_id(nir_builder *b)
94 {
95    return nir_iadd(b, nir_umul_32x16(b, nir_load_ray_num_dss_rt_stacks_intel(b),
96                                         brw_load_btd_dss_id(b)),
97                       nir_load_btd_stack_id_intel(b));
98 }
99 
100 static inline nir_def *
brw_nir_rt_sync_stack_id(nir_builder * b)101 brw_nir_rt_sync_stack_id(nir_builder *b)
102 {
103    return brw_load_eu_thread_simd(b);
104 }
105 
106 /* We have our own load/store scratch helpers because they emit a global
107  * memory read or write based on the scratch_base_ptr system value rather
108  * than a load/store_scratch intrinsic.
109  */
110 static inline nir_def *
brw_nir_rt_load_scratch(nir_builder * b,uint32_t offset,unsigned align,unsigned num_components,unsigned bit_size)111 brw_nir_rt_load_scratch(nir_builder *b, uint32_t offset, unsigned align,
112                         unsigned num_components, unsigned bit_size)
113 {
114    nir_def *addr =
115       nir_iadd_imm(b, nir_load_scratch_base_ptr(b, 1, 64, 1), offset);
116    return brw_nir_rt_load(b, addr, MIN2(align, BRW_BTD_STACK_ALIGN),
117                              num_components, bit_size);
118 }
119 
120 static inline void
brw_nir_rt_store_scratch(nir_builder * b,uint32_t offset,unsigned align,nir_def * value,nir_component_mask_t write_mask)121 brw_nir_rt_store_scratch(nir_builder *b, uint32_t offset, unsigned align,
122                          nir_def *value, nir_component_mask_t write_mask)
123 {
124    nir_def *addr =
125       nir_iadd_imm(b, nir_load_scratch_base_ptr(b, 1, 64, 1), offset);
126    brw_nir_rt_store(b, addr, MIN2(align, BRW_BTD_STACK_ALIGN),
127                     value, write_mask);
128 }
129 
130 static inline void
brw_nir_btd_spawn(nir_builder * b,nir_def * record_addr)131 brw_nir_btd_spawn(nir_builder *b, nir_def *record_addr)
132 {
133    nir_btd_spawn_intel(b, nir_load_btd_global_arg_addr_intel(b), record_addr);
134 }
135 
136 static inline void
brw_nir_btd_retire(nir_builder * b)137 brw_nir_btd_retire(nir_builder *b)
138 {
139    nir_btd_retire_intel(b);
140 }
141 
142 /** This is a pseudo-op which does a bindless return
143  *
144  * It loads the return address from the stack and calls btd_spawn to spawn the
145  * resume shader.
146  */
147 static inline void
brw_nir_btd_return(struct nir_builder * b)148 brw_nir_btd_return(struct nir_builder *b)
149 {
150    nir_def *resume_addr =
151       brw_nir_rt_load_scratch(b, BRW_BTD_STACK_RESUME_BSR_ADDR_OFFSET,
152                               8 /* align */, 1, 64);
153    brw_nir_btd_spawn(b, resume_addr);
154 }
155 
156 static inline void
assert_def_size(nir_def * def,unsigned num_components,unsigned bit_size)157 assert_def_size(nir_def *def, unsigned num_components, unsigned bit_size)
158 {
159    assert(def->num_components == num_components);
160    assert(def->bit_size == bit_size);
161 }
162 
163 static inline nir_def *
brw_nir_num_rt_stacks(nir_builder * b,const struct intel_device_info * devinfo)164 brw_nir_num_rt_stacks(nir_builder *b,
165                       const struct intel_device_info *devinfo)
166 {
167    return nir_imul_imm(b, nir_load_ray_num_dss_rt_stacks_intel(b),
168                           intel_device_info_dual_subslice_id_bound(devinfo));
169 }
170 
171 static inline nir_def *
brw_nir_rt_sw_hotzone_addr(nir_builder * b,const struct intel_device_info * devinfo)172 brw_nir_rt_sw_hotzone_addr(nir_builder *b,
173                            const struct intel_device_info *devinfo)
174 {
175    nir_def *offset32 =
176       nir_imul_imm(b, brw_nir_rt_async_stack_id(b),
177                       BRW_RT_SIZEOF_HOTZONE);
178 
179    offset32 = nir_iadd(b, offset32, nir_ineg(b,
180       nir_imul_imm(b, brw_nir_num_rt_stacks(b, devinfo),
181                       BRW_RT_SIZEOF_HOTZONE)));
182 
183    return nir_iadd(b, nir_load_ray_base_mem_addr_intel(b),
184                       nir_i2i64(b, offset32));
185 }
186 
187 static inline nir_def *
brw_nir_rt_sync_stack_addr(nir_builder * b,nir_def * base_mem_addr,const struct intel_device_info * devinfo)188 brw_nir_rt_sync_stack_addr(nir_builder *b,
189                            nir_def *base_mem_addr,
190                            const struct intel_device_info *devinfo)
191 {
192    /* For Ray queries (Synchronous Ray Tracing), the formula is similar but
193     * goes down from rtMemBasePtr :
194     *
195     *    syncBase  = RTDispatchGlobals.rtMemBasePtr
196     *              - (DSSID * NUM_SIMD_LANES_PER_DSS + SyncStackID + 1)
197     *              * syncStackSize
198     *
199     * We assume that we can calculate a 32-bit offset first and then add it
200     * to the 64-bit base address at the end.
201     */
202    nir_def *offset32 =
203       nir_imul(b,
204                nir_iadd(b,
205                         nir_imul(b, brw_load_btd_dss_id(b),
206                                     brw_nir_rt_load_num_simd_lanes_per_dss(b, devinfo)),
207                         nir_iadd_imm(b, brw_nir_rt_sync_stack_id(b), 1)),
208                nir_imm_int(b, BRW_RT_SIZEOF_RAY_QUERY));
209    return nir_isub(b, base_mem_addr, nir_u2u64(b, offset32));
210 }
211 
212 static inline nir_def *
brw_nir_rt_stack_addr(nir_builder * b)213 brw_nir_rt_stack_addr(nir_builder *b)
214 {
215    /* From the BSpec "Address Computation for Memory Based Data Structures:
216     * Ray and TraversalStack (Async Ray Tracing)":
217     *
218     *    stackBase = RTDispatchGlobals.rtMemBasePtr
219     *              + (DSSID * RTDispatchGlobals.numDSSRTStacks + stackID)
220     *              * RTDispatchGlobals.stackSizePerRay // 64B aligned
221     *
222     * We assume that we can calculate a 32-bit offset first and then add it
223     * to the 64-bit base address at the end.
224     */
225    nir_def *offset32 =
226       nir_imul(b, brw_nir_rt_async_stack_id(b),
227                   nir_load_ray_hw_stack_size_intel(b));
228    return nir_iadd(b, nir_load_ray_base_mem_addr_intel(b),
229                       nir_u2u64(b, offset32));
230 }
231 
232 static inline nir_def *
brw_nir_rt_mem_hit_addr_from_addr(nir_builder * b,nir_def * stack_addr,bool committed)233 brw_nir_rt_mem_hit_addr_from_addr(nir_builder *b,
234                         nir_def *stack_addr,
235                         bool committed)
236 {
237    return nir_iadd_imm(b, stack_addr, committed ? 0 : BRW_RT_SIZEOF_HIT_INFO);
238 }
239 
240 static inline nir_def *
brw_nir_rt_mem_hit_addr(nir_builder * b,bool committed)241 brw_nir_rt_mem_hit_addr(nir_builder *b, bool committed)
242 {
243    return nir_iadd_imm(b, brw_nir_rt_stack_addr(b),
244                           committed ? 0 : BRW_RT_SIZEOF_HIT_INFO);
245 }
246 
247 static inline nir_def *
brw_nir_rt_hit_attrib_data_addr(nir_builder * b)248 brw_nir_rt_hit_attrib_data_addr(nir_builder *b)
249 {
250    return nir_iadd_imm(b, brw_nir_rt_stack_addr(b),
251                           BRW_RT_OFFSETOF_HIT_ATTRIB_DATA);
252 }
253 
254 static inline nir_def *
brw_nir_rt_mem_ray_addr(nir_builder * b,nir_def * stack_addr,enum brw_rt_bvh_level bvh_level)255 brw_nir_rt_mem_ray_addr(nir_builder *b,
256                         nir_def *stack_addr,
257                         enum brw_rt_bvh_level bvh_level)
258 {
259    /* From the BSpec "Address Computation for Memory Based Data Structures:
260     * Ray and TraversalStack (Async Ray Tracing)":
261     *
262     *    rayBase = stackBase + sizeof(HitInfo) * 2 // 64B aligned
263     *    rayPtr  = rayBase + bvhLevel * sizeof(Ray); // 64B aligned
264     *
265     * In Vulkan, we always have exactly two levels of BVH: World and Object.
266     */
267    uint32_t offset = BRW_RT_SIZEOF_HIT_INFO * 2 +
268                      bvh_level * BRW_RT_SIZEOF_RAY;
269    return nir_iadd_imm(b, stack_addr, offset);
270 }
271 
272 static inline nir_def *
brw_nir_rt_sw_stack_addr(nir_builder * b,const struct intel_device_info * devinfo)273 brw_nir_rt_sw_stack_addr(nir_builder *b,
274                          const struct intel_device_info *devinfo)
275 {
276    nir_def *addr = nir_load_ray_base_mem_addr_intel(b);
277 
278    nir_def *offset32 = nir_imul(b, brw_nir_num_rt_stacks(b, devinfo),
279                                        nir_load_ray_hw_stack_size_intel(b));
280    addr = nir_iadd(b, addr, nir_u2u64(b, offset32));
281 
282    nir_def *offset_in_stack =
283       nir_imul(b, nir_u2u64(b, brw_nir_rt_async_stack_id(b)),
284                   nir_u2u64(b, nir_load_ray_sw_stack_size_intel(b)));
285 
286    return nir_iadd(b, addr, offset_in_stack);
287 }
288 
289 static inline nir_def *
nir_unpack_64_4x16_split_z(nir_builder * b,nir_def * val)290 nir_unpack_64_4x16_split_z(nir_builder *b, nir_def *val)
291 {
292    return nir_unpack_32_2x16_split_x(b, nir_unpack_64_2x32_split_y(b, val));
293 }
294 
295 struct brw_nir_rt_globals_defs {
296    nir_def *base_mem_addr;
297    nir_def *call_stack_handler_addr;
298    nir_def *hw_stack_size;
299    nir_def *num_dss_rt_stacks;
300    nir_def *hit_sbt_addr;
301    nir_def *hit_sbt_stride;
302    nir_def *miss_sbt_addr;
303    nir_def *miss_sbt_stride;
304    nir_def *sw_stack_size;
305    nir_def *launch_size;
306    nir_def *call_sbt_addr;
307    nir_def *call_sbt_stride;
308    nir_def *resume_sbt_addr;
309 };
310 
311 static inline void
brw_nir_rt_load_globals_addr(nir_builder * b,struct brw_nir_rt_globals_defs * defs,nir_def * addr)312 brw_nir_rt_load_globals_addr(nir_builder *b,
313                              struct brw_nir_rt_globals_defs *defs,
314                              nir_def *addr)
315 {
316    nir_def *data;
317    data = brw_nir_rt_load_const(b, 16, addr);
318    defs->base_mem_addr = nir_pack_64_2x32(b, nir_trim_vector(b, data, 2));
319 
320    defs->call_stack_handler_addr =
321       nir_pack_64_2x32(b, nir_channels(b, data, 0x3 << 2));
322 
323    defs->hw_stack_size = nir_channel(b, data, 4);
324    defs->num_dss_rt_stacks = nir_iand_imm(b, nir_channel(b, data, 5), 0xffff);
325    defs->hit_sbt_addr =
326       nir_pack_64_2x32_split(b, nir_channel(b, data, 8),
327                                 nir_extract_i16(b, nir_channel(b, data, 9),
328                                                    nir_imm_int(b, 0)));
329    defs->hit_sbt_stride =
330       nir_unpack_32_2x16_split_y(b, nir_channel(b, data, 9));
331    defs->miss_sbt_addr =
332       nir_pack_64_2x32_split(b, nir_channel(b, data, 10),
333                                 nir_extract_i16(b, nir_channel(b, data, 11),
334                                                    nir_imm_int(b, 0)));
335    defs->miss_sbt_stride =
336       nir_unpack_32_2x16_split_y(b, nir_channel(b, data, 11));
337    defs->sw_stack_size = nir_channel(b, data, 12);
338    defs->launch_size = nir_channels(b, data, 0x7u << 13);
339 
340    data = brw_nir_rt_load_const(b, 8, nir_iadd_imm(b, addr, 64));
341    defs->call_sbt_addr =
342       nir_pack_64_2x32_split(b, nir_channel(b, data, 0),
343                                 nir_extract_i16(b, nir_channel(b, data, 1),
344                                                    nir_imm_int(b, 0)));
345    defs->call_sbt_stride =
346       nir_unpack_32_2x16_split_y(b, nir_channel(b, data, 1));
347 
348    defs->resume_sbt_addr =
349       nir_pack_64_2x32(b, nir_channels(b, data, 0x3 << 2));
350 }
351 
352 static inline void
brw_nir_rt_load_globals(nir_builder * b,struct brw_nir_rt_globals_defs * defs)353 brw_nir_rt_load_globals(nir_builder *b,
354                         struct brw_nir_rt_globals_defs *defs)
355 {
356    brw_nir_rt_load_globals_addr(b, defs, nir_load_btd_global_arg_addr_intel(b));
357 }
358 
359 static inline nir_def *
brw_nir_rt_unpack_leaf_ptr(nir_builder * b,nir_def * vec2)360 brw_nir_rt_unpack_leaf_ptr(nir_builder *b, nir_def *vec2)
361 {
362    /* Hit record leaf pointers are 42-bit and assumed to be in 64B chunks.
363     * This leaves 22 bits at the top for other stuff.
364     */
365    nir_def *ptr64 = nir_imul_imm(b, nir_pack_64_2x32(b, vec2), 64);
366 
367    /* The top 16 bits (remember, we shifted by 6 already) contain garbage
368     * that we need to get rid of.
369     */
370    nir_def *ptr_lo = nir_unpack_64_2x32_split_x(b, ptr64);
371    nir_def *ptr_hi = nir_unpack_64_2x32_split_y(b, ptr64);
372    ptr_hi = nir_extract_i16(b, ptr_hi, nir_imm_int(b, 0));
373    return nir_pack_64_2x32_split(b, ptr_lo, ptr_hi);
374 }
375 
376 /**
377  * MemHit memory layout (BSpec 47547) :
378  *
379  *      name            bits    description
380  *    - t               32      hit distance of current hit (or initial traversal distance)
381  *    - u               32      barycentric hit coordinates
382  *    - v               32      barycentric hit coordinates
383  *    - primIndexDelta  16      prim index delta for compressed meshlets and quads
384  *    - valid            1      set if there is a hit
385  *    - leafType         3      type of node primLeafPtr is pointing to
386  *    - primLeafIndex    4      index of the hit primitive inside the leaf
387  *    - bvhLevel         3      the instancing level at which the hit occured
388  *    - frontFace        1      whether we hit the front-facing side of a triangle (also used to pass opaque flag when calling intersection shaders)
389  *    - pad0             4      unused bits
390  *    - primLeafPtr     42      pointer to BVH leaf node (multiple of 64 bytes)
391  *    - hitGroupRecPtr0 22      LSB of hit group record of the hit triangle (multiple of 16 bytes)
392  *    - instLeafPtr     42      pointer to BVH instance leaf node (in multiple of 64 bytes)
393  *    - hitGroupRecPtr1 22      MSB of hit group record of the hit triangle (multiple of 32 bytes)
394  */
395 struct brw_nir_rt_mem_hit_defs {
396    nir_def *t;
397    nir_def *tri_bary; /**< Only valid for triangle geometry */
398    nir_def *aabb_hit_kind; /**< Only valid for AABB geometry */
399    nir_def *valid;
400    nir_def *leaf_type;
401    nir_def *prim_index_delta;
402    nir_def *prim_leaf_index;
403    nir_def *bvh_level;
404    nir_def *front_face;
405    nir_def *done; /**< Only for ray queries */
406    nir_def *prim_leaf_ptr;
407    nir_def *inst_leaf_ptr;
408 };
409 
410 static inline void
brw_nir_rt_load_mem_hit_from_addr(nir_builder * b,struct brw_nir_rt_mem_hit_defs * defs,nir_def * stack_addr,bool committed)411 brw_nir_rt_load_mem_hit_from_addr(nir_builder *b,
412                                   struct brw_nir_rt_mem_hit_defs *defs,
413                                   nir_def *stack_addr,
414                                   bool committed)
415 {
416    nir_def *hit_addr =
417       brw_nir_rt_mem_hit_addr_from_addr(b, stack_addr, committed);
418 
419    nir_def *data = brw_nir_rt_load(b, hit_addr, 16, 4, 32);
420    defs->t = nir_channel(b, data, 0);
421    defs->aabb_hit_kind = nir_channel(b, data, 1);
422    defs->tri_bary = nir_channels(b, data, 0x6);
423    nir_def *bitfield = nir_channel(b, data, 3);
424    defs->prim_index_delta =
425       nir_ubitfield_extract(b, bitfield, nir_imm_int(b, 0), nir_imm_int(b, 16));
426    defs->valid = nir_i2b(b, nir_iand_imm(b, bitfield, 1u << 16));
427    defs->leaf_type =
428       nir_ubitfield_extract(b, bitfield, nir_imm_int(b, 17), nir_imm_int(b, 3));
429    defs->prim_leaf_index =
430       nir_ubitfield_extract(b, bitfield, nir_imm_int(b, 20), nir_imm_int(b, 4));
431    defs->bvh_level =
432       nir_ubitfield_extract(b, bitfield, nir_imm_int(b, 24), nir_imm_int(b, 3));
433    defs->front_face = nir_i2b(b, nir_iand_imm(b, bitfield, 1 << 27));
434    defs->done = nir_i2b(b, nir_iand_imm(b, bitfield, 1 << 28));
435 
436    data = brw_nir_rt_load(b, nir_iadd_imm(b, hit_addr, 16), 16, 4, 32);
437    defs->prim_leaf_ptr =
438       brw_nir_rt_unpack_leaf_ptr(b, nir_channels(b, data, 0x3 << 0));
439    defs->inst_leaf_ptr =
440       brw_nir_rt_unpack_leaf_ptr(b, nir_channels(b, data, 0x3 << 2));
441 }
442 
443 static inline void
brw_nir_rt_load_mem_hit(nir_builder * b,struct brw_nir_rt_mem_hit_defs * defs,bool committed)444 brw_nir_rt_load_mem_hit(nir_builder *b,
445                         struct brw_nir_rt_mem_hit_defs *defs,
446                         bool committed)
447 {
448    brw_nir_rt_load_mem_hit_from_addr(b, defs, brw_nir_rt_stack_addr(b),
449                                      committed);
450 }
451 
452 static inline void
brw_nir_memcpy_global(nir_builder * b,nir_def * dst_addr,uint32_t dst_align,nir_def * src_addr,uint32_t src_align,uint32_t size)453 brw_nir_memcpy_global(nir_builder *b,
454                       nir_def *dst_addr, uint32_t dst_align,
455                       nir_def *src_addr, uint32_t src_align,
456                       uint32_t size)
457 {
458    /* We're going to copy in 16B chunks */
459    assert(size % 16 == 0);
460    dst_align = MIN2(dst_align, 16);
461    src_align = MIN2(src_align, 16);
462 
463    for (unsigned offset = 0; offset < size; offset += 16) {
464       nir_def *data =
465          brw_nir_rt_load(b, nir_iadd_imm(b, src_addr, offset), 16,
466                          4, 32);
467       brw_nir_rt_store(b, nir_iadd_imm(b, dst_addr, offset), 16,
468                        data, 0xf /* write_mask */);
469    }
470 }
471 
472 static inline void
brw_nir_memclear_global(nir_builder * b,nir_def * dst_addr,uint32_t dst_align,uint32_t size)473 brw_nir_memclear_global(nir_builder *b,
474                         nir_def *dst_addr, uint32_t dst_align,
475                         uint32_t size)
476 {
477    /* We're going to copy in 16B chunks */
478    assert(size % 16 == 0);
479    dst_align = MIN2(dst_align, 16);
480 
481    nir_def *zero = nir_imm_ivec4(b, 0, 0, 0, 0);
482    for (unsigned offset = 0; offset < size; offset += 16) {
483       brw_nir_rt_store(b, nir_iadd_imm(b, dst_addr, offset), dst_align,
484                        zero, 0xf /* write_mask */);
485    }
486 }
487 
488 static inline nir_def *
brw_nir_rt_query_done(nir_builder * b,nir_def * stack_addr)489 brw_nir_rt_query_done(nir_builder *b, nir_def *stack_addr)
490 {
491    struct brw_nir_rt_mem_hit_defs hit_in = {};
492    brw_nir_rt_load_mem_hit_from_addr(b, &hit_in, stack_addr,
493                                      false /* committed */);
494 
495    return hit_in.done;
496 }
497 
498 static inline void
brw_nir_rt_set_dword_bit_at(nir_builder * b,nir_def * addr,uint32_t addr_offset,uint32_t bit)499 brw_nir_rt_set_dword_bit_at(nir_builder *b,
500                             nir_def *addr,
501                             uint32_t addr_offset,
502                             uint32_t bit)
503 {
504    nir_def *dword_addr = nir_iadd_imm(b, addr, addr_offset);
505    nir_def *dword = brw_nir_rt_load(b, dword_addr, 4, 1, 32);
506    brw_nir_rt_store(b, dword_addr, 4, nir_ior_imm(b, dword, 1u << bit), 0x1);
507 }
508 
509 static inline void
brw_nir_rt_query_mark_done(nir_builder * b,nir_def * stack_addr)510 brw_nir_rt_query_mark_done(nir_builder *b, nir_def *stack_addr)
511 {
512    brw_nir_rt_set_dword_bit_at(b,
513                                brw_nir_rt_mem_hit_addr_from_addr(b, stack_addr,
514                                                                  false /* committed */),
515                                4 * 3 /* dword offset */, 28 /* bit */);
516 }
517 
518 /* This helper clears the 3rd dword of the MemHit structure where the valid
519  * bit is located.
520  */
521 static inline void
brw_nir_rt_query_mark_init(nir_builder * b,nir_def * stack_addr)522 brw_nir_rt_query_mark_init(nir_builder *b, nir_def *stack_addr)
523 {
524    nir_def *dword_addr;
525 
526    for (uint32_t i = 0; i < 2; i++) {
527       dword_addr =
528          nir_iadd_imm(b,
529                       brw_nir_rt_mem_hit_addr_from_addr(b, stack_addr,
530                                                         i == 0 /* committed */),
531                       4 * 3 /* dword offset */);
532       brw_nir_rt_store(b, dword_addr, 4, nir_imm_int(b, 0), 0x1);
533    }
534 }
535 
536 /* This helper is pretty much a memcpy of uncommitted into committed hit
537  * structure, just adding the valid bit.
538  */
539 static inline void
brw_nir_rt_commit_hit_addr(nir_builder * b,nir_def * stack_addr)540 brw_nir_rt_commit_hit_addr(nir_builder *b, nir_def *stack_addr)
541 {
542    nir_def *dst_addr =
543       brw_nir_rt_mem_hit_addr_from_addr(b, stack_addr, true /* committed */);
544    nir_def *src_addr =
545       brw_nir_rt_mem_hit_addr_from_addr(b, stack_addr, false /* committed */);
546 
547    for (unsigned offset = 0; offset < BRW_RT_SIZEOF_HIT_INFO; offset += 16) {
548       nir_def *data =
549          brw_nir_rt_load(b, nir_iadd_imm(b, src_addr, offset), 16, 4, 32);
550 
551       if (offset == 0) {
552          data = nir_vec4(b,
553                          nir_channel(b, data, 0),
554                          nir_channel(b, data, 1),
555                          nir_channel(b, data, 2),
556                          nir_ior_imm(b,
557                                      nir_channel(b, data, 3),
558                                      0x1 << 16 /* valid */));
559 
560          /* Also write the potential hit as we change it. */
561          brw_nir_rt_store(b, nir_iadd_imm(b, src_addr, offset), 16,
562                           data, 0xf /* write_mask */);
563       }
564 
565       brw_nir_rt_store(b, nir_iadd_imm(b, dst_addr, offset), 16,
566                        data, 0xf /* write_mask */);
567    }
568 }
569 
570 static inline void
brw_nir_rt_commit_hit(nir_builder * b)571 brw_nir_rt_commit_hit(nir_builder *b)
572 {
573    nir_def *stack_addr = brw_nir_rt_stack_addr(b);
574    brw_nir_rt_commit_hit_addr(b, stack_addr);
575 }
576 
577 static inline void
brw_nir_rt_generate_hit_addr(nir_builder * b,nir_def * stack_addr,nir_def * t_val)578 brw_nir_rt_generate_hit_addr(nir_builder *b, nir_def *stack_addr, nir_def *t_val)
579 {
580    nir_def *committed_addr =
581       brw_nir_rt_mem_hit_addr_from_addr(b, stack_addr, true /* committed */);
582    nir_def *potential_addr =
583       brw_nir_rt_mem_hit_addr_from_addr(b, stack_addr, false /* committed */);
584 
585    /* Set:
586     *
587     *   potential.t     = t_val;
588     *   potential.valid = true;
589     */
590    nir_def *potential_hit_dwords_0_3 =
591       brw_nir_rt_load(b, potential_addr, 16, 4, 32);
592    potential_hit_dwords_0_3 =
593       nir_vec4(b,
594                t_val,
595                nir_channel(b, potential_hit_dwords_0_3, 1),
596                nir_channel(b, potential_hit_dwords_0_3, 2),
597                nir_ior_imm(b, nir_channel(b, potential_hit_dwords_0_3, 3),
598                            (0x1 << 16) /* valid */));
599    brw_nir_rt_store(b, potential_addr, 16, potential_hit_dwords_0_3, 0xf /* write_mask */);
600 
601    /* Set:
602     *
603     *   committed.t               = t_val;
604     *   committed.u               = 0.0f;
605     *   committed.v               = 0.0f;
606     *   committed.valid           = true;
607     *   committed.leaf_type       = potential.leaf_type;
608     *   committed.bvh_level       = BRW_RT_BVH_LEVEL_OBJECT;
609     *   committed.front_face      = false;
610     *   committed.prim_leaf_index = 0;
611     *   committed.done            = false;
612     */
613    nir_def *committed_hit_dwords_0_3 =
614       brw_nir_rt_load(b, committed_addr, 16, 4, 32);
615    committed_hit_dwords_0_3 =
616       nir_vec4(b,
617                t_val,
618                nir_imm_float(b, 0.0f),
619                nir_imm_float(b, 0.0f),
620                nir_ior_imm(b,
621                            nir_ior_imm(b, nir_channel(b, potential_hit_dwords_0_3, 3), 0x000e0000),
622                            (0x1 << 16)                     /* valid */ |
623                            (BRW_RT_BVH_LEVEL_OBJECT << 24) /* leaf_type */));
624    brw_nir_rt_store(b, committed_addr, 16, committed_hit_dwords_0_3, 0xf /* write_mask */);
625 
626    /* Set:
627     *
628     *   committed.prim_leaf_ptr   = potential.prim_leaf_ptr;
629     *   committed.inst_leaf_ptr   = potential.inst_leaf_ptr;
630     */
631    brw_nir_memcpy_global(b,
632                          nir_iadd_imm(b, committed_addr, 16), 16,
633                          nir_iadd_imm(b, potential_addr, 16), 16,
634                          16);
635 }
636 
637 struct brw_nir_rt_mem_ray_defs {
638    nir_def *orig;
639    nir_def *dir;
640    nir_def *t_near;
641    nir_def *t_far;
642    nir_def *root_node_ptr;
643    nir_def *ray_flags;
644    nir_def *hit_group_sr_base_ptr;
645    nir_def *hit_group_sr_stride;
646    nir_def *miss_sr_ptr;
647    nir_def *shader_index_multiplier;
648    nir_def *inst_leaf_ptr;
649    nir_def *ray_mask;
650 };
651 
652 static inline void
brw_nir_rt_store_mem_ray_query_at_addr(nir_builder * b,nir_def * ray_addr,const struct brw_nir_rt_mem_ray_defs * defs)653 brw_nir_rt_store_mem_ray_query_at_addr(nir_builder *b,
654                                        nir_def *ray_addr,
655                                        const struct brw_nir_rt_mem_ray_defs *defs)
656 {
657    assert_def_size(defs->orig, 3, 32);
658    assert_def_size(defs->dir, 3, 32);
659    brw_nir_rt_store(b, nir_iadd_imm(b, ray_addr, 0), 16,
660       nir_vec4(b, nir_channel(b, defs->orig, 0),
661                   nir_channel(b, defs->orig, 1),
662                   nir_channel(b, defs->orig, 2),
663                   nir_channel(b, defs->dir, 0)),
664       ~0 /* write mask */);
665 
666    assert_def_size(defs->t_near, 1, 32);
667    assert_def_size(defs->t_far, 1, 32);
668    brw_nir_rt_store(b, nir_iadd_imm(b, ray_addr, 16), 16,
669       nir_vec4(b, nir_channel(b, defs->dir, 1),
670                   nir_channel(b, defs->dir, 2),
671                   defs->t_near,
672                   defs->t_far),
673       ~0 /* write mask */);
674 
675    assert_def_size(defs->root_node_ptr, 1, 64);
676    assert_def_size(defs->ray_flags, 1, 16);
677    brw_nir_rt_store(b, nir_iadd_imm(b, ray_addr, 32), 16,
678       nir_vec2(b, nir_unpack_64_2x32_split_x(b, defs->root_node_ptr),
679                   nir_pack_32_2x16_split(b,
680                      nir_unpack_64_4x16_split_z(b, defs->root_node_ptr),
681                      defs->ray_flags)),
682       0x3 /* write mask */);
683 
684    /* leaf_ptr is optional */
685    nir_def *inst_leaf_ptr;
686    if (defs->inst_leaf_ptr) {
687       inst_leaf_ptr = defs->inst_leaf_ptr;
688    } else {
689       inst_leaf_ptr = nir_imm_int64(b, 0);
690    }
691 
692    assert_def_size(inst_leaf_ptr, 1, 64);
693    assert_def_size(defs->ray_mask, 1, 32);
694    brw_nir_rt_store(b, nir_iadd_imm(b, ray_addr, 56), 8,
695       nir_vec2(b, nir_unpack_64_2x32_split_x(b, inst_leaf_ptr),
696                   nir_pack_32_2x16_split(b,
697                      nir_unpack_64_4x16_split_z(b, inst_leaf_ptr),
698                      nir_unpack_32_2x16_split_x(b, defs->ray_mask))),
699       ~0 /* write mask */);
700 }
701 
702 static inline void
brw_nir_rt_store_mem_ray(nir_builder * b,const struct brw_nir_rt_mem_ray_defs * defs,enum brw_rt_bvh_level bvh_level)703 brw_nir_rt_store_mem_ray(nir_builder *b,
704                          const struct brw_nir_rt_mem_ray_defs *defs,
705                          enum brw_rt_bvh_level bvh_level)
706 {
707    nir_def *ray_addr =
708       brw_nir_rt_mem_ray_addr(b, brw_nir_rt_stack_addr(b), bvh_level);
709 
710    assert_def_size(defs->orig, 3, 32);
711    assert_def_size(defs->dir, 3, 32);
712    brw_nir_rt_store(b, nir_iadd_imm(b, ray_addr, 0), 16,
713       nir_vec4(b, nir_channel(b, defs->orig, 0),
714                   nir_channel(b, defs->orig, 1),
715                   nir_channel(b, defs->orig, 2),
716                   nir_channel(b, defs->dir, 0)),
717       ~0 /* write mask */);
718 
719    assert_def_size(defs->t_near, 1, 32);
720    assert_def_size(defs->t_far, 1, 32);
721    brw_nir_rt_store(b, nir_iadd_imm(b, ray_addr, 16), 16,
722       nir_vec4(b, nir_channel(b, defs->dir, 1),
723                   nir_channel(b, defs->dir, 2),
724                   defs->t_near,
725                   defs->t_far),
726       ~0 /* write mask */);
727 
728    assert_def_size(defs->root_node_ptr, 1, 64);
729    assert_def_size(defs->ray_flags, 1, 16);
730    assert_def_size(defs->hit_group_sr_base_ptr, 1, 64);
731    assert_def_size(defs->hit_group_sr_stride, 1, 16);
732    brw_nir_rt_store(b, nir_iadd_imm(b, ray_addr, 32), 16,
733       nir_vec4(b, nir_unpack_64_2x32_split_x(b, defs->root_node_ptr),
734                   nir_pack_32_2x16_split(b,
735                      nir_unpack_64_4x16_split_z(b, defs->root_node_ptr),
736                      defs->ray_flags),
737                   nir_unpack_64_2x32_split_x(b, defs->hit_group_sr_base_ptr),
738                   nir_pack_32_2x16_split(b,
739                      nir_unpack_64_4x16_split_z(b, defs->hit_group_sr_base_ptr),
740                      defs->hit_group_sr_stride)),
741       ~0 /* write mask */);
742 
743    /* leaf_ptr is optional */
744    nir_def *inst_leaf_ptr;
745    if (defs->inst_leaf_ptr) {
746       inst_leaf_ptr = defs->inst_leaf_ptr;
747    } else {
748       inst_leaf_ptr = nir_imm_int64(b, 0);
749    }
750 
751    assert_def_size(defs->miss_sr_ptr, 1, 64);
752    assert_def_size(defs->shader_index_multiplier, 1, 32);
753    assert_def_size(inst_leaf_ptr, 1, 64);
754    assert_def_size(defs->ray_mask, 1, 32);
755    brw_nir_rt_store(b, nir_iadd_imm(b, ray_addr, 48), 16,
756       nir_vec4(b, nir_unpack_64_2x32_split_x(b, defs->miss_sr_ptr),
757                   nir_pack_32_2x16_split(b,
758                      nir_unpack_64_4x16_split_z(b, defs->miss_sr_ptr),
759                      nir_unpack_32_2x16_split_x(b,
760                         nir_ishl(b, defs->shader_index_multiplier,
761                                     nir_imm_int(b, 8)))),
762                   nir_unpack_64_2x32_split_x(b, inst_leaf_ptr),
763                   nir_pack_32_2x16_split(b,
764                      nir_unpack_64_4x16_split_z(b, inst_leaf_ptr),
765                      nir_unpack_32_2x16_split_x(b, defs->ray_mask))),
766       ~0 /* write mask */);
767 }
768 
769 static inline void
brw_nir_rt_load_mem_ray_from_addr(nir_builder * b,struct brw_nir_rt_mem_ray_defs * defs,nir_def * ray_base_addr,enum brw_rt_bvh_level bvh_level)770 brw_nir_rt_load_mem_ray_from_addr(nir_builder *b,
771                                   struct brw_nir_rt_mem_ray_defs *defs,
772                                   nir_def *ray_base_addr,
773                                   enum brw_rt_bvh_level bvh_level)
774 {
775    nir_def *ray_addr = brw_nir_rt_mem_ray_addr(b,
776                                                    ray_base_addr,
777                                                    bvh_level);
778 
779    nir_def *data[4] = {
780       brw_nir_rt_load(b, nir_iadd_imm(b, ray_addr,  0), 16, 4, 32),
781       brw_nir_rt_load(b, nir_iadd_imm(b, ray_addr, 16), 16, 4, 32),
782       brw_nir_rt_load(b, nir_iadd_imm(b, ray_addr, 32), 16, 4, 32),
783       brw_nir_rt_load(b, nir_iadd_imm(b, ray_addr, 48), 16, 4, 32),
784    };
785 
786    defs->orig = nir_trim_vector(b, data[0], 3);
787    defs->dir = nir_vec3(b, nir_channel(b, data[0], 3),
788                            nir_channel(b, data[1], 0),
789                            nir_channel(b, data[1], 1));
790    defs->t_near = nir_channel(b, data[1], 2);
791    defs->t_far = nir_channel(b, data[1], 3);
792    defs->root_node_ptr =
793       nir_pack_64_2x32_split(b, nir_channel(b, data[2], 0),
794                                 nir_extract_i16(b, nir_channel(b, data[2], 1),
795                                                    nir_imm_int(b, 0)));
796    defs->ray_flags =
797       nir_unpack_32_2x16_split_y(b, nir_channel(b, data[2], 1));
798    defs->hit_group_sr_base_ptr =
799       nir_pack_64_2x32_split(b, nir_channel(b, data[2], 2),
800                                 nir_extract_i16(b, nir_channel(b, data[2], 3),
801                                                    nir_imm_int(b, 0)));
802    defs->hit_group_sr_stride =
803       nir_unpack_32_2x16_split_y(b, nir_channel(b, data[2], 3));
804    defs->miss_sr_ptr =
805       nir_pack_64_2x32_split(b, nir_channel(b, data[3], 0),
806                                 nir_extract_i16(b, nir_channel(b, data[3], 1),
807                                                    nir_imm_int(b, 0)));
808    defs->shader_index_multiplier =
809       nir_ushr(b, nir_unpack_32_2x16_split_y(b, nir_channel(b, data[3], 1)),
810                   nir_imm_int(b, 8));
811    defs->inst_leaf_ptr =
812       nir_pack_64_2x32_split(b, nir_channel(b, data[3], 2),
813                                 nir_extract_i16(b, nir_channel(b, data[3], 3),
814                                                    nir_imm_int(b, 0)));
815    defs->ray_mask =
816       nir_unpack_32_2x16_split_y(b, nir_channel(b, data[3], 3));
817 }
818 
819 static inline void
brw_nir_rt_load_mem_ray(nir_builder * b,struct brw_nir_rt_mem_ray_defs * defs,enum brw_rt_bvh_level bvh_level)820 brw_nir_rt_load_mem_ray(nir_builder *b,
821                         struct brw_nir_rt_mem_ray_defs *defs,
822                         enum brw_rt_bvh_level bvh_level)
823 {
824    brw_nir_rt_load_mem_ray_from_addr(b, defs, brw_nir_rt_stack_addr(b),
825                                      bvh_level);
826 }
827 
828 struct brw_nir_rt_bvh_instance_leaf_defs {
829    nir_def *shader_index;
830    nir_def *contribution_to_hit_group_index;
831    nir_def *world_to_object[4];
832    nir_def *instance_id;
833    nir_def *instance_index;
834    nir_def *object_to_world[4];
835 };
836 
837 static inline void
brw_nir_rt_load_bvh_instance_leaf(nir_builder * b,struct brw_nir_rt_bvh_instance_leaf_defs * defs,nir_def * leaf_addr)838 brw_nir_rt_load_bvh_instance_leaf(nir_builder *b,
839                                   struct brw_nir_rt_bvh_instance_leaf_defs *defs,
840                                   nir_def *leaf_addr)
841 {
842    nir_def *leaf_desc = brw_nir_rt_load(b, leaf_addr, 4, 2, 32);
843 
844    defs->shader_index =
845       nir_iand_imm(b, nir_channel(b, leaf_desc, 0), (1 << 24) - 1);
846    defs->contribution_to_hit_group_index =
847       nir_iand_imm(b, nir_channel(b, leaf_desc, 1), (1 << 24) - 1);
848 
849    defs->world_to_object[0] =
850       brw_nir_rt_load(b, nir_iadd_imm(b, leaf_addr, 16), 4, 3, 32);
851    defs->world_to_object[1] =
852       brw_nir_rt_load(b, nir_iadd_imm(b, leaf_addr, 28), 4, 3, 32);
853    defs->world_to_object[2] =
854       brw_nir_rt_load(b, nir_iadd_imm(b, leaf_addr, 40), 4, 3, 32);
855    /* The last column of the matrices is swapped between the two probably
856     * because it makes it easier/faster for hardware somehow.
857     */
858    defs->object_to_world[3] =
859       brw_nir_rt_load(b, nir_iadd_imm(b, leaf_addr, 52), 4, 3, 32);
860 
861    nir_def *data =
862       brw_nir_rt_load(b, nir_iadd_imm(b, leaf_addr, 64), 4, 4, 32);
863    defs->instance_id = nir_channel(b, data, 2);
864    defs->instance_index = nir_channel(b, data, 3);
865 
866    defs->object_to_world[0] =
867       brw_nir_rt_load(b, nir_iadd_imm(b, leaf_addr, 80), 4, 3, 32);
868    defs->object_to_world[1] =
869       brw_nir_rt_load(b, nir_iadd_imm(b, leaf_addr, 92), 4, 3, 32);
870    defs->object_to_world[2] =
871       brw_nir_rt_load(b, nir_iadd_imm(b, leaf_addr, 104), 4, 3, 32);
872    defs->world_to_object[3] =
873       brw_nir_rt_load(b, nir_iadd_imm(b, leaf_addr, 116), 4, 3, 32);
874 }
875 
876 struct brw_nir_rt_bvh_primitive_leaf_defs {
877    nir_def *shader_index;
878    nir_def *geom_mask;
879    nir_def *geom_index;
880    nir_def *type;
881    nir_def *geom_flags;
882 };
883 
884 static inline void
brw_nir_rt_load_bvh_primitive_leaf(nir_builder * b,struct brw_nir_rt_bvh_primitive_leaf_defs * defs,nir_def * leaf_addr)885 brw_nir_rt_load_bvh_primitive_leaf(nir_builder *b,
886                                    struct brw_nir_rt_bvh_primitive_leaf_defs *defs,
887                                    nir_def *leaf_addr)
888 {
889    nir_def *desc = brw_nir_rt_load(b, leaf_addr, 4, 2, 32);
890 
891    defs->shader_index =
892       nir_ubitfield_extract(b, nir_channel(b, desc, 0),
893                             nir_imm_int(b, 23), nir_imm_int(b, 0));
894    defs->geom_mask =
895       nir_ubitfield_extract(b, nir_channel(b, desc, 0),
896                             nir_imm_int(b, 31), nir_imm_int(b, 24));
897 
898    defs->geom_index =
899       nir_ubitfield_extract(b, nir_channel(b, desc, 1),
900                             nir_imm_int(b, 28), nir_imm_int(b, 0));
901    defs->type =
902       nir_ubitfield_extract(b, nir_channel(b, desc, 1),
903                             nir_imm_int(b, 29), nir_imm_int(b, 29));
904    defs->geom_flags =
905       nir_ubitfield_extract(b, nir_channel(b, desc, 1),
906                             nir_imm_int(b, 31), nir_imm_int(b, 30));
907 }
908 
909 struct brw_nir_rt_bvh_primitive_leaf_positions_defs {
910    nir_def *positions[3];
911 };
912 
913 static inline void
brw_nir_rt_load_bvh_primitive_leaf_positions(nir_builder * b,struct brw_nir_rt_bvh_primitive_leaf_positions_defs * defs,nir_def * leaf_addr)914 brw_nir_rt_load_bvh_primitive_leaf_positions(nir_builder *b,
915                                              struct brw_nir_rt_bvh_primitive_leaf_positions_defs *defs,
916                                              nir_def *leaf_addr)
917 {
918    for (unsigned i = 0; i < ARRAY_SIZE(defs->positions); i++) {
919       defs->positions[i] =
920          brw_nir_rt_load(b, nir_iadd_imm(b, leaf_addr, 16 + i * 4 * 3), 4, 3, 32);
921    }
922 }
923 
924 static inline nir_def *
brw_nir_rt_load_primitive_id_from_hit(nir_builder * b,nir_def * is_procedural,const struct brw_nir_rt_mem_hit_defs * defs)925 brw_nir_rt_load_primitive_id_from_hit(nir_builder *b,
926                                       nir_def *is_procedural,
927                                       const struct brw_nir_rt_mem_hit_defs *defs)
928 {
929    if (!is_procedural) {
930       is_procedural =
931          nir_ieq_imm(b, defs->leaf_type,
932                         BRW_RT_BVH_NODE_TYPE_PROCEDURAL);
933    }
934 
935    nir_def *prim_id_proc, *prim_id_quad;
936    nir_push_if(b, is_procedural);
937    {
938       /* For procedural leafs, the index is in dw[3]. */
939       nir_def *offset =
940          nir_iadd_imm(b, nir_ishl_imm(b, defs->prim_leaf_index, 2), 12);
941       prim_id_proc = nir_load_global(b, nir_iadd(b, defs->prim_leaf_ptr,
942                                                  nir_u2u64(b, offset)),
943                                      4, /* align */ 1, 32);
944    }
945    nir_push_else(b, NULL);
946    {
947       /* For quad leafs, the index is dw[2] and there is a 16bit additional
948        * offset in dw[3].
949        */
950       prim_id_quad = nir_load_global(b, nir_iadd_imm(b, defs->prim_leaf_ptr, 8),
951                                      4, /* align */ 1, 32);
952       prim_id_quad = nir_iadd(b,
953                               prim_id_quad,
954                               defs->prim_index_delta);
955    }
956    nir_pop_if(b, NULL);
957 
958    return nir_if_phi(b, prim_id_proc, prim_id_quad);
959 }
960 
961 static inline nir_def *
brw_nir_rt_acceleration_structure_to_root_node(nir_builder * b,nir_def * as_addr)962 brw_nir_rt_acceleration_structure_to_root_node(nir_builder *b,
963                                                nir_def *as_addr)
964 {
965    /* The HW memory structure in which we specify what acceleration structure
966     * to traverse, takes the address to the root node in the acceleration
967     * structure, not the acceleration structure itself. To find that, we have
968     * to read the root node offset from the acceleration structure which is
969     * the first QWord.
970     *
971     * But if the acceleration structure pointer is NULL, then we should return
972     * NULL as root node pointer.
973     *
974     * TODO: we could optimize this by assuming that for a given version of the
975     * BVH, we can find the root node at a given offset.
976     */
977    nir_def *root_node_ptr, *null_node_ptr;
978    nir_push_if(b, nir_ieq_imm(b, as_addr, 0));
979    {
980       null_node_ptr = nir_imm_int64(b, 0);
981    }
982    nir_push_else(b, NULL);
983    {
984       root_node_ptr =
985          nir_iadd(b, as_addr, brw_nir_rt_load(b, as_addr, 256, 1, 64));
986    }
987    nir_pop_if(b, NULL);
988 
989    return nir_if_phi(b, null_node_ptr, root_node_ptr);
990 }
991 
992 #endif /* BRW_NIR_RT_BUILDER_H */
993