1import os
2import platform
3import sys
4import unittest
5
6# Run from the root dir
7sys.path.insert(0, '.')
8
9from pycparser import c_parser, c_generator, c_ast, parse_file
10
11_c_parser = c_parser.CParser(
12                lex_optimize=False,
13                yacc_debug=True,
14                yacc_optimize=False,
15                yacctab='yacctab')
16
17
18def compare_asts(ast1, ast2):
19    if type(ast1) != type(ast2):
20        return False
21    if isinstance(ast1, tuple) and isinstance(ast2, tuple):
22        if ast1[0] != ast2[0]:
23            return False
24        ast1 = ast1[1]
25        ast2 = ast2[1]
26        return compare_asts(ast1, ast2)
27    for attr in ast1.attr_names:
28        if getattr(ast1, attr) != getattr(ast2, attr):
29            return False
30    for i, c1 in enumerate(ast1.children()):
31        if compare_asts(c1, ast2.children()[i]) == False:
32            return False
33    return True
34
35
36def parse_to_ast(src):
37    return _c_parser.parse(src)
38
39
40class TestFunctionDeclGeneration(unittest.TestCase):
41    class _FuncDeclVisitor(c_ast.NodeVisitor):
42        def __init__(self):
43            self.stubs = []
44
45        def visit_FuncDecl(self, node):
46            gen = c_generator.CGenerator()
47            self.stubs.append(gen.visit(node))
48
49    def test_partial_funcdecl_generation(self):
50        src = r'''
51            void noop(void);
52            void *something(void *thing);
53            int add(int x, int y);'''
54        ast = parse_to_ast(src)
55        v = TestFunctionDeclGeneration._FuncDeclVisitor()
56        v.visit(ast)
57        self.assertEqual(len(v.stubs), 3)
58        self.assertTrue(r'void noop(void)' in v.stubs)
59        self.assertTrue(r'void *something(void *thing)' in v.stubs)
60        self.assertTrue(r'int add(int x, int y)' in v.stubs)
61
62
63class TestCtoC(unittest.TestCase):
64    def _run_c_to_c(self, src):
65        ast = parse_to_ast(src)
66        generator = c_generator.CGenerator()
67        return generator.visit(ast)
68
69    def _assert_ctoc_correct(self, src):
70        """ Checks that the c2c translation was correct by parsing the code
71            generated by c2c for src and comparing the AST with the original
72            AST.
73        """
74        src2 = self._run_c_to_c(src)
75        self.assertTrue(compare_asts(parse_to_ast(src), parse_to_ast(src2)),
76                        src2)
77
78    def test_trivial_decls(self):
79        self._assert_ctoc_correct('int a;')
80        self._assert_ctoc_correct('int b, a;')
81        self._assert_ctoc_correct('int c, b, a;')
82
83    def test_complex_decls(self):
84        self._assert_ctoc_correct('int** (*a)(void);')
85        self._assert_ctoc_correct('int** (*a)(void*, int);')
86        self._assert_ctoc_correct('int (*b)(char * restrict k, float);')
87        self._assert_ctoc_correct('int test(const char* const* arg);')
88        self._assert_ctoc_correct('int test(const char** const arg);')
89
90        #s = 'int test(const char* const* arg);'
91        #parse_to_ast(s).show()
92
93    def test_ternary(self):
94        self._assert_ctoc_correct('''
95            int main(void)
96            {
97                int a, b;
98                (a == 0) ? (b = 1) : (b = 2);
99            }''')
100
101    def test_casts(self):
102        self._assert_ctoc_correct(r'''
103            int main() {
104                int b = (int) f;
105                int c = (int*) f;
106            }''')
107        self._assert_ctoc_correct(r'''
108            int main() {
109                int a = (int) b + 8;
110                int t = (int) c;
111            }
112        ''')
113
114    def test_initlist(self):
115        self._assert_ctoc_correct('int arr[] = {1, 2, 3};')
116
117    def test_exprs(self):
118        self._assert_ctoc_correct('''
119            int main(void)
120            {
121                int a;
122                int b = a++;
123                int c = ++a;
124                int d = a--;
125                int e = --a;
126            }''')
127
128    def test_statements(self):
129        # note two minuses here
130        self._assert_ctoc_correct(r'''
131            int main() {
132                int a;
133                a = 5;
134                ;
135                b = - - a;
136                return a;
137            }''')
138
139    def test_struct_decl(self):
140        self._assert_ctoc_correct(r'''
141            typedef struct node_t {
142                struct node_t* next;
143                int data;
144            } node;
145            ''')
146
147    def test_krstyle(self):
148        self._assert_ctoc_correct(r'''
149            int main(argc, argv)
150            int argc;
151            char** argv;
152            {
153                return 0;
154            }
155        ''')
156
157    def test_switchcase(self):
158        self._assert_ctoc_correct(r'''
159        int main() {
160            switch (myvar) {
161            case 10:
162            {
163                k = 10;
164                p = k + 1;
165                break;
166            }
167            case 20:
168            case 30:
169                return 20;
170            default:
171                break;
172            }
173        }
174        ''')
175
176    def test_nest_initializer_list(self):
177        self._assert_ctoc_correct(r'''
178        int main()
179        {
180           int i[1][1] = { { 1 } };
181        }''')
182
183    def test_nest_named_initializer(self):
184        self._assert_ctoc_correct(r'''struct test
185            {
186                int i;
187                struct test_i_t
188                {
189                    int k;
190                } test_i;
191                int j;
192            };
193            struct test test_var = {.i = 0, .test_i = {.k = 1}, .j = 2};
194        ''')
195
196    def test_expr_list_in_initializer_list(self):
197        self._assert_ctoc_correct(r'''
198        int main()
199        {
200           int i[1] = { (1, 2) };
201        }''')
202
203    def test_issue36(self):
204        self._assert_ctoc_correct(r'''
205            int main() {
206            }''')
207
208    def test_issue37(self):
209        self._assert_ctoc_correct(r'''
210            int main(void)
211            {
212              unsigned size;
213              size = sizeof(size);
214              return 0;
215            }''')
216
217    def test_issue66(self):
218        # A non-existing body must not be generated
219        # (previous valid behavior, still working)
220        self._assert_ctoc_correct(r'''
221            struct foo;
222            ''')
223        # An empty body must be generated
224        # (added behavior)
225        self._assert_ctoc_correct(r'''
226            struct foo {};
227            ''')
228
229    def test_issue83(self):
230        self._assert_ctoc_correct(r'''
231            void x(void) {
232                int i = (9, k);
233            }
234            ''')
235
236    def test_issue84(self):
237        self._assert_ctoc_correct(r'''
238            void x(void) {
239                for (int i = 0;;)
240                    i;
241            }
242            ''')
243
244    def test_issue246(self):
245        self._assert_ctoc_correct(r'''
246            int array[3] = {[0] = 0, [1] = 1, [1+1] = 2};
247            ''')
248
249    def test_exprlist_with_semi(self):
250        self._assert_ctoc_correct(r'''
251            void x() {
252                if (i < j)
253                    tmp = C[i], C[i] = C[j], C[j] = tmp;
254                if (i <= j)
255                    i++, j--;
256            }
257        ''')
258
259    def test_exprlist_with_subexprlist(self):
260        self._assert_ctoc_correct(r'''
261            void x() {
262                (a = b, (b = c, c = a));
263            }
264        ''')
265
266    def test_comma_operator_funcarg(self):
267        self._assert_ctoc_correct(r'''
268            void f(int x) { return x; }
269            int main(void) { f((1, 2)); return 0; }
270        ''')
271
272    def test_comma_op_in_ternary(self):
273        self._assert_ctoc_correct(r'''
274            void f() {
275                (0, 0) ? (0, 0) : (0, 0);
276            }
277        ''')
278
279    def test_comma_op_assignment(self):
280        self._assert_ctoc_correct(r'''
281            void f() {
282                i = (a, b, c);
283            }
284        ''')
285
286    def test_pragma(self):
287        self._assert_ctoc_correct(r'''
288            #pragma foo
289            void f() {
290                #pragma bar
291                i = (a, b, c);
292            }
293            typedef struct s {
294            #pragma baz
295           } s;
296        ''')
297
298    def test_compound_literal(self):
299        self._assert_ctoc_correct('char **foo = (char *[]){ "x", "y", "z" };')
300        self._assert_ctoc_correct('int i = ++(int){ 1 };')
301        self._assert_ctoc_correct('struct foo_s foo = (struct foo_s){ 1, 2 };')
302
303    def test_enum(self):
304        self._assert_ctoc_correct(r'''
305            enum e
306            {
307              a,
308              b = 2,
309              c = 3
310            };
311        ''')
312        self._assert_ctoc_correct(r'''
313            enum f
314            {
315                g = 4,
316                h,
317                i
318            };
319        ''')
320
321    def test_enum_typedef(self):
322        self._assert_ctoc_correct('typedef enum EnumName EnumTypedefName;')
323
324    def test_generate_struct_union_enum_exception(self):
325        generator = c_generator.CGenerator()
326        self.assertRaises(
327            AssertionError,
328            generator._generate_struct_union_enum,
329            n=c_ast.Struct(
330                name='TestStruct',
331                decls=[],
332            ),
333            name='',
334        )
335
336    def test_array_decl(self):
337        self._assert_ctoc_correct('int g(const int a[const 20]){}')
338        ast = parse_to_ast('const int a[const 20];')
339        generator = c_generator.CGenerator()
340        self.assertEqual(generator.visit(ast.ext[0].type),
341                         'const int [const 20]')
342        self.assertEqual(generator.visit(ast.ext[0].type.type),
343                         'const int')
344
345    def test_ptr_decl(self):
346        src = 'const int ** const  x;'
347        self._assert_ctoc_correct(src)
348        ast = parse_to_ast(src)
349        generator = c_generator.CGenerator()
350        self.assertEqual(generator.visit(ast.ext[0].type),
351                         'const int ** const')
352        self.assertEqual(generator.visit(ast.ext[0].type.type),
353                         'const int *')
354        self.assertEqual(generator.visit(ast.ext[0].type.type.type),
355                         'const int')
356
357
358class TestCasttoC(unittest.TestCase):
359    def _find_file(self, name):
360        test_dir = os.path.dirname(__file__)
361        name = os.path.join(test_dir, 'c_files', name)
362        assert os.path.exists(name)
363        return name
364
365    def test_to_type(self):
366        src = 'int *x;'
367        generator = c_generator.CGenerator()
368        test_fun = c_ast.FuncCall(c_ast.ID('test_fun'), c_ast.ExprList([]))
369
370        ast1 = parse_to_ast(src)
371        int_ptr_type = ast1.ext[0].type
372        int_type = int_ptr_type.type
373        self.assertEqual(generator.visit(c_ast.Cast(int_ptr_type, test_fun)),
374                         '(int *) test_fun()')
375        self.assertEqual(generator.visit(c_ast.Cast(int_type, test_fun)),
376                         '(int) test_fun()')
377
378    @unittest.skipUnless(platform.system() == 'Linux',
379                         'cpp only works on Linux')
380    def test_to_type_with_cpp(self):
381        generator = c_generator.CGenerator()
382        test_fun = c_ast.FuncCall(c_ast.ID('test_fun'), c_ast.ExprList([]))
383        memmgr_path = self._find_file('memmgr.h')
384
385        ast2 = parse_file(memmgr_path, use_cpp=True)
386        void_ptr_type = ast2.ext[-3].type.type
387        void_type = void_ptr_type.type
388        self.assertEqual(generator.visit(c_ast.Cast(void_ptr_type, test_fun)),
389                         '(void *) test_fun()')
390        self.assertEqual(generator.visit(c_ast.Cast(void_type, test_fun)),
391                         '(void) test_fun()')
392
393if __name__ == "__main__":
394    unittest.main()
395