1#!/usr/bin/env python3
2
3# Copyright 2023 Google LLC
4#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9#     https://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16
17import argparse
18from dataclasses import dataclass, field
19import json
20from pathlib import Path
21import sys
22from textwrap import dedent
23from typing import List, Tuple, Union, Optional
24
25from pdl import ast, core
26from pdl.utils import indent, to_pascal_case
27
28
29def mask(width: int) -> str:
30    return hex((1 << width) - 1)
31
32
33def deref(var: Optional[str], id: str) -> str:
34    return f'{var}.{id}' if var else id
35
36
37def get_cxx_scalar_type(width: int) -> str:
38    """Return the cxx scalar type to be used to back a PDL type."""
39    for n in [8, 16, 32, 64]:
40        if width <= n:
41            return f'uint{n}_t'
42    # PDL type does not fit on non-extended scalar types.
43    assert False
44
45
46@dataclass
47class FieldParser:
48    byteorder: str
49    offset: int = 0
50    shift: int = 0
51    extract_arrays: bool = field(default=False)
52    chunk: List[Tuple[int, int, ast.Field]] = field(default_factory=lambda: [])
53    chunk_nr: int = 0
54    unchecked_code: List[str] = field(default_factory=lambda: [])
55    code: List[str] = field(default_factory=lambda: [])
56
57    def unchecked_append_(self, line: str):
58        """Append unchecked field parsing code.
59        The function check_size_ must be called to generate a size guard
60        after parsing is completed."""
61        self.unchecked_code.append(line)
62
63    def append_(self, line: str):
64        """Append field parsing code.
65        There must be no unchecked code left before this function is called."""
66        assert len(self.unchecked_code) == 0
67        self.code.append(line)
68
69    def check_size_(self, size: str):
70        """Generate a check of the current span size."""
71        self.append_(f"if (span.size() < {size}) {{")
72        self.append_("    return false;")
73        self.append_("}")
74
75    def check_code_(self):
76        """Generate a size check for pending field parsing."""
77        if len(self.unchecked_code) > 0:
78            assert len(self.chunk) == 0
79            unchecked_code = self.unchecked_code
80            self.unchecked_code = []
81            self.check_size_(str(self.offset))
82            self.code.extend(unchecked_code)
83            self.offset = 0
84
85    def parse_bit_field_(self, field: ast.Field):
86        """Parse the selected field as a bit field.
87        The field is added to the current chunk. When a byte boundary
88        is reached all saved fields are extracted together."""
89
90        # Add to current chunk.
91        width = core.get_field_size(field)
92        self.chunk.append((self.shift, width, field))
93        self.shift += width
94
95        # Wait for more fields if not on a byte boundary.
96        if (self.shift % 8) != 0:
97            return
98
99        # Parse the backing integer using the configured endianness,
100        # extract field values.
101        size = int(self.shift / 8)
102        backing_type = get_cxx_scalar_type(self.shift)
103
104        # Special case when no field is actually used from
105        # the chunk.
106        should_skip_value = all(isinstance(field, ast.ReservedField) for (_, _, field) in self.chunk)
107        if should_skip_value:
108            self.unchecked_append_(f"span.skip({size}); // skip reserved fields")
109            self.offset += size
110            self.shift = 0
111            self.chunk = []
112            return
113
114        if len(self.chunk) > 1:
115            value = f"chunk{self.chunk_nr}"
116            self.unchecked_append_(f"{backing_type} {value} = span.read_{self.byteorder}<{backing_type}, {size}>();")
117            self.chunk_nr += 1
118        else:
119            value = f"span.read_{self.byteorder}<{backing_type}, {size}>()"
120
121        for shift, width, field in self.chunk:
122            v = (value if len(self.chunk) == 1 and shift == 0 else f"({value} >> {shift}) & {mask(width)}")
123
124            if field.cond_for:
125                self.unchecked_append_(f"uint8_t {field.id} = {v};")
126            elif isinstance(field, ast.ScalarField):
127                self.unchecked_append_(f"{field.id}_ = {v};")
128            elif isinstance(field, ast.FixedField) and field.enum_id:
129                self.unchecked_append_(f"if ({field.enum_id}({v}) != {field.enum_id}::{field.tag_id}) {{")
130                self.unchecked_append_("    return false;")
131                self.unchecked_append_("}")
132            elif isinstance(field, ast.FixedField):
133                self.unchecked_append_(f"if (({v}) != {hex(field.value)}) {{")
134                self.unchecked_append_("    return false;")
135                self.unchecked_append_("}")
136            elif isinstance(field, ast.TypedefField):
137                self.unchecked_append_(f"{field.id}_ = {field.type_id}({v});")
138            elif isinstance(field, ast.SizeField):
139                self.unchecked_append_(f"{field.field_id}_size = {v};")
140            elif isinstance(field, ast.CountField):
141                self.unchecked_append_(f"{field.field_id}_count = {v};")
142            elif isinstance(field, ast.ReservedField):
143                pass
144            else:
145                raise Exception(f'Unsupported bit field type {field.kind}')
146
147        # Reset state.
148        self.offset += size
149        self.shift = 0
150        self.chunk = []
151
152    def parse_typedef_field_(self, field: ast.TypedefField):
153        """Parse a typedef field, to the exclusion of Enum fields."""
154        if self.shift != 0:
155            raise Exception('Typedef field does not start on an octet boundary')
156
157        self.check_code_()
158        self.append_(
159            dedent("""\
160            if (!{field_type}::Parse(span, &{field_id}_)) {{
161                return false;
162            }}""".format(field_type=field.type.id, field_id=field.id)))
163
164    def parse_optional_field_(self, field: ast.Field):
165        """Parse the selected optional field.
166        Optional fields must start and end on a byte boundary."""
167
168        self.check_code_()
169
170        if isinstance(field, ast.ScalarField):
171            backing_type = get_cxx_scalar_type(field.width)
172            self.append_(dedent("""
173            if ({cond_id} == {cond_value}) {{
174                if (span.size() < {size}) {{
175                    return false;
176                }}
177                {field_id}_ = std::make_optional(
178                    span.read_{byteorder}<{backing_type}, {size}>());
179            }}
180            """.format(size=int(field.width / 8),
181                       backing_type=backing_type,
182                       field_id=field.id,
183                       cond_id=field.cond.id,
184                       cond_value=field.cond.value,
185                       byteorder=self.byteorder)))
186
187        elif isinstance(field, ast.TypedefField) and isinstance(field.type, ast.EnumDeclaration):
188            backing_type = get_cxx_scalar_type(field.type.width)
189            self.append_(dedent("""
190            if ({cond_id} == {cond_value}) {{
191                if (span.size() < {size}) {{
192                    return false;
193                }}
194                {field_id}_ = std::make_optional({type_id}(
195                    span.read_{byteorder}<{backing_type}, {size}>()));
196            }}
197            """.format(size=int(field.type.width / 8),
198                       backing_type=backing_type,
199                       type_id=field.type_id,
200                       field_id=field.id,
201                       cond_id=field.cond.id,
202                       cond_value=field.cond.value,
203                       byteorder=self.byteorder)))
204
205        elif isinstance(field, ast.TypedefField):
206            self.append_(dedent("""
207            if ({cond_id} == {cond_value}) {{
208                auto& output = {field_id}_.emplace();
209                if (!{type_id}::Parse(span, &output)) {{
210                    return false;
211                }}
212            }}
213            """.format(type_id=field.type_id,
214                       field_id=field.id,
215                       cond_id=field.cond.id,
216                       cond_value=field.cond.value)))
217
218        else:
219            raise Exception(f"unsupported field type {field.__class__.__name__}")
220
221    def parse_array_field_lite_(self, field: ast.ArrayField):
222        """Parse the selected array field.
223        This function does not attempt to parse all elements but just to
224        identify the span of the array."""
225        array_size = core.get_array_field_size(field)
226        element_width = core.get_array_element_size(field)
227        padded_size = field.padded_size
228
229        if element_width:
230            element_width = int(element_width / 8)
231
232        if isinstance(array_size, int):
233            size = None
234            count = array_size
235        elif isinstance(array_size, ast.SizeField):
236            size = f'{field.id}_size'
237            count = None
238        elif isinstance(array_size, ast.CountField):
239            size = None
240            count = f'{field.id}_count'
241        else:
242            size = None
243            count = None
244
245        # Shift the span to reset the offset to 0.
246        self.check_code_()
247
248        # Apply the size modifier.
249        if field.size_modifier and size:
250            self.append_(f"{size} = {size} - {field.size_modifier};")
251
252        # Compute the array size if the count and element width are known.
253        if count is not None and element_width is not None:
254            size = f"{count} * {element_width}"
255
256        # Parse from the padded array if padding is present.
257        if padded_size:
258            self.check_size_(padded_size)
259            self.append_("{")
260            self.append_(
261                f"pdl::packet::slice remaining_span = span.subrange({padded_size}, span.size() - {padded_size});")
262            self.append_(f"span = span.subrange(0, {padded_size});")
263
264        # The array size is known in bytes.
265        if size is not None:
266            self.check_size_(size)
267            self.append_(f"{field.id}_ = span.subrange(0, {size});")
268            self.append_(f"span.skip({size});")
269
270        # The array count is known. The element width is dynamic.
271        # Parse each element iteratively and derive the array span.
272        elif count is not None:
273            self.append_("{")
274            self.append_("pdl::packet::slice temp_span = span;")
275            self.append_(f"for (size_t n = 0; n < {count}; n++) {{")
276            self.append_(f"    {field.type_id} element;")
277            self.append_(f"    if (!{field.type_id}::Parse(temp_span, &element)) {{")
278            self.append_("        return false;")
279            self.append_("    }")
280            self.append_("}")
281            self.append_(f"{field.id}_ = span.subrange(0, span.size() - temp_span.size());")
282            self.append_(f"span.skip({field.id}_.size());")
283            self.append_("}")
284
285        # The array size is not known, assume the array takes the
286        # full remaining space. TODO support having fixed sized fields
287        # following the array.
288        else:
289            self.append_(f"{field.id}_ = span;")
290            self.append_("span.clear();")
291
292        if padded_size:
293            self.append_(f"span = remaining_span;")
294            self.append_("}")
295
296    def parse_array_field_full_(self, field: ast.ArrayField):
297        """Parse the selected array field.
298        This function does not attempt to parse all elements but just to
299        identify the span of the array."""
300        array_size = core.get_array_field_size(field)
301        element_width = core.get_array_element_size(field)
302        element_type = field.type_id or get_cxx_scalar_type(field.width)
303        padded_size = field.padded_size
304
305        if element_width:
306            element_width = int(element_width / 8)
307
308        if isinstance(array_size, int):
309            size = None
310            count = array_size
311        elif isinstance(array_size, ast.SizeField):
312            size = f'{field.id}_size'
313            count = None
314        elif isinstance(array_size, ast.CountField):
315            size = None
316            count = f'{field.id}_count'
317        else:
318            size = None
319            count = None
320
321        # Shift the span to reset the offset to 0.
322        self.check_code_()
323
324        # Apply the size modifier.
325        if field.size_modifier and size:
326            self.append_(f"{size} = {size} - {field.size_modifier};")
327
328        # Compute the array size if the count and element width are known.
329        if count is not None and element_width is not None:
330            size = f"{count} * {element_width}"
331
332        # Parse from the padded array if padding is present.
333        if padded_size:
334            self.check_size_(padded_size)
335            self.append_("{")
336            self.append_(
337                f"pdl::packet::slice remaining_span = span.subrange({padded_size}, span.size() - {padded_size});")
338            self.append_(f"span = span.subrange(0, {padded_size});")
339
340        # The array count is known statically, elements are scalar.
341        if field.width and field.size:
342            assert size is not None
343            self.check_size_(size)
344            element_size = int(field.width / 8)
345            self.append_(f"for (size_t n = 0; n < {field.size}; n++) {{")
346            self.append_(f"    {field.id}_[n] = span.read_{self.byteorder}<{element_type}, {element_size}>();")
347            self.append_("}")
348
349        # The array count is known statically, elements are enum values.
350        elif isinstance(field.type, ast.EnumDeclaration) and field.size:
351            assert size is not None
352            self.check_size_(size)
353            element_size = int(field.type.width / 8)
354            backing_type = get_cxx_scalar_type(field.type.width)
355            self.append_(f"for (size_t n = 0; n < {field.size}; n++) {{")
356            self.append_(
357                f"    {field.id}_[n] = {element_type}(span.read_{self.byteorder}<{backing_type}, {element_size}>());")
358            self.append_("}")
359
360        # The array count is known statically, elements have variable size.
361        elif field.size:
362            self.append_(f"for (size_t n = 0; n < {field.size}; n++) {{")
363            self.append_(f"    if (!{element_type}::Parse(span, &{field.id}_[n])) {{")
364            self.append_("        return false;")
365            self.append_("    }")
366            self.append_("}")
367
368        # The array size is known in bytes.
369        elif size is not None:
370            self.check_size_(size)
371            self.append_("{")
372            self.append_(f"pdl::packet::slice temp_span = span.subrange(0, {size});")
373            self.append_(f"span.skip({size});")
374            self.append_(f"while (temp_span.size() > 0) {{")
375            if field.width:
376                element_size = int(field.width / 8)
377                self.append_(f"    if (temp_span.size() < {element_size}) {{")
378                self.append_(f"        return false;")
379                self.append_("    }")
380                self.append_(
381                    f"    {field.id}_.push_back(temp_span.read_{self.byteorder}<{element_type}, {element_size}>());")
382            elif isinstance(field.type, ast.EnumDeclaration):
383                backing_type = get_cxx_scalar_type(field.type.width)
384                element_size = int(field.type.width / 8)
385                self.append_(f"    if (temp_span.size() < {element_size}) {{")
386                self.append_(f"        return false;")
387                self.append_("    }")
388                self.append_(
389                    f"    {field.id}_.push_back({element_type}(temp_span.read_{self.byteorder}<{backing_type}, {element_size}>()));"
390                )
391            else:
392                self.append_(f"    {element_type} element;")
393                self.append_(f"    if (!{element_type}::Parse(temp_span, &element)) {{")
394                self.append_(f"        return false;")
395                self.append_("    }")
396                self.append_(f"    {field.id}_.emplace_back(std::move(element));")
397            self.append_("}")
398            self.append_("}")
399
400        # The array count is known. The element width is dynamic.
401        # Parse each element iteratively and derive the array span.
402        elif count is not None:
403            self.append_(f"for (size_t n = 0; n < {count}; n++) {{")
404            self.append_(f"    {element_type} element;")
405            self.append_(f"    if (!{field.type_id}::Parse(span, &element)) {{")
406            self.append_("        return false;")
407            self.append_("    }")
408            self.append_(f"    {field.id}_.emplace_back(std::move(element));")
409            self.append_("}")
410
411        # The array size is not known, assume the array takes the
412        # full remaining space. TODO support having fixed sized fields
413        # following the array.
414        elif field.width:
415            element_size = int(field.width / 8)
416            self.append_(f"while (span.size() > 0) {{")
417            self.append_(f"    if (span.size() < {element_size}) {{")
418            self.append_(f"        return false;")
419            self.append_("    }")
420            self.append_(f"    {field.id}_.push_back(span.read_{self.byteorder}<{element_type}, {element_size}>());")
421            self.append_("}")
422        elif isinstance(field.type, ast.EnumDeclaration):
423            element_size = int(field.type.width / 8)
424            backing_type = get_cxx_scalar_type(field.type.width)
425            self.append_(f"while (span.size() > 0) {{")
426            self.append_(f"    if (span.size() < {element_size}) {{")
427            self.append_(f"        return false;")
428            self.append_("    }")
429            self.append_(
430                f"    {field.id}_.push_back({element_type}(span.read_{self.byteorder}<{backing_type}, {element_size}>()));"
431            )
432            self.append_("}")
433        else:
434            self.append_(f"while (span.size() > 0) {{")
435            self.append_(f"    {element_type} element;")
436            self.append_(f"    if (!{element_type}::Parse(span, &element)) {{")
437            self.append_(f"        return false;")
438            self.append_("    }")
439            self.append_(f"    {field.id}_.emplace_back(std::move(element));")
440            self.append_("}")
441
442        if padded_size:
443            self.append_(f"span = remaining_span;")
444            self.append_("}")
445
446    def parse_payload_field_lite_(self, field: Union[ast.BodyField, ast.PayloadField]):
447        """Parse body and payload fields."""
448        if self.shift != 0:
449            raise Exception('Payload field does not start on an octet boundary')
450
451        payload_size = core.get_payload_field_size(field)
452        offset_from_end = core.get_field_offset_from_end(field)
453        self.check_code_()
454
455        if payload_size and getattr(field, 'size_modifier', None):
456            self.append_(f"{field.id}_size -= {field.size_modifier};")
457
458        # The payload or body has a known size.
459        # Consume the payload and update the span in case
460        # fields are placed after the payload.
461        if payload_size:
462            self.check_size_(f"{field.id}_size")
463            self.append_(f"payload_ = span.subrange(0, {field.id}_size);")
464            self.append_(f"span.skip({field.id}_size);")
465        # The payload or body is the last field of a packet,
466        # consume the remaining span.
467        elif offset_from_end == 0:
468            self.append_(f"payload_ = span;")
469            self.append_(f"span.clear();")
470        # The payload or body is followed by fields of static size.
471        # Consume the span that is not reserved for the following fields.
472        elif offset_from_end:
473            if (offset_from_end % 8) != 0:
474                raise Exception('Payload field offset from end of packet is not a multiple of 8')
475            offset_from_end = int(offset_from_end / 8)
476            self.check_size_(f'{offset_from_end}')
477            self.append_(f"payload_ = span.subrange(0, span.size() - {offset_from_end});")
478            self.append_(f"span.skip(payload_.size());")
479
480    def parse_payload_field_full_(self, field: Union[ast.BodyField, ast.PayloadField]):
481        """Parse body and payload fields."""
482        if self.shift != 0:
483            raise Exception('Payload field does not start on an octet boundary')
484
485        payload_size = core.get_payload_field_size(field)
486        offset_from_end = core.get_field_offset_from_end(field)
487        self.check_code_()
488
489        if payload_size and getattr(field, 'size_modifier', None):
490            self.append_(f"{field.id}_size -= {field.size_modifier};")
491
492        # The payload or body has a known size.
493        # Consume the payload and update the span in case
494        # fields are placed after the payload.
495        if payload_size:
496            self.check_size_(f"{field.id}_size")
497            self.append_(f"for (size_t n = 0; n < {field.id}_size; n++) {{")
498            self.append_(f"    payload_.push_back(span.read_{self.byteorder}<uint8_t>();")
499            self.append_("}")
500        # The payload or body is the last field of a packet,
501        # consume the remaining span.
502        elif offset_from_end == 0:
503            self.append_("while (span.size() > 0) {")
504            self.append_(f"    payload_.push_back(span.read_{self.byteorder}<uint8_t>();")
505            self.append_("}")
506        # The payload or body is followed by fields of static size.
507        # Consume the span that is not reserved for the following fields.
508        elif offset_from_end is not None:
509            if (offset_from_end % 8) != 0:
510                raise Exception('Payload field offset from end of packet is not a multiple of 8')
511            offset_from_end = int(offset_from_end / 8)
512            self.check_size_(f'{offset_from_end}')
513            self.append_(f"while (span.size() > {offset_from_end}) {{")
514            self.append_(f"    payload_.push_back(span.read_{self.byteorder}<uint8_t>();")
515            self.append_("}")
516
517    def parse(self, field: ast.Field):
518        # Field has bit granularity.
519        # Append the field to the current chunk,
520        # check if a byte boundary was reached.
521        if field.cond:
522            self.parse_optional_field_(field)
523
524        elif core.is_bit_field(field):
525            self.parse_bit_field_(field)
526
527        # Padding fields.
528        elif isinstance(field, ast.PaddingField):
529            pass
530
531        # Array fields.
532        elif isinstance(field, ast.ArrayField) and self.extract_arrays:
533            self.parse_array_field_full_(field)
534
535        elif isinstance(field, ast.ArrayField) and not self.extract_arrays:
536            self.parse_array_field_lite_(field)
537
538        # Other typedef fields.
539        elif isinstance(field, ast.TypedefField):
540            self.parse_typedef_field_(field)
541
542        # Payload and body fields.
543        elif isinstance(field, (ast.PayloadField, ast.BodyField)) and self.extract_arrays:
544            self.parse_payload_field_full_(field)
545
546        elif isinstance(field, (ast.PayloadField, ast.BodyField)) and not self.extract_arrays:
547            self.parse_payload_field_lite_(field)
548
549        else:
550            raise Exception(f'Unsupported field type {field.kind}')
551
552    def done(self):
553        self.check_code_()
554
555
556@dataclass
557class FieldSerializer:
558    byteorder: str
559    shift: int = 0
560    value: List[Tuple[str, int]] = field(default_factory=lambda: [])
561    code: List[str] = field(default_factory=lambda: [])
562    indent: int = 0
563
564    def indent_(self):
565        self.indent += 1
566
567    def unindent_(self):
568        self.indent -= 1
569
570    def append_(self, line: str):
571        """Append field serializing code."""
572        lines = line.split('\n')
573        self.code.extend(['    ' * self.indent + line for line in lines])
574
575    def get_payload_field_size(self, var: Optional[str], payload: ast.PayloadField, decl: ast.Declaration) -> str:
576        """Compute the size of the selected payload field, with the information
577        of the builder for the selected declaration. The payload field can be
578        the payload of any of the parent declarations, or the current declaration."""
579
580        if payload.parent.id == decl.id:
581            return deref(var, 'payload_.size()')
582
583        # Get the child packet declaration that will match the current
584        # declaration further down.
585        child = decl
586        while child.parent_id != payload.parent.id:
587            child = child.parent
588
589        # The payload is the result of serializing the children fields.
590        constant_width = 0
591        variable_width = []
592        for f in child.fields:
593            field_size = core.get_field_size(f)
594            if field_size is not None:
595                constant_width += field_size
596            elif isinstance(f, (ast.PayloadField, ast.BodyField)):
597                variable_width.append(self.get_payload_field_size(var, f, decl))
598            elif isinstance(f, ast.TypedefField):
599                variable_width.append(f"{f.id}_.GetSize()")
600            elif isinstance(f, ast.ArrayField):
601                variable_width.append(f"Get{to_pascal_case(f.id)}Size()")
602            else:
603                raise Exception("Unsupported field type")
604
605        constant_width = int(constant_width / 8)
606        if constant_width and not variable_width:
607            return str(constant_width)
608
609        temp_var = f'{payload.parent.id.lower()}_payload_size'
610        self.append_(f"size_t {temp_var} = {constant_width};")
611        for dyn in variable_width:
612            self.append_(f"{temp_var} += {dyn};")
613        return temp_var
614
615    def serialize_array_element_(self, field: ast.ArrayField, var: str):
616        """Serialize a single array field element."""
617        if field.width:
618            backing_type = get_cxx_scalar_type(field.width)
619            element_size = int(field.width / 8)
620            self.append_(
621                f"pdl::packet::Builder::write_{self.byteorder}<{backing_type}, {element_size}>(output, {var});")
622        elif isinstance(field.type, ast.EnumDeclaration):
623            backing_type = get_cxx_scalar_type(field.type.width)
624            element_size = int(field.type.width / 8)
625            self.append_(f"pdl::packet::Builder::write_{self.byteorder}<{backing_type}, {element_size}>(" +
626                         f"output, static_cast<{backing_type}>({var}));")
627        else:
628            self.append_(f"{var}.Serialize(output);")
629
630    def serialize_array_field_(self, field: ast.ArrayField, var: str):
631        """Serialize the selected array field."""
632        if field.padded_size:
633            self.append_(f"size_t {field.id}_end = output.size() + {field.padded_size};")
634
635        if field.width == 8:
636            self.append_(f"output.insert(output.end(), {var}.begin(), {var}.end());")
637        else:
638            self.append_(f"for (size_t n = 0; n < {var}.size(); n++) {{")
639            self.indent_()
640            self.serialize_array_element_(field, f'{var}[n]')
641            self.unindent_()
642            self.append_("}")
643
644        if field.padded_size:
645            self.append_(f"while (output.size() < {field.id}_end) {{")
646            self.append_("    output.push_back(0);")
647            self.append_("}")
648
649    def serialize_bit_field_(self, field: ast.Field, parent_var: Optional[str], var: Optional[str],
650                             decl: ast.Declaration):
651        """Serialize the selected field as a bit field.
652        The field is added to the current chunk. When a byte boundary
653        is reached all saved fields are serialized together."""
654
655        # Add to current chunk.
656        width = core.get_field_size(field)
657        shift = self.shift
658
659        if field.cond_for:
660            value_present = field.cond_for.cond.value
661            value_absent = 0 if field.cond_for.cond.value else 1
662            self.value.append((f"({field.cond_for.id}_.has_value() ? {value_present} : {value_absent})", shift))
663        elif isinstance(field, ast.ScalarField):
664            self.value.append((f"{var} & {mask(field.width)}", shift))
665        elif isinstance(field, ast.FixedField) and field.enum_id:
666            self.value.append((f"{field.enum_id}::{field.tag_id}", shift))
667        elif isinstance(field, ast.FixedField):
668            self.value.append((f"{field.value}", shift))
669        elif isinstance(field, ast.TypedefField):
670            self.value.append((f"{var}", shift))
671
672        elif isinstance(field, ast.SizeField):
673            max_size = (1 << field.width) - 1
674            value_field = core.get_packet_field(field.parent, field.field_id)
675            size_modifier = ''
676
677            if getattr(value_field, 'size_modifier', None):
678                size_modifier = f' + {value_field.size_modifier}'
679
680            if isinstance(value_field, (ast.PayloadField, ast.BodyField)):
681                array_size = self.get_payload_field_size(var, field, decl) + size_modifier
682
683            elif isinstance(value_field, ast.ArrayField):
684                accessor_name = to_pascal_case(field.field_id)
685                array_size = deref(var, f'Get{accessor_name}Size()') + size_modifier
686
687            self.value.append((f"{array_size}", shift))
688
689        elif isinstance(field, ast.CountField):
690            max_count = (1 << field.width) - 1
691            self.value.append((f"{field.field_id}_.size()", shift))
692
693        elif isinstance(field, ast.ReservedField):
694            pass
695        else:
696            raise Exception(f'Unsupported bit field type {field.kind}')
697
698        # Check if a byte boundary is reached.
699        self.shift += width
700        if (self.shift % 8) == 0:
701            self.pack_bit_fields_()
702
703    def pack_bit_fields_(self):
704        """Pack serialized bit fields."""
705
706        # Should have an integral number of bytes now.
707        assert (self.shift % 8) == 0
708
709        # Generate the backing integer, and serialize it
710        # using the configured endiannes,
711        size = int(self.shift / 8)
712        backing_type = get_cxx_scalar_type(self.shift)
713        value = [f"(static_cast<{backing_type}>({v[0]}) << {v[1]})" for v in self.value]
714
715        if len(value) == 0:
716            self.append_(f"pdl::packet::Builder::write_{self.byteorder}<{backing_type}, {size}>(output, 0);")
717        elif len(value) == 1:
718            self.append_(f"pdl::packet::Builder::write_{self.byteorder}<{backing_type}, {size}>(output, {value[0]});")
719        else:
720            self.append_(
721                f"pdl::packet::Builder::write_{self.byteorder}<{backing_type}, {size}>(output, {' | '.join(value)});")
722
723        # Reset state.
724        self.shift = 0
725        self.value = []
726
727    def serialize_typedef_field_(self, field: ast.TypedefField, var: str):
728        """Serialize a typedef field, to the exclusion of Enum fields."""
729
730        if self.shift != 0:
731            raise Exception('Typedef field does not start on an octet boundary')
732        if (isinstance(field.type, ast.StructDeclaration) and field.type.parent_id is not None):
733            raise Exception('Derived struct used in typedef field')
734
735        self.append_(f"{var}.Serialize(output);")
736
737    def serialize_optional_field_(self, field: ast.Field):
738        """Serialize optional scalar or typedef fields."""
739
740        if isinstance(field, ast.ScalarField):
741            backing_type = get_cxx_scalar_type(field.width)
742            self.append_(dedent(
743                """
744                if ({field_id}_.has_value()) {{
745                    pdl::packet::Builder::write_{byteorder}<{backing_type}, {size}>(output, {field_id}_.value());
746                }}""".format(field_id=field.id,
747                            size=int(field.width / 8),
748                            backing_type=backing_type,
749                            byteorder=self.byteorder)))
750
751        elif isinstance(field, ast.TypedefField) and isinstance(field.type, ast.EnumDeclaration):
752            backing_type = get_cxx_scalar_type(field.type.width)
753            self.append_(dedent(
754                """
755                if ({field_id}_.has_value()) {{
756                    pdl::packet::Builder::write_{byteorder}<{backing_type}, {size}>(
757                        output, static_cast<{backing_type}>({field_id}_.value()));
758                }}""".format(field_id=field.id,
759                            size=int(field.type.width / 8),
760                            backing_type=backing_type,
761                            byteorder=self.byteorder)))
762
763        elif isinstance(field, ast.TypedefField):
764            self.append_(dedent(
765                """
766                if ({field_id}_.has_value()) {{
767                    {field_id}_->Serialize(output);
768                }}""".format(field_id=field.id)))
769
770        else:
771            raise Exception(f"unsupported field type {field.__class__.__name__}")
772
773    def serialize_payload_field_(self, field: Union[ast.BodyField, ast.PayloadField], var: str):
774        """Serialize body and payload fields."""
775
776        if self.shift != 0:
777            raise Exception('Payload field does not start on an octet boundary')
778
779        self.append_(f"output.insert(output.end(), {var}.begin(), {var}.end());")
780
781    def serialize(self, field: ast.Field, decl: ast.Declaration, var: Optional[str] = None):
782        field_var = deref(var, f'{field.id}_') if hasattr(field, 'id') else None
783
784        if field.cond:
785            self.serialize_optional_field_(field)
786
787        # Field has bit granularity.
788        # Append the field to the current chunk,
789        # check if a byte boundary was reached.
790        elif core.is_bit_field(field):
791            self.serialize_bit_field_(field, var, field_var, decl)
792
793        # Padding fields.
794        elif isinstance(field, ast.PaddingField):
795            pass
796
797        # Array fields.
798        elif isinstance(field, ast.ArrayField):
799            self.serialize_array_field_(field, field_var)
800
801        # Other typedef fields.
802        elif isinstance(field, ast.TypedefField):
803            self.serialize_typedef_field_(field, field_var)
804
805        # Payload and body fields.
806        elif isinstance(field, (ast.PayloadField, ast.BodyField)):
807            self.serialize_payload_field_(field, deref(var, 'payload_'))
808
809        else:
810            raise Exception(f'Unimplemented field type {field.kind}')
811
812
813def generate_enum_declaration(decl: ast.EnumDeclaration) -> str:
814    """Generate the implementation of an enum type."""
815
816    enum_name = decl.id
817    enum_type = get_cxx_scalar_type(decl.width)
818    tag_decls = []
819    for t in decl.tags:
820        # Exclude default tags: DEFAULT = ..
821        if t.value is not None:
822            tag_decls.append(f"{t.id} = {hex(t.value)},")
823
824    return dedent("""\
825
826        enum class {enum_name} : {enum_type} {{
827            {tag_decls}
828        }};
829        """).format(enum_name=enum_name, enum_type=enum_type, tag_decls=indent(tag_decls, 1))
830
831
832def generate_enum_to_text(decl: ast.EnumDeclaration) -> str:
833    """Generate the helper function that will convert an enum tag to string."""
834
835    enum_name = decl.id
836    tag_cases = []
837    for t in decl.tags:
838        # Exclude default tags: DEFAULT = ..
839        if t.value is not None:
840            tag_cases.append(f"case {enum_name}::{t.id}: return \"{t.id}\";")
841
842    return dedent("""\
843
844        inline std::string {enum_name}Text({enum_name} tag) {{
845            switch (tag) {{
846                {tag_cases}
847                default:
848                    return std::string("Unknown {enum_name}: " +
849                           std::to_string(static_cast<uint64_t>(tag)));
850            }}
851        }}
852        """).format(enum_name=enum_name, tag_cases=indent(tag_cases, 2))
853
854
855def generate_packet_view_field_members(decl: ast.Declaration) -> List[str]:
856    """Return the declaration of fields that are backed in the view
857    class declaration.
858
859    Backed fields include all named fields that do not have a constrained
860    value in the selected declaration and its parents.
861
862    :param decl: target declaration"""
863
864    fields = core.get_unconstrained_parent_fields(decl) + decl.fields
865    members = []
866    for field in fields:
867        if field.cond_for:
868            # Scalar fields used as condition for optional fields are treated
869            # as fixed fields since their value is tied to the value of the
870            # optional field.
871            pass
872        elif isinstance(field, (ast.PayloadField, ast.BodyField)):
873            members.append("pdl::packet::slice payload_;")
874        elif isinstance(field, ast.ArrayField):
875            members.append(f"pdl::packet::slice {field.id}_;")
876        elif isinstance(field, ast.ScalarField) and field.cond:
877            members.append(f"std::optional<{get_cxx_scalar_type(field.width)}> {field.id}_{{}};")
878        elif isinstance(field, ast.TypedefField) and field.cond:
879            members.append(f"std::optional<{field.type_id}> {field.id}_{{}};")
880        elif isinstance(field, ast.ScalarField):
881            members.append(f"{get_cxx_scalar_type(field.width)} {field.id}_{{0}};")
882        elif isinstance(field, ast.TypedefField) and isinstance(field.type, ast.EnumDeclaration):
883            members.append(f"{field.type_id} {field.id}_{{{field.type_id}::{field.type.tags[0].id}}};")
884        elif isinstance(field, ast.TypedefField):
885            members.append(f"{field.type_id} {field.id}_;")
886
887    return members
888
889
890def generate_packet_field_members(decl: ast.Declaration) -> List[str]:
891    """Return the declaration of fields that are backed in the view
892    class declaration.
893
894    Backed fields include all named fields that do not have a constrained
895    value in the selected declaration and its parents.
896
897    :param decl: target declaration"""
898
899    members = []
900    for field in decl.fields:
901        if field.cond_for:
902            # Scalar fields used as condition for optional fields are treated
903            # as fixed fields since their value is tied to the value of the
904            # optional field.
905            pass
906        elif isinstance(field, (ast.PayloadField, ast.BodyField)) and not decl.parent:
907            members.append("std::vector<uint8_t> payload_;")
908        elif isinstance(field, ast.ArrayField) and field.size:
909            element_type = field.type_id or get_cxx_scalar_type(field.width)
910            members.append(f"std::array<{element_type}, {field.size}> {field.id}_;")
911        elif isinstance(field, ast.ArrayField):
912            element_type = field.type_id or get_cxx_scalar_type(field.width)
913            members.append(f"std::vector<{element_type}> {field.id}_;")
914        elif isinstance(field, ast.ScalarField) and field.cond:
915            members.append(f"std::optional<{get_cxx_scalar_type(field.width)}> {field.id}_{{}};")
916        elif isinstance(field, ast.TypedefField) and field.cond:
917            members.append(f"std::optional<{field.type_id}> {field.id}_{{}};")
918        elif isinstance(field, ast.ScalarField):
919            members.append(f"{get_cxx_scalar_type(field.width)} {field.id}_{{0}};")
920        elif isinstance(field, ast.TypedefField) and isinstance(field.type, ast.EnumDeclaration):
921            members.append(f"{field.type_id} {field.id}_{{{field.type_id}::{field.type.tags[0].id}}};")
922        elif isinstance(field, ast.TypedefField):
923            members.append(f"{field.type_id} {field.id}_;")
924
925    return members
926
927
928def generate_packet_field_serializers(packet: ast.Declaration) -> List[str]:
929    """Generate the code to serialize the fields of a packet builder or struct."""
930    serializer = FieldSerializer(byteorder=packet.file.byteorder_short)
931    constraints = core.get_parent_constraints(packet)
932    constraints = dict([(c.id, c) for c in constraints])
933    for field in core.get_packet_fields(packet):
934        field_id = getattr(field, 'id', None)
935        constraint = constraints.get(field_id, None)
936        fixed_field = None
937        if constraint and constraint.tag_id:
938            fixed_field = ast.FixedField(enum_id=field.type_id,
939                                         tag_id=constraint.tag_id,
940                                         loc=field.loc,
941                                         kind='fixed_field')
942            fixed_field.parent = field.parent
943        elif constraint:
944            fixed_field = ast.FixedField(width=field.width, value=constraint.value, loc=field.loc, kind='fixed_field')
945            fixed_field.parent = field.parent
946        serializer.serialize(fixed_field or field, packet)
947    return serializer.code
948
949
950def generate_scalar_array_field_accessor(field: ast.ArrayField) -> str:
951    """Parse the selected scalar array field."""
952    element_size = int(field.width / 8)
953    backing_type = get_cxx_scalar_type(field.width)
954    byteorder = field.parent.file.byteorder_short
955    if field.size:
956        return dedent("""\
957            pdl::packet::slice span = {field_id}_;
958            std::array<{backing_type}, {array_size}> elements;
959            for (int n = 0; n < {array_size}; n++) {{
960                elements[n] = span.read_{byteorder}<{backing_type}, {element_size}>();
961            }}
962            return elements;""").format(field_id=field.id,
963                                        backing_type=backing_type,
964                                        element_size=element_size,
965                                        array_size=field.size,
966                                        byteorder=byteorder)
967    else:
968        return dedent("""\
969            pdl::packet::slice span = {field_id}_;
970            std::vector<{backing_type}> elements;
971            while (span.size() >= {element_size}) {{
972                elements.push_back(span.read_{byteorder}<{backing_type}, {element_size}>());
973            }}
974            return elements;""").format(field_id=field.id,
975                                        backing_type=backing_type,
976                                        element_size=element_size,
977                                        byteorder=byteorder)
978
979
980def generate_enum_array_field_accessor(field: ast.ArrayField) -> str:
981    """Parse the selected enum array field."""
982    element_size = int(field.type.width / 8)
983    backing_type = get_cxx_scalar_type(field.type.width)
984    byteorder = field.parent.file.byteorder_short
985    if field.size:
986        return dedent("""\
987            pdl::packet::slice span = {field_id}_;
988            std::array<{enum_type}, {array_size}> elements;
989            for (int n = 0; n < {array_size}; n++) {{
990                elements[n] = {enum_type}(span.read_{byteorder}<{backing_type}, {element_size}>());
991            }}
992            return elements;""").format(field_id=field.id,
993                                        enum_type=field.type.id,
994                                        backing_type=backing_type,
995                                        element_size=element_size,
996                                        array_size=field.size,
997                                        byteorder=byteorder)
998    else:
999        return dedent("""\
1000            pdl::packet::slice span = {field_id}_;
1001            std::vector<{enum_type}> elements;
1002            while (span.size() >= {element_size}) {{
1003                elements.push_back({enum_type}(span.read_{byteorder}<{backing_type}, {element_size}>()));
1004            }}
1005            return elements;""").format(field_id=field.id,
1006                                        enum_type=field.type_id,
1007                                        backing_type=backing_type,
1008                                        element_size=element_size,
1009                                        byteorder=byteorder)
1010
1011
1012def generate_typedef_array_field_accessor(field: ast.ArrayField) -> str:
1013    """Parse the selected typedef array field."""
1014    if field.size:
1015        return dedent("""\
1016            pdl::packet::slice span = {field_id}_;
1017            std::array<{struct_type}, {array_size}> elements;
1018            for (int n = 0; n < {array_size}; n++) {{
1019                {struct_type}::Parse(span, &elements[n]);
1020            }}
1021            return elements;""").format(field_id=field.id, struct_type=field.type_id, array_size=field.size)
1022    else:
1023        return dedent("""\
1024                pdl::packet::slice span = {field_id}_;
1025                std::vector<{struct_type}> elements;
1026                for (;;) {{
1027                    {struct_type} element;
1028                    if (!{struct_type}::Parse(span, &element)) {{
1029                        break;
1030                    }}
1031                    elements.emplace_back(std::move(element));
1032                }}
1033                return elements;""").format(field_id=field.id, struct_type=field.type_id)
1034
1035
1036def generate_array_field_accessor(field: ast.ArrayField):
1037    """Parse the selected array field."""
1038
1039    if field.width is not None:
1040        return generate_scalar_array_field_accessor(field)
1041    elif isinstance(field.type, ast.EnumDeclaration):
1042        return generate_enum_array_field_accessor(field)
1043    else:
1044        return generate_typedef_array_field_accessor(field)
1045
1046
1047def generate_array_field_size_getters(decl: ast.Declaration) -> str:
1048    """Generate size getters for array fields. Produces the serialized
1049    size of the array in bytes."""
1050
1051    getters = []
1052    fields = core.get_unconstrained_parent_fields(decl) + decl.fields
1053    for field in fields:
1054        if not isinstance(field, ast.ArrayField):
1055            continue
1056
1057        element_width = field.width or core.get_declaration_size(field.type)
1058        size = None
1059
1060        if element_width and field.size:
1061            size = int(element_width * field.size / 8)
1062        elif element_width:
1063            size = f"{field.id}_.size() * {int(element_width / 8)}"
1064
1065        if size:
1066            getters.append(
1067                dedent("""\
1068                size_t Get{accessor_name}Size() const {{
1069                    return {size};
1070                }}
1071                """).format(accessor_name=to_pascal_case(field.id), size=size))
1072        else:
1073            getters.append(
1074                dedent("""\
1075                size_t Get{accessor_name}Size() const {{
1076                    size_t array_size = 0;
1077                    for (size_t n = 0; n < {field_id}_.size(); n++) {{
1078                        array_size += {field_id}_[n].GetSize();
1079                    }}
1080                    return array_size;
1081                }}
1082                """).format(accessor_name=to_pascal_case(field.id), field_id=field.id))
1083
1084    return '\n'.join(getters)
1085
1086
1087def generate_packet_size_getter(decl: ast.Declaration) -> List[str]:
1088    """Generate a size getter the current packet. Produces the serialized
1089    size of the packet in bytes."""
1090
1091    constant_width = 0
1092    variable_width = []
1093    for f in core.get_packet_fields(decl):
1094        field_size = core.get_field_size(f)
1095        if f.cond:
1096            if isinstance(f, ast.ScalarField):
1097                variable_width.append(f"({f.id}_.has_value() ? {f.width} : 0)")
1098            elif isinstance(f, ast.TypedefField) and isinstance(f.type, ast.EnumDeclaration):
1099                variable_width.append(f"({f.id}_.has_value() ? {f.type.width} : 0)")
1100            elif isinstance(f, ast.TypedefField):
1101                variable_width.append(f"({f.id}_.has_value() ? {f.id}_->GetSize() : 0)")
1102            else:
1103                raise Exception(f"unsupported field type {f.__class__.__name__}")
1104        elif field_size is not None:
1105            constant_width += field_size
1106        elif isinstance(f, (ast.PayloadField, ast.BodyField)):
1107            variable_width.append("payload_.size()")
1108        elif isinstance(f, ast.TypedefField):
1109            variable_width.append(f"{f.id}_.GetSize()")
1110        elif isinstance(f, ast.ArrayField):
1111            variable_width.append(f"Get{to_pascal_case(f.id)}Size()")
1112        else:
1113            raise Exception("Unsupported field type")
1114
1115    constant_width = int(constant_width / 8)
1116    if not variable_width:
1117        return [f"return {constant_width};"]
1118    elif len(variable_width) == 1 and constant_width:
1119        return [f"return {variable_width[0]} + {constant_width};"]
1120    elif len(variable_width) == 1:
1121        return [f"return {variable_width[0]};"]
1122    elif len(variable_width) > 1 and constant_width:
1123        return ([f"return {constant_width} + ("] + " +\n    ".join(variable_width).split("\n") + [");"])
1124    elif len(variable_width) > 1:
1125        return (["return ("] + " +\n    ".join(variable_width).split("\n") + [");"])
1126    else:
1127        assert False
1128
1129
1130def generate_packet_view_field_accessors(packet: ast.PacketDeclaration) -> List[str]:
1131    """Return the declaration of accessors for the named packet fields."""
1132
1133    accessors = []
1134
1135    # Add accessors for the backed fields.
1136    fields = core.get_unconstrained_parent_fields(packet) + packet.fields
1137    for field in fields:
1138        if field.cond_for:
1139            # Scalar fields used as condition for optional fields are treated
1140            # as fixed fields since their value is tied to the value of the
1141            # optional field.
1142            pass
1143        elif isinstance(field, (ast.PayloadField, ast.BodyField)):
1144            accessors.append(
1145                dedent("""\
1146                std::vector<uint8_t> GetPayload() const {
1147                    _ASSERT_VALID(valid_);
1148                    return payload_.bytes();
1149                }
1150
1151                """))
1152        elif isinstance(field, ast.ArrayField):
1153            element_type = field.type_id or get_cxx_scalar_type(field.width)
1154            array_type = (f"std::array<{element_type}, {field.size}>" if field.size else f"std::vector<{element_type}>")
1155            accessor_name = to_pascal_case(field.id)
1156            accessors.append(
1157                dedent("""\
1158                {array_type} Get{accessor_name}() const {{
1159                    _ASSERT_VALID(valid_);
1160                    {accessor}
1161                }}
1162
1163                """).format(array_type=array_type,
1164                            accessor_name=accessor_name,
1165                            accessor=indent(generate_array_field_accessor(field), 1)))
1166        elif isinstance(field, ast.ScalarField):
1167            field_type = get_cxx_scalar_type(field.width)
1168            field_type = f"std::optional<{field_type}>" if field.cond else field_type
1169            accessor_name = to_pascal_case(field.id)
1170            accessors.append(
1171                dedent("""\
1172                {field_type} Get{accessor_name}() const {{
1173                    _ASSERT_VALID(valid_);
1174                    return {member_name}_;
1175                }}
1176
1177                """).format(field_type=field_type, accessor_name=accessor_name, member_name=field.id))
1178        elif isinstance(field, ast.TypedefField):
1179            field_qualifier = "" if isinstance(field.type, ast.EnumDeclaration) else " const&"
1180            field_type = f"std::optional<{field.type_id}>" if field.cond else field.type_id
1181            accessor_name = to_pascal_case(field.id)
1182            accessors.append(
1183                dedent("""\
1184                {field_type}{field_qualifier} Get{accessor_name}() const {{
1185                    _ASSERT_VALID(valid_);
1186                    return {member_name}_;
1187                }}
1188
1189                """).format(field_type=field_type,
1190                            field_qualifier=field_qualifier,
1191                            accessor_name=accessor_name,
1192                            member_name=field.id))
1193
1194    # Add accessors for constrained parent fields.
1195    # The accessors return a constant value in this case.
1196    for c in core.get_parent_constraints(packet):
1197        field = core.get_packet_field(packet, c.id)
1198        if isinstance(field, ast.ScalarField):
1199            field_type = get_cxx_scalar_type(field.width)
1200            accessor_name = to_pascal_case(field.id)
1201            accessors.append(
1202                dedent("""\
1203                {field_type} Get{accessor_name}() const {{
1204                    return {value};
1205                }}
1206
1207                """).format(field_type=field_type, accessor_name=accessor_name, value=c.value))
1208        else:
1209            accessor_name = to_pascal_case(field.id)
1210            accessors.append(
1211                dedent("""\
1212                {field_type} Get{accessor_name}() const {{
1213                    return {field_type}::{tag_id};
1214                }}
1215
1216                """).format(field_type=field.type_id, accessor_name=accessor_name, tag_id=c.tag_id))
1217
1218    return "".join(accessors)
1219
1220
1221def generate_packet_stringifier(packet: ast.PacketDeclaration) -> str:
1222    """Generate the packet printer. TODO """
1223    return dedent("""\
1224        std::string ToString() const {
1225            return "";
1226        }
1227        """)
1228
1229
1230def generate_packet_view_field_parsers(packet: ast.PacketDeclaration) -> str:
1231    """Generate the packet parser. The validator will extract
1232    the fields it can in a pre-parsing phase. """
1233
1234    code = []
1235
1236    # Generate code to check the validity of the parent,
1237    # and import parent fields that do not have a fixed value in the
1238    # current packet.
1239    if packet.parent:
1240        code.append(
1241            dedent("""\
1242            // Check validity of parent packet.
1243            if (!parent.IsValid()) {
1244                return false;
1245            }
1246            """))
1247        parent_fields = core.get_unconstrained_parent_fields(packet)
1248        if parent_fields:
1249            code.append("// Copy parent field values.")
1250            for f in parent_fields:
1251                code.append(f"{f.id}_ = parent.{f.id}_;")
1252            code.append("")
1253        span = "parent.payload_"
1254    else:
1255        span = "parent"
1256
1257    # Validate parent constraints.
1258    for c in packet.constraints:
1259        if c.tag_id:
1260            enum_type = core.get_packet_field(packet.parent, c.id).type_id
1261            code.append(
1262                dedent("""\
1263                if (parent.{field_id}_ != {enum_type}::{tag_id}) {{
1264                    return false;
1265                }}
1266                """).format(field_id=c.id, enum_type=enum_type, tag_id=c.tag_id))
1267        else:
1268            code.append(
1269                dedent("""\
1270                if (parent.{field_id}_ != {value}) {{
1271                    return false;
1272                }}
1273                """).format(field_id=c.id, value=c.value))
1274
1275    # Parse fields linearly.
1276    if packet.fields:
1277        code.append("// Parse packet field values.")
1278        code.append(f"pdl::packet::slice span = {span};")
1279        for f in packet.fields:
1280            if isinstance(f, ast.SizeField):
1281                code.append(f"{get_cxx_scalar_type(f.width)} {f.field_id}_size;")
1282            elif isinstance(f, (ast.SizeField, ast.CountField)):
1283                code.append(f"{get_cxx_scalar_type(f.width)} {f.field_id}_count;")
1284        parser = FieldParser(extract_arrays=False, byteorder=packet.file.byteorder_short)
1285        for f in packet.fields:
1286            parser.parse(f)
1287        parser.done()
1288        code.extend(parser.code)
1289
1290    code.append("return true;")
1291    return '\n'.join(code)
1292
1293
1294def generate_packet_view_friend_classes(packet: ast.PacketDeclaration) -> str:
1295    """Generate the list of friend declarations for a packet.
1296    These are the direct children of the class."""
1297
1298    return [f"friend class {decl.id}View;" for (_, decl) in core.get_derived_packets(packet, traverse=False)]
1299
1300
1301def generate_packet_view(packet: ast.PacketDeclaration) -> str:
1302    """Generate the implementation of the View class for a
1303    packet declaration."""
1304
1305    parent_class = f"{packet.parent.id}View" if packet.parent else "pdl::packet::slice"
1306    field_members = generate_packet_view_field_members(packet)
1307    field_accessors = generate_packet_view_field_accessors(packet)
1308    field_parsers = generate_packet_view_field_parsers(packet)
1309    friend_classes = generate_packet_view_friend_classes(packet)
1310    stringifier = generate_packet_stringifier(packet)
1311    bytes_initializer = f"parent.bytes_" if packet.parent else "parent"
1312
1313    return dedent("""\
1314
1315        class {packet_name}View {{
1316        public:
1317            static {packet_name}View Create({parent_class} const& parent) {{
1318                return {packet_name}View(parent);
1319            }}
1320
1321            {field_accessors}
1322            {stringifier}
1323
1324            bool IsValid() const {{
1325                return valid_;
1326            }}
1327
1328            pdl::packet::slice bytes() const {{
1329                return bytes_;
1330            }}
1331
1332        protected:
1333            explicit {packet_name}View({parent_class} const& parent)
1334                  : bytes_({bytes_initializer}) {{
1335                valid_ = Parse(parent);
1336            }}
1337
1338            bool Parse({parent_class} const& parent) {{
1339                {field_parsers}
1340            }}
1341
1342            bool valid_{{false}};
1343            pdl::packet::slice bytes_;
1344            {field_members}
1345
1346            {friend_classes}
1347        }};
1348        """).format(packet_name=packet.id,
1349                    parent_class=parent_class,
1350                    bytes_initializer=bytes_initializer,
1351                    field_accessors=indent(field_accessors, 1),
1352                    field_members=indent(field_members, 1),
1353                    field_parsers=indent(field_parsers, 2),
1354                    friend_classes=indent(friend_classes, 1),
1355                    stringifier=indent(stringifier, 1))
1356
1357
1358def generate_packet_constructor(struct: ast.StructDeclaration, constructor_name: str) -> str:
1359    """Generate the implementation of the constructor for a
1360    struct declaration."""
1361
1362    constructor_params = []
1363    constructor_initializers = []
1364    inherited_fields = core.get_unconstrained_parent_fields(struct)
1365    payload_initializer = ''
1366    parent_initializer = []
1367
1368    for field in inherited_fields:
1369        if isinstance(field, ast.ArrayField) and field.size:
1370            element_type = field.type_id or get_cxx_scalar_type(field.width)
1371            constructor_params.append(f"std::array<{element_type}, {field.size}> {field.id}")
1372        elif isinstance(field, ast.ArrayField):
1373            element_type = field.type_id or get_cxx_scalar_type(field.width)
1374            constructor_params.append(f"std::vector<{element_type}> {field.id}")
1375        elif isinstance(field, ast.ScalarField):
1376            backing_type = get_cxx_scalar_type(field.width)
1377            constructor_params.append(f"{backing_type} {field.id}")
1378        elif (isinstance(field, ast.TypedefField) and isinstance(field.type, ast.EnumDeclaration)):
1379            constructor_params.append(f"{field.type_id} {field.id}")
1380        elif isinstance(field, ast.TypedefField):
1381            constructor_params.append(f"{field.type_id} {field.id}")
1382
1383    for field in struct.fields:
1384        if field.cond_for:
1385            pass
1386        elif isinstance(field, (ast.PayloadField, ast.BodyField)):
1387            constructor_params.append("std::vector<uint8_t> payload")
1388            if struct.parent:
1389                payload_initializer = f"payload_ = std::move(payload);"
1390            else:
1391                constructor_initializers.append("payload_(std::move(payload))")
1392        elif isinstance(field, ast.ArrayField) and field.size:
1393            element_type = field.type_id or get_cxx_scalar_type(field.width)
1394            constructor_params.append(f"std::array<{element_type}, {field.size}> {field.id}")
1395            constructor_initializers.append(f"{field.id}_(std::move({field.id}))")
1396        elif isinstance(field, ast.ArrayField):
1397            element_type = field.type_id or get_cxx_scalar_type(field.width)
1398            constructor_params.append(f"std::vector<{element_type}> {field.id}")
1399            constructor_initializers.append(f"{field.id}_(std::move({field.id}))")
1400        elif isinstance(field, ast.ScalarField):
1401            backing_type = get_cxx_scalar_type(field.width)
1402            field_type = f"std::optional<{backing_type}>" if field.cond else backing_type
1403            constructor_params.append(f"{field_type} {field.id}")
1404            constructor_initializers.append(f"{field.id}_({field.id})")
1405        elif (isinstance(field, ast.TypedefField) and isinstance(field.type, ast.EnumDeclaration)):
1406            field_type = f"std::optional<{field.type_id}>" if field.cond else field.type_id
1407            constructor_params.append(f"{field_type} {field.id}")
1408            constructor_initializers.append(f"{field.id}_({field.id})")
1409        elif isinstance(field, ast.TypedefField):
1410            field_type = f"std::optional<{field.type_id}>" if field.cond else field.type_id
1411            constructor_params.append(f"{field_type} {field.id}")
1412            constructor_initializers.append(f"{field.id}_(std::move({field.id}))")
1413
1414    if not constructor_params:
1415        return ""
1416
1417    if struct.parent:
1418        fields = core.get_unconstrained_parent_fields(struct.parent) + struct.parent.fields
1419        parent_constructor_params = []
1420        for field in fields:
1421            constraints = [c for c in struct.constraints if c.id == getattr(field, 'id', None)]
1422            if field.cond_for:
1423                pass
1424            elif isinstance(field, (ast.PayloadField, ast.BodyField)):
1425                parent_constructor_params.append("std::vector<uint8_t>{}")
1426            elif isinstance(field, ast.ArrayField):
1427                parent_constructor_params.append(f"std::move({field.id})")
1428            elif isinstance(field, ast.ScalarField) and constraints:
1429                parent_constructor_params.append(f"{constraints[0].value}")
1430            elif isinstance(field, ast.ScalarField):
1431                parent_constructor_params.append(f"{field.id}")
1432            elif (isinstance(field, ast.TypedefField) and isinstance(field.type, ast.EnumDeclaration) and constraints):
1433                parent_constructor_params.append(f"{field.type_id}::{constraints[0].tag_id}")
1434            elif (isinstance(field, ast.TypedefField) and isinstance(field.type, ast.EnumDeclaration)):
1435                parent_constructor_params.append(f"{field.id}")
1436            elif isinstance(field, ast.TypedefField):
1437                parent_constructor_params.append(f"std::move({field.id})")
1438        parent_constructor_params = ', '.join(parent_constructor_params)
1439        parent_initializer = [f"{struct.parent_id}Builder({parent_constructor_params})"]
1440
1441    explicit = 'explicit ' if len(constructor_params) == 1 else ''
1442    constructor_params = ', '.join(constructor_params)
1443    constructor_initializers = ', '.join(parent_initializer + constructor_initializers)
1444
1445    return dedent("""\
1446        {explicit}{constructor_name}({constructor_params})
1447            : {constructor_initializers} {{
1448        {payload_initializer}
1449    }}""").format(explicit=explicit,
1450                  constructor_name=constructor_name,
1451                  constructor_params=constructor_params,
1452                  payload_initializer=payload_initializer,
1453                  constructor_initializers=constructor_initializers)
1454
1455
1456def generate_packet_creator(packet: ast.PacketDeclaration) -> str:
1457    """Generate the implementation of the creator for a
1458    struct declaration."""
1459
1460    constructor_name = f"{packet.id}Builder"
1461    creator_params = []
1462    constructor_params = []
1463    fields = core.get_unconstrained_parent_fields(packet) + packet.fields
1464
1465    for field in fields:
1466        if field.cond_for:
1467            # Scalar fields used as condition for optional fields are treated
1468            # as fixed fields since their value is tied to the value of the
1469            # optional field.
1470            pass
1471        elif isinstance(field, (ast.PayloadField, ast.BodyField)):
1472            creator_params.append("std::vector<uint8_t> payload")
1473            constructor_params.append("std::move(payload)")
1474        elif isinstance(field, ast.ArrayField) and field.size:
1475            element_type = field.type_id or get_cxx_scalar_type(field.width)
1476            creator_params.append(f"std::array<{element_type}, {field.size}> {field.id}")
1477            constructor_params.append(f"std::move({field.id})")
1478        elif isinstance(field, ast.ArrayField):
1479            element_type = field.type_id or get_cxx_scalar_type(field.width)
1480            creator_params.append(f"std::vector<{element_type}> {field.id}")
1481            constructor_params.append(f"std::move({field.id})")
1482        elif isinstance(field, ast.ScalarField):
1483            backing_type = get_cxx_scalar_type(field.width)
1484            field_type = f"std::optional<{backing_type}>" if field.cond else backing_type
1485            creator_params.append(f"{field_type} {field.id}")
1486            constructor_params.append(f"{field.id}")
1487        elif (isinstance(field, ast.TypedefField) and isinstance(field.type, ast.EnumDeclaration)):
1488            field_type = f"std::optional<{field.type_id}>" if field.cond else field.type_id
1489            creator_params.append(f"{field_type} {field.id}")
1490            constructor_params.append(f"{field.id}")
1491        elif isinstance(field, ast.TypedefField):
1492            field_type = f"std::optional<{field.type_id}>" if field.cond else field.type_id
1493            creator_params.append(f"{field_type} {field.id}")
1494            constructor_params.append(f"std::move({field.id})")
1495
1496    creator_params = ', '.join(creator_params)
1497    constructor_params = ', '.join(constructor_params)
1498
1499    return dedent("""\
1500        static std::unique_ptr<{constructor_name}> Create({creator_params}) {{
1501            return std::make_unique<{constructor_name}>({constructor_params});
1502        }}""").format(constructor_name=constructor_name,
1503                      creator_params=creator_params,
1504                      constructor_params=constructor_params)
1505
1506
1507def generate_packet_builder(packet: ast.PacketDeclaration) -> str:
1508    """Generate the implementation of the Builder class for a
1509    packet declaration."""
1510
1511    class_name = f'{packet.id}Builder'
1512    parent_class = f'{packet.parent_id}Builder' if packet.parent_id else "pdl::packet::Builder"
1513    builder_constructor = generate_packet_constructor(packet, constructor_name=class_name)
1514    builder_creator = generate_packet_creator(packet)
1515    field_members = generate_packet_field_members(packet)
1516    field_serializers = generate_packet_field_serializers(packet)
1517    size_getter = generate_packet_size_getter(packet)
1518    array_field_size_getters = generate_array_field_size_getters(packet)
1519
1520    return dedent("""\
1521
1522        class {class_name} : public {parent_class} {{
1523        public:
1524            ~{class_name}() override = default;
1525            {class_name}() = default;
1526            {class_name}({class_name} const&) = default;
1527            {class_name}({class_name}&&) = default;
1528            {class_name}& operator=({class_name} const&) = default;
1529            {builder_constructor}
1530            {builder_creator}
1531
1532            void Serialize(std::vector<uint8_t>& output) const override {{
1533                {field_serializers}
1534            }}
1535
1536            size_t GetSize() const override {{
1537                {size_getter}
1538            }}
1539
1540            {array_field_size_getters}
1541            {field_members}
1542        }};
1543        """).format(class_name=f'{packet.id}Builder',
1544                    parent_class=parent_class,
1545                    builder_constructor=builder_constructor,
1546                    builder_creator=builder_creator,
1547                    field_members=indent(field_members, 1),
1548                    field_serializers=indent(field_serializers, 2),
1549                    size_getter=indent(size_getter, 1),
1550                    array_field_size_getters=indent(array_field_size_getters, 1))
1551
1552
1553def generate_struct_field_parsers(struct: ast.StructDeclaration) -> str:
1554    """Generate the struct parser. The validator will extract
1555    the fields it can in a pre-parsing phase. """
1556
1557    code = []
1558    parsed_fields = []
1559    post_processing = []
1560
1561    for field in struct.fields:
1562        if field.cond_for:
1563            # Scalar fields used as condition for optional fields are treated
1564            # as fixed fields since their value is tied to the value of the
1565            # optional field.
1566            pass
1567        elif isinstance(field, (ast.PayloadField, ast.BodyField)):
1568            code.append("std::vector<uint8_t> payload_;")
1569            parsed_fields.append("std::move(payload_)")
1570        elif isinstance(field, ast.ArrayField) and field.size:
1571            element_type = field.type_id or get_cxx_scalar_type(field.width)
1572            code.append(f"std::array<{element_type}, {field.size}> {field.id}_;")
1573            parsed_fields.append(f"std::move({field.id}_)")
1574        elif isinstance(field, ast.ArrayField):
1575            element_type = field.type_id or get_cxx_scalar_type(field.width)
1576            code.append(f"std::vector<{element_type}> {field.id}_;")
1577            parsed_fields.append(f"std::move({field.id}_)")
1578        elif isinstance(field, ast.ScalarField):
1579            backing_type = get_cxx_scalar_type(field.width)
1580            field_type = f"std::optional<{backing_type}>" if field.cond else backing_type
1581            code.append(f"{field_type} {field.id}_;")
1582            parsed_fields.append(f"{field.id}_")
1583        elif (isinstance(field, ast.TypedefField) and isinstance(field.type, ast.EnumDeclaration)):
1584            field_type = f"std::optional<{field.type_id}>" if field.cond else field.type_id
1585            code.append(f"{field_type} {field.id}_;")
1586            parsed_fields.append(f"{field.id}_")
1587        elif isinstance(field, ast.TypedefField):
1588            field_type = f"std::optional<{field.type_id}>" if field.cond else field.type_id
1589            code.append(f"{field_type} {field.id}_;")
1590            parsed_fields.append(f"std::move({field.id}_)")
1591        elif isinstance(field, ast.SizeField):
1592            code.append(f"{get_cxx_scalar_type(field.width)} {field.field_id}_size;")
1593        elif isinstance(field, ast.CountField):
1594            code.append(f"{get_cxx_scalar_type(field.width)} {field.field_id}_count;")
1595
1596    parser = FieldParser(extract_arrays=True, byteorder=struct.file.byteorder_short)
1597    for f in struct.fields:
1598        parser.parse(f)
1599    parser.done()
1600    code.extend(parser.code)
1601
1602    parsed_fields = ', '.join(parsed_fields)
1603    code.append(f"*output = {struct.id}({parsed_fields});")
1604    code.append("return true;")
1605    return '\n'.join(code)
1606
1607
1608def generate_struct_declaration(struct: ast.StructDeclaration) -> str:
1609    """Generate the implementation of the class for a
1610    struct declaration."""
1611
1612    if struct.parent:
1613        raise Exception("Struct declaration with parents are not supported")
1614
1615    struct_constructor = generate_packet_constructor(struct, constructor_name=struct.id)
1616    field_members = generate_packet_field_members(struct)
1617    field_parsers = generate_struct_field_parsers(struct)
1618    field_serializers = generate_packet_field_serializers(struct)
1619    size_getter = generate_packet_size_getter(struct)
1620    array_field_size_getters = generate_array_field_size_getters(struct)
1621    stringifier = generate_packet_stringifier(struct)
1622
1623    return dedent("""\
1624
1625        class {struct_name} : public pdl::packet::Builder {{
1626        public:
1627            ~{struct_name}() override = default;
1628            {struct_name}() = default;
1629            {struct_name}({struct_name} const&) = default;
1630            {struct_name}({struct_name}&&) = default;
1631            {struct_name}& operator=({struct_name} const&) = default;
1632            {struct_constructor}
1633
1634            static bool Parse(pdl::packet::slice& span, {struct_name}* output) {{
1635                {field_parsers}
1636            }}
1637
1638            void Serialize(std::vector<uint8_t>& output) const override {{
1639                {field_serializers}
1640            }}
1641
1642            size_t GetSize() const override {{
1643                {size_getter}
1644            }}
1645
1646            {array_field_size_getters}
1647            {stringifier}
1648            {field_members}
1649        }};
1650        """).format(struct_name=struct.id,
1651                    struct_constructor=struct_constructor,
1652                    field_members=indent(field_members, 1),
1653                    field_parsers=indent(field_parsers, 2),
1654                    field_serializers=indent(field_serializers, 2),
1655                    stringifier=indent(stringifier, 1),
1656                    size_getter=indent(size_getter, 1),
1657                    array_field_size_getters=indent(array_field_size_getters, 1))
1658
1659
1660def run(input: argparse.FileType, output: argparse.FileType, namespace: Optional[str], include_header: List[str],
1661        using_namespace: List[str], exclude_declaration: List[str]):
1662
1663    file = ast.File.from_json(json.load(input))
1664    core.desugar(file)
1665
1666    include_header = '\n'.join([f'#include <{header}>' for header in include_header])
1667    using_namespace = '\n'.join([f'using namespace {namespace};' for namespace in using_namespace])
1668    open_namespace = f"namespace {namespace} {{" if namespace else ""
1669    close_namespace = f"}}  // {namespace}" if namespace else ""
1670
1671    output.write(
1672        dedent("""\
1673        // File generated from {input_name}, with the command:
1674        //  {input_command}
1675        // /!\\ Do not edit by hand
1676
1677        #pragma once
1678
1679        #include <cstdint>
1680        #include <string>
1681        #include <optional>
1682        #include <utility>
1683        #include <vector>
1684
1685        #include <packet_runtime.h>
1686
1687        {include_header}
1688        {using_namespace}
1689
1690        #ifndef _ASSERT_VALID
1691        #ifdef ASSERT
1692        #define _ASSERT_VALID ASSERT
1693        #else
1694        #include <cassert>
1695        #define _ASSERT_VALID assert
1696        #endif  // ASSERT
1697        #endif  // !_ASSERT_VALID
1698
1699        {open_namespace}
1700        """).format(input_name=input.name,
1701                    input_command=' '.join(sys.argv),
1702                    include_header=include_header,
1703                    using_namespace=using_namespace,
1704                    open_namespace=open_namespace))
1705
1706    # Forward declarations for packet classes.
1707    # Required for the friend class specifiers.
1708    for d in file.declarations:
1709        if isinstance(d, ast.PacketDeclaration):
1710            output.write(f"class {d.id}View;\n")
1711
1712    for d in file.declarations:
1713        if d.id in exclude_declaration:
1714            continue
1715
1716        if isinstance(d, ast.EnumDeclaration):
1717            output.write(generate_enum_declaration(d))
1718            output.write(generate_enum_to_text(d))
1719        elif isinstance(d, ast.PacketDeclaration):
1720            output.write(generate_packet_view(d))
1721            output.write(generate_packet_builder(d))
1722        elif isinstance(d, ast.StructDeclaration):
1723            output.write(generate_struct_declaration(d))
1724
1725    output.write(f"{close_namespace}\n")
1726
1727
1728def main() -> int:
1729    """Generate cxx PDL backend."""
1730    parser = argparse.ArgumentParser(description=__doc__)
1731    parser.add_argument('--input', type=argparse.FileType('r'), default=sys.stdin, help='Input PDL-JSON source')
1732    parser.add_argument('--output', type=argparse.FileType('w'), default=sys.stdout, help='Output C++ file')
1733    parser.add_argument('--namespace', type=str, help='Generated module namespace')
1734    parser.add_argument('--include-header', type=str, default=[], action='append', help='Added include directives')
1735    parser.add_argument('--using-namespace',
1736                        type=str,
1737                        default=[],
1738                        action='append',
1739                        help='Added using namespace statements')
1740    parser.add_argument('--exclude-declaration',
1741                        type=str,
1742                        default=[],
1743                        action='append',
1744                        help='Exclude declaration from the generated output')
1745    return run(**vars(parser.parse_args()))
1746
1747
1748if __name__ == '__main__':
1749    sys.exit(main())
1750