1 // Copyright (c) 2021 The Vulkano developers
2 // Licensed under the Apache License, Version 2.0
3 // <LICENSE-APACHE or
4 // https://www.apache.org/licenses/LICENSE-2.0> or the MIT
5 // license <LICENSE-MIT or https://opensource.org/licenses/MIT>,
6 // at your option. All files in the project carrying such
7 // notice may not be copied, modified, or distributed except
8 // according to those terms.
9 
10 use super::{write_file, SpirvGrammar};
11 use ahash::{HashMap, HashSet};
12 use heck::ToSnakeCase;
13 use once_cell::sync::Lazy;
14 use proc_macro2::{Ident, TokenStream};
15 use quote::{format_ident, quote};
16 
17 static SPEC_CONSTANT_OP: Lazy<HashSet<&'static str>> = Lazy::new(|| {
18     HashSet::from_iter([
19         "SConvert",
20         "FConvert",
21         "SNegate",
22         "Not",
23         "IAdd",
24         "ISub",
25         "IMul",
26         "UDiv",
27         "SDiv",
28         "UMod",
29         "SRem",
30         "SMod",
31         "ShiftRightLogical",
32         "ShiftRightArithmetic",
33         "ShiftLeftLogical",
34         "BitwiseOr",
35         "BitwiseXor",
36         "BitwiseAnd",
37         "VectorShuffle",
38         "CompositeExtract",
39         "CompositeInsert",
40         "LogicalOr",
41         "LogicalAnd",
42         "LogicalNot",
43         "LogicalEqual",
44         "LogicalNotEqual",
45         "Select",
46         "IEqual",
47         "INotEqual",
48         "ULessThan",
49         "SLessThan",
50         "UGreaterThan",
51         "SGreaterThan",
52         "ULessThanEqual",
53         "SLessThanEqual",
54         "UGreaterThanEqual",
55         "SGreaterThanEqual",
56         "QuantizeToF16",
57         "ConvertFToS",
58         "ConvertSToF",
59         "ConvertFToU",
60         "ConvertUToF",
61         "UConvert",
62         "ConvertPtrToU",
63         "ConvertUToPtr",
64         "GenericCastToPtr",
65         "PtrCastToGeneric",
66         "Bitcast",
67         "FNegate",
68         "FAdd",
69         "FSub",
70         "FMul",
71         "FDiv",
72         "FRem",
73         "FMod",
74         "AccessChain",
75         "InBoundsAccessChain",
76         "PtrAccessChain",
77         "InBoundsPtrAccessChain",
78     ])
79 });
80 
write(grammar: &SpirvGrammar)81 pub fn write(grammar: &SpirvGrammar) {
82     let mut instr_members = instruction_members(grammar);
83     let instr_output = instruction_output(&instr_members, false);
84 
85     instr_members.retain(|member| SPEC_CONSTANT_OP.contains(member.name.to_string().as_str()));
86     instr_members.iter_mut().for_each(|member| {
87         if member.has_result_type_id {
88             member.operands.remove(0);
89         }
90         if member.has_result_id {
91             member.operands.remove(0);
92         }
93     });
94     let spec_constant_instr_output = instruction_output(&instr_members, true);
95 
96     let bit_enum_output = bit_enum_output(&bit_enum_members(grammar));
97     let value_enum_output = value_enum_output(&value_enum_members(grammar));
98 
99     write_file(
100         "spirv_parse.rs",
101         format!(
102             "SPIR-V grammar version {}.{}.{}",
103             grammar.major_version, grammar.minor_version, grammar.revision
104         ),
105         quote! {
106             #instr_output
107             #spec_constant_instr_output
108             #bit_enum_output
109             #value_enum_output
110         },
111     );
112 }
113 
114 #[derive(Clone, Debug)]
115 struct InstructionMember {
116     name: Ident,
117     has_result_id: bool,
118     has_result_type_id: bool,
119     opcode: u16,
120     operands: Vec<OperandMember>,
121 }
122 
123 #[derive(Clone, Debug)]
124 struct OperandMember {
125     name: Ident,
126     ty: TokenStream,
127     parse: TokenStream,
128 }
129 
instruction_output(members: &[InstructionMember], spec_constant: bool) -> TokenStream130 fn instruction_output(members: &[InstructionMember], spec_constant: bool) -> TokenStream {
131     let struct_items = members
132         .iter()
133         .map(|InstructionMember { name, operands, .. }| {
134             if operands.is_empty() {
135                 quote! { #name, }
136             } else {
137                 let operands = operands.iter().map(|OperandMember { name, ty, .. }| {
138                     quote! { #name: #ty, }
139                 });
140                 quote! {
141                     #name {
142                         #(#operands)*
143                     },
144                 }
145             }
146         });
147     let parse_items = members.iter().map(
148         |InstructionMember {
149              name,
150              opcode,
151              operands,
152              ..
153          }| {
154             if operands.is_empty() {
155                 quote! {
156                     #opcode => Self::#name,
157                 }
158             } else {
159                 let operands_items =
160                     operands.iter().map(|OperandMember { name, parse, .. }| {
161                         quote! {
162                             #name: #parse,
163                         }
164                     });
165 
166                 quote! {
167                     #opcode => Self::#name {
168                         #(#operands_items)*
169                     },
170                 }
171             }
172         },
173     );
174 
175     let doc = if spec_constant {
176         "An instruction that is used as the operand of the `SpecConstantOp` instruction."
177     } else {
178         "A parsed SPIR-V instruction."
179     };
180 
181     let enum_name = if spec_constant {
182         format_ident!("SpecConstantInstruction")
183     } else {
184         format_ident!("Instruction")
185     };
186 
187     let result_fns = if spec_constant {
188         quote! {}
189     } else {
190         let result_id_items = members.iter().filter_map(
191             |InstructionMember {
192                  name,
193                  has_result_id,
194                  ..
195              }| {
196                 if *has_result_id {
197                     Some(quote! { Self::#name { result_id, .. } })
198                 } else {
199                     None
200                 }
201             },
202         );
203 
204         quote! {
205             /// Returns the `Id` that is assigned by this instruction, if any.
206             pub fn result_id(&self) -> Option<Id> {
207                 match self {
208                     #(#result_id_items)|* => Some(*result_id),
209                     _ => None
210                 }
211             }
212         }
213     };
214 
215     let opcode_error = if spec_constant {
216         format_ident!("UnknownSpecConstantOpcode")
217     } else {
218         format_ident!("UnknownOpcode")
219     };
220 
221     quote! {
222         #[derive(Clone, Debug, PartialEq, Eq)]
223         #[doc=#doc]
224         pub enum #enum_name {
225             #(#struct_items)*
226         }
227 
228         impl #enum_name {
229             #[allow(dead_code)]
230             fn parse(reader: &mut InstructionReader<'_>) -> Result<Self, ParseError> {
231                 let opcode = (reader.next_u32()? & 0xffff) as u16;
232 
233                 Ok(match opcode {
234                     #(#parse_items)*
235                     opcode => return Err(reader.map_err(ParseErrors::#opcode_error(opcode))),
236                 })
237             }
238 
239             #result_fns
240         }
241     }
242 }
243 
instruction_members(grammar: &SpirvGrammar) -> Vec<InstructionMember>244 fn instruction_members(grammar: &SpirvGrammar) -> Vec<InstructionMember> {
245     let operand_kinds = kinds_to_types(grammar);
246     grammar
247         .instructions
248         .iter()
249         .map(|instruction| {
250             let name = format_ident!("{}", instruction.opname.strip_prefix("Op").unwrap());
251             let mut has_result_id = false;
252             let mut has_result_type_id = false;
253             let mut operand_names = HashMap::default();
254 
255             let mut operands = instruction
256                 .operands
257                 .iter()
258                 .map(|operand| {
259                     let name = if operand.kind == "IdResult" {
260                         has_result_id = true;
261                         format_ident!("result_id")
262                     } else if operand.kind == "IdResultType" {
263                         has_result_type_id = true;
264                         format_ident!("result_type_id")
265                     } else {
266                         to_member_name(&operand.kind, operand.name.as_deref())
267                     };
268 
269                     *operand_names.entry(name.clone()).or_insert(0) += 1;
270 
271                     let (ty, parse) = &operand_kinds[operand.kind.as_str()];
272                     let ty = match operand.quantifier {
273                         Some('?') => quote! { Option<#ty> },
274                         Some('*') => quote! { Vec<#ty> },
275                         _ => ty.clone(),
276                     };
277                     let parse = match operand.quantifier {
278                         Some('?') => quote! {
279                             if !reader.is_empty() {
280                                 Some(#parse)
281                             } else {
282                                 None
283                             }
284                         },
285                         Some('*') => quote! {{
286                             let mut vec = Vec::new();
287                             while !reader.is_empty() {
288                                 vec.push(#parse);
289                             }
290                             vec
291                         }},
292                         _ => parse.clone(),
293                     };
294 
295                     OperandMember { name, ty, parse }
296                 })
297                 .collect::<Vec<_>>();
298 
299             // Add number to operands with identical names
300             for name in operand_names
301                 .into_iter()
302                 .filter_map(|(n, c)| if c > 1 { Some(n) } else { None })
303             {
304                 let mut num = 1;
305 
306                 for operand in operands.iter_mut().filter(|o| o.name == name) {
307                     operand.name = format_ident!("{}{}", name, format!("{}", num));
308                     num += 1;
309                 }
310             }
311 
312             InstructionMember {
313                 name,
314                 has_result_id,
315                 has_result_type_id,
316                 opcode: instruction.opcode,
317                 operands,
318             }
319         })
320         .collect()
321 }
322 
323 #[derive(Clone, Debug)]
324 struct KindEnumMember {
325     name: Ident,
326     value: u32,
327     parameters: Vec<OperandMember>,
328 }
329 
bit_enum_output(enums: &[(Ident, Vec<KindEnumMember>)]) -> TokenStream330 fn bit_enum_output(enums: &[(Ident, Vec<KindEnumMember>)]) -> TokenStream {
331     let enum_items = enums.iter().map(|(name, members)| {
332         let members_items = members.iter().map(
333             |KindEnumMember {
334                  name, parameters, ..
335              }| {
336                 if parameters.is_empty() {
337                     quote! {
338                         pub #name: bool,
339                     }
340                 } else if let [OperandMember { ty, .. }] = parameters.as_slice() {
341                     quote! {
342                         pub #name: Option<#ty>,
343                     }
344                 } else {
345                     let params = parameters.iter().map(|OperandMember { ty, .. }| {
346                         quote! { #ty }
347                     });
348                     quote! {
349                         pub #name: Option<(#(#params),*)>,
350                     }
351                 }
352             },
353         );
354         let parse_items = members.iter().map(
355             |KindEnumMember {
356                  name,
357                  value,
358                  parameters,
359                  ..
360              }| {
361                 if parameters.is_empty() {
362                     quote! {
363                         #name: value & #value != 0,
364                     }
365                 } else {
366                     let some = if let [OperandMember { parse, .. }] = parameters.as_slice() {
367                         quote! { #parse }
368                     } else {
369                         let parse = parameters.iter().map(|OperandMember { parse, .. }| parse);
370                         quote! { (#(#parse),*) }
371                     };
372 
373                     quote! {
374                         #name: if value & #value != 0 {
375                             Some(#some)
376                         } else {
377                             None
378                         },
379                     }
380                 }
381             },
382         );
383 
384         quote! {
385             #[derive(Clone, Copy, Debug, PartialEq, Eq)]
386             #[allow(non_camel_case_types)]
387             pub struct #name {
388                 #(#members_items)*
389             }
390 
391             impl #name {
392                 #[allow(dead_code)]
393                 fn parse(reader: &mut InstructionReader<'_>) -> Result<#name, ParseError> {
394                     let value = reader.next_u32()?;
395 
396                     Ok(Self {
397                         #(#parse_items)*
398                     })
399                 }
400             }
401         }
402     });
403 
404     quote! {
405         #(#enum_items)*
406     }
407 }
408 
bit_enum_members(grammar: &SpirvGrammar) -> Vec<(Ident, Vec<KindEnumMember>)>409 fn bit_enum_members(grammar: &SpirvGrammar) -> Vec<(Ident, Vec<KindEnumMember>)> {
410     let parameter_kinds = kinds_to_types(grammar);
411 
412     grammar
413         .operand_kinds
414         .iter()
415         .filter(|operand_kind| operand_kind.category == "BitEnum")
416         .map(|operand_kind| {
417             let mut previous_value = None;
418 
419             let members = operand_kind
420                 .enumerants
421                 .iter()
422                 .filter_map(|enumerant| {
423                     // Skip enumerants with the same value as the previous.
424                     if previous_value == Some(&enumerant.value) {
425                         return None;
426                     }
427 
428                     previous_value = Some(&enumerant.value);
429 
430                     let value = enumerant
431                         .value
432                         .as_str()
433                         .unwrap()
434                         .strip_prefix("0x")
435                         .unwrap();
436                     let value = u32::from_str_radix(value, 16).unwrap();
437 
438                     if value == 0 {
439                         return None;
440                     }
441 
442                     let name = match enumerant.enumerant.to_snake_case().as_str() {
443                         "const" => format_ident!("constant"),
444                         "not_na_n" => format_ident!("not_nan"),
445                         name => format_ident!("{}", name),
446                     };
447 
448                     let parameters = enumerant
449                         .parameters
450                         .iter()
451                         .map(|param| {
452                             let name = to_member_name(&param.kind, param.name.as_deref());
453                             let (ty, parse) = parameter_kinds[param.kind.as_str()].clone();
454 
455                             OperandMember { name, ty, parse }
456                         })
457                         .collect();
458 
459                     Some(KindEnumMember {
460                         name,
461                         value,
462                         parameters,
463                     })
464                 })
465                 .collect();
466 
467             (format_ident!("{}", operand_kind.kind), members)
468         })
469         .collect()
470 }
471 
value_enum_output(enums: &[(Ident, Vec<KindEnumMember>)]) -> TokenStream472 fn value_enum_output(enums: &[(Ident, Vec<KindEnumMember>)]) -> TokenStream {
473     let enum_items = enums.iter().map(|(name, members)| {
474         let members_items = members.iter().map(
475             |KindEnumMember {
476                  name, parameters, ..
477              }| {
478                 if parameters.is_empty() {
479                     quote! {
480                         #name,
481                     }
482                 } else {
483                     let params = parameters.iter().map(|OperandMember { name, ty, .. }| {
484                         quote! { #name: #ty, }
485                     });
486                     quote! {
487                         #name {
488                             #(#params)*
489                         },
490                     }
491                 }
492             },
493         );
494         let parse_items = members.iter().map(
495             |KindEnumMember {
496                  name,
497                  value,
498                  parameters,
499                  ..
500              }| {
501                 if parameters.is_empty() {
502                     quote! {
503                         #value => Self::#name,
504                     }
505                 } else {
506                     let params_items =
507                         parameters.iter().map(|OperandMember { name, parse, .. }| {
508                             quote! {
509                                 #name: #parse,
510                             }
511                         });
512 
513                     quote! {
514                         #value => Self::#name {
515                             #(#params_items)*
516                         },
517                     }
518                 }
519             },
520         );
521         let name_string = name.to_string();
522 
523         let derives = match name_string.as_str() {
524             "ExecutionModel" => quote! { #[derive(Clone, Copy, Debug, Hash, PartialEq, Eq)] },
525             "Decoration" => quote! { #[derive(Clone, Debug, PartialEq, Eq)] },
526             _ => quote! { #[derive(Clone, Copy, Debug, PartialEq, Eq)] },
527         };
528 
529         quote! {
530             #derives
531             #[allow(non_camel_case_types)]
532             pub enum #name {
533                 #(#members_items)*
534             }
535 
536             impl #name {
537                 #[allow(dead_code)]
538                 fn parse(reader: &mut InstructionReader<'_>) -> Result<#name, ParseError> {
539                     Ok(match reader.next_u32()? {
540                         #(#parse_items)*
541                         value => return Err(reader.map_err(ParseErrors::UnknownEnumerant(#name_string, value))),
542                     })
543                 }
544             }
545         }
546     });
547 
548     quote! {
549         #(#enum_items)*
550     }
551 }
552 
value_enum_members(grammar: &SpirvGrammar) -> Vec<(Ident, Vec<KindEnumMember>)>553 fn value_enum_members(grammar: &SpirvGrammar) -> Vec<(Ident, Vec<KindEnumMember>)> {
554     let parameter_kinds = kinds_to_types(grammar);
555 
556     grammar
557         .operand_kinds
558         .iter()
559         .filter(|operand_kind| operand_kind.category == "ValueEnum")
560         .map(|operand_kind| {
561             let mut previous_value = None;
562 
563             let members = operand_kind
564                 .enumerants
565                 .iter()
566                 .filter_map(|enumerant| {
567                     // Skip enumerants with the same value as the previous.
568                     if previous_value == Some(&enumerant.value) {
569                         return None;
570                     }
571 
572                     previous_value = Some(&enumerant.value);
573 
574                     let name = match enumerant.enumerant.as_str() {
575                         "1D" => format_ident!("Dim1D"),
576                         "2D" => format_ident!("Dim2D"),
577                         "3D" => format_ident!("Dim3D"),
578                         name => format_ident!("{}", name),
579                     };
580                     let parameters = enumerant
581                         .parameters
582                         .iter()
583                         .map(|param| {
584                             let name = to_member_name(&param.kind, param.name.as_deref());
585                             let (ty, parse) = parameter_kinds[param.kind.as_str()].clone();
586 
587                             OperandMember { name, ty, parse }
588                         })
589                         .collect();
590 
591                     Some(KindEnumMember {
592                         name,
593                         value: enumerant.value.as_u64().unwrap() as u32,
594                         parameters,
595                     })
596                 })
597                 .collect();
598 
599             (format_ident!("{}", operand_kind.kind), members)
600         })
601         .collect()
602 }
603 
to_member_name(kind: &str, name: Option<&str>) -> Ident604 fn to_member_name(kind: &str, name: Option<&str>) -> Ident {
605     if let Some(name) = name {
606         let name = name.to_snake_case();
607 
608         // Fix some weird names
609         match name.as_str() {
610             "argument_0_argument_1" => format_ident!("arguments"),
611             "member_0_type_member_1_type" => format_ident!("member_types"),
612             "operand_1_operand_2" => format_ident!("operands"),
613             "parameter_0_type_parameter_1_type" => format_ident!("parameter_types"),
614             "the_name_of_the_opaque_type" => format_ident!("name"),
615             "d_ref" => format_ident!("dref"),
616             "type" => format_ident!("ty"), // type is a keyword
617             _ => format_ident!("{}", name.replace("operand_", "operand")),
618         }
619     } else {
620         format_ident!("{}", kind.to_snake_case())
621     }
622 }
623 
kinds_to_types(grammar: &SpirvGrammar) -> HashMap<&str, (TokenStream, TokenStream)>624 fn kinds_to_types(grammar: &SpirvGrammar) -> HashMap<&str, (TokenStream, TokenStream)> {
625     grammar
626         .operand_kinds
627         .iter()
628         .map(|k| {
629             let (ty, parse) = match k.kind.as_str() {
630                 "LiteralContextDependentNumber" => {
631                     (quote! { Vec<u32> }, quote! { reader.remainder() })
632                 }
633                 "LiteralExtInstInteger" | "LiteralInteger" | "LiteralInt32" => {
634                     (quote! { u32 }, quote! { reader.next_u32()? })
635                 }
636                 "LiteralInt64" => (quote! { u64 }, quote! { reader.next_u64()? }),
637                 "LiteralFloat32" => (
638                     quote! { f32 },
639                     quote! { f32::from_bits(reader.next_u32()?) },
640                 ),
641                 "LiteralFloat64" => (
642                     quote! { f64 },
643                     quote! { f64::from_bits(reader.next_u64()?) },
644                 ),
645                 "LiteralSpecConstantOpInteger" => (
646                     quote! { SpecConstantInstruction },
647                     quote! { SpecConstantInstruction::parse(reader)? },
648                 ),
649                 "LiteralString" => (quote! { String }, quote! { reader.next_string()? }),
650                 "PairIdRefIdRef" => (
651                     quote! { (Id, Id) },
652                     quote! {
653                         (
654                             Id(reader.next_u32()?),
655                             Id(reader.next_u32()?),
656                         )
657                     },
658                 ),
659                 "PairIdRefLiteralInteger" => (
660                     quote! { (Id, u32) },
661                     quote! {
662                         (
663                             Id(reader.next_u32()?),
664                             reader.next_u32()?
665                         )
666                     },
667                 ),
668                 "PairLiteralIntegerIdRef" => (
669                     quote! { (u32, Id) },
670                     quote! {
671                     (
672                         reader.next_u32()?,
673                         Id(reader.next_u32()?)),
674                     },
675                 ),
676                 _ if k.kind.starts_with("Id") => (quote! { Id }, quote! { Id(reader.next_u32()?) }),
677                 ident => {
678                     let ident = format_ident!("{}", ident);
679                     (quote! { #ident }, quote! { #ident::parse(reader)? })
680                 }
681             };
682 
683             (k.kind.as_str(), (ty, parse))
684         })
685         .collect()
686 }
687