xref: /aosp_15_r20/external/mesa3d/src/compiler/spirv/vtn_structured_cfg.c (revision 6104692788411f58d303aa86923a9ff6ecaded22)
1 /*
2  * Copyright © 2015-2023 Intel Corporation
3  *
4  * Permission is hereby granted, free of charge, to any person obtaining a
5  * copy of this software and associated documentation files (the "Software"),
6  * to deal in the Software without restriction, including without limitation
7  * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8  * and/or sell copies of the Software, and to permit persons to whom the
9  * Software is furnished to do so, subject to the following conditions:
10  *
11  * The above copyright notice and this permission notice (including the next
12  * paragraph) shall be included in all copies or substantial portions of the
13  * Software.
14  *
15  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
18  * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20  * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21  * IN THE SOFTWARE.
22  */
23 
24 #include "vtn_private.h"
25 #include "spirv_info.h"
26 #include "util/u_math.h"
27 
28 /* Handle SPIR-V structured control flow, mapping SPIR-V constructs into
29  * equivalent NIR constructs.
30  *
31  * Because SPIR-V can represent more complex control flow than NIR, some
32  * constructs are mapped into a combination of nir_if and nir_loop nodes.  For
33  * example, an selection construct with an "if-break" (an early branch into
34  * the end of the construct) will be mapped into NIR as a loop (to allow the
35  * break) with a nested if (to handle the actual selection).
36  *
37  * Note that using NIR loops this way requires us to propagate breaks and
38  * continues that are meant to outer constructs when a nir_loop is used for a
39  * SPIR-V construct other than Loop.
40  *
41  * The process of identifying and ordering the blocks before the NIR
42  * translation is similar to what's done in Tint, using the "reverse
43  * structured post-order traversal".  See also the file comments
44  * src/reader/spirv/function.cc in the Tint repository.
45  */
46 
47 enum vtn_construct_type {
48    /* Not formally a SPIR-V construct but used to represent the entire
49     * function.
50     */
51    vtn_construct_type_function,
52 
53    /* Selection construct uses a nir_if and optionally a nir_loop to handle
54     * if-breaks.
55     */
56    vtn_construct_type_selection,
57 
58    /* Loop construct uses a nir_loop and optionally a nir_if to handle an
59     * OpBranchConditional as part of the head of the loop.
60     */
61    vtn_construct_type_loop,
62 
63    /* Continue construct maps to the NIR continue construct of the corresponding
64     * loop.  For convenience, unlike in SPIR-V, the parent of this construct is
65     * always the loop construct.  Continue construct is omitted for single-block
66     * loops.
67     */
68    vtn_construct_type_continue,
69 
70    /* Switch construct is not directly mapped into any NIR structure, the work
71     * is handled by the case constructs.  It does keep a nir_variable for
72     * handling case fallback logic.
73     */
74    vtn_construct_type_switch,
75 
76    /* Case construct uses a nir_if and optionally a nir_loop to handle early
77     * breaks.  Note switch_breaks are handled by each case.
78     */
79    vtn_construct_type_case,
80 };
81 
82 static const char *
vtn_construct_type_to_string(enum vtn_construct_type t)83 vtn_construct_type_to_string(enum vtn_construct_type t)
84 {
85 #define CASE(typ) case vtn_construct_type_##typ: return #typ
86    switch (t) {
87    CASE(function);
88    CASE(selection);
89    CASE(loop);
90    CASE(continue);
91    CASE(switch);
92    CASE(case);
93    }
94 #undef CASE
95    unreachable("invalid construct type");
96    return "";
97 }
98 
99 struct vtn_construct {
100    enum vtn_construct_type type;
101 
102    bool needs_nloop;
103    bool needs_break_propagation;
104    bool needs_continue_propagation;
105    bool needs_fallthrough;
106 
107    struct vtn_construct *parent;
108 
109    struct vtn_construct *innermost_loop;
110    struct vtn_construct *innermost_switch;
111    struct vtn_construct *innermost_case;
112 
113    unsigned start_pos;
114    unsigned end_pos;
115 
116    /* Usually the same as end_pos, but may be different in case of an "early
117     * merge" after divergence caused by an OpBranchConditional.  This can
118     * happen in selection and loop constructs.
119     */
120    unsigned merge_pos;
121 
122    /* Valid when not zero, indicates the block that starts the then and else
123     * paths in a condition.  This may be used by selection constructs.
124     */
125    unsigned then_pos;
126    unsigned else_pos;
127 
128    /* Indicates where the continue block is, marking the end of the body of
129     * the loop.  Note the block ordering will always give us first the loop
130     * body blocks then the continue block.  Used by loop construct.
131     */
132    unsigned continue_pos;
133 
134    /* For the list of all constructs in vtn_function. */
135    struct list_head link;
136 
137    /* NIR nodes that are associated with this construct.  See
138     * vtn_construct_type for an overview.
139     */
140    nir_loop *nloop;
141    nir_if *nif;
142 
143    /* This variable will be set by an inner construct to indicate that a break
144     * is necessary.  We need to use variables here for situations when the
145     * inner construct has a loop of its own for other reasons.
146     */
147    nir_variable *break_var;
148 
149    /* Same logic but for continue. */
150    nir_variable *continue_var;
151 
152    /* This is used by each case to force entering in the case regardless of
153     * the condition.  We always set it when handling a branch that is a
154     * switch_break or a switch_fallthrough.
155     */
156    nir_variable *fallthrough_var;
157 
158    unsigned index;
159 };
160 
161 enum vtn_branch_type {
162    vtn_branch_type_none,
163    vtn_branch_type_forward,
164    vtn_branch_type_if_break,
165    vtn_branch_type_switch_break,
166    vtn_branch_type_switch_fallthrough,
167    vtn_branch_type_loop_break,
168    vtn_branch_type_loop_continue,
169    vtn_branch_type_loop_back_edge,
170    vtn_branch_type_discard,
171    vtn_branch_type_terminate_invocation,
172    vtn_branch_type_ignore_intersection,
173    vtn_branch_type_terminate_ray,
174    vtn_branch_type_emit_mesh_tasks,
175    vtn_branch_type_return,
176 };
177 
178 static const char *
vtn_branch_type_to_string(enum vtn_branch_type t)179 vtn_branch_type_to_string(enum vtn_branch_type t)
180 {
181 #define CASE(typ) case vtn_branch_type_##typ: return #typ
182    switch (t) {
183    CASE(none);
184    CASE(forward);
185    CASE(if_break);
186    CASE(switch_break);
187    CASE(switch_fallthrough);
188    CASE(loop_break);
189    CASE(loop_continue);
190    CASE(loop_back_edge);
191    CASE(discard);
192    CASE(terminate_invocation);
193    CASE(ignore_intersection);
194    CASE(terminate_ray);
195    CASE(emit_mesh_tasks);
196    CASE(return);
197    }
198 #undef CASE
199    unreachable("unknown branch type");
200    return "";
201 }
202 
203 struct vtn_successor {
204    struct vtn_block *block;
205    enum vtn_branch_type branch_type;
206 };
207 
208 static bool
vtn_is_single_block_loop(const struct vtn_construct * c)209 vtn_is_single_block_loop(const struct vtn_construct *c)
210 {
211    return c->type == vtn_construct_type_loop &&
212           c->start_pos == c->continue_pos;
213 }
214 
215 static struct vtn_construct *
vtn_find_innermost(enum vtn_construct_type type,struct vtn_construct * c)216 vtn_find_innermost(enum vtn_construct_type type, struct vtn_construct *c)
217 {
218    while (c && c->type != type)
219       c = c->parent;
220    return c;
221 }
222 
223 static void
print_ordered_blocks(const struct vtn_function * func)224 print_ordered_blocks(const struct vtn_function *func)
225 {
226    for (unsigned i = 0; i < func->ordered_blocks_count; i++) {
227       struct vtn_block *block = func->ordered_blocks[i];
228       printf("[id=%-6u] %4u", block->label[1], block->pos);
229       if (block->successors_count > 0) {
230          printf(" ->");
231          for (unsigned j = 0; j < block->successors_count; j++) {
232             printf(" ");
233             if (block->successors[j].block)
234                printf("%u/", block->successors[j].block->pos);
235             printf("%s", vtn_branch_type_to_string(block->successors[j].branch_type));
236          }
237       }
238       if (!block->visited)
239          printf("  NOT VISITED");
240       printf("\n");
241    }
242 }
243 
244 static struct vtn_case *
vtn_find_fallthrough_target(struct vtn_builder * b,const uint32_t * switch_merge,struct vtn_block * source_block,struct vtn_block * block)245 vtn_find_fallthrough_target(struct vtn_builder *b, const uint32_t *switch_merge,
246                             struct vtn_block *source_block, struct vtn_block *block)
247 {
248    if (block->visited)
249       return NULL;
250 
251    if (block->label[1] == switch_merge[1])
252       return NULL;
253 
254    /* Don't consider the initial source block a fallthrough target of itself. */
255    if (block->switch_case && block != source_block)
256       return block->switch_case;
257 
258    if (block->merge)
259       return vtn_find_fallthrough_target(b, switch_merge, source_block,
260                                          vtn_block(b, block->merge[1]));
261 
262    const uint32_t *branch = block->branch;
263    vtn_assert(branch);
264 
265    switch (branch[0] & SpvOpCodeMask) {
266    case SpvOpBranch:
267       return vtn_find_fallthrough_target(b, switch_merge, source_block,
268                                          vtn_block(b, branch[1]));
269    case SpvOpBranchConditional: {
270       struct vtn_case *target =
271          vtn_find_fallthrough_target(b, switch_merge, source_block,
272                                      vtn_block(b, branch[2]));
273       if (!target)
274          target = vtn_find_fallthrough_target(b, switch_merge, source_block,
275                                               vtn_block(b, branch[3]));
276       return target;
277    }
278    default:
279       return NULL;
280    }
281 }
282 
283 static void
structured_post_order_traversal(struct vtn_builder * b,struct vtn_block * block)284 structured_post_order_traversal(struct vtn_builder *b, struct vtn_block *block)
285 {
286    if (block->visited)
287       return;
288 
289    block->visited = true;
290 
291    if (block->merge) {
292       structured_post_order_traversal(b, vtn_block(b, block->merge[1]));
293 
294       SpvOp merge_op = block->merge[0] & SpvOpCodeMask;
295       if (merge_op == SpvOpLoopMerge) {
296          struct vtn_block *continue_block = vtn_block(b, block->merge[2]);
297          structured_post_order_traversal(b, continue_block);
298       }
299    }
300 
301    const uint32_t *branch = block->branch;
302    vtn_assert(branch);
303 
304    switch (branch[0] & SpvOpCodeMask) {
305    case SpvOpBranch:
306       block->successors_count = 1;
307       block->successors = vtn_zalloc(b, struct vtn_successor);
308       block->successors[0].block = vtn_block(b, branch[1]);
309       structured_post_order_traversal(b, block->successors[0].block);
310       break;
311 
312    case SpvOpBranchConditional:
313       block->successors_count = 2;
314       block->successors = vtn_zalloc_array(b, struct vtn_successor, 2);
315       block->successors[0].block = vtn_block(b, branch[2]);
316       block->successors[1].block = vtn_block(b, branch[3]);
317 
318       /* The result of the traversal will be reversed, so to provide a
319        * more natural order, with THEN blocks appearing before ELSE blocks,
320        * we need to traverse them in the reversed order.
321        */
322       int order[] = { 1, 0 };
323 
324       /* There's a catch when traversing case fallthroughs: we want to avoid
325        * walking part of a case construct, then the fallthrough -- possibly
326        * visiting another entire case construct, and back to the other part
327        * of that original case construct. So if the THEN path is a fallthrough,
328        * swap the visit order.
329        */
330       if (block->successors[0].block->switch_case) {
331          order[0] = !order[0];
332          order[1] = !order[1];
333       }
334 
335       structured_post_order_traversal(b, block->successors[order[0]].block);
336       structured_post_order_traversal(b, block->successors[order[1]].block);
337       break;
338 
339    case SpvOpSwitch: {
340       /* TODO: Save this to use during Switch construct creation. */
341       struct list_head cases;
342       list_inithead(&cases);
343       vtn_parse_switch(b, block->branch, &cases);
344 
345       block->successors_count = list_length(&cases);
346       block->successors = vtn_zalloc_array(b, struct vtn_successor, block->successors_count);
347 
348       /* The 'Rules for Structured Control-flow constructs' already guarantee
349        * that the labels of the targets are ordered in a way that if
350        * there is a fallthrough, they will appear consecutively.  The only
351        * exception is Default, which is always the first in the list.
352        *
353        * Because we are doing a DFS from the end of the cases, the
354        * traversal already handle a Case falling through Default.
355        *
356        * The scenario that needs fixing is when no case falls to Default, but
357        * Default falls to another case.  For that scenario we move the Default
358        * right before the case it falls to.
359        */
360 
361       struct vtn_case *default_case = list_first_entry(&cases, struct vtn_case, link);
362       vtn_assert(default_case && default_case->is_default);
363 
364       struct vtn_case *fall_target =
365          vtn_find_fallthrough_target(b, block->merge, default_case->block,
366                                      default_case->block);
367       if (fall_target)
368          list_move_to(&default_case->link, &fall_target->link);
369 
370       /* Because the result of the traversal will be reversed, loop backwards
371        * in the case list.
372        */
373       unsigned i = 0;
374       list_for_each_entry_rev(struct vtn_case, cse, &cases, link) {
375          structured_post_order_traversal(b, cse->block);
376          block->successors[i].block = cse->block;
377          i++;
378       }
379 
380       break;
381    }
382 
383    case SpvOpKill:
384    case SpvOpTerminateInvocation:
385    case SpvOpIgnoreIntersectionKHR:
386    case SpvOpTerminateRayKHR:
387    case SpvOpReturn:
388    case SpvOpReturnValue:
389    case SpvOpEmitMeshTasksEXT:
390    case SpvOpUnreachable:
391       block->successors_count = 1;
392       block->successors = vtn_zalloc(b, struct vtn_successor);
393       break;
394 
395    default:
396       unreachable("invalid branch opcode");
397    }
398 
399    b->func->ordered_blocks[b->func->ordered_blocks_count++] = block;
400 }
401 
402 static void
sort_blocks(struct vtn_builder * b)403 sort_blocks(struct vtn_builder *b)
404 {
405    struct vtn_block **ordered_blocks =
406       vtn_zalloc_array(b, struct vtn_block *, b->func->block_count);
407 
408    b->func->ordered_blocks = ordered_blocks;
409 
410    structured_post_order_traversal(b, b->func->start_block);
411 
412    /* Reverse it, so that blocks appear before their successors. */
413    unsigned count = b->func->ordered_blocks_count;
414    for (unsigned i = 0; i < (count / 2); i++) {
415       unsigned j = count - i - 1;
416       struct vtn_block *tmp = ordered_blocks[i];
417       ordered_blocks[i] = ordered_blocks[j];
418       ordered_blocks[j] = tmp;
419    }
420 
421    for (unsigned i = 0; i < count; i++)
422       ordered_blocks[i]->pos = i;
423 }
424 
425 static void
print_construct(const struct vtn_function * func,const struct vtn_construct * c)426 print_construct(const struct vtn_function *func,
427                 const struct vtn_construct *c)
428 {
429    for (const struct vtn_construct *p = c->parent; p; p = p->parent)
430       printf("    ");
431    printf("C%u/%s ", c->index, vtn_construct_type_to_string(c->type));
432    printf("  %u->%u", c->start_pos, c->end_pos);
433    if (c->merge_pos)
434       printf("  merge=%u", c->merge_pos);
435    if (c->then_pos)
436       printf("  then=%u", c->then_pos);
437    if (c->else_pos)
438       printf("  else=%u", c->else_pos);
439    if (c->needs_nloop)
440       printf("  nloop");
441    if (c->needs_break_propagation)
442       printf("  break_prop");
443    if (c->needs_continue_propagation)
444       printf("  continue_prop");
445    if (c->type == vtn_construct_type_loop) {
446       if (vtn_is_single_block_loop(c))
447          printf("  single_block_loop");
448       else
449          printf("  cont=%u", c->continue_pos);
450    }
451    if (c->type == vtn_construct_type_case) {
452       struct vtn_block *block = func->ordered_blocks[c->start_pos];
453       if (block->switch_case->is_default) {
454          printf(" [default]");
455       } else {
456          printf(" [values:");
457          util_dynarray_foreach(&block->switch_case->values, uint64_t, val)
458             printf(" %" PRIu64, *val);
459          printf("]");
460       }
461    }
462    printf("\n");
463 }
464 
465 static void
print_constructs(struct vtn_function * func)466 print_constructs(struct vtn_function *func)
467 {
468    list_for_each_entry(struct vtn_construct, c, &func->constructs, link)
469       print_construct(func, c);
470 }
471 
472 struct vtn_construct_stack {
473    /* Array of `struct vtn_construct *`. */
474    struct util_dynarray data;
475 };
476 
477 static inline void
init_construct_stack(struct vtn_construct_stack * stack,void * mem_ctx)478 init_construct_stack(struct vtn_construct_stack *stack, void *mem_ctx)
479 {
480    assert(mem_ctx);
481    util_dynarray_init(&stack->data, mem_ctx);
482 }
483 
484 static inline unsigned
count_construct_stack(struct vtn_construct_stack * stack)485 count_construct_stack(struct vtn_construct_stack *stack)
486 {
487    return util_dynarray_num_elements(&stack->data, struct vtn_construct *);
488 }
489 
490 static inline struct vtn_construct *
top_construct(struct vtn_construct_stack * stack)491 top_construct(struct vtn_construct_stack *stack)
492 {
493    assert(count_construct_stack(stack) > 0);
494    return util_dynarray_top(&stack->data, struct vtn_construct *);
495 }
496 
497 static inline void
pop_construct(struct vtn_construct_stack * stack)498 pop_construct(struct vtn_construct_stack *stack)
499 {
500    assert(count_construct_stack(stack) > 0);
501    (void)util_dynarray_pop(&stack->data, struct vtn_construct *);
502 }
503 
504 static inline void
push_construct(struct vtn_construct_stack * stack,struct vtn_construct * c)505 push_construct(struct vtn_construct_stack *stack, struct vtn_construct *c)
506 {
507    util_dynarray_append(&stack->data, struct vtn_construct *, c);
508 }
509 
510 static int
cmp_succ_block_pos(const void * pa,const void * pb)511 cmp_succ_block_pos(const void *pa, const void *pb)
512 {
513    const struct vtn_successor *sa = pa;
514    const struct vtn_successor *sb = pb;
515    const unsigned a = sa->block->pos;
516    const unsigned b = sb->block->pos;
517    if (a < b)
518       return -1;
519    if (a > b)
520       return 1;
521    return 0;
522 }
523 
524 static void
create_constructs(struct vtn_builder * b)525 create_constructs(struct vtn_builder *b)
526 {
527    struct vtn_construct *func_construct = vtn_zalloc(b, struct vtn_construct);
528    func_construct->type = vtn_construct_type_function;
529    func_construct->start_pos = 0;
530    func_construct->end_pos = b->func->ordered_blocks_count;
531 
532    for (unsigned i = 0; i < b->func->ordered_blocks_count; i++) {
533       struct vtn_block *block = b->func->ordered_blocks[i];
534 
535       if (block->merge) {
536          SpvOp merge_op = block->merge[0] & SpvOpCodeMask;
537          SpvOp branch_op = block->branch[0] & SpvOpCodeMask;
538 
539          const unsigned end_pos = vtn_block(b, block->merge[1])->pos;
540 
541          if (merge_op == SpvOpLoopMerge) {
542             struct vtn_construct *loop = vtn_zalloc(b, struct vtn_construct);
543             loop->type = vtn_construct_type_loop;
544             loop->start_pos = block->pos;
545             loop->end_pos = end_pos;
546 
547             loop->parent = block->parent;
548             block->parent = loop;
549 
550             struct vtn_block *continue_block = vtn_block(b, block->merge[2]);
551             loop->continue_pos = continue_block->pos;
552 
553             if (!vtn_is_single_block_loop(loop)) {
554                struct vtn_construct *cont = vtn_zalloc(b, struct vtn_construct);
555                cont->type = vtn_construct_type_continue;
556                cont->parent = loop;
557                cont->start_pos = loop->continue_pos;
558                cont->end_pos = end_pos;
559 
560                cont->parent = loop;
561                continue_block->parent = cont;
562             }
563 
564             /* Not all combinations of OpLoopMerge and OpBranchConditional are valid,
565              * workaround for invalid combinations by injecting an extra selection.
566              *
567              * Old versions of dxil-spirv generated this.
568              */
569             if (branch_op == SpvOpBranchConditional) {
570                vtn_assert(block->successors_count == 2);
571                const unsigned then_pos = block->successors[0].block ?
572                                          block->successors[0].block->pos : 0;
573                const unsigned else_pos = block->successors[1].block ?
574                                          block->successors[1].block->pos : 0;
575 
576                if (then_pos > loop->start_pos && then_pos < loop->continue_pos &&
577                    else_pos > loop->start_pos && else_pos < loop->continue_pos) {
578                   vtn_warn("An OpSelectionMerge instruction is required to precede "
579                            "an OpBranchConditional instruction that has different "
580                            "True Label and False Label operands where neither are "
581                            "declared merge blocks or Continue Targets.");
582                   struct vtn_construct *sel = vtn_zalloc(b, struct vtn_construct);
583                   sel->type = vtn_construct_type_selection;
584                   sel->start_pos = loop->start_pos;
585                   sel->end_pos = loop->continue_pos;
586                   sel->then_pos = then_pos;
587                   sel->else_pos = else_pos;
588                   sel->parent = loop;
589                   block->parent = sel;
590                }
591             }
592 
593          } else if (branch_op == SpvOpSwitch) {
594             vtn_assert(merge_op == SpvOpSelectionMerge);
595 
596             struct vtn_construct *swtch = vtn_zalloc(b, struct vtn_construct);
597             swtch->type = vtn_construct_type_switch;
598             swtch->start_pos = block->pos;
599             swtch->end_pos = end_pos;
600 
601             swtch->parent = block->parent;
602             block->parent = swtch;
603 
604             struct list_head cases;
605             list_inithead(&cases);
606             vtn_parse_switch(b, block->branch, &cases);
607 
608             vtn_foreach_case_safe(cse, &cases) {
609                if (cse->block->pos < end_pos) {
610                   struct vtn_block *case_block = cse->block;
611                   struct vtn_construct *c = vtn_zalloc(b, struct vtn_construct);
612                   c->type = vtn_construct_type_case;
613                   c->parent = swtch;
614                   c->start_pos = case_block->pos;
615 
616                   /* Upper bound, will be updated right after. */
617                   c->end_pos = swtch->end_pos;
618 
619                   vtn_assert(case_block->parent == NULL || case_block->parent == swtch);
620                   case_block->parent = c;
621                } else {
622                   /* A target in OpSwitch must point either to one of the case
623                    * constructs or to the Merge block.  No outer break/continue
624                    * is allowed.
625                    */
626                   vtn_assert(cse->block->pos == end_pos);
627                }
628                list_delinit(&cse->link);
629             }
630 
631             /* Case constructs don't overlap, so they end as the next one
632              * begins.
633              */
634             qsort(block->successors, block->successors_count,
635                   sizeof(struct vtn_successor), cmp_succ_block_pos);
636             for (unsigned succ_idx = 1; succ_idx < block->successors_count; succ_idx++) {
637                unsigned succ_pos = block->successors[succ_idx].block->pos;
638                /* The successors are ordered, so once we see a successor point
639                 * to the merge block, we are done fixing the cases.
640                 */
641                if (succ_pos >= swtch->end_pos)
642                   break;
643                struct vtn_construct *prev_cse =
644                   vtn_find_innermost(vtn_construct_type_case,
645                                      block->successors[succ_idx - 1].block->parent);
646                vtn_assert(prev_cse);
647                prev_cse->end_pos = succ_pos;
648             }
649 
650          } else {
651             vtn_assert(merge_op == SpvOpSelectionMerge);
652             vtn_assert(branch_op == SpvOpBranchConditional);
653 
654             struct vtn_construct *sel = vtn_zalloc(b, struct vtn_construct);
655             sel->type = vtn_construct_type_selection;
656             sel->start_pos = block->pos;
657             sel->end_pos = end_pos;
658             sel->parent = block->parent;
659             block->parent = sel;
660 
661             vtn_assert(block->successors_count == 2);
662             struct vtn_block *then_block = block->successors[0].block;
663             struct vtn_block *else_block = block->successors[1].block;
664 
665             sel->then_pos = then_block ? then_block->pos : 0;
666             sel->else_pos = else_block ? else_block->pos : 0;
667          }
668       }
669    }
670 
671    /* Link the constructs with their parents and with the remaining blocks
672     * that do not start one.  This will also build the ordered list of
673     * constructs.
674     */
675    struct vtn_construct_stack stack;
676    init_construct_stack(&stack, b);
677    push_construct(&stack, func_construct);
678    list_addtail(&func_construct->link, &b->func->constructs);
679 
680    for (unsigned i = 0; i < b->func->ordered_blocks_count; i++) {
681       struct vtn_block *block = b->func->ordered_blocks[i];
682 
683       while (block->pos == top_construct(&stack)->end_pos)
684          pop_construct(&stack);
685 
686       /* Identify the start of a continue construct. */
687       if (top_construct(&stack)->type == vtn_construct_type_loop &&
688           !vtn_is_single_block_loop(top_construct(&stack)) &&
689           top_construct(&stack)->continue_pos == block->pos) {
690          struct vtn_construct *c = vtn_find_innermost(vtn_construct_type_continue, block->parent);
691          vtn_assert(c);
692          vtn_assert(c->parent == top_construct(&stack));
693 
694          list_addtail(&c->link, &b->func->constructs);
695          push_construct(&stack, c);
696       }
697 
698       if (top_construct(&stack)->type == vtn_construct_type_switch) {
699          struct vtn_block *header = b->func->ordered_blocks[top_construct(&stack)->start_pos];
700          for (unsigned succ_idx = 0; succ_idx < header->successors_count; succ_idx++) {
701             struct vtn_successor *succ = &header->successors[succ_idx];
702             if (block == succ->block) {
703                struct vtn_construct *c = vtn_find_innermost(vtn_construct_type_case, succ->block->parent);
704                if (c) {
705                   vtn_assert(c->parent == top_construct(&stack));
706 
707                   list_addtail(&c->link, &b->func->constructs);
708                   push_construct(&stack, c);
709                }
710                break;
711             }
712          }
713       }
714 
715       if (block->merge) {
716          switch (block->merge[0] & SpvOpCodeMask) {
717          case SpvOpSelectionMerge: {
718             struct vtn_construct *c = block->parent;
719             vtn_assert(c->type == vtn_construct_type_selection ||
720                        c->type == vtn_construct_type_switch);
721 
722             c->parent = top_construct(&stack);
723 
724             list_addtail(&c->link, &b->func->constructs);
725             push_construct(&stack, c);
726             break;
727          }
728 
729          case SpvOpLoopMerge: {
730             struct vtn_construct *c = block->parent;
731             struct vtn_construct *loop = c;
732 
733             /* A loop might have an extra selection injected, skip it. */
734             if (c->type == vtn_construct_type_selection)
735                loop = c->parent;
736 
737             vtn_assert(loop->type == vtn_construct_type_loop);
738             loop->parent = top_construct(&stack);
739 
740             list_addtail(&loop->link, &b->func->constructs);
741             push_construct(&stack, loop);
742 
743             if (loop != c) {
744                /* Make sure we also "enter" the extra construct. */
745                list_addtail(&c->link, &b->func->constructs);
746                push_construct(&stack, c);
747             }
748             break;
749          }
750 
751          default:
752             unreachable("invalid merge opcode");
753          }
754       }
755 
756       block->parent = top_construct(&stack);
757    }
758 
759    vtn_assert(count_construct_stack(&stack) == 1);
760    vtn_assert(top_construct(&stack)->type == vtn_construct_type_function);
761 
762    unsigned index = 0;
763    list_for_each_entry(struct vtn_construct, c, &b->func->constructs, link)
764       c->index = index++;
765 }
766 
767 static void
validate_constructs(struct vtn_builder * b)768 validate_constructs(struct vtn_builder *b)
769 {
770    list_for_each_entry(struct vtn_construct, c, &b->func->constructs, link) {
771       if (c->type == vtn_construct_type_function)
772          vtn_assert(c->parent == NULL);
773       else
774          vtn_assert(c->parent);
775 
776       switch (c->type) {
777       case vtn_construct_type_continue:
778          vtn_assert(c->parent->type == vtn_construct_type_loop);
779          break;
780       case vtn_construct_type_case:
781          vtn_assert(c->parent->type == vtn_construct_type_switch);
782          break;
783       default:
784          /* Nothing to do. */
785          break;
786       }
787    }
788 }
789 
790 static void
find_innermost_constructs(struct vtn_builder * b)791 find_innermost_constructs(struct vtn_builder *b)
792 {
793    list_for_each_entry(struct vtn_construct, c, &b->func->constructs, link) {
794       if (c->type == vtn_construct_type_function) {
795          c->innermost_loop = NULL;
796          c->innermost_switch = NULL;
797          c->innermost_case = NULL;
798          continue;
799       }
800 
801       if (c->type == vtn_construct_type_loop)
802          c->innermost_loop = c;
803       else
804          c->innermost_loop = c->parent->innermost_loop;
805 
806       if (c->type == vtn_construct_type_switch)
807          c->innermost_switch = c;
808       else
809          c->innermost_switch = c->parent->innermost_switch;
810 
811       if (c->type == vtn_construct_type_case)
812          c->innermost_case = c;
813       else
814          c->innermost_case = c->parent->innermost_case;
815    }
816 
817    list_for_each_entry(struct vtn_construct, c, &b->func->constructs, link) {
818       vtn_assert(vtn_find_innermost(vtn_construct_type_loop, c) == c->innermost_loop);
819       vtn_assert(vtn_find_innermost(vtn_construct_type_switch, c) == c->innermost_switch);
820       vtn_assert(vtn_find_innermost(vtn_construct_type_case, c) == c->innermost_case);
821    }
822 }
823 
824 static void
set_needs_continue_propagation(struct vtn_construct * c)825 set_needs_continue_propagation(struct vtn_construct *c)
826 {
827    for (; c != c->innermost_loop; c = c->parent)
828       c->needs_continue_propagation = true;
829 }
830 
831 static void
set_needs_break_propagation(struct vtn_construct * c,struct vtn_construct * to_break)832 set_needs_break_propagation(struct vtn_construct *c,
833                             struct vtn_construct *to_break)
834 {
835    for (; c != to_break; c = c->parent)
836       c->needs_break_propagation = true;
837 }
838 
839 static enum vtn_branch_type
branch_type_for_successor(struct vtn_builder * b,struct vtn_block * block,struct vtn_successor * succ)840 branch_type_for_successor(struct vtn_builder *b, struct vtn_block *block,
841                           struct vtn_successor *succ)
842 {
843    unsigned pos = block->pos;
844    unsigned succ_pos = succ->block->pos;
845 
846    struct vtn_construct *inner = block->parent;
847    vtn_assert(inner);
848 
849    /* Identify the types of branches, applying the "Rules for Structured
850     * Control-flow Constructs" from SPIR-V spec.
851     */
852 
853    struct vtn_construct *innermost_loop = inner->innermost_loop;
854    if (innermost_loop) {
855       /* Entering the innermost loop’s continue construct. */
856       if (!vtn_is_single_block_loop(innermost_loop) &&
857           succ_pos == innermost_loop->continue_pos) {
858          set_needs_continue_propagation(inner);
859          return vtn_branch_type_loop_continue;
860       }
861 
862       /* Breaking from the innermost loop (and branching from back-edge block
863        * to loop merge).
864        */
865       if (succ_pos == innermost_loop->end_pos) {
866          set_needs_break_propagation(inner, innermost_loop);
867          return vtn_branch_type_loop_break;
868       }
869 
870       /* Next loop iteration.  There can be only a single loop back-edge
871        * for each loop construct.
872        */
873       if (succ_pos == innermost_loop->start_pos) {
874          vtn_assert(inner->type == vtn_construct_type_continue ||
875                     vtn_is_single_block_loop(innermost_loop));
876          return vtn_branch_type_loop_back_edge;
877       }
878    }
879 
880    struct vtn_construct *innermost_switch = inner->innermost_switch;
881    if (innermost_switch) {
882       struct vtn_construct *innermost_cse = inner->innermost_case;
883 
884       /* Breaking from the innermost switch construct. */
885       if (succ_pos == innermost_switch->end_pos) {
886          /* Use a nloop if this is not a natural exit from a case construct. */
887          if (innermost_cse && pos != innermost_cse->end_pos - 1) {
888             innermost_cse->needs_nloop = true;
889             set_needs_break_propagation(inner, innermost_cse);
890          }
891          return vtn_branch_type_switch_break;
892       }
893 
894       /* Branching from one case construct to another. */
895       if (inner != innermost_switch) {
896          vtn_assert(innermost_cse);
897          vtn_assert(innermost_cse->parent == innermost_switch);
898 
899          if (succ->block->switch_case) {
900             /* Both cases should be from the same Switch construct. */
901             struct vtn_construct *target_cse = succ->block->parent->innermost_case;
902             vtn_assert(target_cse->parent == innermost_switch);
903             target_cse->needs_fallthrough = true;
904             return vtn_branch_type_switch_fallthrough;
905          }
906       }
907    }
908 
909    if (inner->type == vtn_construct_type_selection) {
910       /* Branches from the header block that were not categorized above will
911        * follow to the then/else paths or to the merge block, and are handled
912        * by the nir_if node.
913        */
914       if (block->merge)
915          return vtn_branch_type_forward;
916 
917       /* Breaking from a selection construct. */
918       if (succ_pos == inner->end_pos) {
919          /* Identify cases where the break would be a natural flow in the NIR
920           * construct.  We don't need the extra loop in such cases.
921           *
922           * Because then/else are not ordered, we need to find which one happens
923           * later.  For non early merges, the branch from the block right before
924           * the second side of the if starts will also jumps naturally to the
925           * end of the if.
926           */
927          const bool has_early_merge = inner->merge_pos != inner->end_pos;
928          const unsigned second_pos = MAX2(inner->then_pos, inner->else_pos);
929 
930          const bool natural_exit_from_if =
931             pos + 1 == inner->end_pos ||
932             (!has_early_merge && (pos + 1 == second_pos));
933 
934          inner->needs_nloop = !natural_exit_from_if;
935          return vtn_branch_type_if_break;
936       }
937    }
938 
939    if (succ_pos < inner->end_pos)
940       return vtn_branch_type_forward;
941 
942    const enum nir_spirv_debug_level level = NIR_SPIRV_DEBUG_LEVEL_ERROR;
943    const size_t offset = 0;
944 
945    vtn_logf(b, level, offset,
946             "SPIR-V parsing FAILED:\n"
947             "    Unrecognized branch from block pos %u (id=%u) "
948             "to block pos %u (id=%u)",
949             block->pos, block->label[1],
950             succ->block->pos, succ->block->label[1]);
951 
952    vtn_logf(b, level, offset,
953             "    Inner construct '%s': %u -> %u  (merge=%u then=%u else=%u)",
954             vtn_construct_type_to_string(inner->type),
955             inner->start_pos, inner->end_pos, inner->merge_pos, inner->then_pos, inner->else_pos);
956 
957    struct vtn_construct *outer = inner->parent;
958    if (outer) {
959       vtn_logf(b, level, offset,
960                "    Outer construct '%s': %u -> %u  (merge=%u then=%u else=%u)",
961                vtn_construct_type_to_string(outer->type),
962                outer->start_pos, outer->end_pos, outer->merge_pos, outer->then_pos, outer->else_pos);
963    }
964 
965    vtn_fail("Unable to identify branch type");
966    return vtn_branch_type_none;
967 }
968 
969 static enum vtn_branch_type
branch_type_for_terminator(struct vtn_builder * b,struct vtn_block * block)970 branch_type_for_terminator(struct vtn_builder *b, struct vtn_block *block)
971 {
972    vtn_assert(block->successors_count == 1);
973    vtn_assert(block->successors[0].block == NULL);
974 
975    switch (block->branch[0] & SpvOpCodeMask) {
976    case SpvOpKill:
977       return vtn_branch_type_discard;
978    case SpvOpTerminateInvocation:
979       return vtn_branch_type_terminate_invocation;
980    case SpvOpIgnoreIntersectionKHR:
981       return vtn_branch_type_ignore_intersection;
982    case SpvOpTerminateRayKHR:
983       return vtn_branch_type_terminate_ray;
984    case SpvOpEmitMeshTasksEXT:
985       return vtn_branch_type_emit_mesh_tasks;
986    case SpvOpReturn:
987    case SpvOpReturnValue:
988    case SpvOpUnreachable:
989       return vtn_branch_type_return;
990    default:
991       unreachable("unexpected terminator operation");
992       return vtn_branch_type_none;
993    }
994 }
995 
996 static void
set_branch_types(struct vtn_builder * b)997 set_branch_types(struct vtn_builder *b)
998 {
999    for (unsigned i = 0; i < b->func->ordered_blocks_count; i++) {
1000       struct vtn_block *block = b->func->ordered_blocks[i];
1001       for (unsigned j = 0; j < block->successors_count; j++) {
1002          struct vtn_successor *succ = &block->successors[j];
1003 
1004          if (succ->block)
1005             succ->branch_type = branch_type_for_successor(b, block, succ);
1006          else
1007             succ->branch_type = branch_type_for_terminator(b, block);
1008 
1009          vtn_assert(succ->branch_type != vtn_branch_type_none);
1010       }
1011    }
1012 }
1013 
1014 static void
find_merge_pos(struct vtn_builder * b)1015 find_merge_pos(struct vtn_builder *b)
1016 {
1017    /* Merges are at the end of the construct by construction... */
1018    list_for_each_entry(struct vtn_construct, c, &b->func->constructs, link)
1019       c->merge_pos = c->end_pos;
1020 
1021    /* ...except when we have an "early merge", i.e. a branch that converges
1022     * before the declared merge point.  For these cases the actual merge is
1023     * stored in merge_pos.
1024     *
1025     * Look at all header blocks for constructs that may have such early
1026     * merge, and check whether they fit
1027     */
1028    for (unsigned i = 0; i < b->func->ordered_blocks_count; i++) {
1029       if (!b->func->ordered_blocks[i]->merge)
1030          continue;
1031 
1032       struct vtn_block *header = b->func->ordered_blocks[i];
1033       if (header->successors_count != 2)
1034          continue;
1035 
1036       /* Ignore single-block loops (i.e. header thats in a continue
1037        * construct).  Because the loop has no body, no block will
1038        * be identified in the then/else sides, the vtn_emit_branch
1039        * calls will be enough.
1040        */
1041 
1042       struct vtn_construct *c = header->parent;
1043       if (c->type != vtn_construct_type_selection)
1044          continue;
1045 
1046       const unsigned first_pos = MIN2(c->then_pos, c->else_pos);
1047       const unsigned second_pos = MAX2(c->then_pos, c->else_pos);
1048 
1049       /* The first side ends where the second starts.  The second side ends
1050        * either the continue position (that is guaranteed to appear after the
1051        * body of a loop) or the actual end of the construct.
1052        *
1053        * Because of the way we ordered the blocks, if there's an early merge,
1054        * the first side of the if will have a branch inside the second side.
1055        */
1056       const unsigned first_end = second_pos;
1057       const unsigned second_end = c->end_pos;
1058 
1059       unsigned early_merge_pos = 0;
1060       for (unsigned pos = first_pos; pos < first_end; pos++) {
1061          /* For each block in first... */
1062          struct vtn_block *block = b->func->ordered_blocks[pos];
1063          for (unsigned s = 0; s < block->successors_count; s++) {
1064             if (block->successors[s].block) {
1065                /* ...see if one of its successors branches to the second side. */
1066                const unsigned succ_pos = block->successors[s].block->pos;
1067                if (succ_pos >= second_pos && succ_pos < second_end) {
1068                   vtn_fail_if(early_merge_pos,
1069                               "A single selection construct cannot "
1070                               "have multiple early merges");
1071                   early_merge_pos = succ_pos;
1072                }
1073             }
1074          }
1075 
1076          if (early_merge_pos) {
1077             c->merge_pos = early_merge_pos;
1078             break;
1079          }
1080       }
1081    }
1082 }
1083 
1084 void
vtn_build_structured_cfg(struct vtn_builder * b,const uint32_t * words,const uint32_t * end)1085 vtn_build_structured_cfg(struct vtn_builder *b, const uint32_t *words, const uint32_t *end)
1086 {
1087    vtn_foreach_function(func, &b->functions) {
1088       b->func = func;
1089 
1090       sort_blocks(b);
1091 
1092       create_constructs(b);
1093 
1094       validate_constructs(b);
1095 
1096       find_innermost_constructs(b);
1097 
1098       find_merge_pos(b);
1099 
1100       set_branch_types(b);
1101 
1102       if (MESA_SPIRV_DEBUG(STRUCTURED)) {
1103          printf("\nBLOCKS (%u):\n", func->ordered_blocks_count);
1104          print_ordered_blocks(func);
1105          printf("\nCONSTRUCTS (%u):\n", list_length(&func->constructs));
1106          print_constructs(func);
1107          printf("\n");
1108       }
1109    }
1110 }
1111 
1112 static int
vtn_set_break_vars_between(struct vtn_builder * b,struct vtn_construct * from,struct vtn_construct * to)1113 vtn_set_break_vars_between(struct vtn_builder *b,
1114                            struct vtn_construct *from,
1115                            struct vtn_construct *to)
1116 {
1117    vtn_assert(from);
1118    vtn_assert(to);
1119 
1120    int count = 0;
1121    for (struct vtn_construct *c = from; c != to; c = c->parent) {
1122       if (c->break_var) {
1123          vtn_assert(c->nloop);
1124          count++;
1125 
1126          /* There's no need to set break_var for the from block an actual break will be emitted
1127           * by the callers.
1128           */
1129          if (c != from)
1130             nir_store_var(&b->nb, c->break_var, nir_imm_true(&b->nb), 1);
1131       } else {
1132          /* There's a 1:1 correspondence between break_vars and nloops. */
1133          vtn_assert(!c->nloop);
1134       }
1135    }
1136 
1137    return count;
1138 }
1139 
1140 static void
vtn_emit_break_for_construct(struct vtn_builder * b,const struct vtn_block * block,struct vtn_construct * to_break)1141 vtn_emit_break_for_construct(struct vtn_builder *b,
1142                              const struct vtn_block *block,
1143                              struct vtn_construct *to_break)
1144 {
1145    vtn_assert(to_break);
1146    vtn_assert(to_break->nloop);
1147 
1148    bool has_intermediate = vtn_set_break_vars_between(b, block->parent, to_break);
1149    if (has_intermediate)
1150       nir_store_var(&b->nb, to_break->break_var, nir_imm_true(&b->nb), 1);
1151 
1152    nir_jump(&b->nb, nir_jump_break);
1153 }
1154 
1155 static void
vtn_emit_continue_for_construct(struct vtn_builder * b,const struct vtn_block * block,struct vtn_construct * to_continue)1156 vtn_emit_continue_for_construct(struct vtn_builder *b,
1157                                 const struct vtn_block *block,
1158                                 struct vtn_construct *to_continue)
1159 {
1160    vtn_assert(to_continue);
1161    vtn_assert(to_continue->type == vtn_construct_type_loop);
1162    vtn_assert(to_continue->nloop);
1163 
1164    bool has_intermediate = vtn_set_break_vars_between(b, block->parent, to_continue);
1165    if (has_intermediate) {
1166       nir_store_var(&b->nb, to_continue->continue_var, nir_imm_true(&b->nb), 1);
1167       nir_jump(&b->nb, nir_jump_break);
1168    } else {
1169       nir_jump(&b->nb, nir_jump_continue);
1170    }
1171 }
1172 
1173 static void
vtn_emit_branch(struct vtn_builder * b,const struct vtn_block * block,const struct vtn_successor * succ)1174 vtn_emit_branch(struct vtn_builder *b, const struct vtn_block *block,
1175                 const struct vtn_successor *succ)
1176 {
1177    switch (succ->branch_type) {
1178    case vtn_branch_type_none:
1179       vtn_assert(!"invalid branch type");
1180       break;
1181 
1182    case vtn_branch_type_forward:
1183       /* Nothing to do. */
1184       break;
1185 
1186    case vtn_branch_type_if_break: {
1187       struct vtn_construct *inner_if = block->parent;
1188       vtn_assert(inner_if->type == vtn_construct_type_selection);
1189       if (inner_if->nloop) {
1190          vtn_emit_break_for_construct(b, block, inner_if);
1191       } else {
1192          /* Nothing to do. This is a natural exit from an if construct. */
1193       }
1194       break;
1195    }
1196 
1197    case vtn_branch_type_switch_break: {
1198       struct vtn_construct *swtch = block->parent->innermost_switch;
1199       vtn_assert(swtch);
1200 
1201       struct vtn_construct *cse = block->parent->innermost_case;
1202       if (cse && cse->parent == swtch && cse->nloop) {
1203          vtn_emit_break_for_construct(b, block, cse);
1204       } else {
1205          /* Nothing to do.  This case doesn't have a loop, so this is a
1206           * natural break from a case.
1207           */
1208       }
1209       break;
1210    }
1211 
1212    case vtn_branch_type_switch_fallthrough: {
1213       struct vtn_construct *cse = block->parent->innermost_case;
1214       vtn_assert(cse);
1215 
1216       struct vtn_construct *swtch = cse->parent;
1217       vtn_assert(swtch->type == vtn_construct_type_switch);
1218 
1219       /* Successor is the start of another case construct with the same parent
1220        * switch construct.
1221        */
1222       vtn_assert(succ->block->switch_case != NULL);
1223       struct vtn_construct *target = succ->block->parent->innermost_case;
1224       vtn_assert(target != NULL && target->type == vtn_construct_type_case);
1225       vtn_assert(target->parent == swtch);
1226       vtn_assert(target->fallthrough_var);
1227 
1228       nir_store_var(&b->nb, target->fallthrough_var, nir_imm_true(&b->nb), 1);
1229       if (cse->nloop)
1230          vtn_emit_break_for_construct(b, block, cse);
1231       break;
1232    }
1233 
1234    case vtn_branch_type_loop_break: {
1235       struct vtn_construct *loop = block->parent->innermost_loop;
1236       vtn_assert(loop);
1237       vtn_emit_break_for_construct(b, block, loop);
1238       break;
1239    }
1240 
1241    case vtn_branch_type_loop_continue: {
1242       struct vtn_construct *loop = block->parent->innermost_loop;
1243       vtn_assert(loop);
1244       vtn_emit_continue_for_construct(b, block, loop);
1245       break;
1246    }
1247 
1248    case vtn_branch_type_loop_back_edge:
1249       /* Nothing to do: naturally handled by NIR loop node. */
1250       break;
1251 
1252    case vtn_branch_type_return:
1253       vtn_assert(block);
1254       vtn_emit_ret_store(b, block);
1255       nir_jump(&b->nb, nir_jump_return);
1256       break;
1257 
1258    case vtn_branch_type_discard:
1259       if (b->convert_discard_to_demote) {
1260          nir_demote(&b->nb);
1261 
1262          /* Workaround for outdated test cases from CTS and Tint which assume
1263           * that OpKill always terminates the invocation. Break from the
1264           * current loop if it exists in order to prevent infinite loops.
1265           */
1266          struct vtn_construct *loop = block->parent->innermost_loop;
1267          if (loop)
1268             vtn_emit_break_for_construct(b, block, loop);
1269       } else {
1270          nir_discard(&b->nb);
1271       }
1272       break;
1273 
1274    case vtn_branch_type_terminate_invocation:
1275       nir_terminate(&b->nb);
1276       break;
1277 
1278    case vtn_branch_type_ignore_intersection:
1279       nir_ignore_ray_intersection(&b->nb);
1280       nir_jump(&b->nb, nir_jump_halt);
1281       break;
1282 
1283    case vtn_branch_type_terminate_ray:
1284       nir_terminate_ray(&b->nb);
1285       nir_jump(&b->nb, nir_jump_halt);
1286       break;
1287 
1288    case vtn_branch_type_emit_mesh_tasks: {
1289       vtn_assert(block);
1290       vtn_assert(block->branch);
1291 
1292       const uint32_t *w = block->branch;
1293       vtn_assert((w[0] & SpvOpCodeMask) == SpvOpEmitMeshTasksEXT);
1294 
1295       /* Launches mesh shader workgroups from the task shader.
1296        * Arguments are: vec(x, y, z), payload pointer
1297        */
1298       nir_def *dimensions =
1299          nir_vec3(&b->nb, vtn_get_nir_ssa(b, w[1]),
1300                           vtn_get_nir_ssa(b, w[2]),
1301                           vtn_get_nir_ssa(b, w[3]));
1302 
1303       /* The payload variable is optional.
1304        * We don't have a NULL deref in NIR, so just emit the explicit
1305        * intrinsic when there is no payload.
1306        */
1307       const unsigned count = w[0] >> SpvWordCountShift;
1308       if (count == 4)
1309          nir_launch_mesh_workgroups(&b->nb, dimensions);
1310       else if (count == 5)
1311          nir_launch_mesh_workgroups_with_payload_deref(&b->nb, dimensions,
1312                                                        vtn_get_nir_ssa(b, w[4]));
1313       else
1314          vtn_fail("Invalid EmitMeshTasksEXT.");
1315 
1316       nir_jump(&b->nb, nir_jump_halt);
1317       break;
1318    }
1319 
1320    default:
1321       vtn_fail("Invalid branch type");
1322    }
1323 }
1324 
1325 static nir_selection_control
vtn_selection_control(struct vtn_builder * b,SpvSelectionControlMask control)1326 vtn_selection_control(struct vtn_builder *b, SpvSelectionControlMask control)
1327 {
1328    if (control == SpvSelectionControlMaskNone)
1329       return nir_selection_control_none;
1330    else if (control & SpvSelectionControlDontFlattenMask)
1331       return nir_selection_control_dont_flatten;
1332    else if (control & SpvSelectionControlFlattenMask)
1333       return nir_selection_control_flatten;
1334    else
1335       vtn_fail("Invalid selection control");
1336 }
1337 
1338 static void
vtn_emit_block(struct vtn_builder * b,struct vtn_block * block,vtn_instruction_handler handler)1339 vtn_emit_block(struct vtn_builder *b, struct vtn_block *block,
1340                vtn_instruction_handler handler)
1341 {
1342    const uint32_t *block_start = block->label;
1343    const uint32_t *block_end = block->merge ? block->merge :
1344                                               block->branch;
1345 
1346    block_start = vtn_foreach_instruction(b, block_start, block_end,
1347                                          vtn_handle_phis_first_pass);
1348 
1349    vtn_foreach_instruction(b, block_start, block_end, handler);
1350 
1351    block->end_nop = nir_nop(&b->nb);
1352 
1353    if (block->parent->type == vtn_construct_type_switch) {
1354       /* Switch is handled as a sequence of NIR if for each of the cases. */
1355 
1356    } else if (block->successors_count == 1) {
1357       vtn_assert(block->successors[0].branch_type != vtn_branch_type_none);
1358       vtn_emit_branch(b, block, &block->successors[0]);
1359 
1360    } else if (block->successors_count == 2) {
1361       struct vtn_successor *then_succ = &block->successors[0];
1362       struct vtn_successor *else_succ = &block->successors[1];
1363       struct vtn_construct *c = block->parent;
1364 
1365       nir_def *cond = vtn_get_nir_ssa(b, block->branch[1]);
1366       if (then_succ->block == else_succ->block)
1367          cond = nir_imm_true(&b->nb);
1368 
1369       /* The branches will already be emitted here, so for paths that
1370        * doesn't have blocks inside the construct, e.g. that are an
1371        * exit from the construct, nothing else is needed.
1372        */
1373       nir_if *sel = nir_push_if(&b->nb, cond);
1374       vtn_emit_branch(b, block, then_succ);
1375       if (then_succ->block != else_succ->block) {
1376          nir_push_else(&b->nb, NULL);
1377          vtn_emit_branch(b, block, else_succ);
1378       }
1379       nir_pop_if(&b->nb, NULL);
1380 
1381       if (c->type == vtn_construct_type_selection &&
1382           block->pos == c->start_pos) {
1383          /* This is the start of a selection construct. Record the nir_if in
1384           * the construct so we can close it properly and handle the then and
1385           * else cases in block iteration.
1386           */
1387          vtn_assert(c->nif == NULL);
1388          c->nif = sel;
1389 
1390          vtn_assert(block->merge != NULL);
1391 
1392          SpvOp merge_op = block->merge[0] & SpvOpCodeMask;
1393          if (merge_op == SpvOpSelectionMerge)
1394             sel->control = vtn_selection_control(b, block->merge[2]);
1395 
1396          /* In most cases, vtn_emit_cf_func_structured() will place the cursor
1397           * in the correct side of the nir_if. However, in the case where the
1398           * selection construct is empty, we need to ensure that the cursor is
1399           * at least inside the nir_if or NIR will assert when we try to close
1400           * it with nir_pop_if().
1401           */
1402          b->nb.cursor = nir_before_cf_list(&sel->then_list);
1403       } else {
1404          vtn_fail_if(then_succ->branch_type == vtn_branch_type_forward &&
1405                      else_succ->branch_type == vtn_branch_type_forward &&
1406                      then_succ->block != else_succ->block,
1407                      "An OpSelectionMerge instruction is required to precede "
1408                      "an OpBranchConditional instruction that has different "
1409                      "True Label and False Label operands where neither are "
1410                      "declared merge blocks or Continue Targets.");
1411 
1412          if (then_succ->branch_type == vtn_branch_type_forward) {
1413             b->nb.cursor = nir_before_cf_list(&sel->then_list);
1414          } else if (else_succ->branch_type == vtn_branch_type_forward) {
1415             b->nb.cursor = nir_before_cf_list(&sel->else_list);
1416          } else {
1417             /* Leave it alone */
1418          }
1419       }
1420    }
1421 }
1422 
1423 static nir_def *
vtn_switch_case_condition(struct vtn_builder * b,struct vtn_construct * swtch,nir_def * sel,struct vtn_case * cse)1424 vtn_switch_case_condition(struct vtn_builder *b, struct vtn_construct *swtch,
1425                           nir_def *sel, struct vtn_case *cse)
1426 {
1427    vtn_assert(swtch->type == vtn_construct_type_switch);
1428 
1429    if (cse->is_default) {
1430       nir_def *any = nir_imm_false(&b->nb);
1431 
1432       struct vtn_block *header = b->func->ordered_blocks[swtch->start_pos];
1433 
1434       for (unsigned j = 0; j < header->successors_count; j++) {
1435          struct vtn_successor *succ = &header->successors[j];
1436          struct vtn_case *other = succ->block->switch_case;
1437 
1438          if (other->is_default)
1439             continue;
1440          any = nir_ior(&b->nb, any,
1441                        vtn_switch_case_condition(b, swtch, sel, other));
1442       }
1443 
1444       return nir_inot(&b->nb, any);
1445    } else {
1446       nir_def *cond = nir_imm_false(&b->nb);
1447       util_dynarray_foreach(&cse->values, uint64_t, val)
1448          cond = nir_ior(&b->nb, cond, nir_ieq_imm(&b->nb, sel, *val));
1449       return cond;
1450    }
1451 }
1452 
1453 static nir_loop_control
vtn_loop_control(struct vtn_builder * b,SpvLoopControlMask control)1454 vtn_loop_control(struct vtn_builder *b, SpvLoopControlMask control)
1455 {
1456    if (control == SpvLoopControlMaskNone)
1457       return nir_loop_control_none;
1458    else if (control & SpvLoopControlDontUnrollMask)
1459       return nir_loop_control_dont_unroll;
1460    else if (control & SpvLoopControlUnrollMask)
1461       return nir_loop_control_unroll;
1462    else if ((control & SpvLoopControlDependencyInfiniteMask) ||
1463             (control & SpvLoopControlDependencyLengthMask) ||
1464             (control & SpvLoopControlMinIterationsMask) ||
1465             (control & SpvLoopControlMaxIterationsMask) ||
1466             (control & SpvLoopControlIterationMultipleMask) ||
1467             (control & SpvLoopControlPeelCountMask) ||
1468             (control & SpvLoopControlPartialCountMask)) {
1469       /* We do not do anything special with these yet. */
1470       return nir_loop_control_none;
1471    } else {
1472       vtn_fail("Invalid loop control");
1473    }
1474 }
1475 
1476 static void
vtn_emit_control_flow_propagation(struct vtn_builder * b,struct vtn_construct * top)1477 vtn_emit_control_flow_propagation(struct vtn_builder *b,
1478                                   struct vtn_construct *top)
1479 {
1480    if (top->type == vtn_construct_type_function ||
1481        top->type == vtn_construct_type_continue ||
1482        top->type == vtn_construct_type_switch)
1483       return;
1484 
1485    /* Find the innermost parent with a NIR loop. */
1486    struct vtn_construct *parent_with_nloop = NULL;
1487    for (struct vtn_construct *c = top->parent; c; c = c->parent) {
1488       if (c->nloop) {
1489          parent_with_nloop = c;
1490          break;
1491       }
1492    }
1493    if (parent_with_nloop == NULL)
1494       return;
1495 
1496    /* If there's another nloop in the parent chain, decide whether we need
1497     * to emit conditional continue/break after top construct is closed.
1498     */
1499 
1500    if (top->needs_continue_propagation &&
1501        parent_with_nloop == top->innermost_loop) {
1502       struct vtn_construct *loop = top->innermost_loop;
1503       vtn_assert(loop);
1504       vtn_assert(loop != top);
1505 
1506       nir_push_if(&b->nb, nir_load_var(&b->nb, loop->continue_var));
1507       nir_jump(&b->nb, nir_jump_continue);
1508       nir_pop_if(&b->nb, NULL);
1509    }
1510 
1511    if (top->needs_break_propagation) {
1512       vtn_assert(parent_with_nloop->break_var);
1513       nir_break_if(&b->nb, nir_load_var(&b->nb, parent_with_nloop->break_var));
1514    }
1515 }
1516 
1517 static inline nir_variable *
vtn_create_local_bool(struct vtn_builder * b,const char * name)1518 vtn_create_local_bool(struct vtn_builder *b, const char *name)
1519 {
1520    return nir_local_variable_create(b->nb.impl, glsl_bool_type(), name);
1521 }
1522 
1523 void
vtn_emit_cf_func_structured(struct vtn_builder * b,struct vtn_function * func,vtn_instruction_handler handler)1524 vtn_emit_cf_func_structured(struct vtn_builder *b, struct vtn_function *func,
1525                             vtn_instruction_handler handler)
1526 {
1527    struct vtn_construct *current =
1528       list_first_entry(&func->constructs, struct vtn_construct, link);
1529    vtn_assert(current->type == vtn_construct_type_function);
1530 
1531    /* Walk the blocks in order keeping track of the constructs that started
1532     * but haven't ended yet.  When constructs start and end, add extra code to
1533     * setup the NIR control flow (different for each construct), also add
1534     * extra code for propagating certain branch types.
1535     */
1536 
1537    struct vtn_construct_stack stack;
1538    init_construct_stack(&stack, b);
1539    push_construct(&stack, current);
1540 
1541    for (unsigned i = 0; i < func->ordered_blocks_count; i++) {
1542       struct vtn_block *block = func->ordered_blocks[i];
1543       struct vtn_construct *top = top_construct(&stack);
1544 
1545       /* Close out any past constructs and make sure the cursor is at the
1546        * right place to start this block. For each block, there are three
1547        * cases we care about here:
1548        *
1549        *  1. It is the block at the end (in our reverse structured post-order
1550        *     traversal) of one or more constructs and closes them.
1551        *
1552        *  2. It is an early merge of a selection construct.
1553        *
1554        *  3. It is the start of the then or else case of a selection construct
1555        *     and we may have previously been emitting code in the other side.
1556        */
1557 
1558       /* Close (or early merge) any constructs that end at this block. */
1559       bool merged_any_constructs = false;
1560       while (top->end_pos == block->pos || top->merge_pos == block->pos) {
1561          merged_any_constructs = true;
1562          if (top->nif) {
1563             const bool has_early_merge = top->merge_pos != top->end_pos;
1564 
1565             if (!has_early_merge) {
1566                nir_pop_if(&b->nb, top->nif);
1567             } else if (block->pos == top->merge_pos) {
1568                /* This is an early merge. */
1569 
1570                nir_pop_if(&b->nb, top->nif);
1571 
1572                /* The extra dummy "if (true)" for the merged part avoids
1573                 * generating multiple jumps in sequence and upsetting
1574                 * NIR rules.  We'll pop it in the case below when we reach
1575                 * the end_pos block.
1576                 */
1577                nir_push_if(&b->nb, nir_imm_true(&b->nb));
1578 
1579                /* Stop since this construct still has more blocks. */
1580                break;
1581             } else {
1582                /* Pop the dummy if added for the blocks after the early merge. */
1583                vtn_assert(block->pos == top->end_pos);
1584                nir_pop_if(&b->nb, NULL);
1585             }
1586          }
1587 
1588          if (top->nloop) {
1589             /* For constructs that are not SPIR-V loop, a NIR loop may be used
1590              * to provide richer control flow.  So we add a nir break to cause
1591              * the loop stop at the first iteration, unless there's already a
1592              * jump at the end of the last block.
1593              */
1594             if (top->type != vtn_construct_type_loop) {
1595                nir_block *last = nir_loop_last_block(top->nloop);
1596                if (!nir_block_ends_in_jump(last)) {
1597                   b->nb.cursor = nir_after_block(last);
1598                   nir_jump(&b->nb, nir_jump_break);
1599                }
1600             }
1601 
1602             nir_pop_loop(&b->nb, top->nloop);
1603          }
1604 
1605          vtn_emit_control_flow_propagation(b, top);
1606 
1607          pop_construct(&stack);
1608          top = top_construct(&stack);
1609       }
1610 
1611       /* We are fully inside the current top. */
1612       vtn_assert(block->pos < top->end_pos);
1613 
1614       /* Move the cursor to the right side of a selection construct.
1615        *
1616        * If we merged any constructs, we don't need to move because
1617        * either: this is an early merge and we already set the cursor above;
1618        * or a construct ended, and this is a 'merge block' for that
1619        * construct, so it can't also be a 'Target' for an outer conditional.
1620        */
1621       if (!merged_any_constructs && top->type == vtn_construct_type_selection &&
1622           (block->pos == top->then_pos || block->pos == top->else_pos)) {
1623          vtn_assert(top->nif);
1624 
1625          struct vtn_block *header = func->ordered_blocks[top->start_pos];
1626          vtn_assert(header->successors_count == 2);
1627 
1628          if (block->pos == top->then_pos)
1629             b->nb.cursor = nir_before_cf_list(&top->nif->then_list);
1630          else
1631             b->nb.cursor = nir_before_cf_list(&top->nif->else_list);
1632       }
1633 
1634       /* Open any constructs which start at this block.
1635        *
1636        * Constructs which are designated by Op*Merge are considered to start
1637        * at the block which contains the merge instruction.  This means that
1638        * loops constructs start at the first block inside the loop while
1639        * selection and switch constructs start at the block containing the
1640        * OpBranchConditional or OpSwitch.
1641        */
1642       while (current->link.next != &func->constructs) {
1643          struct vtn_construct *next =
1644             list_entry(current->link.next, struct vtn_construct, link);
1645 
1646          /* Stop once we find a construct that doesn't start in this block. */
1647          if (next->start_pos != block->pos)
1648             break;
1649 
1650          switch (next->type) {
1651          case vtn_construct_type_function:
1652             unreachable("should've already entered function construct");
1653             break;
1654 
1655          case vtn_construct_type_selection: {
1656             /* Add the wrapper loop now and the nir_if, along the contents of
1657              * this entire block, will get added inside the loop as part of
1658              * vtn_emit_block() below.
1659              */
1660             if (next->needs_nloop) {
1661                next->break_var = vtn_create_local_bool(b, "if_break");
1662                nir_store_var(&b->nb, next->break_var, nir_imm_false(&b->nb), 1);
1663                next->nloop = nir_push_loop(&b->nb);
1664             }
1665             break;
1666          }
1667 
1668          case vtn_construct_type_loop: {
1669             next->break_var = vtn_create_local_bool(b, "loop_break");
1670             next->continue_var = vtn_create_local_bool(b, "loop_continue");
1671 
1672             nir_store_var(&b->nb, next->break_var, nir_imm_false(&b->nb), 1);
1673             next->nloop = nir_push_loop(&b->nb);
1674             nir_store_var(&b->nb, next->continue_var, nir_imm_false(&b->nb), 1);
1675 
1676             next->nloop->control = vtn_loop_control(b, block->merge[3]);
1677 
1678             break;
1679          }
1680 
1681          case vtn_construct_type_continue: {
1682             struct vtn_construct *loop = next->parent;
1683             assert(loop->type == vtn_construct_type_loop);
1684             assert(!vtn_is_single_block_loop(loop));
1685 
1686             nir_push_continue(&b->nb, loop->nloop);
1687 
1688             break;
1689          }
1690 
1691          case vtn_construct_type_switch: {
1692             /* Switch is not translated to any NIR node, all is handled by
1693              * each individual case construct.
1694              */
1695             for (unsigned j = 0; j < block->successors_count; j++) {
1696                struct vtn_successor *s = &block->successors[j];
1697                if (s->block && s->block->pos < next->end_pos) {
1698                   struct vtn_construct *c = s->block->parent->innermost_case;
1699                   vtn_assert(c->type == vtn_construct_type_case);
1700                   if (c->needs_fallthrough) {
1701                      c->fallthrough_var = vtn_create_local_bool(b, "fallthrough");
1702                      nir_store_var(&b->nb, c->fallthrough_var, nir_imm_false(&b->nb), 1);
1703                   }
1704                }
1705             }
1706             break;
1707          }
1708 
1709          case vtn_construct_type_case: {
1710             struct vtn_construct *swtch = next->parent;
1711             struct vtn_block *header = func->ordered_blocks[swtch->start_pos];
1712 
1713             nir_def *sel = vtn_get_nir_ssa(b, header->branch[1]);
1714             nir_def *case_condition =
1715                vtn_switch_case_condition(b, swtch, sel, block->switch_case);
1716             if (next->fallthrough_var) {
1717                case_condition =
1718                   nir_ior(&b->nb, case_condition,
1719                           nir_load_var(&b->nb, next->fallthrough_var));
1720             }
1721 
1722             if (next->needs_nloop) {
1723                next->break_var = vtn_create_local_bool(b, "case_break");
1724                nir_store_var(&b->nb, next->break_var, nir_imm_false(&b->nb), 1);
1725                next->nloop = nir_push_loop(&b->nb);
1726             }
1727 
1728             next->nif = nir_push_if(&b->nb, case_condition);
1729 
1730             break;
1731          }
1732          }
1733 
1734          current = next;
1735          push_construct(&stack, next);
1736       }
1737 
1738       vtn_emit_block(b, block, handler);
1739    }
1740 
1741    vtn_assert(count_construct_stack(&stack) == 1);
1742 }
1743