1# Copyright 2019 the gRPC authors.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""Test for multiprocessing example."""
15
16import ast
17import logging
18import math
19import os
20import re
21import subprocess
22import tempfile
23import unittest
24
25_BINARY_DIR = os.path.realpath(
26    os.path.join(os.path.dirname(os.path.abspath(__file__)), "..")
27)
28_SERVER_PATH = os.path.join(_BINARY_DIR, "server")
29_CLIENT_PATH = os.path.join(_BINARY_DIR, "client")
30
31
32def is_prime(n):
33    for i in range(2, int(math.ceil(math.sqrt(n)))):
34        if n % i == 0:
35            return False
36    else:
37        return True
38
39
40def _get_server_address(server_stream):
41    while True:
42        server_stream.seek(0)
43        line = server_stream.readline()
44        while line:
45            matches = re.search("Binding to '(.+)'", line)
46            if matches is not None:
47                return matches.groups()[0]
48            line = server_stream.readline()
49
50
51class MultiprocessingExampleTest(unittest.TestCase):
52    def test_multiprocessing_example(self):
53        server_stdout = tempfile.TemporaryFile(mode="r")
54        server_process = subprocess.Popen((_SERVER_PATH,), stdout=server_stdout)
55        server_address = _get_server_address(server_stdout)
56        client_stdout = tempfile.TemporaryFile(mode="r")
57        client_process = subprocess.Popen(
58            (
59                _CLIENT_PATH,
60                server_address,
61            ),
62            stdout=client_stdout,
63        )
64        client_process.wait()
65        server_process.terminate()
66        client_stdout.seek(0)
67        results = ast.literal_eval(client_stdout.read().strip().split("\n")[-1])
68        values = tuple(result[0] for result in results)
69        self.assertSequenceEqual(range(2, 10000), values)
70        for result in results:
71            self.assertEqual(is_prime(result[0]), result[1])
72
73
74if __name__ == "__main__":
75    logging.basicConfig()
76    unittest.main(verbosity=2)
77