xref: /aosp_15_r20/external/mesa3d/src/compiler/nir/nir_functions.c (revision 6104692788411f58d303aa86923a9ff6ecaded22)
1 /*
2  * Copyright © 2015 Intel Corporation
3  *
4  * Permission is hereby granted, free of charge, to any person obtaining a
5  * copy of this software and associated documentation files (the "Software"),
6  * to deal in the Software without restriction, including without limitation
7  * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8  * and/or sell copies of the Software, and to permit persons to whom the
9  * Software is furnished to do so, subject to the following conditions:
10  *
11  * The above copyright notice and this permission notice (including the next
12  * paragraph) shall be included in all copies or substantial portions of the
13  * Software.
14  *
15  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
18  * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20  * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21  * IN THE SOFTWARE.
22  */
23 
24 #include "nir.h"
25 #include "nir_builder.h"
26 #include "nir_control_flow.h"
27 #include "nir_vla.h"
28 
29 /*
30  * TODO: write a proper inliner for GPUs.
31  * This heuristic just inlines small functions,
32  * and tail calls get inlined as well.
33  */
34 static bool
nir_function_can_inline(nir_function * function)35 nir_function_can_inline(nir_function *function)
36 {
37    bool can_inline = true;
38    if (!function->should_inline) {
39       if (function->impl) {
40          if (function->impl->num_blocks > 2)
41             can_inline = false;
42          if (function->impl->ssa_alloc > 45)
43             can_inline = false;
44       }
45    }
46    return can_inline;
47 }
48 
49 static bool
function_ends_in_jump(nir_function_impl * impl)50 function_ends_in_jump(nir_function_impl *impl)
51 {
52    nir_block *last_block = nir_impl_last_block(impl);
53    return nir_block_ends_in_jump(last_block);
54 }
55 
56 /* A cast is used to deref function in/out params. However the bindless
57  * textures spec allows both uniforms and functions temps to be passed to a
58  * function param defined the same way. To deal with this we need to update
59  * this when we inline and know what variable mode we are dealing with.
60  */
61 static void
fixup_cast_deref_mode(nir_deref_instr * deref)62 fixup_cast_deref_mode(nir_deref_instr *deref)
63 {
64    nir_deref_instr *parent = nir_src_as_deref(deref->parent);
65    if (parent && parent->modes & nir_var_uniform &&
66        deref->modes & nir_var_function_temp) {
67       deref->modes |= nir_var_uniform;
68       deref->modes ^= nir_var_function_temp;
69 
70       nir_foreach_use(use, &deref->def) {
71          if (nir_src_parent_instr(use)->type != nir_instr_type_deref)
72             continue;
73 
74          /* Recurse into children */
75          fixup_cast_deref_mode(nir_instr_as_deref(nir_src_parent_instr(use)));
76       }
77    }
78 }
79 
80 void
nir_inline_function_impl(struct nir_builder * b,const nir_function_impl * impl,nir_def ** params,struct hash_table * shader_var_remap)81 nir_inline_function_impl(struct nir_builder *b,
82                          const nir_function_impl *impl,
83                          nir_def **params,
84                          struct hash_table *shader_var_remap)
85 {
86    nir_function_impl *copy = nir_function_impl_clone(b->shader, impl);
87 
88    exec_list_append(&b->impl->locals, &copy->locals);
89 
90    nir_foreach_block(block, copy) {
91       nir_foreach_instr_safe(instr, block) {
92          switch (instr->type) {
93          case nir_instr_type_deref: {
94             nir_deref_instr *deref = nir_instr_as_deref(instr);
95 
96             /* Note: This shouldn't change the mode of anything but the
97              * replaced nir_intrinsic_load_param intrinsics handled later in
98              * this switch table. Any incorrect modes should have already been
99              * detected by previous nir_vaidate calls.
100              */
101             if (deref->deref_type == nir_deref_type_cast) {
102                fixup_cast_deref_mode(deref);
103                break;
104             }
105 
106             if (deref->deref_type != nir_deref_type_var)
107                break;
108 
109             /* We don't need to remap function variables.  We already cloned
110              * them as part of nir_function_impl_clone and appended them to
111              * b->impl->locals.
112              */
113             if (deref->var->data.mode == nir_var_function_temp)
114                break;
115 
116             /* If no map is provided, we assume that there are either no
117              * shader variables or they already live b->shader (this is the
118              * case for function inlining within a single shader.
119              */
120             if (shader_var_remap == NULL)
121                break;
122 
123             struct hash_entry *entry =
124                _mesa_hash_table_search(shader_var_remap, deref->var);
125             if (entry == NULL) {
126                nir_variable *nvar = nir_variable_clone(deref->var, b->shader);
127                nir_shader_add_variable(b->shader, nvar);
128                entry = _mesa_hash_table_insert(shader_var_remap,
129                                                deref->var, nvar);
130             }
131             deref->var = entry->data;
132             break;
133          }
134 
135          case nir_instr_type_intrinsic: {
136             nir_intrinsic_instr *load = nir_instr_as_intrinsic(instr);
137             if (load->intrinsic != nir_intrinsic_load_param)
138                break;
139 
140             unsigned param_idx = nir_intrinsic_param_idx(load);
141             assert(param_idx < impl->function->num_params);
142             nir_def_replace(&load->def, params[param_idx]);
143             break;
144          }
145 
146          case nir_instr_type_jump:
147             /* Returns have to be lowered for this to work */
148             assert(nir_instr_as_jump(instr)->type != nir_jump_return);
149             break;
150 
151          default:
152             break;
153          }
154       }
155    }
156 
157    bool nest_if = function_ends_in_jump(copy);
158 
159    /* Pluck the body out of the function and place it here */
160    nir_cf_list body;
161    nir_cf_list_extract(&body, &copy->body);
162 
163    if (nest_if) {
164       nir_if *cf = nir_push_if(b, nir_imm_true(b));
165       nir_cf_reinsert(&body, nir_after_cf_list(&cf->then_list));
166       nir_pop_if(b, cf);
167    } else {
168       /* Insert a nop at the cursor so we can keep track of where things are as
169        * we add/remove stuff from the CFG.
170        */
171       nir_intrinsic_instr *nop = nir_nop(b);
172       nir_cf_reinsert(&body, nir_before_instr(&nop->instr));
173       b->cursor = nir_instr_remove(&nop->instr);
174    }
175 }
176 
177 static bool inline_function_impl(nir_function_impl *impl, struct set *inlined);
178 
inline_functions_pass(nir_builder * b,nir_instr * instr,void * cb_data)179 static bool inline_functions_pass(nir_builder *b,
180                                   nir_instr *instr,
181                                   void *cb_data)
182 {
183    struct set *inlined = cb_data;
184    if (instr->type != nir_instr_type_call)
185       return false;
186 
187    nir_call_instr *call = nir_instr_as_call(instr);
188    assert(call->callee->impl);
189 
190    if (b->shader->options->driver_functions &&
191        b->shader->info.stage == MESA_SHADER_KERNEL) {
192       bool last_instr = (instr == nir_block_last_instr(instr->block));
193       if (!nir_function_can_inline(call->callee) && !last_instr) {
194          return false;
195       }
196    }
197 
198    /* Make sure that the function we're calling is already inlined */
199    inline_function_impl(call->callee->impl, inlined);
200 
201    b->cursor = nir_instr_remove(&call->instr);
202 
203    /* Rewrite all of the uses of the callee's parameters to use the call
204     * instructions sources.  In order to ensure that the "load" happens
205     * here and not later (for register sources), we make sure to convert it
206     * to an SSA value first.
207     */
208    const unsigned num_params = call->num_params;
209    NIR_VLA(nir_def *, params, num_params);
210    for (unsigned i = 0; i < num_params; i++) {
211       params[i] = call->params[i].ssa;
212    }
213 
214    nir_inline_function_impl(b, call->callee->impl, params, NULL);
215    return true;
216 }
217 
218 static bool
inline_function_impl(nir_function_impl * impl,struct set * inlined)219 inline_function_impl(nir_function_impl *impl, struct set *inlined)
220 {
221    if (_mesa_set_search(inlined, impl))
222       return false; /* Already inlined */
223 
224    bool progress;
225    progress = nir_function_instructions_pass(impl, inline_functions_pass,
226                                              nir_metadata_none, inlined);
227    if (progress) {
228       /* Indices are completely messed up now */
229       nir_index_ssa_defs(impl);
230    }
231 
232    _mesa_set_add(inlined, impl);
233 
234    return progress;
235 }
236 
237 /** A pass to inline all functions in a shader into their callers
238  *
239  * For most use-cases, function inlining is a multi-step process.  The general
240  * pattern employed by SPIR-V consumers and others is as follows:
241  *
242  *  1. nir_lower_variable_initializers(shader, nir_var_function_temp)
243  *
244  *     This is needed because local variables from the callee are simply added
245  *     to the locals list for the caller and the information about where the
246  *     constant initializer logically happens is lost.  If the callee is
247  *     called in a loop, this can cause the variable to go from being
248  *     initialized once per loop iteration to being initialized once at the
249  *     top of the caller and values to persist from one invocation of the
250  *     callee to the next.  The simple solution to this problem is to get rid
251  *     of constant initializers before function inlining.
252  *
253  *  2. nir_lower_returns(shader)
254  *
255  *     nir_inline_functions assumes that all functions end "naturally" by
256  *     execution reaching the end of the function without any return
257  *     instructions causing instant jumps to the end.  Thanks to NIR being
258  *     structured, we can't represent arbitrary jumps to various points in the
259  *     program which is what an early return in the callee would have to turn
260  *     into when we inline it into the caller.  Instead, we require returns to
261  *     be lowered which lets us just copy+paste the callee directly into the
262  *     caller.
263  *
264  *  3. nir_inline_functions(shader)
265  *
266  *     This does the actual function inlining and the resulting shader will
267  *     contain no call instructions.
268  *
269  *  4. nir_opt_deref(shader)
270  *
271  *     Most functions contain pointer parameters where the result of a deref
272  *     instruction is passed in as a parameter, loaded via a load_param
273  *     intrinsic, and then turned back into a deref via a cast.  Function
274  *     inlining will get rid of the load_param but we are still left with a
275  *     cast.  Running nir_opt_deref gets rid of the intermediate cast and
276  *     results in a whole deref chain again.  This is currently required by a
277  *     number of optimizations and lowering passes at least for certain
278  *     variable modes.
279  *
280  *  5. Loop over the functions and delete all but the main entrypoint.
281  *
282  *     In the Intel Vulkan driver this looks like this:
283  *
284  *        nir_remove_non_entrypoints(nir);
285  *
286  *    While nir_inline_functions does get rid of all call instructions, it
287  *    doesn't get rid of any functions because it doesn't know what the "root
288  *    function" is.  Instead, it's up to the individual driver to know how to
289  *    decide on a root function and delete the rest.  With SPIR-V,
290  *    spirv_to_nir returns the root function and so we can just use == whereas
291  *    with GL, you may have to look for a function named "main".
292  *
293  *  6. nir_lower_variable_initializers(shader, ~nir_var_function_temp)
294  *
295  *     Lowering constant initializers on inputs, outputs, global variables,
296  *     etc. requires that we know the main entrypoint so that we know where to
297  *     initialize them.  Otherwise, we would have to assume that anything
298  *     could be a main entrypoint and initialize them at the start of every
299  *     function but that would clearly be wrong if any of those functions were
300  *     ever called within another function.  Simply requiring a single-
301  *     entrypoint function shader is the best way to make it well-defined.
302  */
303 bool
nir_inline_functions(nir_shader * shader)304 nir_inline_functions(nir_shader *shader)
305 {
306    struct set *inlined = _mesa_pointer_set_create(NULL);
307    bool progress = false;
308 
309    nir_foreach_function_impl(impl, shader) {
310       progress = inline_function_impl(impl, inlined) || progress;
311    }
312 
313    _mesa_set_destroy(inlined, NULL);
314 
315    return progress;
316 }
317 
318 struct lower_link_state {
319    struct hash_table *shader_var_remap;
320    const nir_shader *link_shader;
321    unsigned printf_index_offset;
322 };
323 
324 static bool
lower_calls_vars_instr(struct nir_builder * b,nir_instr * instr,void * cb_data)325 lower_calls_vars_instr(struct nir_builder *b,
326                        nir_instr *instr,
327                        void *cb_data)
328 {
329    struct lower_link_state *state = cb_data;
330 
331    switch (instr->type) {
332    case nir_instr_type_deref: {
333       nir_deref_instr *deref = nir_instr_as_deref(instr);
334       if (deref->deref_type != nir_deref_type_var)
335          return false;
336       if (deref->var->data.mode == nir_var_function_temp)
337          return false;
338 
339       assert(state->shader_var_remap);
340       struct hash_entry *entry =
341          _mesa_hash_table_search(state->shader_var_remap, deref->var);
342       if (entry == NULL) {
343          nir_variable *nvar = nir_variable_clone(deref->var, b->shader);
344          nir_shader_add_variable(b->shader, nvar);
345          entry = _mesa_hash_table_insert(state->shader_var_remap,
346                                          deref->var, nvar);
347       }
348       deref->var = entry->data;
349       break;
350    }
351    case nir_instr_type_call: {
352       nir_call_instr *ncall = nir_instr_as_call(instr);
353       if (!ncall->callee->name)
354          return false;
355 
356       nir_function *func = nir_shader_get_function_for_name(b->shader, ncall->callee->name);
357       if (func) {
358          ncall->callee = func;
359          break;
360       }
361 
362       nir_function *new_func;
363       new_func = nir_shader_get_function_for_name(state->link_shader, ncall->callee->name);
364       if (new_func)
365          ncall->callee = nir_function_clone(b->shader, new_func);
366       break;
367    }
368    case nir_instr_type_intrinsic: {
369       /* Reindex the offset of the printf intrinsic by the number of already
370        * present printfs in the shader where functions are linked into.
371        */
372       if (state->printf_index_offset == 0)
373          return false;
374 
375       nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
376       if (intrin->intrinsic != nir_intrinsic_printf)
377          return false;
378 
379       b->cursor = nir_before_instr(instr);
380       nir_src_rewrite(&intrin->src[0],
381                       nir_iadd_imm(b, intrin->src[0].ssa,
382                                       state->printf_index_offset));
383       break;
384    }
385    default:
386       break;
387    }
388    return true;
389 }
390 
391 static bool
lower_call_function_impl(struct nir_builder * b,nir_function * callee,const nir_function_impl * impl,struct lower_link_state * state)392 lower_call_function_impl(struct nir_builder *b,
393                          nir_function *callee,
394                          const nir_function_impl *impl,
395                          struct lower_link_state *state)
396 {
397    nir_function_impl *copy = nir_function_impl_clone(b->shader, impl);
398    copy->function = callee;
399    callee->impl = copy;
400 
401    return nir_function_instructions_pass(copy,
402                                          lower_calls_vars_instr,
403                                          nir_metadata_none,
404                                          state);
405 }
406 
407 static bool
function_link_pass(struct nir_builder * b,nir_instr * instr,void * cb_data)408 function_link_pass(struct nir_builder *b,
409                    nir_instr *instr,
410                    void *cb_data)
411 {
412    struct lower_link_state *state = cb_data;
413 
414    if (instr->type != nir_instr_type_call)
415       return false;
416 
417    nir_call_instr *call = nir_instr_as_call(instr);
418    nir_function *func = NULL;
419 
420    if (!call->callee->name)
421       return false;
422 
423    if (call->callee->impl)
424       return false;
425 
426    func = nir_shader_get_function_for_name(state->link_shader, call->callee->name);
427    if (!func || !func->impl) {
428       return false;
429    }
430    return lower_call_function_impl(b, call->callee,
431                                    func->impl,
432                                    state);
433 }
434 
435 bool
nir_link_shader_functions(nir_shader * shader,const nir_shader * link_shader)436 nir_link_shader_functions(nir_shader *shader,
437                           const nir_shader *link_shader)
438 {
439    void *ra_ctx = ralloc_context(NULL);
440    struct hash_table *copy_vars = _mesa_pointer_hash_table_create(ra_ctx);
441    bool progress = false, overall_progress = false;
442 
443    struct lower_link_state state = {
444       .shader_var_remap = copy_vars,
445       .link_shader = link_shader,
446       .printf_index_offset = shader->printf_info_count,
447    };
448    /* do progress passes inside the pass */
449    do {
450       progress = false;
451       nir_foreach_function_impl(impl, shader) {
452          bool this_progress = nir_function_instructions_pass(impl,
453                                                              function_link_pass,
454                                                              nir_metadata_none,
455                                                              &state);
456          if (this_progress)
457             nir_index_ssa_defs(impl);
458          progress |= this_progress;
459       }
460       overall_progress |= progress;
461    } while (progress);
462 
463    if (overall_progress && link_shader->printf_info_count > 0) {
464       shader->printf_info = reralloc(shader, shader->printf_info,
465                                      u_printf_info,
466                                      shader->printf_info_count +
467                                      link_shader->printf_info_count);
468 
469       for (unsigned i = 0; i < link_shader->printf_info_count; i++){
470          const u_printf_info *src_info = &link_shader->printf_info[i];
471          u_printf_info *dst_info = &shader->printf_info[shader->printf_info_count++];
472 
473          dst_info->num_args = src_info->num_args;
474          dst_info->arg_sizes = ralloc_array(shader, unsigned, dst_info->num_args);
475          memcpy(dst_info->arg_sizes, src_info->arg_sizes,
476                 sizeof(dst_info->arg_sizes[0]) * dst_info->num_args);
477 
478          dst_info->string_size = src_info->string_size;
479          dst_info->strings = ralloc_memdup(shader, src_info->strings,
480                                            dst_info->string_size);
481       }
482    }
483 
484    ralloc_free(ra_ctx);
485 
486    return overall_progress;
487 }
488 
489 static void
490 nir_mark_used_functions(struct nir_function *func, struct set *used_funcs);
491 
mark_used_pass_cb(struct nir_builder * b,nir_instr * instr,void * data)492 static bool mark_used_pass_cb(struct nir_builder *b,
493                               nir_instr *instr, void *data)
494 {
495    struct set *used_funcs = data;
496    if (instr->type != nir_instr_type_call)
497       return false;
498    nir_call_instr *call = nir_instr_as_call(instr);
499 
500    _mesa_set_add(used_funcs, call->callee);
501 
502    nir_mark_used_functions(call->callee, used_funcs);
503    return true;
504 }
505 
506 static void
nir_mark_used_functions(struct nir_function * func,struct set * used_funcs)507 nir_mark_used_functions(struct nir_function *func, struct set *used_funcs)
508 {
509    if (func->impl) {
510       nir_function_instructions_pass(func->impl,
511                                      mark_used_pass_cb,
512                                      nir_metadata_none,
513                                      used_funcs);
514    }
515 }
516 
517 void
nir_cleanup_functions(nir_shader * nir)518 nir_cleanup_functions(nir_shader *nir)
519 {
520    if (!nir->options->driver_functions) {
521       nir_remove_non_entrypoints(nir);
522       return;
523    }
524 
525    struct set *used_funcs = _mesa_set_create(NULL, _mesa_hash_pointer,
526                                              _mesa_key_pointer_equal);
527    foreach_list_typed_safe(nir_function, func, node, &nir->functions) {
528       if (func->is_entrypoint) {
529          _mesa_set_add(used_funcs, func);
530          nir_mark_used_functions(func, used_funcs);
531       }
532    }
533    foreach_list_typed_safe(nir_function, func, node, &nir->functions) {
534       if (!_mesa_set_search(used_funcs, func))
535          exec_node_remove(&func->node);
536    }
537    _mesa_set_destroy(used_funcs, NULL);
538 }
539