1# Copyright 2016 gRPC authors. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15from concurrent import futures 16import multiprocessing 17import random 18import threading 19import time 20 21try: 22 # The resource module is not available on Windows. While this server only 23 # supports Linux, we must still be _importable_ on Windows. 24 import resource 25except ImportError: 26 pass 27 28import grpc 29 30from src.proto.grpc.testing import benchmark_service_pb2_grpc 31from src.proto.grpc.testing import control_pb2 32from src.proto.grpc.testing import stats_pb2 33from src.proto.grpc.testing import worker_service_pb2_grpc 34from tests.qps import benchmark_client 35from tests.qps import benchmark_server 36from tests.qps import client_runner 37from tests.qps import histogram 38from tests.unit import resources 39from tests.unit import test_common 40 41 42class Snapshotter: 43 def __init__(self): 44 self._start_time = 0.0 45 self._end_time = 0.0 46 self._last_utime = 0.0 47 self._utime = 0.0 48 self._last_stime = 0.0 49 self._stime = 0.0 50 51 def get_time_elapsed(self): 52 return self._end_time - self._start_time 53 54 def get_utime(self): 55 return self._utime - self._last_utime 56 57 def get_stime(self): 58 return self._stime - self._last_stime 59 60 def snapshot(self): 61 self._end_time = time.time() 62 63 usage = resource.getrusage(resource.RUSAGE_SELF) 64 self._utime = usage.ru_utime 65 self._stime = usage.ru_stime 66 67 def reset(self): 68 self._start_time = self._end_time 69 self._last_utime = self._utime 70 self._last_stime = self._stime 71 72 def stats(self): 73 return { 74 "time_elapsed": self.get_time_elapsed(), 75 "time_user": self.get_utime(), 76 "time_system": self.get_stime(), 77 } 78 79 80class WorkerServer(worker_service_pb2_grpc.WorkerServiceServicer): 81 """Python Worker Server implementation.""" 82 83 def __init__(self, server_port=None): 84 self._quit_event = threading.Event() 85 self._server_port = server_port 86 self._snapshotter = Snapshotter() 87 88 def RunServer(self, request_iterator, context): 89 # pylint: disable=stop-iteration-return 90 config = next(request_iterator).setup 91 # pylint: enable=stop-iteration-return 92 server, port = self._create_server(config) 93 cores = multiprocessing.cpu_count() 94 server.start() 95 self._snapshotter.snapshot() 96 self._snapshotter.reset() 97 yield self._get_server_status(port, cores) 98 99 for request in request_iterator: 100 self._snapshotter.snapshot() 101 status = self._get_server_status(port, cores) 102 if request.mark.reset: 103 self._snapshotter.reset() 104 yield status 105 server.stop(None) 106 107 def _get_server_status(self, port, cores): 108 stats = stats_pb2.ServerStats(**self._snapshotter.stats()) 109 return control_pb2.ServerStatus(stats=stats, port=port, cores=cores) 110 111 def _create_server(self, config): 112 if config.async_server_threads == 0: 113 # This is the default concurrent.futures thread pool size, but 114 # None doesn't seem to work 115 server_threads = multiprocessing.cpu_count() * 5 116 else: 117 server_threads = config.async_server_threads 118 server = test_common.test_server(max_workers=server_threads) 119 if config.server_type == control_pb2.ASYNC_SERVER: 120 servicer = benchmark_server.BenchmarkServer() 121 benchmark_service_pb2_grpc.add_BenchmarkServiceServicer_to_server( 122 servicer, server 123 ) 124 elif config.server_type == control_pb2.ASYNC_GENERIC_SERVER: 125 resp_size = config.payload_config.bytebuf_params.resp_size 126 servicer = benchmark_server.GenericBenchmarkServer(resp_size) 127 method_implementations = { 128 "StreamingCall": grpc.stream_stream_rpc_method_handler( 129 servicer.StreamingCall 130 ), 131 "UnaryCall": grpc.unary_unary_rpc_method_handler( 132 servicer.UnaryCall 133 ), 134 } 135 handler = grpc.method_handlers_generic_handler( 136 "grpc.testing.BenchmarkService", method_implementations 137 ) 138 server.add_generic_rpc_handlers((handler,)) 139 else: 140 raise Exception( 141 "Unsupported server type {}".format(config.server_type) 142 ) 143 144 if self._server_port is not None and config.port == 0: 145 server_port = self._server_port 146 else: 147 server_port = config.port 148 149 if config.HasField("security_params"): # Use SSL 150 server_creds = grpc.ssl_server_credentials( 151 ((resources.private_key(), resources.certificate_chain()),) 152 ) 153 port = server.add_secure_port( 154 "[::]:{}".format(server_port), server_creds 155 ) 156 else: 157 port = server.add_insecure_port("[::]:{}".format(server_port)) 158 159 return (server, port) 160 161 def RunClient(self, request_iterator, context): 162 # pylint: disable=stop-iteration-return 163 config = next(request_iterator).setup 164 # pylint: enable=stop-iteration-return 165 client_runners = [] 166 qps_data = histogram.Histogram( 167 config.histogram_params.resolution, 168 config.histogram_params.max_possible, 169 ) 170 self._snapshotter.snapshot() 171 self._snapshotter.reset() 172 173 # Create a client for each channel 174 for i in range(config.client_channels): 175 server = config.server_targets[i % len(config.server_targets)] 176 runner = self._create_client_runner(server, config, qps_data) 177 client_runners.append(runner) 178 runner.start() 179 180 self._snapshotter.snapshot() 181 yield self._get_client_status(qps_data) 182 183 # Respond to stat requests 184 for request in request_iterator: 185 self._snapshotter.snapshot() 186 status = self._get_client_status(qps_data) 187 if request.mark.reset: 188 qps_data.reset() 189 self._snapshotter.reset() 190 yield status 191 192 # Cleanup the clients 193 for runner in client_runners: 194 runner.stop() 195 196 def _get_client_status(self, qps_data): 197 latencies = qps_data.get_data() 198 stats = stats_pb2.ClientStats( 199 latencies=latencies, **self._snapshotter.stats() 200 ) 201 return control_pb2.ClientStatus(stats=stats) 202 203 def _create_client_runner(self, server, config, qps_data): 204 no_ping_pong = False 205 if config.client_type == control_pb2.SYNC_CLIENT: 206 if config.rpc_type == control_pb2.UNARY: 207 client = benchmark_client.UnarySyncBenchmarkClient( 208 server, config, qps_data 209 ) 210 elif config.rpc_type == control_pb2.STREAMING: 211 client = benchmark_client.StreamingSyncBenchmarkClient( 212 server, config, qps_data 213 ) 214 elif config.rpc_type == control_pb2.STREAMING_FROM_SERVER: 215 no_ping_pong = True 216 client = benchmark_client.ServerStreamingSyncBenchmarkClient( 217 server, config, qps_data 218 ) 219 elif config.client_type == control_pb2.ASYNC_CLIENT: 220 if config.rpc_type == control_pb2.UNARY: 221 client = benchmark_client.UnaryAsyncBenchmarkClient( 222 server, config, qps_data 223 ) 224 else: 225 raise Exception("Async streaming client not supported") 226 else: 227 raise Exception( 228 "Unsupported client type {}".format(config.client_type) 229 ) 230 231 # In multi-channel tests, we split the load across all channels 232 load_factor = float(config.client_channels) 233 if config.load_params.WhichOneof("load") == "closed_loop": 234 runner = client_runner.ClosedLoopClientRunner( 235 client, config.outstanding_rpcs_per_channel, no_ping_pong 236 ) 237 else: # Open loop Poisson 238 alpha = config.load_params.poisson.offered_load / load_factor 239 240 def poisson(): 241 while True: 242 yield random.expovariate(alpha) 243 244 runner = client_runner.OpenLoopClientRunner(client, poisson()) 245 246 return runner 247 248 def CoreCount(self, request, context): 249 return control_pb2.CoreResponse(cores=multiprocessing.cpu_count()) 250 251 def QuitWorker(self, request, context): 252 self._quit_event.set() 253 return control_pb2.Void() 254 255 def wait_for_quit(self): 256 self._quit_event.wait() 257