1 // Copyright © 2023 Mel Henning
2 // SPDX-License-Identifier: MIT
3
4 use crate::ir::*;
5 use compiler::cfg::CFGBuilder;
6 use std::collections::HashMap;
7
clone_branch(op: &Op) -> Op8 fn clone_branch(op: &Op) -> Op {
9 match op {
10 Op::Bra(b) => Op::Bra(b.clone()),
11 Op::Exit(e) => Op::Exit(e.clone()),
12 _ => unreachable!(),
13 }
14 }
15
jump_thread(func: &mut Function) -> bool16 fn jump_thread(func: &mut Function) -> bool {
17 // Let's call a basic block "trivial" if its only instruction is an
18 // unconditional branch. If a block is trivial, we can update all of its
19 // predecessors to jump to its sucessor.
20 //
21 // A single reverse pass over the basic blocks is enough to update all of
22 // the edges we're interested in. Roughly, if we assume that all loops in
23 // the shader can terminate, then loop heads are never trivial and we
24 // never replace a backward edge. Therefore, in each step we only need to
25 // make sure that later control flow has been replaced in order to update
26 // the current block as much as possible.
27 //
28 // We additionally try to update a branch-to-empty-block to point to the
29 // block's successor, which along with block dce/reordering can sometimes
30 // enable a later optimization that converts branches to fallthrough.
31 let mut progress = false;
32
33 // A branch to label can be replaced with Op
34 let mut replacements: HashMap<Label, Op> = HashMap::new();
35
36 // Invariant 1: At the end of each loop iteration,
37 // every trivial block with an index in [i, blocks.len())
38 // is represented in replacements.keys()
39 // Invariant 2: replacements.values() never contains
40 // a branch to a trivial block
41 for i in (0..func.blocks.len()).rev() {
42 // Replace the branch if possible
43 if let Some(instr) = func.blocks[i].instrs.last_mut() {
44 if let Op::Bra(OpBra { target }) = instr.op {
45 if let Some(replacement) = replacements.get(&target) {
46 instr.op = clone_branch(replacement);
47 progress = true;
48 }
49 // If the branch target was previously a trivial block then the
50 // branch was previously a forward edge (see above) and by
51 // invariants 1 and 2 we just updated the branch to target
52 // a nontrivial block
53 }
54 }
55
56 // Is this block trivial?
57 let block_label = func.blocks[i].label;
58 match &func.blocks[i].instrs[..] {
59 [instr] => {
60 if instr.is_branch() && instr.pred.is_true() {
61 // Upholds invariant 2 because we updated the branch above
62 replacements.insert(block_label, clone_branch(&instr.op));
63 }
64 }
65 [] => {
66 // Empty block - falls through
67 // Our successor might be trivial, so we need to
68 // apply the rewrite map to uphold invariant 2
69 let target_label = func.blocks[i + 1].label;
70 let replacement = replacements
71 .get(&target_label)
72 .map(clone_branch)
73 .unwrap_or_else(|| {
74 Op::Bra(OpBra {
75 target: target_label,
76 })
77 });
78 replacements.insert(block_label, replacement);
79 }
80 _ => (),
81 }
82 }
83
84 if progress {
85 // We don't update the CFG above, so rewrite it if we made progress
86 rewrite_cfg(func);
87 }
88
89 progress
90 }
91
rewrite_cfg(func: &mut Function)92 fn rewrite_cfg(func: &mut Function) {
93 // CFGBuilder takes care of removing dead blocks for us
94 // We use the basic block's label to identify it
95 let mut builder = CFGBuilder::new();
96
97 for i in 0..func.blocks.len() {
98 let block = &func.blocks[i];
99 // Note: fall-though must be first edge
100 if block.falls_through() {
101 let next_block = &func.blocks[i + 1];
102 builder.add_edge(block.label, next_block.label);
103 }
104 if let Some(control_flow) = block.branch() {
105 match &control_flow.op {
106 Op::Bra(bra) => {
107 builder.add_edge(block.label, bra.target);
108 }
109 Op::Exit(_) => (),
110 _ => unreachable!(),
111 };
112 }
113 }
114
115 for block in func.blocks.drain() {
116 builder.add_node(block.label, block);
117 }
118 let _ = std::mem::replace(&mut func.blocks, builder.as_cfg());
119 }
120
121 /// Replace jumps to the following block with fall-through
opt_fall_through(func: &mut Function)122 fn opt_fall_through(func: &mut Function) {
123 for i in 0..func.blocks.len() - 1 {
124 let remove_last_instr = match func.blocks[i].branch() {
125 Some(b) => match b.op {
126 Op::Bra(OpBra { target }) => target == func.blocks[i + 1].label,
127 _ => false,
128 },
129 None => false,
130 };
131
132 if remove_last_instr {
133 func.blocks[i].instrs.pop();
134 }
135 }
136 }
137
138 impl Function {
opt_jump_thread(&mut self)139 pub fn opt_jump_thread(&mut self) {
140 if jump_thread(self) {
141 opt_fall_through(self);
142 }
143 }
144 }
145
146 impl Shader<'_> {
147 /// A simple jump threading pass
148 ///
149 /// Note that this can introduce critical edges, so it cannot be run before RA
opt_jump_thread(&mut self)150 pub fn opt_jump_thread(&mut self) {
151 for f in &mut self.functions {
152 f.opt_jump_thread();
153 }
154 }
155 }
156