1#!/usr/bin/env python3 2 3# Copyright 2023 Google LLC 4# 5# Licensed under the Apache License, Version 2.0 (the "License"); 6# you may not use this file except in compliance with the License. 7# You may obtain a copy of the License at 8# 9# https://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, software 12# distributed under the License is distributed on an "AS IS" BASIS, 13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14# See the License for the specific language governing permissions and 15# limitations under the License. 16 17import argparse 18from dataclasses import dataclass, field 19import json 20from pathlib import Path 21import sys 22from textwrap import dedent 23from typing import List, Tuple, Union, Optional 24 25from pdl import ast, core 26from pdl.utils import indent 27 28 29def mask(width: int) -> str: 30 return hex((1 << width) - 1) 31 32 33def generate_prelude() -> str: 34 return dedent("""\ 35 from dataclasses import dataclass, field, fields 36 from typing import Optional, List, Tuple, Union 37 import enum 38 import inspect 39 import math 40 41 @dataclass 42 class Packet: 43 payload: Optional[bytes] = field(repr=False, default_factory=bytes, compare=False) 44 45 @classmethod 46 def parse_all(cls, span: bytes) -> 'Packet': 47 packet, remain = getattr(cls, 'parse')(span) 48 if len(remain) > 0: 49 raise Exception('Unexpected parsing remainder') 50 return packet 51 52 @property 53 def size(self) -> int: 54 pass 55 56 def show(self, prefix: str = ''): 57 print(f'{self.__class__.__name__}') 58 59 def print_val(p: str, pp: str, name: str, align: int, typ, val): 60 if name == 'payload': 61 pass 62 63 # Scalar fields. 64 elif typ is int: 65 print(f'{p}{name:{align}} = {val} (0x{val:x})') 66 67 # Byte fields. 68 elif typ is bytes: 69 print(f'{p}{name:{align}} = [', end='') 70 line = '' 71 n_pp = '' 72 for (idx, b) in enumerate(val): 73 if idx > 0 and idx % 8 == 0: 74 print(f'{n_pp}{line}') 75 line = '' 76 n_pp = pp + (' ' * (align + 4)) 77 line += f' {b:02x}' 78 print(f'{n_pp}{line} ]') 79 80 # Enum fields. 81 elif inspect.isclass(typ) and issubclass(typ, enum.IntEnum): 82 print(f'{p}{name:{align}} = {typ.__name__}::{val.name} (0x{val:x})') 83 84 # Struct fields. 85 elif inspect.isclass(typ) and issubclass(typ, globals().get('Packet')): 86 print(f'{p}{name:{align}} = ', end='') 87 val.show(prefix=pp) 88 89 # Array fields. 90 elif getattr(typ, '__origin__', None) == list: 91 print(f'{p}{name:{align}}') 92 last = len(val) - 1 93 align = 5 94 for (idx, elt) in enumerate(val): 95 n_p = pp + ('├── ' if idx != last else '└── ') 96 n_pp = pp + ('│ ' if idx != last else ' ') 97 print_val(n_p, n_pp, f'[{idx}]', align, typ.__args__[0], val[idx]) 98 99 # Custom fields. 100 elif inspect.isclass(typ): 101 print(f'{p}{name:{align}} = {repr(val)}') 102 103 else: 104 print(f'{p}{name:{align}} = ##{typ}##') 105 106 last = len(fields(self)) - 1 107 align = max(len(f.name) for f in fields(self) if f.name != 'payload') 108 109 for (idx, f) in enumerate(fields(self)): 110 p = prefix + ('├── ' if idx != last else '└── ') 111 pp = prefix + ('│ ' if idx != last else ' ') 112 val = getattr(self, f.name) 113 114 print_val(p, pp, f.name, align, f.type, val) 115 """) 116 117 118@dataclass 119class FieldParser: 120 byteorder: str 121 offset: int = 0 122 shift: int = 0 123 chunk: List[Tuple[int, int, ast.Field]] = field(default_factory=lambda: []) 124 unchecked_code: List[str] = field(default_factory=lambda: []) 125 code: List[str] = field(default_factory=lambda: []) 126 127 def unchecked_append_(self, line: str): 128 """Append unchecked field parsing code. 129 The function check_size_ must be called to generate a size guard 130 after parsing is completed.""" 131 self.unchecked_code.append(line) 132 133 def append_(self, code: str): 134 """Append field parsing code. 135 There must be no unchecked code left before this function is called.""" 136 assert len(self.unchecked_code) == 0 137 self.code.extend(code.split('\n')) 138 139 def check_size_(self, size: str): 140 """Generate a check of the current span size.""" 141 self.append_(f"if len(span) < {size}:") 142 self.append_(f" raise Exception('Invalid packet size')") 143 144 def check_code_(self): 145 """Generate a size check for pending field parsing.""" 146 if len(self.unchecked_code) > 0: 147 assert len(self.chunk) == 0 148 unchecked_code = self.unchecked_code 149 self.unchecked_code = [] 150 self.check_size_(str(self.offset)) 151 self.code.extend(unchecked_code) 152 153 def consume_span_(self, keep: int = 0) -> str: 154 """Skip consumed span bytes.""" 155 if self.offset > 0: 156 self.check_code_() 157 self.append_(f'span = span[{self.offset - keep}:]') 158 self.offset = 0 159 160 def parse_array_element_dynamic_(self, field: ast.ArrayField, span: str): 161 """Parse a single array field element of variable size.""" 162 if isinstance(field.type, ast.StructDeclaration): 163 self.append_(f" element, {span} = {field.type_id}.parse({span})") 164 self.append_(f" {field.id}.append(element)") 165 else: 166 raise Exception(f'Unexpected array element type {field.type_id} {field.width}') 167 168 def parse_array_element_static_(self, field: ast.ArrayField, span: str): 169 """Parse a single array field element of constant size.""" 170 if field.width is not None: 171 element = f"int.from_bytes({span}, byteorder='{self.byteorder}')" 172 self.append_(f" {field.id}.append({element})") 173 elif isinstance(field.type, ast.EnumDeclaration): 174 element = f"int.from_bytes({span}, byteorder='{self.byteorder}')" 175 element = f"{field.type_id}({element})" 176 self.append_(f" {field.id}.append({element})") 177 else: 178 element = f"{field.type_id}.parse_all({span})" 179 self.append_(f" {field.id}.append({element})") 180 181 def parse_byte_array_field_(self, field: ast.ArrayField): 182 """Parse the selected u8 array field.""" 183 array_size = core.get_array_field_size(field) 184 padded_size = field.padded_size 185 186 # Shift the span to reset the offset to 0. 187 self.consume_span_() 188 189 # Derive the array size. 190 if isinstance(array_size, int): 191 size = array_size 192 elif isinstance(array_size, ast.SizeField): 193 size = f'{field.id}_size - {field.size_modifier}' if field.size_modifier else f'{field.id}_size' 194 elif isinstance(array_size, ast.CountField): 195 size = f'{field.id}_count' 196 else: 197 size = None 198 199 # Parse from the padded array if padding is present. 200 if padded_size and size is not None: 201 self.check_size_(padded_size) 202 self.append_(f"if {size} > {padded_size}:") 203 self.append_(" raise Exception('Array size is larger than the padding size')") 204 self.append_(f"fields['{field.id}'] = list(span[:{size}])") 205 self.append_(f"span = span[{padded_size}:]") 206 207 elif size is not None: 208 self.check_size_(size) 209 self.append_(f"fields['{field.id}'] = list(span[:{size}])") 210 self.append_(f"span = span[{size}:]") 211 212 else: 213 self.append_(f"fields['{field.id}'] = list(span)") 214 self.append_(f"span = bytes()") 215 216 def parse_array_field_(self, field: ast.ArrayField): 217 """Parse the selected array field.""" 218 array_size = core.get_array_field_size(field) 219 element_width = core.get_array_element_size(field) 220 padded_size = field.padded_size 221 222 if element_width: 223 if element_width % 8 != 0: 224 raise Exception('Array element size is not a multiple of 8') 225 element_width = int(element_width / 8) 226 227 if isinstance(array_size, int): 228 size = None 229 count = array_size 230 elif isinstance(array_size, ast.SizeField): 231 size = f'{field.id}_size' 232 count = None 233 elif isinstance(array_size, ast.CountField): 234 size = None 235 count = f'{field.id}_count' 236 else: 237 size = None 238 count = None 239 240 # Shift the span to reset the offset to 0. 241 self.consume_span_() 242 243 # Apply the size modifier. 244 if field.size_modifier and size: 245 self.append_(f"{size} = {size} - {field.size_modifier}") 246 247 # Parse from the padded array if padding is present. 248 if padded_size: 249 self.check_size_(padded_size) 250 self.append_(f"remaining_span = span[{padded_size}:]") 251 self.append_(f"span = span[:{padded_size}]") 252 253 # The element width is not known, but the array full octet size 254 # is known by size field. Parse elements item by item as a vector. 255 if element_width is None and size is not None: 256 self.check_size_(size) 257 self.append_(f"array_span = span[:{size}]") 258 self.append_(f"{field.id} = []") 259 self.append_("while len(array_span) > 0:") 260 self.parse_array_element_dynamic_(field, 'array_span') 261 self.append_(f"fields['{field.id}'] = {field.id}") 262 self.append_(f"span = span[{size}:]") 263 264 # The element width is not known, but the array element count 265 # is known statically or by count field. 266 # Parse elements item by item as a vector. 267 elif element_width is None and count is not None: 268 self.append_(f"{field.id} = []") 269 self.append_(f"for n in range({count}):") 270 self.parse_array_element_dynamic_(field, 'span') 271 self.append_(f"fields['{field.id}'] = {field.id}") 272 273 # Neither the count not size is known, 274 # parse elements until the end of the span. 275 elif element_width is None: 276 self.append_(f"{field.id} = []") 277 self.append_("while len(span) > 0:") 278 self.parse_array_element_dynamic_(field, 'span') 279 self.append_(f"fields['{field.id}'] = {field.id}") 280 281 # The element width is known, and the array element count is known 282 # statically, or by count field. 283 elif count is not None: 284 array_size = (f'{count}' if element_width == 1 else f'{count} * {element_width}') 285 self.check_size_(array_size) 286 self.append_(f"{field.id} = []") 287 self.append_(f"for n in range({count}):") 288 span = ('span[n:n + 1]' if element_width == 1 else f'span[n * {element_width}:(n + 1) * {element_width}]') 289 self.parse_array_element_static_(field, span) 290 self.append_(f"fields['{field.id}'] = {field.id}") 291 self.append_(f"span = span[{array_size}:]") 292 293 # The element width is known, and the array full size is known 294 # by size field, or unknown (in which case it is the remaining span 295 # length). 296 else: 297 if size is not None: 298 self.check_size_(size) 299 array_size = size or 'len(span)' 300 if element_width != 1: 301 self.append_(f"if {array_size} % {element_width} != 0:") 302 self.append_(" raise Exception('Array size is not a multiple of the element size')") 303 self.append_(f"{field.id}_count = int({array_size} / {element_width})") 304 array_count = f'{field.id}_count' 305 else: 306 array_count = array_size 307 self.append_(f"{field.id} = []") 308 self.append_(f"for n in range({array_count}):") 309 span = ('span[n:n + 1]' if element_width == 1 else f'span[n * {element_width}:(n + 1) * {element_width}]') 310 self.parse_array_element_static_(field, span) 311 self.append_(f"fields['{field.id}'] = {field.id}") 312 if size is not None: 313 self.append_(f"span = span[{size}:]") 314 else: 315 self.append_(f"span = bytes()") 316 317 # Drop the padding 318 if padded_size: 319 self.append_(f"span = remaining_span") 320 321 def parse_optional_field_(self, field: ast.Field): 322 """Parse the selected optional field. 323 Optional fields must start and end on a byte boundary.""" 324 325 if self.shift != 0: 326 raise Exception('Optional field does not start on an octet boundary') 327 if (isinstance(field, ast.TypedefField) and 328 isinstance(field.type, ast.StructDeclaration) and 329 field.type.parent_id is not None): 330 raise Exception('Derived struct used in optional typedef field') 331 332 self.consume_span_() 333 334 if isinstance(field, ast.ScalarField): 335 self.append_(dedent(""" 336 if {cond_id} == {cond_value}: 337 if len(span) < {size}: 338 raise Exception('Invalid packet size') 339 fields['{field_id}'] = int.from_bytes(span[:{size}], byteorder='{byteorder}') 340 span = span[{size}:] 341 """.format(size=int(field.width / 8), 342 field_id=field.id, 343 cond_id=field.cond.id, 344 cond_value=field.cond.value, 345 byteorder=self.byteorder))) 346 347 elif isinstance(field, ast.TypedefField) and isinstance(field.type, ast.EnumDeclaration): 348 self.append_(dedent(""" 349 if {cond_id} == {cond_value}: 350 if len(span) < {size}: 351 raise Exception('Invalid packet size') 352 fields['{field_id}'] = {type_id}( 353 int.from_bytes(span[:{size}], byteorder='{byteorder}')) 354 span = span[{size}:] 355 """.format(size=int(field.type.width / 8), 356 field_id=field.id, 357 type_id=field.type_id, 358 cond_id=field.cond.id, 359 cond_value=field.cond.value, 360 byteorder=self.byteorder))) 361 362 elif isinstance(field, ast.TypedefField): 363 self.append_(dedent(""" 364 if {cond_id} == {cond_value}: 365 {field_id}, span = {type_id}.parse(span) 366 fields['{field_id}'] = {field_id} 367 """.format(field_id=field.id, 368 type_id=field.type_id, 369 cond_id=field.cond.id, 370 cond_value=field.cond.value))) 371 372 else: 373 raise Exception(f"unsupported field type {field.__class__.__name__}") 374 375 def parse_bit_field_(self, field: ast.Field): 376 """Parse the selected field as a bit field. 377 The field is added to the current chunk. When a byte boundary 378 is reached all saved fields are extracted together.""" 379 380 # Add to current chunk. 381 width = core.get_field_size(field) 382 self.chunk.append((self.shift, width, field)) 383 self.shift += width 384 385 # Wait for more fields if not on a byte boundary. 386 if (self.shift % 8) != 0: 387 return 388 389 # Parse the backing integer using the configured endiannes, 390 # extract field values. 391 size = int(self.shift / 8) 392 end_offset = self.offset + size 393 394 if size == 1: 395 value = f"span[{self.offset}]" 396 else: 397 span = f"span[{self.offset}:{end_offset}]" 398 self.unchecked_append_(f"value_ = int.from_bytes({span}, byteorder='{self.byteorder}')") 399 value = "value_" 400 401 for shift, width, field in self.chunk: 402 v = (value if len(self.chunk) == 1 and shift == 0 else f"({value} >> {shift}) & {mask(width)}") 403 404 if field.cond_for: 405 self.unchecked_append_(f"{field.id} = {v}") 406 elif isinstance(field, ast.ScalarField): 407 self.unchecked_append_(f"fields['{field.id}'] = {v}") 408 elif isinstance(field, ast.FixedField) and field.enum_id: 409 self.unchecked_append_(f"if {v} != {field.enum_id}.{field.tag_id}:") 410 self.unchecked_append_(f" raise Exception('Unexpected fixed field value')") 411 elif isinstance(field, ast.FixedField): 412 self.unchecked_append_(f"if {v} != {hex(field.value)}:") 413 self.unchecked_append_(f" raise Exception('Unexpected fixed field value')") 414 elif isinstance(field, ast.TypedefField): 415 self.unchecked_append_(f"fields['{field.id}'] = {field.type_id}.from_int({v})") 416 elif isinstance(field, ast.SizeField): 417 self.unchecked_append_(f"{field.field_id}_size = {v}") 418 elif isinstance(field, ast.CountField): 419 self.unchecked_append_(f"{field.field_id}_count = {v}") 420 elif isinstance(field, ast.ReservedField): 421 pass 422 else: 423 raise Exception(f'Unsupported bit field type {field.kind}') 424 425 # Reset state. 426 self.offset = end_offset 427 self.shift = 0 428 self.chunk = [] 429 430 def parse_typedef_field_(self, field: ast.TypedefField): 431 """Parse a typedef field, to the exclusion of Enum fields.""" 432 433 if self.shift != 0: 434 raise Exception('Typedef field does not start on an octet boundary') 435 if (isinstance(field.type, ast.StructDeclaration) and field.type.parent_id is not None): 436 raise Exception('Derived struct used in typedef field') 437 438 width = core.get_declaration_size(field.type) 439 if width is None: 440 self.consume_span_() 441 self.append_(f"{field.id}, span = {field.type_id}.parse(span)") 442 self.append_(f"fields['{field.id}'] = {field.id}") 443 else: 444 if width % 8 != 0: 445 raise Exception('Typedef field type size is not a multiple of 8') 446 width = int(width / 8) 447 end_offset = self.offset + width 448 # Checksum value field is generated alongside checksum start. 449 # Deal with this field as padding. 450 if not isinstance(field.type, ast.ChecksumDeclaration): 451 span = f'span[{self.offset}:{end_offset}]' 452 self.unchecked_append_(f"fields['{field.id}'] = {field.type_id}.parse_all({span})") 453 self.offset = end_offset 454 455 def parse_payload_field_(self, field: Union[ast.BodyField, ast.PayloadField]): 456 """Parse body and payload fields.""" 457 458 payload_size = core.get_payload_field_size(field) 459 offset_from_end = core.get_field_offset_from_end(field) 460 461 # If the payload is not byte aligned, do parse the bit fields 462 # that can be extracted, but do not consume the input bytes as 463 # they will also be included in the payload span. 464 if self.shift != 0: 465 if payload_size: 466 raise Exception("Unexpected payload size for non byte aligned payload") 467 468 rounded_size = int((self.shift + 7) / 8) 469 padding_bits = 8 * rounded_size - self.shift 470 self.parse_bit_field_(core.make_reserved_field(padding_bits)) 471 self.consume_span_(rounded_size) 472 else: 473 self.consume_span_() 474 475 # The payload or body has a known size. 476 # Consume the payload and update the span in case 477 # fields are placed after the payload. 478 if payload_size: 479 if getattr(field, 'size_modifier', None): 480 self.append_(f"{field.id}_size -= {field.size_modifier}") 481 self.check_size_(f'{field.id}_size') 482 self.append_(f"payload = span[:{field.id}_size]") 483 self.append_(f"span = span[{field.id}_size:]") 484 # The payload or body is the last field of a packet, 485 # consume the remaining span. 486 elif offset_from_end == 0: 487 self.append_(f"payload = span") 488 self.append_(f"span = bytes([])") 489 # The payload or body is followed by fields of static size. 490 # Consume the span that is not reserved for the following fields. 491 elif offset_from_end is not None: 492 if (offset_from_end % 8) != 0: 493 raise Exception('Payload field offset from end of packet is not a multiple of 8') 494 offset_from_end = int(offset_from_end / 8) 495 self.check_size_(f'{offset_from_end}') 496 self.append_(f"payload = span[:-{offset_from_end}]") 497 self.append_(f"span = span[-{offset_from_end}:]") 498 self.append_(f"fields['payload'] = payload") 499 500 def parse_checksum_field_(self, field: ast.ChecksumField): 501 """Generate a checksum check.""" 502 503 # The checksum value field can be read starting from the current 504 # offset if the fields in between are of fixed size, or from the end 505 # of the span otherwise. 506 self.consume_span_() 507 value_field = core.get_packet_field(field.parent, field.field_id) 508 offset_from_start = 0 509 offset_from_end = 0 510 start_index = field.parent.fields.index(field) 511 value_index = field.parent.fields.index(value_field) 512 value_size = int(core.get_field_size(value_field) / 8) 513 514 for f in field.parent.fields[start_index + 1:value_index]: 515 size = core.get_field_size(f) 516 if size is None: 517 offset_from_start = None 518 break 519 else: 520 offset_from_start += size 521 522 trailing_fields = field.parent.fields[value_index:] 523 trailing_fields.reverse() 524 for f in trailing_fields: 525 size = core.get_field_size(f) 526 if size is None: 527 offset_from_end = None 528 break 529 else: 530 offset_from_end += size 531 532 if offset_from_start is not None: 533 if offset_from_start % 8 != 0: 534 raise Exception('Checksum value field is not aligned to an octet boundary') 535 offset_from_start = int(offset_from_start / 8) 536 checksum_span = f'span[:{offset_from_start}]' 537 if value_size > 1: 538 start = offset_from_start 539 end = offset_from_start + value_size 540 value = f"int.from_bytes(span[{start}:{end}], byteorder='{self.byteorder}')" 541 else: 542 value = f'span[{offset_from_start}]' 543 self.check_size_(offset_from_start + value_size) 544 545 elif offset_from_end is not None: 546 sign = '' 547 if offset_from_end % 8 != 0: 548 raise Exception('Checksum value field is not aligned to an octet boundary') 549 offset_from_end = int(offset_from_end / 8) 550 checksum_span = f'span[:-{offset_from_end}]' 551 if value_size > 1: 552 start = offset_from_end 553 end = offset_from_end - value_size 554 value = f"int.from_bytes(span[-{start}:-{end}], byteorder='{self.byteorder}')" 555 else: 556 value = f'span[-{offset_from_end}]' 557 self.check_size_(offset_from_end) 558 559 else: 560 raise Exception('Checksum value field cannot be read at constant offset') 561 562 self.append_(f"{value_field.id} = {value}") 563 self.append_(f"fields['{value_field.id}'] = {value_field.id}") 564 self.append_(f"computed_{value_field.id} = {value_field.type.function}({checksum_span})") 565 self.append_(f"if computed_{value_field.id} != {value_field.id}:") 566 self.append_(" raise Exception(f'Invalid checksum computation:" + 567 f" {{computed_{value_field.id}}} != {{{value_field.id}}}')") 568 569 def parse(self, field: ast.Field): 570 if field.cond: 571 self.parse_optional_field_(field) 572 573 # Field has bit granularity. 574 # Append the field to the current chunk, 575 # check if a byte boundary was reached. 576 elif core.is_bit_field(field): 577 self.parse_bit_field_(field) 578 579 # Padding fields. 580 elif isinstance(field, ast.PaddingField): 581 pass 582 583 # Array fields. 584 elif isinstance(field, ast.ArrayField) and field.width == 8: 585 self.parse_byte_array_field_(field) 586 587 elif isinstance(field, ast.ArrayField): 588 self.parse_array_field_(field) 589 590 # Other typedef fields. 591 elif isinstance(field, ast.TypedefField): 592 self.parse_typedef_field_(field) 593 594 # Payload and body fields. 595 elif isinstance(field, (ast.PayloadField, ast.BodyField)): 596 self.parse_payload_field_(field) 597 598 # Checksum fields. 599 elif isinstance(field, ast.ChecksumField): 600 self.parse_checksum_field_(field) 601 602 else: 603 raise Exception(f'Unimplemented field type {field.kind}') 604 605 def done(self): 606 self.consume_span_() 607 608 609@dataclass 610class FieldSerializer: 611 byteorder: str 612 shift: int = 0 613 value: List[str] = field(default_factory=lambda: []) 614 code: List[str] = field(default_factory=lambda: []) 615 indent: int = 0 616 617 def indent_(self): 618 self.indent += 1 619 620 def unindent_(self): 621 self.indent -= 1 622 623 def append_(self, line: str): 624 """Append field serializing code.""" 625 lines = line.split('\n') 626 self.code.extend([' ' * self.indent + line for line in lines]) 627 628 def extend_(self, value: str, length: int): 629 """Append data to the span being constructed.""" 630 if length == 1: 631 self.append_(f"_span.append({value})") 632 else: 633 self.append_(f"_span.extend(int.to_bytes({value}, length={length}, byteorder='{self.byteorder}'))") 634 635 def serialize_array_element_(self, field: ast.ArrayField): 636 """Serialize a single array field element.""" 637 if field.width is not None: 638 length = int(field.width / 8) 639 self.extend_('_elt', length) 640 elif isinstance(field.type, ast.EnumDeclaration): 641 length = int(field.type.width / 8) 642 self.extend_('_elt', length) 643 else: 644 self.append_("_span.extend(_elt.serialize())") 645 646 def serialize_array_field_(self, field: ast.ArrayField): 647 """Serialize the selected array field.""" 648 if field.padded_size: 649 self.append_(f"_{field.id}_start = len(_span)") 650 651 if field.width == 8: 652 self.append_(f"_span.extend(self.{field.id})") 653 else: 654 self.append_(f"for _elt in self.{field.id}:") 655 self.indent_() 656 self.serialize_array_element_(field) 657 self.unindent_() 658 659 if field.padded_size: 660 self.append_(f"_span.extend([0] * ({field.padded_size} - len(_span) + _{field.id}_start))") 661 662 def serialize_optional_field_(self, field: ast.Field): 663 if isinstance(field, ast.ScalarField): 664 self.append_(dedent( 665 """ 666 if self.{field_id} is not None: 667 _span.extend(int.to_bytes(self.{field_id}, length={size}, byteorder='{byteorder}')) 668 """.format(field_id=field.id, 669 size=int(field.width / 8), 670 byteorder=self.byteorder))) 671 672 elif isinstance(field, ast.TypedefField) and isinstance(field.type, ast.EnumDeclaration): 673 self.append_(dedent( 674 """ 675 if self.{field_id} is not None: 676 _span.extend(int.to_bytes(self.{field_id}, length={size}, byteorder='{byteorder}')) 677 """.format(field_id=field.id, 678 size=int(field.type.width / 8), 679 byteorder=self.byteorder))) 680 681 elif isinstance(field, ast.TypedefField): 682 self.append_(dedent( 683 """ 684 if self.{field_id} is not None: 685 _span.extend(self.{field_id}.serialize()) 686 """.format(field_id=field.id))) 687 688 else: 689 raise Exception(f"unsupported field type {field.__class__.__name__}") 690 691 def serialize_bit_field_(self, field: ast.Field): 692 """Serialize the selected field as a bit field. 693 The field is added to the current chunk. When a byte boundary 694 is reached all saved fields are serialized together.""" 695 696 # Add to current chunk. 697 width = core.get_field_size(field) 698 shift = self.shift 699 700 if isinstance(field, str): 701 self.value.append(f"({field} << {shift})") 702 elif field.cond_for: 703 # Scalar field used as condition for an optional field. 704 # The width is always 1, the value is determined from 705 # the presence or absence of the optional field. 706 value_present = field.cond_for.cond.value 707 value_absent = 0 if field.cond_for.cond.value else 1 708 self.value.append(f"(({value_absent} if self.{field.cond_for.id} is None else {value_present}) << {shift})") 709 elif isinstance(field, ast.ScalarField): 710 max_value = (1 << field.width) - 1 711 self.append_(f"if self.{field.id} > {max_value}:") 712 self.append_(f" print(f\"Invalid value for field {field.parent.id}::{field.id}:" + 713 f" {{self.{field.id}}} > {max_value}; the value will be truncated\")") 714 self.append_(f" self.{field.id} &= {max_value}") 715 self.value.append(f"(self.{field.id} << {shift})") 716 elif isinstance(field, ast.FixedField) and field.enum_id: 717 self.value.append(f"({field.enum_id}.{field.tag_id} << {shift})") 718 elif isinstance(field, ast.FixedField): 719 self.value.append(f"({field.value} << {shift})") 720 elif isinstance(field, ast.TypedefField): 721 self.value.append(f"(self.{field.id} << {shift})") 722 723 elif isinstance(field, ast.SizeField): 724 max_size = (1 << field.width) - 1 725 value_field = core.get_packet_field(field.parent, field.field_id) 726 size_modifier = '' 727 728 if getattr(value_field, 'size_modifier', None): 729 size_modifier = f' + {value_field.size_modifier}' 730 731 if isinstance(value_field, (ast.PayloadField, ast.BodyField)): 732 self.append_(f"_payload_size = len(payload or self.payload or []){size_modifier}") 733 self.append_(f"if _payload_size > {max_size}:") 734 self.append_(f" print(f\"Invalid length for payload field:" + 735 f" {{_payload_size}} > {max_size}; the packet cannot be generated\")") 736 self.append_(f" raise Exception(\"Invalid payload length\")") 737 array_size = "_payload_size" 738 elif isinstance(value_field, ast.ArrayField) and value_field.width: 739 array_size = f"(len(self.{value_field.id}) * {int(value_field.width / 8)}{size_modifier})" 740 elif isinstance(value_field, ast.ArrayField) and isinstance(value_field.type, ast.EnumDeclaration): 741 array_size = f"(len(self.{value_field.id}) * {int(value_field.type.width / 8)}{size_modifier})" 742 elif isinstance(value_field, ast.ArrayField): 743 self.append_( 744 f"_{value_field.id}_size = sum([elt.size for elt in self.{value_field.id}]){size_modifier}") 745 array_size = f"_{value_field.id}_size" 746 else: 747 raise Exception("Unsupported field type") 748 self.value.append(f"({array_size} << {shift})") 749 750 elif isinstance(field, ast.CountField): 751 max_count = (1 << field.width) - 1 752 self.append_(f"if len(self.{field.field_id}) > {max_count}:") 753 self.append_(f" print(f\"Invalid length for field {field.parent.id}::{field.field_id}:" + 754 f" {{len(self.{field.field_id})}} > {max_count}; the array will be truncated\")") 755 self.append_(f" del self.{field.field_id}[{max_count}:]") 756 self.value.append(f"(len(self.{field.field_id}) << {shift})") 757 elif isinstance(field, ast.ReservedField): 758 pass 759 else: 760 raise Exception(f'Unsupported bit field type {field.kind}') 761 762 # Check if a byte boundary is reached. 763 self.shift += width 764 if (self.shift % 8) == 0: 765 self.pack_bit_fields_() 766 767 def pack_bit_fields_(self): 768 """Pack serialized bit fields.""" 769 770 # Should have an integral number of bytes now. 771 assert (self.shift % 8) == 0 772 773 # Generate the backing integer, and serialize it 774 # using the configured endiannes, 775 size = int(self.shift / 8) 776 777 if len(self.value) == 0: 778 self.append_(f"_span.extend([0] * {size})") 779 elif len(self.value) == 1: 780 self.extend_(self.value[0], size) 781 else: 782 self.append_(f"_value = (") 783 self.append_(" " + " |\n ".join(self.value)) 784 self.append_(")") 785 self.extend_('_value', size) 786 787 # Reset state. 788 self.shift = 0 789 self.value = [] 790 791 def serialize_typedef_field_(self, field: ast.TypedefField): 792 """Serialize a typedef field, to the exclusion of Enum fields.""" 793 794 if self.shift != 0: 795 raise Exception('Typedef field does not start on an octet boundary') 796 if (isinstance(field.type, ast.StructDeclaration) and field.type.parent_id is not None): 797 raise Exception('Derived struct used in typedef field') 798 799 if isinstance(field.type, ast.ChecksumDeclaration): 800 size = int(field.type.width / 8) 801 self.append_(f"_checksum = {field.type.function}(_span[_checksum_start:])") 802 self.extend_('_checksum', size) 803 else: 804 self.append_(f"_span.extend(self.{field.id}.serialize())") 805 806 def serialize_payload_field_(self, field: Union[ast.BodyField, ast.PayloadField]): 807 """Serialize body and payload fields.""" 808 809 if self.shift != 0 and self.byteorder == 'big': 810 raise Exception('Payload field does not start on an octet boundary') 811 812 if self.shift == 0: 813 self.append_(f"_span.extend(payload or self.payload or [])") 814 else: 815 # Supported case of packet inheritance; 816 # the incomplete fields are serialized into 817 # the payload, rather than separately. 818 # First extract the padding bits from the payload, 819 # then recombine them with the bit fields to be serialized. 820 rounded_size = int((self.shift + 7) / 8) 821 padding_bits = 8 * rounded_size - self.shift 822 self.append_(f"_payload = payload or self.payload or bytes()") 823 self.append_(f"if len(_payload) < {rounded_size}:") 824 self.append_(f" raise Exception(f\"Invalid length for payload field:" + 825 f" {{len(_payload)}} < {rounded_size}\")") 826 self.append_( 827 f"_padding = int.from_bytes(_payload[:{rounded_size}], byteorder='{self.byteorder}') >> {self.shift}") 828 self.value.append(f"(_padding << {self.shift})") 829 self.shift += padding_bits 830 self.pack_bit_fields_() 831 self.append_(f"_span.extend(_payload[{rounded_size}:])") 832 833 def serialize_checksum_field_(self, field: ast.ChecksumField): 834 """Generate a checksum check.""" 835 836 self.append_("_checksum_start = len(_span)") 837 838 def serialize(self, field: ast.Field): 839 if field.cond: 840 self.serialize_optional_field_(field) 841 842 # Field has bit granularity. 843 # Append the field to the current chunk, 844 # check if a byte boundary was reached. 845 elif core.is_bit_field(field): 846 self.serialize_bit_field_(field) 847 848 # Padding fields. 849 elif isinstance(field, ast.PaddingField): 850 pass 851 852 # Array fields. 853 elif isinstance(field, ast.ArrayField): 854 self.serialize_array_field_(field) 855 856 # Other typedef fields. 857 elif isinstance(field, ast.TypedefField): 858 self.serialize_typedef_field_(field) 859 860 # Payload and body fields. 861 elif isinstance(field, (ast.PayloadField, ast.BodyField)): 862 self.serialize_payload_field_(field) 863 864 # Checksum fields. 865 elif isinstance(field, ast.ChecksumField): 866 self.serialize_checksum_field_(field) 867 868 else: 869 raise Exception(f'Unimplemented field type {field.kind}') 870 871 872def generate_toplevel_packet_serializer(packet: ast.Declaration) -> List[str]: 873 """Generate the serialize() function for a toplevel Packet or Struct 874 declaration.""" 875 876 serializer = FieldSerializer(byteorder=packet.file.byteorder) 877 for f in packet.fields: 878 serializer.serialize(f) 879 return ['_span = bytearray()'] + serializer.code + ['return bytes(_span)'] 880 881 882def generate_derived_packet_serializer(packet: ast.Declaration) -> List[str]: 883 """Generate the serialize() function for a derived Packet or Struct 884 declaration.""" 885 886 packet_shift = core.get_packet_shift(packet) 887 if packet_shift and packet.file.byteorder == 'big': 888 raise Exception(f"Big-endian packet {packet.id} has an unsupported body shift") 889 890 serializer = FieldSerializer(byteorder=packet.file.byteorder, shift=packet_shift) 891 for f in packet.fields: 892 serializer.serialize(f) 893 return ['_span = bytearray()' 894 ] + serializer.code + [f'return {packet.parent.id}.serialize(self, payload = bytes(_span))'] 895 896 897def generate_packet_parser(packet: ast.Declaration) -> List[str]: 898 """Generate the parse() function for a toplevel Packet or Struct 899 declaration.""" 900 901 packet_shift = core.get_packet_shift(packet) 902 if packet_shift and packet.file.byteorder == 'big': 903 raise Exception(f"Big-endian packet {packet.id} has an unsupported body shift") 904 905 # Convert the packet constraints to a boolean expression. 906 validation = [] 907 constraints = core.get_all_packet_constraints(packet) 908 if constraints: 909 cond = [] 910 for c in constraints: 911 if c.value is not None: 912 cond.append(f"fields['{c.id}'] != {hex(c.value)}") 913 else: 914 field = core.get_packet_field(packet, c.id) 915 cond.append(f"fields['{c.id}'] != {field.type_id}.{c.tag_id}") 916 917 validation = [f"if {' or '.join(cond)}:", " raise Exception(\"Invalid constraint field values\")"] 918 919 # Parse fields iteratively. 920 parser = FieldParser(byteorder=packet.file.byteorder, shift=packet_shift) 921 for f in packet.fields: 922 parser.parse(f) 923 parser.done() 924 925 # Specialize to child packets. 926 children = core.get_derived_packets(packet) 927 decl = [] if packet.parent_id else ['fields = {\'payload\': None}'] 928 specialization = [] 929 930 if len(children) != 0: 931 # Try parsing every child packet successively until one is 932 # successfully parsed. Return a parsing error if none is valid. 933 # Return parent packet if no child packet matches. 934 # TODO: order child packets by decreasing size in case no constraint 935 # is given for specialization. 936 for _, child in children: 937 specialization.append("try:") 938 specialization.append(f" return {child.id}.parse(fields.copy(), payload)") 939 specialization.append("except Exception as exn:") 940 specialization.append(" pass") 941 942 return decl + validation + parser.code + specialization + [f"return {packet.id}(**fields), span"] 943 944 945def generate_packet_size_getter(packet: ast.Declaration) -> List[str]: 946 constant_width = 0 947 variable_width = [] 948 for f in packet.fields: 949 field_size = core.get_field_size(f) 950 if f.cond: 951 if isinstance(f, ast.ScalarField): 952 return f"(0 if self.{f.id} is None else {f.width})" 953 elif isinstance(f, ast.TypedefField) and isinstance(f.type, ast.EnumDeclaration): 954 return f"(0 if self.{f.id} is None else {f.type.width})" 955 elif isinstance(f, ast.TypedefField): 956 return f"(0 if self.{f.id} is None else self.{f.id}.size)" 957 else: 958 raise Exception(f"unsupported field type {f.__class__.__name__}") 959 elif field_size is not None: 960 constant_width += field_size 961 elif isinstance(f, (ast.PayloadField, ast.BodyField)): 962 variable_width.append("len(self.payload)") 963 elif isinstance(f, ast.TypedefField): 964 variable_width.append(f"self.{f.id}.size") 965 elif isinstance(f, ast.ArrayField) and isinstance(f.type, (ast.StructDeclaration, ast.CustomFieldDeclaration)): 966 variable_width.append(f"sum([elt.size for elt in self.{f.id}])") 967 elif isinstance(f, ast.ArrayField) and isinstance(f.type, ast.EnumDeclaration): 968 variable_width.append(f"len(self.{f.id}) * {f.type.width}") 969 elif isinstance(f, ast.ArrayField): 970 variable_width.append(f"len(self.{f.id}) * {int(f.width / 8)}") 971 else: 972 raise Exception("Unsupported field type") 973 974 constant_width = int(constant_width / 8) 975 if len(variable_width) == 0: 976 return [f"return {constant_width}"] 977 elif len(variable_width) == 1 and constant_width: 978 return [f"return {variable_width[0]} + {constant_width}"] 979 elif len(variable_width) == 1: 980 return [f"return {variable_width[0]}"] 981 elif len(variable_width) > 1 and constant_width: 982 return ([f"return {constant_width} + ("] + " +\n ".join(variable_width).split("\n") + [")"]) 983 elif len(variable_width) > 1: 984 return (["return ("] + " +\n ".join(variable_width).split("\n") + [")"]) 985 else: 986 assert False 987 988 989def generate_packet_post_init(decl: ast.Declaration) -> List[str]: 990 """Generate __post_init__ function to set constraint field values.""" 991 992 # Gather all constraints from parent packets. 993 constraints = core.get_all_packet_constraints(decl) 994 995 if constraints: 996 code = [] 997 for c in constraints: 998 if c.value is not None: 999 code.append(f"self.{c.id} = {c.value}") 1000 else: 1001 field = core.get_packet_field(decl, c.id) 1002 code.append(f"self.{c.id} = {field.type_id}.{c.tag_id}") 1003 return code 1004 1005 else: 1006 return ["pass"] 1007 1008 1009def generate_enum_declaration(decl: ast.EnumDeclaration) -> str: 1010 """Generate the implementation of an enum type.""" 1011 1012 enum_name = decl.id 1013 tag_decls = [] 1014 for t in decl.tags: 1015 # Enums in python are closed and ranges cannot be represented; 1016 # instead the generated code uses Union[int, Enum] 1017 # when ranges are used. 1018 if t.value is not None: 1019 tag_decls.append(f"{t.id} = {hex(t.value)}") 1020 1021 if core.is_open_enum(decl): 1022 unknown_handler = ["return v"] 1023 else: 1024 unknown_handler = [] 1025 for t in decl.tags: 1026 if t.range is not None: 1027 unknown_handler.append(f"if v >= 0x{t.range[0]:x} and v <= 0x{t.range[1]:x}:") 1028 unknown_handler.append(f" return v") 1029 unknown_handler.append("raise exn") 1030 1031 return dedent("""\ 1032 1033 class {enum_name}(enum.IntEnum): 1034 {tag_decls} 1035 1036 @staticmethod 1037 def from_int(v: int) -> Union[int, '{enum_name}']: 1038 try: 1039 return {enum_name}(v) 1040 except ValueError as exn: 1041 {unknown_handler} 1042 1043 """).format(enum_name=enum_name, 1044 tag_decls=indent(tag_decls, 1), 1045 unknown_handler=indent(unknown_handler, 3)) 1046 1047 1048def generate_packet_declaration(packet: ast.Declaration) -> str: 1049 """Generate the implementation a toplevel Packet or Struct 1050 declaration.""" 1051 1052 packet_name = packet.id 1053 field_decls = [] 1054 for f in packet.fields: 1055 if f.cond: 1056 if isinstance(f, ast.ScalarField): 1057 field_decls.append(f"{f.id}: Optional[int] = field(kw_only=True, default=None)") 1058 elif isinstance(f, ast.TypedefField): 1059 field_decls.append(f"{f.id}: Optional[{f.type_id}] = field(kw_only=True, default=None)") 1060 else: 1061 pass 1062 elif f.cond_for: 1063 # The fields used as condition for optional fields are 1064 # not generated since their value is tied to the value of the 1065 # optional field. 1066 pass 1067 elif isinstance(f, ast.ScalarField): 1068 field_decls.append(f"{f.id}: int = field(kw_only=True, default=0)") 1069 elif isinstance(f, ast.TypedefField): 1070 if isinstance(f.type, ast.EnumDeclaration) and f.type.tags[0].range: 1071 field_decls.append( 1072 f"{f.id}: {f.type_id} = field(kw_only=True, default={f.type.tags[0].range[0]})") 1073 elif isinstance(f.type, ast.EnumDeclaration): 1074 field_decls.append( 1075 f"{f.id}: {f.type_id} = field(kw_only=True, default={f.type_id}.{f.type.tags[0].id})") 1076 elif isinstance(f.type, ast.ChecksumDeclaration): 1077 field_decls.append(f"{f.id}: int = field(kw_only=True, default=0)") 1078 elif isinstance(f.type, (ast.StructDeclaration, ast.CustomFieldDeclaration)): 1079 field_decls.append(f"{f.id}: {f.type_id} = field(kw_only=True, default_factory={f.type_id})") 1080 else: 1081 raise Exception("Unsupported typedef field type") 1082 elif isinstance(f, ast.ArrayField) and f.width == 8: 1083 field_decls.append(f"{f.id}: bytearray = field(kw_only=True, default_factory=bytearray)") 1084 elif isinstance(f, ast.ArrayField) and f.width: 1085 field_decls.append(f"{f.id}: List[int] = field(kw_only=True, default_factory=list)") 1086 elif isinstance(f, ast.ArrayField) and f.type_id: 1087 field_decls.append(f"{f.id}: List[{f.type_id}] = field(kw_only=True, default_factory=list)") 1088 1089 if packet.parent_id: 1090 parent_name = packet.parent_id 1091 parent_fields = 'fields: dict, ' 1092 serializer = generate_derived_packet_serializer(packet) 1093 else: 1094 parent_name = 'Packet' 1095 parent_fields = '' 1096 serializer = generate_toplevel_packet_serializer(packet) 1097 1098 parser = generate_packet_parser(packet) 1099 size = generate_packet_size_getter(packet) 1100 post_init = generate_packet_post_init(packet) 1101 1102 return dedent("""\ 1103 1104 @dataclass 1105 class {packet_name}({parent_name}): 1106 {field_decls} 1107 1108 def __post_init__(self): 1109 {post_init} 1110 1111 @staticmethod 1112 def parse({parent_fields}span: bytes) -> Tuple['{packet_name}', bytes]: 1113 {parser} 1114 1115 def serialize(self, payload: bytes = None) -> bytes: 1116 {serializer} 1117 1118 @property 1119 def size(self) -> int: 1120 {size} 1121 """).format(packet_name=packet_name, 1122 parent_name=parent_name, 1123 parent_fields=parent_fields, 1124 field_decls=indent(field_decls, 1), 1125 post_init=indent(post_init, 2), 1126 parser=indent(parser, 2), 1127 serializer=indent(serializer, 2), 1128 size=indent(size, 2)) 1129 1130 1131def generate_custom_field_declaration_check(decl: ast.CustomFieldDeclaration) -> str: 1132 """Generate the code to validate a user custom field implementation. 1133 1134 This code is to be executed when the generated module is loaded to ensure 1135 the user gets an immediate and clear error message when the provided 1136 custom types do not fit the expected template. 1137 """ 1138 return dedent("""\ 1139 1140 if (not callable(getattr({custom_field_name}, 'parse', None)) or 1141 not callable(getattr({custom_field_name}, 'parse_all', None))): 1142 raise Exception('The custom field type {custom_field_name} does not implement the parse method') 1143 """).format(custom_field_name=decl.id) 1144 1145 1146def generate_checksum_declaration_check(decl: ast.ChecksumDeclaration) -> str: 1147 """Generate the code to validate a user checksum field implementation. 1148 1149 This code is to be executed when the generated module is loaded to ensure 1150 the user gets an immediate and clear error message when the provided 1151 checksum functions do not fit the expected template. 1152 """ 1153 return dedent("""\ 1154 1155 if not callable({checksum_name}): 1156 raise Exception('{checksum_name} is not callable') 1157 """).format(checksum_name=decl.id) 1158 1159 1160def run(input: argparse.FileType, output: argparse.FileType, custom_type_location: Optional[str], exclude_declaration: List[str]): 1161 file = ast.File.from_json(json.load(input)) 1162 core.desugar(file) 1163 1164 custom_types = [] 1165 custom_type_checks = "" 1166 for d in file.declarations: 1167 if d.id in exclude_declaration: 1168 continue 1169 1170 if isinstance(d, ast.CustomFieldDeclaration): 1171 custom_types.append(d.id) 1172 custom_type_checks += generate_custom_field_declaration_check(d) 1173 elif isinstance(d, ast.ChecksumDeclaration): 1174 custom_types.append(d.id) 1175 custom_type_checks += generate_checksum_declaration_check(d) 1176 1177 output.write(f"# File generated from {input.name}, with the command:\n") 1178 output.write(f"# {' '.join(sys.argv)}\n") 1179 output.write("# /!\\ Do not edit by hand.\n") 1180 if custom_types and custom_type_location: 1181 output.write(f"\nfrom {custom_type_location} import {', '.join(custom_types)}\n") 1182 output.write(generate_prelude()) 1183 output.write(custom_type_checks) 1184 1185 for d in file.declarations: 1186 if d.id in exclude_declaration: 1187 continue 1188 1189 if isinstance(d, ast.EnumDeclaration): 1190 output.write(generate_enum_declaration(d)) 1191 elif isinstance(d, (ast.PacketDeclaration, ast.StructDeclaration)): 1192 output.write(generate_packet_declaration(d)) 1193 1194 1195def main() -> int: 1196 """Generate python PDL backend.""" 1197 parser = argparse.ArgumentParser(description=__doc__) 1198 parser.add_argument('--input', type=argparse.FileType('r'), default=sys.stdin, help='Input PDL-JSON source') 1199 parser.add_argument('--output', type=argparse.FileType('w'), default=sys.stdout, help='Output Python file') 1200 parser.add_argument('--custom-type-location', 1201 type=str, 1202 required=False, 1203 help='Module of declaration of custom types') 1204 parser.add_argument('--exclude-declaration', 1205 type=str, 1206 default=[], 1207 action='append', 1208 help='Exclude declaration from the generated output') 1209 return run(**vars(parser.parse_args())) 1210 1211 1212if __name__ == '__main__': 1213 sys.exit(main()) 1214