xref: /aosp_15_r20/external/grpc-grpc/src/python/grpcio_tests/tests/qps/benchmark_client.py (revision cc02d7e222339f7a4f6ba5f422e6413f4bd931f2)
1# Copyright 2016 gRPC authors.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""Defines test client behaviors (UNARY/STREAMING) (SYNC/ASYNC)."""
15
16import abc
17from concurrent import futures
18import queue
19import threading
20import time
21
22import grpc
23
24from src.proto.grpc.testing import benchmark_service_pb2_grpc
25from src.proto.grpc.testing import messages_pb2
26from tests.unit import resources
27from tests.unit import test_common
28
29_TIMEOUT = 60 * 60 * 24
30
31
32class GenericStub(object):
33    def __init__(self, channel):
34        self.UnaryCall = channel.unary_unary(
35            "/grpc.testing.BenchmarkService/UnaryCall",
36            _registered_method=True,
37        )
38        self.StreamingFromServer = channel.unary_stream(
39            "/grpc.testing.BenchmarkService/StreamingFromServer",
40            _registered_method=True,
41        )
42        self.StreamingCall = channel.stream_stream(
43            "/grpc.testing.BenchmarkService/StreamingCall",
44            _registered_method=True,
45        )
46
47
48class BenchmarkClient:
49    """Benchmark client interface that exposes a non-blocking send_request()."""
50
51    __metaclass__ = abc.ABCMeta
52
53    def __init__(self, server, config, hist):
54        # Create the stub
55        if config.HasField("security_params"):
56            creds = grpc.ssl_channel_credentials(
57                resources.test_root_certificates()
58            )
59            channel = test_common.test_secure_channel(
60                server, creds, config.security_params.server_host_override
61            )
62        else:
63            channel = grpc.insecure_channel(server)
64
65        # waits for the channel to be ready before we start sending messages
66        grpc.channel_ready_future(channel).result()
67
68        if config.payload_config.WhichOneof("payload") == "simple_params":
69            self._generic = False
70            self._stub = benchmark_service_pb2_grpc.BenchmarkServiceStub(
71                channel
72            )
73            payload = messages_pb2.Payload(
74                body=bytes(b"\0" * config.payload_config.simple_params.req_size)
75            )
76            self._request = messages_pb2.SimpleRequest(
77                payload=payload,
78                response_size=config.payload_config.simple_params.resp_size,
79            )
80        else:
81            self._generic = True
82            self._stub = GenericStub(channel)
83            self._request = bytes(
84                b"\0" * config.payload_config.bytebuf_params.req_size
85            )
86
87        self._hist = hist
88        self._response_callbacks = []
89
90    def add_response_callback(self, callback):
91        """callback will be invoked as callback(client, query_time)"""
92        self._response_callbacks.append(callback)
93
94    @abc.abstractmethod
95    def send_request(self):
96        """Non-blocking wrapper for a client's request operation."""
97        raise NotImplementedError()
98
99    def start(self):
100        pass
101
102    def stop(self):
103        pass
104
105    def _handle_response(self, client, query_time):
106        self._hist.add(query_time * 1e9)  # Report times in nanoseconds
107        for callback in self._response_callbacks:
108            callback(client, query_time)
109
110
111class UnarySyncBenchmarkClient(BenchmarkClient):
112    def __init__(self, server, config, hist):
113        super(UnarySyncBenchmarkClient, self).__init__(server, config, hist)
114        self._pool = futures.ThreadPoolExecutor(
115            max_workers=config.outstanding_rpcs_per_channel
116        )
117
118    def send_request(self):
119        # Send requests in separate threads to support multiple outstanding rpcs
120        # (See src/proto/grpc/testing/control.proto)
121        self._pool.submit(self._dispatch_request)
122
123    def stop(self):
124        self._pool.shutdown(wait=True)
125        self._stub = None
126
127    def _dispatch_request(self):
128        start_time = time.time()
129        self._stub.UnaryCall(self._request, _TIMEOUT)
130        end_time = time.time()
131        self._handle_response(self, end_time - start_time)
132
133
134class UnaryAsyncBenchmarkClient(BenchmarkClient):
135    def send_request(self):
136        # Use the Future callback api to support multiple outstanding rpcs
137        start_time = time.time()
138        response_future = self._stub.UnaryCall.future(self._request, _TIMEOUT)
139        response_future.add_done_callback(
140            lambda resp: self._response_received(start_time, resp)
141        )
142
143    def _response_received(self, start_time, resp):
144        resp.result()
145        end_time = time.time()
146        self._handle_response(self, end_time - start_time)
147
148    def stop(self):
149        self._stub = None
150
151
152class _SyncStream(object):
153    def __init__(self, stub, generic, request, handle_response):
154        self._stub = stub
155        self._generic = generic
156        self._request = request
157        self._handle_response = handle_response
158        self._is_streaming = False
159        self._request_queue = queue.Queue()
160        self._send_time_queue = queue.Queue()
161
162    def send_request(self):
163        self._send_time_queue.put(time.time())
164        self._request_queue.put(self._request)
165
166    def start(self):
167        self._is_streaming = True
168        response_stream = self._stub.StreamingCall(
169            self._request_generator(), _TIMEOUT
170        )
171        for _ in response_stream:
172            self._handle_response(
173                self, time.time() - self._send_time_queue.get_nowait()
174            )
175
176    def stop(self):
177        self._is_streaming = False
178
179    def _request_generator(self):
180        while self._is_streaming:
181            try:
182                request = self._request_queue.get(block=True, timeout=1.0)
183                yield request
184            except queue.Empty:
185                pass
186
187
188class StreamingSyncBenchmarkClient(BenchmarkClient):
189    def __init__(self, server, config, hist):
190        super(StreamingSyncBenchmarkClient, self).__init__(server, config, hist)
191        self._pool = futures.ThreadPoolExecutor(
192            max_workers=config.outstanding_rpcs_per_channel
193        )
194        self._streams = [
195            _SyncStream(
196                self._stub, self._generic, self._request, self._handle_response
197            )
198            for _ in range(config.outstanding_rpcs_per_channel)
199        ]
200        self._curr_stream = 0
201
202    def send_request(self):
203        # Use a round_robin scheduler to determine what stream to send on
204        self._streams[self._curr_stream].send_request()
205        self._curr_stream = (self._curr_stream + 1) % len(self._streams)
206
207    def start(self):
208        for stream in self._streams:
209            self._pool.submit(stream.start)
210
211    def stop(self):
212        for stream in self._streams:
213            stream.stop()
214        self._pool.shutdown(wait=True)
215        self._stub = None
216
217
218class ServerStreamingSyncBenchmarkClient(BenchmarkClient):
219    def __init__(self, server, config, hist):
220        super(ServerStreamingSyncBenchmarkClient, self).__init__(
221            server, config, hist
222        )
223        if config.outstanding_rpcs_per_channel == 1:
224            self._pool = None
225        else:
226            self._pool = futures.ThreadPoolExecutor(
227                max_workers=config.outstanding_rpcs_per_channel
228            )
229        self._rpcs = []
230        self._sender = None
231
232    def send_request(self):
233        if self._pool is None:
234            self._sender = threading.Thread(
235                target=self._one_stream_streaming_rpc, daemon=True
236            )
237            self._sender.start()
238        else:
239            self._pool.submit(self._one_stream_streaming_rpc)
240
241    def _one_stream_streaming_rpc(self):
242        response_stream = self._stub.StreamingFromServer(
243            self._request, _TIMEOUT
244        )
245        self._rpcs.append(response_stream)
246        start_time = time.time()
247        for _ in response_stream:
248            self._handle_response(self, time.time() - start_time)
249            start_time = time.time()
250
251    def stop(self):
252        for call in self._rpcs:
253            call.cancel()
254        if self._sender is not None:
255            self._sender.join()
256        if self._pool is not None:
257            self._pool.shutdown(wait=False)
258        self._stub = None
259