1 extern crate proc_macro;
2 use proc_macro::Delimiter;
3 use proc_macro::TokenStream;
4 use proc_macro::TokenTree::Group;
5 use proc_macro::TokenTree::Ident;
6 use proc_macro::TokenTree::Punct;
7
8 /// Macro for generating the C API stubs for normal functions
9 #[proc_macro_attribute]
cl_entrypoint(attr: TokenStream, item: TokenStream) -> TokenStream10 pub fn cl_entrypoint(attr: TokenStream, item: TokenStream) -> TokenStream {
11 let mut name = None;
12 let mut args = None;
13 let mut ret_type = None;
14
15 let mut iter = item.clone().into_iter();
16 while let Some(item) = iter.next() {
17 match item {
18 Ident(ident) => match ident.to_string().as_str() {
19 // extract the function name
20 "fn" => name = Some(iter.next().unwrap().to_string()),
21
22 // extract inner type
23 "CLResult" => {
24 // skip the `<`
25 iter.next();
26 let mut ret_type_tmp = String::new();
27
28 for ident in iter.by_ref() {
29 if ident.to_string() == ">" {
30 break;
31 }
32
33 if ret_type_tmp.ends_with("mut") || ret_type_tmp.ends_with("const") {
34 ret_type_tmp.push(' ');
35 }
36
37 ret_type_tmp.push_str(ident.to_string().as_str());
38 }
39
40 ret_type = Some(ret_type_tmp);
41 }
42 _ => {}
43 },
44 Group(group) => {
45 if args.is_some() {
46 continue;
47 }
48
49 if group.delimiter() != Delimiter::Parenthesis {
50 continue;
51 }
52
53 // the first group are our function args :)
54 args = Some(group.stream());
55 }
56 _ => {}
57 }
58 }
59
60 let name = name.as_ref().expect("no name found!");
61 let args = args.as_ref().expect("no args found!");
62 let ret_type = ret_type.as_ref().expect("no ret_type found!");
63
64 let mut arg_names = Vec::new();
65 let mut collect = true;
66
67 // extract the variable names of our function arguments
68 for item in args.clone() {
69 match item {
70 Ident(ident) => {
71 if collect {
72 arg_names.push(ident);
73 }
74 }
75
76 // we ignore everything between a `:` and a `,` as those are the argument types
77 Punct(punct) => match punct.as_char() {
78 ':' => collect = false,
79 ',' => collect = true,
80 _ => {}
81 },
82
83 _ => {}
84 }
85 }
86
87 // convert to string and strip `mut` specifiers
88 let arg_names: Vec<_> = arg_names
89 .clone()
90 .into_iter()
91 .map(|ident| ident.to_string())
92 .filter(|ident| ident != "mut")
93 .collect();
94
95 let arg_names_str = arg_names.join(",");
96 let mut args = args.to_string();
97 if !args.ends_with(',') {
98 args.push(',');
99 }
100
101 // depending on the return type we have to generate a different match case
102 let mut res: TokenStream = if ret_type == "()" {
103 // trivial case: return the `Err(err)` as is
104 format!(
105 "pub extern \"C\" fn {attr}(
106 {args}
107 ) -> cl_int {{
108 match {name}({arg_names_str}) {{
109 Ok(_) => CL_SUCCESS as cl_int,
110 Err(e) => e,
111 }}
112 }}"
113 )
114 } else {
115 // here we write the error code into the last argument, which we also add. All OpenCL APIs
116 // which return an object do have the `errcode_ret: *mut cl_int` argument last, so we can
117 // just make use of this here.
118 format!(
119 "pub extern \"C\" fn {attr}(
120 {args}
121 errcode_ret: *mut cl_int,
122 ) -> {ret_type} {{
123 let (ptr, err) = match {name}({arg_names_str}) {{
124 Ok(o) => (o, CL_SUCCESS as cl_int),
125 Err(e) => (std::ptr::null_mut(), e),
126 }};
127 if !errcode_ret.is_null() {{
128 unsafe {{
129 *errcode_ret = err;
130 }}
131 }}
132 ptr
133 }}"
134 )
135 }
136 .parse()
137 .unwrap();
138
139 res.extend(item);
140 res
141 }
142
143 /// Special macro for generating C function stubs to call into our `CLInfo` trait
144 #[proc_macro_attribute]
cl_info_entrypoint(attr: TokenStream, item: TokenStream) -> TokenStream145 pub fn cl_info_entrypoint(attr: TokenStream, item: TokenStream) -> TokenStream {
146 let mut name = None;
147 let mut args = Vec::new();
148 let mut iter = item.clone().into_iter();
149
150 let mut collect = false;
151
152 // we have to extract the type name we implement the trait for and the type of the input
153 // parameters. The input Parameters are defined as `T` inside `CLInfo<T>` or `CLInfoObj<T, ..>`
154 while let Some(item) = iter.next() {
155 match item {
156 Ident(ident) => {
157 if collect {
158 args.push(ident);
159 } else if ident.to_string() == "for" {
160 name = Some(iter.next().unwrap().to_string());
161 }
162 }
163 Punct(punct) => match punct.as_char() {
164 '<' => collect = true,
165 '>' => collect = false,
166 _ => {}
167 },
168 _ => {}
169 }
170 }
171
172 let name = name.as_ref().expect("no name found!");
173 assert!(!args.is_empty());
174
175 // the 1st argument is special as it's the actual property being queried. The remaining
176 // arguments are additional input data being passed before the property.
177 let arg = &args[0];
178 let (args_values, args) = args[1..]
179 .iter()
180 .enumerate()
181 .map(|(idx, arg)| (format!("arg{idx},"), format!("arg{idx}: {arg},")))
182 .reduce(|(a1, b1), (a2, b2)| (a1 + &a2, b1 + &b2))
183 .unwrap_or_default();
184
185 // depending on the amount of arguments we have a different trait implementation
186 let method = if args.len() > 1 {
187 "get_info_obj"
188 } else {
189 "get_info"
190 };
191
192 let mut res: TokenStream = format!(
193 "pub extern \"C\" fn {attr}(
194 input: {name},
195 {args}
196 param_name: {arg},
197 param_value_size: usize,
198 param_value: *mut ::std::ffi::c_void,
199 param_value_size_ret: *mut usize,
200 ) -> cl_int {{
201 match input.{method}(
202 {args_values}
203 param_name,
204 param_value_size,
205 param_value,
206 param_value_size_ret,
207 ) {{
208 Ok(_) => CL_SUCCESS as cl_int,
209 Err(e) => e,
210 }}
211 }}"
212 )
213 .parse()
214 .unwrap();
215
216 res.extend(item);
217 res
218 }
219