1#!/usr/bin/env python3
2
3# Copyright 2023 Google LLC
4#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9#     https://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16
17import argparse
18from dataclasses import dataclass, field
19import json
20from pathlib import Path
21import sys
22from textwrap import dedent
23from typing import List, Tuple, Union, Optional
24
25from pdl import ast, core
26from pdl.utils import indent
27
28
29def mask(width: int) -> str:
30    return hex((1 << width) - 1)
31
32
33def generate_prelude() -> str:
34    return dedent("""\
35        from dataclasses import dataclass, field, fields
36        from typing import Optional, List, Tuple, Union
37        import enum
38        import inspect
39        import math
40
41        @dataclass
42        class Packet:
43            payload: Optional[bytes] = field(repr=False, default_factory=bytes, compare=False)
44
45            @classmethod
46            def parse_all(cls, span: bytes) -> 'Packet':
47                packet, remain = getattr(cls, 'parse')(span)
48                if len(remain) > 0:
49                    raise Exception('Unexpected parsing remainder')
50                return packet
51
52            @property
53            def size(self) -> int:
54                pass
55
56            def show(self, prefix: str = ''):
57                print(f'{self.__class__.__name__}')
58
59                def print_val(p: str, pp: str, name: str, align: int, typ, val):
60                    if name == 'payload':
61                        pass
62
63                    # Scalar fields.
64                    elif typ is int:
65                        print(f'{p}{name:{align}} = {val} (0x{val:x})')
66
67                    # Byte fields.
68                    elif typ is bytes:
69                        print(f'{p}{name:{align}} = [', end='')
70                        line = ''
71                        n_pp = ''
72                        for (idx, b) in enumerate(val):
73                            if idx > 0 and idx % 8 == 0:
74                                print(f'{n_pp}{line}')
75                                line = ''
76                                n_pp = pp + (' ' * (align + 4))
77                            line += f' {b:02x}'
78                        print(f'{n_pp}{line} ]')
79
80                    # Enum fields.
81                    elif inspect.isclass(typ) and issubclass(typ, enum.IntEnum):
82                        print(f'{p}{name:{align}} = {typ.__name__}::{val.name} (0x{val:x})')
83
84                    # Struct fields.
85                    elif inspect.isclass(typ) and issubclass(typ, globals().get('Packet')):
86                        print(f'{p}{name:{align}} = ', end='')
87                        val.show(prefix=pp)
88
89                    # Array fields.
90                    elif getattr(typ, '__origin__', None) == list:
91                        print(f'{p}{name:{align}}')
92                        last = len(val) - 1
93                        align = 5
94                        for (idx, elt) in enumerate(val):
95                            n_p  = pp + ('├── ' if idx != last else '└── ')
96                            n_pp = pp + ('│   ' if idx != last else '    ')
97                            print_val(n_p, n_pp, f'[{idx}]', align, typ.__args__[0], val[idx])
98
99                    # Custom fields.
100                    elif inspect.isclass(typ):
101                        print(f'{p}{name:{align}} = {repr(val)}')
102
103                    else:
104                        print(f'{p}{name:{align}} = ##{typ}##')
105
106                last = len(fields(self)) - 1
107                align = max(len(f.name) for f in fields(self) if f.name != 'payload')
108
109                for (idx, f) in enumerate(fields(self)):
110                    p  = prefix + ('├── ' if idx != last else '└── ')
111                    pp = prefix + ('│   ' if idx != last else '    ')
112                    val = getattr(self, f.name)
113
114                    print_val(p, pp, f.name, align, f.type, val)
115        """)
116
117
118@dataclass
119class FieldParser:
120    byteorder: str
121    offset: int = 0
122    shift: int = 0
123    chunk: List[Tuple[int, int, ast.Field]] = field(default_factory=lambda: [])
124    unchecked_code: List[str] = field(default_factory=lambda: [])
125    code: List[str] = field(default_factory=lambda: [])
126
127    def unchecked_append_(self, line: str):
128        """Append unchecked field parsing code.
129        The function check_size_ must be called to generate a size guard
130        after parsing is completed."""
131        self.unchecked_code.append(line)
132
133    def append_(self, code: str):
134        """Append field parsing code.
135        There must be no unchecked code left before this function is called."""
136        assert len(self.unchecked_code) == 0
137        self.code.extend(code.split('\n'))
138
139    def check_size_(self, size: str):
140        """Generate a check of the current span size."""
141        self.append_(f"if len(span) < {size}:")
142        self.append_(f"    raise Exception('Invalid packet size')")
143
144    def check_code_(self):
145        """Generate a size check for pending field parsing."""
146        if len(self.unchecked_code) > 0:
147            assert len(self.chunk) == 0
148            unchecked_code = self.unchecked_code
149            self.unchecked_code = []
150            self.check_size_(str(self.offset))
151            self.code.extend(unchecked_code)
152
153    def consume_span_(self, keep: int = 0) -> str:
154        """Skip consumed span bytes."""
155        if self.offset > 0:
156            self.check_code_()
157            self.append_(f'span = span[{self.offset - keep}:]')
158            self.offset = 0
159
160    def parse_array_element_dynamic_(self, field: ast.ArrayField, span: str):
161        """Parse a single array field element of variable size."""
162        if isinstance(field.type, ast.StructDeclaration):
163            self.append_(f"    element, {span} = {field.type_id}.parse({span})")
164            self.append_(f"    {field.id}.append(element)")
165        else:
166            raise Exception(f'Unexpected array element type {field.type_id} {field.width}')
167
168    def parse_array_element_static_(self, field: ast.ArrayField, span: str):
169        """Parse a single array field element of constant size."""
170        if field.width is not None:
171            element = f"int.from_bytes({span}, byteorder='{self.byteorder}')"
172            self.append_(f"    {field.id}.append({element})")
173        elif isinstance(field.type, ast.EnumDeclaration):
174            element = f"int.from_bytes({span}, byteorder='{self.byteorder}')"
175            element = f"{field.type_id}({element})"
176            self.append_(f"    {field.id}.append({element})")
177        else:
178            element = f"{field.type_id}.parse_all({span})"
179            self.append_(f"    {field.id}.append({element})")
180
181    def parse_byte_array_field_(self, field: ast.ArrayField):
182        """Parse the selected u8 array field."""
183        array_size = core.get_array_field_size(field)
184        padded_size = field.padded_size
185
186        # Shift the span to reset the offset to 0.
187        self.consume_span_()
188
189        # Derive the array size.
190        if isinstance(array_size, int):
191            size = array_size
192        elif isinstance(array_size, ast.SizeField):
193            size = f'{field.id}_size - {field.size_modifier}' if field.size_modifier else f'{field.id}_size'
194        elif isinstance(array_size, ast.CountField):
195            size = f'{field.id}_count'
196        else:
197            size = None
198
199        # Parse from the padded array if padding is present.
200        if padded_size and size is not None:
201            self.check_size_(padded_size)
202            self.append_(f"if {size} > {padded_size}:")
203            self.append_("    raise Exception('Array size is larger than the padding size')")
204            self.append_(f"fields['{field.id}'] = list(span[:{size}])")
205            self.append_(f"span = span[{padded_size}:]")
206
207        elif size is not None:
208            self.check_size_(size)
209            self.append_(f"fields['{field.id}'] = list(span[:{size}])")
210            self.append_(f"span = span[{size}:]")
211
212        else:
213            self.append_(f"fields['{field.id}'] = list(span)")
214            self.append_(f"span = bytes()")
215
216    def parse_array_field_(self, field: ast.ArrayField):
217        """Parse the selected array field."""
218        array_size = core.get_array_field_size(field)
219        element_width = core.get_array_element_size(field)
220        padded_size = field.padded_size
221
222        if element_width:
223            if element_width % 8 != 0:
224                raise Exception('Array element size is not a multiple of 8')
225            element_width = int(element_width / 8)
226
227        if isinstance(array_size, int):
228            size = None
229            count = array_size
230        elif isinstance(array_size, ast.SizeField):
231            size = f'{field.id}_size'
232            count = None
233        elif isinstance(array_size, ast.CountField):
234            size = None
235            count = f'{field.id}_count'
236        else:
237            size = None
238            count = None
239
240        # Shift the span to reset the offset to 0.
241        self.consume_span_()
242
243        # Apply the size modifier.
244        if field.size_modifier and size:
245            self.append_(f"{size} = {size} - {field.size_modifier}")
246
247        # Parse from the padded array if padding is present.
248        if padded_size:
249            self.check_size_(padded_size)
250            self.append_(f"remaining_span = span[{padded_size}:]")
251            self.append_(f"span = span[:{padded_size}]")
252
253        # The element width is not known, but the array full octet size
254        # is known by size field. Parse elements item by item as a vector.
255        if element_width is None and size is not None:
256            self.check_size_(size)
257            self.append_(f"array_span = span[:{size}]")
258            self.append_(f"{field.id} = []")
259            self.append_("while len(array_span) > 0:")
260            self.parse_array_element_dynamic_(field, 'array_span')
261            self.append_(f"fields['{field.id}'] = {field.id}")
262            self.append_(f"span = span[{size}:]")
263
264        # The element width is not known, but the array element count
265        # is known statically or by count field.
266        # Parse elements item by item as a vector.
267        elif element_width is None and count is not None:
268            self.append_(f"{field.id} = []")
269            self.append_(f"for n in range({count}):")
270            self.parse_array_element_dynamic_(field, 'span')
271            self.append_(f"fields['{field.id}'] = {field.id}")
272
273        # Neither the count not size is known,
274        # parse elements until the end of the span.
275        elif element_width is None:
276            self.append_(f"{field.id} = []")
277            self.append_("while len(span) > 0:")
278            self.parse_array_element_dynamic_(field, 'span')
279            self.append_(f"fields['{field.id}'] = {field.id}")
280
281        # The element width is known, and the array element count is known
282        # statically, or by count field.
283        elif count is not None:
284            array_size = (f'{count}' if element_width == 1 else f'{count} * {element_width}')
285            self.check_size_(array_size)
286            self.append_(f"{field.id} = []")
287            self.append_(f"for n in range({count}):")
288            span = ('span[n:n + 1]' if element_width == 1 else f'span[n * {element_width}:(n + 1) * {element_width}]')
289            self.parse_array_element_static_(field, span)
290            self.append_(f"fields['{field.id}'] = {field.id}")
291            self.append_(f"span = span[{array_size}:]")
292
293        # The element width is known, and the array full size is known
294        # by size field, or unknown (in which case it is the remaining span
295        # length).
296        else:
297            if size is not None:
298                self.check_size_(size)
299            array_size = size or 'len(span)'
300            if element_width != 1:
301                self.append_(f"if {array_size} % {element_width} != 0:")
302                self.append_("    raise Exception('Array size is not a multiple of the element size')")
303                self.append_(f"{field.id}_count = int({array_size} / {element_width})")
304                array_count = f'{field.id}_count'
305            else:
306                array_count = array_size
307            self.append_(f"{field.id} = []")
308            self.append_(f"for n in range({array_count}):")
309            span = ('span[n:n + 1]' if element_width == 1 else f'span[n * {element_width}:(n + 1) * {element_width}]')
310            self.parse_array_element_static_(field, span)
311            self.append_(f"fields['{field.id}'] = {field.id}")
312            if size is not None:
313                self.append_(f"span = span[{size}:]")
314            else:
315                self.append_(f"span = bytes()")
316
317        # Drop the padding
318        if padded_size:
319            self.append_(f"span = remaining_span")
320
321    def parse_optional_field_(self, field: ast.Field):
322        """Parse the selected optional field.
323        Optional fields must start and end on a byte boundary."""
324
325        if self.shift != 0:
326            raise Exception('Optional field does not start on an octet boundary')
327        if (isinstance(field, ast.TypedefField) and
328            isinstance(field.type, ast.StructDeclaration) and
329            field.type.parent_id is not None):
330            raise Exception('Derived struct used in optional typedef field')
331
332        self.consume_span_()
333
334        if isinstance(field, ast.ScalarField):
335            self.append_(dedent("""
336            if {cond_id} == {cond_value}:
337                if len(span) < {size}:
338                    raise Exception('Invalid packet size')
339                fields['{field_id}'] = int.from_bytes(span[:{size}], byteorder='{byteorder}')
340                span = span[{size}:]
341            """.format(size=int(field.width / 8),
342                       field_id=field.id,
343                       cond_id=field.cond.id,
344                       cond_value=field.cond.value,
345                       byteorder=self.byteorder)))
346
347        elif isinstance(field, ast.TypedefField) and isinstance(field.type, ast.EnumDeclaration):
348            self.append_(dedent("""
349            if {cond_id} == {cond_value}:
350                if len(span) < {size}:
351                    raise Exception('Invalid packet size')
352                fields['{field_id}'] = {type_id}(
353                    int.from_bytes(span[:{size}], byteorder='{byteorder}'))
354                span = span[{size}:]
355            """.format(size=int(field.type.width / 8),
356                       field_id=field.id,
357                       type_id=field.type_id,
358                       cond_id=field.cond.id,
359                       cond_value=field.cond.value,
360                       byteorder=self.byteorder)))
361
362        elif isinstance(field, ast.TypedefField):
363            self.append_(dedent("""
364            if {cond_id} == {cond_value}:
365                {field_id}, span = {type_id}.parse(span)
366                fields['{field_id}'] = {field_id}
367            """.format(field_id=field.id,
368                       type_id=field.type_id,
369                       cond_id=field.cond.id,
370                       cond_value=field.cond.value)))
371
372        else:
373            raise Exception(f"unsupported field type {field.__class__.__name__}")
374
375    def parse_bit_field_(self, field: ast.Field):
376        """Parse the selected field as a bit field.
377        The field is added to the current chunk. When a byte boundary
378        is reached all saved fields are extracted together."""
379
380        # Add to current chunk.
381        width = core.get_field_size(field)
382        self.chunk.append((self.shift, width, field))
383        self.shift += width
384
385        # Wait for more fields if not on a byte boundary.
386        if (self.shift % 8) != 0:
387            return
388
389        # Parse the backing integer using the configured endiannes,
390        # extract field values.
391        size = int(self.shift / 8)
392        end_offset = self.offset + size
393
394        if size == 1:
395            value = f"span[{self.offset}]"
396        else:
397            span = f"span[{self.offset}:{end_offset}]"
398            self.unchecked_append_(f"value_ = int.from_bytes({span}, byteorder='{self.byteorder}')")
399            value = "value_"
400
401        for shift, width, field in self.chunk:
402            v = (value if len(self.chunk) == 1 and shift == 0 else f"({value} >> {shift}) & {mask(width)}")
403
404            if field.cond_for:
405                self.unchecked_append_(f"{field.id} = {v}")
406            elif isinstance(field, ast.ScalarField):
407                self.unchecked_append_(f"fields['{field.id}'] = {v}")
408            elif isinstance(field, ast.FixedField) and field.enum_id:
409                self.unchecked_append_(f"if {v} != {field.enum_id}.{field.tag_id}:")
410                self.unchecked_append_(f"    raise Exception('Unexpected fixed field value')")
411            elif isinstance(field, ast.FixedField):
412                self.unchecked_append_(f"if {v} != {hex(field.value)}:")
413                self.unchecked_append_(f"    raise Exception('Unexpected fixed field value')")
414            elif isinstance(field, ast.TypedefField):
415                self.unchecked_append_(f"fields['{field.id}'] = {field.type_id}.from_int({v})")
416            elif isinstance(field, ast.SizeField):
417                self.unchecked_append_(f"{field.field_id}_size = {v}")
418            elif isinstance(field, ast.CountField):
419                self.unchecked_append_(f"{field.field_id}_count = {v}")
420            elif isinstance(field, ast.ReservedField):
421                pass
422            else:
423                raise Exception(f'Unsupported bit field type {field.kind}')
424
425        # Reset state.
426        self.offset = end_offset
427        self.shift = 0
428        self.chunk = []
429
430    def parse_typedef_field_(self, field: ast.TypedefField):
431        """Parse a typedef field, to the exclusion of Enum fields."""
432
433        if self.shift != 0:
434            raise Exception('Typedef field does not start on an octet boundary')
435        if (isinstance(field.type, ast.StructDeclaration) and field.type.parent_id is not None):
436            raise Exception('Derived struct used in typedef field')
437
438        width = core.get_declaration_size(field.type)
439        if width is None:
440            self.consume_span_()
441            self.append_(f"{field.id}, span = {field.type_id}.parse(span)")
442            self.append_(f"fields['{field.id}'] = {field.id}")
443        else:
444            if width % 8 != 0:
445                raise Exception('Typedef field type size is not a multiple of 8')
446            width = int(width / 8)
447            end_offset = self.offset + width
448            # Checksum value field is generated alongside checksum start.
449            # Deal with this field as padding.
450            if not isinstance(field.type, ast.ChecksumDeclaration):
451                span = f'span[{self.offset}:{end_offset}]'
452                self.unchecked_append_(f"fields['{field.id}'] = {field.type_id}.parse_all({span})")
453            self.offset = end_offset
454
455    def parse_payload_field_(self, field: Union[ast.BodyField, ast.PayloadField]):
456        """Parse body and payload fields."""
457
458        payload_size = core.get_payload_field_size(field)
459        offset_from_end = core.get_field_offset_from_end(field)
460
461        # If the payload is not byte aligned, do parse the bit fields
462        # that can be extracted, but do not consume the input bytes as
463        # they will also be included in the payload span.
464        if self.shift != 0:
465            if payload_size:
466                raise Exception("Unexpected payload size for non byte aligned payload")
467
468            rounded_size = int((self.shift + 7) / 8)
469            padding_bits = 8 * rounded_size - self.shift
470            self.parse_bit_field_(core.make_reserved_field(padding_bits))
471            self.consume_span_(rounded_size)
472        else:
473            self.consume_span_()
474
475        # The payload or body has a known size.
476        # Consume the payload and update the span in case
477        # fields are placed after the payload.
478        if payload_size:
479            if getattr(field, 'size_modifier', None):
480                self.append_(f"{field.id}_size -= {field.size_modifier}")
481            self.check_size_(f'{field.id}_size')
482            self.append_(f"payload = span[:{field.id}_size]")
483            self.append_(f"span = span[{field.id}_size:]")
484        # The payload or body is the last field of a packet,
485        # consume the remaining span.
486        elif offset_from_end == 0:
487            self.append_(f"payload = span")
488            self.append_(f"span = bytes([])")
489        # The payload or body is followed by fields of static size.
490        # Consume the span that is not reserved for the following fields.
491        elif offset_from_end is not None:
492            if (offset_from_end % 8) != 0:
493                raise Exception('Payload field offset from end of packet is not a multiple of 8')
494            offset_from_end = int(offset_from_end / 8)
495            self.check_size_(f'{offset_from_end}')
496            self.append_(f"payload = span[:-{offset_from_end}]")
497            self.append_(f"span = span[-{offset_from_end}:]")
498        self.append_(f"fields['payload'] = payload")
499
500    def parse_checksum_field_(self, field: ast.ChecksumField):
501        """Generate a checksum check."""
502
503        # The checksum value field can be read starting from the current
504        # offset if the fields in between are of fixed size, or from the end
505        # of the span otherwise.
506        self.consume_span_()
507        value_field = core.get_packet_field(field.parent, field.field_id)
508        offset_from_start = 0
509        offset_from_end = 0
510        start_index = field.parent.fields.index(field)
511        value_index = field.parent.fields.index(value_field)
512        value_size = int(core.get_field_size(value_field) / 8)
513
514        for f in field.parent.fields[start_index + 1:value_index]:
515            size = core.get_field_size(f)
516            if size is None:
517                offset_from_start = None
518                break
519            else:
520                offset_from_start += size
521
522        trailing_fields = field.parent.fields[value_index:]
523        trailing_fields.reverse()
524        for f in trailing_fields:
525            size = core.get_field_size(f)
526            if size is None:
527                offset_from_end = None
528                break
529            else:
530                offset_from_end += size
531
532        if offset_from_start is not None:
533            if offset_from_start % 8 != 0:
534                raise Exception('Checksum value field is not aligned to an octet boundary')
535            offset_from_start = int(offset_from_start / 8)
536            checksum_span = f'span[:{offset_from_start}]'
537            if value_size > 1:
538                start = offset_from_start
539                end = offset_from_start + value_size
540                value = f"int.from_bytes(span[{start}:{end}], byteorder='{self.byteorder}')"
541            else:
542                value = f'span[{offset_from_start}]'
543            self.check_size_(offset_from_start + value_size)
544
545        elif offset_from_end is not None:
546            sign = ''
547            if offset_from_end % 8 != 0:
548                raise Exception('Checksum value field is not aligned to an octet boundary')
549            offset_from_end = int(offset_from_end / 8)
550            checksum_span = f'span[:-{offset_from_end}]'
551            if value_size > 1:
552                start = offset_from_end
553                end = offset_from_end - value_size
554                value = f"int.from_bytes(span[-{start}:-{end}], byteorder='{self.byteorder}')"
555            else:
556                value = f'span[-{offset_from_end}]'
557            self.check_size_(offset_from_end)
558
559        else:
560            raise Exception('Checksum value field cannot be read at constant offset')
561
562        self.append_(f"{value_field.id} = {value}")
563        self.append_(f"fields['{value_field.id}'] = {value_field.id}")
564        self.append_(f"computed_{value_field.id} = {value_field.type.function}({checksum_span})")
565        self.append_(f"if computed_{value_field.id} != {value_field.id}:")
566        self.append_("    raise Exception(f'Invalid checksum computation:" +
567                     f" {{computed_{value_field.id}}} != {{{value_field.id}}}')")
568
569    def parse(self, field: ast.Field):
570        if field.cond:
571            self.parse_optional_field_(field)
572
573        # Field has bit granularity.
574        # Append the field to the current chunk,
575        # check if a byte boundary was reached.
576        elif core.is_bit_field(field):
577            self.parse_bit_field_(field)
578
579        # Padding fields.
580        elif isinstance(field, ast.PaddingField):
581            pass
582
583        # Array fields.
584        elif isinstance(field, ast.ArrayField) and field.width == 8:
585            self.parse_byte_array_field_(field)
586
587        elif isinstance(field, ast.ArrayField):
588            self.parse_array_field_(field)
589
590        # Other typedef fields.
591        elif isinstance(field, ast.TypedefField):
592            self.parse_typedef_field_(field)
593
594        # Payload and body fields.
595        elif isinstance(field, (ast.PayloadField, ast.BodyField)):
596            self.parse_payload_field_(field)
597
598        # Checksum fields.
599        elif isinstance(field, ast.ChecksumField):
600            self.parse_checksum_field_(field)
601
602        else:
603            raise Exception(f'Unimplemented field type {field.kind}')
604
605    def done(self):
606        self.consume_span_()
607
608
609@dataclass
610class FieldSerializer:
611    byteorder: str
612    shift: int = 0
613    value: List[str] = field(default_factory=lambda: [])
614    code: List[str] = field(default_factory=lambda: [])
615    indent: int = 0
616
617    def indent_(self):
618        self.indent += 1
619
620    def unindent_(self):
621        self.indent -= 1
622
623    def append_(self, line: str):
624        """Append field serializing code."""
625        lines = line.split('\n')
626        self.code.extend(['    ' * self.indent + line for line in lines])
627
628    def extend_(self, value: str, length: int):
629        """Append data to the span being constructed."""
630        if length == 1:
631            self.append_(f"_span.append({value})")
632        else:
633            self.append_(f"_span.extend(int.to_bytes({value}, length={length}, byteorder='{self.byteorder}'))")
634
635    def serialize_array_element_(self, field: ast.ArrayField):
636        """Serialize a single array field element."""
637        if field.width is not None:
638            length = int(field.width / 8)
639            self.extend_('_elt', length)
640        elif isinstance(field.type, ast.EnumDeclaration):
641            length = int(field.type.width / 8)
642            self.extend_('_elt', length)
643        else:
644            self.append_("_span.extend(_elt.serialize())")
645
646    def serialize_array_field_(self, field: ast.ArrayField):
647        """Serialize the selected array field."""
648        if field.padded_size:
649            self.append_(f"_{field.id}_start = len(_span)")
650
651        if field.width == 8:
652            self.append_(f"_span.extend(self.{field.id})")
653        else:
654            self.append_(f"for _elt in self.{field.id}:")
655            self.indent_()
656            self.serialize_array_element_(field)
657            self.unindent_()
658
659        if field.padded_size:
660            self.append_(f"_span.extend([0] * ({field.padded_size} - len(_span) + _{field.id}_start))")
661
662    def serialize_optional_field_(self, field: ast.Field):
663        if isinstance(field, ast.ScalarField):
664            self.append_(dedent(
665                """
666                if self.{field_id} is not None:
667                    _span.extend(int.to_bytes(self.{field_id}, length={size}, byteorder='{byteorder}'))
668                """.format(field_id=field.id,
669                            size=int(field.width / 8),
670                            byteorder=self.byteorder)))
671
672        elif isinstance(field, ast.TypedefField) and isinstance(field.type, ast.EnumDeclaration):
673            self.append_(dedent(
674                """
675                if self.{field_id} is not None:
676                    _span.extend(int.to_bytes(self.{field_id}, length={size}, byteorder='{byteorder}'))
677                """.format(field_id=field.id,
678                            size=int(field.type.width / 8),
679                            byteorder=self.byteorder)))
680
681        elif isinstance(field, ast.TypedefField):
682            self.append_(dedent(
683                """
684                if self.{field_id} is not None:
685                    _span.extend(self.{field_id}.serialize())
686                """.format(field_id=field.id)))
687
688        else:
689            raise Exception(f"unsupported field type {field.__class__.__name__}")
690
691    def serialize_bit_field_(self, field: ast.Field):
692        """Serialize the selected field as a bit field.
693        The field is added to the current chunk. When a byte boundary
694        is reached all saved fields are serialized together."""
695
696        # Add to current chunk.
697        width = core.get_field_size(field)
698        shift = self.shift
699
700        if isinstance(field, str):
701            self.value.append(f"({field} << {shift})")
702        elif field.cond_for:
703            # Scalar field used as condition for an optional field.
704            # The width is always 1, the value is determined from
705            # the presence or absence of the optional field.
706            value_present = field.cond_for.cond.value
707            value_absent = 0 if field.cond_for.cond.value else 1
708            self.value.append(f"(({value_absent} if self.{field.cond_for.id} is None else {value_present}) << {shift})")
709        elif isinstance(field, ast.ScalarField):
710            max_value = (1 << field.width) - 1
711            self.append_(f"if self.{field.id} > {max_value}:")
712            self.append_(f"    print(f\"Invalid value for field {field.parent.id}::{field.id}:" +
713                         f" {{self.{field.id}}} > {max_value}; the value will be truncated\")")
714            self.append_(f"    self.{field.id} &= {max_value}")
715            self.value.append(f"(self.{field.id} << {shift})")
716        elif isinstance(field, ast.FixedField) and field.enum_id:
717            self.value.append(f"({field.enum_id}.{field.tag_id} << {shift})")
718        elif isinstance(field, ast.FixedField):
719            self.value.append(f"({field.value} << {shift})")
720        elif isinstance(field, ast.TypedefField):
721            self.value.append(f"(self.{field.id} << {shift})")
722
723        elif isinstance(field, ast.SizeField):
724            max_size = (1 << field.width) - 1
725            value_field = core.get_packet_field(field.parent, field.field_id)
726            size_modifier = ''
727
728            if getattr(value_field, 'size_modifier', None):
729                size_modifier = f' + {value_field.size_modifier}'
730
731            if isinstance(value_field, (ast.PayloadField, ast.BodyField)):
732                self.append_(f"_payload_size = len(payload or self.payload or []){size_modifier}")
733                self.append_(f"if _payload_size > {max_size}:")
734                self.append_(f"    print(f\"Invalid length for payload field:" +
735                             f"  {{_payload_size}} > {max_size}; the packet cannot be generated\")")
736                self.append_(f"    raise Exception(\"Invalid payload length\")")
737                array_size = "_payload_size"
738            elif isinstance(value_field, ast.ArrayField) and value_field.width:
739                array_size = f"(len(self.{value_field.id}) * {int(value_field.width / 8)}{size_modifier})"
740            elif isinstance(value_field, ast.ArrayField) and isinstance(value_field.type, ast.EnumDeclaration):
741                array_size = f"(len(self.{value_field.id}) * {int(value_field.type.width / 8)}{size_modifier})"
742            elif isinstance(value_field, ast.ArrayField):
743                self.append_(
744                    f"_{value_field.id}_size = sum([elt.size for elt in self.{value_field.id}]){size_modifier}")
745                array_size = f"_{value_field.id}_size"
746            else:
747                raise Exception("Unsupported field type")
748            self.value.append(f"({array_size} << {shift})")
749
750        elif isinstance(field, ast.CountField):
751            max_count = (1 << field.width) - 1
752            self.append_(f"if len(self.{field.field_id}) > {max_count}:")
753            self.append_(f"    print(f\"Invalid length for field {field.parent.id}::{field.field_id}:" +
754                         f"  {{len(self.{field.field_id})}} > {max_count}; the array will be truncated\")")
755            self.append_(f"    del self.{field.field_id}[{max_count}:]")
756            self.value.append(f"(len(self.{field.field_id}) << {shift})")
757        elif isinstance(field, ast.ReservedField):
758            pass
759        else:
760            raise Exception(f'Unsupported bit field type {field.kind}')
761
762        # Check if a byte boundary is reached.
763        self.shift += width
764        if (self.shift % 8) == 0:
765            self.pack_bit_fields_()
766
767    def pack_bit_fields_(self):
768        """Pack serialized bit fields."""
769
770        # Should have an integral number of bytes now.
771        assert (self.shift % 8) == 0
772
773        # Generate the backing integer, and serialize it
774        # using the configured endiannes,
775        size = int(self.shift / 8)
776
777        if len(self.value) == 0:
778            self.append_(f"_span.extend([0] * {size})")
779        elif len(self.value) == 1:
780            self.extend_(self.value[0], size)
781        else:
782            self.append_(f"_value = (")
783            self.append_("    " + " |\n    ".join(self.value))
784            self.append_(")")
785            self.extend_('_value', size)
786
787        # Reset state.
788        self.shift = 0
789        self.value = []
790
791    def serialize_typedef_field_(self, field: ast.TypedefField):
792        """Serialize a typedef field, to the exclusion of Enum fields."""
793
794        if self.shift != 0:
795            raise Exception('Typedef field does not start on an octet boundary')
796        if (isinstance(field.type, ast.StructDeclaration) and field.type.parent_id is not None):
797            raise Exception('Derived struct used in typedef field')
798
799        if isinstance(field.type, ast.ChecksumDeclaration):
800            size = int(field.type.width / 8)
801            self.append_(f"_checksum = {field.type.function}(_span[_checksum_start:])")
802            self.extend_('_checksum', size)
803        else:
804            self.append_(f"_span.extend(self.{field.id}.serialize())")
805
806    def serialize_payload_field_(self, field: Union[ast.BodyField, ast.PayloadField]):
807        """Serialize body and payload fields."""
808
809        if self.shift != 0 and self.byteorder == 'big':
810            raise Exception('Payload field does not start on an octet boundary')
811
812        if self.shift == 0:
813            self.append_(f"_span.extend(payload or self.payload or [])")
814        else:
815            # Supported case of packet inheritance;
816            # the incomplete fields are serialized into
817            # the payload, rather than separately.
818            # First extract the padding bits from the payload,
819            # then recombine them with the bit fields to be serialized.
820            rounded_size = int((self.shift + 7) / 8)
821            padding_bits = 8 * rounded_size - self.shift
822            self.append_(f"_payload = payload or self.payload or bytes()")
823            self.append_(f"if len(_payload) < {rounded_size}:")
824            self.append_(f"    raise Exception(f\"Invalid length for payload field:" +
825                         f"  {{len(_payload)}} < {rounded_size}\")")
826            self.append_(
827                f"_padding = int.from_bytes(_payload[:{rounded_size}], byteorder='{self.byteorder}') >> {self.shift}")
828            self.value.append(f"(_padding << {self.shift})")
829            self.shift += padding_bits
830            self.pack_bit_fields_()
831            self.append_(f"_span.extend(_payload[{rounded_size}:])")
832
833    def serialize_checksum_field_(self, field: ast.ChecksumField):
834        """Generate a checksum check."""
835
836        self.append_("_checksum_start = len(_span)")
837
838    def serialize(self, field: ast.Field):
839        if field.cond:
840            self.serialize_optional_field_(field)
841
842        # Field has bit granularity.
843        # Append the field to the current chunk,
844        # check if a byte boundary was reached.
845        elif core.is_bit_field(field):
846            self.serialize_bit_field_(field)
847
848        # Padding fields.
849        elif isinstance(field, ast.PaddingField):
850            pass
851
852        # Array fields.
853        elif isinstance(field, ast.ArrayField):
854            self.serialize_array_field_(field)
855
856        # Other typedef fields.
857        elif isinstance(field, ast.TypedefField):
858            self.serialize_typedef_field_(field)
859
860        # Payload and body fields.
861        elif isinstance(field, (ast.PayloadField, ast.BodyField)):
862            self.serialize_payload_field_(field)
863
864        # Checksum fields.
865        elif isinstance(field, ast.ChecksumField):
866            self.serialize_checksum_field_(field)
867
868        else:
869            raise Exception(f'Unimplemented field type {field.kind}')
870
871
872def generate_toplevel_packet_serializer(packet: ast.Declaration) -> List[str]:
873    """Generate the serialize() function for a toplevel Packet or Struct
874       declaration."""
875
876    serializer = FieldSerializer(byteorder=packet.file.byteorder)
877    for f in packet.fields:
878        serializer.serialize(f)
879    return ['_span = bytearray()'] + serializer.code + ['return bytes(_span)']
880
881
882def generate_derived_packet_serializer(packet: ast.Declaration) -> List[str]:
883    """Generate the serialize() function for a derived Packet or Struct
884       declaration."""
885
886    packet_shift = core.get_packet_shift(packet)
887    if packet_shift and packet.file.byteorder == 'big':
888        raise Exception(f"Big-endian packet {packet.id} has an unsupported body shift")
889
890    serializer = FieldSerializer(byteorder=packet.file.byteorder, shift=packet_shift)
891    for f in packet.fields:
892        serializer.serialize(f)
893    return ['_span = bytearray()'
894           ] + serializer.code + [f'return {packet.parent.id}.serialize(self, payload = bytes(_span))']
895
896
897def generate_packet_parser(packet: ast.Declaration) -> List[str]:
898    """Generate the parse() function for a toplevel Packet or Struct
899       declaration."""
900
901    packet_shift = core.get_packet_shift(packet)
902    if packet_shift and packet.file.byteorder == 'big':
903        raise Exception(f"Big-endian packet {packet.id} has an unsupported body shift")
904
905    # Convert the packet constraints to a boolean expression.
906    validation = []
907    constraints = core.get_all_packet_constraints(packet)
908    if constraints:
909        cond = []
910        for c in constraints:
911            if c.value is not None:
912                cond.append(f"fields['{c.id}'] != {hex(c.value)}")
913            else:
914                field = core.get_packet_field(packet, c.id)
915                cond.append(f"fields['{c.id}'] != {field.type_id}.{c.tag_id}")
916
917        validation = [f"if {' or '.join(cond)}:", "    raise Exception(\"Invalid constraint field values\")"]
918
919    # Parse fields iteratively.
920    parser = FieldParser(byteorder=packet.file.byteorder, shift=packet_shift)
921    for f in packet.fields:
922        parser.parse(f)
923    parser.done()
924
925    # Specialize to child packets.
926    children = core.get_derived_packets(packet)
927    decl = [] if packet.parent_id else ['fields = {\'payload\': None}']
928    specialization = []
929
930    if len(children) != 0:
931        # Try parsing every child packet successively until one is
932        # successfully parsed. Return a parsing error if none is valid.
933        # Return parent packet if no child packet matches.
934        # TODO: order child packets by decreasing size in case no constraint
935        # is given for specialization.
936        for _, child in children:
937            specialization.append("try:")
938            specialization.append(f"    return {child.id}.parse(fields.copy(), payload)")
939            specialization.append("except Exception as exn:")
940            specialization.append("    pass")
941
942    return decl + validation + parser.code + specialization + [f"return {packet.id}(**fields), span"]
943
944
945def generate_packet_size_getter(packet: ast.Declaration) -> List[str]:
946    constant_width = 0
947    variable_width = []
948    for f in packet.fields:
949        field_size = core.get_field_size(f)
950        if f.cond:
951            if isinstance(f, ast.ScalarField):
952                return f"(0 if self.{f.id} is None else {f.width})"
953            elif isinstance(f, ast.TypedefField) and isinstance(f.type, ast.EnumDeclaration):
954                return f"(0 if self.{f.id} is None else {f.type.width})"
955            elif isinstance(f, ast.TypedefField):
956                return f"(0 if self.{f.id} is None else self.{f.id}.size)"
957            else:
958                raise Exception(f"unsupported field type {f.__class__.__name__}")
959        elif field_size is not None:
960            constant_width += field_size
961        elif isinstance(f, (ast.PayloadField, ast.BodyField)):
962            variable_width.append("len(self.payload)")
963        elif isinstance(f, ast.TypedefField):
964            variable_width.append(f"self.{f.id}.size")
965        elif isinstance(f, ast.ArrayField) and isinstance(f.type, (ast.StructDeclaration, ast.CustomFieldDeclaration)):
966            variable_width.append(f"sum([elt.size for elt in self.{f.id}])")
967        elif isinstance(f, ast.ArrayField) and isinstance(f.type, ast.EnumDeclaration):
968            variable_width.append(f"len(self.{f.id}) * {f.type.width}")
969        elif isinstance(f, ast.ArrayField):
970            variable_width.append(f"len(self.{f.id}) * {int(f.width / 8)}")
971        else:
972            raise Exception("Unsupported field type")
973
974    constant_width = int(constant_width / 8)
975    if len(variable_width) == 0:
976        return [f"return {constant_width}"]
977    elif len(variable_width) == 1 and constant_width:
978        return [f"return {variable_width[0]} + {constant_width}"]
979    elif len(variable_width) == 1:
980        return [f"return {variable_width[0]}"]
981    elif len(variable_width) > 1 and constant_width:
982        return ([f"return {constant_width} + ("] + " +\n    ".join(variable_width).split("\n") + [")"])
983    elif len(variable_width) > 1:
984        return (["return ("] + " +\n    ".join(variable_width).split("\n") + [")"])
985    else:
986        assert False
987
988
989def generate_packet_post_init(decl: ast.Declaration) -> List[str]:
990    """Generate __post_init__ function to set constraint field values."""
991
992    # Gather all constraints from parent packets.
993    constraints = core.get_all_packet_constraints(decl)
994
995    if constraints:
996        code = []
997        for c in constraints:
998            if c.value is not None:
999                code.append(f"self.{c.id} = {c.value}")
1000            else:
1001                field = core.get_packet_field(decl, c.id)
1002                code.append(f"self.{c.id} = {field.type_id}.{c.tag_id}")
1003        return code
1004
1005    else:
1006        return ["pass"]
1007
1008
1009def generate_enum_declaration(decl: ast.EnumDeclaration) -> str:
1010    """Generate the implementation of an enum type."""
1011
1012    enum_name = decl.id
1013    tag_decls = []
1014    for t in decl.tags:
1015        # Enums in python are closed and ranges cannot be represented;
1016        # instead the generated code uses Union[int, Enum]
1017        # when ranges are used.
1018        if t.value is not None:
1019            tag_decls.append(f"{t.id} = {hex(t.value)}")
1020
1021    if core.is_open_enum(decl):
1022        unknown_handler = ["return v"]
1023    else:
1024        unknown_handler = []
1025        for t in decl.tags:
1026            if t.range is not None:
1027                unknown_handler.append(f"if v >= 0x{t.range[0]:x} and v <= 0x{t.range[1]:x}:")
1028                unknown_handler.append(f"    return v")
1029        unknown_handler.append("raise exn")
1030
1031    return dedent("""\
1032
1033        class {enum_name}(enum.IntEnum):
1034            {tag_decls}
1035
1036            @staticmethod
1037            def from_int(v: int) -> Union[int, '{enum_name}']:
1038                try:
1039                    return {enum_name}(v)
1040                except ValueError as exn:
1041                    {unknown_handler}
1042
1043        """).format(enum_name=enum_name,
1044                    tag_decls=indent(tag_decls, 1),
1045                    unknown_handler=indent(unknown_handler, 3))
1046
1047
1048def generate_packet_declaration(packet: ast.Declaration) -> str:
1049    """Generate the implementation a toplevel Packet or Struct
1050       declaration."""
1051
1052    packet_name = packet.id
1053    field_decls = []
1054    for f in packet.fields:
1055        if f.cond:
1056            if isinstance(f, ast.ScalarField):
1057                field_decls.append(f"{f.id}: Optional[int] = field(kw_only=True, default=None)")
1058            elif isinstance(f, ast.TypedefField):
1059                field_decls.append(f"{f.id}: Optional[{f.type_id}] = field(kw_only=True, default=None)")
1060            else:
1061                pass
1062        elif f.cond_for:
1063            # The fields used as condition for optional fields are
1064            # not generated since their value is tied to the value of the
1065            # optional field.
1066            pass
1067        elif isinstance(f, ast.ScalarField):
1068            field_decls.append(f"{f.id}: int = field(kw_only=True, default=0)")
1069        elif isinstance(f, ast.TypedefField):
1070            if isinstance(f.type, ast.EnumDeclaration) and f.type.tags[0].range:
1071                field_decls.append(
1072                    f"{f.id}: {f.type_id} = field(kw_only=True, default={f.type.tags[0].range[0]})")
1073            elif isinstance(f.type, ast.EnumDeclaration):
1074                field_decls.append(
1075                    f"{f.id}: {f.type_id} = field(kw_only=True, default={f.type_id}.{f.type.tags[0].id})")
1076            elif isinstance(f.type, ast.ChecksumDeclaration):
1077                field_decls.append(f"{f.id}: int = field(kw_only=True, default=0)")
1078            elif isinstance(f.type, (ast.StructDeclaration, ast.CustomFieldDeclaration)):
1079                field_decls.append(f"{f.id}: {f.type_id} = field(kw_only=True, default_factory={f.type_id})")
1080            else:
1081                raise Exception("Unsupported typedef field type")
1082        elif isinstance(f, ast.ArrayField) and f.width == 8:
1083            field_decls.append(f"{f.id}: bytearray = field(kw_only=True, default_factory=bytearray)")
1084        elif isinstance(f, ast.ArrayField) and f.width:
1085            field_decls.append(f"{f.id}: List[int] = field(kw_only=True, default_factory=list)")
1086        elif isinstance(f, ast.ArrayField) and f.type_id:
1087            field_decls.append(f"{f.id}: List[{f.type_id}] = field(kw_only=True, default_factory=list)")
1088
1089    if packet.parent_id:
1090        parent_name = packet.parent_id
1091        parent_fields = 'fields: dict, '
1092        serializer = generate_derived_packet_serializer(packet)
1093    else:
1094        parent_name = 'Packet'
1095        parent_fields = ''
1096        serializer = generate_toplevel_packet_serializer(packet)
1097
1098    parser = generate_packet_parser(packet)
1099    size = generate_packet_size_getter(packet)
1100    post_init = generate_packet_post_init(packet)
1101
1102    return dedent("""\
1103
1104        @dataclass
1105        class {packet_name}({parent_name}):
1106            {field_decls}
1107
1108            def __post_init__(self):
1109                {post_init}
1110
1111            @staticmethod
1112            def parse({parent_fields}span: bytes) -> Tuple['{packet_name}', bytes]:
1113                {parser}
1114
1115            def serialize(self, payload: bytes = None) -> bytes:
1116                {serializer}
1117
1118            @property
1119            def size(self) -> int:
1120                {size}
1121        """).format(packet_name=packet_name,
1122                    parent_name=parent_name,
1123                    parent_fields=parent_fields,
1124                    field_decls=indent(field_decls, 1),
1125                    post_init=indent(post_init, 2),
1126                    parser=indent(parser, 2),
1127                    serializer=indent(serializer, 2),
1128                    size=indent(size, 2))
1129
1130
1131def generate_custom_field_declaration_check(decl: ast.CustomFieldDeclaration) -> str:
1132    """Generate the code to validate a user custom field implementation.
1133
1134    This code is to be executed when the generated module is loaded to ensure
1135    the user gets an immediate and clear error message when the provided
1136    custom types do not fit the expected template.
1137    """
1138    return dedent("""\
1139
1140        if (not callable(getattr({custom_field_name}, 'parse', None)) or
1141            not callable(getattr({custom_field_name}, 'parse_all', None))):
1142            raise Exception('The custom field type {custom_field_name} does not implement the parse method')
1143    """).format(custom_field_name=decl.id)
1144
1145
1146def generate_checksum_declaration_check(decl: ast.ChecksumDeclaration) -> str:
1147    """Generate the code to validate a user checksum field implementation.
1148
1149    This code is to be executed when the generated module is loaded to ensure
1150    the user gets an immediate and clear error message when the provided
1151    checksum functions do not fit the expected template.
1152    """
1153    return dedent("""\
1154
1155        if not callable({checksum_name}):
1156            raise Exception('{checksum_name} is not callable')
1157    """).format(checksum_name=decl.id)
1158
1159
1160def run(input: argparse.FileType, output: argparse.FileType, custom_type_location: Optional[str], exclude_declaration: List[str]):
1161    file = ast.File.from_json(json.load(input))
1162    core.desugar(file)
1163
1164    custom_types = []
1165    custom_type_checks = ""
1166    for d in file.declarations:
1167        if d.id in exclude_declaration:
1168            continue
1169
1170        if isinstance(d, ast.CustomFieldDeclaration):
1171            custom_types.append(d.id)
1172            custom_type_checks += generate_custom_field_declaration_check(d)
1173        elif isinstance(d, ast.ChecksumDeclaration):
1174            custom_types.append(d.id)
1175            custom_type_checks += generate_checksum_declaration_check(d)
1176
1177    output.write(f"# File generated from {input.name}, with the command:\n")
1178    output.write(f"#  {' '.join(sys.argv)}\n")
1179    output.write("# /!\\ Do not edit by hand.\n")
1180    if custom_types and custom_type_location:
1181        output.write(f"\nfrom {custom_type_location} import {', '.join(custom_types)}\n")
1182    output.write(generate_prelude())
1183    output.write(custom_type_checks)
1184
1185    for d in file.declarations:
1186        if d.id in exclude_declaration:
1187            continue
1188
1189        if isinstance(d, ast.EnumDeclaration):
1190            output.write(generate_enum_declaration(d))
1191        elif isinstance(d, (ast.PacketDeclaration, ast.StructDeclaration)):
1192            output.write(generate_packet_declaration(d))
1193
1194
1195def main() -> int:
1196    """Generate python PDL backend."""
1197    parser = argparse.ArgumentParser(description=__doc__)
1198    parser.add_argument('--input', type=argparse.FileType('r'), default=sys.stdin, help='Input PDL-JSON source')
1199    parser.add_argument('--output', type=argparse.FileType('w'), default=sys.stdout, help='Output Python file')
1200    parser.add_argument('--custom-type-location',
1201                        type=str,
1202                        required=False,
1203                        help='Module of declaration of custom types')
1204    parser.add_argument('--exclude-declaration',
1205                        type=str,
1206                        default=[],
1207                        action='append',
1208                        help='Exclude declaration from the generated output')
1209    return run(**vars(parser.parse_args()))
1210
1211
1212if __name__ == '__main__':
1213    sys.exit(main())
1214