xref: /aosp_15_r20/external/grpc-grpc/src/python/grpcio_tests/tests/unit/_compression_test.py (revision cc02d7e222339f7a4f6ba5f422e6413f4bd931f2)
1# Copyright 2016 gRPC authors.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""Tests server and client side compression."""
15
16from concurrent import futures
17import contextlib
18import functools
19import itertools
20import logging
21import os
22import unittest
23
24import grpc
25from grpc import _grpcio_metadata
26
27from tests.unit import _tcp_proxy
28from tests.unit.framework.common import test_constants
29
30_UNARY_UNARY = "/test/UnaryUnary"
31_UNARY_STREAM = "/test/UnaryStream"
32_STREAM_UNARY = "/test/StreamUnary"
33_STREAM_STREAM = "/test/StreamStream"
34
35# Cut down on test time.
36_STREAM_LENGTH = test_constants.STREAM_LENGTH // 16
37
38_HOST = "localhost"
39
40_REQUEST = b"\x00" * 100
41_COMPRESSION_RATIO_THRESHOLD = 0.05
42_COMPRESSION_METHODS = (
43    None,
44    # Disabled for test tractability.
45    # grpc.Compression.NoCompression,
46    # grpc.Compression.Deflate,
47    grpc.Compression.Gzip,
48)
49_COMPRESSION_NAMES = {
50    None: "Uncompressed",
51    grpc.Compression.NoCompression: "NoCompression",
52    grpc.Compression.Deflate: "DeflateCompression",
53    grpc.Compression.Gzip: "GzipCompression",
54}
55
56_TEST_OPTIONS = {
57    "client_streaming": (True, False),
58    "server_streaming": (True, False),
59    "channel_compression": _COMPRESSION_METHODS,
60    "multicallable_compression": _COMPRESSION_METHODS,
61    "server_compression": _COMPRESSION_METHODS,
62    "server_call_compression": _COMPRESSION_METHODS,
63}
64
65
66def _make_handle_unary_unary(pre_response_callback):
67    def _handle_unary(request, servicer_context):
68        if pre_response_callback:
69            pre_response_callback(request, servicer_context)
70        return request
71
72    return _handle_unary
73
74
75def _make_handle_unary_stream(pre_response_callback):
76    def _handle_unary_stream(request, servicer_context):
77        if pre_response_callback:
78            pre_response_callback(request, servicer_context)
79        for _ in range(_STREAM_LENGTH):
80            yield request
81
82    return _handle_unary_stream
83
84
85def _make_handle_stream_unary(pre_response_callback):
86    def _handle_stream_unary(request_iterator, servicer_context):
87        if pre_response_callback:
88            pre_response_callback(request_iterator, servicer_context)
89        response = None
90        for request in request_iterator:
91            if not response:
92                response = request
93        return response
94
95    return _handle_stream_unary
96
97
98def _make_handle_stream_stream(pre_response_callback):
99    def _handle_stream(request_iterator, servicer_context):
100        # TODO(issue:#6891) We should be able to remove this loop,
101        # and replace with return; yield
102        for request in request_iterator:
103            if pre_response_callback:
104                pre_response_callback(request, servicer_context)
105            yield request
106
107    return _handle_stream
108
109
110def set_call_compression(
111    compression_method, request_or_iterator, servicer_context
112):
113    del request_or_iterator
114    servicer_context.set_compression(compression_method)
115
116
117def disable_next_compression(request, servicer_context):
118    del request
119    servicer_context.disable_next_message_compression()
120
121
122def disable_first_compression(request, servicer_context):
123    if int(request.decode("ascii")) == 0:
124        servicer_context.disable_next_message_compression()
125
126
127class _MethodHandler(grpc.RpcMethodHandler):
128    def __init__(
129        self, request_streaming, response_streaming, pre_response_callback
130    ):
131        self.request_streaming = request_streaming
132        self.response_streaming = response_streaming
133        self.request_deserializer = None
134        self.response_serializer = None
135        self.unary_unary = None
136        self.unary_stream = None
137        self.stream_unary = None
138        self.stream_stream = None
139
140        if self.request_streaming and self.response_streaming:
141            self.stream_stream = _make_handle_stream_stream(
142                pre_response_callback
143            )
144        elif not self.request_streaming and not self.response_streaming:
145            self.unary_unary = _make_handle_unary_unary(pre_response_callback)
146        elif not self.request_streaming and self.response_streaming:
147            self.unary_stream = _make_handle_unary_stream(pre_response_callback)
148        else:
149            self.stream_unary = _make_handle_stream_unary(pre_response_callback)
150
151
152class _GenericHandler(grpc.GenericRpcHandler):
153    def __init__(self, pre_response_callback):
154        self._pre_response_callback = pre_response_callback
155
156    def service(self, handler_call_details):
157        if handler_call_details.method == _UNARY_UNARY:
158            return _MethodHandler(False, False, self._pre_response_callback)
159        elif handler_call_details.method == _UNARY_STREAM:
160            return _MethodHandler(False, True, self._pre_response_callback)
161        elif handler_call_details.method == _STREAM_UNARY:
162            return _MethodHandler(True, False, self._pre_response_callback)
163        elif handler_call_details.method == _STREAM_STREAM:
164            return _MethodHandler(True, True, self._pre_response_callback)
165        else:
166            return None
167
168
169@contextlib.contextmanager
170def _instrumented_client_server_pair(
171    channel_kwargs, server_kwargs, server_handler
172):
173    server = grpc.server(futures.ThreadPoolExecutor(), **server_kwargs)
174    server.add_generic_rpc_handlers((server_handler,))
175    server_port = server.add_insecure_port("{}:0".format(_HOST))
176    server.start()
177    with _tcp_proxy.TcpProxy(_HOST, _HOST, server_port) as proxy:
178        proxy_port = proxy.get_port()
179        with grpc.insecure_channel(
180            "{}:{}".format(_HOST, proxy_port), **channel_kwargs
181        ) as client_channel:
182            try:
183                yield client_channel, proxy, server
184            finally:
185                server.stop(None)
186
187
188def _get_byte_counts(
189    channel_kwargs,
190    multicallable_kwargs,
191    client_function,
192    server_kwargs,
193    server_handler,
194    message,
195):
196    with _instrumented_client_server_pair(
197        channel_kwargs, server_kwargs, server_handler
198    ) as pipeline:
199        client_channel, proxy, server = pipeline
200        client_function(client_channel, multicallable_kwargs, message)
201        return proxy.get_byte_count()
202
203
204def _get_compression_ratios(
205    client_function,
206    first_channel_kwargs,
207    first_multicallable_kwargs,
208    first_server_kwargs,
209    first_server_handler,
210    second_channel_kwargs,
211    second_multicallable_kwargs,
212    second_server_kwargs,
213    second_server_handler,
214    message,
215):
216    first_bytes_sent, first_bytes_received = _get_byte_counts(
217        first_channel_kwargs,
218        first_multicallable_kwargs,
219        client_function,
220        first_server_kwargs,
221        first_server_handler,
222        message,
223    )
224    second_bytes_sent, second_bytes_received = _get_byte_counts(
225        second_channel_kwargs,
226        second_multicallable_kwargs,
227        client_function,
228        second_server_kwargs,
229        second_server_handler,
230        message,
231    )
232    return (
233        (second_bytes_sent - first_bytes_sent) / float(first_bytes_sent),
234        (second_bytes_received - first_bytes_received)
235        / float(first_bytes_received),
236    )
237
238
239def _unary_unary_client(channel, multicallable_kwargs, message):
240    multi_callable = channel.unary_unary(
241        _UNARY_UNARY,
242        _registered_method=True,
243    )
244    response = multi_callable(message, **multicallable_kwargs)
245    if response != message:
246        raise RuntimeError(
247            "Request '{}' != Response '{}'".format(message, response)
248        )
249
250
251def _unary_stream_client(channel, multicallable_kwargs, message):
252    multi_callable = channel.unary_stream(
253        _UNARY_STREAM,
254        _registered_method=True,
255    )
256    response_iterator = multi_callable(message, **multicallable_kwargs)
257    for response in response_iterator:
258        if response != message:
259            raise RuntimeError(
260                "Request '{}' != Response '{}'".format(message, response)
261            )
262
263
264def _stream_unary_client(channel, multicallable_kwargs, message):
265    multi_callable = channel.stream_unary(
266        _STREAM_UNARY,
267        _registered_method=True,
268    )
269    requests = (_REQUEST for _ in range(_STREAM_LENGTH))
270    response = multi_callable(requests, **multicallable_kwargs)
271    if response != message:
272        raise RuntimeError(
273            "Request '{}' != Response '{}'".format(message, response)
274        )
275
276
277def _stream_stream_client(channel, multicallable_kwargs, message):
278    multi_callable = channel.stream_stream(
279        _STREAM_STREAM,
280        _registered_method=True,
281    )
282    request_prefix = str(0).encode("ascii") * 100
283    requests = (
284        request_prefix + str(i).encode("ascii") for i in range(_STREAM_LENGTH)
285    )
286    response_iterator = multi_callable(requests, **multicallable_kwargs)
287    for i, response in enumerate(response_iterator):
288        if int(response.decode("ascii")) != i:
289            raise RuntimeError(
290                "Request '{}' != Response '{}'".format(i, response)
291            )
292
293
294class CompressionTest(unittest.TestCase):
295    def assertCompressed(self, compression_ratio):
296        self.assertLess(
297            compression_ratio,
298            -1.0 * _COMPRESSION_RATIO_THRESHOLD,
299            msg="Actual compression ratio: {}".format(compression_ratio),
300        )
301
302    def assertNotCompressed(self, compression_ratio):
303        self.assertGreaterEqual(
304            compression_ratio,
305            -1.0 * _COMPRESSION_RATIO_THRESHOLD,
306            msg="Actual compession ratio: {}".format(compression_ratio),
307        )
308
309    def assertConfigurationCompressed(
310        self,
311        client_streaming,
312        server_streaming,
313        channel_compression,
314        multicallable_compression,
315        server_compression,
316        server_call_compression,
317    ):
318        client_side_compressed = (
319            channel_compression or multicallable_compression
320        )
321        server_side_compressed = server_compression or server_call_compression
322        channel_kwargs = (
323            {
324                "compression": channel_compression,
325            }
326            if channel_compression
327            else {}
328        )
329        multicallable_kwargs = (
330            {
331                "compression": multicallable_compression,
332            }
333            if multicallable_compression
334            else {}
335        )
336
337        client_function = None
338        if not client_streaming and not server_streaming:
339            client_function = _unary_unary_client
340        elif not client_streaming and server_streaming:
341            client_function = _unary_stream_client
342        elif client_streaming and not server_streaming:
343            client_function = _stream_unary_client
344        else:
345            client_function = _stream_stream_client
346
347        server_kwargs = (
348            {
349                "compression": server_compression,
350            }
351            if server_compression
352            else {}
353        )
354        server_handler = (
355            _GenericHandler(
356                functools.partial(set_call_compression, grpc.Compression.Gzip)
357            )
358            if server_call_compression
359            else _GenericHandler(None)
360        )
361        _get_compression_ratios(
362            client_function,
363            {},
364            {},
365            {},
366            _GenericHandler(None),
367            channel_kwargs,
368            multicallable_kwargs,
369            server_kwargs,
370            server_handler,
371            _REQUEST,
372        )
373
374    def testDisableNextCompressionStreaming(self):
375        server_kwargs = {
376            "compression": grpc.Compression.Deflate,
377        }
378        _get_compression_ratios(
379            _stream_stream_client,
380            {},
381            {},
382            {},
383            _GenericHandler(None),
384            {},
385            {},
386            server_kwargs,
387            _GenericHandler(disable_next_compression),
388            _REQUEST,
389        )
390
391    def testDisableNextCompressionStreamingResets(self):
392        server_kwargs = {
393            "compression": grpc.Compression.Deflate,
394        }
395        _get_compression_ratios(
396            _stream_stream_client,
397            {},
398            {},
399            {},
400            _GenericHandler(None),
401            {},
402            {},
403            server_kwargs,
404            _GenericHandler(disable_first_compression),
405            _REQUEST,
406        )
407
408
409def _get_compression_str(name, value):
410    return "{}{}".format(name, _COMPRESSION_NAMES[value])
411
412
413def _get_compression_test_name(
414    client_streaming,
415    server_streaming,
416    channel_compression,
417    multicallable_compression,
418    server_compression,
419    server_call_compression,
420):
421    client_arity = "Stream" if client_streaming else "Unary"
422    server_arity = "Stream" if server_streaming else "Unary"
423    arity = "{}{}".format(client_arity, server_arity)
424    channel_compression_str = _get_compression_str(
425        "Channel", channel_compression
426    )
427    multicallable_compression_str = _get_compression_str(
428        "Multicallable", multicallable_compression
429    )
430    server_compression_str = _get_compression_str("Server", server_compression)
431    server_call_compression_str = _get_compression_str(
432        "ServerCall", server_call_compression
433    )
434    return "test{}{}{}{}{}".format(
435        arity,
436        channel_compression_str,
437        multicallable_compression_str,
438        server_compression_str,
439        server_call_compression_str,
440    )
441
442
443def _test_options():
444    for test_parameters in itertools.product(*_TEST_OPTIONS.values()):
445        yield dict(zip(_TEST_OPTIONS.keys(), test_parameters))
446
447
448for options in _test_options():
449
450    def test_compression(**kwargs):
451        def _test_compression(self):
452            self.assertConfigurationCompressed(**kwargs)
453
454        return _test_compression
455
456    setattr(
457        CompressionTest,
458        _get_compression_test_name(**options),
459        test_compression(**options),
460    )
461
462if __name__ == "__main__":
463    logging.basicConfig()
464    unittest.main(verbosity=2)
465