xref: /aosp_15_r20/external/grpc-grpc/src/python/grpcio_tests/tests/qps/worker_server.py (revision cc02d7e222339f7a4f6ba5f422e6413f4bd931f2)
1# Copyright 2016 gRPC authors.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15from concurrent import futures
16import multiprocessing
17import random
18import threading
19import time
20
21try:
22    # The resource module is not available on Windows. While this server only
23    # supports Linux, we must still be _importable_ on Windows.
24    import resource
25except ImportError:
26    pass
27
28import grpc
29
30from src.proto.grpc.testing import benchmark_service_pb2_grpc
31from src.proto.grpc.testing import control_pb2
32from src.proto.grpc.testing import stats_pb2
33from src.proto.grpc.testing import worker_service_pb2_grpc
34from tests.qps import benchmark_client
35from tests.qps import benchmark_server
36from tests.qps import client_runner
37from tests.qps import histogram
38from tests.unit import resources
39from tests.unit import test_common
40
41
42class Snapshotter:
43    def __init__(self):
44        self._start_time = 0.0
45        self._end_time = 0.0
46        self._last_utime = 0.0
47        self._utime = 0.0
48        self._last_stime = 0.0
49        self._stime = 0.0
50
51    def get_time_elapsed(self):
52        return self._end_time - self._start_time
53
54    def get_utime(self):
55        return self._utime - self._last_utime
56
57    def get_stime(self):
58        return self._stime - self._last_stime
59
60    def snapshot(self):
61        self._end_time = time.time()
62
63        usage = resource.getrusage(resource.RUSAGE_SELF)
64        self._utime = usage.ru_utime
65        self._stime = usage.ru_stime
66
67    def reset(self):
68        self._start_time = self._end_time
69        self._last_utime = self._utime
70        self._last_stime = self._stime
71
72    def stats(self):
73        return {
74            "time_elapsed": self.get_time_elapsed(),
75            "time_user": self.get_utime(),
76            "time_system": self.get_stime(),
77        }
78
79
80class WorkerServer(worker_service_pb2_grpc.WorkerServiceServicer):
81    """Python Worker Server implementation."""
82
83    def __init__(self, server_port=None):
84        self._quit_event = threading.Event()
85        self._server_port = server_port
86        self._snapshotter = Snapshotter()
87
88    def RunServer(self, request_iterator, context):
89        # pylint: disable=stop-iteration-return
90        config = next(request_iterator).setup
91        # pylint: enable=stop-iteration-return
92        server, port = self._create_server(config)
93        cores = multiprocessing.cpu_count()
94        server.start()
95        self._snapshotter.snapshot()
96        self._snapshotter.reset()
97        yield self._get_server_status(port, cores)
98
99        for request in request_iterator:
100            self._snapshotter.snapshot()
101            status = self._get_server_status(port, cores)
102            if request.mark.reset:
103                self._snapshotter.reset()
104            yield status
105        server.stop(None)
106
107    def _get_server_status(self, port, cores):
108        stats = stats_pb2.ServerStats(**self._snapshotter.stats())
109        return control_pb2.ServerStatus(stats=stats, port=port, cores=cores)
110
111    def _create_server(self, config):
112        if config.async_server_threads == 0:
113            # This is the default concurrent.futures thread pool size, but
114            # None doesn't seem to work
115            server_threads = multiprocessing.cpu_count() * 5
116        else:
117            server_threads = config.async_server_threads
118        server = test_common.test_server(max_workers=server_threads)
119        if config.server_type == control_pb2.ASYNC_SERVER:
120            servicer = benchmark_server.BenchmarkServer()
121            benchmark_service_pb2_grpc.add_BenchmarkServiceServicer_to_server(
122                servicer, server
123            )
124        elif config.server_type == control_pb2.ASYNC_GENERIC_SERVER:
125            resp_size = config.payload_config.bytebuf_params.resp_size
126            servicer = benchmark_server.GenericBenchmarkServer(resp_size)
127            method_implementations = {
128                "StreamingCall": grpc.stream_stream_rpc_method_handler(
129                    servicer.StreamingCall
130                ),
131                "UnaryCall": grpc.unary_unary_rpc_method_handler(
132                    servicer.UnaryCall
133                ),
134            }
135            handler = grpc.method_handlers_generic_handler(
136                "grpc.testing.BenchmarkService", method_implementations
137            )
138            server.add_generic_rpc_handlers((handler,))
139        else:
140            raise Exception(
141                "Unsupported server type {}".format(config.server_type)
142            )
143
144        if self._server_port is not None and config.port == 0:
145            server_port = self._server_port
146        else:
147            server_port = config.port
148
149        if config.HasField("security_params"):  # Use SSL
150            server_creds = grpc.ssl_server_credentials(
151                ((resources.private_key(), resources.certificate_chain()),)
152            )
153            port = server.add_secure_port(
154                "[::]:{}".format(server_port), server_creds
155            )
156        else:
157            port = server.add_insecure_port("[::]:{}".format(server_port))
158
159        return (server, port)
160
161    def RunClient(self, request_iterator, context):
162        # pylint: disable=stop-iteration-return
163        config = next(request_iterator).setup
164        # pylint: enable=stop-iteration-return
165        client_runners = []
166        qps_data = histogram.Histogram(
167            config.histogram_params.resolution,
168            config.histogram_params.max_possible,
169        )
170        self._snapshotter.snapshot()
171        self._snapshotter.reset()
172
173        # Create a client for each channel
174        for i in range(config.client_channels):
175            server = config.server_targets[i % len(config.server_targets)]
176            runner = self._create_client_runner(server, config, qps_data)
177            client_runners.append(runner)
178            runner.start()
179
180        self._snapshotter.snapshot()
181        yield self._get_client_status(qps_data)
182
183        # Respond to stat requests
184        for request in request_iterator:
185            self._snapshotter.snapshot()
186            status = self._get_client_status(qps_data)
187            if request.mark.reset:
188                qps_data.reset()
189                self._snapshotter.reset()
190            yield status
191
192        # Cleanup the clients
193        for runner in client_runners:
194            runner.stop()
195
196    def _get_client_status(self, qps_data):
197        latencies = qps_data.get_data()
198        stats = stats_pb2.ClientStats(
199            latencies=latencies, **self._snapshotter.stats()
200        )
201        return control_pb2.ClientStatus(stats=stats)
202
203    def _create_client_runner(self, server, config, qps_data):
204        no_ping_pong = False
205        if config.client_type == control_pb2.SYNC_CLIENT:
206            if config.rpc_type == control_pb2.UNARY:
207                client = benchmark_client.UnarySyncBenchmarkClient(
208                    server, config, qps_data
209                )
210            elif config.rpc_type == control_pb2.STREAMING:
211                client = benchmark_client.StreamingSyncBenchmarkClient(
212                    server, config, qps_data
213                )
214            elif config.rpc_type == control_pb2.STREAMING_FROM_SERVER:
215                no_ping_pong = True
216                client = benchmark_client.ServerStreamingSyncBenchmarkClient(
217                    server, config, qps_data
218                )
219        elif config.client_type == control_pb2.ASYNC_CLIENT:
220            if config.rpc_type == control_pb2.UNARY:
221                client = benchmark_client.UnaryAsyncBenchmarkClient(
222                    server, config, qps_data
223                )
224            else:
225                raise Exception("Async streaming client not supported")
226        else:
227            raise Exception(
228                "Unsupported client type {}".format(config.client_type)
229            )
230
231        # In multi-channel tests, we split the load across all channels
232        load_factor = float(config.client_channels)
233        if config.load_params.WhichOneof("load") == "closed_loop":
234            runner = client_runner.ClosedLoopClientRunner(
235                client, config.outstanding_rpcs_per_channel, no_ping_pong
236            )
237        else:  # Open loop Poisson
238            alpha = config.load_params.poisson.offered_load / load_factor
239
240            def poisson():
241                while True:
242                    yield random.expovariate(alpha)
243
244            runner = client_runner.OpenLoopClientRunner(client, poisson())
245
246        return runner
247
248    def CoreCount(self, request, context):
249        return control_pb2.CoreResponse(cores=multiprocessing.cpu_count())
250
251    def QuitWorker(self, request, context):
252        self._quit_event.set()
253        return control_pb2.Void()
254
255    def wait_for_quit(self):
256        self._quit_event.wait()
257