1 /*
2 * Copyright © 2022 Collabora, Ltd.
3 * SPDX-License-Identifier: MIT
4 */
5
6 #include "nak_private.h"
7 #include "nir_builder.h"
8
9 static void
push_block(nir_builder * b,nir_block * block,bool divergent)10 push_block(nir_builder *b, nir_block *block, bool divergent)
11 {
12 assert(nir_cursors_equal(b->cursor, nir_after_impl(b->impl)));
13 block->divergent = divergent;
14 block->cf_node.parent = &b->impl->cf_node;
15 exec_list_push_tail(&b->impl->body, &block->cf_node.node);
16 b->cursor = nir_after_block(block);
17 }
18
19 enum scope_type {
20 SCOPE_TYPE_SHADER,
21 SCOPE_TYPE_IF_MERGE,
22 SCOPE_TYPE_LOOP_BREAK,
23 SCOPE_TYPE_LOOP_CONT,
24 };
25
26 struct scope {
27 enum scope_type type;
28
29 struct scope *parent;
30 uint32_t depth;
31
32 /**
33 * True if control-flow ever diverges within this scope, not accounting
34 * for divergence in child scopes.
35 */
36 bool divergent;
37
38 nir_block *merge;
39 nir_def *bar;
40
41 uint32_t escapes;
42 };
43
44 static struct scope
push_scope(nir_builder * b,enum scope_type scope_type,struct scope * parent,bool divergent,bool needs_sync,nir_block * merge_block)45 push_scope(nir_builder *b,
46 enum scope_type scope_type,
47 struct scope *parent,
48 bool divergent,
49 bool needs_sync,
50 nir_block *merge_block)
51 {
52 struct scope scope = {
53 .type = scope_type,
54 .parent = parent,
55 .depth = parent->depth + 1,
56 .divergent = parent->divergent || divergent,
57 .merge = merge_block,
58 };
59
60 if (needs_sync)
61 scope.bar = nir_bar_set_nv(b);
62
63 return scope;
64 }
65
66 static void
pop_scope(nir_builder * b,nir_def * esc_reg,struct scope scope)67 pop_scope(nir_builder *b, nir_def *esc_reg, struct scope scope)
68 {
69 if (scope.bar == NULL)
70 return;
71
72 nir_bar_sync_nv(b, scope.bar, scope.bar);
73
74 if (scope.escapes > 0) {
75 /* Find the nearest scope with a sync. */
76 nir_block *parent_merge = b->impl->end_block;
77 for (struct scope *p = scope.parent; p != NULL; p = p->parent) {
78 if (p->bar != NULL) {
79 parent_merge = p->merge;
80 break;
81 }
82 }
83
84 /* No escape is ~0, halt is 0, and we choose outer scope indices such
85 * that outer scopes always have lower indices than inner scopes.
86 */
87 nir_def *esc = nir_ult_imm(b, nir_load_reg(b, esc_reg), scope.depth);
88
89 /* We have to put the escape in its own block to avoid critical edges.
90 * If we just did goto_if, we would end up with multiple successors,
91 * including a jump to the parent's merge block which has multiple
92 * predecessors.
93 */
94 nir_block *esc_block = nir_block_create(b->shader);
95 nir_block *next_block = nir_block_create(b->shader);
96 nir_goto_if(b, esc_block, esc, next_block);
97 push_block(b, esc_block, false);
98 nir_goto(b, parent_merge);
99 push_block(b, next_block, scope.parent->divergent);
100 }
101 }
102
103 static enum scope_type
jump_target_scope_type(nir_jump_type jump_type)104 jump_target_scope_type(nir_jump_type jump_type)
105 {
106 switch (jump_type) {
107 case nir_jump_break: return SCOPE_TYPE_LOOP_BREAK;
108 case nir_jump_continue: return SCOPE_TYPE_LOOP_CONT;
109 default:
110 unreachable("Unknown jump type");
111 }
112 }
113
114 static void
break_scopes(nir_builder * b,nir_def * esc_reg,struct scope * current_scope,nir_jump_type jump_type)115 break_scopes(nir_builder *b, nir_def *esc_reg,
116 struct scope *current_scope,
117 nir_jump_type jump_type)
118 {
119 nir_block *first_sync = NULL;
120 uint32_t target_depth = UINT32_MAX;
121 enum scope_type target_scope_type = jump_target_scope_type(jump_type);
122 for (struct scope *scope = current_scope; scope; scope = scope->parent) {
123 if (first_sync == NULL && scope->bar != NULL)
124 first_sync = scope->merge;
125
126 if (scope->type == target_scope_type) {
127 if (first_sync == NULL) {
128 first_sync = scope->merge;
129 } else {
130 /* In order for our cascade to work, we need to have the invariant
131 * that anything which escapes any scope with a warp sync needs to
132 * target a scope with a warp sync.
133 */
134 assert(scope->bar != NULL);
135 }
136 target_depth = scope->depth;
137 break;
138 } else {
139 scope->escapes++;
140 }
141 }
142 assert(target_depth < UINT32_MAX);
143
144 nir_store_reg(b, nir_imm_int(b, target_depth), esc_reg);
145 nir_goto(b, first_sync);
146 }
147
148 static void
normal_exit(nir_builder * b,nir_def * esc_reg,nir_block * merge_block)149 normal_exit(nir_builder *b, nir_def *esc_reg, nir_block *merge_block)
150 {
151 assert(nir_cursors_equal(b->cursor, nir_after_impl(b->impl)));
152 nir_block *block = nir_cursor_current_block(b->cursor);
153
154 if (!nir_block_ends_in_jump(block)) {
155 nir_store_reg(b, nir_imm_int(b, ~0), esc_reg);
156 nir_goto(b, merge_block);
157 }
158 }
159
160 /* This is a heuristic for what instructions are allowed before we sync.
161 * Annoyingly, we've gotten rid of phis so it's not as simple as "is it a
162 * phi?".
163 */
164 static bool
instr_is_allowed_before_sync(nir_instr * instr)165 instr_is_allowed_before_sync(nir_instr *instr)
166 {
167 switch (instr->type) {
168 case nir_instr_type_alu: {
169 nir_alu_instr *alu = nir_instr_as_alu(instr);
170 /* We could probably allow more ALU as long as it doesn't contain
171 * derivatives but let's be conservative and only allow mov for now.
172 */
173 return alu->op == nir_op_mov;
174 }
175
176 case nir_instr_type_intrinsic: {
177 nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
178 return intrin->intrinsic == nir_intrinsic_load_reg ||
179 intrin->intrinsic == nir_intrinsic_store_reg;
180 }
181
182 default:
183 return false;
184 }
185 }
186
187 /** Returns true if our successor will sync for us
188 *
189 * This is a bit of a heuristic
190 */
191 static bool
parent_scope_will_sync(nir_cf_node * node,struct scope * parent_scope)192 parent_scope_will_sync(nir_cf_node *node, struct scope *parent_scope)
193 {
194 /* First search forward to see if there's anything non-trivial after this
195 * node within the parent scope.
196 */
197 nir_block *block = nir_cf_node_as_block(nir_cf_node_next(node));
198 nir_foreach_instr(instr, block) {
199 if (!instr_is_allowed_before_sync(instr))
200 return false;
201 }
202
203 /* There's another loop or if following and we didn't find a sync */
204 if (nir_cf_node_next(&block->cf_node))
205 return false;
206
207 /* See if the parent scope will sync for us. */
208 if (parent_scope->bar != NULL)
209 return true;
210
211 switch (parent_scope->type) {
212 case SCOPE_TYPE_SHADER:
213 return true;
214
215 case SCOPE_TYPE_IF_MERGE:
216 return parent_scope_will_sync(block->cf_node.parent,
217 parent_scope->parent);
218
219 case SCOPE_TYPE_LOOP_CONT:
220 /* In this case, the loop doesn't have a sync of its own so we're
221 * expected to be uniform before we hit the continue.
222 */
223 return false;
224
225 case SCOPE_TYPE_LOOP_BREAK:
226 unreachable("Loops must have a continue scope");
227
228 default:
229 unreachable("Unknown scope type");
230 }
231 }
232
233 static bool
block_is_merge(const nir_block * block)234 block_is_merge(const nir_block *block)
235 {
236 /* If it's unreachable, there is no merge */
237 if (block->imm_dom == NULL)
238 return false;
239
240 unsigned num_preds = 0;
241 set_foreach(block->predecessors, entry) {
242 const nir_block *pred = entry->key;
243
244 /* We don't care about unreachable blocks */
245 if (pred->imm_dom == NULL)
246 continue;
247
248 num_preds++;
249 }
250
251 return num_preds > 1;
252 }
253
254 static void
lower_cf_list(nir_builder * b,nir_def * esc_reg,struct scope * parent_scope,struct exec_list * cf_list)255 lower_cf_list(nir_builder *b, nir_def *esc_reg, struct scope *parent_scope,
256 struct exec_list *cf_list)
257 {
258 foreach_list_typed_safe(nir_cf_node, node, node, cf_list) {
259 switch (node->type) {
260 case nir_cf_node_block: {
261 nir_block *block = nir_cf_node_as_block(node);
262 if (exec_list_is_empty(&block->instr_list))
263 break;
264
265 nir_cursor start = nir_before_block(block);
266 nir_cursor end = nir_after_block(block);
267
268 nir_jump_instr *jump = NULL;
269 nir_instr *last_instr = nir_block_last_instr(block);
270 if (last_instr->type == nir_instr_type_jump) {
271 jump = nir_instr_as_jump(last_instr);
272 end = nir_before_instr(&jump->instr);
273 }
274
275 nir_cf_list instrs;
276 nir_cf_extract(&instrs, start, end);
277 b->cursor = nir_cf_reinsert(&instrs, b->cursor);
278
279 if (jump != NULL) {
280 if (jump->type == nir_jump_halt) {
281 /* Halt instructions map to OpExit on NVIDIA hardware and
282 * exited lanes never block a bsync.
283 */
284 nir_instr_remove(&jump->instr);
285 nir_builder_instr_insert(b, &jump->instr);
286 } else {
287 /* Everything else needs a break cascade */
288 break_scopes(b, esc_reg, parent_scope, jump->type);
289 }
290 }
291 break;
292 }
293
294 case nir_cf_node_if: {
295 nir_if *nif = nir_cf_node_as_if(node);
296
297 nir_def *cond = nif->condition.ssa;
298 nir_instr_clear_src(NULL, &nif->condition);
299
300 nir_block *then_block = nir_block_create(b->shader);
301 nir_block *else_block = nir_block_create(b->shader);
302 nir_block *merge_block = nir_block_create(b->shader);
303
304 const bool needs_sync = cond->divergent &&
305 block_is_merge(nir_cf_node_as_block(nir_cf_node_next(node))) &&
306 !parent_scope_will_sync(&nif->cf_node, parent_scope);
307
308 struct scope scope = push_scope(b, SCOPE_TYPE_IF_MERGE,
309 parent_scope, cond->divergent,
310 needs_sync, merge_block);
311
312 nir_goto_if(b, then_block, cond, else_block);
313
314 push_block(b, then_block, scope.divergent);
315 lower_cf_list(b, esc_reg, &scope, &nif->then_list);
316 normal_exit(b, esc_reg, merge_block);
317
318 push_block(b, else_block, scope.divergent);
319 lower_cf_list(b, esc_reg, &scope, &nif->else_list);
320 normal_exit(b, esc_reg, merge_block);
321
322 push_block(b, merge_block, parent_scope->divergent);
323 pop_scope(b, esc_reg, scope);
324
325 break;
326 }
327
328 case nir_cf_node_loop: {
329 nir_loop *loop = nir_cf_node_as_loop(node);
330
331 nir_block *head_block = nir_block_create(b->shader);
332 nir_block *break_block = nir_block_create(b->shader);
333 nir_block *cont_block = nir_block_create(b->shader);
334
335 /* TODO: We can potentially avoid the break sync for loops when the
336 * parent scope syncs for us. However, we still need to handle the
337 * continue clause cascading to the break. If there is a
338 * nir_jump_halt involved, then we have a real cascade where it needs
339 * to then jump to the next scope. Getting all these cases right
340 * while avoiding an extra sync for the loop break is tricky at best.
341 */
342 struct scope break_scope = push_scope(b, SCOPE_TYPE_LOOP_BREAK,
343 parent_scope, loop->divergent,
344 loop->divergent, break_block);
345
346 nir_goto(b, head_block);
347 push_block(b, head_block, break_scope.divergent);
348
349 struct scope cont_scope = push_scope(b, SCOPE_TYPE_LOOP_CONT,
350 &break_scope, loop->divergent,
351 loop->divergent, cont_block);
352
353 lower_cf_list(b, esc_reg, &cont_scope, &loop->body);
354 normal_exit(b, esc_reg, cont_block);
355
356 push_block(b, cont_block, break_scope.divergent);
357
358 pop_scope(b, esc_reg, cont_scope);
359
360 lower_cf_list(b, esc_reg, &break_scope, &loop->continue_list);
361
362 nir_goto(b, head_block);
363 push_block(b, break_block, parent_scope->divergent);
364
365 pop_scope(b, esc_reg, break_scope);
366
367 break;
368 }
369
370 default:
371 unreachable("Unknown CF node type");
372 }
373 }
374 }
375
376 static void
recompute_phi_divergence_impl(nir_function_impl * impl)377 recompute_phi_divergence_impl(nir_function_impl *impl)
378 {
379 bool progress;
380 do {
381 progress = false;
382 nir_foreach_block_unstructured(block, impl) {
383 nir_foreach_instr(instr, block) {
384 if (instr->type != nir_instr_type_phi)
385 break;
386
387 nir_phi_instr *phi = nir_instr_as_phi(instr);
388
389 bool divergent = false;
390 nir_foreach_phi_src(phi_src, phi) {
391 /* There is a tricky case we need to care about here where a
392 * convergent block has a divergent dominator. This can happen
393 * if, for instance, you have the following loop:
394 *
395 * loop {
396 * if (div) {
397 * %20 = load_ubo(0, 0);
398 * } else {
399 * terminate;
400 * }
401 * }
402 * use(%20);
403 *
404 * In this case, the load_ubo() dominates the use() even though
405 * the load_ubo() exists in divergent control-flow. In this
406 * case, we simply flag the whole phi divergent because we
407 * don't want to deal with inserting a r2ur somewhere.
408 */
409 if (phi_src->pred->divergent || phi_src->src.ssa->divergent ||
410 phi_src->src.ssa->parent_instr->block->divergent) {
411 divergent = true;
412 break;
413 }
414 }
415
416 if (divergent != phi->def.divergent) {
417 phi->def.divergent = divergent;
418 progress = true;
419 }
420 }
421 }
422 } while(progress);
423 }
424
425 static bool
lower_cf_func(nir_function * func)426 lower_cf_func(nir_function *func)
427 {
428 if (func->impl == NULL)
429 return false;
430
431 if (exec_list_is_singular(&func->impl->body)) {
432 nir_metadata_preserve(func->impl, nir_metadata_all);
433 return false;
434 }
435
436 nir_function_impl *old_impl = func->impl;
437
438 /* We use this in block_is_merge() */
439 nir_metadata_require(old_impl, nir_metadata_dominance);
440
441 /* First, we temporarily get rid of SSA. This will make all our block
442 * motion way easier.
443 */
444 nir_foreach_block(block, old_impl)
445 nir_lower_phis_to_regs_block(block);
446
447 /* We create a whole new nir_function_impl and copy the contents over */
448 func->impl = NULL;
449 nir_function_impl *new_impl = nir_function_impl_create(func);
450 new_impl->structured = false;
451
452 /* We copy defs from the old impl */
453 new_impl->ssa_alloc = old_impl->ssa_alloc;
454
455 nir_builder b = nir_builder_at(nir_before_impl(new_impl));
456 nir_def *esc_reg = nir_decl_reg(&b, 1, 32, 0);
457
458 /* Having a function scope makes everything easier */
459 struct scope scope = {
460 .type = SCOPE_TYPE_SHADER,
461 .merge = new_impl->end_block,
462 };
463 lower_cf_list(&b, esc_reg, &scope, &old_impl->body);
464 normal_exit(&b, esc_reg, new_impl->end_block);
465
466 /* Now sort by reverse PDFS and restore SSA
467 *
468 * Note: Since we created a new nir_function_impl, there is no metadata,
469 * dirty or otherwise, so we have no need to call nir_metadata_preserve().
470 */
471 nir_sort_unstructured_blocks(new_impl);
472 nir_repair_ssa_impl(new_impl);
473 nir_lower_reg_intrinsics_to_ssa_impl(new_impl);
474 recompute_phi_divergence_impl(new_impl);
475
476 return true;
477 }
478
479 bool
nak_nir_lower_cf(nir_shader * nir)480 nak_nir_lower_cf(nir_shader *nir)
481 {
482 bool progress = false;
483
484 nir_foreach_function(func, nir) {
485 if (lower_cf_func(func))
486 progress = true;
487 }
488
489 return progress;
490 }
491