xref: /aosp_15_r20/external/mesa3d/src/nouveau/compiler/nak/opt_jump_thread.rs (revision 6104692788411f58d303aa86923a9ff6ecaded22)
1 // Copyright © 2023 Mel Henning
2 // SPDX-License-Identifier: MIT
3 
4 use crate::ir::*;
5 use compiler::cfg::CFGBuilder;
6 use std::collections::HashMap;
7 
clone_branch(op: &Op) -> Op8 fn clone_branch(op: &Op) -> Op {
9     match op {
10         Op::Bra(b) => Op::Bra(b.clone()),
11         Op::Exit(e) => Op::Exit(e.clone()),
12         _ => unreachable!(),
13     }
14 }
15 
jump_thread(func: &mut Function) -> bool16 fn jump_thread(func: &mut Function) -> bool {
17     // Let's call a basic block "trivial" if its only instruction is an
18     // unconditional branch. If a block is trivial, we can update all of its
19     // predecessors to jump to its sucessor.
20     //
21     // A single reverse pass over the basic blocks is enough to update all of
22     // the edges we're interested in. Roughly, if we assume that all loops in
23     // the shader can terminate, then loop heads are never trivial and we
24     // never replace a backward edge. Therefore, in each step we only need to
25     // make sure that later control flow has been replaced in order to update
26     // the current block as much as possible.
27     //
28     // We additionally try to update a branch-to-empty-block to point to the
29     // block's successor, which along with block dce/reordering can sometimes
30     // enable a later optimization that converts branches to fallthrough.
31     let mut progress = false;
32 
33     // A branch to label can be replaced with Op
34     let mut replacements: HashMap<Label, Op> = HashMap::new();
35 
36     // Invariant 1: At the end of each loop iteration,
37     //              every trivial block with an index in [i, blocks.len())
38     //              is represented in replacements.keys()
39     // Invariant 2: replacements.values() never contains
40     //              a branch to a trivial block
41     for i in (0..func.blocks.len()).rev() {
42         // Replace the branch if possible
43         if let Some(instr) = func.blocks[i].instrs.last_mut() {
44             if let Op::Bra(OpBra { target }) = instr.op {
45                 if let Some(replacement) = replacements.get(&target) {
46                     instr.op = clone_branch(replacement);
47                     progress = true;
48                 }
49                 // If the branch target was previously a trivial block then the
50                 // branch was previously a forward edge (see above) and by
51                 // invariants 1 and 2 we just updated the branch to target
52                 // a nontrivial block
53             }
54         }
55 
56         // Is this block trivial?
57         let block_label = func.blocks[i].label;
58         match &func.blocks[i].instrs[..] {
59             [instr] => {
60                 if instr.is_branch() && instr.pred.is_true() {
61                     // Upholds invariant 2 because we updated the branch above
62                     replacements.insert(block_label, clone_branch(&instr.op));
63                 }
64             }
65             [] => {
66                 // Empty block - falls through
67                 // Our successor might be trivial, so we need to
68                 // apply the rewrite map to uphold invariant 2
69                 let target_label = func.blocks[i + 1].label;
70                 let replacement = replacements
71                     .get(&target_label)
72                     .map(clone_branch)
73                     .unwrap_or_else(|| {
74                         Op::Bra(OpBra {
75                             target: target_label,
76                         })
77                     });
78                 replacements.insert(block_label, replacement);
79             }
80             _ => (),
81         }
82     }
83 
84     if progress {
85         // We don't update the CFG above, so rewrite it if we made progress
86         rewrite_cfg(func);
87     }
88 
89     progress
90 }
91 
rewrite_cfg(func: &mut Function)92 fn rewrite_cfg(func: &mut Function) {
93     // CFGBuilder takes care of removing dead blocks for us
94     // We use the basic block's label to identify it
95     let mut builder = CFGBuilder::new();
96 
97     for i in 0..func.blocks.len() {
98         let block = &func.blocks[i];
99         // Note: fall-though must be first edge
100         if block.falls_through() {
101             let next_block = &func.blocks[i + 1];
102             builder.add_edge(block.label, next_block.label);
103         }
104         if let Some(control_flow) = block.branch() {
105             match &control_flow.op {
106                 Op::Bra(bra) => {
107                     builder.add_edge(block.label, bra.target);
108                 }
109                 Op::Exit(_) => (),
110                 _ => unreachable!(),
111             };
112         }
113     }
114 
115     for block in func.blocks.drain() {
116         builder.add_node(block.label, block);
117     }
118     let _ = std::mem::replace(&mut func.blocks, builder.as_cfg());
119 }
120 
121 /// Replace jumps to the following block with fall-through
opt_fall_through(func: &mut Function)122 fn opt_fall_through(func: &mut Function) {
123     for i in 0..func.blocks.len() - 1 {
124         let remove_last_instr = match func.blocks[i].branch() {
125             Some(b) => match b.op {
126                 Op::Bra(OpBra { target }) => target == func.blocks[i + 1].label,
127                 _ => false,
128             },
129             None => false,
130         };
131 
132         if remove_last_instr {
133             func.blocks[i].instrs.pop();
134         }
135     }
136 }
137 
138 impl Function {
opt_jump_thread(&mut self)139     pub fn opt_jump_thread(&mut self) {
140         if jump_thread(self) {
141             opt_fall_through(self);
142         }
143     }
144 }
145 
146 impl Shader<'_> {
147     /// A simple jump threading pass
148     ///
149     /// Note that this can introduce critical edges, so it cannot be run before RA
opt_jump_thread(&mut self)150     pub fn opt_jump_thread(&mut self) {
151         for f in &mut self.functions {
152             f.opt_jump_thread();
153         }
154     }
155 }
156