1# Copyright 2016 gRPC authors. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14"""Tests server and client side compression.""" 15 16from concurrent import futures 17import contextlib 18import functools 19import itertools 20import logging 21import os 22import unittest 23 24import grpc 25from grpc import _grpcio_metadata 26 27from tests.unit import _tcp_proxy 28from tests.unit.framework.common import test_constants 29 30_UNARY_UNARY = "/test/UnaryUnary" 31_UNARY_STREAM = "/test/UnaryStream" 32_STREAM_UNARY = "/test/StreamUnary" 33_STREAM_STREAM = "/test/StreamStream" 34 35# Cut down on test time. 36_STREAM_LENGTH = test_constants.STREAM_LENGTH // 16 37 38_HOST = "localhost" 39 40_REQUEST = b"\x00" * 100 41_COMPRESSION_RATIO_THRESHOLD = 0.05 42_COMPRESSION_METHODS = ( 43 None, 44 # Disabled for test tractability. 45 # grpc.Compression.NoCompression, 46 # grpc.Compression.Deflate, 47 grpc.Compression.Gzip, 48) 49_COMPRESSION_NAMES = { 50 None: "Uncompressed", 51 grpc.Compression.NoCompression: "NoCompression", 52 grpc.Compression.Deflate: "DeflateCompression", 53 grpc.Compression.Gzip: "GzipCompression", 54} 55 56_TEST_OPTIONS = { 57 "client_streaming": (True, False), 58 "server_streaming": (True, False), 59 "channel_compression": _COMPRESSION_METHODS, 60 "multicallable_compression": _COMPRESSION_METHODS, 61 "server_compression": _COMPRESSION_METHODS, 62 "server_call_compression": _COMPRESSION_METHODS, 63} 64 65 66def _make_handle_unary_unary(pre_response_callback): 67 def _handle_unary(request, servicer_context): 68 if pre_response_callback: 69 pre_response_callback(request, servicer_context) 70 return request 71 72 return _handle_unary 73 74 75def _make_handle_unary_stream(pre_response_callback): 76 def _handle_unary_stream(request, servicer_context): 77 if pre_response_callback: 78 pre_response_callback(request, servicer_context) 79 for _ in range(_STREAM_LENGTH): 80 yield request 81 82 return _handle_unary_stream 83 84 85def _make_handle_stream_unary(pre_response_callback): 86 def _handle_stream_unary(request_iterator, servicer_context): 87 if pre_response_callback: 88 pre_response_callback(request_iterator, servicer_context) 89 response = None 90 for request in request_iterator: 91 if not response: 92 response = request 93 return response 94 95 return _handle_stream_unary 96 97 98def _make_handle_stream_stream(pre_response_callback): 99 def _handle_stream(request_iterator, servicer_context): 100 # TODO(issue:#6891) We should be able to remove this loop, 101 # and replace with return; yield 102 for request in request_iterator: 103 if pre_response_callback: 104 pre_response_callback(request, servicer_context) 105 yield request 106 107 return _handle_stream 108 109 110def set_call_compression( 111 compression_method, request_or_iterator, servicer_context 112): 113 del request_or_iterator 114 servicer_context.set_compression(compression_method) 115 116 117def disable_next_compression(request, servicer_context): 118 del request 119 servicer_context.disable_next_message_compression() 120 121 122def disable_first_compression(request, servicer_context): 123 if int(request.decode("ascii")) == 0: 124 servicer_context.disable_next_message_compression() 125 126 127class _MethodHandler(grpc.RpcMethodHandler): 128 def __init__( 129 self, request_streaming, response_streaming, pre_response_callback 130 ): 131 self.request_streaming = request_streaming 132 self.response_streaming = response_streaming 133 self.request_deserializer = None 134 self.response_serializer = None 135 self.unary_unary = None 136 self.unary_stream = None 137 self.stream_unary = None 138 self.stream_stream = None 139 140 if self.request_streaming and self.response_streaming: 141 self.stream_stream = _make_handle_stream_stream( 142 pre_response_callback 143 ) 144 elif not self.request_streaming and not self.response_streaming: 145 self.unary_unary = _make_handle_unary_unary(pre_response_callback) 146 elif not self.request_streaming and self.response_streaming: 147 self.unary_stream = _make_handle_unary_stream(pre_response_callback) 148 else: 149 self.stream_unary = _make_handle_stream_unary(pre_response_callback) 150 151 152class _GenericHandler(grpc.GenericRpcHandler): 153 def __init__(self, pre_response_callback): 154 self._pre_response_callback = pre_response_callback 155 156 def service(self, handler_call_details): 157 if handler_call_details.method == _UNARY_UNARY: 158 return _MethodHandler(False, False, self._pre_response_callback) 159 elif handler_call_details.method == _UNARY_STREAM: 160 return _MethodHandler(False, True, self._pre_response_callback) 161 elif handler_call_details.method == _STREAM_UNARY: 162 return _MethodHandler(True, False, self._pre_response_callback) 163 elif handler_call_details.method == _STREAM_STREAM: 164 return _MethodHandler(True, True, self._pre_response_callback) 165 else: 166 return None 167 168 169@contextlib.contextmanager 170def _instrumented_client_server_pair( 171 channel_kwargs, server_kwargs, server_handler 172): 173 server = grpc.server(futures.ThreadPoolExecutor(), **server_kwargs) 174 server.add_generic_rpc_handlers((server_handler,)) 175 server_port = server.add_insecure_port("{}:0".format(_HOST)) 176 server.start() 177 with _tcp_proxy.TcpProxy(_HOST, _HOST, server_port) as proxy: 178 proxy_port = proxy.get_port() 179 with grpc.insecure_channel( 180 "{}:{}".format(_HOST, proxy_port), **channel_kwargs 181 ) as client_channel: 182 try: 183 yield client_channel, proxy, server 184 finally: 185 server.stop(None) 186 187 188def _get_byte_counts( 189 channel_kwargs, 190 multicallable_kwargs, 191 client_function, 192 server_kwargs, 193 server_handler, 194 message, 195): 196 with _instrumented_client_server_pair( 197 channel_kwargs, server_kwargs, server_handler 198 ) as pipeline: 199 client_channel, proxy, server = pipeline 200 client_function(client_channel, multicallable_kwargs, message) 201 return proxy.get_byte_count() 202 203 204def _get_compression_ratios( 205 client_function, 206 first_channel_kwargs, 207 first_multicallable_kwargs, 208 first_server_kwargs, 209 first_server_handler, 210 second_channel_kwargs, 211 second_multicallable_kwargs, 212 second_server_kwargs, 213 second_server_handler, 214 message, 215): 216 first_bytes_sent, first_bytes_received = _get_byte_counts( 217 first_channel_kwargs, 218 first_multicallable_kwargs, 219 client_function, 220 first_server_kwargs, 221 first_server_handler, 222 message, 223 ) 224 second_bytes_sent, second_bytes_received = _get_byte_counts( 225 second_channel_kwargs, 226 second_multicallable_kwargs, 227 client_function, 228 second_server_kwargs, 229 second_server_handler, 230 message, 231 ) 232 return ( 233 (second_bytes_sent - first_bytes_sent) / float(first_bytes_sent), 234 (second_bytes_received - first_bytes_received) 235 / float(first_bytes_received), 236 ) 237 238 239def _unary_unary_client(channel, multicallable_kwargs, message): 240 multi_callable = channel.unary_unary( 241 _UNARY_UNARY, 242 _registered_method=True, 243 ) 244 response = multi_callable(message, **multicallable_kwargs) 245 if response != message: 246 raise RuntimeError( 247 "Request '{}' != Response '{}'".format(message, response) 248 ) 249 250 251def _unary_stream_client(channel, multicallable_kwargs, message): 252 multi_callable = channel.unary_stream( 253 _UNARY_STREAM, 254 _registered_method=True, 255 ) 256 response_iterator = multi_callable(message, **multicallable_kwargs) 257 for response in response_iterator: 258 if response != message: 259 raise RuntimeError( 260 "Request '{}' != Response '{}'".format(message, response) 261 ) 262 263 264def _stream_unary_client(channel, multicallable_kwargs, message): 265 multi_callable = channel.stream_unary( 266 _STREAM_UNARY, 267 _registered_method=True, 268 ) 269 requests = (_REQUEST for _ in range(_STREAM_LENGTH)) 270 response = multi_callable(requests, **multicallable_kwargs) 271 if response != message: 272 raise RuntimeError( 273 "Request '{}' != Response '{}'".format(message, response) 274 ) 275 276 277def _stream_stream_client(channel, multicallable_kwargs, message): 278 multi_callable = channel.stream_stream( 279 _STREAM_STREAM, 280 _registered_method=True, 281 ) 282 request_prefix = str(0).encode("ascii") * 100 283 requests = ( 284 request_prefix + str(i).encode("ascii") for i in range(_STREAM_LENGTH) 285 ) 286 response_iterator = multi_callable(requests, **multicallable_kwargs) 287 for i, response in enumerate(response_iterator): 288 if int(response.decode("ascii")) != i: 289 raise RuntimeError( 290 "Request '{}' != Response '{}'".format(i, response) 291 ) 292 293 294class CompressionTest(unittest.TestCase): 295 def assertCompressed(self, compression_ratio): 296 self.assertLess( 297 compression_ratio, 298 -1.0 * _COMPRESSION_RATIO_THRESHOLD, 299 msg="Actual compression ratio: {}".format(compression_ratio), 300 ) 301 302 def assertNotCompressed(self, compression_ratio): 303 self.assertGreaterEqual( 304 compression_ratio, 305 -1.0 * _COMPRESSION_RATIO_THRESHOLD, 306 msg="Actual compession ratio: {}".format(compression_ratio), 307 ) 308 309 def assertConfigurationCompressed( 310 self, 311 client_streaming, 312 server_streaming, 313 channel_compression, 314 multicallable_compression, 315 server_compression, 316 server_call_compression, 317 ): 318 client_side_compressed = ( 319 channel_compression or multicallable_compression 320 ) 321 server_side_compressed = server_compression or server_call_compression 322 channel_kwargs = ( 323 { 324 "compression": channel_compression, 325 } 326 if channel_compression 327 else {} 328 ) 329 multicallable_kwargs = ( 330 { 331 "compression": multicallable_compression, 332 } 333 if multicallable_compression 334 else {} 335 ) 336 337 client_function = None 338 if not client_streaming and not server_streaming: 339 client_function = _unary_unary_client 340 elif not client_streaming and server_streaming: 341 client_function = _unary_stream_client 342 elif client_streaming and not server_streaming: 343 client_function = _stream_unary_client 344 else: 345 client_function = _stream_stream_client 346 347 server_kwargs = ( 348 { 349 "compression": server_compression, 350 } 351 if server_compression 352 else {} 353 ) 354 server_handler = ( 355 _GenericHandler( 356 functools.partial(set_call_compression, grpc.Compression.Gzip) 357 ) 358 if server_call_compression 359 else _GenericHandler(None) 360 ) 361 _get_compression_ratios( 362 client_function, 363 {}, 364 {}, 365 {}, 366 _GenericHandler(None), 367 channel_kwargs, 368 multicallable_kwargs, 369 server_kwargs, 370 server_handler, 371 _REQUEST, 372 ) 373 374 def testDisableNextCompressionStreaming(self): 375 server_kwargs = { 376 "compression": grpc.Compression.Deflate, 377 } 378 _get_compression_ratios( 379 _stream_stream_client, 380 {}, 381 {}, 382 {}, 383 _GenericHandler(None), 384 {}, 385 {}, 386 server_kwargs, 387 _GenericHandler(disable_next_compression), 388 _REQUEST, 389 ) 390 391 def testDisableNextCompressionStreamingResets(self): 392 server_kwargs = { 393 "compression": grpc.Compression.Deflate, 394 } 395 _get_compression_ratios( 396 _stream_stream_client, 397 {}, 398 {}, 399 {}, 400 _GenericHandler(None), 401 {}, 402 {}, 403 server_kwargs, 404 _GenericHandler(disable_first_compression), 405 _REQUEST, 406 ) 407 408 409def _get_compression_str(name, value): 410 return "{}{}".format(name, _COMPRESSION_NAMES[value]) 411 412 413def _get_compression_test_name( 414 client_streaming, 415 server_streaming, 416 channel_compression, 417 multicallable_compression, 418 server_compression, 419 server_call_compression, 420): 421 client_arity = "Stream" if client_streaming else "Unary" 422 server_arity = "Stream" if server_streaming else "Unary" 423 arity = "{}{}".format(client_arity, server_arity) 424 channel_compression_str = _get_compression_str( 425 "Channel", channel_compression 426 ) 427 multicallable_compression_str = _get_compression_str( 428 "Multicallable", multicallable_compression 429 ) 430 server_compression_str = _get_compression_str("Server", server_compression) 431 server_call_compression_str = _get_compression_str( 432 "ServerCall", server_call_compression 433 ) 434 return "test{}{}{}{}{}".format( 435 arity, 436 channel_compression_str, 437 multicallable_compression_str, 438 server_compression_str, 439 server_call_compression_str, 440 ) 441 442 443def _test_options(): 444 for test_parameters in itertools.product(*_TEST_OPTIONS.values()): 445 yield dict(zip(_TEST_OPTIONS.keys(), test_parameters)) 446 447 448for options in _test_options(): 449 450 def test_compression(**kwargs): 451 def _test_compression(self): 452 self.assertConfigurationCompressed(**kwargs) 453 454 return _test_compression 455 456 setattr( 457 CompressionTest, 458 _get_compression_test_name(**options), 459 test_compression(**options), 460 ) 461 462if __name__ == "__main__": 463 logging.basicConfig() 464 unittest.main(verbosity=2) 465