1# Copyright 2019 Google LLC 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# https://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14"""Module related to code analysis and generation.""" 15 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19from ctypes import util 20import itertools 21import os 22# pylint: disable=unused-import 23from typing import (Text, List, Optional, Set, Dict, Callable, IO, 24 Generator as Gen, Tuple, Union, Sequence) # pyformat: disable 25# pylint: enable=unused-import 26from clang import cindex 27 28 29_PARSE_OPTIONS = ( 30 cindex.TranslationUnit.PARSE_SKIP_FUNCTION_BODIES 31 | cindex.TranslationUnit.PARSE_INCOMPLETE | 32 # for include directives 33 cindex.TranslationUnit.PARSE_DETAILED_PROCESSING_RECORD) 34 35 36def _init_libclang(): 37 """Finds and initializes the libclang library.""" 38 if cindex.Config.loaded: 39 return 40 # Try to find libclang in the standard location and a few versioned paths 41 # that are used on Debian (and others). If LD_LIBRARY_PATH is set, it is 42 # used as well. 43 for version in [ 44 '', 45 '16', 46 '15', 47 '14', 48 '13', 49 '12', 50 '11', 51 '10', 52 '9', 53 '8', 54 '7', 55 '6.0', 56 '5.0', 57 '4.0', 58 ]: 59 libname = 'clang' + ('-' + version if version else '') 60 libclang = util.find_library(libname) 61 if libclang: 62 cindex.Config.set_library_file(libclang) 63 break 64 65 66def get_header_guard(path): 67 # type: (Text) -> Text 68 """Generates header guard string from path.""" 69 # the output file will be most likely somewhere in genfiles, strip the 70 # prefix in that case, also strip .gen if this is a step before clang-format 71 if not path: 72 raise ValueError('Cannot prepare header guard from path: {}'.format(path)) 73 if 'genfiles/' in path: 74 path = path.split('genfiles/')[1] 75 if path.endswith('.gen'): 76 path = path.split('.gen')[0] 77 path = path.upper().replace('.', '_').replace('-', '_').replace('/', '_') 78 return path + '_' 79 80 81def _stringify_tokens(tokens, separator='\n'): 82 # type: (Sequence[cindex.Token], Text) -> Text 83 """Converts tokens to text respecting line position (disrespecting column).""" 84 previous = OutputLine(0, []) # not used in output 85 lines = [] # type: List[OutputLine] 86 87 for _, group in itertools.groupby(tokens, lambda t: t.location.line): 88 group_list = list(group) 89 line = OutputLine(previous.next_tab, group_list) 90 91 lines.append(line) 92 previous = line 93 94 return separator.join(str(l) for l in lines) 95 96 97TYPE_MAPPING = { 98 cindex.TypeKind.VOID: '::sapi::v::Void', 99 cindex.TypeKind.CHAR_S: '::sapi::v::Char', 100 cindex.TypeKind.CHAR_U: '::sapi::v::Char', 101 cindex.TypeKind.INT: '::sapi::v::Int', 102 cindex.TypeKind.UINT: '::sapi::v::UInt', 103 cindex.TypeKind.LONG: '::sapi::v::Long', 104 cindex.TypeKind.ULONG: '::sapi::v::ULong', 105 cindex.TypeKind.UCHAR: '::sapi::v::UChar', 106 cindex.TypeKind.USHORT: '::sapi::v::UShort', 107 cindex.TypeKind.SHORT: '::sapi::v::Short', 108 cindex.TypeKind.LONGLONG: '::sapi::v::LLong', 109 cindex.TypeKind.ULONGLONG: '::sapi::v::ULLong', 110 cindex.TypeKind.FLOAT: '::sapi::v::Reg<float>', 111 cindex.TypeKind.DOUBLE: '::sapi::v::Reg<double>', 112 cindex.TypeKind.LONGDOUBLE: '::sapi::v::Reg<long double>', 113 cindex.TypeKind.SCHAR: '::sapi::v::SChar', 114 cindex.TypeKind.SHORT: '::sapi::v::Short', 115 cindex.TypeKind.BOOL: '::sapi::v::Bool', 116} 117 118 119class Type(object): 120 """Class representing a type. 121 122 Wraps cindex.Type of the argument/return value and provides helpers for the 123 code generation. 124 """ 125 126 def __init__(self, tu, clang_type): 127 # type: (_TranslationUnit, cindex.Type) -> None 128 self._clang_type = clang_type 129 self._tu = tu 130 131 # pylint: disable=protected-access 132 def __eq__(self, other): 133 # type: (Type) -> bool 134 # Use get_usr() to deduplicate Type objects based on declaration 135 decl = self._get_declaration() 136 decl_o = other._get_declaration() 137 138 return decl.get_usr() == decl_o.get_usr() 139 140 def __ne__(self, other): 141 # type: (Type) -> bool 142 return not self.__eq__(other) 143 144 def __lt__(self, other): 145 # type: (Type) -> bool 146 """Compares two Types belonging to the same TranslationUnit. 147 148 This is being used to properly order types before emitting to generated 149 file. To be more specific: structure definition that contains field that is 150 a typedef should end up after that typedef definition. This is achieved by 151 exploiting the order in which clang iterate over AST in translation unit. 152 153 Args: 154 other: other comparison type 155 156 Returns: 157 true if this Type occurs earlier in the AST than 'other' 158 """ 159 self._validate_tu(other) 160 return (self._tu.order[self._get_declaration().hash] < 161 self._tu.order[other._get_declaration().hash]) # pylint: disable=protected-access 162 163 def __gt__(self, other): 164 # type: (Type) -> bool 165 """Compares two Types belonging to the same TranslationUnit. 166 167 This is being used to properly order types before emitting to generated 168 file. To be more specific: structure definition that contains field that is 169 a typedef should end up after that typedef definition. This is achieved by 170 exploiting the order in which clang iterate over AST in translation unit. 171 172 Args: 173 other: other comparison type 174 175 Returns: 176 true if this Type occurs later in the AST than 'other' 177 """ 178 self._validate_tu(other) 179 return (self._tu.order[self._get_declaration().hash] > 180 self._tu.order[other._get_declaration().hash]) # pylint: disable=protected-access 181 182 def __hash__(self): 183 """Types with the same declaration should hash to the same value.""" 184 return hash(self._get_declaration().get_usr()) 185 186 def _validate_tu(self, other): 187 # type: (Type) -> None 188 if self._tu != other._tu: # pylint: disable=protected-access 189 raise ValueError('Cannot compare types from different translation units.') 190 191 def is_void(self): 192 # type: () -> bool 193 return self._clang_type.kind == cindex.TypeKind.VOID 194 195 def is_typedef(self): 196 # type: () -> bool 197 return self._clang_type.kind == cindex.TypeKind.TYPEDEF 198 199 def is_elaborated(self): 200 # type: () -> bool 201 return self._clang_type.kind == cindex.TypeKind.ELABORATED 202 203 # Hack: both class and struct types are indistinguishable except for 204 # declaration cursor kind 205 def is_sugared_record(self): # class, struct, union 206 # type: () -> bool 207 return self._clang_type.get_declaration().kind in ( 208 cindex.CursorKind.STRUCT_DECL, cindex.CursorKind.UNION_DECL, 209 cindex.CursorKind.CLASS_DECL) 210 211 def is_struct(self): 212 # type: () -> bool 213 return (self._clang_type.get_declaration().kind == 214 cindex.CursorKind.STRUCT_DECL) 215 216 def is_class(self): 217 # type: () -> bool 218 return (self._clang_type.get_declaration().kind == 219 cindex.CursorKind.CLASS_DECL) 220 221 def is_union(self): 222 # type: () -> bool 223 return (self._clang_type.get_declaration().kind == 224 cindex.CursorKind.UNION_DECL) 225 226 def is_function(self): 227 # type: () -> bool 228 return self._clang_type.kind == cindex.TypeKind.FUNCTIONPROTO 229 230 def is_sugared_ptr(self): 231 # type: () -> bool 232 return self._clang_type.get_canonical().kind == cindex.TypeKind.POINTER 233 234 def is_sugared_enum(self): 235 # type: () -> bool 236 return self._clang_type.get_canonical().kind == cindex.TypeKind.ENUM 237 238 def is_const_array(self): 239 # type: () -> bool 240 return self._clang_type.kind == cindex.TypeKind.CONSTANTARRAY 241 242 def is_simple_type(self): 243 # type: () -> bool 244 return self._clang_type.kind in TYPE_MAPPING 245 246 def get_pointee(self): 247 # type: () -> Type 248 return Type(self._tu, self._clang_type.get_pointee()) 249 250 def _get_declaration(self): 251 # type: () -> cindex.Cursor 252 decl = self._clang_type.get_declaration() 253 if decl.kind == cindex.CursorKind.NO_DECL_FOUND and self.is_sugared_ptr(): 254 decl = self.get_pointee()._get_declaration() # pylint: disable=protected-access 255 256 return decl 257 258 def get_related_types(self, result=None, skip_self=False): 259 # type: (Optional[Set[Type]], bool) -> Set[Type] 260 """Returns all types related to this one eg. typedefs, nested structs.""" 261 if result is None: 262 result = set() 263 264 # Base case. 265 if self in result or self.is_simple_type() or self.is_class(): 266 return result 267 268 # Sugar types. 269 if self.is_typedef(): 270 return self._get_related_types_of_typedef(result) 271 272 if self.is_elaborated(): 273 return Type(self._tu, 274 self._clang_type.get_named_type()).get_related_types( 275 result, skip_self) 276 277 # Composite types. 278 if self.is_const_array(): 279 t = Type(self._tu, self._clang_type.get_array_element_type()) 280 return t.get_related_types(result) 281 282 if self._clang_type.kind in (cindex.TypeKind.POINTER, 283 cindex.TypeKind.MEMBERPOINTER, 284 cindex.TypeKind.LVALUEREFERENCE, 285 cindex.TypeKind.RVALUEREFERENCE): 286 return self.get_pointee().get_related_types(result, skip_self) 287 288 # union + struct, class should be filtered out 289 if self.is_struct() or self.is_union(): 290 return self._get_related_types_of_record(result, skip_self) 291 292 if self.is_function(): 293 return self._get_related_types_of_function(result) 294 295 if self.is_sugared_enum(): 296 if not skip_self: 297 result.add(self) 298 self._tu.search_for_macro_name(self._get_declaration()) 299 return result 300 301 # Ignore all cindex.TypeKind.UNEXPOSED AST nodes 302 # TODO(b/256934562): Remove the disable once the pytype bug is fixed. 303 return result # pytype: disable=bad-return-type 304 305 def _get_related_types_of_typedef(self, result): 306 # type: (Set[Type]) -> Set[Type] 307 """Returns all intermediate types related to the typedef.""" 308 result.add(self) 309 decl = self._clang_type.get_declaration() 310 self._tu.search_for_macro_name(decl) 311 312 t = Type(self._tu, decl.underlying_typedef_type) 313 if t.is_sugared_ptr(): 314 t = t.get_pointee() 315 316 if not t.is_simple_type(): 317 skip_child = self.contains_declaration(t) 318 if t.is_sugared_record() and skip_child: 319 # if child declaration is contained in parent, we don't have to emit it 320 self._tu.types_to_skip.add(t) 321 result.update(t.get_related_types(result, skip_child)) 322 323 return result 324 325 def _get_related_types_of_record(self, result, skip_self=False): 326 # type: (Set[Type], bool) -> Set[Type] 327 """Returns all types related to the structure.""" 328 # skip unnamed structures eg. typedef struct {...} x; 329 # struct {...} will be rendered as part of typedef rendering 330 decl = self._get_declaration() 331 if not decl.is_anonymous() and not skip_self: 332 self._tu.search_for_macro_name(decl) 333 result.add(self) 334 335 for f in self._clang_type.get_fields(): 336 self._tu.search_for_macro_name(f) 337 result.update(Type(self._tu, f.type).get_related_types(result)) 338 339 return result 340 341 def _get_related_types_of_function(self, result): 342 # type: (Set[Type]) -> Set[Type] 343 """Returns all types related to the function.""" 344 for arg in self._clang_type.argument_types(): 345 result.update(Type(self._tu, arg).get_related_types(result)) 346 related = Type(self._tu, 347 self._clang_type.get_result()).get_related_types(result) 348 result.update(related) 349 350 return result 351 352 def contains_declaration(self, other): 353 # type: (Type) -> bool 354 """Checks if string representation of a type contains the other type.""" 355 self_extent = self._get_declaration().extent 356 other_extent = other._get_declaration().extent # pylint: disable=protected-access 357 358 if other_extent.start.file is None: 359 return False 360 return (other_extent.start in self_extent and 361 other_extent.end in self_extent) 362 363 def stringify(self): 364 # type: () -> Text 365 """Returns string representation of the Type.""" 366 # (szwl): as simple as possible, keeps macros in separate lines not to 367 # break things; this will go through clang format nevertheless 368 tokens = [ 369 x for x in self._get_declaration().get_tokens() 370 if x.kind is not cindex.TokenKind.COMMENT 371 ] 372 373 return _stringify_tokens(tokens) 374 375 376class OutputLine(object): 377 """Helper class for Type printing.""" 378 379 def __init__(self, tab, tokens): 380 # type: (int, List[cindex.Token]) -> None 381 self.tokens = tokens 382 self.spellings = [] 383 self.define = False 384 self.tab = tab 385 self.next_tab = tab 386 list(map(self._process_token, self.tokens)) 387 388 def _process_token(self, t): 389 # type: (cindex.Token) -> None 390 """Processes a token, setting up internal states rel. to intendation.""" 391 if t.spelling == '#': 392 self.define = True 393 elif t.spelling == '{': 394 self.next_tab += 1 395 elif t.spelling == '}': 396 self.tab -= 1 397 self.next_tab -= 1 398 399 is_bracket = t.spelling == '(' 400 is_macro = len(self.spellings) == 1 and self.spellings[0] == '#' 401 if self.spellings and not is_bracket and not is_macro: 402 self.spellings.append(' ') 403 self.spellings.append(t.spelling) 404 405 def __str__(self): 406 # type: () -> Text 407 tabs = ('\t' * self.tab) if not self.define else '' 408 return tabs + ''.join(t for t in self.spellings) 409 410 411class ArgumentType(Type): 412 """Class representing function argument type. 413 414 Object fields are being used by the code template: 415 pos: argument position 416 type: string representation of the type 417 argument: string representation of the type as function argument 418 mapped_type: SAPI equivalent of the type 419 wrapped: wraps type in SAPI object constructor 420 call_argument: type (or it's sapi wrapper) used in function call 421 """ 422 423 def __init__(self, function, pos, arg_type, name=None): 424 # type: (Function, int, cindex.Type, Optional[Text]) -> None 425 super(ArgumentType, self).__init__(function.translation_unit(), arg_type) 426 self._function = function 427 428 self.pos = pos 429 self.name = name or 'a{}'.format(pos) 430 self.type = arg_type.spelling 431 432 template = '{}' if self.is_sugared_ptr() else '&{}_' 433 self.call_argument = template.format(self.name) 434 435 def __str__(self): 436 # type: () -> Text 437 """Returns function argument prepared from the type.""" 438 if self.is_sugared_ptr(): 439 return '::sapi::v::Ptr* {}'.format(self.name) 440 441 return '{} {}'.format(self._clang_type.spelling, self.name) 442 443 @property 444 def wrapped(self): 445 # type: () -> Text 446 return '{} {name}_(({name}))'.format(self.mapped_type, name=self.name) 447 448 @property 449 def mapped_type(self): 450 # type: () -> Text 451 """Maps the type to its SAPI equivalent.""" 452 if self.is_sugared_ptr(): 453 # TODO(szwl): const ptrs do not play well with SAPI C++ API... 454 spelling = self._clang_type.spelling.replace('const', '') 455 return '::sapi::v::Reg<{}>'.format(spelling) 456 457 type_ = self._clang_type 458 459 if type_.kind == cindex.TypeKind.TYPEDEF: 460 type_ = self._clang_type.get_canonical() 461 if type_.kind == cindex.TypeKind.ELABORATED: 462 type_ = type_.get_canonical() 463 if type_.kind == cindex.TypeKind.ENUM: 464 return '::sapi::v::IntBase<{}>'.format(self._clang_type.spelling) 465 if type_.kind in [ 466 cindex.TypeKind.CONSTANTARRAY, cindex.TypeKind.INCOMPLETEARRAY 467 ]: 468 return '::sapi::v::Reg<{}>'.format(self._clang_type.spelling) 469 470 if type_.kind == cindex.TypeKind.LVALUEREFERENCE: 471 return 'LVALUEREFERENCE::NOT_SUPPORTED' 472 473 if type_.kind == cindex.TypeKind.RVALUEREFERENCE: 474 return 'RVALUEREFERENCE::NOT_SUPPORTED' 475 476 if type_.kind in [cindex.TypeKind.RECORD, cindex.TypeKind.ELABORATED]: 477 raise ValueError('Elaborate type (eg. struct) in mapped_type is not ' 478 'supported: function {}, arg {}, type {}, location {}' 479 ''.format(self._function.name, self.pos, 480 self._clang_type.spelling, 481 self._function.cursor.location)) 482 483 if type_.kind not in TYPE_MAPPING: 484 raise KeyError('Key {} does not exist in TYPE_MAPPING.' 485 ' function {}, arg {}, type {}, location {}' 486 ''.format(type_.kind, self._function.name, self.pos, 487 self._clang_type.spelling, 488 self._function.cursor.location)) 489 490 return TYPE_MAPPING[type_.kind] 491 492 493class ReturnType(ArgumentType): 494 """Class representing function return type. 495 496 Attributes: 497 return_type: absl::StatusOr<T> where T is original return type, or 498 absl::Status for functions returning void 499 """ 500 501 def __init__(self, function, arg_type): 502 # type: (Function, cindex.Type) -> None 503 super(ReturnType, self).__init__(function, 0, arg_type, None) 504 505 def __str__(self): 506 # type: () -> Text 507 """Returns function return type prepared from the type.""" 508 # TODO(szwl): const ptrs do not play well with SAPI C++ API... 509 spelling = self._clang_type.spelling.replace('const', '') 510 return_type = 'absl::StatusOr<{}>'.format(spelling) 511 return_type = 'absl::Status' if self.is_void() else return_type 512 return return_type 513 514 515class Function(object): 516 """Class representing SAPI-wrapped function used by the template. 517 518 Wraps Clang cursor object of kind FUNCTION_DECL and provides helpers to 519 aid code generation. 520 """ 521 522 def __init__(self, tu, cursor): 523 # type: (_TranslationUnit, cindex.Cursor) -> None 524 self._tu = tu 525 self.cursor = cursor # type: cindex.Index 526 self.name = cursor.spelling # type: Text 527 self.result = ReturnType(self, cursor.result_type) 528 self.original_definition = '{} {}'.format( 529 cursor.result_type.spelling, self.cursor.displayname) # type: Text 530 531 types = self.cursor.get_arguments() 532 self.argument_types = [ 533 ArgumentType(self, i, t.type, t.spelling) for i, t in enumerate(types) 534 ] 535 536 def translation_unit(self): 537 # type: () -> _TranslationUnit 538 return self._tu 539 540 def arguments(self): 541 # type: () -> List[ArgumentType] 542 return self.argument_types 543 544 def call_arguments(self): 545 # type: () -> List[Text] 546 return [a.call_argument for a in self.argument_types] 547 548 def get_absolute_path(self): 549 # type: () -> Text 550 return self.cursor.location.file.name 551 552 def get_include_path(self, prefix): 553 # type: (Optional[Text]) -> Text 554 """Creates a proper include path.""" 555 # TODO(szwl): sanity checks 556 # TODO(szwl): prefix 'utils/' and the path is '.../fileutils/...' case 557 if prefix and not prefix.endswith('/'): 558 prefix += '/' 559 560 if not prefix: 561 return self.get_absolute_path() 562 elif prefix in self.get_absolute_path(): 563 return prefix + self.get_absolute_path().split(prefix)[-1] 564 return prefix + self.get_absolute_path().split('/')[-1] 565 566 def get_related_types(self, processed=None): 567 # type: (Optional[Set[Type]]) -> Set[Type] 568 result = self.result.get_related_types(processed) 569 for a in self.argument_types: 570 result.update(a.get_related_types(processed)) 571 572 return result 573 574 def is_mangled(self): 575 # type: () -> bool 576 return self.cursor.mangled_name != self.cursor.spelling 577 578 def __hash__(self): 579 # type: () -> int 580 return hash(self.cursor.get_usr()) 581 582 def __eq__(self, other): 583 # type: (Function) -> bool 584 return self.cursor.mangled_name == other.cursor.mangled_name 585 586 587class _TranslationUnit(object): 588 """Class wrapping clang's _TranslationUnit. Provides extra utilities.""" 589 590 def __init__(self, path, tu, limit_scan_depth=False): 591 # type: (Text, cindex.TranslationUnit, bool) -> None 592 self.path = path 593 self.limit_scan_depth = limit_scan_depth 594 self._tu = tu 595 self._processed = False 596 self.forward_decls = dict() 597 self.functions = set() 598 self.order = dict() 599 self.defines = {} 600 self.required_defines = set() 601 self.types_to_skip = set() 602 603 def _process(self): 604 # type: () -> None 605 """Walks the cursor tree and caches some for future use.""" 606 if not self._processed: 607 # self.includes[self._tu.spelling] = (0, self._tu.cursor) 608 self._processed = True 609 # TODO(szwl): duplicates? 610 # TODO(szwl): for d in translation_unit.diagnostics:, handle that 611 612 for i, cursor in enumerate(self._walk_preorder()): 613 # Workaround for issue#32 614 # ignore all the cursors with kinds not implemented in python bindings 615 try: 616 cursor.kind 617 except ValueError: 618 continue 619 # naive way to order types: they should be ordered when walking the tree 620 if cursor.kind.is_declaration(): 621 self.order[cursor.hash] = i 622 623 if (cursor.kind == cindex.CursorKind.MACRO_DEFINITION and 624 cursor.location.file): 625 self.order[cursor.hash] = i 626 self.defines[cursor.spelling] = cursor 627 628 # most likely a forward decl of struct 629 if (cursor.kind == cindex.CursorKind.STRUCT_DECL and 630 not cursor.is_definition()): 631 self.forward_decls[Type(self, cursor.type)] = cursor 632 if (cursor.kind == cindex.CursorKind.FUNCTION_DECL and 633 cursor.linkage != cindex.LinkageKind.INTERNAL): 634 if self.limit_scan_depth: 635 if (cursor.location and cursor.location.file.name == self.path): 636 self.functions.add(Function(self, cursor)) 637 else: 638 self.functions.add(Function(self, cursor)) 639 640 def get_functions(self): 641 # type: () -> Set[Function] 642 if not self._processed: 643 self._process() 644 return self.functions 645 646 def _walk_preorder(self): 647 # type: () -> Gen 648 for c in self._tu.cursor.walk_preorder(): 649 yield c 650 651 def search_for_macro_name(self, cursor): 652 # type: (cindex.Cursor) -> None 653 """Searches for possible macro usage in constant array types.""" 654 tokens = list(t.spelling for t in cursor.get_tokens()) 655 try: 656 for token in tokens: 657 if token in self.defines and token not in self.required_defines: 658 self.required_defines.add(token) 659 self.search_for_macro_name(self.defines[token]) 660 except ValueError: 661 return 662 663 664class Analyzer(object): 665 """Class responsible for analysis.""" 666 667 @staticmethod 668 def process_files(input_paths, compile_flags, limit_scan_depth=False): 669 # type: (Text, List[Text], bool) -> List[_TranslationUnit] 670 """Processes files with libclang and returns TranslationUnit objects.""" 671 _init_libclang() 672 673 tus = [] 674 for path in input_paths: 675 tu = Analyzer._analyze_file_for_tu( 676 path, compile_flags=compile_flags, limit_scan_depth=limit_scan_depth) 677 tus.append(tu) 678 return tus 679 680 # pylint: disable=line-too-long 681 @staticmethod 682 def _analyze_file_for_tu(path, 683 compile_flags=None, 684 test_file_existence=True, 685 unsaved_files=None, 686 limit_scan_depth=False): 687 # type: (Text, Optional[List[Text]], bool, Optional[Tuple[Text, Union[Text, IO[Text]]]], bool) -> _TranslationUnit 688 """Returns Analysis object for given path.""" 689 compile_flags = compile_flags or [] 690 if test_file_existence and not os.path.isfile(path): 691 raise IOError('Path {} does not exist.'.format(path)) 692 693 _init_libclang() 694 index = cindex.Index.create() # type: cindex.Index 695 # TODO(szwl): hack until I figure out how python swig does that. 696 # Headers will be parsed as C++. C libs usually have 697 # '#ifdef __cplusplus extern "C"' for compatibility with c++ 698 lang = '-xc++' if not path.endswith('.c') else '-xc' 699 args = [lang] 700 args += compile_flags 701 args.append('-I.') 702 return _TranslationUnit( 703 path, 704 index.parse( 705 path, 706 args=args, 707 unsaved_files=unsaved_files, 708 options=_PARSE_OPTIONS), 709 limit_scan_depth=limit_scan_depth) 710 711 712class Generator(object): 713 """Class responsible for code generation.""" 714 715 AUTO_GENERATED = ('// AUTO-GENERATED by the Sandboxed API generator.\n' 716 '// Edits will be discarded when regenerating this file.\n') 717 718 GUARD_START = ('#ifndef {0}\n' '#define {0}') 719 GUARD_END = '#endif // {}' 720 EMBED_INCLUDE = '#include "{}"' 721 EMBED_CLASS = ('class {0}Sandbox : public ::sapi::Sandbox {{\n' 722 ' public:\n' 723 ' {0}Sandbox() : ::sapi::Sandbox({1}_embed_create()) {{}}\n' 724 '}};') 725 726 def __init__(self, translation_units): 727 # type: (List[cindex.TranslationUnit]) -> None 728 """Initializes the generator. 729 730 Args: 731 translation_units: list of translation_units for analyzed files, 732 facultative. If not given, then one is computed for each element of 733 input_paths 734 """ 735 self.translation_units = translation_units 736 self.functions = None 737 _init_libclang() 738 739 def generate(self, 740 name, 741 function_names, 742 namespace=None, 743 output_file=None, 744 embed_dir=None, 745 embed_name=None): 746 # pylint: disable=line-too-long 747 # type: (Text, List[Text], Optional[Text], Optional[Text], Optional[Text], Optional[Text]) -> Text 748 """Generates structures, functions and typedefs. 749 750 Args: 751 name: name of the class that will contain generated interface 752 function_names: list of function names to export to the interface 753 namespace: namespace of the interface 754 output_file: path to the output file, used to generate header guards; 755 defaults to None that does not generate the guard #include directives; 756 defaults to None that causes to emit the whole file path 757 embed_dir: path to directory with embed includes 758 embed_name: name of the embed object 759 760 Returns: 761 generated interface as a string 762 """ 763 related_types = self._get_related_types(function_names) 764 forward_decls = self._get_forward_decls(related_types) 765 functions = self._get_functions(function_names) 766 related_types = [(t.stringify() + ';') for t in related_types] 767 defines = self._get_defines() 768 769 api = { 770 'name': name, 771 'functions': functions, 772 'related_types': defines + forward_decls + related_types, 773 'namespaces': namespace.split('::') if namespace else [], 774 'embed_dir': embed_dir, 775 'embed_name': embed_name, 776 'output_file': output_file 777 } 778 return self.format_template(**api) 779 780 def _get_functions(self, func_names=None): 781 # type: (Optional[List[Text]]) -> List[Function] 782 """Gets Function objects that will be used to generate interface.""" 783 if self.functions is not None: 784 return self.functions 785 self.functions = [] 786 # TODO(szwl): for d in translation_unit.diagnostics:, handle that 787 for translation_unit in self.translation_units: 788 self.functions += [ 789 f for f in translation_unit.get_functions() 790 if not func_names or f.name in func_names 791 ] 792 # allow only nonmangled functions - C++ overloads are not handled in 793 # code generation 794 self.functions = [f for f in self.functions if not f.is_mangled()] 795 796 # remove duplicates 797 self.functions = list(set(self.functions)) 798 self.functions.sort(key=lambda x: x.name) 799 return self.functions 800 801 def _get_related_types(self, func_names=None): 802 # type: (Optional[List[Text]]) -> List[Type] 803 """Gets type definitions related to chosen functions. 804 805 Types related to one function will land in the same translation unit, 806 we gather the types, sort it and put as a sublist in types list. 807 This is necessary as we can't compare types from two different translation 808 units. 809 810 Args: 811 func_names: list of function names to take into consideration, empty means 812 all functions. 813 814 Returns: 815 list of types in correct (ready to render) order 816 """ 817 processed = set() 818 fn_related_types = set() 819 types = [] 820 types_to_skip = set() 821 822 for f in self._get_functions(func_names): 823 fn_related_types = f.get_related_types() 824 types += sorted(r for r in fn_related_types if r not in processed) 825 processed.update(fn_related_types) 826 types_to_skip.update(f.translation_unit().types_to_skip) 827 828 return [t for t in types if t not in types_to_skip] 829 830 def _get_defines(self): 831 # type: () -> List[Text] 832 """Gets #define directives that appeared during TranslationUnit processing. 833 834 Returns: 835 list of #define string representations 836 """ 837 838 def make_sort_condition(translation_unit): 839 return lambda cursor: translation_unit.order[cursor.hash] 840 841 result = [] 842 for tu in self.translation_units: 843 tmp_result = [] 844 sort_condition = make_sort_condition(tu) 845 for name in tu.required_defines: 846 if name in tu.defines: 847 define = tu.defines[name] 848 tmp_result.append(define) 849 for define in sorted(tmp_result, key=sort_condition): 850 result.append('#define ' + 851 _stringify_tokens(define.get_tokens(), separator=' \\\n')) 852 return result 853 854 def _get_forward_decls(self, types): 855 # type: (List[Type]) -> List[Text] 856 """Gets forward declarations of related types, if present.""" 857 forward_decls = dict() 858 result = [] 859 done = set() 860 for tu in self.translation_units: 861 forward_decls.update(tu.forward_decls) 862 863 for t in types: 864 if t in forward_decls and t not in done: 865 result.append(_stringify_tokens(forward_decls[t].get_tokens()) + ';') 866 done.add(t) 867 868 return result 869 870 def _format_function(self, f): 871 # type: (Function) -> Text 872 """Renders one function of the Api. 873 874 Args: 875 f: function object with information necessary to emit full function body 876 877 Returns: 878 filled function template 879 """ 880 result = [] 881 result.append(' // {}'.format(f.original_definition)) 882 883 arguments = ', '.join(str(a) for a in f.arguments()) 884 result.append(' {} {}({}) {{'.format(f.result, f.name, arguments)) 885 result.append(' {} ret;'.format(f.result.mapped_type)) 886 887 argument_types = [] 888 for a in f.argument_types: 889 if not a.is_sugared_ptr(): 890 argument_types.append(a.wrapped + ';') 891 if argument_types: 892 for arg in argument_types: 893 result.append(' {}'.format(arg)) 894 895 call_arguments = f.call_arguments() 896 if call_arguments: # fake empty space to add ',' before first argument 897 call_arguments.insert(0, '') 898 result.append('') 899 # For OSS, the macro below will be replaced. 900 result.append(' SAPI_RETURN_IF_ERROR(sandbox_->Call("{}", &ret{}));' 901 ''.format(f.name, ', '.join(call_arguments))) 902 903 return_status = 'return absl::OkStatus();' 904 if f.result and not f.result.is_void(): 905 if f.result and f.result.is_sugared_enum(): 906 return_status = ('return static_cast<{}>' 907 '(ret.GetValue());').format(f.result.type) 908 else: 909 return_status = 'return ret.GetValue();' 910 result.append(' {}'.format(return_status)) 911 result.append(' }') 912 913 return '\n'.join(result) 914 915 def format_template(self, name, functions, related_types, namespaces, 916 embed_dir, embed_name, output_file): 917 # pylint: disable=line-too-long 918 # type: (Text, List[Function], List[Text], List[Text], Text, Text, Text) -> Text 919 # pylint: enable=line-too-long 920 """Formats arguments into proper interface header file. 921 922 Args: 923 name: name of the Api - 'Test' will yield TestApi object 924 functions: list of functions to generate 925 related_types: types used in the above functions 926 namespaces: list of namespaces to wrap the Api class with 927 embed_dir: directory where the embedded library lives 928 embed_name: name of embedded library 929 output_file: interface output path - used in header guard generation 930 931 Returns: 932 generated header file text 933 """ 934 result = [Generator.AUTO_GENERATED] 935 936 header_guard = get_header_guard(output_file) if output_file else '' 937 if header_guard: 938 result.append(Generator.GUARD_START.format(header_guard)) 939 940 # Copybara transform results in the paths below. 941 result.append('#include "absl/status/status.h"') 942 result.append('#include "absl/status/statusor.h"') 943 result.append('#include "sandboxed_api/sandbox.h"') 944 result.append('#include "sandboxed_api/util/status_macros.h"') 945 result.append('#include "sandboxed_api/vars.h"') 946 947 if embed_name: 948 embed_dir = embed_dir or '' 949 result.append( 950 Generator.EMBED_INCLUDE.format( 951 os.path.join(embed_dir, embed_name) + '_embed.h')) 952 953 if namespaces: 954 result.append('') 955 for n in namespaces: 956 result.append('namespace {} {{'.format(n)) 957 958 if related_types: 959 result.append('') 960 for t in related_types: 961 result.append(t) 962 963 result.append('') 964 965 if embed_name: 966 result.append( 967 Generator.EMBED_CLASS.format(name, embed_name.replace('-', '_'))) 968 969 result.append('class {}Api {{'.format(name)) 970 result.append(' public:') 971 result.append(' explicit {}Api(::sapi::Sandbox* sandbox)' 972 ' : sandbox_(sandbox) {{}}'.format(name)) 973 result.append(' // Deprecated') 974 result.append(' ::sapi::Sandbox* GetSandbox() const { return sandbox(); }') 975 result.append(' ::sapi::Sandbox* sandbox() const { return sandbox_; }') 976 977 for f in functions: 978 result.append('') 979 result.append(self._format_function(f)) 980 981 result.append('') 982 result.append(' private:') 983 result.append(' ::sapi::Sandbox* sandbox_;') 984 result.append('};') 985 result.append('') 986 987 if namespaces: 988 for n in reversed(namespaces): 989 result.append('}} // namespace {}'.format(n)) 990 991 if header_guard: 992 result.append(Generator.GUARD_END.format(header_guard)) 993 994 result.append('') 995 996 return '\n'.join(result) 997