1# Copyright 2023 Google LLC
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15from dataclasses import dataclass, field
16from typing import Optional, List, Dict, Tuple
17
18constructors_ = dict()
19
20
21def node(kind: str):
22
23    def decorator(cls):
24        cls = dataclass(cls)
25        constructors_[kind] = cls
26        return cls
27
28    return decorator
29
30
31@dataclass
32class SourceLocation:
33    offset: int
34    line: int
35    column: int
36
37
38@dataclass
39class SourceRange:
40    file: int
41    start: SourceLocation
42    end: SourceLocation
43
44
45@dataclass
46class Node:
47    kind: str
48    loc: SourceLocation
49
50
51@node('tag')
52class Tag(Node):
53    id: str
54    value: Optional[int] = field(default=None)
55    range: Optional[Tuple[int, int]] = field(default=None)
56    tags: Optional[List['Tag']] = field(default=None)
57
58
59@node('constraint')
60class Constraint(Node):
61    id: str
62    value: Optional[int]
63    tag_id: Optional[str]
64
65
66@dataclass
67class Field(Node):
68    parent: Node = field(init=False)
69    cond: Optional[Constraint] = field(init=False, default=None)
70    # Backlink to the (optional) optional field referencing
71    # this field as condition.
72    cond_for: Optional['Field'] = field(init=False, default=None)
73
74@node('checksum_field')
75class ChecksumField(Field):
76    field_id: str
77
78
79@node('padding_field')
80class PaddingField(Field):
81    size: int
82
83
84@node('size_field')
85class SizeField(Field):
86    field_id: str
87    width: int
88
89
90@node('elementsize_field')
91class ElementSize(Field):
92    field_id: str
93    width: int
94
95
96@node('count_field')
97class CountField(Field):
98    field_id: str
99    width: int
100
101
102@node('body_field')
103class BodyField(Field):
104    id: str = field(init=False, default='_body_')
105
106
107@node('payload_field')
108class PayloadField(Field):
109    size_modifier: Optional[str]
110    id: str = field(init=False, default='_payload_')
111
112
113@node('fixed_field')
114class FixedField(Field):
115    width: Optional[int] = None
116    value: Optional[int] = None
117    enum_id: Optional[str] = None
118    tag_id: Optional[str] = None
119
120    @property
121    def type(self) -> Optional['Declaration']:
122        return self.parent.file.typedef_scope[self.enum_id] if self.enum_id else None
123
124
125@node('reserved_field')
126class ReservedField(Field):
127    width: int
128
129
130@node('array_field')
131class ArrayField(Field):
132    id: str
133    width: Optional[int]
134    type_id: Optional[str]
135    size_modifier: Optional[str]
136    size: Optional[int]
137    padded_size: Optional[int] = field(init=False, default=None)
138
139    @property
140    def type(self) -> Optional['Declaration']:
141        return self.parent.file.typedef_scope[self.type_id] if self.type_id else None
142
143
144@node('scalar_field')
145class ScalarField(Field):
146    id: str
147    width: int
148
149
150@node('typedef_field')
151class TypedefField(Field):
152    id: str
153    type_id: str
154
155    @property
156    def type(self) -> 'Declaration':
157        return self.parent.file.typedef_scope[self.type_id]
158
159
160@node('group_field')
161class GroupField(Field):
162    group_id: str
163    constraints: List[Constraint]
164
165
166@dataclass
167class Declaration(Node):
168    file: 'File' = field(init=False)
169
170    def __post_init__(self):
171        if hasattr(self, 'fields'):
172            for f in self.fields:
173                f.parent = self
174
175
176@node('endianness_declaration')
177class EndiannessDeclaration(Node):
178    value: str
179
180
181@node('checksum_declaration')
182class ChecksumDeclaration(Declaration):
183    id: str
184    function: str
185    width: int
186
187
188@node('custom_field_declaration')
189class CustomFieldDeclaration(Declaration):
190    id: str
191    function: str
192    width: Optional[int]
193
194
195@node('enum_declaration')
196class EnumDeclaration(Declaration):
197    id: str
198    tags: List[Tag]
199    width: int
200
201
202@node('packet_declaration')
203class PacketDeclaration(Declaration):
204    id: str
205    parent_id: Optional[str]
206    constraints: List[Constraint]
207    fields: List[Field]
208
209    @property
210    def parent(self) -> Optional['PacketDeclaration']:
211        return self.file.packet_scope[self.parent_id] if self.parent_id else None
212
213
214@node('struct_declaration')
215class StructDeclaration(Declaration):
216    id: str
217    parent_id: Optional[str]
218    constraints: List[Constraint]
219    fields: List[Field]
220
221    @property
222    def parent(self) -> Optional['StructDeclaration']:
223        return self.file.typedef_scope[self.parent_id] if self.parent_id else None
224
225
226@node('group_declaration')
227class GroupDeclaration(Declaration):
228    id: str
229    fields: List[Field]
230
231
232@dataclass
233class File:
234    endianness: EndiannessDeclaration
235    declarations: List[Declaration]
236    packet_scope: Dict[str, Declaration] = field(init=False)
237    typedef_scope: Dict[str, Declaration] = field(init=False)
238    group_scope: Dict[str, Declaration] = field(init=False)
239
240    def __post_init__(self):
241        self.packet_scope = dict()
242        self.typedef_scope = dict()
243        self.group_scope = dict()
244
245        # Construct the toplevel declaration scopes.
246        for d in self.declarations:
247            d.file = self
248            if isinstance(d, PacketDeclaration):
249                self.packet_scope[d.id] = d
250            elif isinstance(d, GroupDeclaration):
251                self.group_scope[d.id] = d
252            else:
253                self.typedef_scope[d.id] = d
254
255    @staticmethod
256    def from_json(obj: object) -> 'File':
257        """Import a File exported as JSON object by the PDL parser."""
258        endianness = convert_(obj['endianness'])
259        declarations = convert_(obj['declarations'])
260        return File(endianness, declarations)
261
262    @property
263    def byteorder(self) -> str:
264        return 'little' if self.endianness.value == 'little_endian' else 'big'
265
266    @property
267    def byteorder_short(self, short: bool = False) -> str:
268        return 'le' if self.endianness.value == 'little_endian' else 'be'
269
270
271def convert_(obj: object) -> object:
272    if obj is None:
273        return None
274    if isinstance(obj, (int, str)):
275        return obj
276    if isinstance(obj, list):
277        return [convert_(elt) for elt in obj]
278    if isinstance(obj, object):
279        if 'start' in obj.keys() and 'end' in obj.keys():
280            return (obj['start'], obj['end'])
281        kind = obj['kind']
282        loc = obj['loc']
283        loc = SourceRange(loc['file'], SourceLocation(**loc['start']), SourceLocation(**loc['end']))
284        constructor = constructors_.get(kind)
285        if not constructor:
286            raise Exception(f'Unknown kind {kind}')
287        members = {'loc': loc, 'kind': kind}
288        cond = None
289        for name, value in obj.items():
290            if name == 'cond':
291                cond = convert_(value)
292            elif name != 'kind' and name != 'loc':
293                members[name] = convert_(value)
294        val = constructor(**members)
295        if cond:
296            val.cond = cond
297        return val
298    raise Exception('Unhandled json object type')
299