xref: /aosp_15_r20/external/mesa3d/src/gallium/frontends/rusticl/proc/lib.rs (revision 6104692788411f58d303aa86923a9ff6ecaded22)
1 extern crate proc_macro;
2 use proc_macro::Delimiter;
3 use proc_macro::TokenStream;
4 use proc_macro::TokenTree::Group;
5 use proc_macro::TokenTree::Ident;
6 use proc_macro::TokenTree::Punct;
7 
8 /// Macro for generating the C API stubs for normal functions
9 #[proc_macro_attribute]
cl_entrypoint(attr: TokenStream, item: TokenStream) -> TokenStream10 pub fn cl_entrypoint(attr: TokenStream, item: TokenStream) -> TokenStream {
11     let mut name = None;
12     let mut args = None;
13     let mut ret_type = None;
14 
15     let mut iter = item.clone().into_iter();
16     while let Some(item) = iter.next() {
17         match item {
18             Ident(ident) => match ident.to_string().as_str() {
19                 // extract the function name
20                 "fn" => name = Some(iter.next().unwrap().to_string()),
21 
22                 // extract inner type
23                 "CLResult" => {
24                     // skip the `<`
25                     iter.next();
26                     let mut ret_type_tmp = String::new();
27 
28                     for ident in iter.by_ref() {
29                         if ident.to_string() == ">" {
30                             break;
31                         }
32 
33                         if ret_type_tmp.ends_with("mut") || ret_type_tmp.ends_with("const") {
34                             ret_type_tmp.push(' ');
35                         }
36 
37                         ret_type_tmp.push_str(ident.to_string().as_str());
38                     }
39 
40                     ret_type = Some(ret_type_tmp);
41                 }
42                 _ => {}
43             },
44             Group(group) => {
45                 if args.is_some() {
46                     continue;
47                 }
48 
49                 if group.delimiter() != Delimiter::Parenthesis {
50                     continue;
51                 }
52 
53                 // the first group are our function args :)
54                 args = Some(group.stream());
55             }
56             _ => {}
57         }
58     }
59 
60     let name = name.as_ref().expect("no name found!");
61     let args = args.as_ref().expect("no args found!");
62     let ret_type = ret_type.as_ref().expect("no ret_type found!");
63 
64     let mut arg_names = Vec::new();
65     let mut collect = true;
66 
67     // extract the variable names of our function arguments
68     for item in args.clone() {
69         match item {
70             Ident(ident) => {
71                 if collect {
72                     arg_names.push(ident);
73                 }
74             }
75 
76             // we ignore everything between a `:` and a `,` as those are the argument types
77             Punct(punct) => match punct.as_char() {
78                 ':' => collect = false,
79                 ',' => collect = true,
80                 _ => {}
81             },
82 
83             _ => {}
84         }
85     }
86 
87     // convert to string and strip `mut` specifiers
88     let arg_names: Vec<_> = arg_names
89         .clone()
90         .into_iter()
91         .map(|ident| ident.to_string())
92         .filter(|ident| ident != "mut")
93         .collect();
94 
95     let arg_names_str = arg_names.join(",");
96     let mut args = args.to_string();
97     if !args.ends_with(',') {
98         args.push(',');
99     }
100 
101     // depending on the return type we have to generate a different match case
102     let mut res: TokenStream = if ret_type == "()" {
103         // trivial case: return the `Err(err)` as is
104         format!(
105             "pub extern \"C\" fn {attr}(
106                 {args}
107             ) -> cl_int {{
108                 match {name}({arg_names_str}) {{
109                     Ok(_) => CL_SUCCESS as cl_int,
110                     Err(e) => e,
111                 }}
112             }}"
113         )
114     } else {
115         // here we write the error code into the last argument, which we also add. All OpenCL APIs
116         // which return an object do have the `errcode_ret: *mut cl_int` argument last, so we can
117         // just make use of this here.
118         format!(
119             "pub extern \"C\" fn {attr}(
120                 {args}
121                 errcode_ret: *mut cl_int,
122             ) -> {ret_type} {{
123                 let (ptr, err) = match {name}({arg_names_str}) {{
124                     Ok(o) => (o, CL_SUCCESS as cl_int),
125                     Err(e) => (std::ptr::null_mut(), e),
126                 }};
127                 if !errcode_ret.is_null() {{
128                     unsafe {{
129                         *errcode_ret = err;
130                     }}
131                 }}
132                 ptr
133             }}"
134         )
135     }
136     .parse()
137     .unwrap();
138 
139     res.extend(item);
140     res
141 }
142 
143 /// Special macro for generating C function stubs to call into our `CLInfo` trait
144 #[proc_macro_attribute]
cl_info_entrypoint(attr: TokenStream, item: TokenStream) -> TokenStream145 pub fn cl_info_entrypoint(attr: TokenStream, item: TokenStream) -> TokenStream {
146     let mut name = None;
147     let mut args = Vec::new();
148     let mut iter = item.clone().into_iter();
149 
150     let mut collect = false;
151 
152     // we have to extract the type name we implement the trait for and the type of the input
153     // parameters. The input Parameters are defined as `T` inside `CLInfo<T>` or `CLInfoObj<T, ..>`
154     while let Some(item) = iter.next() {
155         match item {
156             Ident(ident) => {
157                 if collect {
158                     args.push(ident);
159                 } else if ident.to_string() == "for" {
160                     name = Some(iter.next().unwrap().to_string());
161                 }
162             }
163             Punct(punct) => match punct.as_char() {
164                 '<' => collect = true,
165                 '>' => collect = false,
166                 _ => {}
167             },
168             _ => {}
169         }
170     }
171 
172     let name = name.as_ref().expect("no name found!");
173     assert!(!args.is_empty());
174 
175     // the 1st argument is special as it's the actual property being queried. The remaining
176     // arguments are additional input data being passed before the property.
177     let arg = &args[0];
178     let (args_values, args) = args[1..]
179         .iter()
180         .enumerate()
181         .map(|(idx, arg)| (format!("arg{idx},"), format!("arg{idx}: {arg},")))
182         .reduce(|(a1, b1), (a2, b2)| (a1 + &a2, b1 + &b2))
183         .unwrap_or_default();
184 
185     // depending on the amount of arguments we have a different trait implementation
186     let method = if args.len() > 1 {
187         "get_info_obj"
188     } else {
189         "get_info"
190     };
191 
192     let mut res: TokenStream = format!(
193         "pub extern \"C\" fn {attr}(
194             input: {name},
195             {args}
196             param_name: {arg},
197             param_value_size: usize,
198             param_value: *mut ::std::ffi::c_void,
199             param_value_size_ret: *mut usize,
200         ) -> cl_int {{
201             match input.{method}(
202                 {args_values}
203                 param_name,
204                 param_value_size,
205                 param_value,
206                 param_value_size_ret,
207             ) {{
208                 Ok(_) => CL_SUCCESS as cl_int,
209                 Err(e) => e,
210             }}
211         }}"
212     )
213     .parse()
214     .unwrap();
215 
216     res.extend(item);
217     res
218 }
219