xref: /aosp_15_r20/external/mesa3d/src/nouveau/compiler/nak/opt_prmt.rs (revision 6104692788411f58d303aa86923a9ff6ecaded22)
1 /*
2  * Copyright © 2023 Collabora, Ltd.
3  * SPDX-License-Identifier: MIT
4  */
5 
6 use std::collections::HashMap;
7 
8 use crate::ir::*;
9 
10 struct PrmtSrcs {
11     srcs: [SrcRef; 2],
12     num_srcs: usize,
13     imm_src: usize,
14     num_imm_bytes: usize,
15 }
16 
17 impl PrmtSrcs {
new() -> PrmtSrcs18     fn new() -> PrmtSrcs {
19         PrmtSrcs {
20             srcs: [SrcRef::Zero; 2],
21             num_srcs: 0,
22             imm_src: usize::MAX,
23             num_imm_bytes: 0,
24         }
25     }
26 
try_add_src(&mut self, src: SrcRef) -> Option<usize>27     fn try_add_src(&mut self, src: SrcRef) -> Option<usize> {
28         for i in 0..self.num_srcs {
29             if self.srcs[i] == src {
30                 return Some(i);
31             }
32         }
33 
34         if self.num_srcs < 2 {
35             let i = self.num_srcs;
36             self.num_srcs += 1;
37             self.srcs[i] = src;
38             Some(i)
39         } else {
40             None
41         }
42     }
43 
try_add_imm_u8(&mut self, u: u8) -> Option<usize>44     fn try_add_imm_u8(&mut self, u: u8) -> Option<usize> {
45         if self.imm_src == usize::MAX {
46             if self.num_srcs >= 2 {
47                 return None;
48             }
49             self.imm_src = self.num_srcs;
50             self.num_srcs += 1;
51         }
52 
53         match &mut self.srcs[self.imm_src] {
54             SrcRef::Zero => {
55                 if u == 0 {
56                     // Common case, just leave it as a SrcRef::Zero
57                     debug_assert!(self.num_imm_bytes <= 1);
58                     self.num_imm_bytes = 1;
59                     Some(0)
60                 } else {
61                     let b = self.num_imm_bytes;
62                     self.num_imm_bytes += 1;
63                     let imm = u32::from(u) << (b * 8);
64                     self.srcs[self.imm_src] = SrcRef::Imm32(imm);
65                     Some(b)
66                 }
67             }
68             SrcRef::Imm32(imm) => {
69                 let b = self.num_imm_bytes;
70                 self.num_imm_bytes += 1;
71                 *imm |= u32::from(u) << (b * 8);
72                 Some(b)
73             }
74             _ => panic!("We said this was the imm src"),
75         }
76     }
77 }
78 
79 struct PrmtEntry {
80     sel: PrmtSel,
81     srcs: [SrcRef; 2],
82 }
83 
84 struct PrmtPass {
85     ssa_prmt: HashMap<SSAValue, PrmtEntry>,
86 }
87 
88 impl PrmtPass {
new() -> PrmtPass89     fn new() -> PrmtPass {
90         PrmtPass {
91             ssa_prmt: HashMap::new(),
92         }
93     }
94 
add_prmt(&mut self, op: &OpPrmt)95     fn add_prmt(&mut self, op: &OpPrmt) {
96         let Dst::SSA(dst_ssa) = op.dst else {
97             return;
98         };
99         debug_assert!(dst_ssa.comps() == 1);
100         let dst_ssa = dst_ssa[0];
101 
102         let Some(sel) = op.get_sel() else {
103             return;
104         };
105 
106         debug_assert!(op.srcs[0].src_mod.is_none());
107         debug_assert!(op.srcs[1].src_mod.is_none());
108         let srcs = [op.srcs[0].src_ref, op.srcs[1].src_ref];
109 
110         self.ssa_prmt.insert(dst_ssa, PrmtEntry { sel, srcs });
111     }
112 
get_prmt(&self, ssa: &SSAValue) -> Option<&PrmtEntry>113     fn get_prmt(&self, ssa: &SSAValue) -> Option<&PrmtEntry> {
114         self.ssa_prmt.get(ssa)
115     }
116 
get_prmt_for_src(&self, src: &Src) -> Option<&PrmtEntry>117     fn get_prmt_for_src(&self, src: &Src) -> Option<&PrmtEntry> {
118         debug_assert!(src.src_mod.is_none());
119         if let SrcRef::SSA(vec) = &src.src_ref {
120             debug_assert!(vec.comps() == 1);
121             self.get_prmt(&vec[0])
122         } else {
123             None
124         }
125     }
126 
127     /// Try to optimize for the OpPrmt of OpPrmt case where only one source of
128     /// the inner OpPrmt is used
try_opt_prmt_src(&mut self, op: &mut OpPrmt, src_idx: usize) -> bool129     fn try_opt_prmt_src(&mut self, op: &mut OpPrmt, src_idx: usize) -> bool {
130         let Some(op_sel) = op.get_sel() else {
131             return false;
132         };
133 
134         let Some(src_prmt) = self.get_prmt_for_src(&op.srcs[src_idx]) else {
135             return false;
136         };
137 
138         let mut new_sel = [PrmtSelByte::INVALID; 4];
139         let mut src_prmt_src = usize::MAX;
140         for i in 0..4 {
141             let op_sel_byte = op_sel.get(i);
142             if op_sel_byte.src() != src_idx {
143                 new_sel[i] = op_sel_byte;
144                 continue;
145             }
146 
147             let src_sel_byte = src_prmt.sel.get(op_sel_byte.byte());
148 
149             if src_prmt_src != usize::MAX && src_prmt_src != src_sel_byte.src()
150             {
151                 return false;
152             }
153             src_prmt_src = src_sel_byte.src();
154 
155             new_sel[i] = PrmtSelByte::new(
156                 src_idx,
157                 src_sel_byte.byte(),
158                 op_sel_byte.msb() | src_sel_byte.msb(),
159             );
160         }
161 
162         let new_sel = PrmtSel::new(new_sel);
163 
164         op.sel = new_sel.into();
165         if src_prmt_src == usize::MAX {
166             // This source is unused
167             op.srcs[src_idx] = 0.into();
168         } else {
169             op.srcs[src_idx] = src_prmt.srcs[src_prmt_src].into();
170         }
171         true
172     }
173 
174     /// Try to optimize for the OpPrmt of OpPrmt case as if we're considering a
175     /// full 4-way OpPrmt in which some sources may be duplicates
try_opt_prmt4(&mut self, op: &mut OpPrmt) -> bool176     fn try_opt_prmt4(&mut self, op: &mut OpPrmt) -> bool {
177         let Some(op_sel) = op.get_sel() else {
178             return false;
179         };
180 
181         let mut srcs = PrmtSrcs::new();
182         let mut new_sel = [PrmtSelByte::INVALID; 4];
183         for i in 0..4 {
184             let op_sel_byte = op_sel.get(i);
185             let src = &op.srcs[op_sel_byte.src()];
186 
187             if let Some(src_prmt) = self.get_prmt_for_src(src) {
188                 let src_sel_byte = src_prmt.sel.get(op_sel_byte.byte());
189                 let src_prmt_src = &src_prmt.srcs[src_sel_byte.src()];
190                 if let Some(u) = src_prmt_src.as_u32() {
191                     let mut imm_u8 = src_sel_byte.fold_u32(u);
192                     if op_sel_byte.msb() {
193                         imm_u8 = ((imm_u8 as i8) >> 7) as u8;
194                     }
195 
196                     let Some(byte_idx) = srcs.try_add_imm_u8(imm_u8) else {
197                         return false;
198                     };
199 
200                     new_sel[i] =
201                         PrmtSelByte::new(srcs.imm_src, byte_idx, false);
202                 } else {
203                     let Some(src_idx) = srcs.try_add_src(*src_prmt_src) else {
204                         return false;
205                     };
206 
207                     new_sel[i] = PrmtSelByte::new(
208                         src_idx,
209                         src_sel_byte.byte(),
210                         op_sel_byte.msb() | src_sel_byte.msb(),
211                     );
212                 }
213             } else if let Some(u) = src.as_u32() {
214                 let imm_u8 = op_sel_byte.fold_u32(u);
215                 let Some(byte_idx) = srcs.try_add_imm_u8(imm_u8) else {
216                     return false;
217                 };
218 
219                 new_sel[i] = PrmtSelByte::new(srcs.imm_src, byte_idx, false);
220             } else {
221                 debug_assert!(src.src_mod.is_none());
222                 let Some(src_idx) = srcs.try_add_src(src.src_ref) else {
223                     return false;
224                 };
225 
226                 new_sel[i] = PrmtSelByte::new(
227                     src_idx,
228                     op_sel_byte.byte(),
229                     op_sel_byte.msb(),
230                 );
231             }
232         }
233 
234         let new_sel = PrmtSel::new(new_sel);
235         if new_sel == op_sel
236             && srcs.srcs[0] == op.srcs[0].src_ref
237             && srcs.srcs[1] == op.srcs[1].src_ref
238         {
239             return false;
240         }
241 
242         op.sel = new_sel.into();
243         op.srcs[0] = srcs.srcs[0].into();
244         op.srcs[1] = srcs.srcs[1].into();
245         true
246     }
247 
opt_prmt(&mut self, op: &mut OpPrmt)248     fn opt_prmt(&mut self, op: &mut OpPrmt) {
249         for i in 0..2 {
250             loop {
251                 if !self.try_opt_prmt_src(op, i) {
252                     break;
253                 }
254             }
255         }
256 
257         loop {
258             if !self.try_opt_prmt4(op) {
259                 break;
260             }
261         }
262 
263         self.add_prmt(op);
264     }
265 
run(&mut self, f: &mut Function)266     fn run(&mut self, f: &mut Function) {
267         for b in &mut f.blocks {
268             for instr in &mut b.instrs {
269                 if let Op::Prmt(op) = &mut instr.op {
270                     self.opt_prmt(op);
271                 }
272             }
273         }
274     }
275 }
276 
277 impl Shader<'_> {
opt_prmt(&mut self)278     pub fn opt_prmt(&mut self) {
279         for f in &mut self.functions {
280             PrmtPass::new().run(f);
281         }
282     }
283 }
284