xref: /aosp_15_r20/external/sandboxed-api/sandboxed_api/tools/generator2/code.py (revision ec63e07ab9515d95e79c211197c445ef84cefa6a)
1# Copyright 2019 Google LLC
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""Module related to code analysis and generation."""
15
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19from ctypes import util
20import itertools
21import os
22# pylint: disable=unused-import
23from typing import (Text, List, Optional, Set, Dict, Callable, IO,
24                    Generator as Gen, Tuple, Union, Sequence)  # pyformat: disable
25# pylint: enable=unused-import
26from clang import cindex
27
28
29_PARSE_OPTIONS = (
30    cindex.TranslationUnit.PARSE_SKIP_FUNCTION_BODIES
31    | cindex.TranslationUnit.PARSE_INCOMPLETE |
32    # for include directives
33    cindex.TranslationUnit.PARSE_DETAILED_PROCESSING_RECORD)
34
35
36def _init_libclang():
37  """Finds and initializes the libclang library."""
38  if cindex.Config.loaded:
39    return
40  # Try to find libclang in the standard location and a few versioned paths
41  # that are used on Debian (and others). If LD_LIBRARY_PATH is set, it is
42  # used as well.
43  for version in [
44      '',
45      '16',
46      '15',
47      '14',
48      '13',
49      '12',
50      '11',
51      '10',
52      '9',
53      '8',
54      '7',
55      '6.0',
56      '5.0',
57      '4.0',
58  ]:
59    libname = 'clang' + ('-' + version if version else '')
60    libclang = util.find_library(libname)
61    if libclang:
62      cindex.Config.set_library_file(libclang)
63      break
64
65
66def get_header_guard(path):
67  # type: (Text) -> Text
68  """Generates header guard string from path."""
69  # the output file will be most likely somewhere in genfiles, strip the
70  # prefix in that case, also strip .gen if this is a step before clang-format
71  if not path:
72    raise ValueError('Cannot prepare header guard from path: {}'.format(path))
73  if 'genfiles/' in path:
74    path = path.split('genfiles/')[1]
75  if path.endswith('.gen'):
76    path = path.split('.gen')[0]
77  path = path.upper().replace('.', '_').replace('-', '_').replace('/', '_')
78  return path + '_'
79
80
81def _stringify_tokens(tokens, separator='\n'):
82  # type: (Sequence[cindex.Token], Text) -> Text
83  """Converts tokens to text respecting line position (disrespecting column)."""
84  previous = OutputLine(0, [])  # not used in output
85  lines = []  # type: List[OutputLine]
86
87  for _, group in itertools.groupby(tokens, lambda t: t.location.line):
88    group_list = list(group)
89    line = OutputLine(previous.next_tab, group_list)
90
91    lines.append(line)
92    previous = line
93
94  return separator.join(str(l) for l in lines)
95
96
97TYPE_MAPPING = {
98    cindex.TypeKind.VOID: '::sapi::v::Void',
99    cindex.TypeKind.CHAR_S: '::sapi::v::Char',
100    cindex.TypeKind.CHAR_U: '::sapi::v::Char',
101    cindex.TypeKind.INT: '::sapi::v::Int',
102    cindex.TypeKind.UINT: '::sapi::v::UInt',
103    cindex.TypeKind.LONG: '::sapi::v::Long',
104    cindex.TypeKind.ULONG: '::sapi::v::ULong',
105    cindex.TypeKind.UCHAR: '::sapi::v::UChar',
106    cindex.TypeKind.USHORT: '::sapi::v::UShort',
107    cindex.TypeKind.SHORT: '::sapi::v::Short',
108    cindex.TypeKind.LONGLONG: '::sapi::v::LLong',
109    cindex.TypeKind.ULONGLONG: '::sapi::v::ULLong',
110    cindex.TypeKind.FLOAT: '::sapi::v::Reg<float>',
111    cindex.TypeKind.DOUBLE: '::sapi::v::Reg<double>',
112    cindex.TypeKind.LONGDOUBLE: '::sapi::v::Reg<long double>',
113    cindex.TypeKind.SCHAR: '::sapi::v::SChar',
114    cindex.TypeKind.SHORT: '::sapi::v::Short',
115    cindex.TypeKind.BOOL: '::sapi::v::Bool',
116}
117
118
119class Type(object):
120  """Class representing a type.
121
122  Wraps cindex.Type of the argument/return value and provides helpers for the
123  code generation.
124  """
125
126  def __init__(self, tu, clang_type):
127    # type: (_TranslationUnit, cindex.Type) -> None
128    self._clang_type = clang_type
129    self._tu = tu
130
131  # pylint: disable=protected-access
132  def __eq__(self, other):
133    # type: (Type) -> bool
134    # Use get_usr() to deduplicate Type objects based on declaration
135    decl = self._get_declaration()
136    decl_o = other._get_declaration()
137
138    return decl.get_usr() == decl_o.get_usr()
139
140  def __ne__(self, other):
141    # type: (Type) -> bool
142    return not self.__eq__(other)
143
144  def __lt__(self, other):
145    # type: (Type) -> bool
146    """Compares two Types belonging to the same TranslationUnit.
147
148    This is being used to properly order types before emitting to generated
149    file. To be more specific: structure definition that contains field that is
150    a typedef should end up after that typedef definition. This is achieved by
151    exploiting the order in which clang iterate over AST in translation unit.
152
153    Args:
154      other: other comparison type
155
156    Returns:
157      true if this Type occurs earlier in the AST than 'other'
158    """
159    self._validate_tu(other)
160    return (self._tu.order[self._get_declaration().hash] <
161            self._tu.order[other._get_declaration().hash])  # pylint: disable=protected-access
162
163  def __gt__(self, other):
164    # type: (Type) -> bool
165    """Compares two Types belonging to the same TranslationUnit.
166
167    This is being used to properly order types before emitting to generated
168    file. To be more specific: structure definition that contains field that is
169    a typedef should end up after that typedef definition. This is achieved by
170    exploiting the order in which clang iterate over AST in translation unit.
171
172    Args:
173      other: other comparison type
174
175    Returns:
176      true if this Type occurs later in the AST than 'other'
177    """
178    self._validate_tu(other)
179    return (self._tu.order[self._get_declaration().hash] >
180            self._tu.order[other._get_declaration().hash])  # pylint: disable=protected-access
181
182  def __hash__(self):
183    """Types with the same declaration should hash to the same value."""
184    return hash(self._get_declaration().get_usr())
185
186  def _validate_tu(self, other):
187    # type: (Type) -> None
188    if self._tu != other._tu:  # pylint: disable=protected-access
189      raise ValueError('Cannot compare types from different translation units.')
190
191  def is_void(self):
192    # type: () -> bool
193    return self._clang_type.kind == cindex.TypeKind.VOID
194
195  def is_typedef(self):
196    # type: () -> bool
197    return self._clang_type.kind == cindex.TypeKind.TYPEDEF
198
199  def is_elaborated(self):
200    # type: () -> bool
201    return self._clang_type.kind == cindex.TypeKind.ELABORATED
202
203  # Hack: both class and struct types are indistinguishable except for
204  # declaration cursor kind
205  def is_sugared_record(self):  # class, struct, union
206    # type: () -> bool
207    return self._clang_type.get_declaration().kind in (
208        cindex.CursorKind.STRUCT_DECL, cindex.CursorKind.UNION_DECL,
209        cindex.CursorKind.CLASS_DECL)
210
211  def is_struct(self):
212    # type: () -> bool
213    return (self._clang_type.get_declaration().kind ==
214            cindex.CursorKind.STRUCT_DECL)
215
216  def is_class(self):
217    # type: () -> bool
218    return (self._clang_type.get_declaration().kind ==
219            cindex.CursorKind.CLASS_DECL)
220
221  def is_union(self):
222    # type: () -> bool
223    return (self._clang_type.get_declaration().kind ==
224            cindex.CursorKind.UNION_DECL)
225
226  def is_function(self):
227    # type: () -> bool
228    return self._clang_type.kind == cindex.TypeKind.FUNCTIONPROTO
229
230  def is_sugared_ptr(self):
231    # type: () -> bool
232    return self._clang_type.get_canonical().kind == cindex.TypeKind.POINTER
233
234  def is_sugared_enum(self):
235    # type: () -> bool
236    return self._clang_type.get_canonical().kind == cindex.TypeKind.ENUM
237
238  def is_const_array(self):
239    # type: () -> bool
240    return self._clang_type.kind == cindex.TypeKind.CONSTANTARRAY
241
242  def is_simple_type(self):
243    # type: () -> bool
244    return self._clang_type.kind in TYPE_MAPPING
245
246  def get_pointee(self):
247    # type: () -> Type
248    return Type(self._tu, self._clang_type.get_pointee())
249
250  def _get_declaration(self):
251    # type: () -> cindex.Cursor
252    decl = self._clang_type.get_declaration()
253    if decl.kind == cindex.CursorKind.NO_DECL_FOUND and self.is_sugared_ptr():
254      decl = self.get_pointee()._get_declaration()  # pylint: disable=protected-access
255
256    return decl
257
258  def get_related_types(self, result=None, skip_self=False):
259    # type: (Optional[Set[Type]], bool) -> Set[Type]
260    """Returns all types related to this one eg. typedefs, nested structs."""
261    if result is None:
262      result = set()
263
264    # Base case.
265    if self in result or self.is_simple_type() or self.is_class():
266      return result
267
268    # Sugar types.
269    if self.is_typedef():
270      return self._get_related_types_of_typedef(result)
271
272    if self.is_elaborated():
273      return Type(self._tu,
274                  self._clang_type.get_named_type()).get_related_types(
275                      result, skip_self)
276
277    # Composite types.
278    if self.is_const_array():
279      t = Type(self._tu, self._clang_type.get_array_element_type())
280      return t.get_related_types(result)
281
282    if self._clang_type.kind in (cindex.TypeKind.POINTER,
283                                 cindex.TypeKind.MEMBERPOINTER,
284                                 cindex.TypeKind.LVALUEREFERENCE,
285                                 cindex.TypeKind.RVALUEREFERENCE):
286      return self.get_pointee().get_related_types(result, skip_self)
287
288    # union + struct, class should be filtered out
289    if self.is_struct() or self.is_union():
290      return self._get_related_types_of_record(result, skip_self)
291
292    if self.is_function():
293      return self._get_related_types_of_function(result)
294
295    if self.is_sugared_enum():
296      if not skip_self:
297        result.add(self)
298        self._tu.search_for_macro_name(self._get_declaration())
299      return result
300
301    # Ignore all cindex.TypeKind.UNEXPOSED AST nodes
302    # TODO(b/256934562): Remove the disable once the pytype bug is fixed.
303    return result  # pytype: disable=bad-return-type
304
305  def _get_related_types_of_typedef(self, result):
306    # type: (Set[Type]) -> Set[Type]
307    """Returns all intermediate types related to the typedef."""
308    result.add(self)
309    decl = self._clang_type.get_declaration()
310    self._tu.search_for_macro_name(decl)
311
312    t = Type(self._tu, decl.underlying_typedef_type)
313    if t.is_sugared_ptr():
314      t = t.get_pointee()
315
316    if not t.is_simple_type():
317      skip_child = self.contains_declaration(t)
318      if t.is_sugared_record() and skip_child:
319        # if child declaration is contained in parent, we don't have to emit it
320        self._tu.types_to_skip.add(t)
321      result.update(t.get_related_types(result, skip_child))
322
323    return result
324
325  def _get_related_types_of_record(self, result, skip_self=False):
326    # type: (Set[Type], bool) -> Set[Type]
327    """Returns all types related to the structure."""
328    # skip unnamed structures eg. typedef struct {...} x;
329    # struct {...} will be rendered as part of typedef rendering
330    decl = self._get_declaration()
331    if not decl.is_anonymous() and not skip_self:
332      self._tu.search_for_macro_name(decl)
333      result.add(self)
334
335    for f in self._clang_type.get_fields():
336      self._tu.search_for_macro_name(f)
337      result.update(Type(self._tu, f.type).get_related_types(result))
338
339    return result
340
341  def _get_related_types_of_function(self, result):
342    # type: (Set[Type]) -> Set[Type]
343    """Returns all types related to the function."""
344    for arg in self._clang_type.argument_types():
345      result.update(Type(self._tu, arg).get_related_types(result))
346    related = Type(self._tu,
347                   self._clang_type.get_result()).get_related_types(result)
348    result.update(related)
349
350    return result
351
352  def contains_declaration(self, other):
353    # type: (Type) -> bool
354    """Checks if string representation of a type contains the other type."""
355    self_extent = self._get_declaration().extent
356    other_extent = other._get_declaration().extent  # pylint: disable=protected-access
357
358    if other_extent.start.file is None:
359      return False
360    return (other_extent.start in self_extent and
361            other_extent.end in self_extent)
362
363  def stringify(self):
364    # type: () -> Text
365    """Returns string representation of the Type."""
366    # (szwl): as simple as possible, keeps macros in separate lines not to
367    # break things; this will go through clang format nevertheless
368    tokens = [
369        x for x in self._get_declaration().get_tokens()
370        if x.kind is not cindex.TokenKind.COMMENT
371    ]
372
373    return _stringify_tokens(tokens)
374
375
376class OutputLine(object):
377  """Helper class for Type printing."""
378
379  def __init__(self, tab, tokens):
380    # type: (int, List[cindex.Token]) -> None
381    self.tokens = tokens
382    self.spellings = []
383    self.define = False
384    self.tab = tab
385    self.next_tab = tab
386    list(map(self._process_token, self.tokens))
387
388  def _process_token(self, t):
389    # type: (cindex.Token) -> None
390    """Processes a token, setting up internal states rel. to intendation."""
391    if t.spelling == '#':
392      self.define = True
393    elif t.spelling == '{':
394      self.next_tab += 1
395    elif t.spelling == '}':
396      self.tab -= 1
397      self.next_tab -= 1
398
399    is_bracket = t.spelling == '('
400    is_macro = len(self.spellings) == 1 and self.spellings[0] == '#'
401    if self.spellings and not is_bracket and not is_macro:
402      self.spellings.append(' ')
403    self.spellings.append(t.spelling)
404
405  def __str__(self):
406    # type: () -> Text
407    tabs = ('\t' * self.tab) if not self.define else ''
408    return tabs + ''.join(t for t in self.spellings)
409
410
411class ArgumentType(Type):
412  """Class representing function argument type.
413
414  Object fields are being used by the code template:
415  pos: argument position
416  type: string representation of the type
417  argument: string representation of the type as function argument
418  mapped_type: SAPI equivalent of the type
419  wrapped: wraps type in SAPI object constructor
420  call_argument: type (or it's sapi wrapper) used in function call
421  """
422
423  def __init__(self, function, pos, arg_type, name=None):
424    # type: (Function, int, cindex.Type, Optional[Text]) -> None
425    super(ArgumentType, self).__init__(function.translation_unit(), arg_type)
426    self._function = function
427
428    self.pos = pos
429    self.name = name or 'a{}'.format(pos)
430    self.type = arg_type.spelling
431
432    template = '{}' if self.is_sugared_ptr() else '&{}_'
433    self.call_argument = template.format(self.name)
434
435  def __str__(self):
436    # type: () -> Text
437    """Returns function argument prepared from the type."""
438    if self.is_sugared_ptr():
439      return '::sapi::v::Ptr* {}'.format(self.name)
440
441    return '{} {}'.format(self._clang_type.spelling, self.name)
442
443  @property
444  def wrapped(self):
445    # type: () -> Text
446    return '{} {name}_(({name}))'.format(self.mapped_type, name=self.name)
447
448  @property
449  def mapped_type(self):
450    # type: () -> Text
451    """Maps the type to its SAPI equivalent."""
452    if self.is_sugared_ptr():
453      # TODO(szwl): const ptrs do not play well with SAPI C++ API...
454      spelling = self._clang_type.spelling.replace('const', '')
455      return '::sapi::v::Reg<{}>'.format(spelling)
456
457    type_ = self._clang_type
458
459    if type_.kind == cindex.TypeKind.TYPEDEF:
460      type_ = self._clang_type.get_canonical()
461    if type_.kind == cindex.TypeKind.ELABORATED:
462      type_ = type_.get_canonical()
463    if type_.kind == cindex.TypeKind.ENUM:
464      return '::sapi::v::IntBase<{}>'.format(self._clang_type.spelling)
465    if type_.kind in [
466        cindex.TypeKind.CONSTANTARRAY, cindex.TypeKind.INCOMPLETEARRAY
467    ]:
468      return '::sapi::v::Reg<{}>'.format(self._clang_type.spelling)
469
470    if type_.kind == cindex.TypeKind.LVALUEREFERENCE:
471      return 'LVALUEREFERENCE::NOT_SUPPORTED'
472
473    if type_.kind == cindex.TypeKind.RVALUEREFERENCE:
474      return 'RVALUEREFERENCE::NOT_SUPPORTED'
475
476    if type_.kind in [cindex.TypeKind.RECORD, cindex.TypeKind.ELABORATED]:
477      raise ValueError('Elaborate type (eg. struct) in mapped_type is not '
478                       'supported: function {}, arg {}, type {}, location {}'
479                       ''.format(self._function.name, self.pos,
480                                 self._clang_type.spelling,
481                                 self._function.cursor.location))
482
483    if type_.kind not in TYPE_MAPPING:
484      raise KeyError('Key {} does not exist in TYPE_MAPPING.'
485                     ' function {}, arg {}, type {}, location {}'
486                     ''.format(type_.kind, self._function.name, self.pos,
487                               self._clang_type.spelling,
488                               self._function.cursor.location))
489
490    return TYPE_MAPPING[type_.kind]
491
492
493class ReturnType(ArgumentType):
494  """Class representing function return type.
495
496     Attributes:
497       return_type: absl::StatusOr<T> where T is original return type, or
498                    absl::Status for functions returning void
499  """
500
501  def __init__(self, function, arg_type):
502    # type: (Function, cindex.Type) -> None
503    super(ReturnType, self).__init__(function, 0, arg_type, None)
504
505  def __str__(self):
506    # type: () -> Text
507    """Returns function return type prepared from the type."""
508    # TODO(szwl): const ptrs do not play well with SAPI C++ API...
509    spelling = self._clang_type.spelling.replace('const', '')
510    return_type = 'absl::StatusOr<{}>'.format(spelling)
511    return_type = 'absl::Status' if self.is_void() else return_type
512    return return_type
513
514
515class Function(object):
516  """Class representing SAPI-wrapped function used by the template.
517
518  Wraps Clang cursor object of kind FUNCTION_DECL and provides helpers to
519  aid code generation.
520  """
521
522  def __init__(self, tu, cursor):
523    # type: (_TranslationUnit, cindex.Cursor) -> None
524    self._tu = tu
525    self.cursor = cursor  # type: cindex.Index
526    self.name = cursor.spelling  # type: Text
527    self.result = ReturnType(self, cursor.result_type)
528    self.original_definition = '{} {}'.format(
529        cursor.result_type.spelling, self.cursor.displayname)  # type: Text
530
531    types = self.cursor.get_arguments()
532    self.argument_types = [
533        ArgumentType(self, i, t.type, t.spelling) for i, t in enumerate(types)
534    ]
535
536  def translation_unit(self):
537    # type: () -> _TranslationUnit
538    return self._tu
539
540  def arguments(self):
541    # type: () -> List[ArgumentType]
542    return self.argument_types
543
544  def call_arguments(self):
545    # type: () -> List[Text]
546    return [a.call_argument for a in self.argument_types]
547
548  def get_absolute_path(self):
549    # type: () -> Text
550    return self.cursor.location.file.name
551
552  def get_include_path(self, prefix):
553    # type: (Optional[Text]) -> Text
554    """Creates a proper include path."""
555    # TODO(szwl): sanity checks
556    # TODO(szwl): prefix 'utils/' and the path is '.../fileutils/...' case
557    if prefix and not prefix.endswith('/'):
558      prefix += '/'
559
560    if not prefix:
561      return self.get_absolute_path()
562    elif prefix in self.get_absolute_path():
563      return prefix + self.get_absolute_path().split(prefix)[-1]
564    return prefix + self.get_absolute_path().split('/')[-1]
565
566  def get_related_types(self, processed=None):
567    # type: (Optional[Set[Type]]) -> Set[Type]
568    result = self.result.get_related_types(processed)
569    for a in self.argument_types:
570      result.update(a.get_related_types(processed))
571
572    return result
573
574  def is_mangled(self):
575    # type: () -> bool
576    return self.cursor.mangled_name != self.cursor.spelling
577
578  def __hash__(self):
579    # type: () -> int
580    return hash(self.cursor.get_usr())
581
582  def __eq__(self, other):
583    # type: (Function) -> bool
584    return self.cursor.mangled_name == other.cursor.mangled_name
585
586
587class _TranslationUnit(object):
588  """Class wrapping clang's _TranslationUnit. Provides extra utilities."""
589
590  def __init__(self, path, tu, limit_scan_depth=False):
591    # type: (Text, cindex.TranslationUnit, bool) -> None
592    self.path = path
593    self.limit_scan_depth = limit_scan_depth
594    self._tu = tu
595    self._processed = False
596    self.forward_decls = dict()
597    self.functions = set()
598    self.order = dict()
599    self.defines = {}
600    self.required_defines = set()
601    self.types_to_skip = set()
602
603  def _process(self):
604    # type: () -> None
605    """Walks the cursor tree and caches some for future use."""
606    if not self._processed:
607      # self.includes[self._tu.spelling] = (0, self._tu.cursor)
608      self._processed = True
609      # TODO(szwl): duplicates?
610      # TODO(szwl): for d in translation_unit.diagnostics:, handle that
611
612      for i, cursor in enumerate(self._walk_preorder()):
613        # Workaround for issue#32
614        # ignore all the cursors with kinds not implemented in python bindings
615        try:
616          cursor.kind
617        except ValueError:
618          continue
619        # naive way to order types: they should be ordered when walking the tree
620        if cursor.kind.is_declaration():
621          self.order[cursor.hash] = i
622
623        if (cursor.kind == cindex.CursorKind.MACRO_DEFINITION and
624            cursor.location.file):
625          self.order[cursor.hash] = i
626          self.defines[cursor.spelling] = cursor
627
628        # most likely a forward decl of struct
629        if (cursor.kind == cindex.CursorKind.STRUCT_DECL and
630            not cursor.is_definition()):
631          self.forward_decls[Type(self, cursor.type)] = cursor
632        if (cursor.kind == cindex.CursorKind.FUNCTION_DECL and
633            cursor.linkage != cindex.LinkageKind.INTERNAL):
634          if self.limit_scan_depth:
635            if (cursor.location and cursor.location.file.name == self.path):
636              self.functions.add(Function(self, cursor))
637          else:
638            self.functions.add(Function(self, cursor))
639
640  def get_functions(self):
641    # type: () -> Set[Function]
642    if not self._processed:
643      self._process()
644    return self.functions
645
646  def _walk_preorder(self):
647    # type: () -> Gen
648    for c in self._tu.cursor.walk_preorder():
649      yield c
650
651  def search_for_macro_name(self, cursor):
652    # type: (cindex.Cursor) -> None
653    """Searches for possible macro usage in constant array types."""
654    tokens = list(t.spelling for t in cursor.get_tokens())
655    try:
656      for token in tokens:
657        if token in self.defines and token not in self.required_defines:
658          self.required_defines.add(token)
659          self.search_for_macro_name(self.defines[token])
660    except ValueError:
661      return
662
663
664class Analyzer(object):
665  """Class responsible for analysis."""
666
667  @staticmethod
668  def process_files(input_paths, compile_flags, limit_scan_depth=False):
669    # type: (Text, List[Text], bool) -> List[_TranslationUnit]
670    """Processes files with libclang and returns TranslationUnit objects."""
671    _init_libclang()
672
673    tus = []
674    for path in input_paths:
675      tu = Analyzer._analyze_file_for_tu(
676          path, compile_flags=compile_flags, limit_scan_depth=limit_scan_depth)
677      tus.append(tu)
678    return tus
679
680  # pylint: disable=line-too-long
681  @staticmethod
682  def _analyze_file_for_tu(path,
683                           compile_flags=None,
684                           test_file_existence=True,
685                           unsaved_files=None,
686                           limit_scan_depth=False):
687    # type: (Text, Optional[List[Text]], bool, Optional[Tuple[Text, Union[Text, IO[Text]]]], bool) -> _TranslationUnit
688    """Returns Analysis object for given path."""
689    compile_flags = compile_flags or []
690    if test_file_existence and not os.path.isfile(path):
691      raise IOError('Path {} does not exist.'.format(path))
692
693    _init_libclang()
694    index = cindex.Index.create()  # type: cindex.Index
695    # TODO(szwl): hack until I figure out how python swig does that.
696    # Headers will be parsed as C++. C libs usually have
697    # '#ifdef __cplusplus extern "C"' for compatibility with c++
698    lang = '-xc++' if not path.endswith('.c') else '-xc'
699    args = [lang]
700    args += compile_flags
701    args.append('-I.')
702    return _TranslationUnit(
703        path,
704        index.parse(
705            path,
706            args=args,
707            unsaved_files=unsaved_files,
708            options=_PARSE_OPTIONS),
709        limit_scan_depth=limit_scan_depth)
710
711
712class Generator(object):
713  """Class responsible for code generation."""
714
715  AUTO_GENERATED = ('// AUTO-GENERATED by the Sandboxed API generator.\n'
716                    '// Edits will be discarded when regenerating this file.\n')
717
718  GUARD_START = ('#ifndef {0}\n' '#define {0}')
719  GUARD_END = '#endif  // {}'
720  EMBED_INCLUDE = '#include "{}"'
721  EMBED_CLASS = ('class {0}Sandbox : public ::sapi::Sandbox {{\n'
722                 ' public:\n'
723                 '  {0}Sandbox() : ::sapi::Sandbox({1}_embed_create()) {{}}\n'
724                 '}};')
725
726  def __init__(self, translation_units):
727    # type: (List[cindex.TranslationUnit]) -> None
728    """Initializes the generator.
729
730    Args:
731      translation_units: list of translation_units for analyzed files,
732        facultative. If not given, then one is computed for each element of
733        input_paths
734    """
735    self.translation_units = translation_units
736    self.functions = None
737    _init_libclang()
738
739  def generate(self,
740               name,
741               function_names,
742               namespace=None,
743               output_file=None,
744               embed_dir=None,
745               embed_name=None):
746    # pylint: disable=line-too-long
747    # type: (Text, List[Text], Optional[Text], Optional[Text], Optional[Text], Optional[Text]) -> Text
748    """Generates structures, functions and typedefs.
749
750    Args:
751      name: name of the class that will contain generated interface
752      function_names: list of function names to export to the interface
753      namespace: namespace of the interface
754      output_file: path to the output file, used to generate header guards;
755        defaults to None that does not generate the guard #include directives;
756        defaults to None that causes to emit the whole file path
757      embed_dir: path to directory with embed includes
758      embed_name: name of the embed object
759
760    Returns:
761      generated interface as a string
762    """
763    related_types = self._get_related_types(function_names)
764    forward_decls = self._get_forward_decls(related_types)
765    functions = self._get_functions(function_names)
766    related_types = [(t.stringify() + ';') for t in related_types]
767    defines = self._get_defines()
768
769    api = {
770        'name': name,
771        'functions': functions,
772        'related_types': defines + forward_decls + related_types,
773        'namespaces': namespace.split('::') if namespace else [],
774        'embed_dir': embed_dir,
775        'embed_name': embed_name,
776        'output_file': output_file
777    }
778    return self.format_template(**api)
779
780  def _get_functions(self, func_names=None):
781    # type: (Optional[List[Text]]) -> List[Function]
782    """Gets Function objects that will be used to generate interface."""
783    if self.functions is not None:
784      return self.functions
785    self.functions = []
786    # TODO(szwl): for d in translation_unit.diagnostics:, handle that
787    for translation_unit in self.translation_units:
788      self.functions += [
789          f for f in translation_unit.get_functions()
790          if not func_names or f.name in func_names
791      ]
792    # allow only nonmangled functions - C++ overloads are not handled in
793    # code generation
794    self.functions = [f for f in self.functions if not f.is_mangled()]
795
796    # remove duplicates
797    self.functions = list(set(self.functions))
798    self.functions.sort(key=lambda x: x.name)
799    return self.functions
800
801  def _get_related_types(self, func_names=None):
802    # type: (Optional[List[Text]]) -> List[Type]
803    """Gets type definitions related to chosen functions.
804
805    Types related to one function will land in the same translation unit,
806    we gather the types, sort it and put as a sublist in types list.
807    This is necessary as we can't compare types from two different translation
808    units.
809
810    Args:
811      func_names: list of function names to take into consideration, empty means
812        all functions.
813
814    Returns:
815      list of types in correct (ready to render) order
816    """
817    processed = set()
818    fn_related_types = set()
819    types = []
820    types_to_skip = set()
821
822    for f in self._get_functions(func_names):
823      fn_related_types = f.get_related_types()
824      types += sorted(r for r in fn_related_types if r not in processed)
825      processed.update(fn_related_types)
826      types_to_skip.update(f.translation_unit().types_to_skip)
827
828    return [t for t in types if t not in types_to_skip]
829
830  def _get_defines(self):
831    # type: () -> List[Text]
832    """Gets #define directives that appeared during TranslationUnit processing.
833
834    Returns:
835      list of #define string representations
836    """
837
838    def make_sort_condition(translation_unit):
839      return lambda cursor: translation_unit.order[cursor.hash]
840
841    result = []
842    for tu in self.translation_units:
843      tmp_result = []
844      sort_condition = make_sort_condition(tu)
845      for name in tu.required_defines:
846        if name in tu.defines:
847          define = tu.defines[name]
848          tmp_result.append(define)
849      for define in sorted(tmp_result, key=sort_condition):
850        result.append('#define ' +
851                      _stringify_tokens(define.get_tokens(), separator=' \\\n'))
852    return result
853
854  def _get_forward_decls(self, types):
855    # type: (List[Type]) -> List[Text]
856    """Gets forward declarations of related types, if present."""
857    forward_decls = dict()
858    result = []
859    done = set()
860    for tu in self.translation_units:
861      forward_decls.update(tu.forward_decls)
862
863      for t in types:
864        if t in forward_decls and t not in done:
865          result.append(_stringify_tokens(forward_decls[t].get_tokens()) + ';')
866          done.add(t)
867
868    return result
869
870  def _format_function(self, f):
871    # type: (Function) -> Text
872    """Renders one function of the Api.
873
874    Args:
875      f: function object with information necessary to emit full function body
876
877    Returns:
878      filled function template
879    """
880    result = []
881    result.append('  // {}'.format(f.original_definition))
882
883    arguments = ', '.join(str(a) for a in f.arguments())
884    result.append('  {} {}({}) {{'.format(f.result, f.name, arguments))
885    result.append('    {} ret;'.format(f.result.mapped_type))
886
887    argument_types = []
888    for a in f.argument_types:
889      if not a.is_sugared_ptr():
890        argument_types.append(a.wrapped + ';')
891    if argument_types:
892      for arg in argument_types:
893        result.append('    {}'.format(arg))
894
895    call_arguments = f.call_arguments()
896    if call_arguments:  # fake empty space to add ',' before first argument
897      call_arguments.insert(0, '')
898    result.append('')
899    # For OSS, the macro below will be replaced.
900    result.append('    SAPI_RETURN_IF_ERROR(sandbox_->Call("{}", &ret{}));'
901                  ''.format(f.name, ', '.join(call_arguments)))
902
903    return_status = 'return absl::OkStatus();'
904    if f.result and not f.result.is_void():
905      if f.result and f.result.is_sugared_enum():
906        return_status = ('return static_cast<{}>'
907                         '(ret.GetValue());').format(f.result.type)
908      else:
909        return_status = 'return ret.GetValue();'
910    result.append('    {}'.format(return_status))
911    result.append('  }')
912
913    return '\n'.join(result)
914
915  def format_template(self, name, functions, related_types, namespaces,
916                      embed_dir, embed_name, output_file):
917    # pylint: disable=line-too-long
918    # type: (Text, List[Function], List[Text], List[Text], Text, Text, Text) -> Text
919    # pylint: enable=line-too-long
920    """Formats arguments into proper interface header file.
921
922    Args:
923      name: name of the Api - 'Test' will yield TestApi object
924      functions: list of functions to generate
925      related_types: types used in the above functions
926      namespaces: list of namespaces to wrap the Api class with
927      embed_dir: directory where the embedded library lives
928      embed_name: name of embedded library
929      output_file: interface output path - used in header guard generation
930
931    Returns:
932      generated header file text
933    """
934    result = [Generator.AUTO_GENERATED]
935
936    header_guard = get_header_guard(output_file) if output_file else ''
937    if header_guard:
938      result.append(Generator.GUARD_START.format(header_guard))
939
940    # Copybara transform results in the paths below.
941    result.append('#include "absl/status/status.h"')
942    result.append('#include "absl/status/statusor.h"')
943    result.append('#include "sandboxed_api/sandbox.h"')
944    result.append('#include "sandboxed_api/util/status_macros.h"')
945    result.append('#include "sandboxed_api/vars.h"')
946
947    if embed_name:
948      embed_dir = embed_dir or ''
949      result.append(
950          Generator.EMBED_INCLUDE.format(
951              os.path.join(embed_dir, embed_name) + '_embed.h'))
952
953    if namespaces:
954      result.append('')
955      for n in namespaces:
956        result.append('namespace {} {{'.format(n))
957
958    if related_types:
959      result.append('')
960      for t in related_types:
961        result.append(t)
962
963    result.append('')
964
965    if embed_name:
966      result.append(
967          Generator.EMBED_CLASS.format(name, embed_name.replace('-', '_')))
968
969    result.append('class {}Api {{'.format(name))
970    result.append(' public:')
971    result.append('  explicit {}Api(::sapi::Sandbox* sandbox)'
972                  ' : sandbox_(sandbox) {{}}'.format(name))
973    result.append('  // Deprecated')
974    result.append('  ::sapi::Sandbox* GetSandbox() const { return sandbox(); }')
975    result.append('  ::sapi::Sandbox* sandbox() const { return sandbox_; }')
976
977    for f in functions:
978      result.append('')
979      result.append(self._format_function(f))
980
981    result.append('')
982    result.append(' private:')
983    result.append('  ::sapi::Sandbox* sandbox_;')
984    result.append('};')
985    result.append('')
986
987    if namespaces:
988      for n in reversed(namespaces):
989        result.append('}}  // namespace {}'.format(n))
990
991    if header_guard:
992      result.append(Generator.GUARD_END.format(header_guard))
993
994    result.append('')
995
996    return '\n'.join(result)
997