xref: /aosp_15_r20/external/mesa3d/src/intel/compiler/brw_nir_lower_intersection_shader.c (revision 6104692788411f58d303aa86923a9ff6ecaded22)
1 /*
2  * Copyright (c) 2020 Intel Corporation
3  *
4  * Permission is hereby granted, free of charge, to any person obtaining a
5  * copy of this software and associated documentation files (the "Software"),
6  * to deal in the Software without restriction, including without limitation
7  * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8  * and/or sell copies of the Software, and to permit persons to whom the
9  * Software is furnished to do so, subject to the following conditions:
10  *
11  * The above copyright notice and this permission notice (including the next
12  * paragraph) shall be included in all copies or substantial portions of the
13  * Software.
14  *
15  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
18  * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20  * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21  * IN THE SOFTWARE.
22  */
23 
24 #include "brw_nir_rt.h"
25 #include "brw_nir_rt_builder.h"
26 
27 static nir_function_impl *
lower_any_hit_for_intersection(nir_shader * any_hit)28 lower_any_hit_for_intersection(nir_shader *any_hit)
29 {
30    nir_function_impl *impl = nir_shader_get_entrypoint(any_hit);
31 
32    /* Any-hit shaders need three parameters */
33    assert(impl->function->num_params == 0);
34    nir_parameter params[] = {
35       {
36          /* A pointer to a boolean value for whether or not the hit was
37           * accepted.
38           */
39          .num_components = 1,
40          .bit_size = 32,
41       },
42       {
43          /* The hit T value */
44          .num_components = 1,
45          .bit_size = 32,
46       },
47       {
48          /* The hit kind */
49          .num_components = 1,
50          .bit_size = 32,
51       },
52    };
53    impl->function->num_params = ARRAY_SIZE(params);
54    impl->function->params =
55       ralloc_array(any_hit, nir_parameter, ARRAY_SIZE(params));
56    memcpy(impl->function->params, params, sizeof(params));
57 
58    nir_builder build = nir_builder_at(nir_before_impl(impl));
59    nir_builder *b = &build;
60 
61    nir_def *commit_ptr = nir_load_param(b, 0);
62    nir_def *hit_t = nir_load_param(b, 1);
63    nir_def *hit_kind = nir_load_param(b, 2);
64 
65    nir_deref_instr *commit =
66       nir_build_deref_cast(b, commit_ptr, nir_var_function_temp,
67                            glsl_bool_type(), 0);
68 
69    nir_foreach_block_safe(block, impl) {
70       nir_foreach_instr_safe(instr, block) {
71          switch (instr->type) {
72          case nir_instr_type_intrinsic: {
73             nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
74             switch (intrin->intrinsic) {
75             case nir_intrinsic_ignore_ray_intersection:
76                b->cursor = nir_instr_remove(&intrin->instr);
77                /* We put the newly emitted code inside a dummy if because it's
78                 * going to contain a jump instruction and we don't want to
79                 * deal with that mess here.  It'll get dealt with by our
80                 * control-flow optimization passes.
81                 */
82                nir_store_deref(b, commit, nir_imm_false(b), 0x1);
83                nir_push_if(b, nir_imm_true(b));
84                nir_jump(b, nir_jump_return);
85                nir_pop_if(b, NULL);
86                break;
87 
88             case nir_intrinsic_terminate_ray:
89                /* The "normal" handling of terminateRay works fine in
90                 * intersection shaders.
91                 */
92                break;
93 
94             case nir_intrinsic_load_ray_t_max:
95                nir_def_replace(&intrin->def, hit_t);
96                break;
97 
98             case nir_intrinsic_load_ray_hit_kind:
99                nir_def_replace(&intrin->def, hit_kind);
100                break;
101 
102             default:
103                break;
104             }
105             break;
106          }
107 
108          case nir_instr_type_jump: {
109             /* Stomp any halts to returns since they only return from the
110              * any-hit shader and not necessarily from the intersection
111              * shader.  This is safe to do because we've already asserted
112              * that we only have the one function.
113              */
114             nir_jump_instr *jump = nir_instr_as_jump(instr);
115             if (jump->type == nir_jump_halt)
116                jump->type = nir_jump_return;
117             break;
118          }
119 
120          default:
121             break;
122          }
123       }
124    }
125 
126    nir_validate_shader(any_hit, "after initial any-hit lowering");
127 
128    nir_lower_returns_impl(impl);
129 
130    nir_validate_shader(any_hit, "after lowering returns");
131 
132    return impl;
133 }
134 
135 static void
build_accept_ray(nir_builder * b)136 build_accept_ray(nir_builder *b)
137 {
138    /* Set the "valid" bit in mem_hit */
139    nir_def *ray_addr = brw_nir_rt_mem_hit_addr(b, false /* committed */);
140    nir_def *flags_dw_addr = nir_iadd_imm(b, ray_addr, 12);
141    nir_store_global(b, flags_dw_addr, 4,
142                     nir_ior(b, nir_load_global(b, flags_dw_addr, 4, 1, 32),
143                             nir_imm_int(b, 1 << 16)), 0x1 /* write_mask */);
144 
145    nir_accept_ray_intersection(b);
146 }
147 
148 void
brw_nir_lower_intersection_shader(nir_shader * intersection,const nir_shader * any_hit,const struct intel_device_info * devinfo)149 brw_nir_lower_intersection_shader(nir_shader *intersection,
150                                   const nir_shader *any_hit,
151                                   const struct intel_device_info *devinfo)
152 {
153    void *dead_ctx = ralloc_context(intersection);
154 
155    nir_function_impl *any_hit_impl = NULL;
156    struct hash_table *any_hit_var_remap = NULL;
157    if (any_hit) {
158       nir_shader *any_hit_tmp = nir_shader_clone(dead_ctx, any_hit);
159       NIR_PASS_V(any_hit_tmp, nir_opt_dce);
160       any_hit_impl = lower_any_hit_for_intersection(any_hit_tmp);
161       any_hit_var_remap = _mesa_pointer_hash_table_create(dead_ctx);
162    }
163 
164    nir_function_impl *impl = nir_shader_get_entrypoint(intersection);
165 
166    nir_builder build = nir_builder_at(nir_before_impl(impl));
167    nir_builder *b = &build;
168 
169    nir_def *t_addr = brw_nir_rt_mem_hit_addr(b, false /* committed */);
170    nir_variable *commit =
171       nir_local_variable_create(impl, glsl_bool_type(), "ray_commit");
172    nir_store_var(b, commit, nir_imm_false(b), 0x1);
173 
174    assert(impl->end_block->predecessors->entries == 1);
175    set_foreach(impl->end_block->predecessors, block_entry) {
176       struct nir_block *block = (void *)block_entry->key;
177       b->cursor = nir_after_block_before_jump(block);
178       nir_push_if(b, nir_load_var(b, commit));
179       {
180          build_accept_ray(b);
181       }
182       nir_push_else(b, NULL);
183       {
184          nir_ignore_ray_intersection(b);
185       }
186       nir_pop_if(b, NULL);
187       break;
188    }
189 
190    nir_foreach_block_safe(block, impl) {
191       nir_foreach_instr_safe(instr, block) {
192          switch (instr->type) {
193          case nir_instr_type_intrinsic: {
194             nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
195             switch (intrin->intrinsic) {
196             case nir_intrinsic_report_ray_intersection: {
197                b->cursor = nir_instr_remove(&intrin->instr);
198                nir_def *hit_t = intrin->src[0].ssa;
199                nir_def *hit_kind = intrin->src[1].ssa;
200                nir_def *min_t = nir_load_ray_t_min(b);
201 
202                struct brw_nir_rt_mem_ray_defs ray_def;
203                brw_nir_rt_load_mem_ray(b, &ray_def, BRW_RT_BVH_LEVEL_WORLD);
204 
205                struct brw_nir_rt_mem_hit_defs hit_in = {};
206                brw_nir_rt_load_mem_hit(b, &hit_in, false);
207 
208                nir_def *max_t = ray_def.t_far;
209 
210                /* bool commit_tmp = false; */
211                nir_variable *commit_tmp =
212                   nir_local_variable_create(impl, glsl_bool_type(),
213                                             "commit_tmp");
214                nir_store_var(b, commit_tmp, nir_imm_false(b), 0x1);
215 
216                nir_push_if(b, nir_iand(b, nir_fge(b, hit_t, min_t),
217                                           nir_fge(b, max_t, hit_t)));
218                {
219                   /* Any-hit defaults to commit */
220                   nir_store_var(b, commit_tmp, nir_imm_true(b), 0x1);
221 
222                   if (any_hit_impl != NULL) {
223                      nir_push_if(b, nir_inot(b, nir_load_leaf_opaque_intel(b)));
224                      {
225                         nir_def *params[] = {
226                            &nir_build_deref_var(b, commit_tmp)->def,
227                            hit_t,
228                            hit_kind,
229                         };
230                         nir_inline_function_impl(b, any_hit_impl, params,
231                                                  any_hit_var_remap);
232                      }
233                      nir_pop_if(b, NULL);
234                   }
235 
236                   nir_push_if(b, nir_load_var(b, commit_tmp));
237                   {
238                      nir_store_var(b, commit, nir_imm_true(b), 0x1);
239 
240                      nir_def *ray_addr =
241                         brw_nir_rt_mem_ray_addr(b, brw_nir_rt_stack_addr(b), BRW_RT_BVH_LEVEL_WORLD);
242 
243                      nir_store_global(b, nir_iadd_imm(b, ray_addr, 16 + 12), 4,  hit_t, 0x1);
244                      nir_store_global(b, t_addr, 4,
245                                       nir_vec2(b, nir_fmin(b, hit_t, hit_in.t), hit_kind),
246                                       0x3);
247 
248                      /* There may be multiple reportIntersection() calls in
249                       * the shader, so if terminateOnFirstHit was requested,
250                       * accept the hit now. The lowering of
251                       * accept_ray_intersection will handle the rest.
252                       */
253                      nir_def *terminate = nir_test_mask(b, nir_load_ray_flags(b),
254                                                         BRW_RT_RAY_FLAG_TERMINATE_ON_FIRST_HIT);
255                      nir_push_if(b, terminate);
256                      {
257                         build_accept_ray(b);
258                      }
259                      nir_pop_if(b, NULL);
260                   }
261                   nir_pop_if(b, NULL);
262                }
263                nir_pop_if(b, NULL);
264 
265                nir_def *accepted = nir_load_var(b, commit_tmp);
266                nir_def_rewrite_uses(&intrin->def,
267                                         accepted);
268                break;
269             }
270 
271             default:
272                break;
273             }
274             break;
275          }
276 
277          default:
278             break;
279          }
280       }
281    }
282    nir_metadata_preserve(impl, nir_metadata_none);
283 
284    /* We did some inlining; have to re-index SSA defs */
285    nir_index_ssa_defs(impl);
286 
287    ralloc_free(dead_ctx);
288 }
289