1# Copyright 2016 gRPC authors. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14"""Entry point for running stress tests.""" 15 16from concurrent import futures 17import queue 18import threading 19 20from absl import app 21from absl.flags import argparse_flags 22import grpc 23 24from src.proto.grpc.testing import metrics_pb2_grpc 25from src.proto.grpc.testing import test_pb2_grpc 26from tests.interop import methods 27from tests.interop import resources 28from tests.qps import histogram 29from tests.stress import metrics_server 30from tests.stress import test_runner 31 32 33def _args(argv): 34 parser = argparse_flags.ArgumentParser() 35 parser.add_argument( 36 "--server_addresses", 37 help="comma separated list of hostname:port to run servers on", 38 default="localhost:8080", 39 type=str, 40 ) 41 parser.add_argument( 42 "--test_cases", 43 help="comma separated list of testcase:weighting of tests to run", 44 default="large_unary:100", 45 type=str, 46 ) 47 parser.add_argument( 48 "--test_duration_secs", 49 help="number of seconds to run the stress test", 50 default=-1, 51 type=int, 52 ) 53 parser.add_argument( 54 "--num_channels_per_server", 55 help="number of channels per server", 56 default=1, 57 type=int, 58 ) 59 parser.add_argument( 60 "--num_stubs_per_channel", 61 help="number of stubs to create per channel", 62 default=1, 63 type=int, 64 ) 65 parser.add_argument( 66 "--metrics_port", 67 help="the port to listen for metrics requests on", 68 default=8081, 69 type=int, 70 ) 71 parser.add_argument( 72 "--use_test_ca", 73 help="Whether to use our fake CA. Requires --use_tls=true", 74 default=False, 75 type=bool, 76 ) 77 parser.add_argument( 78 "--use_tls", help="Whether to use TLS", default=False, type=bool 79 ) 80 parser.add_argument( 81 "--server_host_override", 82 help="the server host to which to claim to connect", 83 type=str, 84 ) 85 return parser.parse_args(argv[1:]) 86 87 88def _test_case_from_arg(test_case_arg): 89 for test_case in methods.TestCase: 90 if test_case_arg == test_case.value: 91 return test_case 92 else: 93 raise ValueError("No test case {}!".format(test_case_arg)) 94 95 96def _parse_weighted_test_cases(test_case_args): 97 weighted_test_cases = {} 98 for test_case_arg in test_case_args.split(","): 99 name, weight = test_case_arg.split(":", 1) 100 test_case = _test_case_from_arg(name) 101 weighted_test_cases[test_case] = int(weight) 102 return weighted_test_cases 103 104 105def _get_channel(target, args): 106 if args.use_tls: 107 if args.use_test_ca: 108 root_certificates = resources.test_root_certificates() 109 else: 110 root_certificates = None # will load default roots. 111 channel_credentials = grpc.ssl_channel_credentials( 112 root_certificates=root_certificates 113 ) 114 options = ( 115 ( 116 "grpc.ssl_target_name_override", 117 args.server_host_override, 118 ), 119 ) 120 channel = grpc.secure_channel( 121 target, channel_credentials, options=options 122 ) 123 else: 124 channel = grpc.insecure_channel(target) 125 126 # waits for the channel to be ready before we start sending messages 127 grpc.channel_ready_future(channel).result() 128 return channel 129 130 131def run_test(args): 132 test_cases = _parse_weighted_test_cases(args.test_cases) 133 test_server_targets = args.server_addresses.split(",") 134 # Propagate any client exceptions with a queue 135 exception_queue = queue.Queue() 136 stop_event = threading.Event() 137 hist = histogram.Histogram(1, 1) 138 runners = [] 139 140 server = grpc.server(futures.ThreadPoolExecutor(max_workers=25)) 141 metrics_pb2_grpc.add_MetricsServiceServicer_to_server( 142 metrics_server.MetricsServer(hist), server 143 ) 144 server.add_insecure_port("[::]:{}".format(args.metrics_port)) 145 server.start() 146 147 for test_server_target in test_server_targets: 148 for _ in range(args.num_channels_per_server): 149 channel = _get_channel(test_server_target, args) 150 for _ in range(args.num_stubs_per_channel): 151 stub = test_pb2_grpc.TestServiceStub(channel) 152 runner = test_runner.TestRunner( 153 stub, test_cases, hist, exception_queue, stop_event 154 ) 155 runners.append(runner) 156 157 for runner in runners: 158 runner.start() 159 try: 160 timeout_secs = args.test_duration_secs 161 if timeout_secs < 0: 162 timeout_secs = None 163 raise exception_queue.get(block=True, timeout=timeout_secs) 164 except queue.Empty: 165 # No exceptions thrown, success 166 pass 167 finally: 168 stop_event.set() 169 for runner in runners: 170 runner.join() 171 runner = None 172 server.stop(None) 173 174 175if __name__ == "__main__": 176 app.run(run_test, flags_parser=_args) 177