xref: /aosp_15_r20/external/mesa3d/src/nouveau/compiler/nak_nir_lower_cf.c (revision 6104692788411f58d303aa86923a9ff6ecaded22)
1 /*
2  * Copyright © 2022 Collabora, Ltd.
3  * SPDX-License-Identifier: MIT
4  */
5 
6 #include "nak_private.h"
7 #include "nir_builder.h"
8 
9 static void
push_block(nir_builder * b,nir_block * block,bool divergent)10 push_block(nir_builder *b, nir_block *block, bool divergent)
11 {
12    assert(nir_cursors_equal(b->cursor, nir_after_impl(b->impl)));
13    block->divergent = divergent;
14    block->cf_node.parent = &b->impl->cf_node;
15    exec_list_push_tail(&b->impl->body, &block->cf_node.node);
16    b->cursor = nir_after_block(block);
17 }
18 
19 enum scope_type {
20    SCOPE_TYPE_SHADER,
21    SCOPE_TYPE_IF_MERGE,
22    SCOPE_TYPE_LOOP_BREAK,
23    SCOPE_TYPE_LOOP_CONT,
24 };
25 
26 struct scope {
27    enum scope_type type;
28 
29    struct scope *parent;
30    uint32_t depth;
31 
32    /**
33     * True if control-flow ever diverges within this scope, not accounting
34     * for divergence in child scopes.
35     */
36    bool divergent;
37 
38    nir_block *merge;
39    nir_def *bar;
40 
41    uint32_t escapes;
42 };
43 
44 static struct scope
push_scope(nir_builder * b,enum scope_type scope_type,struct scope * parent,bool divergent,bool needs_sync,nir_block * merge_block)45 push_scope(nir_builder *b,
46            enum scope_type scope_type,
47            struct scope *parent,
48            bool divergent,
49            bool needs_sync,
50            nir_block *merge_block)
51 {
52    struct scope scope = {
53       .type = scope_type,
54       .parent = parent,
55       .depth = parent->depth + 1,
56       .divergent = parent->divergent || divergent,
57       .merge = merge_block,
58    };
59 
60    if (needs_sync)
61       scope.bar = nir_bar_set_nv(b);
62 
63    return scope;
64 }
65 
66 static void
pop_scope(nir_builder * b,nir_def * esc_reg,struct scope scope)67 pop_scope(nir_builder *b, nir_def *esc_reg, struct scope scope)
68 {
69    if (scope.bar == NULL)
70       return;
71 
72    nir_bar_sync_nv(b, scope.bar, scope.bar);
73 
74    if (scope.escapes > 0) {
75       /* Find the nearest scope with a sync. */
76       nir_block *parent_merge = b->impl->end_block;
77       for (struct scope *p = scope.parent; p != NULL; p = p->parent) {
78          if (p->bar != NULL) {
79             parent_merge = p->merge;
80             break;
81          }
82       }
83 
84       /* No escape is ~0, halt is 0, and we choose outer scope indices such
85        * that outer scopes always have lower indices than inner scopes.
86        */
87       nir_def *esc = nir_ult_imm(b, nir_load_reg(b, esc_reg), scope.depth);
88 
89       /* We have to put the escape in its own block to avoid critical edges.
90        * If we just did goto_if, we would end up with multiple successors,
91        * including a jump to the parent's merge block which has multiple
92        * predecessors.
93        */
94       nir_block *esc_block = nir_block_create(b->shader);
95       nir_block *next_block = nir_block_create(b->shader);
96       nir_goto_if(b, esc_block, esc, next_block);
97       push_block(b, esc_block, false);
98       nir_goto(b, parent_merge);
99       push_block(b, next_block, scope.parent->divergent);
100    }
101 }
102 
103 static enum scope_type
jump_target_scope_type(nir_jump_type jump_type)104 jump_target_scope_type(nir_jump_type jump_type)
105 {
106    switch (jump_type) {
107    case nir_jump_break:    return SCOPE_TYPE_LOOP_BREAK;
108    case nir_jump_continue: return SCOPE_TYPE_LOOP_CONT;
109    default:
110       unreachable("Unknown jump type");
111    }
112 }
113 
114 static void
break_scopes(nir_builder * b,nir_def * esc_reg,struct scope * current_scope,nir_jump_type jump_type)115 break_scopes(nir_builder *b, nir_def *esc_reg,
116              struct scope *current_scope,
117              nir_jump_type jump_type)
118 {
119    nir_block *first_sync = NULL;
120    uint32_t target_depth = UINT32_MAX;
121    enum scope_type target_scope_type = jump_target_scope_type(jump_type);
122    for (struct scope *scope = current_scope; scope; scope = scope->parent) {
123       if (first_sync == NULL && scope->bar != NULL)
124          first_sync = scope->merge;
125 
126       if (scope->type == target_scope_type) {
127          if (first_sync == NULL) {
128             first_sync = scope->merge;
129          } else {
130             /* In order for our cascade to work, we need to have the invariant
131              * that anything which escapes any scope with a warp sync needs to
132              * target a scope with a warp sync.
133              */
134             assert(scope->bar != NULL);
135          }
136          target_depth = scope->depth;
137          break;
138       } else {
139          scope->escapes++;
140       }
141    }
142    assert(target_depth < UINT32_MAX);
143 
144    nir_store_reg(b, nir_imm_int(b, target_depth), esc_reg);
145    nir_goto(b, first_sync);
146 }
147 
148 static void
normal_exit(nir_builder * b,nir_def * esc_reg,nir_block * merge_block)149 normal_exit(nir_builder *b, nir_def *esc_reg, nir_block *merge_block)
150 {
151    assert(nir_cursors_equal(b->cursor, nir_after_impl(b->impl)));
152    nir_block *block = nir_cursor_current_block(b->cursor);
153 
154    if (!nir_block_ends_in_jump(block)) {
155       nir_store_reg(b, nir_imm_int(b, ~0), esc_reg);
156       nir_goto(b, merge_block);
157    }
158 }
159 
160 /* This is a heuristic for what instructions are allowed before we sync.
161  * Annoyingly, we've gotten rid of phis so it's not as simple as "is it a
162  * phi?".
163  */
164 static bool
instr_is_allowed_before_sync(nir_instr * instr)165 instr_is_allowed_before_sync(nir_instr *instr)
166 {
167    switch (instr->type) {
168    case nir_instr_type_alu: {
169       nir_alu_instr *alu = nir_instr_as_alu(instr);
170       /* We could probably allow more ALU as long as it doesn't contain
171        * derivatives but let's be conservative and only allow mov for now.
172        */
173       return alu->op == nir_op_mov;
174    }
175 
176    case nir_instr_type_intrinsic: {
177       nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
178       return intrin->intrinsic == nir_intrinsic_load_reg ||
179              intrin->intrinsic == nir_intrinsic_store_reg;
180    }
181 
182    default:
183       return false;
184    }
185 }
186 
187 /** Returns true if our successor will sync for us
188  *
189  * This is a bit of a heuristic
190  */
191 static bool
parent_scope_will_sync(nir_cf_node * node,struct scope * parent_scope)192 parent_scope_will_sync(nir_cf_node *node, struct scope *parent_scope)
193 {
194    /* First search forward to see if there's anything non-trivial after this
195     * node within the parent scope.
196     */
197    nir_block *block = nir_cf_node_as_block(nir_cf_node_next(node));
198    nir_foreach_instr(instr, block) {
199       if (!instr_is_allowed_before_sync(instr))
200          return false;
201    }
202 
203    /* There's another loop or if following and we didn't find a sync */
204    if (nir_cf_node_next(&block->cf_node))
205       return false;
206 
207    /* See if the parent scope will sync for us. */
208    if (parent_scope->bar != NULL)
209       return true;
210 
211    switch (parent_scope->type) {
212    case SCOPE_TYPE_SHADER:
213       return true;
214 
215    case SCOPE_TYPE_IF_MERGE:
216       return parent_scope_will_sync(block->cf_node.parent,
217                                     parent_scope->parent);
218 
219    case SCOPE_TYPE_LOOP_CONT:
220       /* In this case, the loop doesn't have a sync of its own so we're
221        * expected to be uniform before we hit the continue.
222        */
223       return false;
224 
225    case SCOPE_TYPE_LOOP_BREAK:
226       unreachable("Loops must have a continue scope");
227 
228    default:
229       unreachable("Unknown scope type");
230    }
231 }
232 
233 static bool
block_is_merge(const nir_block * block)234 block_is_merge(const nir_block *block)
235 {
236    /* If it's unreachable, there is no merge */
237    if (block->imm_dom == NULL)
238       return false;
239 
240    unsigned num_preds = 0;
241    set_foreach(block->predecessors, entry) {
242       const nir_block *pred = entry->key;
243 
244       /* We don't care about unreachable blocks */
245       if (pred->imm_dom == NULL)
246          continue;
247 
248       num_preds++;
249    }
250 
251    return num_preds > 1;
252 }
253 
254 static void
lower_cf_list(nir_builder * b,nir_def * esc_reg,struct scope * parent_scope,struct exec_list * cf_list)255 lower_cf_list(nir_builder *b, nir_def *esc_reg, struct scope *parent_scope,
256               struct exec_list *cf_list)
257 {
258    foreach_list_typed_safe(nir_cf_node, node, node, cf_list) {
259       switch (node->type) {
260       case nir_cf_node_block: {
261          nir_block *block = nir_cf_node_as_block(node);
262          if (exec_list_is_empty(&block->instr_list))
263             break;
264 
265          nir_cursor start = nir_before_block(block);
266          nir_cursor end = nir_after_block(block);
267 
268          nir_jump_instr *jump = NULL;
269          nir_instr *last_instr = nir_block_last_instr(block);
270          if (last_instr->type == nir_instr_type_jump) {
271             jump = nir_instr_as_jump(last_instr);
272             end = nir_before_instr(&jump->instr);
273          }
274 
275          nir_cf_list instrs;
276          nir_cf_extract(&instrs, start, end);
277          b->cursor = nir_cf_reinsert(&instrs, b->cursor);
278 
279          if (jump != NULL) {
280             if (jump->type == nir_jump_halt) {
281                /* Halt instructions map to OpExit on NVIDIA hardware and
282                 * exited lanes never block a bsync.
283                 */
284                nir_instr_remove(&jump->instr);
285                nir_builder_instr_insert(b, &jump->instr);
286             } else {
287                /* Everything else needs a break cascade */
288                break_scopes(b, esc_reg, parent_scope, jump->type);
289             }
290          }
291          break;
292       }
293 
294       case nir_cf_node_if: {
295          nir_if *nif = nir_cf_node_as_if(node);
296 
297          nir_def *cond = nif->condition.ssa;
298          nir_instr_clear_src(NULL, &nif->condition);
299 
300          nir_block *then_block = nir_block_create(b->shader);
301          nir_block *else_block = nir_block_create(b->shader);
302          nir_block *merge_block = nir_block_create(b->shader);
303 
304          const bool needs_sync = cond->divergent &&
305             block_is_merge(nir_cf_node_as_block(nir_cf_node_next(node))) &&
306             !parent_scope_will_sync(&nif->cf_node, parent_scope);
307 
308          struct scope scope = push_scope(b, SCOPE_TYPE_IF_MERGE,
309                                          parent_scope, cond->divergent,
310                                          needs_sync, merge_block);
311 
312          nir_goto_if(b, then_block, cond, else_block);
313 
314          push_block(b, then_block, scope.divergent);
315          lower_cf_list(b, esc_reg, &scope, &nif->then_list);
316          normal_exit(b, esc_reg, merge_block);
317 
318          push_block(b, else_block, scope.divergent);
319          lower_cf_list(b, esc_reg, &scope, &nif->else_list);
320          normal_exit(b, esc_reg, merge_block);
321 
322          push_block(b, merge_block, parent_scope->divergent);
323          pop_scope(b, esc_reg, scope);
324 
325          break;
326       }
327 
328       case nir_cf_node_loop: {
329          nir_loop *loop = nir_cf_node_as_loop(node);
330 
331          nir_block *head_block = nir_block_create(b->shader);
332          nir_block *break_block = nir_block_create(b->shader);
333          nir_block *cont_block = nir_block_create(b->shader);
334 
335          /* TODO: We can potentially avoid the break sync for loops when the
336           * parent scope syncs for us.  However, we still need to handle the
337           * continue clause cascading to the break.  If there is a
338           * nir_jump_halt involved, then we have a real cascade where it needs
339           * to then jump to the next scope.  Getting all these cases right
340           * while avoiding an extra sync for the loop break is tricky at best.
341           */
342          struct scope break_scope = push_scope(b, SCOPE_TYPE_LOOP_BREAK,
343                                                parent_scope, loop->divergent,
344                                                loop->divergent, break_block);
345 
346          nir_goto(b, head_block);
347          push_block(b, head_block, break_scope.divergent);
348 
349          struct scope cont_scope = push_scope(b, SCOPE_TYPE_LOOP_CONT,
350                                               &break_scope, loop->divergent,
351                                               loop->divergent, cont_block);
352 
353          lower_cf_list(b, esc_reg, &cont_scope, &loop->body);
354          normal_exit(b, esc_reg, cont_block);
355 
356          push_block(b, cont_block, break_scope.divergent);
357 
358          pop_scope(b, esc_reg, cont_scope);
359 
360          lower_cf_list(b, esc_reg, &break_scope, &loop->continue_list);
361 
362          nir_goto(b, head_block);
363          push_block(b, break_block, parent_scope->divergent);
364 
365          pop_scope(b, esc_reg, break_scope);
366 
367          break;
368       }
369 
370       default:
371          unreachable("Unknown CF node type");
372       }
373    }
374 }
375 
376 static void
recompute_phi_divergence_impl(nir_function_impl * impl)377 recompute_phi_divergence_impl(nir_function_impl *impl)
378 {
379    bool progress;
380    do {
381       progress = false;
382       nir_foreach_block_unstructured(block, impl) {
383          nir_foreach_instr(instr, block) {
384             if (instr->type != nir_instr_type_phi)
385                break;
386 
387             nir_phi_instr *phi = nir_instr_as_phi(instr);
388 
389             bool divergent = false;
390             nir_foreach_phi_src(phi_src, phi) {
391                /* There is a tricky case we need to care about here where a
392                 * convergent block has a divergent dominator.  This can happen
393                 * if, for instance, you have the following loop:
394                 *
395                 *    loop {
396                 *       if (div) {
397                 *          %20 = load_ubo(0, 0);
398                 *       } else {
399                 *          terminate;
400                 *       }
401                 *    }
402                 *    use(%20);
403                 *
404                 * In this case, the load_ubo() dominates the use() even though
405                 * the load_ubo() exists in divergent control-flow.  In this
406                 * case, we simply flag the whole phi divergent because we
407                 * don't want to deal with inserting a r2ur somewhere.
408                 */
409                if (phi_src->pred->divergent || phi_src->src.ssa->divergent ||
410                    phi_src->src.ssa->parent_instr->block->divergent) {
411                   divergent = true;
412                   break;
413                }
414             }
415 
416             if (divergent != phi->def.divergent) {
417                phi->def.divergent = divergent;
418                progress = true;
419             }
420          }
421       }
422    } while(progress);
423 }
424 
425 static bool
lower_cf_func(nir_function * func)426 lower_cf_func(nir_function *func)
427 {
428    if (func->impl == NULL)
429       return false;
430 
431    if (exec_list_is_singular(&func->impl->body)) {
432       nir_metadata_preserve(func->impl, nir_metadata_all);
433       return false;
434    }
435 
436    nir_function_impl *old_impl = func->impl;
437 
438    /* We use this in block_is_merge() */
439    nir_metadata_require(old_impl, nir_metadata_dominance);
440 
441    /* First, we temporarily get rid of SSA.  This will make all our block
442     * motion way easier.
443     */
444    nir_foreach_block(block, old_impl)
445       nir_lower_phis_to_regs_block(block);
446 
447    /* We create a whole new nir_function_impl and copy the contents over */
448    func->impl = NULL;
449    nir_function_impl *new_impl = nir_function_impl_create(func);
450    new_impl->structured = false;
451 
452    /* We copy defs from the old impl */
453    new_impl->ssa_alloc = old_impl->ssa_alloc;
454 
455    nir_builder b = nir_builder_at(nir_before_impl(new_impl));
456    nir_def *esc_reg = nir_decl_reg(&b, 1, 32, 0);
457 
458    /* Having a function scope makes everything easier */
459    struct scope scope = {
460       .type = SCOPE_TYPE_SHADER,
461       .merge = new_impl->end_block,
462    };
463    lower_cf_list(&b, esc_reg, &scope, &old_impl->body);
464    normal_exit(&b, esc_reg, new_impl->end_block);
465 
466    /* Now sort by reverse PDFS and restore SSA
467     *
468     * Note: Since we created a new nir_function_impl, there is no metadata,
469     * dirty or otherwise, so we have no need to call nir_metadata_preserve().
470     */
471    nir_sort_unstructured_blocks(new_impl);
472    nir_repair_ssa_impl(new_impl);
473    nir_lower_reg_intrinsics_to_ssa_impl(new_impl);
474    recompute_phi_divergence_impl(new_impl);
475 
476    return true;
477 }
478 
479 bool
nak_nir_lower_cf(nir_shader * nir)480 nak_nir_lower_cf(nir_shader *nir)
481 {
482    bool progress = false;
483 
484    nir_foreach_function(func, nir) {
485       if (lower_cf_func(func))
486          progress = true;
487    }
488 
489    return progress;
490 }
491