xref: /aosp_15_r20/external/mesa3d/src/gallium/frontends/rusticl/mesa/compiler/nir.rs (revision 6104692788411f58d303aa86923a9ff6ecaded22)
1 use mesa_rust_gen::*;
2 use mesa_rust_util::bitset;
3 use mesa_rust_util::offset_of;
4 
5 use std::convert::TryInto;
6 use std::ffi::CString;
7 use std::marker::PhantomData;
8 use std::ptr;
9 use std::ptr::NonNull;
10 use std::slice;
11 
12 pub struct ExecListIter<'a, T> {
13     n: &'a mut exec_node,
14     offset: usize,
15     _marker: PhantomData<T>,
16 }
17 
18 impl<'a, T> ExecListIter<'a, T> {
new(l: &'a mut exec_list, offset: usize) -> Self19     fn new(l: &'a mut exec_list, offset: usize) -> Self {
20         Self {
21             n: &mut l.head_sentinel,
22             offset: offset,
23             _marker: PhantomData,
24         }
25     }
26 }
27 
28 impl<'a, T: 'a> Iterator for ExecListIter<'a, T> {
29     type Item = &'a mut T;
30 
next(&mut self) -> Option<Self::Item>31     fn next(&mut self) -> Option<Self::Item> {
32         self.n = unsafe { &mut *self.n.next };
33         if self.n.next.is_null() {
34             None
35         } else {
36             let t: *mut _ = self.n;
37             Some(unsafe { &mut *(t.byte_sub(self.offset).cast()) })
38         }
39     }
40 }
41 
42 #[macro_export]
43 #[cfg(debug_assertions)]
44 macro_rules! nir_pass_impl {
45     ($nir:ident, $pass:ident, $func:ident $(,$arg:expr)* $(,)?) => {
46         {
47             let func_str = ::std::stringify!($func);
48             let func_cstr = ::std::ffi::CString::new(func_str).unwrap();
49             let res = if unsafe { should_skip_nir(func_cstr.as_ptr()) } {
50                 println!("skipping {}", func_str);
51                 false
52             } else {
53                 $nir.metadata_set_validation_flag();
54                 if $nir.should_print() {
55                     println!("{}", func_str);
56                 }
57                 if $nir.$pass($func $(,$arg)*) {
58                     $nir.validate(&format!("after {} in {}:{}", func_str, file!(), line!()));
59                     if $nir.should_print() {
60                         $nir.print();
61                     }
62                     $nir.metadata_check_validation_flag();
63                     true
64                 } else {
65                     false
66                 }
67             };
68 
69             // SAFETY: mutable static can't be read safely, but this value isn't going to change
70             let ndebug = unsafe { nir_debug };
71             if ndebug & NIR_DEBUG_CLONE != 0 {
72                 $nir.validate_clone();
73             }
74 
75             if ndebug & NIR_DEBUG_SERIALIZE != 0 {
76                 $nir.validate_serialize_deserialize();
77             }
78 
79             res
80         }
81     };
82 }
83 
84 #[macro_export]
85 #[cfg(not(debug_assertions))]
86 macro_rules! nir_pass_impl {
87     ($nir:ident, $pass:ident, $func:ident $(,$arg:expr)* $(,)?) => {
88         $nir.$pass($func $(,$arg)*)
89     };
90 }
91 
92 #[macro_export]
93 macro_rules! nir_pass {
94     ($nir:ident, $func:ident $(,)?) => {
95         $crate::nir_pass_impl!($nir, pass0, $func)
96     };
97 
98     ($nir:ident, $func:ident, $a:expr $(,)?) => {
99         $crate::nir_pass_impl!($nir, pass1, $func, $a)
100     };
101 
102     ($nir:ident, $func:ident, $a:expr, $b:expr $(,)?) => {
103         $crate::nir_pass_impl!($nir, pass2, $func, $a, $b)
104     };
105 
106     ($nir:ident, $func:ident, $a:expr, $b:expr, $c:expr $(,)?) => {
107         $crate::nir_pass_impl!($nir, pass3, $func, $a, $b, $c)
108     };
109 }
110 
111 pub struct NirPrintfInfo {
112     count: usize,
113     printf_info: *mut u_printf_info,
114 }
115 
116 // SAFETY: `u_printf_info` is considered immutable
117 unsafe impl Send for NirPrintfInfo {}
118 unsafe impl Sync for NirPrintfInfo {}
119 
120 impl NirPrintfInfo {
u_printf(&self, buf: &[u8])121     pub fn u_printf(&self, buf: &[u8]) {
122         unsafe {
123             u_printf(
124                 stdout_ptr(),
125                 buf.as_ptr().cast(),
126                 buf.len(),
127                 self.printf_info.cast(),
128                 self.count as u32,
129             );
130         }
131     }
132 }
133 
134 impl Drop for NirPrintfInfo {
drop(&mut self)135     fn drop(&mut self) {
136         unsafe {
137             ralloc_free(self.printf_info.cast());
138         };
139     }
140 }
141 
142 pub struct NirShader {
143     nir: NonNull<nir_shader>,
144 }
145 
146 // SAFETY: It's safe to share a nir_shader between threads.
147 unsafe impl Send for NirShader {}
148 
149 // SAFETY: We do not allow interior mutability with &NirShader
150 unsafe impl Sync for NirShader {}
151 
152 impl NirShader {
new(nir: *mut nir_shader) -> Option<Self>153     pub fn new(nir: *mut nir_shader) -> Option<Self> {
154         NonNull::new(nir).map(|nir| Self { nir: nir })
155     }
156 
deserialize( blob: &mut blob_reader, options: *const nir_shader_compiler_options, ) -> Option<Self>157     pub fn deserialize(
158         blob: &mut blob_reader,
159         options: *const nir_shader_compiler_options,
160     ) -> Option<Self> {
161         unsafe { Self::new(nir_deserialize(ptr::null_mut(), options, blob)) }
162     }
163 
serialize(&self, blob: &mut blob)164     pub fn serialize(&self, blob: &mut blob) {
165         unsafe {
166             nir_serialize(blob, self.nir.as_ptr(), false);
167         }
168     }
169 
print(&self)170     pub fn print(&self) {
171         unsafe { nir_print_shader(self.nir.as_ptr(), stderr_ptr()) };
172     }
173 
get_nir(&self) -> *mut nir_shader174     pub fn get_nir(&self) -> *mut nir_shader {
175         self.nir.as_ptr()
176     }
177 
dup_for_driver(&self) -> *mut nir_shader178     pub fn dup_for_driver(&self) -> *mut nir_shader {
179         unsafe { nir_shader_clone(ptr::null_mut(), self.nir.as_ptr()) }
180     }
181 
sweep_mem(&mut self)182     pub fn sweep_mem(&mut self) {
183         unsafe { nir_sweep(self.nir.as_ptr()) }
184     }
185 
pass0<R>(&mut self, pass: unsafe extern "C" fn(*mut nir_shader) -> R) -> R186     pub fn pass0<R>(&mut self, pass: unsafe extern "C" fn(*mut nir_shader) -> R) -> R {
187         unsafe { pass(self.nir.as_ptr()) }
188     }
189 
pass1<R, A>( &mut self, pass: unsafe extern "C" fn(*mut nir_shader, a: A) -> R, a: A, ) -> R190     pub fn pass1<R, A>(
191         &mut self,
192         pass: unsafe extern "C" fn(*mut nir_shader, a: A) -> R,
193         a: A,
194     ) -> R {
195         unsafe { pass(self.nir.as_ptr(), a) }
196     }
197 
pass2<R, A, B>( &mut self, pass: unsafe extern "C" fn(*mut nir_shader, a: A, b: B) -> R, a: A, b: B, ) -> R198     pub fn pass2<R, A, B>(
199         &mut self,
200         pass: unsafe extern "C" fn(*mut nir_shader, a: A, b: B) -> R,
201         a: A,
202         b: B,
203     ) -> R {
204         unsafe { pass(self.nir.as_ptr(), a, b) }
205     }
206 
pass3<R, A, B, C>( &mut self, pass: unsafe extern "C" fn(*mut nir_shader, a: A, b: B, c: C) -> R, a: A, b: B, c: C, ) -> R207     pub fn pass3<R, A, B, C>(
208         &mut self,
209         pass: unsafe extern "C" fn(*mut nir_shader, a: A, b: B, c: C) -> R,
210         a: A,
211         b: B,
212         c: C,
213     ) -> R {
214         unsafe { pass(self.nir.as_ptr(), a, b, c) }
215     }
216 
217     #[cfg(debug_assertions)]
metadata_check_validation_flag(&self)218     pub fn metadata_check_validation_flag(&self) {
219         unsafe { nir_metadata_check_validation_flag(self.nir.as_ptr()) }
220     }
221 
222     #[cfg(debug_assertions)]
metadata_set_validation_flag(&mut self)223     pub fn metadata_set_validation_flag(&mut self) {
224         unsafe { nir_metadata_set_validation_flag(self.nir.as_ptr()) }
225     }
226 
227     #[cfg(debug_assertions)]
validate(&self, when: &str)228     pub fn validate(&self, when: &str) {
229         let cstr = CString::new(when).unwrap();
230         unsafe { nir_validate_shader(self.nir.as_ptr(), cstr.as_ptr()) }
231     }
232 
should_print(&self) -> bool233     pub fn should_print(&self) -> bool {
234         unsafe { should_print_nir(self.nir.as_ptr()) }
235     }
236 
validate_serialize_deserialize(&mut self)237     pub fn validate_serialize_deserialize(&mut self) {
238         unsafe { nir_shader_serialize_deserialize(self.nir.as_ptr()) }
239     }
240 
validate_clone(&mut self)241     pub fn validate_clone(&mut self) {
242         unsafe {
243             let nir_ptr = self.nir.as_ptr();
244             let clone = nir_shader_clone(ralloc_parent(nir_ptr.cast()), nir_ptr);
245             nir_shader_replace(nir_ptr, clone)
246         }
247     }
248 
entrypoint(&self) -> *mut nir_function_impl249     pub fn entrypoint(&self) -> *mut nir_function_impl {
250         unsafe { nir_shader_get_entrypoint(self.nir.as_ptr()) }
251     }
252 
structurize(&mut self)253     pub fn structurize(&mut self) {
254         nir_pass!(self, nir_lower_goto_ifs);
255         nir_pass!(self, nir_opt_dead_cf);
256     }
257 
inline(&mut self, libclc: &NirShader)258     pub fn inline(&mut self, libclc: &NirShader) {
259         nir_pass!(
260             self,
261             nir_lower_variable_initializers,
262             nir_variable_mode::nir_var_function_temp,
263         );
264         nir_pass!(self, nir_lower_returns);
265         nir_pass!(self, nir_link_shader_functions, libclc.nir.as_ptr());
266         nir_pass!(self, nir_inline_functions);
267     }
268 
gather_info(&mut self)269     pub fn gather_info(&mut self) {
270         unsafe { nir_shader_gather_info(self.nir.as_ptr(), self.entrypoint()) }
271     }
272 
remove_non_entrypoints(&mut self)273     pub fn remove_non_entrypoints(&mut self) {
274         unsafe { nir_remove_non_entrypoints(self.nir.as_ptr()) };
275     }
276 
cleanup_functions(&mut self)277     pub fn cleanup_functions(&mut self) {
278         unsafe { nir_cleanup_functions(self.nir.as_ptr()) };
279     }
280 
variables(&mut self) -> ExecListIter<nir_variable>281     pub fn variables(&mut self) -> ExecListIter<nir_variable> {
282         ExecListIter::new(
283             &mut unsafe { self.nir.as_mut() }.variables,
284             offset_of!(nir_variable, node),
285         )
286     }
287 
num_images(&self) -> u8288     pub fn num_images(&self) -> u8 {
289         unsafe { (*self.nir.as_ptr()).info.num_images }
290     }
291 
num_textures(&self) -> u8292     pub fn num_textures(&self) -> u8 {
293         unsafe { (*self.nir.as_ptr()).info.num_textures }
294     }
295 
reset_scratch_size(&mut self)296     pub fn reset_scratch_size(&mut self) {
297         unsafe {
298             (*self.nir.as_ptr()).scratch_size = 0;
299         }
300     }
301 
scratch_size(&self) -> u32302     pub fn scratch_size(&self) -> u32 {
303         unsafe { (*self.nir.as_ptr()).scratch_size }
304     }
305 
reset_shared_size(&mut self)306     pub fn reset_shared_size(&mut self) {
307         unsafe {
308             (*self.nir.as_ptr()).info.shared_size = 0;
309         }
310     }
shared_size(&self) -> u32311     pub fn shared_size(&self) -> u32 {
312         unsafe { (*self.nir.as_ptr()).info.shared_size }
313     }
314 
workgroup_size(&self) -> [u16; 3]315     pub fn workgroup_size(&self) -> [u16; 3] {
316         unsafe { (*self.nir.as_ptr()).info.workgroup_size }
317     }
318 
subgroup_size(&self) -> u8319     pub fn subgroup_size(&self) -> u8 {
320         let subgroup_size = unsafe { (*self.nir.as_ptr()).info.subgroup_size };
321         let valid_subgroup_sizes = [
322             gl_subgroup_size::SUBGROUP_SIZE_REQUIRE_8,
323             gl_subgroup_size::SUBGROUP_SIZE_REQUIRE_16,
324             gl_subgroup_size::SUBGROUP_SIZE_REQUIRE_32,
325             gl_subgroup_size::SUBGROUP_SIZE_REQUIRE_64,
326             gl_subgroup_size::SUBGROUP_SIZE_REQUIRE_128,
327         ];
328 
329         if valid_subgroup_sizes.contains(&subgroup_size) {
330             subgroup_size as u8
331         } else {
332             0
333         }
334     }
335 
num_subgroups(&self) -> u8336     pub fn num_subgroups(&self) -> u8 {
337         unsafe { (*self.nir.as_ptr()).info.num_subgroups }
338     }
339 
set_workgroup_size_variable_if_zero(&mut self)340     pub fn set_workgroup_size_variable_if_zero(&mut self) {
341         let nir = self.nir.as_ptr();
342         unsafe {
343             (*nir)
344                 .info
345                 .set_workgroup_size_variable((*nir).info.workgroup_size[0] == 0);
346         }
347     }
348 
set_workgroup_size(&mut self, size: [u16; 3])349     pub fn set_workgroup_size(&mut self, size: [u16; 3]) {
350         let nir = unsafe { self.nir.as_mut() };
351         nir.info.set_workgroup_size_variable(false);
352         nir.info.workgroup_size = size;
353     }
354 
workgroup_size_variable(&self) -> bool355     pub fn workgroup_size_variable(&self) -> bool {
356         unsafe { self.nir.as_ref() }.info.workgroup_size_variable()
357     }
358 
workgroup_size_hint(&self) -> [u16; 3]359     pub fn workgroup_size_hint(&self) -> [u16; 3] {
360         unsafe { self.nir.as_ref().info.anon_1.cs.workgroup_size_hint }
361     }
362 
set_has_variable_shared_mem(&mut self, val: bool)363     pub fn set_has_variable_shared_mem(&mut self, val: bool) {
364         unsafe {
365             self.nir
366                 .as_mut()
367                 .info
368                 .anon_1
369                 .cs
370                 .set_has_variable_shared_mem(val)
371         }
372     }
373 
variables_with_mode( &mut self, mode: nir_variable_mode, ) -> impl Iterator<Item = &mut nir_variable>374     pub fn variables_with_mode(
375         &mut self,
376         mode: nir_variable_mode,
377     ) -> impl Iterator<Item = &mut nir_variable> {
378         self.variables()
379             .filter(move |v| v.data.mode() & mode.0 != 0)
380     }
381 
extract_constant_initializers(&mut self)382     pub fn extract_constant_initializers(&mut self) {
383         let nir = self.nir.as_ptr();
384         unsafe {
385             if (*nir).constant_data_size > 0 {
386                 assert!((*nir).constant_data.is_null());
387                 (*nir).constant_data = rzalloc_size(nir.cast(), (*nir).constant_data_size as usize);
388                 nir_gather_explicit_io_initializers(
389                     nir,
390                     (*nir).constant_data,
391                     (*nir).constant_data_size as usize,
392                     nir_variable_mode::nir_var_mem_constant,
393                 );
394             }
395         }
396     }
397 
has_constant(&self) -> bool398     pub fn has_constant(&self) -> bool {
399         unsafe {
400             !self.nir.as_ref().constant_data.is_null() && self.nir.as_ref().constant_data_size > 0
401         }
402     }
403 
has_printf(&self) -> bool404     pub fn has_printf(&self) -> bool {
405         unsafe {
406             !self.nir.as_ref().printf_info.is_null() && self.nir.as_ref().printf_info_count != 0
407         }
408     }
409 
take_printf_info(&mut self) -> Option<NirPrintfInfo>410     pub fn take_printf_info(&mut self) -> Option<NirPrintfInfo> {
411         let nir = unsafe { self.nir.as_mut() };
412 
413         let info = nir.printf_info;
414         if info.is_null() {
415             return None;
416         }
417         let count = nir.printf_info_count as usize;
418 
419         unsafe {
420             ralloc_steal(ptr::null(), info.cast());
421 
422             for i in 0..count {
423                 ralloc_steal(info.cast(), (*info.add(i)).arg_sizes.cast());
424                 ralloc_steal(info.cast(), (*info.add(i)).strings.cast());
425             }
426         };
427 
428         let result = Some(NirPrintfInfo {
429             count: count,
430             printf_info: info,
431         });
432 
433         nir.printf_info_count = 0;
434         nir.printf_info = ptr::null_mut();
435 
436         result
437     }
438 
get_constant_buffer(&self) -> &[u8]439     pub fn get_constant_buffer(&self) -> &[u8] {
440         unsafe {
441             let nir = self.nir.as_ref();
442             // Sometimes, constant_data can be a null pointer if the size is 0
443             if nir.constant_data_size == 0 {
444                 &[]
445             } else {
446                 slice::from_raw_parts(nir.constant_data.cast(), nir.constant_data_size as usize)
447             }
448         }
449     }
450 
preserve_fp16_denorms(&mut self)451     pub fn preserve_fp16_denorms(&mut self) {
452         unsafe {
453             self.nir.as_mut().info.float_controls_execution_mode |=
454                 float_controls::FLOAT_CONTROLS_DENORM_PRESERVE_FP16 as u32;
455         }
456     }
457 
set_fp_rounding_mode_rtne(&mut self)458     pub fn set_fp_rounding_mode_rtne(&mut self) {
459         unsafe {
460             self.nir.as_mut().info.float_controls_execution_mode |=
461                 float_controls::FLOAT_CONTROLS_ROUNDING_MODE_RTE_FP16 as u32
462                     | float_controls::FLOAT_CONTROLS_ROUNDING_MODE_RTE_FP32 as u32
463                     | float_controls::FLOAT_CONTROLS_ROUNDING_MODE_RTE_FP64 as u32;
464         }
465     }
466 
reads_sysval(&self, sysval: gl_system_value) -> bool467     pub fn reads_sysval(&self, sysval: gl_system_value) -> bool {
468         let nir = unsafe { self.nir.as_ref() };
469         bitset::test_bit(&nir.info.system_values_read, sysval as u32)
470     }
471 
add_var( &mut self, mode: nir_variable_mode, glsl_type: *const glsl_type, loc: usize, name: &str, )472     pub fn add_var(
473         &mut self,
474         mode: nir_variable_mode,
475         glsl_type: *const glsl_type,
476         loc: usize,
477         name: &str,
478     ) {
479         let name = CString::new(name).unwrap();
480         unsafe {
481             let var = nir_variable_create(self.nir.as_ptr(), mode, glsl_type, name.as_ptr());
482             (*var).data.location = loc.try_into().unwrap();
483         }
484     }
485 }
486 
487 impl Clone for NirShader {
clone(&self) -> Self488     fn clone(&self) -> Self {
489         Self {
490             nir: unsafe { NonNull::new_unchecked(self.dup_for_driver()) },
491         }
492     }
493 }
494 
495 impl Drop for NirShader {
drop(&mut self)496     fn drop(&mut self) {
497         unsafe { ralloc_free(self.nir.as_ptr().cast()) };
498     }
499 }
500