xref: /aosp_15_r20/external/yapf/yapf/yapflib/split_penalty.py (revision 7249d1a64f4850ccf838e62a46276f891f72998e)
1# Copyright 2015 Google Inc. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""Computation of split penalties before/between tokens."""
15
16import re
17
18from lib2to3 import pytree
19from lib2to3.pgen2 import token as grammar_token
20
21from yapf.yapflib import format_token
22from yapf.yapflib import py3compat
23from yapf.yapflib import pytree_utils
24from yapf.yapflib import pytree_visitor
25from yapf.yapflib import style
26from yapf.yapflib import subtypes
27
28# TODO(morbo): Document the annotations in a centralized place. E.g., the
29# README file.
30UNBREAKABLE = 1000 * 1000
31NAMED_ASSIGN = 15000
32DOTTED_NAME = 4000
33VERY_STRONGLY_CONNECTED = 3500
34STRONGLY_CONNECTED = 3000
35CONNECTED = 500
36TOGETHER = 100
37
38OR_TEST = 1000
39AND_TEST = 1100
40NOT_TEST = 1200
41COMPARISON = 1300
42STAR_EXPR = 1300
43EXPR = 1400
44XOR_EXPR = 1500
45AND_EXPR = 1700
46SHIFT_EXPR = 1800
47ARITH_EXPR = 1900
48TERM = 2000
49FACTOR = 2100
50POWER = 2200
51ATOM = 2300
52ONE_ELEMENT_ARGUMENT = 500
53SUBSCRIPT = 6000
54
55
56def ComputeSplitPenalties(tree):
57  """Compute split penalties on tokens in the given parse tree.
58
59  Arguments:
60    tree: the top-level pytree node to annotate with penalties.
61  """
62  _SplitPenaltyAssigner().Visit(tree)
63
64
65class _SplitPenaltyAssigner(pytree_visitor.PyTreeVisitor):
66  """Assigns split penalties to tokens, based on parse tree structure.
67
68  Split penalties are attached as annotations to tokens.
69  """
70
71  def Visit(self, node):
72    if not hasattr(node, 'is_pseudo'):  # Ignore pseudo tokens.
73      super(_SplitPenaltyAssigner, self).Visit(node)
74
75  def Visit_import_as_names(self, node):  # pyline: disable=invalid-name
76    # import_as_names ::= import_as_name (',' import_as_name)* [',']
77    self.DefaultNodeVisit(node)
78    prev_child = None
79    for child in node.children:
80      if (prev_child and isinstance(prev_child, pytree.Leaf) and
81          prev_child.value == ','):
82        _SetSplitPenalty(child, style.Get('SPLIT_PENALTY_IMPORT_NAMES'))
83      prev_child = child
84
85  def Visit_classdef(self, node):  # pylint: disable=invalid-name
86    # classdef ::= 'class' NAME ['(' [arglist] ')'] ':' suite
87    #
88    # NAME
89    _SetUnbreakable(node.children[1])
90    if len(node.children) > 4:
91      # opening '('
92      _SetUnbreakable(node.children[2])
93    # ':'
94    _SetUnbreakable(node.children[-2])
95    self.DefaultNodeVisit(node)
96
97  def Visit_funcdef(self, node):  # pylint: disable=invalid-name
98    # funcdef ::= 'def' NAME parameters ['->' test] ':' suite
99    #
100    # Can't break before the function name and before the colon. The parameters
101    # are handled by child iteration.
102    colon_idx = 1
103    while pytree_utils.NodeName(node.children[colon_idx]) == 'simple_stmt':
104      colon_idx += 1
105    _SetUnbreakable(node.children[colon_idx])
106    arrow_idx = -1
107    while colon_idx < len(node.children):
108      if isinstance(node.children[colon_idx], pytree.Leaf):
109        if node.children[colon_idx].value == ':':
110          break
111        if node.children[colon_idx].value == '->':
112          arrow_idx = colon_idx
113      colon_idx += 1
114    _SetUnbreakable(node.children[colon_idx])
115    self.DefaultNodeVisit(node)
116    if arrow_idx > 0:
117      _SetSplitPenalty(
118          pytree_utils.LastLeafNode(node.children[arrow_idx - 1]), 0)
119      _SetUnbreakable(node.children[arrow_idx])
120      _SetStronglyConnected(node.children[arrow_idx + 1])
121
122  def Visit_lambdef(self, node):  # pylint: disable=invalid-name
123    # lambdef ::= 'lambda' [varargslist] ':' test
124    # Loop over the lambda up to and including the colon.
125    allow_multiline_lambdas = style.Get('ALLOW_MULTILINE_LAMBDAS')
126    if not allow_multiline_lambdas:
127      for child in node.children:
128        if child.type == grammar_token.COMMENT:
129          if re.search(r'pylint:.*disable=.*\bg-long-lambda', child.value):
130            allow_multiline_lambdas = True
131            break
132
133    if allow_multiline_lambdas:
134      _SetExpressionPenalty(node, STRONGLY_CONNECTED)
135    else:
136      _SetExpressionPenalty(node, VERY_STRONGLY_CONNECTED)
137
138  def Visit_parameters(self, node):  # pylint: disable=invalid-name
139    # parameters ::= '(' [typedargslist] ')'
140    self.DefaultNodeVisit(node)
141
142    # Can't break before the opening paren of a parameter list.
143    _SetUnbreakable(node.children[0])
144    if not (style.Get('INDENT_CLOSING_BRACKETS') or
145            style.Get('DEDENT_CLOSING_BRACKETS')):
146      _SetStronglyConnected(node.children[-1])
147
148  def Visit_arglist(self, node):  # pylint: disable=invalid-name
149    # arglist ::= argument (',' argument)* [',']
150    if node.children[0].type == grammar_token.STAR:
151      # Python 3 treats a star expression as a specific expression type.
152      # Process it in that method.
153      self.Visit_star_expr(node)
154      return
155
156    self.DefaultNodeVisit(node)
157
158    for index in py3compat.range(1, len(node.children)):
159      child = node.children[index]
160      if isinstance(child, pytree.Leaf) and child.value == ',':
161        _SetUnbreakable(child)
162
163    for child in node.children:
164      if pytree_utils.NodeName(child) == 'atom':
165        _IncreasePenalty(child, CONNECTED)
166
167  def Visit_argument(self, node):  # pylint: disable=invalid-name
168    # argument ::= test [comp_for] | test '=' test  # Really [keyword '='] test
169    self.DefaultNodeVisit(node)
170
171    for index in py3compat.range(1, len(node.children) - 1):
172      child = node.children[index]
173      if isinstance(child, pytree.Leaf) and child.value == '=':
174        _SetSplitPenalty(
175            pytree_utils.FirstLeafNode(node.children[index]), NAMED_ASSIGN)
176        _SetSplitPenalty(
177            pytree_utils.FirstLeafNode(node.children[index + 1]), NAMED_ASSIGN)
178
179  def Visit_tname(self, node):  # pylint: disable=invalid-name
180    # tname ::= NAME [':' test]
181    self.DefaultNodeVisit(node)
182
183    for index in py3compat.range(1, len(node.children) - 1):
184      child = node.children[index]
185      if isinstance(child, pytree.Leaf) and child.value == ':':
186        _SetSplitPenalty(
187            pytree_utils.FirstLeafNode(node.children[index]), NAMED_ASSIGN)
188        _SetSplitPenalty(
189            pytree_utils.FirstLeafNode(node.children[index + 1]), NAMED_ASSIGN)
190
191  def Visit_dotted_name(self, node):  # pylint: disable=invalid-name
192    # dotted_name ::= NAME ('.' NAME)*
193    for child in node.children:
194      self.Visit(child)
195    start = 2 if hasattr(node.children[0], 'is_pseudo') else 1
196    for i in py3compat.range(start, len(node.children)):
197      _SetUnbreakable(node.children[i])
198
199  def Visit_dictsetmaker(self, node):  # pylint: disable=invalid-name
200    # dictsetmaker ::= ( (test ':' test
201    #                      (comp_for | (',' test ':' test)* [','])) |
202    #                    (test (comp_for | (',' test)* [','])) )
203    for child in node.children:
204      self.Visit(child)
205      if child.type == grammar_token.COLON:
206        # This is a key to a dictionary. We don't want to split the key if at
207        # all possible.
208        _SetStronglyConnected(child)
209
210  def Visit_trailer(self, node):  # pylint: disable=invalid-name
211    # trailer ::= '(' [arglist] ')' | '[' subscriptlist ']' | '.' NAME
212    if node.children[0].value == '.':
213      before = style.Get('SPLIT_BEFORE_DOT')
214      _SetSplitPenalty(node.children[0],
215                       VERY_STRONGLY_CONNECTED if before else DOTTED_NAME)
216      _SetSplitPenalty(node.children[1],
217                       DOTTED_NAME if before else VERY_STRONGLY_CONNECTED)
218    elif len(node.children) == 2:
219      # Don't split an empty argument list if at all possible.
220      _SetSplitPenalty(node.children[1], VERY_STRONGLY_CONNECTED)
221    elif len(node.children) == 3:
222      name = pytree_utils.NodeName(node.children[1])
223      if name in {'argument', 'comparison'}:
224        # Don't split an argument list with one element if at all possible.
225        _SetStronglyConnected(node.children[1])
226        if (len(node.children[1].children) > 1 and
227            pytree_utils.NodeName(node.children[1].children[1]) == 'comp_for'):
228          # Don't penalize splitting before a comp_for expression.
229          _SetSplitPenalty(pytree_utils.FirstLeafNode(node.children[1]), 0)
230        else:
231          _SetSplitPenalty(
232              pytree_utils.FirstLeafNode(node.children[1]),
233              ONE_ELEMENT_ARGUMENT)
234      elif (node.children[0].type == grammar_token.LSQB and
235            len(node.children[1].children) > 2 and
236            (name.endswith('_test') or name.endswith('_expr'))):
237        _SetStronglyConnected(node.children[1].children[0])
238        _SetStronglyConnected(node.children[1].children[2])
239
240        # Still allow splitting around the operator.
241        split_before = ((name.endswith('_test') and
242                         style.Get('SPLIT_BEFORE_LOGICAL_OPERATOR')) or
243                        (name.endswith('_expr') and
244                         style.Get('SPLIT_BEFORE_BITWISE_OPERATOR')))
245        if split_before:
246          _SetSplitPenalty(
247              pytree_utils.LastLeafNode(node.children[1].children[1]), 0)
248        else:
249          _SetSplitPenalty(
250              pytree_utils.FirstLeafNode(node.children[1].children[2]), 0)
251
252        # Don't split the ending bracket of a subscript list.
253        _RecAnnotate(node.children[-1], pytree_utils.Annotation.SPLIT_PENALTY,
254                     VERY_STRONGLY_CONNECTED)
255      elif name not in {
256          'arglist', 'argument', 'term', 'or_test', 'and_test', 'comparison',
257          'atom', 'power'
258      }:
259        # Don't split an argument list with one element if at all possible.
260        stypes = pytree_utils.GetNodeAnnotation(
261            pytree_utils.FirstLeafNode(node), pytree_utils.Annotation.SUBTYPE)
262        if stypes and subtypes.SUBSCRIPT_BRACKET in stypes:
263          _IncreasePenalty(node, SUBSCRIPT)
264
265          # Bump up the split penalty for the first part of a subscript. We
266          # would rather not split there.
267          _IncreasePenalty(node.children[1], CONNECTED)
268        else:
269          _SetStronglyConnected(node.children[1], node.children[2])
270
271      if name == 'arglist':
272        _SetStronglyConnected(node.children[-1])
273
274    self.DefaultNodeVisit(node)
275
276  def Visit_power(self, node):  # pylint: disable=invalid-name,missing-docstring
277    # power ::= atom trailer* ['**' factor]
278    self.DefaultNodeVisit(node)
279
280    # When atom is followed by a trailer, we can not break between them.
281    # E.g. arr[idx] - no break allowed between 'arr' and '['.
282    if (len(node.children) > 1 and
283        pytree_utils.NodeName(node.children[1]) == 'trailer'):
284      # children[1] itself is a whole trailer: we don't want to
285      # mark all of it as unbreakable, only its first token: (, [ or .
286      first = pytree_utils.FirstLeafNode(node.children[1])
287      if first.value != '.':
288        _SetUnbreakable(node.children[1].children[0])
289
290      # A special case when there are more trailers in the sequence. Given:
291      #   atom tr1 tr2
292      # The last token of tr1 and the first token of tr2 comprise an unbreakable
293      # region. For example: foo.bar.baz(1)
294      # We can't put breaks between either of the '.', '(', or '[' and the names
295      # *preceding* them.
296      prev_trailer_idx = 1
297      while prev_trailer_idx < len(node.children) - 1:
298        cur_trailer_idx = prev_trailer_idx + 1
299        cur_trailer = node.children[cur_trailer_idx]
300        if pytree_utils.NodeName(cur_trailer) != 'trailer':
301          break
302
303        # Now we know we have two trailers one after the other
304        prev_trailer = node.children[prev_trailer_idx]
305        if prev_trailer.children[-1].value != ')':
306          # Set the previous node unbreakable if it's not a function call:
307          #   atom tr1() tr2
308          # It may be necessary (though undesirable) to split up a previous
309          # function call's parentheses to the next line.
310          _SetStronglyConnected(prev_trailer.children[-1])
311        _SetStronglyConnected(cur_trailer.children[0])
312        prev_trailer_idx = cur_trailer_idx
313
314    # We don't want to split before the last ')' of a function call. This also
315    # takes care of the special case of:
316    #   atom tr1 tr2 ... trn
317    # where the 'tr#' are trailers that may end in a ')'.
318    for trailer in node.children[1:]:
319      if pytree_utils.NodeName(trailer) != 'trailer':
320        break
321      if trailer.children[0].value in '([':
322        if len(trailer.children) > 2:
323          stypes = pytree_utils.GetNodeAnnotation(
324              trailer.children[0], pytree_utils.Annotation.SUBTYPE)
325          if stypes and subtypes.SUBSCRIPT_BRACKET in stypes:
326            _SetStronglyConnected(
327                pytree_utils.FirstLeafNode(trailer.children[1]))
328
329          last_child_node = pytree_utils.LastLeafNode(trailer)
330          if last_child_node.value.strip().startswith('#'):
331            last_child_node = last_child_node.prev_sibling
332          if not (style.Get('INDENT_CLOSING_BRACKETS') or
333                  style.Get('DEDENT_CLOSING_BRACKETS')):
334            last = pytree_utils.LastLeafNode(last_child_node.prev_sibling)
335            if last.value != ',':
336              if last_child_node.value == ']':
337                _SetUnbreakable(last_child_node)
338              else:
339                _SetSplitPenalty(last_child_node, VERY_STRONGLY_CONNECTED)
340        else:
341          # If the trailer's children are '()', then make it a strongly
342          # connected region.  It's sometimes necessary, though undesirable, to
343          # split the two.
344          _SetStronglyConnected(trailer.children[-1])
345
346  def Visit_subscriptlist(self, node):  # pylint: disable=invalid-name
347    # subscriptlist ::= subscript (',' subscript)* [',']
348    self.DefaultNodeVisit(node)
349    _SetSplitPenalty(pytree_utils.FirstLeafNode(node), 0)
350    prev_child = None
351    for child in node.children:
352      if prev_child and prev_child.type == grammar_token.COMMA:
353        _SetSplitPenalty(pytree_utils.FirstLeafNode(child), 0)
354      prev_child = child
355
356  def Visit_subscript(self, node):  # pylint: disable=invalid-name
357    # subscript ::= test | [test] ':' [test] [sliceop]
358    _SetStronglyConnected(*node.children)
359    self.DefaultNodeVisit(node)
360
361  def Visit_comp_for(self, node):  # pylint: disable=invalid-name
362    # comp_for ::= 'for' exprlist 'in' testlist_safe [comp_iter]
363    _SetSplitPenalty(pytree_utils.FirstLeafNode(node), 0)
364    _SetStronglyConnected(*node.children[1:])
365    self.DefaultNodeVisit(node)
366
367  def Visit_old_comp_for(self, node):  # pylint: disable=invalid-name
368    # Python 3.7
369    self.Visit_comp_for(node)
370
371  def Visit_comp_if(self, node):  # pylint: disable=invalid-name
372    # comp_if ::= 'if' old_test [comp_iter]
373    _SetSplitPenalty(node.children[0],
374                     style.Get('SPLIT_PENALTY_BEFORE_IF_EXPR'))
375    _SetStronglyConnected(*node.children[1:])
376    self.DefaultNodeVisit(node)
377
378  def Visit_old_comp_if(self, node):  # pylint: disable=invalid-name
379    # Python 3.7
380    self.Visit_comp_if(node)
381
382  def Visit_test(self, node):  # pylint: disable=invalid-name
383    # test ::= or_test ['if' or_test 'else' test] | lambdef
384    _IncreasePenalty(node, OR_TEST)
385    self.DefaultNodeVisit(node)
386
387  def Visit_or_test(self, node):  # pylint: disable=invalid-name
388    # or_test ::= and_test ('or' and_test)*
389    self.DefaultNodeVisit(node)
390    _IncreasePenalty(node, OR_TEST)
391    index = 1
392    while index + 1 < len(node.children):
393      if style.Get('SPLIT_BEFORE_LOGICAL_OPERATOR'):
394        _DecrementSplitPenalty(
395            pytree_utils.FirstLeafNode(node.children[index]), OR_TEST)
396      else:
397        _DecrementSplitPenalty(
398            pytree_utils.FirstLeafNode(node.children[index + 1]), OR_TEST)
399      index += 2
400
401  def Visit_and_test(self, node):  # pylint: disable=invalid-name
402    # and_test ::= not_test ('and' not_test)*
403    self.DefaultNodeVisit(node)
404    _IncreasePenalty(node, AND_TEST)
405    index = 1
406    while index + 1 < len(node.children):
407      if style.Get('SPLIT_BEFORE_LOGICAL_OPERATOR'):
408        _DecrementSplitPenalty(
409            pytree_utils.FirstLeafNode(node.children[index]), AND_TEST)
410      else:
411        _DecrementSplitPenalty(
412            pytree_utils.FirstLeafNode(node.children[index + 1]), AND_TEST)
413      index += 2
414
415  def Visit_not_test(self, node):  # pylint: disable=invalid-name
416    # not_test ::= 'not' not_test | comparison
417    self.DefaultNodeVisit(node)
418    _IncreasePenalty(node, NOT_TEST)
419
420  def Visit_comparison(self, node):  # pylint: disable=invalid-name
421    # comparison ::= expr (comp_op expr)*
422    self.DefaultNodeVisit(node)
423    if len(node.children) == 3 and _StronglyConnectedCompOp(node):
424      _IncreasePenalty(node.children[1], VERY_STRONGLY_CONNECTED)
425      _SetSplitPenalty(
426          pytree_utils.FirstLeafNode(node.children[2]), STRONGLY_CONNECTED)
427    else:
428      _IncreasePenalty(node, COMPARISON)
429
430  def Visit_star_expr(self, node):  # pylint: disable=invalid-name
431    # star_expr ::= '*' expr
432    self.DefaultNodeVisit(node)
433    _IncreasePenalty(node, STAR_EXPR)
434
435  def Visit_expr(self, node):  # pylint: disable=invalid-name
436    # expr ::= xor_expr ('|' xor_expr)*
437    self.DefaultNodeVisit(node)
438    _IncreasePenalty(node, EXPR)
439    _SetBitwiseOperandPenalty(node, '|')
440
441  def Visit_xor_expr(self, node):  # pylint: disable=invalid-name
442    # xor_expr ::= and_expr ('^' and_expr)*
443    self.DefaultNodeVisit(node)
444    _IncreasePenalty(node, XOR_EXPR)
445    _SetBitwiseOperandPenalty(node, '^')
446
447  def Visit_and_expr(self, node):  # pylint: disable=invalid-name
448    # and_expr ::= shift_expr ('&' shift_expr)*
449    self.DefaultNodeVisit(node)
450    _IncreasePenalty(node, AND_EXPR)
451    _SetBitwiseOperandPenalty(node, '&')
452
453  def Visit_shift_expr(self, node):  # pylint: disable=invalid-name
454    # shift_expr ::= arith_expr (('<<'|'>>') arith_expr)*
455    self.DefaultNodeVisit(node)
456    _IncreasePenalty(node, SHIFT_EXPR)
457
458  _ARITH_OPS = frozenset({'PLUS', 'MINUS'})
459
460  def Visit_arith_expr(self, node):  # pylint: disable=invalid-name
461    # arith_expr ::= term (('+'|'-') term)*
462    self.DefaultNodeVisit(node)
463    _IncreasePenalty(node, ARITH_EXPR)
464    _SetExpressionOperandPenalty(node, self._ARITH_OPS)
465
466  _TERM_OPS = frozenset({'STAR', 'AT', 'SLASH', 'PERCENT', 'DOUBLESLASH'})
467
468  def Visit_term(self, node):  # pylint: disable=invalid-name
469    # term ::= factor (('*'|'@'|'/'|'%'|'//') factor)*
470    self.DefaultNodeVisit(node)
471    _IncreasePenalty(node, TERM)
472    _SetExpressionOperandPenalty(node, self._TERM_OPS)
473
474  def Visit_factor(self, node):  # pyline: disable=invalid-name
475    # factor ::= ('+'|'-'|'~') factor | power
476    self.DefaultNodeVisit(node)
477    _IncreasePenalty(node, FACTOR)
478
479  def Visit_atom(self, node):  # pylint: disable=invalid-name
480    # atom ::= ('(' [yield_expr|testlist_gexp] ')'
481    #           '[' [listmaker] ']' |
482    #           '{' [dictsetmaker] '}')
483    self.DefaultNodeVisit(node)
484    if (node.children[0].value == '(' and
485        not hasattr(node.children[0], 'is_pseudo')):
486      if node.children[-1].value == ')':
487        if pytree_utils.NodeName(node.parent) == 'if_stmt':
488          _SetSplitPenalty(node.children[-1], STRONGLY_CONNECTED)
489        else:
490          if len(node.children) > 2:
491            _SetSplitPenalty(pytree_utils.FirstLeafNode(node.children[1]), EXPR)
492          _SetSplitPenalty(node.children[-1], ATOM)
493    elif node.children[0].value in '[{' and len(node.children) == 2:
494      # Keep empty containers together if we can.
495      _SetUnbreakable(node.children[-1])
496
497  def Visit_testlist_gexp(self, node):  # pylint: disable=invalid-name
498    self.DefaultNodeVisit(node)
499    prev_was_comma = False
500    for child in node.children:
501      if isinstance(child, pytree.Leaf) and child.value == ',':
502        _SetUnbreakable(child)
503        prev_was_comma = True
504      else:
505        if prev_was_comma:
506          _SetSplitPenalty(pytree_utils.FirstLeafNode(child), TOGETHER)
507        prev_was_comma = False
508
509
510def _SetUnbreakable(node):
511  """Set an UNBREAKABLE penalty annotation for the given node."""
512  _RecAnnotate(node, pytree_utils.Annotation.SPLIT_PENALTY, UNBREAKABLE)
513
514
515def _SetStronglyConnected(*nodes):
516  """Set a STRONGLY_CONNECTED penalty annotation for the given nodes."""
517  for node in nodes:
518    _RecAnnotate(node, pytree_utils.Annotation.SPLIT_PENALTY,
519                 STRONGLY_CONNECTED)
520
521
522def _SetExpressionPenalty(node, penalty):
523  """Set a penalty annotation on children nodes."""
524
525  def RecExpression(node, first_child_leaf):
526    if node is first_child_leaf:
527      return
528
529    if isinstance(node, pytree.Leaf):
530      if node.value in {'(', 'for', 'if'}:
531        return
532      penalty_annotation = pytree_utils.GetNodeAnnotation(
533          node, pytree_utils.Annotation.SPLIT_PENALTY, default=0)
534      if penalty_annotation < penalty:
535        _SetSplitPenalty(node, penalty)
536    else:
537      for child in node.children:
538        RecExpression(child, first_child_leaf)
539
540  RecExpression(node, pytree_utils.FirstLeafNode(node))
541
542
543def _SetBitwiseOperandPenalty(node, op):
544  for index in py3compat.range(1, len(node.children) - 1):
545    child = node.children[index]
546    if isinstance(child, pytree.Leaf) and child.value == op:
547      if style.Get('SPLIT_BEFORE_BITWISE_OPERATOR'):
548        _SetSplitPenalty(child, style.Get('SPLIT_PENALTY_BITWISE_OPERATOR'))
549      else:
550        _SetSplitPenalty(
551            pytree_utils.FirstLeafNode(node.children[index + 1]),
552            style.Get('SPLIT_PENALTY_BITWISE_OPERATOR'))
553
554
555def _SetExpressionOperandPenalty(node, ops):
556  for index in py3compat.range(1, len(node.children) - 1):
557    child = node.children[index]
558    if pytree_utils.NodeName(child) in ops:
559      if style.Get('SPLIT_BEFORE_ARITHMETIC_OPERATOR'):
560        _SetSplitPenalty(child, style.Get('SPLIT_PENALTY_ARITHMETIC_OPERATOR'))
561      else:
562        _SetSplitPenalty(
563            pytree_utils.FirstLeafNode(node.children[index + 1]),
564            style.Get('SPLIT_PENALTY_ARITHMETIC_OPERATOR'))
565
566
567def _IncreasePenalty(node, amt):
568  """Increase a penalty annotation on children nodes."""
569
570  def RecExpression(node, first_child_leaf):
571    if node is first_child_leaf:
572      return
573
574    if isinstance(node, pytree.Leaf):
575      if node.value in {'(', 'for'}:
576        return
577      penalty = pytree_utils.GetNodeAnnotation(
578          node, pytree_utils.Annotation.SPLIT_PENALTY, default=0)
579      _SetSplitPenalty(node, penalty + amt)
580    else:
581      for child in node.children:
582        RecExpression(child, first_child_leaf)
583
584  RecExpression(node, pytree_utils.FirstLeafNode(node))
585
586
587def _RecAnnotate(tree, annotate_name, annotate_value):
588  """Recursively set the given annotation on all leafs of the subtree.
589
590  Takes care to only increase the penalty. If the node already has a higher
591  or equal penalty associated with it, this is a no-op.
592
593  Args:
594    tree: subtree to annotate
595    annotate_name: name of the annotation to set
596    annotate_value: value of the annotation to set
597  """
598  for child in tree.children:
599    _RecAnnotate(child, annotate_name, annotate_value)
600  if isinstance(tree, pytree.Leaf):
601    cur_annotate = pytree_utils.GetNodeAnnotation(
602        tree, annotate_name, default=0)
603    if cur_annotate < annotate_value:
604      pytree_utils.SetNodeAnnotation(tree, annotate_name, annotate_value)
605
606
607def _StronglyConnectedCompOp(op):
608  if (len(op.children[1].children) == 2 and
609      pytree_utils.NodeName(op.children[1]) == 'comp_op'):
610    if (pytree_utils.FirstLeafNode(op.children[1]).value == 'not' and
611        pytree_utils.LastLeafNode(op.children[1]).value == 'in'):
612      return True
613    if (pytree_utils.FirstLeafNode(op.children[1]).value == 'is' and
614        pytree_utils.LastLeafNode(op.children[1]).value == 'not'):
615      return True
616  if (isinstance(op.children[1], pytree.Leaf) and
617      op.children[1].value in {'==', 'in'}):
618    return True
619  return False
620
621
622def _DecrementSplitPenalty(node, amt):
623  penalty = pytree_utils.GetNodeAnnotation(
624      node, pytree_utils.Annotation.SPLIT_PENALTY, default=amt)
625  penalty = penalty - amt if amt < penalty else 0
626  _SetSplitPenalty(node, penalty)
627
628
629def _SetSplitPenalty(node, penalty):
630  pytree_utils.SetNodeAnnotation(node, pytree_utils.Annotation.SPLIT_PENALTY,
631                                 penalty)
632