xref: /aosp_15_r20/external/grpc-grpc/src/python/grpcio_tests/tests/stress/client.py (revision cc02d7e222339f7a4f6ba5f422e6413f4bd931f2)
1# Copyright 2016 gRPC authors.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""Entry point for running stress tests."""
15
16from concurrent import futures
17import queue
18import threading
19
20from absl import app
21from absl.flags import argparse_flags
22import grpc
23
24from src.proto.grpc.testing import metrics_pb2_grpc
25from src.proto.grpc.testing import test_pb2_grpc
26from tests.interop import methods
27from tests.interop import resources
28from tests.qps import histogram
29from tests.stress import metrics_server
30from tests.stress import test_runner
31
32
33def _args(argv):
34    parser = argparse_flags.ArgumentParser()
35    parser.add_argument(
36        "--server_addresses",
37        help="comma separated list of hostname:port to run servers on",
38        default="localhost:8080",
39        type=str,
40    )
41    parser.add_argument(
42        "--test_cases",
43        help="comma separated list of testcase:weighting of tests to run",
44        default="large_unary:100",
45        type=str,
46    )
47    parser.add_argument(
48        "--test_duration_secs",
49        help="number of seconds to run the stress test",
50        default=-1,
51        type=int,
52    )
53    parser.add_argument(
54        "--num_channels_per_server",
55        help="number of channels per server",
56        default=1,
57        type=int,
58    )
59    parser.add_argument(
60        "--num_stubs_per_channel",
61        help="number of stubs to create per channel",
62        default=1,
63        type=int,
64    )
65    parser.add_argument(
66        "--metrics_port",
67        help="the port to listen for metrics requests on",
68        default=8081,
69        type=int,
70    )
71    parser.add_argument(
72        "--use_test_ca",
73        help="Whether to use our fake CA. Requires --use_tls=true",
74        default=False,
75        type=bool,
76    )
77    parser.add_argument(
78        "--use_tls", help="Whether to use TLS", default=False, type=bool
79    )
80    parser.add_argument(
81        "--server_host_override",
82        help="the server host to which to claim to connect",
83        type=str,
84    )
85    return parser.parse_args(argv[1:])
86
87
88def _test_case_from_arg(test_case_arg):
89    for test_case in methods.TestCase:
90        if test_case_arg == test_case.value:
91            return test_case
92    else:
93        raise ValueError("No test case {}!".format(test_case_arg))
94
95
96def _parse_weighted_test_cases(test_case_args):
97    weighted_test_cases = {}
98    for test_case_arg in test_case_args.split(","):
99        name, weight = test_case_arg.split(":", 1)
100        test_case = _test_case_from_arg(name)
101        weighted_test_cases[test_case] = int(weight)
102    return weighted_test_cases
103
104
105def _get_channel(target, args):
106    if args.use_tls:
107        if args.use_test_ca:
108            root_certificates = resources.test_root_certificates()
109        else:
110            root_certificates = None  # will load default roots.
111        channel_credentials = grpc.ssl_channel_credentials(
112            root_certificates=root_certificates
113        )
114        options = (
115            (
116                "grpc.ssl_target_name_override",
117                args.server_host_override,
118            ),
119        )
120        channel = grpc.secure_channel(
121            target, channel_credentials, options=options
122        )
123    else:
124        channel = grpc.insecure_channel(target)
125
126    # waits for the channel to be ready before we start sending messages
127    grpc.channel_ready_future(channel).result()
128    return channel
129
130
131def run_test(args):
132    test_cases = _parse_weighted_test_cases(args.test_cases)
133    test_server_targets = args.server_addresses.split(",")
134    # Propagate any client exceptions with a queue
135    exception_queue = queue.Queue()
136    stop_event = threading.Event()
137    hist = histogram.Histogram(1, 1)
138    runners = []
139
140    server = grpc.server(futures.ThreadPoolExecutor(max_workers=25))
141    metrics_pb2_grpc.add_MetricsServiceServicer_to_server(
142        metrics_server.MetricsServer(hist), server
143    )
144    server.add_insecure_port("[::]:{}".format(args.metrics_port))
145    server.start()
146
147    for test_server_target in test_server_targets:
148        for _ in range(args.num_channels_per_server):
149            channel = _get_channel(test_server_target, args)
150            for _ in range(args.num_stubs_per_channel):
151                stub = test_pb2_grpc.TestServiceStub(channel)
152                runner = test_runner.TestRunner(
153                    stub, test_cases, hist, exception_queue, stop_event
154                )
155                runners.append(runner)
156
157    for runner in runners:
158        runner.start()
159    try:
160        timeout_secs = args.test_duration_secs
161        if timeout_secs < 0:
162            timeout_secs = None
163        raise exception_queue.get(block=True, timeout=timeout_secs)
164    except queue.Empty:
165        # No exceptions thrown, success
166        pass
167    finally:
168        stop_event.set()
169        for runner in runners:
170            runner.join()
171        runner = None
172        server.stop(None)
173
174
175if __name__ == "__main__":
176    app.run(run_test, flags_parser=_args)
177