1# Copyright 2017 The Abseil Authors.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#      http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15"""Tests for test sharding protocol."""
16
17import os
18import subprocess
19import sys
20
21from absl.testing import _bazelize_command
22from absl.testing import absltest
23from absl.testing import parameterized
24from absl.testing.tests import absltest_env
25
26
27NUM_TEST_METHODS = 8  # Hard-coded, based on absltest_sharding_test_helper.py
28
29
30class TestShardingTest(parameterized.TestCase):
31  """Integration tests: Runs a test binary with sharding.
32
33  This is done by setting the sharding environment variables.
34  """
35
36  def setUp(self):
37    super().setUp()
38    self._shard_file = None
39
40  def tearDown(self):
41    super().tearDown()
42    if self._shard_file is not None and os.path.exists(self._shard_file):
43      os.unlink(self._shard_file)
44
45  def _run_sharded(
46      self,
47      total_shards,
48      shard_index,
49      shard_file=None,
50      additional_env=None,
51      helper_name='absltest_sharding_test_helper',
52  ):
53    """Runs the py_test binary in a subprocess.
54
55    Args:
56      total_shards: int, the total number of shards.
57      shard_index: int, the shard index.
58      shard_file: string, if not 'None', the path to the shard file. This method
59        asserts it is properly created.
60      additional_env: Additional environment variables to be set for the py_test
61        binary.
62      helper_name: The name of the helper binary.
63
64    Returns:
65      (stdout, exit_code) tuple of (string, int).
66    """
67    env = absltest_env.inherited_env()
68    if additional_env:
69      env.update(additional_env)
70    env.update({
71        'TEST_TOTAL_SHARDS': str(total_shards),
72        'TEST_SHARD_INDEX': str(shard_index)
73    })
74    if shard_file:
75      self._shard_file = shard_file
76      env['TEST_SHARD_STATUS_FILE'] = shard_file
77      if os.path.exists(shard_file):
78        os.unlink(shard_file)
79
80    helper = 'absl/testing/tests/' + helper_name
81    proc = subprocess.Popen(
82        args=[_bazelize_command.get_executable_path(helper)],
83        env=env,
84        stdout=subprocess.PIPE,
85        stderr=subprocess.STDOUT,
86        universal_newlines=True,
87    )
88    stdout = proc.communicate()[0]
89
90    if shard_file:
91      self.assertTrue(os.path.exists(shard_file))
92
93    return (stdout, proc.wait())
94
95  def _assert_sharding_correctness(self, total_shards):
96    """Assert the primary correctness and performance of sharding.
97
98    1. Completeness (all methods are run)
99    2. Partition (each method run at most once)
100    3. Balance (for performance)
101
102    Args:
103      total_shards: int, total number of shards.
104    """
105
106    outerr_by_shard = []  # A list of lists of strings
107    combined_outerr = []  # A list of strings
108    exit_code_by_shard = []  # A list of ints
109
110    for i in range(total_shards):
111      (out, exit_code) = self._run_sharded(total_shards, i)
112      method_list = [x for x in out.split('\n') if x.startswith('class')]
113      outerr_by_shard.append(method_list)
114      combined_outerr.extend(method_list)
115      exit_code_by_shard.append(exit_code)
116
117    self.assertLen([x for x in exit_code_by_shard if x != 0], 1,
118                   'Expected exactly one failure')
119
120    # Test completeness and partition properties.
121    self.assertLen(combined_outerr, NUM_TEST_METHODS,
122                   'Partition requirement not met')
123    self.assertLen(set(combined_outerr), NUM_TEST_METHODS,
124                   'Completeness requirement not met')
125
126    # Test balance:
127    for i in range(len(outerr_by_shard)):
128      self.assertGreaterEqual(len(outerr_by_shard[i]),
129                              (NUM_TEST_METHODS / total_shards) - 1,
130                              'Shard %d of %d out of balance' %
131                              (i, len(outerr_by_shard)))
132
133  def test_shard_file(self):
134    self._run_sharded(3, 1, os.path.join(
135        absltest.TEST_TMPDIR.value, 'shard_file'))
136
137  def test_zero_shards(self):
138    out, exit_code = self._run_sharded(0, 0)
139    self.assertEqual(1, exit_code)
140    self.assertGreaterEqual(out.find('Bad sharding values. index=0, total=0'),
141                            0, 'Bad output: %s' % (out))
142
143  def test_with_four_shards(self):
144    self._assert_sharding_correctness(4)
145
146  def test_with_one_shard(self):
147    self._assert_sharding_correctness(1)
148
149  def test_with_ten_shards(self):
150    shards = 10
151    # This test relies on the shard count to be greater than the number of
152    # tests, to ensure that the non-zero shards won't fail even if no tests ran
153    # on Python 3.12+.
154    self.assertGreater(shards, NUM_TEST_METHODS)
155    self._assert_sharding_correctness(shards)
156
157  def test_sharding_with_randomization(self):
158    # If we're both sharding *and* randomizing, we need to confirm that we
159    # randomize within the shard; we use two seeds to confirm we're seeing the
160    # same tests (sharding is consistent) in a different order.
161    tests_seen = []
162    for seed in ('7', '17'):
163      out, exit_code = self._run_sharded(
164          2, 0, additional_env={'TEST_RANDOMIZE_ORDERING_SEED': seed})
165      self.assertEqual(0, exit_code)
166      tests_seen.append([x for x in out.splitlines() if x.startswith('class')])
167    first_tests, second_tests = tests_seen  # pylint: disable=unbalanced-tuple-unpacking
168    self.assertEqual(set(first_tests), set(second_tests))
169    self.assertNotEqual(first_tests, second_tests)
170
171  @parameterized.named_parameters(
172      ('total_1_index_0', 1, 0, None),
173      ('total_2_index_0', 2, 0, None),
174      # The 2nd shard (index=1) should not fail.
175      ('total_2_index_1', 2, 1, 0),
176  )
177  def test_no_tests_ran(
178      self, total_shards, shard_index, override_expected_exit_code
179  ):
180    if override_expected_exit_code is not None:
181      expected_exit_code = override_expected_exit_code
182    elif sys.version_info[:2] >= (3, 12):
183      expected_exit_code = 5
184    else:
185      expected_exit_code = 0
186    out, exit_code = self._run_sharded(
187        total_shards,
188        shard_index,
189        helper_name='absltest_sharding_test_helper_no_tests',
190    )
191    self.assertEqual(
192        expected_exit_code,
193        exit_code,
194        'Unexpected exit code, output:\n{}'.format(out),
195    )
196
197
198if __name__ == '__main__':
199  absltest.main()
200