xref: /aosp_15_r20/external/yapf/yapftests/split_penalty_test.py (revision 7249d1a64f4850ccf838e62a46276f891f72998e)
1# Copyright 2015 Google Inc. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""Tests for yapf.split_penalty."""
15
16import sys
17import textwrap
18import unittest
19
20from lib2to3 import pytree
21
22from yapf.yapflib import pytree_utils
23from yapf.yapflib import pytree_visitor
24from yapf.yapflib import split_penalty
25from yapf.yapflib import style
26
27from yapftests import yapf_test_helper
28
29UNBREAKABLE = split_penalty.UNBREAKABLE
30VERY_STRONGLY_CONNECTED = split_penalty.VERY_STRONGLY_CONNECTED
31DOTTED_NAME = split_penalty.DOTTED_NAME
32STRONGLY_CONNECTED = split_penalty.STRONGLY_CONNECTED
33
34
35class SplitPenaltyTest(yapf_test_helper.YAPFTest):
36
37  @classmethod
38  def setUpClass(cls):
39    style.SetGlobalStyle(style.CreateYapfStyle())
40
41  def _ParseAndComputePenalties(self, code, dumptree=False):
42    """Parses the code and computes split penalties.
43
44    Arguments:
45      code: code to parse as a string
46      dumptree: if True, the parsed pytree (after penalty assignment) is dumped
47        to stderr. Useful for debugging.
48
49    Returns:
50      Parse tree.
51    """
52    tree = pytree_utils.ParseCodeToTree(code)
53    split_penalty.ComputeSplitPenalties(tree)
54    if dumptree:
55      pytree_visitor.DumpPyTree(tree, target_stream=sys.stderr)
56    return tree
57
58  def _CheckPenalties(self, tree, list_of_expected):
59    """Check that the tokens in the tree have the correct penalties.
60
61    Args:
62      tree: the pytree.
63      list_of_expected: list of (name, penalty) pairs. Non-semantic tokens are
64        filtered out from the expected values.
65    """
66
67    def FlattenRec(tree):
68      if pytree_utils.NodeName(tree) in pytree_utils.NONSEMANTIC_TOKENS:
69        return []
70      if isinstance(tree, pytree.Leaf):
71        return [(tree.value,
72                 pytree_utils.GetNodeAnnotation(
73                     tree, pytree_utils.Annotation.SPLIT_PENALTY))]
74      nodes = []
75      for node in tree.children:
76        nodes += FlattenRec(node)
77      return nodes
78
79    self.assertEqual(list_of_expected, FlattenRec(tree))
80
81  def testUnbreakable(self):
82    # Test function definitions.
83    code = textwrap.dedent(r"""
84      def foo(x):
85        pass
86      """)
87    tree = self._ParseAndComputePenalties(code)
88    self._CheckPenalties(tree, [
89        ('def', None),
90        ('foo', UNBREAKABLE),
91        ('(', UNBREAKABLE),
92        ('x', None),
93        (')', STRONGLY_CONNECTED),
94        (':', UNBREAKABLE),
95        ('pass', None),
96    ])
97
98    # Test function definition with trailing comment.
99    code = textwrap.dedent(r"""
100      def foo(x):  # trailing comment
101        pass
102      """)
103    tree = self._ParseAndComputePenalties(code)
104    self._CheckPenalties(tree, [
105        ('def', None),
106        ('foo', UNBREAKABLE),
107        ('(', UNBREAKABLE),
108        ('x', None),
109        (')', STRONGLY_CONNECTED),
110        (':', UNBREAKABLE),
111        ('pass', None),
112    ])
113
114    # Test class definitions.
115    code = textwrap.dedent(r"""
116      class A:
117        pass
118      class B(A):
119        pass
120      """)
121    tree = self._ParseAndComputePenalties(code)
122    self._CheckPenalties(tree, [
123        ('class', None),
124        ('A', UNBREAKABLE),
125        (':', UNBREAKABLE),
126        ('pass', None),
127        ('class', None),
128        ('B', UNBREAKABLE),
129        ('(', UNBREAKABLE),
130        ('A', None),
131        (')', None),
132        (':', UNBREAKABLE),
133        ('pass', None),
134    ])
135
136    # Test lambda definitions.
137    code = textwrap.dedent(r"""
138      lambda a, b: None
139      """)
140    tree = self._ParseAndComputePenalties(code)
141    self._CheckPenalties(tree, [
142        ('lambda', None),
143        ('a', VERY_STRONGLY_CONNECTED),
144        (',', VERY_STRONGLY_CONNECTED),
145        ('b', VERY_STRONGLY_CONNECTED),
146        (':', VERY_STRONGLY_CONNECTED),
147        ('None', VERY_STRONGLY_CONNECTED),
148    ])
149
150    # Test dotted names.
151    code = textwrap.dedent(r"""
152      import a.b.c
153      """)
154    tree = self._ParseAndComputePenalties(code)
155    self._CheckPenalties(tree, [
156        ('import', None),
157        ('a', None),
158        ('.', UNBREAKABLE),
159        ('b', UNBREAKABLE),
160        ('.', UNBREAKABLE),
161        ('c', UNBREAKABLE),
162    ])
163
164  def testStronglyConnected(self):
165    # Test dictionary keys.
166    code = textwrap.dedent(r"""
167      a = {
168          'x': 42,
169          y(lambda a: 23): 37,
170      }
171      """)
172    tree = self._ParseAndComputePenalties(code)
173    self._CheckPenalties(tree, [
174        ('a', None),
175        ('=', None),
176        ('{', None),
177        ("'x'", None),
178        (':', STRONGLY_CONNECTED),
179        ('42', None),
180        (',', None),
181        ('y', None),
182        ('(', UNBREAKABLE),
183        ('lambda', STRONGLY_CONNECTED),
184        ('a', VERY_STRONGLY_CONNECTED),
185        (':', VERY_STRONGLY_CONNECTED),
186        ('23', VERY_STRONGLY_CONNECTED),
187        (')', VERY_STRONGLY_CONNECTED),
188        (':', STRONGLY_CONNECTED),
189        ('37', None),
190        (',', None),
191        ('}', None),
192    ])
193
194    # Test list comprehension.
195    code = textwrap.dedent(r"""
196      [a for a in foo if a.x == 37]
197      """)
198    tree = self._ParseAndComputePenalties(code)
199    self._CheckPenalties(tree, [
200        ('[', None),
201        ('a', None),
202        ('for', 0),
203        ('a', STRONGLY_CONNECTED),
204        ('in', STRONGLY_CONNECTED),
205        ('foo', STRONGLY_CONNECTED),
206        ('if', 0),
207        ('a', STRONGLY_CONNECTED),
208        ('.', VERY_STRONGLY_CONNECTED),
209        ('x', DOTTED_NAME),
210        ('==', STRONGLY_CONNECTED),
211        ('37', STRONGLY_CONNECTED),
212        (']', None),
213    ])
214
215  def testFuncCalls(self):
216    code = 'foo(1, 2, 3)\n'
217    tree = self._ParseAndComputePenalties(code)
218    self._CheckPenalties(tree, [
219        ('foo', None),
220        ('(', UNBREAKABLE),
221        ('1', None),
222        (',', UNBREAKABLE),
223        ('2', None),
224        (',', UNBREAKABLE),
225        ('3', None),
226        (')', VERY_STRONGLY_CONNECTED),
227    ])
228
229    # Now a method call, which has more than one trailer
230    code = 'foo.bar.baz(1, 2, 3)\n'
231    tree = self._ParseAndComputePenalties(code)
232    self._CheckPenalties(tree, [
233        ('foo', None),
234        ('.', VERY_STRONGLY_CONNECTED),
235        ('bar', DOTTED_NAME),
236        ('.', VERY_STRONGLY_CONNECTED),
237        ('baz', DOTTED_NAME),
238        ('(', STRONGLY_CONNECTED),
239        ('1', None),
240        (',', UNBREAKABLE),
241        ('2', None),
242        (',', UNBREAKABLE),
243        ('3', None),
244        (')', VERY_STRONGLY_CONNECTED),
245    ])
246
247    # Test single generator argument.
248    code = 'max(i for i in xrange(10))\n'
249    tree = self._ParseAndComputePenalties(code)
250    self._CheckPenalties(tree, [
251        ('max', None),
252        ('(', UNBREAKABLE),
253        ('i', 0),
254        ('for', 0),
255        ('i', STRONGLY_CONNECTED),
256        ('in', STRONGLY_CONNECTED),
257        ('xrange', STRONGLY_CONNECTED),
258        ('(', UNBREAKABLE),
259        ('10', STRONGLY_CONNECTED),
260        (')', VERY_STRONGLY_CONNECTED),
261        (')', VERY_STRONGLY_CONNECTED),
262    ])
263
264
265if __name__ == '__main__':
266  unittest.main()
267