1# Copyright 2017 The Abseil Authors. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15"""Tests for test sharding protocol.""" 16 17import os 18import subprocess 19import sys 20 21from absl.testing import _bazelize_command 22from absl.testing import absltest 23from absl.testing import parameterized 24from absl.testing.tests import absltest_env 25 26 27NUM_TEST_METHODS = 8 # Hard-coded, based on absltest_sharding_test_helper.py 28 29 30class TestShardingTest(parameterized.TestCase): 31 """Integration tests: Runs a test binary with sharding. 32 33 This is done by setting the sharding environment variables. 34 """ 35 36 def setUp(self): 37 super().setUp() 38 self._shard_file = None 39 40 def tearDown(self): 41 super().tearDown() 42 if self._shard_file is not None and os.path.exists(self._shard_file): 43 os.unlink(self._shard_file) 44 45 def _run_sharded( 46 self, 47 total_shards, 48 shard_index, 49 shard_file=None, 50 additional_env=None, 51 helper_name='absltest_sharding_test_helper', 52 ): 53 """Runs the py_test binary in a subprocess. 54 55 Args: 56 total_shards: int, the total number of shards. 57 shard_index: int, the shard index. 58 shard_file: string, if not 'None', the path to the shard file. This method 59 asserts it is properly created. 60 additional_env: Additional environment variables to be set for the py_test 61 binary. 62 helper_name: The name of the helper binary. 63 64 Returns: 65 (stdout, exit_code) tuple of (string, int). 66 """ 67 env = absltest_env.inherited_env() 68 if additional_env: 69 env.update(additional_env) 70 env.update({ 71 'TEST_TOTAL_SHARDS': str(total_shards), 72 'TEST_SHARD_INDEX': str(shard_index) 73 }) 74 if shard_file: 75 self._shard_file = shard_file 76 env['TEST_SHARD_STATUS_FILE'] = shard_file 77 if os.path.exists(shard_file): 78 os.unlink(shard_file) 79 80 helper = 'absl/testing/tests/' + helper_name 81 proc = subprocess.Popen( 82 args=[_bazelize_command.get_executable_path(helper)], 83 env=env, 84 stdout=subprocess.PIPE, 85 stderr=subprocess.STDOUT, 86 universal_newlines=True, 87 ) 88 stdout = proc.communicate()[0] 89 90 if shard_file: 91 self.assertTrue(os.path.exists(shard_file)) 92 93 return (stdout, proc.wait()) 94 95 def _assert_sharding_correctness(self, total_shards): 96 """Assert the primary correctness and performance of sharding. 97 98 1. Completeness (all methods are run) 99 2. Partition (each method run at most once) 100 3. Balance (for performance) 101 102 Args: 103 total_shards: int, total number of shards. 104 """ 105 106 outerr_by_shard = [] # A list of lists of strings 107 combined_outerr = [] # A list of strings 108 exit_code_by_shard = [] # A list of ints 109 110 for i in range(total_shards): 111 (out, exit_code) = self._run_sharded(total_shards, i) 112 method_list = [x for x in out.split('\n') if x.startswith('class')] 113 outerr_by_shard.append(method_list) 114 combined_outerr.extend(method_list) 115 exit_code_by_shard.append(exit_code) 116 117 self.assertLen([x for x in exit_code_by_shard if x != 0], 1, 118 'Expected exactly one failure') 119 120 # Test completeness and partition properties. 121 self.assertLen(combined_outerr, NUM_TEST_METHODS, 122 'Partition requirement not met') 123 self.assertLen(set(combined_outerr), NUM_TEST_METHODS, 124 'Completeness requirement not met') 125 126 # Test balance: 127 for i in range(len(outerr_by_shard)): 128 self.assertGreaterEqual(len(outerr_by_shard[i]), 129 (NUM_TEST_METHODS / total_shards) - 1, 130 'Shard %d of %d out of balance' % 131 (i, len(outerr_by_shard))) 132 133 def test_shard_file(self): 134 self._run_sharded(3, 1, os.path.join( 135 absltest.TEST_TMPDIR.value, 'shard_file')) 136 137 def test_zero_shards(self): 138 out, exit_code = self._run_sharded(0, 0) 139 self.assertEqual(1, exit_code) 140 self.assertGreaterEqual(out.find('Bad sharding values. index=0, total=0'), 141 0, 'Bad output: %s' % (out)) 142 143 def test_with_four_shards(self): 144 self._assert_sharding_correctness(4) 145 146 def test_with_one_shard(self): 147 self._assert_sharding_correctness(1) 148 149 def test_with_ten_shards(self): 150 shards = 10 151 # This test relies on the shard count to be greater than the number of 152 # tests, to ensure that the non-zero shards won't fail even if no tests ran 153 # on Python 3.12+. 154 self.assertGreater(shards, NUM_TEST_METHODS) 155 self._assert_sharding_correctness(shards) 156 157 def test_sharding_with_randomization(self): 158 # If we're both sharding *and* randomizing, we need to confirm that we 159 # randomize within the shard; we use two seeds to confirm we're seeing the 160 # same tests (sharding is consistent) in a different order. 161 tests_seen = [] 162 for seed in ('7', '17'): 163 out, exit_code = self._run_sharded( 164 2, 0, additional_env={'TEST_RANDOMIZE_ORDERING_SEED': seed}) 165 self.assertEqual(0, exit_code) 166 tests_seen.append([x for x in out.splitlines() if x.startswith('class')]) 167 first_tests, second_tests = tests_seen # pylint: disable=unbalanced-tuple-unpacking 168 self.assertEqual(set(first_tests), set(second_tests)) 169 self.assertNotEqual(first_tests, second_tests) 170 171 @parameterized.named_parameters( 172 ('total_1_index_0', 1, 0, None), 173 ('total_2_index_0', 2, 0, None), 174 # The 2nd shard (index=1) should not fail. 175 ('total_2_index_1', 2, 1, 0), 176 ) 177 def test_no_tests_ran( 178 self, total_shards, shard_index, override_expected_exit_code 179 ): 180 if override_expected_exit_code is not None: 181 expected_exit_code = override_expected_exit_code 182 elif sys.version_info[:2] >= (3, 12): 183 expected_exit_code = 5 184 else: 185 expected_exit_code = 0 186 out, exit_code = self._run_sharded( 187 total_shards, 188 shard_index, 189 helper_name='absltest_sharding_test_helper_no_tests', 190 ) 191 self.assertEqual( 192 expected_exit_code, 193 exit_code, 194 'Unexpected exit code, output:\n{}'.format(out), 195 ) 196 197 198if __name__ == '__main__': 199 absltest.main() 200