1""" 2 ast 3 ~~~ 4 5 The `ast` module helps Python applications to process trees of the Python 6 abstract syntax grammar. The abstract syntax itself might change with 7 each Python release; this module helps to find out programmatically what 8 the current grammar looks like and allows modifications of it. 9 10 An abstract syntax tree can be generated by passing `ast.PyCF_ONLY_AST` as 11 a flag to the `compile()` builtin function or by using the `parse()` 12 function from this module. The result will be a tree of objects whose 13 classes all inherit from `ast.AST`. 14 15 A modified abstract syntax tree can be compiled into a Python code object 16 using the built-in `compile()` function. 17 18 Additionally various helper functions are provided that make working with 19 the trees simpler. The main intention of the helper functions and this 20 module in general is to provide an easy to use interface for libraries 21 that work tightly with the python syntax (template engines for example). 22 23 24 :copyright: Copyright 2008 by Armin Ronacher. 25 :license: Python License. 26""" 27import sys 28from _ast import * 29from contextlib import contextmanager, nullcontext 30from enum import IntEnum, auto, _simple_enum 31 32 33def parse(source, filename='<unknown>', mode='exec', *, 34 type_comments=False, feature_version=None): 35 """ 36 Parse the source into an AST node. 37 Equivalent to compile(source, filename, mode, PyCF_ONLY_AST). 38 Pass type_comments=True to get back type comments where the syntax allows. 39 """ 40 flags = PyCF_ONLY_AST 41 if type_comments: 42 flags |= PyCF_TYPE_COMMENTS 43 if isinstance(feature_version, tuple): 44 major, minor = feature_version # Should be a 2-tuple. 45 assert major == 3 46 feature_version = minor 47 elif feature_version is None: 48 feature_version = -1 49 # Else it should be an int giving the minor version for 3.x. 50 return compile(source, filename, mode, flags, 51 _feature_version=feature_version) 52 53 54def literal_eval(node_or_string): 55 """ 56 Evaluate an expression node or a string containing only a Python 57 expression. The string or node provided may only consist of the following 58 Python literal structures: strings, bytes, numbers, tuples, lists, dicts, 59 sets, booleans, and None. 60 61 Caution: A complex expression can overflow the C stack and cause a crash. 62 """ 63 if isinstance(node_or_string, str): 64 node_or_string = parse(node_or_string.lstrip(" \t"), mode='eval') 65 if isinstance(node_or_string, Expression): 66 node_or_string = node_or_string.body 67 def _raise_malformed_node(node): 68 msg = "malformed node or string" 69 if lno := getattr(node, 'lineno', None): 70 msg += f' on line {lno}' 71 raise ValueError(msg + f': {node!r}') 72 def _convert_num(node): 73 if not isinstance(node, Constant) or type(node.value) not in (int, float, complex): 74 _raise_malformed_node(node) 75 return node.value 76 def _convert_signed_num(node): 77 if isinstance(node, UnaryOp) and isinstance(node.op, (UAdd, USub)): 78 operand = _convert_num(node.operand) 79 if isinstance(node.op, UAdd): 80 return + operand 81 else: 82 return - operand 83 return _convert_num(node) 84 def _convert(node): 85 if isinstance(node, Constant): 86 return node.value 87 elif isinstance(node, Tuple): 88 return tuple(map(_convert, node.elts)) 89 elif isinstance(node, List): 90 return list(map(_convert, node.elts)) 91 elif isinstance(node, Set): 92 return set(map(_convert, node.elts)) 93 elif (isinstance(node, Call) and isinstance(node.func, Name) and 94 node.func.id == 'set' and node.args == node.keywords == []): 95 return set() 96 elif isinstance(node, Dict): 97 if len(node.keys) != len(node.values): 98 _raise_malformed_node(node) 99 return dict(zip(map(_convert, node.keys), 100 map(_convert, node.values))) 101 elif isinstance(node, BinOp) and isinstance(node.op, (Add, Sub)): 102 left = _convert_signed_num(node.left) 103 right = _convert_num(node.right) 104 if isinstance(left, (int, float)) and isinstance(right, complex): 105 if isinstance(node.op, Add): 106 return left + right 107 else: 108 return left - right 109 return _convert_signed_num(node) 110 return _convert(node_or_string) 111 112 113def dump(node, annotate_fields=True, include_attributes=False, *, indent=None): 114 """ 115 Return a formatted dump of the tree in node. This is mainly useful for 116 debugging purposes. If annotate_fields is true (by default), 117 the returned string will show the names and the values for fields. 118 If annotate_fields is false, the result string will be more compact by 119 omitting unambiguous field names. Attributes such as line 120 numbers and column offsets are not dumped by default. If this is wanted, 121 include_attributes can be set to true. If indent is a non-negative 122 integer or string, then the tree will be pretty-printed with that indent 123 level. None (the default) selects the single line representation. 124 """ 125 def _format(node, level=0): 126 if indent is not None: 127 level += 1 128 prefix = '\n' + indent * level 129 sep = ',\n' + indent * level 130 else: 131 prefix = '' 132 sep = ', ' 133 if isinstance(node, AST): 134 cls = type(node) 135 args = [] 136 allsimple = True 137 keywords = annotate_fields 138 for name in node._fields: 139 try: 140 value = getattr(node, name) 141 except AttributeError: 142 keywords = True 143 continue 144 if value is None and getattr(cls, name, ...) is None: 145 keywords = True 146 continue 147 value, simple = _format(value, level) 148 allsimple = allsimple and simple 149 if keywords: 150 args.append('%s=%s' % (name, value)) 151 else: 152 args.append(value) 153 if include_attributes and node._attributes: 154 for name in node._attributes: 155 try: 156 value = getattr(node, name) 157 except AttributeError: 158 continue 159 if value is None and getattr(cls, name, ...) is None: 160 continue 161 value, simple = _format(value, level) 162 allsimple = allsimple and simple 163 args.append('%s=%s' % (name, value)) 164 if allsimple and len(args) <= 3: 165 return '%s(%s)' % (node.__class__.__name__, ', '.join(args)), not args 166 return '%s(%s%s)' % (node.__class__.__name__, prefix, sep.join(args)), False 167 elif isinstance(node, list): 168 if not node: 169 return '[]', True 170 return '[%s%s]' % (prefix, sep.join(_format(x, level)[0] for x in node)), False 171 return repr(node), True 172 173 if not isinstance(node, AST): 174 raise TypeError('expected AST, got %r' % node.__class__.__name__) 175 if indent is not None and not isinstance(indent, str): 176 indent = ' ' * indent 177 return _format(node)[0] 178 179 180def copy_location(new_node, old_node): 181 """ 182 Copy source location (`lineno`, `col_offset`, `end_lineno`, and `end_col_offset` 183 attributes) from *old_node* to *new_node* if possible, and return *new_node*. 184 """ 185 for attr in 'lineno', 'col_offset', 'end_lineno', 'end_col_offset': 186 if attr in old_node._attributes and attr in new_node._attributes: 187 value = getattr(old_node, attr, None) 188 # end_lineno and end_col_offset are optional attributes, and they 189 # should be copied whether the value is None or not. 190 if value is not None or ( 191 hasattr(old_node, attr) and attr.startswith("end_") 192 ): 193 setattr(new_node, attr, value) 194 return new_node 195 196 197def fix_missing_locations(node): 198 """ 199 When you compile a node tree with compile(), the compiler expects lineno and 200 col_offset attributes for every node that supports them. This is rather 201 tedious to fill in for generated nodes, so this helper adds these attributes 202 recursively where not already set, by setting them to the values of the 203 parent node. It works recursively starting at *node*. 204 """ 205 def _fix(node, lineno, col_offset, end_lineno, end_col_offset): 206 if 'lineno' in node._attributes: 207 if not hasattr(node, 'lineno'): 208 node.lineno = lineno 209 else: 210 lineno = node.lineno 211 if 'end_lineno' in node._attributes: 212 if getattr(node, 'end_lineno', None) is None: 213 node.end_lineno = end_lineno 214 else: 215 end_lineno = node.end_lineno 216 if 'col_offset' in node._attributes: 217 if not hasattr(node, 'col_offset'): 218 node.col_offset = col_offset 219 else: 220 col_offset = node.col_offset 221 if 'end_col_offset' in node._attributes: 222 if getattr(node, 'end_col_offset', None) is None: 223 node.end_col_offset = end_col_offset 224 else: 225 end_col_offset = node.end_col_offset 226 for child in iter_child_nodes(node): 227 _fix(child, lineno, col_offset, end_lineno, end_col_offset) 228 _fix(node, 1, 0, 1, 0) 229 return node 230 231 232def increment_lineno(node, n=1): 233 """ 234 Increment the line number and end line number of each node in the tree 235 starting at *node* by *n*. This is useful to "move code" to a different 236 location in a file. 237 """ 238 for child in walk(node): 239 # TypeIgnore is a special case where lineno is not an attribute 240 # but rather a field of the node itself. 241 if isinstance(child, TypeIgnore): 242 child.lineno = getattr(child, 'lineno', 0) + n 243 continue 244 245 if 'lineno' in child._attributes: 246 child.lineno = getattr(child, 'lineno', 0) + n 247 if ( 248 "end_lineno" in child._attributes 249 and (end_lineno := getattr(child, "end_lineno", 0)) is not None 250 ): 251 child.end_lineno = end_lineno + n 252 return node 253 254 255def iter_fields(node): 256 """ 257 Yield a tuple of ``(fieldname, value)`` for each field in ``node._fields`` 258 that is present on *node*. 259 """ 260 for field in node._fields: 261 try: 262 yield field, getattr(node, field) 263 except AttributeError: 264 pass 265 266 267def iter_child_nodes(node): 268 """ 269 Yield all direct child nodes of *node*, that is, all fields that are nodes 270 and all items of fields that are lists of nodes. 271 """ 272 for name, field in iter_fields(node): 273 if isinstance(field, AST): 274 yield field 275 elif isinstance(field, list): 276 for item in field: 277 if isinstance(item, AST): 278 yield item 279 280 281def get_docstring(node, clean=True): 282 """ 283 Return the docstring for the given node or None if no docstring can 284 be found. If the node provided does not have docstrings a TypeError 285 will be raised. 286 287 If *clean* is `True`, all tabs are expanded to spaces and any whitespace 288 that can be uniformly removed from the second line onwards is removed. 289 """ 290 if not isinstance(node, (AsyncFunctionDef, FunctionDef, ClassDef, Module)): 291 raise TypeError("%r can't have docstrings" % node.__class__.__name__) 292 if not(node.body and isinstance(node.body[0], Expr)): 293 return None 294 node = node.body[0].value 295 if isinstance(node, Str): 296 text = node.s 297 elif isinstance(node, Constant) and isinstance(node.value, str): 298 text = node.value 299 else: 300 return None 301 if clean: 302 import inspect 303 text = inspect.cleandoc(text) 304 return text 305 306 307def _splitlines_no_ff(source): 308 """Split a string into lines ignoring form feed and other chars. 309 310 This mimics how the Python parser splits source code. 311 """ 312 idx = 0 313 lines = [] 314 next_line = '' 315 while idx < len(source): 316 c = source[idx] 317 next_line += c 318 idx += 1 319 # Keep \r\n together 320 if c == '\r' and idx < len(source) and source[idx] == '\n': 321 next_line += '\n' 322 idx += 1 323 if c in '\r\n': 324 lines.append(next_line) 325 next_line = '' 326 327 if next_line: 328 lines.append(next_line) 329 return lines 330 331 332def _pad_whitespace(source): 333 r"""Replace all chars except '\f\t' in a line with spaces.""" 334 result = '' 335 for c in source: 336 if c in '\f\t': 337 result += c 338 else: 339 result += ' ' 340 return result 341 342 343def get_source_segment(source, node, *, padded=False): 344 """Get source code segment of the *source* that generated *node*. 345 346 If some location information (`lineno`, `end_lineno`, `col_offset`, 347 or `end_col_offset`) is missing, return None. 348 349 If *padded* is `True`, the first line of a multi-line statement will 350 be padded with spaces to match its original position. 351 """ 352 try: 353 if node.end_lineno is None or node.end_col_offset is None: 354 return None 355 lineno = node.lineno - 1 356 end_lineno = node.end_lineno - 1 357 col_offset = node.col_offset 358 end_col_offset = node.end_col_offset 359 except AttributeError: 360 return None 361 362 lines = _splitlines_no_ff(source) 363 if end_lineno == lineno: 364 return lines[lineno].encode()[col_offset:end_col_offset].decode() 365 366 if padded: 367 padding = _pad_whitespace(lines[lineno].encode()[:col_offset].decode()) 368 else: 369 padding = '' 370 371 first = padding + lines[lineno].encode()[col_offset:].decode() 372 last = lines[end_lineno].encode()[:end_col_offset].decode() 373 lines = lines[lineno+1:end_lineno] 374 375 lines.insert(0, first) 376 lines.append(last) 377 return ''.join(lines) 378 379 380def walk(node): 381 """ 382 Recursively yield all descendant nodes in the tree starting at *node* 383 (including *node* itself), in no specified order. This is useful if you 384 only want to modify nodes in place and don't care about the context. 385 """ 386 from collections import deque 387 todo = deque([node]) 388 while todo: 389 node = todo.popleft() 390 todo.extend(iter_child_nodes(node)) 391 yield node 392 393 394class NodeVisitor(object): 395 """ 396 A node visitor base class that walks the abstract syntax tree and calls a 397 visitor function for every node found. This function may return a value 398 which is forwarded by the `visit` method. 399 400 This class is meant to be subclassed, with the subclass adding visitor 401 methods. 402 403 Per default the visitor functions for the nodes are ``'visit_'`` + 404 class name of the node. So a `TryFinally` node visit function would 405 be `visit_TryFinally`. This behavior can be changed by overriding 406 the `visit` method. If no visitor function exists for a node 407 (return value `None`) the `generic_visit` visitor is used instead. 408 409 Don't use the `NodeVisitor` if you want to apply changes to nodes during 410 traversing. For this a special visitor exists (`NodeTransformer`) that 411 allows modifications. 412 """ 413 414 def visit(self, node): 415 """Visit a node.""" 416 method = 'visit_' + node.__class__.__name__ 417 visitor = getattr(self, method, self.generic_visit) 418 return visitor(node) 419 420 def generic_visit(self, node): 421 """Called if no explicit visitor function exists for a node.""" 422 for field, value in iter_fields(node): 423 if isinstance(value, list): 424 for item in value: 425 if isinstance(item, AST): 426 self.visit(item) 427 elif isinstance(value, AST): 428 self.visit(value) 429 430 def visit_Constant(self, node): 431 value = node.value 432 type_name = _const_node_type_names.get(type(value)) 433 if type_name is None: 434 for cls, name in _const_node_type_names.items(): 435 if isinstance(value, cls): 436 type_name = name 437 break 438 if type_name is not None: 439 method = 'visit_' + type_name 440 try: 441 visitor = getattr(self, method) 442 except AttributeError: 443 pass 444 else: 445 import warnings 446 warnings.warn(f"{method} is deprecated; add visit_Constant", 447 DeprecationWarning, 2) 448 return visitor(node) 449 return self.generic_visit(node) 450 451 452class NodeTransformer(NodeVisitor): 453 """ 454 A :class:`NodeVisitor` subclass that walks the abstract syntax tree and 455 allows modification of nodes. 456 457 The `NodeTransformer` will walk the AST and use the return value of the 458 visitor methods to replace or remove the old node. If the return value of 459 the visitor method is ``None``, the node will be removed from its location, 460 otherwise it is replaced with the return value. The return value may be the 461 original node in which case no replacement takes place. 462 463 Here is an example transformer that rewrites all occurrences of name lookups 464 (``foo``) to ``data['foo']``:: 465 466 class RewriteName(NodeTransformer): 467 468 def visit_Name(self, node): 469 return Subscript( 470 value=Name(id='data', ctx=Load()), 471 slice=Constant(value=node.id), 472 ctx=node.ctx 473 ) 474 475 Keep in mind that if the node you're operating on has child nodes you must 476 either transform the child nodes yourself or call the :meth:`generic_visit` 477 method for the node first. 478 479 For nodes that were part of a collection of statements (that applies to all 480 statement nodes), the visitor may also return a list of nodes rather than 481 just a single node. 482 483 Usually you use the transformer like this:: 484 485 node = YourTransformer().visit(node) 486 """ 487 488 def generic_visit(self, node): 489 for field, old_value in iter_fields(node): 490 if isinstance(old_value, list): 491 new_values = [] 492 for value in old_value: 493 if isinstance(value, AST): 494 value = self.visit(value) 495 if value is None: 496 continue 497 elif not isinstance(value, AST): 498 new_values.extend(value) 499 continue 500 new_values.append(value) 501 old_value[:] = new_values 502 elif isinstance(old_value, AST): 503 new_node = self.visit(old_value) 504 if new_node is None: 505 delattr(node, field) 506 else: 507 setattr(node, field, new_node) 508 return node 509 510 511# If the ast module is loaded more than once, only add deprecated methods once 512if not hasattr(Constant, 'n'): 513 # The following code is for backward compatibility. 514 # It will be removed in future. 515 516 def _getter(self): 517 """Deprecated. Use value instead.""" 518 return self.value 519 520 def _setter(self, value): 521 self.value = value 522 523 Constant.n = property(_getter, _setter) 524 Constant.s = property(_getter, _setter) 525 526class _ABC(type): 527 528 def __init__(cls, *args): 529 cls.__doc__ = """Deprecated AST node class. Use ast.Constant instead""" 530 531 def __instancecheck__(cls, inst): 532 if not isinstance(inst, Constant): 533 return False 534 if cls in _const_types: 535 try: 536 value = inst.value 537 except AttributeError: 538 return False 539 else: 540 return ( 541 isinstance(value, _const_types[cls]) and 542 not isinstance(value, _const_types_not.get(cls, ())) 543 ) 544 return type.__instancecheck__(cls, inst) 545 546def _new(cls, *args, **kwargs): 547 for key in kwargs: 548 if key not in cls._fields: 549 # arbitrary keyword arguments are accepted 550 continue 551 pos = cls._fields.index(key) 552 if pos < len(args): 553 raise TypeError(f"{cls.__name__} got multiple values for argument {key!r}") 554 if cls in _const_types: 555 return Constant(*args, **kwargs) 556 return Constant.__new__(cls, *args, **kwargs) 557 558class Num(Constant, metaclass=_ABC): 559 _fields = ('n',) 560 __new__ = _new 561 562class Str(Constant, metaclass=_ABC): 563 _fields = ('s',) 564 __new__ = _new 565 566class Bytes(Constant, metaclass=_ABC): 567 _fields = ('s',) 568 __new__ = _new 569 570class NameConstant(Constant, metaclass=_ABC): 571 __new__ = _new 572 573class Ellipsis(Constant, metaclass=_ABC): 574 _fields = () 575 576 def __new__(cls, *args, **kwargs): 577 if cls is Ellipsis: 578 return Constant(..., *args, **kwargs) 579 return Constant.__new__(cls, *args, **kwargs) 580 581_const_types = { 582 Num: (int, float, complex), 583 Str: (str,), 584 Bytes: (bytes,), 585 NameConstant: (type(None), bool), 586 Ellipsis: (type(...),), 587} 588_const_types_not = { 589 Num: (bool,), 590} 591 592_const_node_type_names = { 593 bool: 'NameConstant', # should be before int 594 type(None): 'NameConstant', 595 int: 'Num', 596 float: 'Num', 597 complex: 'Num', 598 str: 'Str', 599 bytes: 'Bytes', 600 type(...): 'Ellipsis', 601} 602 603class slice(AST): 604 """Deprecated AST node class.""" 605 606class Index(slice): 607 """Deprecated AST node class. Use the index value directly instead.""" 608 def __new__(cls, value, **kwargs): 609 return value 610 611class ExtSlice(slice): 612 """Deprecated AST node class. Use ast.Tuple instead.""" 613 def __new__(cls, dims=(), **kwargs): 614 return Tuple(list(dims), Load(), **kwargs) 615 616# If the ast module is loaded more than once, only add deprecated methods once 617if not hasattr(Tuple, 'dims'): 618 # The following code is for backward compatibility. 619 # It will be removed in future. 620 621 def _dims_getter(self): 622 """Deprecated. Use elts instead.""" 623 return self.elts 624 625 def _dims_setter(self, value): 626 self.elts = value 627 628 Tuple.dims = property(_dims_getter, _dims_setter) 629 630class Suite(mod): 631 """Deprecated AST node class. Unused in Python 3.""" 632 633class AugLoad(expr_context): 634 """Deprecated AST node class. Unused in Python 3.""" 635 636class AugStore(expr_context): 637 """Deprecated AST node class. Unused in Python 3.""" 638 639class Param(expr_context): 640 """Deprecated AST node class. Unused in Python 3.""" 641 642 643# Large float and imaginary literals get turned into infinities in the AST. 644# We unparse those infinities to INFSTR. 645_INFSTR = "1e" + repr(sys.float_info.max_10_exp + 1) 646 647@_simple_enum(IntEnum) 648class _Precedence: 649 """Precedence table that originated from python grammar.""" 650 651 NAMED_EXPR = auto() # <target> := <expr1> 652 TUPLE = auto() # <expr1>, <expr2> 653 YIELD = auto() # 'yield', 'yield from' 654 TEST = auto() # 'if'-'else', 'lambda' 655 OR = auto() # 'or' 656 AND = auto() # 'and' 657 NOT = auto() # 'not' 658 CMP = auto() # '<', '>', '==', '>=', '<=', '!=', 659 # 'in', 'not in', 'is', 'is not' 660 EXPR = auto() 661 BOR = EXPR # '|' 662 BXOR = auto() # '^' 663 BAND = auto() # '&' 664 SHIFT = auto() # '<<', '>>' 665 ARITH = auto() # '+', '-' 666 TERM = auto() # '*', '@', '/', '%', '//' 667 FACTOR = auto() # unary '+', '-', '~' 668 POWER = auto() # '**' 669 AWAIT = auto() # 'await' 670 ATOM = auto() 671 672 def next(self): 673 try: 674 return self.__class__(self + 1) 675 except ValueError: 676 return self 677 678 679_SINGLE_QUOTES = ("'", '"') 680_MULTI_QUOTES = ('"""', "'''") 681_ALL_QUOTES = (*_SINGLE_QUOTES, *_MULTI_QUOTES) 682 683class _Unparser(NodeVisitor): 684 """Methods in this class recursively traverse an AST and 685 output source code for the abstract syntax; original formatting 686 is disregarded.""" 687 688 def __init__(self, *, _avoid_backslashes=False): 689 self._source = [] 690 self._precedences = {} 691 self._type_ignores = {} 692 self._indent = 0 693 self._avoid_backslashes = _avoid_backslashes 694 self._in_try_star = False 695 696 def interleave(self, inter, f, seq): 697 """Call f on each item in seq, calling inter() in between.""" 698 seq = iter(seq) 699 try: 700 f(next(seq)) 701 except StopIteration: 702 pass 703 else: 704 for x in seq: 705 inter() 706 f(x) 707 708 def items_view(self, traverser, items): 709 """Traverse and separate the given *items* with a comma and append it to 710 the buffer. If *items* is a single item sequence, a trailing comma 711 will be added.""" 712 if len(items) == 1: 713 traverser(items[0]) 714 self.write(",") 715 else: 716 self.interleave(lambda: self.write(", "), traverser, items) 717 718 def maybe_newline(self): 719 """Adds a newline if it isn't the start of generated source""" 720 if self._source: 721 self.write("\n") 722 723 def fill(self, text=""): 724 """Indent a piece of text and append it, according to the current 725 indentation level""" 726 self.maybe_newline() 727 self.write(" " * self._indent + text) 728 729 def write(self, *text): 730 """Add new source parts""" 731 self._source.extend(text) 732 733 @contextmanager 734 def buffered(self, buffer = None): 735 if buffer is None: 736 buffer = [] 737 738 original_source = self._source 739 self._source = buffer 740 yield buffer 741 self._source = original_source 742 743 @contextmanager 744 def block(self, *, extra = None): 745 """A context manager for preparing the source for blocks. It adds 746 the character':', increases the indentation on enter and decreases 747 the indentation on exit. If *extra* is given, it will be directly 748 appended after the colon character. 749 """ 750 self.write(":") 751 if extra: 752 self.write(extra) 753 self._indent += 1 754 yield 755 self._indent -= 1 756 757 @contextmanager 758 def delimit(self, start, end): 759 """A context manager for preparing the source for expressions. It adds 760 *start* to the buffer and enters, after exit it adds *end*.""" 761 762 self.write(start) 763 yield 764 self.write(end) 765 766 def delimit_if(self, start, end, condition): 767 if condition: 768 return self.delimit(start, end) 769 else: 770 return nullcontext() 771 772 def require_parens(self, precedence, node): 773 """Shortcut to adding precedence related parens""" 774 return self.delimit_if("(", ")", self.get_precedence(node) > precedence) 775 776 def get_precedence(self, node): 777 return self._precedences.get(node, _Precedence.TEST) 778 779 def set_precedence(self, precedence, *nodes): 780 for node in nodes: 781 self._precedences[node] = precedence 782 783 def get_raw_docstring(self, node): 784 """If a docstring node is found in the body of the *node* parameter, 785 return that docstring node, None otherwise. 786 787 Logic mirrored from ``_PyAST_GetDocString``.""" 788 if not isinstance( 789 node, (AsyncFunctionDef, FunctionDef, ClassDef, Module) 790 ) or len(node.body) < 1: 791 return None 792 node = node.body[0] 793 if not isinstance(node, Expr): 794 return None 795 node = node.value 796 if isinstance(node, Constant) and isinstance(node.value, str): 797 return node 798 799 def get_type_comment(self, node): 800 comment = self._type_ignores.get(node.lineno) or node.type_comment 801 if comment is not None: 802 return f" # type: {comment}" 803 804 def traverse(self, node): 805 if isinstance(node, list): 806 for item in node: 807 self.traverse(item) 808 else: 809 super().visit(node) 810 811 # Note: as visit() resets the output text, do NOT rely on 812 # NodeVisitor.generic_visit to handle any nodes (as it calls back in to 813 # the subclass visit() method, which resets self._source to an empty list) 814 def visit(self, node): 815 """Outputs a source code string that, if converted back to an ast 816 (using ast.parse) will generate an AST equivalent to *node*""" 817 self._source = [] 818 self.traverse(node) 819 return "".join(self._source) 820 821 def _write_docstring_and_traverse_body(self, node): 822 if (docstring := self.get_raw_docstring(node)): 823 self._write_docstring(docstring) 824 self.traverse(node.body[1:]) 825 else: 826 self.traverse(node.body) 827 828 def visit_Module(self, node): 829 self._type_ignores = { 830 ignore.lineno: f"ignore{ignore.tag}" 831 for ignore in node.type_ignores 832 } 833 self._write_docstring_and_traverse_body(node) 834 self._type_ignores.clear() 835 836 def visit_FunctionType(self, node): 837 with self.delimit("(", ")"): 838 self.interleave( 839 lambda: self.write(", "), self.traverse, node.argtypes 840 ) 841 842 self.write(" -> ") 843 self.traverse(node.returns) 844 845 def visit_Expr(self, node): 846 self.fill() 847 self.set_precedence(_Precedence.YIELD, node.value) 848 self.traverse(node.value) 849 850 def visit_NamedExpr(self, node): 851 with self.require_parens(_Precedence.NAMED_EXPR, node): 852 self.set_precedence(_Precedence.ATOM, node.target, node.value) 853 self.traverse(node.target) 854 self.write(" := ") 855 self.traverse(node.value) 856 857 def visit_Import(self, node): 858 self.fill("import ") 859 self.interleave(lambda: self.write(", "), self.traverse, node.names) 860 861 def visit_ImportFrom(self, node): 862 self.fill("from ") 863 self.write("." * (node.level or 0)) 864 if node.module: 865 self.write(node.module) 866 self.write(" import ") 867 self.interleave(lambda: self.write(", "), self.traverse, node.names) 868 869 def visit_Assign(self, node): 870 self.fill() 871 for target in node.targets: 872 self.set_precedence(_Precedence.TUPLE, target) 873 self.traverse(target) 874 self.write(" = ") 875 self.traverse(node.value) 876 if type_comment := self.get_type_comment(node): 877 self.write(type_comment) 878 879 def visit_AugAssign(self, node): 880 self.fill() 881 self.traverse(node.target) 882 self.write(" " + self.binop[node.op.__class__.__name__] + "= ") 883 self.traverse(node.value) 884 885 def visit_AnnAssign(self, node): 886 self.fill() 887 with self.delimit_if("(", ")", not node.simple and isinstance(node.target, Name)): 888 self.traverse(node.target) 889 self.write(": ") 890 self.traverse(node.annotation) 891 if node.value: 892 self.write(" = ") 893 self.traverse(node.value) 894 895 def visit_Return(self, node): 896 self.fill("return") 897 if node.value: 898 self.write(" ") 899 self.traverse(node.value) 900 901 def visit_Pass(self, node): 902 self.fill("pass") 903 904 def visit_Break(self, node): 905 self.fill("break") 906 907 def visit_Continue(self, node): 908 self.fill("continue") 909 910 def visit_Delete(self, node): 911 self.fill("del ") 912 self.interleave(lambda: self.write(", "), self.traverse, node.targets) 913 914 def visit_Assert(self, node): 915 self.fill("assert ") 916 self.traverse(node.test) 917 if node.msg: 918 self.write(", ") 919 self.traverse(node.msg) 920 921 def visit_Global(self, node): 922 self.fill("global ") 923 self.interleave(lambda: self.write(", "), self.write, node.names) 924 925 def visit_Nonlocal(self, node): 926 self.fill("nonlocal ") 927 self.interleave(lambda: self.write(", "), self.write, node.names) 928 929 def visit_Await(self, node): 930 with self.require_parens(_Precedence.AWAIT, node): 931 self.write("await") 932 if node.value: 933 self.write(" ") 934 self.set_precedence(_Precedence.ATOM, node.value) 935 self.traverse(node.value) 936 937 def visit_Yield(self, node): 938 with self.require_parens(_Precedence.YIELD, node): 939 self.write("yield") 940 if node.value: 941 self.write(" ") 942 self.set_precedence(_Precedence.ATOM, node.value) 943 self.traverse(node.value) 944 945 def visit_YieldFrom(self, node): 946 with self.require_parens(_Precedence.YIELD, node): 947 self.write("yield from ") 948 if not node.value: 949 raise ValueError("Node can't be used without a value attribute.") 950 self.set_precedence(_Precedence.ATOM, node.value) 951 self.traverse(node.value) 952 953 def visit_Raise(self, node): 954 self.fill("raise") 955 if not node.exc: 956 if node.cause: 957 raise ValueError(f"Node can't use cause without an exception.") 958 return 959 self.write(" ") 960 self.traverse(node.exc) 961 if node.cause: 962 self.write(" from ") 963 self.traverse(node.cause) 964 965 def do_visit_try(self, node): 966 self.fill("try") 967 with self.block(): 968 self.traverse(node.body) 969 for ex in node.handlers: 970 self.traverse(ex) 971 if node.orelse: 972 self.fill("else") 973 with self.block(): 974 self.traverse(node.orelse) 975 if node.finalbody: 976 self.fill("finally") 977 with self.block(): 978 self.traverse(node.finalbody) 979 980 def visit_Try(self, node): 981 prev_in_try_star = self._in_try_star 982 try: 983 self._in_try_star = False 984 self.do_visit_try(node) 985 finally: 986 self._in_try_star = prev_in_try_star 987 988 def visit_TryStar(self, node): 989 prev_in_try_star = self._in_try_star 990 try: 991 self._in_try_star = True 992 self.do_visit_try(node) 993 finally: 994 self._in_try_star = prev_in_try_star 995 996 def visit_ExceptHandler(self, node): 997 self.fill("except*" if self._in_try_star else "except") 998 if node.type: 999 self.write(" ") 1000 self.traverse(node.type) 1001 if node.name: 1002 self.write(" as ") 1003 self.write(node.name) 1004 with self.block(): 1005 self.traverse(node.body) 1006 1007 def visit_ClassDef(self, node): 1008 self.maybe_newline() 1009 for deco in node.decorator_list: 1010 self.fill("@") 1011 self.traverse(deco) 1012 self.fill("class " + node.name) 1013 with self.delimit_if("(", ")", condition = node.bases or node.keywords): 1014 comma = False 1015 for e in node.bases: 1016 if comma: 1017 self.write(", ") 1018 else: 1019 comma = True 1020 self.traverse(e) 1021 for e in node.keywords: 1022 if comma: 1023 self.write(", ") 1024 else: 1025 comma = True 1026 self.traverse(e) 1027 1028 with self.block(): 1029 self._write_docstring_and_traverse_body(node) 1030 1031 def visit_FunctionDef(self, node): 1032 self._function_helper(node, "def") 1033 1034 def visit_AsyncFunctionDef(self, node): 1035 self._function_helper(node, "async def") 1036 1037 def _function_helper(self, node, fill_suffix): 1038 self.maybe_newline() 1039 for deco in node.decorator_list: 1040 self.fill("@") 1041 self.traverse(deco) 1042 def_str = fill_suffix + " " + node.name 1043 self.fill(def_str) 1044 with self.delimit("(", ")"): 1045 self.traverse(node.args) 1046 if node.returns: 1047 self.write(" -> ") 1048 self.traverse(node.returns) 1049 with self.block(extra=self.get_type_comment(node)): 1050 self._write_docstring_and_traverse_body(node) 1051 1052 def visit_For(self, node): 1053 self._for_helper("for ", node) 1054 1055 def visit_AsyncFor(self, node): 1056 self._for_helper("async for ", node) 1057 1058 def _for_helper(self, fill, node): 1059 self.fill(fill) 1060 self.set_precedence(_Precedence.TUPLE, node.target) 1061 self.traverse(node.target) 1062 self.write(" in ") 1063 self.traverse(node.iter) 1064 with self.block(extra=self.get_type_comment(node)): 1065 self.traverse(node.body) 1066 if node.orelse: 1067 self.fill("else") 1068 with self.block(): 1069 self.traverse(node.orelse) 1070 1071 def visit_If(self, node): 1072 self.fill("if ") 1073 self.traverse(node.test) 1074 with self.block(): 1075 self.traverse(node.body) 1076 # collapse nested ifs into equivalent elifs. 1077 while node.orelse and len(node.orelse) == 1 and isinstance(node.orelse[0], If): 1078 node = node.orelse[0] 1079 self.fill("elif ") 1080 self.traverse(node.test) 1081 with self.block(): 1082 self.traverse(node.body) 1083 # final else 1084 if node.orelse: 1085 self.fill("else") 1086 with self.block(): 1087 self.traverse(node.orelse) 1088 1089 def visit_While(self, node): 1090 self.fill("while ") 1091 self.traverse(node.test) 1092 with self.block(): 1093 self.traverse(node.body) 1094 if node.orelse: 1095 self.fill("else") 1096 with self.block(): 1097 self.traverse(node.orelse) 1098 1099 def visit_With(self, node): 1100 self.fill("with ") 1101 self.interleave(lambda: self.write(", "), self.traverse, node.items) 1102 with self.block(extra=self.get_type_comment(node)): 1103 self.traverse(node.body) 1104 1105 def visit_AsyncWith(self, node): 1106 self.fill("async with ") 1107 self.interleave(lambda: self.write(", "), self.traverse, node.items) 1108 with self.block(extra=self.get_type_comment(node)): 1109 self.traverse(node.body) 1110 1111 def _str_literal_helper( 1112 self, string, *, quote_types=_ALL_QUOTES, escape_special_whitespace=False 1113 ): 1114 """Helper for writing string literals, minimizing escapes. 1115 Returns the tuple (string literal to write, possible quote types). 1116 """ 1117 def escape_char(c): 1118 # \n and \t are non-printable, but we only escape them if 1119 # escape_special_whitespace is True 1120 if not escape_special_whitespace and c in "\n\t": 1121 return c 1122 # Always escape backslashes and other non-printable characters 1123 if c == "\\" or not c.isprintable(): 1124 return c.encode("unicode_escape").decode("ascii") 1125 return c 1126 1127 escaped_string = "".join(map(escape_char, string)) 1128 possible_quotes = quote_types 1129 if "\n" in escaped_string: 1130 possible_quotes = [q for q in possible_quotes if q in _MULTI_QUOTES] 1131 possible_quotes = [q for q in possible_quotes if q not in escaped_string] 1132 if not possible_quotes: 1133 # If there aren't any possible_quotes, fallback to using repr 1134 # on the original string. Try to use a quote from quote_types, 1135 # e.g., so that we use triple quotes for docstrings. 1136 string = repr(string) 1137 quote = next((q for q in quote_types if string[0] in q), string[0]) 1138 return string[1:-1], [quote] 1139 if escaped_string: 1140 # Sort so that we prefer '''"''' over """\"""" 1141 possible_quotes.sort(key=lambda q: q[0] == escaped_string[-1]) 1142 # If we're using triple quotes and we'd need to escape a final 1143 # quote, escape it 1144 if possible_quotes[0][0] == escaped_string[-1]: 1145 assert len(possible_quotes[0]) == 3 1146 escaped_string = escaped_string[:-1] + "\\" + escaped_string[-1] 1147 return escaped_string, possible_quotes 1148 1149 def _write_str_avoiding_backslashes(self, string, *, quote_types=_ALL_QUOTES): 1150 """Write string literal value with a best effort attempt to avoid backslashes.""" 1151 string, quote_types = self._str_literal_helper(string, quote_types=quote_types) 1152 quote_type = quote_types[0] 1153 self.write(f"{quote_type}{string}{quote_type}") 1154 1155 def visit_JoinedStr(self, node): 1156 self.write("f") 1157 if self._avoid_backslashes: 1158 with self.buffered() as buffer: 1159 self._write_fstring_inner(node) 1160 return self._write_str_avoiding_backslashes("".join(buffer)) 1161 1162 # If we don't need to avoid backslashes globally (i.e., we only need 1163 # to avoid them inside FormattedValues), it's cosmetically preferred 1164 # to use escaped whitespace. That is, it's preferred to use backslashes 1165 # for cases like: f"{x}\n". To accomplish this, we keep track of what 1166 # in our buffer corresponds to FormattedValues and what corresponds to 1167 # Constant parts of the f-string, and allow escapes accordingly. 1168 fstring_parts = [] 1169 for value in node.values: 1170 with self.buffered() as buffer: 1171 self._write_fstring_inner(value) 1172 fstring_parts.append( 1173 ("".join(buffer), isinstance(value, Constant)) 1174 ) 1175 1176 new_fstring_parts = [] 1177 quote_types = list(_ALL_QUOTES) 1178 for value, is_constant in fstring_parts: 1179 value, quote_types = self._str_literal_helper( 1180 value, 1181 quote_types=quote_types, 1182 escape_special_whitespace=is_constant, 1183 ) 1184 new_fstring_parts.append(value) 1185 1186 value = "".join(new_fstring_parts) 1187 quote_type = quote_types[0] 1188 self.write(f"{quote_type}{value}{quote_type}") 1189 1190 def _write_fstring_inner(self, node): 1191 if isinstance(node, JoinedStr): 1192 # for both the f-string itself, and format_spec 1193 for value in node.values: 1194 self._write_fstring_inner(value) 1195 elif isinstance(node, Constant) and isinstance(node.value, str): 1196 value = node.value.replace("{", "{{").replace("}", "}}") 1197 self.write(value) 1198 elif isinstance(node, FormattedValue): 1199 self.visit_FormattedValue(node) 1200 else: 1201 raise ValueError(f"Unexpected node inside JoinedStr, {node!r}") 1202 1203 def visit_FormattedValue(self, node): 1204 def unparse_inner(inner): 1205 unparser = type(self)(_avoid_backslashes=True) 1206 unparser.set_precedence(_Precedence.TEST.next(), inner) 1207 return unparser.visit(inner) 1208 1209 with self.delimit("{", "}"): 1210 expr = unparse_inner(node.value) 1211 if "\\" in expr: 1212 raise ValueError( 1213 "Unable to avoid backslash in f-string expression part" 1214 ) 1215 if expr.startswith("{"): 1216 # Separate pair of opening brackets as "{ {" 1217 self.write(" ") 1218 self.write(expr) 1219 if node.conversion != -1: 1220 self.write(f"!{chr(node.conversion)}") 1221 if node.format_spec: 1222 self.write(":") 1223 self._write_fstring_inner(node.format_spec) 1224 1225 def visit_Name(self, node): 1226 self.write(node.id) 1227 1228 def _write_docstring(self, node): 1229 self.fill() 1230 if node.kind == "u": 1231 self.write("u") 1232 self._write_str_avoiding_backslashes(node.value, quote_types=_MULTI_QUOTES) 1233 1234 def _write_constant(self, value): 1235 if isinstance(value, (float, complex)): 1236 # Substitute overflowing decimal literal for AST infinities, 1237 # and inf - inf for NaNs. 1238 self.write( 1239 repr(value) 1240 .replace("inf", _INFSTR) 1241 .replace("nan", f"({_INFSTR}-{_INFSTR})") 1242 ) 1243 elif self._avoid_backslashes and isinstance(value, str): 1244 self._write_str_avoiding_backslashes(value) 1245 else: 1246 self.write(repr(value)) 1247 1248 def visit_Constant(self, node): 1249 value = node.value 1250 if isinstance(value, tuple): 1251 with self.delimit("(", ")"): 1252 self.items_view(self._write_constant, value) 1253 elif value is ...: 1254 self.write("...") 1255 else: 1256 if node.kind == "u": 1257 self.write("u") 1258 self._write_constant(node.value) 1259 1260 def visit_List(self, node): 1261 with self.delimit("[", "]"): 1262 self.interleave(lambda: self.write(", "), self.traverse, node.elts) 1263 1264 def visit_ListComp(self, node): 1265 with self.delimit("[", "]"): 1266 self.traverse(node.elt) 1267 for gen in node.generators: 1268 self.traverse(gen) 1269 1270 def visit_GeneratorExp(self, node): 1271 with self.delimit("(", ")"): 1272 self.traverse(node.elt) 1273 for gen in node.generators: 1274 self.traverse(gen) 1275 1276 def visit_SetComp(self, node): 1277 with self.delimit("{", "}"): 1278 self.traverse(node.elt) 1279 for gen in node.generators: 1280 self.traverse(gen) 1281 1282 def visit_DictComp(self, node): 1283 with self.delimit("{", "}"): 1284 self.traverse(node.key) 1285 self.write(": ") 1286 self.traverse(node.value) 1287 for gen in node.generators: 1288 self.traverse(gen) 1289 1290 def visit_comprehension(self, node): 1291 if node.is_async: 1292 self.write(" async for ") 1293 else: 1294 self.write(" for ") 1295 self.set_precedence(_Precedence.TUPLE, node.target) 1296 self.traverse(node.target) 1297 self.write(" in ") 1298 self.set_precedence(_Precedence.TEST.next(), node.iter, *node.ifs) 1299 self.traverse(node.iter) 1300 for if_clause in node.ifs: 1301 self.write(" if ") 1302 self.traverse(if_clause) 1303 1304 def visit_IfExp(self, node): 1305 with self.require_parens(_Precedence.TEST, node): 1306 self.set_precedence(_Precedence.TEST.next(), node.body, node.test) 1307 self.traverse(node.body) 1308 self.write(" if ") 1309 self.traverse(node.test) 1310 self.write(" else ") 1311 self.set_precedence(_Precedence.TEST, node.orelse) 1312 self.traverse(node.orelse) 1313 1314 def visit_Set(self, node): 1315 if node.elts: 1316 with self.delimit("{", "}"): 1317 self.interleave(lambda: self.write(", "), self.traverse, node.elts) 1318 else: 1319 # `{}` would be interpreted as a dictionary literal, and 1320 # `set` might be shadowed. Thus: 1321 self.write('{*()}') 1322 1323 def visit_Dict(self, node): 1324 def write_key_value_pair(k, v): 1325 self.traverse(k) 1326 self.write(": ") 1327 self.traverse(v) 1328 1329 def write_item(item): 1330 k, v = item 1331 if k is None: 1332 # for dictionary unpacking operator in dicts {**{'y': 2}} 1333 # see PEP 448 for details 1334 self.write("**") 1335 self.set_precedence(_Precedence.EXPR, v) 1336 self.traverse(v) 1337 else: 1338 write_key_value_pair(k, v) 1339 1340 with self.delimit("{", "}"): 1341 self.interleave( 1342 lambda: self.write(", "), write_item, zip(node.keys, node.values) 1343 ) 1344 1345 def visit_Tuple(self, node): 1346 with self.delimit_if( 1347 "(", 1348 ")", 1349 len(node.elts) == 0 or self.get_precedence(node) > _Precedence.TUPLE 1350 ): 1351 self.items_view(self.traverse, node.elts) 1352 1353 unop = {"Invert": "~", "Not": "not", "UAdd": "+", "USub": "-"} 1354 unop_precedence = { 1355 "not": _Precedence.NOT, 1356 "~": _Precedence.FACTOR, 1357 "+": _Precedence.FACTOR, 1358 "-": _Precedence.FACTOR, 1359 } 1360 1361 def visit_UnaryOp(self, node): 1362 operator = self.unop[node.op.__class__.__name__] 1363 operator_precedence = self.unop_precedence[operator] 1364 with self.require_parens(operator_precedence, node): 1365 self.write(operator) 1366 # factor prefixes (+, -, ~) shouldn't be separated 1367 # from the value they belong, (e.g: +1 instead of + 1) 1368 if operator_precedence is not _Precedence.FACTOR: 1369 self.write(" ") 1370 self.set_precedence(operator_precedence, node.operand) 1371 self.traverse(node.operand) 1372 1373 binop = { 1374 "Add": "+", 1375 "Sub": "-", 1376 "Mult": "*", 1377 "MatMult": "@", 1378 "Div": "/", 1379 "Mod": "%", 1380 "LShift": "<<", 1381 "RShift": ">>", 1382 "BitOr": "|", 1383 "BitXor": "^", 1384 "BitAnd": "&", 1385 "FloorDiv": "//", 1386 "Pow": "**", 1387 } 1388 1389 binop_precedence = { 1390 "+": _Precedence.ARITH, 1391 "-": _Precedence.ARITH, 1392 "*": _Precedence.TERM, 1393 "@": _Precedence.TERM, 1394 "/": _Precedence.TERM, 1395 "%": _Precedence.TERM, 1396 "<<": _Precedence.SHIFT, 1397 ">>": _Precedence.SHIFT, 1398 "|": _Precedence.BOR, 1399 "^": _Precedence.BXOR, 1400 "&": _Precedence.BAND, 1401 "//": _Precedence.TERM, 1402 "**": _Precedence.POWER, 1403 } 1404 1405 binop_rassoc = frozenset(("**",)) 1406 def visit_BinOp(self, node): 1407 operator = self.binop[node.op.__class__.__name__] 1408 operator_precedence = self.binop_precedence[operator] 1409 with self.require_parens(operator_precedence, node): 1410 if operator in self.binop_rassoc: 1411 left_precedence = operator_precedence.next() 1412 right_precedence = operator_precedence 1413 else: 1414 left_precedence = operator_precedence 1415 right_precedence = operator_precedence.next() 1416 1417 self.set_precedence(left_precedence, node.left) 1418 self.traverse(node.left) 1419 self.write(f" {operator} ") 1420 self.set_precedence(right_precedence, node.right) 1421 self.traverse(node.right) 1422 1423 cmpops = { 1424 "Eq": "==", 1425 "NotEq": "!=", 1426 "Lt": "<", 1427 "LtE": "<=", 1428 "Gt": ">", 1429 "GtE": ">=", 1430 "Is": "is", 1431 "IsNot": "is not", 1432 "In": "in", 1433 "NotIn": "not in", 1434 } 1435 1436 def visit_Compare(self, node): 1437 with self.require_parens(_Precedence.CMP, node): 1438 self.set_precedence(_Precedence.CMP.next(), node.left, *node.comparators) 1439 self.traverse(node.left) 1440 for o, e in zip(node.ops, node.comparators): 1441 self.write(" " + self.cmpops[o.__class__.__name__] + " ") 1442 self.traverse(e) 1443 1444 boolops = {"And": "and", "Or": "or"} 1445 boolop_precedence = {"and": _Precedence.AND, "or": _Precedence.OR} 1446 1447 def visit_BoolOp(self, node): 1448 operator = self.boolops[node.op.__class__.__name__] 1449 operator_precedence = self.boolop_precedence[operator] 1450 1451 def increasing_level_traverse(node): 1452 nonlocal operator_precedence 1453 operator_precedence = operator_precedence.next() 1454 self.set_precedence(operator_precedence, node) 1455 self.traverse(node) 1456 1457 with self.require_parens(operator_precedence, node): 1458 s = f" {operator} " 1459 self.interleave(lambda: self.write(s), increasing_level_traverse, node.values) 1460 1461 def visit_Attribute(self, node): 1462 self.set_precedence(_Precedence.ATOM, node.value) 1463 self.traverse(node.value) 1464 # Special case: 3.__abs__() is a syntax error, so if node.value 1465 # is an integer literal then we need to either parenthesize 1466 # it or add an extra space to get 3 .__abs__(). 1467 if isinstance(node.value, Constant) and isinstance(node.value.value, int): 1468 self.write(" ") 1469 self.write(".") 1470 self.write(node.attr) 1471 1472 def visit_Call(self, node): 1473 self.set_precedence(_Precedence.ATOM, node.func) 1474 self.traverse(node.func) 1475 with self.delimit("(", ")"): 1476 comma = False 1477 for e in node.args: 1478 if comma: 1479 self.write(", ") 1480 else: 1481 comma = True 1482 self.traverse(e) 1483 for e in node.keywords: 1484 if comma: 1485 self.write(", ") 1486 else: 1487 comma = True 1488 self.traverse(e) 1489 1490 def visit_Subscript(self, node): 1491 def is_non_empty_tuple(slice_value): 1492 return ( 1493 isinstance(slice_value, Tuple) 1494 and slice_value.elts 1495 ) 1496 1497 self.set_precedence(_Precedence.ATOM, node.value) 1498 self.traverse(node.value) 1499 with self.delimit("[", "]"): 1500 if is_non_empty_tuple(node.slice): 1501 # parentheses can be omitted if the tuple isn't empty 1502 self.items_view(self.traverse, node.slice.elts) 1503 else: 1504 self.traverse(node.slice) 1505 1506 def visit_Starred(self, node): 1507 self.write("*") 1508 self.set_precedence(_Precedence.EXPR, node.value) 1509 self.traverse(node.value) 1510 1511 def visit_Ellipsis(self, node): 1512 self.write("...") 1513 1514 def visit_Slice(self, node): 1515 if node.lower: 1516 self.traverse(node.lower) 1517 self.write(":") 1518 if node.upper: 1519 self.traverse(node.upper) 1520 if node.step: 1521 self.write(":") 1522 self.traverse(node.step) 1523 1524 def visit_Match(self, node): 1525 self.fill("match ") 1526 self.traverse(node.subject) 1527 with self.block(): 1528 for case in node.cases: 1529 self.traverse(case) 1530 1531 def visit_arg(self, node): 1532 self.write(node.arg) 1533 if node.annotation: 1534 self.write(": ") 1535 self.traverse(node.annotation) 1536 1537 def visit_arguments(self, node): 1538 first = True 1539 # normal arguments 1540 all_args = node.posonlyargs + node.args 1541 defaults = [None] * (len(all_args) - len(node.defaults)) + node.defaults 1542 for index, elements in enumerate(zip(all_args, defaults), 1): 1543 a, d = elements 1544 if first: 1545 first = False 1546 else: 1547 self.write(", ") 1548 self.traverse(a) 1549 if d: 1550 self.write("=") 1551 self.traverse(d) 1552 if index == len(node.posonlyargs): 1553 self.write(", /") 1554 1555 # varargs, or bare '*' if no varargs but keyword-only arguments present 1556 if node.vararg or node.kwonlyargs: 1557 if first: 1558 first = False 1559 else: 1560 self.write(", ") 1561 self.write("*") 1562 if node.vararg: 1563 self.write(node.vararg.arg) 1564 if node.vararg.annotation: 1565 self.write(": ") 1566 self.traverse(node.vararg.annotation) 1567 1568 # keyword-only arguments 1569 if node.kwonlyargs: 1570 for a, d in zip(node.kwonlyargs, node.kw_defaults): 1571 self.write(", ") 1572 self.traverse(a) 1573 if d: 1574 self.write("=") 1575 self.traverse(d) 1576 1577 # kwargs 1578 if node.kwarg: 1579 if first: 1580 first = False 1581 else: 1582 self.write(", ") 1583 self.write("**" + node.kwarg.arg) 1584 if node.kwarg.annotation: 1585 self.write(": ") 1586 self.traverse(node.kwarg.annotation) 1587 1588 def visit_keyword(self, node): 1589 if node.arg is None: 1590 self.write("**") 1591 else: 1592 self.write(node.arg) 1593 self.write("=") 1594 self.traverse(node.value) 1595 1596 def visit_Lambda(self, node): 1597 with self.require_parens(_Precedence.TEST, node): 1598 self.write("lambda") 1599 with self.buffered() as buffer: 1600 self.traverse(node.args) 1601 if buffer: 1602 self.write(" ", *buffer) 1603 self.write(": ") 1604 self.set_precedence(_Precedence.TEST, node.body) 1605 self.traverse(node.body) 1606 1607 def visit_alias(self, node): 1608 self.write(node.name) 1609 if node.asname: 1610 self.write(" as " + node.asname) 1611 1612 def visit_withitem(self, node): 1613 self.traverse(node.context_expr) 1614 if node.optional_vars: 1615 self.write(" as ") 1616 self.traverse(node.optional_vars) 1617 1618 def visit_match_case(self, node): 1619 self.fill("case ") 1620 self.traverse(node.pattern) 1621 if node.guard: 1622 self.write(" if ") 1623 self.traverse(node.guard) 1624 with self.block(): 1625 self.traverse(node.body) 1626 1627 def visit_MatchValue(self, node): 1628 self.traverse(node.value) 1629 1630 def visit_MatchSingleton(self, node): 1631 self._write_constant(node.value) 1632 1633 def visit_MatchSequence(self, node): 1634 with self.delimit("[", "]"): 1635 self.interleave( 1636 lambda: self.write(", "), self.traverse, node.patterns 1637 ) 1638 1639 def visit_MatchStar(self, node): 1640 name = node.name 1641 if name is None: 1642 name = "_" 1643 self.write(f"*{name}") 1644 1645 def visit_MatchMapping(self, node): 1646 def write_key_pattern_pair(pair): 1647 k, p = pair 1648 self.traverse(k) 1649 self.write(": ") 1650 self.traverse(p) 1651 1652 with self.delimit("{", "}"): 1653 keys = node.keys 1654 self.interleave( 1655 lambda: self.write(", "), 1656 write_key_pattern_pair, 1657 zip(keys, node.patterns, strict=True), 1658 ) 1659 rest = node.rest 1660 if rest is not None: 1661 if keys: 1662 self.write(", ") 1663 self.write(f"**{rest}") 1664 1665 def visit_MatchClass(self, node): 1666 self.set_precedence(_Precedence.ATOM, node.cls) 1667 self.traverse(node.cls) 1668 with self.delimit("(", ")"): 1669 patterns = node.patterns 1670 self.interleave( 1671 lambda: self.write(", "), self.traverse, patterns 1672 ) 1673 attrs = node.kwd_attrs 1674 if attrs: 1675 def write_attr_pattern(pair): 1676 attr, pattern = pair 1677 self.write(f"{attr}=") 1678 self.traverse(pattern) 1679 1680 if patterns: 1681 self.write(", ") 1682 self.interleave( 1683 lambda: self.write(", "), 1684 write_attr_pattern, 1685 zip(attrs, node.kwd_patterns, strict=True), 1686 ) 1687 1688 def visit_MatchAs(self, node): 1689 name = node.name 1690 pattern = node.pattern 1691 if name is None: 1692 self.write("_") 1693 elif pattern is None: 1694 self.write(node.name) 1695 else: 1696 with self.require_parens(_Precedence.TEST, node): 1697 self.set_precedence(_Precedence.BOR, node.pattern) 1698 self.traverse(node.pattern) 1699 self.write(f" as {node.name}") 1700 1701 def visit_MatchOr(self, node): 1702 with self.require_parens(_Precedence.BOR, node): 1703 self.set_precedence(_Precedence.BOR.next(), *node.patterns) 1704 self.interleave(lambda: self.write(" | "), self.traverse, node.patterns) 1705 1706def unparse(ast_obj): 1707 unparser = _Unparser() 1708 return unparser.visit(ast_obj) 1709 1710 1711def main(): 1712 import argparse 1713 1714 parser = argparse.ArgumentParser(prog='python -m ast') 1715 parser.add_argument('infile', type=argparse.FileType(mode='rb'), nargs='?', 1716 default='-', 1717 help='the file to parse; defaults to stdin') 1718 parser.add_argument('-m', '--mode', default='exec', 1719 choices=('exec', 'single', 'eval', 'func_type'), 1720 help='specify what kind of code must be parsed') 1721 parser.add_argument('--no-type-comments', default=True, action='store_false', 1722 help="don't add information about type comments") 1723 parser.add_argument('-a', '--include-attributes', action='store_true', 1724 help='include attributes such as line numbers and ' 1725 'column offsets') 1726 parser.add_argument('-i', '--indent', type=int, default=3, 1727 help='indentation of nodes (number of spaces)') 1728 args = parser.parse_args() 1729 1730 with args.infile as infile: 1731 source = infile.read() 1732 tree = parse(source, args.infile.name, args.mode, type_comments=args.no_type_comments) 1733 print(dump(tree, include_attributes=args.include_attributes, indent=args.indent)) 1734 1735if __name__ == '__main__': 1736 main() 1737