1#------------------------------------------------------------------------------
2# pycparser: c_generator.py
3#
4# C code generator from pycparser AST nodes.
5#
6# Eli Bendersky [https://eli.thegreenplace.net/]
7# License: BSD
8#------------------------------------------------------------------------------
9from . import c_ast
10
11
12class CGenerator(object):
13    """ Uses the same visitor pattern as c_ast.NodeVisitor, but modified to
14        return a value from each visit method, using string accumulation in
15        generic_visit.
16    """
17    def __init__(self):
18        # Statements start with indentation of self.indent_level spaces, using
19        # the _make_indent method
20        #
21        self.indent_level = 0
22
23    def _make_indent(self):
24        return ' ' * self.indent_level
25
26    def visit(self, node):
27        method = 'visit_' + node.__class__.__name__
28        return getattr(self, method, self.generic_visit)(node)
29
30    def generic_visit(self, node):
31        #~ print('generic:', type(node))
32        if node is None:
33            return ''
34        else:
35            return ''.join(self.visit(c) for c_name, c in node.children())
36
37    def visit_Constant(self, n):
38        return n.value
39
40    def visit_ID(self, n):
41        return n.name
42
43    def visit_Pragma(self, n):
44        ret = '#pragma'
45        if n.string:
46            ret += ' ' + n.string
47        return ret
48
49    def visit_ArrayRef(self, n):
50        arrref = self._parenthesize_unless_simple(n.name)
51        return arrref + '[' + self.visit(n.subscript) + ']'
52
53    def visit_StructRef(self, n):
54        sref = self._parenthesize_unless_simple(n.name)
55        return sref + n.type + self.visit(n.field)
56
57    def visit_FuncCall(self, n):
58        fref = self._parenthesize_unless_simple(n.name)
59        return fref + '(' + self.visit(n.args) + ')'
60
61    def visit_UnaryOp(self, n):
62        operand = self._parenthesize_unless_simple(n.expr)
63        if n.op == 'p++':
64            return '%s++' % operand
65        elif n.op == 'p--':
66            return '%s--' % operand
67        elif n.op == 'sizeof':
68            # Always parenthesize the argument of sizeof since it can be
69            # a name.
70            return 'sizeof(%s)' % self.visit(n.expr)
71        else:
72            return '%s%s' % (n.op, operand)
73
74    def visit_BinaryOp(self, n):
75        lval_str = self._parenthesize_if(n.left,
76                            lambda d: not self._is_simple_node(d))
77        rval_str = self._parenthesize_if(n.right,
78                            lambda d: not self._is_simple_node(d))
79        return '%s %s %s' % (lval_str, n.op, rval_str)
80
81    def visit_Assignment(self, n):
82        rval_str = self._parenthesize_if(
83                            n.rvalue,
84                            lambda n: isinstance(n, c_ast.Assignment))
85        return '%s %s %s' % (self.visit(n.lvalue), n.op, rval_str)
86
87    def visit_IdentifierType(self, n):
88        return ' '.join(n.names)
89
90    def _visit_expr(self, n):
91        if isinstance(n, c_ast.InitList):
92            return '{' + self.visit(n) + '}'
93        elif isinstance(n, c_ast.ExprList):
94            return '(' + self.visit(n) + ')'
95        else:
96            return self.visit(n)
97
98    def visit_Decl(self, n, no_type=False):
99        # no_type is used when a Decl is part of a DeclList, where the type is
100        # explicitly only for the first declaration in a list.
101        #
102        s = n.name if no_type else self._generate_decl(n)
103        if n.bitsize: s += ' : ' + self.visit(n.bitsize)
104        if n.init:
105            s += ' = ' + self._visit_expr(n.init)
106        return s
107
108    def visit_DeclList(self, n):
109        s = self.visit(n.decls[0])
110        if len(n.decls) > 1:
111            s += ', ' + ', '.join(self.visit_Decl(decl, no_type=True)
112                                    for decl in n.decls[1:])
113        return s
114
115    def visit_Typedef(self, n):
116        s = ''
117        if n.storage: s += ' '.join(n.storage) + ' '
118        s += self._generate_type(n.type)
119        return s
120
121    def visit_Cast(self, n):
122        s = '(' + self._generate_type(n.to_type, emit_declname=False) + ')'
123        return s + ' ' + self._parenthesize_unless_simple(n.expr)
124
125    def visit_ExprList(self, n):
126        visited_subexprs = []
127        for expr in n.exprs:
128            visited_subexprs.append(self._visit_expr(expr))
129        return ', '.join(visited_subexprs)
130
131    def visit_InitList(self, n):
132        visited_subexprs = []
133        for expr in n.exprs:
134            visited_subexprs.append(self._visit_expr(expr))
135        return ', '.join(visited_subexprs)
136
137    def visit_Enum(self, n):
138        return self._generate_struct_union_enum(n, name='enum')
139
140    def visit_Enumerator(self, n):
141        if not n.value:
142            return '{indent}{name},\n'.format(
143                indent=self._make_indent(),
144                name=n.name,
145            )
146        else:
147            return '{indent}{name} = {value},\n'.format(
148                indent=self._make_indent(),
149                name=n.name,
150                value=self.visit(n.value),
151            )
152
153    def visit_FuncDef(self, n):
154        decl = self.visit(n.decl)
155        self.indent_level = 0
156        body = self.visit(n.body)
157        if n.param_decls:
158            knrdecls = ';\n'.join(self.visit(p) for p in n.param_decls)
159            return decl + '\n' + knrdecls + ';\n' + body + '\n'
160        else:
161            return decl + '\n' + body + '\n'
162
163    def visit_FileAST(self, n):
164        s = ''
165        for ext in n.ext:
166            if isinstance(ext, c_ast.FuncDef):
167                s += self.visit(ext)
168            elif isinstance(ext, c_ast.Pragma):
169                s += self.visit(ext) + '\n'
170            else:
171                s += self.visit(ext) + ';\n'
172        return s
173
174    def visit_Compound(self, n):
175        s = self._make_indent() + '{\n'
176        self.indent_level += 2
177        if n.block_items:
178            s += ''.join(self._generate_stmt(stmt) for stmt in n.block_items)
179        self.indent_level -= 2
180        s += self._make_indent() + '}\n'
181        return s
182
183    def visit_CompoundLiteral(self, n):
184        return '(' + self.visit(n.type) + '){' + self.visit(n.init) + '}'
185
186
187    def visit_EmptyStatement(self, n):
188        return ';'
189
190    def visit_ParamList(self, n):
191        return ', '.join(self.visit(param) for param in n.params)
192
193    def visit_Return(self, n):
194        s = 'return'
195        if n.expr: s += ' ' + self.visit(n.expr)
196        return s + ';'
197
198    def visit_Break(self, n):
199        return 'break;'
200
201    def visit_Continue(self, n):
202        return 'continue;'
203
204    def visit_TernaryOp(self, n):
205        s  = '(' + self._visit_expr(n.cond) + ') ? '
206        s += '(' + self._visit_expr(n.iftrue) + ') : '
207        s += '(' + self._visit_expr(n.iffalse) + ')'
208        return s
209
210    def visit_If(self, n):
211        s = 'if ('
212        if n.cond: s += self.visit(n.cond)
213        s += ')\n'
214        s += self._generate_stmt(n.iftrue, add_indent=True)
215        if n.iffalse:
216            s += self._make_indent() + 'else\n'
217            s += self._generate_stmt(n.iffalse, add_indent=True)
218        return s
219
220    def visit_For(self, n):
221        s = 'for ('
222        if n.init: s += self.visit(n.init)
223        s += ';'
224        if n.cond: s += ' ' + self.visit(n.cond)
225        s += ';'
226        if n.next: s += ' ' + self.visit(n.next)
227        s += ')\n'
228        s += self._generate_stmt(n.stmt, add_indent=True)
229        return s
230
231    def visit_While(self, n):
232        s = 'while ('
233        if n.cond: s += self.visit(n.cond)
234        s += ')\n'
235        s += self._generate_stmt(n.stmt, add_indent=True)
236        return s
237
238    def visit_DoWhile(self, n):
239        s = 'do\n'
240        s += self._generate_stmt(n.stmt, add_indent=True)
241        s += self._make_indent() + 'while ('
242        if n.cond: s += self.visit(n.cond)
243        s += ');'
244        return s
245
246    def visit_Switch(self, n):
247        s = 'switch (' + self.visit(n.cond) + ')\n'
248        s += self._generate_stmt(n.stmt, add_indent=True)
249        return s
250
251    def visit_Case(self, n):
252        s = 'case ' + self.visit(n.expr) + ':\n'
253        for stmt in n.stmts:
254            s += self._generate_stmt(stmt, add_indent=True)
255        return s
256
257    def visit_Default(self, n):
258        s = 'default:\n'
259        for stmt in n.stmts:
260            s += self._generate_stmt(stmt, add_indent=True)
261        return s
262
263    def visit_Label(self, n):
264        return n.name + ':\n' + self._generate_stmt(n.stmt)
265
266    def visit_Goto(self, n):
267        return 'goto ' + n.name + ';'
268
269    def visit_EllipsisParam(self, n):
270        return '...'
271
272    def visit_Struct(self, n):
273        return self._generate_struct_union_enum(n, 'struct')
274
275    def visit_Typename(self, n):
276        return self._generate_type(n.type)
277
278    def visit_Union(self, n):
279        return self._generate_struct_union_enum(n, 'union')
280
281    def visit_NamedInitializer(self, n):
282        s = ''
283        for name in n.name:
284            if isinstance(name, c_ast.ID):
285                s += '.' + name.name
286            else:
287                s += '[' + self.visit(name) + ']'
288        s += ' = ' + self._visit_expr(n.expr)
289        return s
290
291    def visit_FuncDecl(self, n):
292        return self._generate_type(n)
293
294    def visit_ArrayDecl(self, n):
295        return self._generate_type(n, emit_declname=False)
296
297    def visit_TypeDecl(self, n):
298        return self._generate_type(n, emit_declname=False)
299
300    def visit_PtrDecl(self, n):
301        return self._generate_type(n, emit_declname=False)
302
303    def _generate_struct_union_enum(self, n, name):
304        """ Generates code for structs, unions, and enums. name should be
305            'struct', 'union', or 'enum'.
306        """
307        if name in ('struct', 'union'):
308            members = n.decls
309            body_function = self._generate_struct_union_body
310        else:
311            assert name == 'enum'
312            members = None if n.values is None else n.values.enumerators
313            body_function = self._generate_enum_body
314        s = name + ' ' + (n.name or '')
315        if members is not None:
316            # None means no members
317            # Empty sequence means an empty list of members
318            s += '\n'
319            s += self._make_indent()
320            self.indent_level += 2
321            s += '{\n'
322            s += body_function(members)
323            self.indent_level -= 2
324            s += self._make_indent() + '}'
325        return s
326
327    def _generate_struct_union_body(self, members):
328        return ''.join(self._generate_stmt(decl) for decl in members)
329
330    def _generate_enum_body(self, members):
331        # `[:-2] + '\n'` removes the final `,` from the enumerator list
332        return ''.join(self.visit(value) for value in members)[:-2] + '\n'
333
334    def _generate_stmt(self, n, add_indent=False):
335        """ Generation from a statement node. This method exists as a wrapper
336            for individual visit_* methods to handle different treatment of
337            some statements in this context.
338        """
339        typ = type(n)
340        if add_indent: self.indent_level += 2
341        indent = self._make_indent()
342        if add_indent: self.indent_level -= 2
343
344        if typ in (
345                c_ast.Decl, c_ast.Assignment, c_ast.Cast, c_ast.UnaryOp,
346                c_ast.BinaryOp, c_ast.TernaryOp, c_ast.FuncCall, c_ast.ArrayRef,
347                c_ast.StructRef, c_ast.Constant, c_ast.ID, c_ast.Typedef,
348                c_ast.ExprList):
349            # These can also appear in an expression context so no semicolon
350            # is added to them automatically
351            #
352            return indent + self.visit(n) + ';\n'
353        elif typ in (c_ast.Compound,):
354            # No extra indentation required before the opening brace of a
355            # compound - because it consists of multiple lines it has to
356            # compute its own indentation.
357            #
358            return self.visit(n)
359        else:
360            return indent + self.visit(n) + '\n'
361
362    def _generate_decl(self, n):
363        """ Generation from a Decl node.
364        """
365        s = ''
366        if n.funcspec: s = ' '.join(n.funcspec) + ' '
367        if n.storage: s += ' '.join(n.storage) + ' '
368        s += self._generate_type(n.type)
369        return s
370
371    def _generate_type(self, n, modifiers=[], emit_declname = True):
372        """ Recursive generation from a type node. n is the type node.
373            modifiers collects the PtrDecl, ArrayDecl and FuncDecl modifiers
374            encountered on the way down to a TypeDecl, to allow proper
375            generation from it.
376        """
377        typ = type(n)
378        #~ print(n, modifiers)
379
380        if typ == c_ast.TypeDecl:
381            s = ''
382            if n.quals: s += ' '.join(n.quals) + ' '
383            s += self.visit(n.type)
384
385            nstr = n.declname if n.declname and emit_declname else ''
386            # Resolve modifiers.
387            # Wrap in parens to distinguish pointer to array and pointer to
388            # function syntax.
389            #
390            for i, modifier in enumerate(modifiers):
391                if isinstance(modifier, c_ast.ArrayDecl):
392                    if (i != 0 and
393                        isinstance(modifiers[i - 1], c_ast.PtrDecl)):
394                            nstr = '(' + nstr + ')'
395                    nstr += '['
396                    if modifier.dim_quals:
397                        nstr += ' '.join(modifier.dim_quals) + ' '
398                    nstr += self.visit(modifier.dim) + ']'
399                elif isinstance(modifier, c_ast.FuncDecl):
400                    if (i != 0 and
401                        isinstance(modifiers[i - 1], c_ast.PtrDecl)):
402                            nstr = '(' + nstr + ')'
403                    nstr += '(' + self.visit(modifier.args) + ')'
404                elif isinstance(modifier, c_ast.PtrDecl):
405                    if modifier.quals:
406                        nstr = '* %s%s' % (' '.join(modifier.quals),
407                                           ' ' + nstr if nstr else '')
408                    else:
409                        nstr = '*' + nstr
410            if nstr: s += ' ' + nstr
411            return s
412        elif typ == c_ast.Decl:
413            return self._generate_decl(n.type)
414        elif typ == c_ast.Typename:
415            return self._generate_type(n.type, emit_declname = emit_declname)
416        elif typ == c_ast.IdentifierType:
417            return ' '.join(n.names) + ' '
418        elif typ in (c_ast.ArrayDecl, c_ast.PtrDecl, c_ast.FuncDecl):
419            return self._generate_type(n.type, modifiers + [n],
420                                       emit_declname = emit_declname)
421        else:
422            return self.visit(n)
423
424    def _parenthesize_if(self, n, condition):
425        """ Visits 'n' and returns its string representation, parenthesized
426            if the condition function applied to the node returns True.
427        """
428        s = self._visit_expr(n)
429        if condition(n):
430            return '(' + s + ')'
431        else:
432            return s
433
434    def _parenthesize_unless_simple(self, n):
435        """ Common use case for _parenthesize_if
436        """
437        return self._parenthesize_if(n, lambda d: not self._is_simple_node(d))
438
439    def _is_simple_node(self, n):
440        """ Returns True for nodes that are "simple" - i.e. nodes that always
441            have higher precedence than operators.
442        """
443        return isinstance(n, (c_ast.Constant, c_ast.ID, c_ast.ArrayRef,
444                              c_ast.StructRef, c_ast.FuncCall))
445