xref: /aosp_15_r20/external/mesa3d/src/compiler/nir/nir_lower_terminate_to_demote.c (revision 6104692788411f58d303aa86923a9ff6ecaded22)
1 /*
2  * Copyright © 2024 Collabora, Ltd.
3  * SPDX-License-Identifier: MIT
4  */
5 
6 #include "nir.h"
7 #include "nir_builder.h"
8 
9 static bool
nir_lower_terminate_cf_list(nir_builder * b,struct exec_list * cf_list)10 nir_lower_terminate_cf_list(nir_builder *b, struct exec_list *cf_list)
11 {
12    bool progress = false;
13 
14    foreach_list_typed_safe(nir_cf_node, node, node, cf_list) {
15       switch (node->type) {
16       case nir_cf_node_block: {
17          nir_block *block = nir_cf_node_as_block(node);
18 
19          nir_foreach_instr_safe(instr, block) {
20             if (instr->type != nir_instr_type_intrinsic)
21                continue;
22 
23             nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
24             switch (intrin->intrinsic) {
25             case nir_intrinsic_terminate: {
26                /* Everything after the terminate is dead */
27                nir_cf_list dead_cf;
28                nir_cf_extract(&dead_cf, nir_after_instr(&intrin->instr),
29                                         nir_after_cf_list(cf_list));
30                nir_cf_delete(&dead_cf);
31 
32                intrin->intrinsic = nir_intrinsic_demote;
33                b->cursor = nir_after_instr(&intrin->instr);
34                nir_jump(b, nir_jump_halt);
35 
36                /* We just removed the remainder of this list of CF nodes.
37                 * It's not safe to continue iterating.
38                 */
39                return true;
40             }
41 
42             case nir_intrinsic_terminate_if:
43                b->cursor = nir_before_instr(&intrin->instr);
44                nir_push_if(b, intrin->src[0].ssa);
45                {
46                   nir_demote(b);
47                   nir_jump(b, nir_jump_halt);
48                }
49                nir_instr_remove(&intrin->instr);
50                progress = true;
51                break;
52 
53             default:
54                break;
55             }
56          }
57          break;
58       }
59 
60       case nir_cf_node_if: {
61          nir_if *nif = nir_cf_node_as_if(node);
62          progress |= nir_lower_terminate_cf_list(b, &nif->then_list);
63          progress |= nir_lower_terminate_cf_list(b, &nif->else_list);
64          break;
65       }
66 
67       case nir_cf_node_loop: {
68          nir_loop *loop = nir_cf_node_as_loop(node);
69          progress |= nir_lower_terminate_cf_list(b, &loop->body);
70          progress |= nir_lower_terminate_cf_list(b, &loop->continue_list);
71          break;
72       }
73 
74       default:
75          unreachable("Unknown CF node type");
76       }
77    }
78 
79    return progress;
80 }
81 
82 static bool
nir_lower_terminate_impl(nir_function_impl * impl)83 nir_lower_terminate_impl(nir_function_impl *impl)
84 {
85    nir_builder b = nir_builder_create(impl);
86    bool progress = nir_lower_terminate_cf_list(&b, &impl->body);
87 
88    if (progress) {
89       nir_metadata_preserve(impl, nir_metadata_none);
90    } else {
91       nir_metadata_preserve(impl, nir_metadata_all);
92    }
93 
94    return progress;
95 }
96 
97 /** Lowers nir_intrinsic_terminate to demote + halt
98  *
99  * The semantics of nir_intrinsic_terminate require that threads immediately
100  * exit.  In SPIR-V, terminate is branch instruction even though it's only an
101  * intrinsic in NIR.  This pass lowers terminate to demote + halt.  Since halt
102  * is a jump instruction in NIR, this restores those semantics and NIR can
103  * reason about dead threads after a halt.  It allows lets back-ends to only
104  * implement nir_intrinsic_demote as long as they also implement nir_jump_halt.
105  */
106 bool
nir_lower_terminate_to_demote(nir_shader * nir)107 nir_lower_terminate_to_demote(nir_shader *nir)
108 {
109    bool progress = false;
110 
111    nir_foreach_function_impl(impl, nir) {
112       if (nir_lower_terminate_impl(impl))
113          progress = true;
114    }
115 
116    return progress;
117 }
118