1# Copyright 2016 gRPC authors. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14"""Service-side implementation of gRPC Python.""" 15 16from __future__ import annotations 17 18import collections 19from concurrent import futures 20import contextvars 21import enum 22import logging 23import threading 24import time 25import traceback 26from typing import ( 27 Any, 28 Callable, 29 Iterable, 30 Iterator, 31 List, 32 Mapping, 33 Optional, 34 Sequence, 35 Set, 36 Tuple, 37 Union, 38) 39 40import grpc # pytype: disable=pyi-error 41from grpc import _common # pytype: disable=pyi-error 42from grpc import _compression # pytype: disable=pyi-error 43from grpc import _interceptor # pytype: disable=pyi-error 44from grpc._cython import cygrpc 45from grpc._typing import ArityAgnosticMethodHandler 46from grpc._typing import ChannelArgumentType 47from grpc._typing import DeserializingFunction 48from grpc._typing import MetadataType 49from grpc._typing import NullaryCallbackType 50from grpc._typing import ResponseType 51from grpc._typing import SerializingFunction 52from grpc._typing import ServerCallbackTag 53from grpc._typing import ServerTagCallbackType 54 55_LOGGER = logging.getLogger(__name__) 56 57_SHUTDOWN_TAG = "shutdown" 58_REQUEST_CALL_TAG = "request_call" 59 60_RECEIVE_CLOSE_ON_SERVER_TOKEN = "receive_close_on_server" 61_SEND_INITIAL_METADATA_TOKEN = "send_initial_metadata" 62_RECEIVE_MESSAGE_TOKEN = "receive_message" 63_SEND_MESSAGE_TOKEN = "send_message" 64_SEND_INITIAL_METADATA_AND_SEND_MESSAGE_TOKEN = ( 65 "send_initial_metadata * send_message" 66) 67_SEND_STATUS_FROM_SERVER_TOKEN = "send_status_from_server" 68_SEND_INITIAL_METADATA_AND_SEND_STATUS_FROM_SERVER_TOKEN = ( 69 "send_initial_metadata * send_status_from_server" 70) 71 72_OPEN = "open" 73_CLOSED = "closed" 74_CANCELLED = "cancelled" 75 76_EMPTY_FLAGS = 0 77 78_DEALLOCATED_SERVER_CHECK_PERIOD_S = 1.0 79_INF_TIMEOUT = 1e9 80 81 82def _serialized_request(request_event: cygrpc.BaseEvent) -> bytes: 83 return request_event.batch_operations[0].message() 84 85 86def _application_code(code: grpc.StatusCode) -> cygrpc.StatusCode: 87 cygrpc_code = _common.STATUS_CODE_TO_CYGRPC_STATUS_CODE.get(code) 88 return cygrpc.StatusCode.unknown if cygrpc_code is None else cygrpc_code 89 90 91def _completion_code(state: _RPCState) -> cygrpc.StatusCode: 92 if state.code is None: 93 return cygrpc.StatusCode.ok 94 else: 95 return _application_code(state.code) 96 97 98def _abortion_code( 99 state: _RPCState, code: cygrpc.StatusCode 100) -> cygrpc.StatusCode: 101 if state.code is None: 102 return code 103 else: 104 return _application_code(state.code) 105 106 107def _details(state: _RPCState) -> bytes: 108 return b"" if state.details is None else state.details 109 110 111class _HandlerCallDetails( 112 collections.namedtuple( 113 "_HandlerCallDetails", 114 ( 115 "method", 116 "invocation_metadata", 117 ), 118 ), 119 grpc.HandlerCallDetails, 120): 121 pass 122 123 124class _RPCState(object): 125 context: contextvars.Context 126 condition: threading.Condition 127 due = Set[str] 128 request: Any 129 client: str 130 initial_metadata_allowed: bool 131 compression_algorithm: Optional[grpc.Compression] 132 disable_next_compression: bool 133 trailing_metadata: Optional[MetadataType] 134 code: Optional[grpc.StatusCode] 135 details: Optional[bytes] 136 statused: bool 137 rpc_errors: List[Exception] 138 callbacks: Optional[List[NullaryCallbackType]] 139 aborted: bool 140 141 def __init__(self): 142 self.context = contextvars.Context() 143 self.condition = threading.Condition() 144 self.due = set() 145 self.request = None 146 self.client = _OPEN 147 self.initial_metadata_allowed = True 148 self.compression_algorithm = None 149 self.disable_next_compression = False 150 self.trailing_metadata = None 151 self.code = None 152 self.details = None 153 self.statused = False 154 self.rpc_errors = [] 155 self.callbacks = [] 156 self.aborted = False 157 158 159def _raise_rpc_error(state: _RPCState) -> None: 160 rpc_error = grpc.RpcError() 161 state.rpc_errors.append(rpc_error) 162 raise rpc_error 163 164 165def _possibly_finish_call( 166 state: _RPCState, token: str 167) -> ServerTagCallbackType: 168 state.due.remove(token) 169 if not _is_rpc_state_active(state) and not state.due: 170 callbacks = state.callbacks 171 state.callbacks = None 172 return state, callbacks 173 else: 174 return None, () 175 176 177def _send_status_from_server(state: _RPCState, token: str) -> ServerCallbackTag: 178 def send_status_from_server(unused_send_status_from_server_event): 179 with state.condition: 180 return _possibly_finish_call(state, token) 181 182 return send_status_from_server 183 184 185def _get_initial_metadata( 186 state: _RPCState, metadata: Optional[MetadataType] 187) -> Optional[MetadataType]: 188 with state.condition: 189 if state.compression_algorithm: 190 compression_metadata = ( 191 _compression.compression_algorithm_to_metadata( 192 state.compression_algorithm 193 ), 194 ) 195 if metadata is None: 196 return compression_metadata 197 else: 198 return compression_metadata + tuple(metadata) 199 else: 200 return metadata 201 202 203def _get_initial_metadata_operation( 204 state: _RPCState, metadata: Optional[MetadataType] 205) -> cygrpc.Operation: 206 operation = cygrpc.SendInitialMetadataOperation( 207 _get_initial_metadata(state, metadata), _EMPTY_FLAGS 208 ) 209 return operation 210 211 212def _abort( 213 state: _RPCState, call: cygrpc.Call, code: cygrpc.StatusCode, details: bytes 214) -> None: 215 if state.client is not _CANCELLED: 216 effective_code = _abortion_code(state, code) 217 effective_details = details if state.details is None else state.details 218 if state.initial_metadata_allowed: 219 operations = ( 220 _get_initial_metadata_operation(state, None), 221 cygrpc.SendStatusFromServerOperation( 222 state.trailing_metadata, 223 effective_code, 224 effective_details, 225 _EMPTY_FLAGS, 226 ), 227 ) 228 token = _SEND_INITIAL_METADATA_AND_SEND_STATUS_FROM_SERVER_TOKEN 229 else: 230 operations = ( 231 cygrpc.SendStatusFromServerOperation( 232 state.trailing_metadata, 233 effective_code, 234 effective_details, 235 _EMPTY_FLAGS, 236 ), 237 ) 238 token = _SEND_STATUS_FROM_SERVER_TOKEN 239 call.start_server_batch( 240 operations, _send_status_from_server(state, token) 241 ) 242 state.statused = True 243 state.due.add(token) 244 245 246def _receive_close_on_server(state: _RPCState) -> ServerCallbackTag: 247 def receive_close_on_server(receive_close_on_server_event): 248 with state.condition: 249 if receive_close_on_server_event.batch_operations[0].cancelled(): 250 state.client = _CANCELLED 251 elif state.client is _OPEN: 252 state.client = _CLOSED 253 state.condition.notify_all() 254 return _possibly_finish_call(state, _RECEIVE_CLOSE_ON_SERVER_TOKEN) 255 256 return receive_close_on_server 257 258 259def _receive_message( 260 state: _RPCState, 261 call: cygrpc.Call, 262 request_deserializer: Optional[DeserializingFunction], 263) -> ServerCallbackTag: 264 def receive_message(receive_message_event): 265 serialized_request = _serialized_request(receive_message_event) 266 if serialized_request is None: 267 with state.condition: 268 if state.client is _OPEN: 269 state.client = _CLOSED 270 state.condition.notify_all() 271 return _possibly_finish_call(state, _RECEIVE_MESSAGE_TOKEN) 272 else: 273 request = _common.deserialize( 274 serialized_request, request_deserializer 275 ) 276 with state.condition: 277 if request is None: 278 _abort( 279 state, 280 call, 281 cygrpc.StatusCode.internal, 282 b"Exception deserializing request!", 283 ) 284 else: 285 state.request = request 286 state.condition.notify_all() 287 return _possibly_finish_call(state, _RECEIVE_MESSAGE_TOKEN) 288 289 return receive_message 290 291 292def _send_initial_metadata(state: _RPCState) -> ServerCallbackTag: 293 def send_initial_metadata(unused_send_initial_metadata_event): 294 with state.condition: 295 return _possibly_finish_call(state, _SEND_INITIAL_METADATA_TOKEN) 296 297 return send_initial_metadata 298 299 300def _send_message(state: _RPCState, token: str) -> ServerCallbackTag: 301 def send_message(unused_send_message_event): 302 with state.condition: 303 state.condition.notify_all() 304 return _possibly_finish_call(state, token) 305 306 return send_message 307 308 309class _Context(grpc.ServicerContext): 310 _rpc_event: cygrpc.BaseEvent 311 _state: _RPCState 312 request_deserializer: Optional[DeserializingFunction] 313 314 def __init__( 315 self, 316 rpc_event: cygrpc.BaseEvent, 317 state: _RPCState, 318 request_deserializer: Optional[DeserializingFunction], 319 ): 320 self._rpc_event = rpc_event 321 self._state = state 322 self._request_deserializer = request_deserializer 323 324 def is_active(self) -> bool: 325 with self._state.condition: 326 return _is_rpc_state_active(self._state) 327 328 def time_remaining(self) -> float: 329 return max(self._rpc_event.call_details.deadline - time.time(), 0) 330 331 def cancel(self) -> None: 332 self._rpc_event.call.cancel() 333 334 def add_callback(self, callback: NullaryCallbackType) -> bool: 335 with self._state.condition: 336 if self._state.callbacks is None: 337 return False 338 else: 339 self._state.callbacks.append(callback) 340 return True 341 342 def disable_next_message_compression(self) -> None: 343 with self._state.condition: 344 self._state.disable_next_compression = True 345 346 def invocation_metadata(self) -> Optional[MetadataType]: 347 return self._rpc_event.invocation_metadata 348 349 def peer(self) -> str: 350 return _common.decode(self._rpc_event.call.peer()) 351 352 def peer_identities(self) -> Optional[Sequence[bytes]]: 353 return cygrpc.peer_identities(self._rpc_event.call) 354 355 def peer_identity_key(self) -> Optional[str]: 356 id_key = cygrpc.peer_identity_key(self._rpc_event.call) 357 return id_key if id_key is None else _common.decode(id_key) 358 359 def auth_context(self) -> Mapping[str, Sequence[bytes]]: 360 auth_context = cygrpc.auth_context(self._rpc_event.call) 361 auth_context_dict = {} if auth_context is None else auth_context 362 return { 363 _common.decode(key): value 364 for key, value in auth_context_dict.items() 365 } 366 367 def set_compression(self, compression: grpc.Compression) -> None: 368 with self._state.condition: 369 self._state.compression_algorithm = compression 370 371 def send_initial_metadata(self, initial_metadata: MetadataType) -> None: 372 with self._state.condition: 373 if self._state.client is _CANCELLED: 374 _raise_rpc_error(self._state) 375 else: 376 if self._state.initial_metadata_allowed: 377 operation = _get_initial_metadata_operation( 378 self._state, initial_metadata 379 ) 380 self._rpc_event.call.start_server_batch( 381 (operation,), _send_initial_metadata(self._state) 382 ) 383 self._state.initial_metadata_allowed = False 384 self._state.due.add(_SEND_INITIAL_METADATA_TOKEN) 385 else: 386 raise ValueError("Initial metadata no longer allowed!") 387 388 def set_trailing_metadata(self, trailing_metadata: MetadataType) -> None: 389 with self._state.condition: 390 self._state.trailing_metadata = trailing_metadata 391 392 def trailing_metadata(self) -> Optional[MetadataType]: 393 return self._state.trailing_metadata 394 395 def abort(self, code: grpc.StatusCode, details: str) -> None: 396 # treat OK like other invalid arguments: fail the RPC 397 if code == grpc.StatusCode.OK: 398 _LOGGER.error( 399 "abort() called with StatusCode.OK; returning UNKNOWN" 400 ) 401 code = grpc.StatusCode.UNKNOWN 402 details = "" 403 with self._state.condition: 404 self._state.code = code 405 self._state.details = _common.encode(details) 406 self._state.aborted = True 407 raise Exception() 408 409 def abort_with_status(self, status: grpc.Status) -> None: 410 self._state.trailing_metadata = status.trailing_metadata 411 self.abort(status.code, status.details) 412 413 def set_code(self, code: grpc.StatusCode) -> None: 414 with self._state.condition: 415 self._state.code = code 416 417 def code(self) -> grpc.StatusCode: 418 return self._state.code 419 420 def set_details(self, details: str) -> None: 421 with self._state.condition: 422 self._state.details = _common.encode(details) 423 424 def details(self) -> bytes: 425 return self._state.details 426 427 def _finalize_state(self) -> None: 428 pass 429 430 431class _RequestIterator(object): 432 _state: _RPCState 433 _call: cygrpc.Call 434 _request_deserializer: Optional[DeserializingFunction] 435 436 def __init__( 437 self, 438 state: _RPCState, 439 call: cygrpc.Call, 440 request_deserializer: Optional[DeserializingFunction], 441 ): 442 self._state = state 443 self._call = call 444 self._request_deserializer = request_deserializer 445 446 def _raise_or_start_receive_message(self) -> None: 447 if self._state.client is _CANCELLED: 448 _raise_rpc_error(self._state) 449 elif not _is_rpc_state_active(self._state): 450 raise StopIteration() 451 else: 452 self._call.start_server_batch( 453 (cygrpc.ReceiveMessageOperation(_EMPTY_FLAGS),), 454 _receive_message( 455 self._state, self._call, self._request_deserializer 456 ), 457 ) 458 self._state.due.add(_RECEIVE_MESSAGE_TOKEN) 459 460 def _look_for_request(self) -> Any: 461 if self._state.client is _CANCELLED: 462 _raise_rpc_error(self._state) 463 elif ( 464 self._state.request is None 465 and _RECEIVE_MESSAGE_TOKEN not in self._state.due 466 ): 467 raise StopIteration() 468 else: 469 request = self._state.request 470 self._state.request = None 471 return request 472 473 raise AssertionError() # should never run 474 475 def _next(self) -> Any: 476 with self._state.condition: 477 self._raise_or_start_receive_message() 478 while True: 479 self._state.condition.wait() 480 request = self._look_for_request() 481 if request is not None: 482 return request 483 484 def __iter__(self) -> _RequestIterator: 485 return self 486 487 def __next__(self) -> Any: 488 return self._next() 489 490 def next(self) -> Any: 491 return self._next() 492 493 494def _unary_request( 495 rpc_event: cygrpc.BaseEvent, 496 state: _RPCState, 497 request_deserializer: Optional[DeserializingFunction], 498) -> Callable[[], Any]: 499 def unary_request(): 500 with state.condition: 501 if not _is_rpc_state_active(state): 502 return None 503 else: 504 rpc_event.call.start_server_batch( 505 (cygrpc.ReceiveMessageOperation(_EMPTY_FLAGS),), 506 _receive_message( 507 state, rpc_event.call, request_deserializer 508 ), 509 ) 510 state.due.add(_RECEIVE_MESSAGE_TOKEN) 511 while True: 512 state.condition.wait() 513 if state.request is None: 514 if state.client is _CLOSED: 515 details = '"{}" requires exactly one request message.'.format( 516 rpc_event.call_details.method 517 ) 518 _abort( 519 state, 520 rpc_event.call, 521 cygrpc.StatusCode.unimplemented, 522 _common.encode(details), 523 ) 524 return None 525 elif state.client is _CANCELLED: 526 return None 527 else: 528 request = state.request 529 state.request = None 530 return request 531 532 return unary_request 533 534 535def _call_behavior( 536 rpc_event: cygrpc.BaseEvent, 537 state: _RPCState, 538 behavior: ArityAgnosticMethodHandler, 539 argument: Any, 540 request_deserializer: Optional[DeserializingFunction], 541 send_response_callback: Optional[Callable[[ResponseType], None]] = None, 542) -> Tuple[Union[ResponseType, Iterator[ResponseType]], bool]: 543 from grpc import _create_servicer_context # pytype: disable=pyi-error 544 545 with _create_servicer_context( 546 rpc_event, state, request_deserializer 547 ) as context: 548 try: 549 response_or_iterator = None 550 if send_response_callback is not None: 551 response_or_iterator = behavior( 552 argument, context, send_response_callback 553 ) 554 else: 555 response_or_iterator = behavior(argument, context) 556 return response_or_iterator, True 557 except Exception as exception: # pylint: disable=broad-except 558 with state.condition: 559 if state.aborted: 560 _abort( 561 state, 562 rpc_event.call, 563 cygrpc.StatusCode.unknown, 564 b"RPC Aborted", 565 ) 566 elif exception not in state.rpc_errors: 567 try: 568 details = "Exception calling application: {}".format( 569 exception 570 ) 571 except Exception: # pylint: disable=broad-except 572 details = ( 573 "Calling application raised unprintable Exception!" 574 ) 575 _LOGGER.exception( 576 traceback.format_exception( 577 type(exception), 578 exception, 579 exception.__traceback__, 580 ) 581 ) 582 traceback.print_exc() 583 _LOGGER.exception(details) 584 _abort( 585 state, 586 rpc_event.call, 587 cygrpc.StatusCode.unknown, 588 _common.encode(details), 589 ) 590 return None, False 591 592 593def _take_response_from_response_iterator( 594 rpc_event: cygrpc.BaseEvent, 595 state: _RPCState, 596 response_iterator: Iterator[ResponseType], 597) -> Tuple[ResponseType, bool]: 598 try: 599 return next(response_iterator), True 600 except StopIteration: 601 return None, True 602 except Exception as exception: # pylint: disable=broad-except 603 with state.condition: 604 if state.aborted: 605 _abort( 606 state, 607 rpc_event.call, 608 cygrpc.StatusCode.unknown, 609 b"RPC Aborted", 610 ) 611 elif exception not in state.rpc_errors: 612 details = "Exception iterating responses: {}".format(exception) 613 _LOGGER.exception(details) 614 _abort( 615 state, 616 rpc_event.call, 617 cygrpc.StatusCode.unknown, 618 _common.encode(details), 619 ) 620 return None, False 621 622 623def _serialize_response( 624 rpc_event: cygrpc.BaseEvent, 625 state: _RPCState, 626 response: Any, 627 response_serializer: Optional[SerializingFunction], 628) -> Optional[bytes]: 629 serialized_response = _common.serialize(response, response_serializer) 630 if serialized_response is None: 631 with state.condition: 632 _abort( 633 state, 634 rpc_event.call, 635 cygrpc.StatusCode.internal, 636 b"Failed to serialize response!", 637 ) 638 return None 639 else: 640 return serialized_response 641 642 643def _get_send_message_op_flags_from_state( 644 state: _RPCState, 645) -> Union[int, cygrpc.WriteFlag]: 646 if state.disable_next_compression: 647 return cygrpc.WriteFlag.no_compress 648 else: 649 return _EMPTY_FLAGS 650 651 652def _reset_per_message_state(state: _RPCState) -> None: 653 with state.condition: 654 state.disable_next_compression = False 655 656 657def _send_response( 658 rpc_event: cygrpc.BaseEvent, state: _RPCState, serialized_response: bytes 659) -> bool: 660 with state.condition: 661 if not _is_rpc_state_active(state): 662 return False 663 else: 664 if state.initial_metadata_allowed: 665 operations = ( 666 _get_initial_metadata_operation(state, None), 667 cygrpc.SendMessageOperation( 668 serialized_response, 669 _get_send_message_op_flags_from_state(state), 670 ), 671 ) 672 state.initial_metadata_allowed = False 673 token = _SEND_INITIAL_METADATA_AND_SEND_MESSAGE_TOKEN 674 else: 675 operations = ( 676 cygrpc.SendMessageOperation( 677 serialized_response, 678 _get_send_message_op_flags_from_state(state), 679 ), 680 ) 681 token = _SEND_MESSAGE_TOKEN 682 rpc_event.call.start_server_batch( 683 operations, _send_message(state, token) 684 ) 685 state.due.add(token) 686 _reset_per_message_state(state) 687 while True: 688 state.condition.wait() 689 if token not in state.due: 690 return _is_rpc_state_active(state) 691 692 693def _status( 694 rpc_event: cygrpc.BaseEvent, 695 state: _RPCState, 696 serialized_response: Optional[bytes], 697) -> None: 698 with state.condition: 699 if state.client is not _CANCELLED: 700 code = _completion_code(state) 701 details = _details(state) 702 operations = [ 703 cygrpc.SendStatusFromServerOperation( 704 state.trailing_metadata, code, details, _EMPTY_FLAGS 705 ), 706 ] 707 if state.initial_metadata_allowed: 708 operations.append(_get_initial_metadata_operation(state, None)) 709 if serialized_response is not None: 710 operations.append( 711 cygrpc.SendMessageOperation( 712 serialized_response, 713 _get_send_message_op_flags_from_state(state), 714 ) 715 ) 716 rpc_event.call.start_server_batch( 717 operations, 718 _send_status_from_server(state, _SEND_STATUS_FROM_SERVER_TOKEN), 719 ) 720 state.statused = True 721 _reset_per_message_state(state) 722 state.due.add(_SEND_STATUS_FROM_SERVER_TOKEN) 723 724 725def _unary_response_in_pool( 726 rpc_event: cygrpc.BaseEvent, 727 state: _RPCState, 728 behavior: ArityAgnosticMethodHandler, 729 argument_thunk: Callable[[], Any], 730 request_deserializer: Optional[SerializingFunction], 731 response_serializer: Optional[SerializingFunction], 732) -> None: 733 cygrpc.install_context_from_request_call_event(rpc_event) 734 735 try: 736 argument = argument_thunk() 737 if argument is not None: 738 response, proceed = _call_behavior( 739 rpc_event, state, behavior, argument, request_deserializer 740 ) 741 if proceed: 742 serialized_response = _serialize_response( 743 rpc_event, state, response, response_serializer 744 ) 745 if serialized_response is not None: 746 _status(rpc_event, state, serialized_response) 747 except Exception: # pylint: disable=broad-except 748 traceback.print_exc() 749 finally: 750 cygrpc.uninstall_context() 751 752 753def _stream_response_in_pool( 754 rpc_event: cygrpc.BaseEvent, 755 state: _RPCState, 756 behavior: ArityAgnosticMethodHandler, 757 argument_thunk: Callable[[], Any], 758 request_deserializer: Optional[DeserializingFunction], 759 response_serializer: Optional[SerializingFunction], 760) -> None: 761 cygrpc.install_context_from_request_call_event(rpc_event) 762 763 def send_response(response: Any) -> None: 764 if response is None: 765 _status(rpc_event, state, None) 766 else: 767 serialized_response = _serialize_response( 768 rpc_event, state, response, response_serializer 769 ) 770 if serialized_response is not None: 771 _send_response(rpc_event, state, serialized_response) 772 773 try: 774 argument = argument_thunk() 775 if argument is not None: 776 if ( 777 hasattr(behavior, "experimental_non_blocking") 778 and behavior.experimental_non_blocking 779 ): 780 _call_behavior( 781 rpc_event, 782 state, 783 behavior, 784 argument, 785 request_deserializer, 786 send_response_callback=send_response, 787 ) 788 else: 789 response_iterator, proceed = _call_behavior( 790 rpc_event, state, behavior, argument, request_deserializer 791 ) 792 if proceed: 793 _send_message_callback_to_blocking_iterator_adapter( 794 rpc_event, state, send_response, response_iterator 795 ) 796 except Exception: # pylint: disable=broad-except 797 traceback.print_exc() 798 finally: 799 cygrpc.uninstall_context() 800 801 802def _is_rpc_state_active(state: _RPCState) -> bool: 803 return state.client is not _CANCELLED and not state.statused 804 805 806def _send_message_callback_to_blocking_iterator_adapter( 807 rpc_event: cygrpc.BaseEvent, 808 state: _RPCState, 809 send_response_callback: Callable[[ResponseType], None], 810 response_iterator: Iterator[ResponseType], 811) -> None: 812 while True: 813 response, proceed = _take_response_from_response_iterator( 814 rpc_event, state, response_iterator 815 ) 816 if proceed: 817 send_response_callback(response) 818 if not _is_rpc_state_active(state): 819 break 820 else: 821 break 822 823 824def _select_thread_pool_for_behavior( 825 behavior: ArityAgnosticMethodHandler, 826 default_thread_pool: futures.ThreadPoolExecutor, 827) -> futures.ThreadPoolExecutor: 828 if hasattr(behavior, "experimental_thread_pool") and isinstance( 829 behavior.experimental_thread_pool, futures.ThreadPoolExecutor 830 ): 831 return behavior.experimental_thread_pool 832 else: 833 return default_thread_pool 834 835 836def _handle_unary_unary( 837 rpc_event: cygrpc.BaseEvent, 838 state: _RPCState, 839 method_handler: grpc.RpcMethodHandler, 840 default_thread_pool: futures.ThreadPoolExecutor, 841) -> futures.Future: 842 unary_request = _unary_request( 843 rpc_event, state, method_handler.request_deserializer 844 ) 845 thread_pool = _select_thread_pool_for_behavior( 846 method_handler.unary_unary, default_thread_pool 847 ) 848 return thread_pool.submit( 849 state.context.run, 850 _unary_response_in_pool, 851 rpc_event, 852 state, 853 method_handler.unary_unary, 854 unary_request, 855 method_handler.request_deserializer, 856 method_handler.response_serializer, 857 ) 858 859 860def _handle_unary_stream( 861 rpc_event: cygrpc.BaseEvent, 862 state: _RPCState, 863 method_handler: grpc.RpcMethodHandler, 864 default_thread_pool: futures.ThreadPoolExecutor, 865) -> futures.Future: 866 unary_request = _unary_request( 867 rpc_event, state, method_handler.request_deserializer 868 ) 869 thread_pool = _select_thread_pool_for_behavior( 870 method_handler.unary_stream, default_thread_pool 871 ) 872 return thread_pool.submit( 873 state.context.run, 874 _stream_response_in_pool, 875 rpc_event, 876 state, 877 method_handler.unary_stream, 878 unary_request, 879 method_handler.request_deserializer, 880 method_handler.response_serializer, 881 ) 882 883 884def _handle_stream_unary( 885 rpc_event: cygrpc.BaseEvent, 886 state: _RPCState, 887 method_handler: grpc.RpcMethodHandler, 888 default_thread_pool: futures.ThreadPoolExecutor, 889) -> futures.Future: 890 request_iterator = _RequestIterator( 891 state, rpc_event.call, method_handler.request_deserializer 892 ) 893 thread_pool = _select_thread_pool_for_behavior( 894 method_handler.stream_unary, default_thread_pool 895 ) 896 return thread_pool.submit( 897 state.context.run, 898 _unary_response_in_pool, 899 rpc_event, 900 state, 901 method_handler.stream_unary, 902 lambda: request_iterator, 903 method_handler.request_deserializer, 904 method_handler.response_serializer, 905 ) 906 907 908def _handle_stream_stream( 909 rpc_event: cygrpc.BaseEvent, 910 state: _RPCState, 911 method_handler: grpc.RpcMethodHandler, 912 default_thread_pool: futures.ThreadPoolExecutor, 913) -> futures.Future: 914 request_iterator = _RequestIterator( 915 state, rpc_event.call, method_handler.request_deserializer 916 ) 917 thread_pool = _select_thread_pool_for_behavior( 918 method_handler.stream_stream, default_thread_pool 919 ) 920 return thread_pool.submit( 921 state.context.run, 922 _stream_response_in_pool, 923 rpc_event, 924 state, 925 method_handler.stream_stream, 926 lambda: request_iterator, 927 method_handler.request_deserializer, 928 method_handler.response_serializer, 929 ) 930 931 932def _find_method_handler( 933 rpc_event: cygrpc.BaseEvent, 934 state: _RPCState, 935 generic_handlers: List[grpc.GenericRpcHandler], 936 interceptor_pipeline: Optional[_interceptor._ServicePipeline], 937) -> Optional[grpc.RpcMethodHandler]: 938 def query_handlers( 939 handler_call_details: _HandlerCallDetails, 940 ) -> Optional[grpc.RpcMethodHandler]: 941 for generic_handler in generic_handlers: 942 method_handler = generic_handler.service(handler_call_details) 943 if method_handler is not None: 944 return method_handler 945 return None 946 947 handler_call_details = _HandlerCallDetails( 948 _common.decode(rpc_event.call_details.method), 949 rpc_event.invocation_metadata, 950 ) 951 952 if interceptor_pipeline is not None: 953 return state.context.run( 954 interceptor_pipeline.execute, query_handlers, handler_call_details 955 ) 956 else: 957 return state.context.run(query_handlers, handler_call_details) 958 959 960def _reject_rpc( 961 rpc_event: cygrpc.BaseEvent, 962 rpc_state: _RPCState, 963 status: cygrpc.StatusCode, 964 details: bytes, 965): 966 operations = ( 967 _get_initial_metadata_operation(rpc_state, None), 968 cygrpc.ReceiveCloseOnServerOperation(_EMPTY_FLAGS), 969 cygrpc.SendStatusFromServerOperation( 970 None, status, details, _EMPTY_FLAGS 971 ), 972 ) 973 rpc_event.call.start_server_batch( 974 operations, 975 lambda ignored_event: ( 976 rpc_state, 977 (), 978 ), 979 ) 980 981 982def _handle_with_method_handler( 983 rpc_event: cygrpc.BaseEvent, 984 state: _RPCState, 985 method_handler: grpc.RpcMethodHandler, 986 thread_pool: futures.ThreadPoolExecutor, 987) -> futures.Future: 988 with state.condition: 989 rpc_event.call.start_server_batch( 990 (cygrpc.ReceiveCloseOnServerOperation(_EMPTY_FLAGS),), 991 _receive_close_on_server(state), 992 ) 993 state.due.add(_RECEIVE_CLOSE_ON_SERVER_TOKEN) 994 if method_handler.request_streaming: 995 if method_handler.response_streaming: 996 return _handle_stream_stream( 997 rpc_event, state, method_handler, thread_pool 998 ) 999 else: 1000 return _handle_stream_unary( 1001 rpc_event, state, method_handler, thread_pool 1002 ) 1003 else: 1004 if method_handler.response_streaming: 1005 return _handle_unary_stream( 1006 rpc_event, state, method_handler, thread_pool 1007 ) 1008 else: 1009 return _handle_unary_unary( 1010 rpc_event, state, method_handler, thread_pool 1011 ) 1012 1013 1014def _handle_call( 1015 rpc_event: cygrpc.BaseEvent, 1016 generic_handlers: List[grpc.GenericRpcHandler], 1017 interceptor_pipeline: Optional[_interceptor._ServicePipeline], 1018 thread_pool: futures.ThreadPoolExecutor, 1019 concurrency_exceeded: bool, 1020) -> Tuple[Optional[_RPCState], Optional[futures.Future]]: 1021 if not rpc_event.success: 1022 return None, None 1023 if rpc_event.call_details.method is not None: 1024 rpc_state = _RPCState() 1025 try: 1026 method_handler = _find_method_handler( 1027 rpc_event, rpc_state, generic_handlers, interceptor_pipeline 1028 ) 1029 except Exception as exception: # pylint: disable=broad-except 1030 details = "Exception servicing handler: {}".format(exception) 1031 _LOGGER.exception(details) 1032 _reject_rpc( 1033 rpc_event, 1034 rpc_state, 1035 cygrpc.StatusCode.unknown, 1036 b"Error in service handler!", 1037 ) 1038 return rpc_state, None 1039 if method_handler is None: 1040 _reject_rpc( 1041 rpc_event, 1042 rpc_state, 1043 cygrpc.StatusCode.unimplemented, 1044 b"Method not found!", 1045 ) 1046 return rpc_state, None 1047 elif concurrency_exceeded: 1048 _reject_rpc( 1049 rpc_event, 1050 rpc_state, 1051 cygrpc.StatusCode.resource_exhausted, 1052 b"Concurrent RPC limit exceeded!", 1053 ) 1054 return rpc_state, None 1055 else: 1056 return ( 1057 rpc_state, 1058 _handle_with_method_handler( 1059 rpc_event, rpc_state, method_handler, thread_pool 1060 ), 1061 ) 1062 else: 1063 return None, None 1064 1065 1066@enum.unique 1067class _ServerStage(enum.Enum): 1068 STOPPED = "stopped" 1069 STARTED = "started" 1070 GRACE = "grace" 1071 1072 1073class _ServerState(object): 1074 lock: threading.RLock 1075 completion_queue: cygrpc.CompletionQueue 1076 server: cygrpc.Server 1077 generic_handlers: List[grpc.GenericRpcHandler] 1078 interceptor_pipeline: Optional[_interceptor._ServicePipeline] 1079 thread_pool: futures.ThreadPoolExecutor 1080 stage: _ServerStage 1081 termination_event: threading.Event 1082 shutdown_events: List[threading.Event] 1083 maximum_concurrent_rpcs: Optional[int] 1084 active_rpc_count: int 1085 rpc_states: Set[_RPCState] 1086 due: Set[str] 1087 server_deallocated: bool 1088 1089 # pylint: disable=too-many-arguments 1090 def __init__( 1091 self, 1092 completion_queue: cygrpc.CompletionQueue, 1093 server: cygrpc.Server, 1094 generic_handlers: Sequence[grpc.GenericRpcHandler], 1095 interceptor_pipeline: Optional[_interceptor._ServicePipeline], 1096 thread_pool: futures.ThreadPoolExecutor, 1097 maximum_concurrent_rpcs: Optional[int], 1098 ): 1099 self.lock = threading.RLock() 1100 self.completion_queue = completion_queue 1101 self.server = server 1102 self.generic_handlers = list(generic_handlers) 1103 self.interceptor_pipeline = interceptor_pipeline 1104 self.thread_pool = thread_pool 1105 self.stage = _ServerStage.STOPPED 1106 self.termination_event = threading.Event() 1107 self.shutdown_events = [self.termination_event] 1108 self.maximum_concurrent_rpcs = maximum_concurrent_rpcs 1109 self.active_rpc_count = 0 1110 1111 # TODO(https://github.com/grpc/grpc/issues/6597): eliminate these fields. 1112 self.rpc_states = set() 1113 self.due = set() 1114 1115 # A "volatile" flag to interrupt the daemon serving thread 1116 self.server_deallocated = False 1117 1118 1119def _add_generic_handlers( 1120 state: _ServerState, generic_handlers: Iterable[grpc.GenericRpcHandler] 1121) -> None: 1122 with state.lock: 1123 state.generic_handlers.extend(generic_handlers) 1124 1125 1126def _add_insecure_port(state: _ServerState, address: bytes) -> int: 1127 with state.lock: 1128 return state.server.add_http2_port(address) 1129 1130 1131def _add_secure_port( 1132 state: _ServerState, 1133 address: bytes, 1134 server_credentials: grpc.ServerCredentials, 1135) -> int: 1136 with state.lock: 1137 return state.server.add_http2_port( 1138 address, server_credentials._credentials 1139 ) 1140 1141 1142def _request_call(state: _ServerState) -> None: 1143 state.server.request_call( 1144 state.completion_queue, state.completion_queue, _REQUEST_CALL_TAG 1145 ) 1146 state.due.add(_REQUEST_CALL_TAG) 1147 1148 1149# TODO(https://github.com/grpc/grpc/issues/6597): delete this function. 1150def _stop_serving(state: _ServerState) -> bool: 1151 if not state.rpc_states and not state.due: 1152 state.server.destroy() 1153 for shutdown_event in state.shutdown_events: 1154 shutdown_event.set() 1155 state.stage = _ServerStage.STOPPED 1156 return True 1157 else: 1158 return False 1159 1160 1161def _on_call_completed(state: _ServerState) -> None: 1162 with state.lock: 1163 state.active_rpc_count -= 1 1164 1165 1166def _process_event_and_continue( 1167 state: _ServerState, event: cygrpc.BaseEvent 1168) -> bool: 1169 should_continue = True 1170 if event.tag is _SHUTDOWN_TAG: 1171 with state.lock: 1172 state.due.remove(_SHUTDOWN_TAG) 1173 if _stop_serving(state): 1174 should_continue = False 1175 elif event.tag is _REQUEST_CALL_TAG: 1176 with state.lock: 1177 state.due.remove(_REQUEST_CALL_TAG) 1178 concurrency_exceeded = ( 1179 state.maximum_concurrent_rpcs is not None 1180 and state.active_rpc_count >= state.maximum_concurrent_rpcs 1181 ) 1182 rpc_state, rpc_future = _handle_call( 1183 event, 1184 state.generic_handlers, 1185 state.interceptor_pipeline, 1186 state.thread_pool, 1187 concurrency_exceeded, 1188 ) 1189 if rpc_state is not None: 1190 state.rpc_states.add(rpc_state) 1191 if rpc_future is not None: 1192 state.active_rpc_count += 1 1193 rpc_future.add_done_callback( 1194 lambda unused_future: _on_call_completed(state) 1195 ) 1196 if state.stage is _ServerStage.STARTED: 1197 _request_call(state) 1198 elif _stop_serving(state): 1199 should_continue = False 1200 else: 1201 rpc_state, callbacks = event.tag(event) 1202 for callback in callbacks: 1203 try: 1204 callback() 1205 except Exception: # pylint: disable=broad-except 1206 _LOGGER.exception("Exception calling callback!") 1207 if rpc_state is not None: 1208 with state.lock: 1209 state.rpc_states.remove(rpc_state) 1210 if _stop_serving(state): 1211 should_continue = False 1212 return should_continue 1213 1214 1215def _serve(state: _ServerState) -> None: 1216 while True: 1217 timeout = time.time() + _DEALLOCATED_SERVER_CHECK_PERIOD_S 1218 event = state.completion_queue.poll(timeout) 1219 if state.server_deallocated: 1220 _begin_shutdown_once(state) 1221 if event.completion_type != cygrpc.CompletionType.queue_timeout: 1222 if not _process_event_and_continue(state, event): 1223 return 1224 # We want to force the deletion of the previous event 1225 # ~before~ we poll again; if the event has a reference 1226 # to a shutdown Call object, this can induce spinlock. 1227 event = None 1228 1229 1230def _begin_shutdown_once(state: _ServerState) -> None: 1231 with state.lock: 1232 if state.stage is _ServerStage.STARTED: 1233 state.server.shutdown(state.completion_queue, _SHUTDOWN_TAG) 1234 state.stage = _ServerStage.GRACE 1235 state.due.add(_SHUTDOWN_TAG) 1236 1237 1238def _stop(state: _ServerState, grace: Optional[float]) -> threading.Event: 1239 with state.lock: 1240 if state.stage is _ServerStage.STOPPED: 1241 shutdown_event = threading.Event() 1242 shutdown_event.set() 1243 return shutdown_event 1244 else: 1245 _begin_shutdown_once(state) 1246 shutdown_event = threading.Event() 1247 state.shutdown_events.append(shutdown_event) 1248 if grace is None: 1249 state.server.cancel_all_calls() 1250 else: 1251 1252 def cancel_all_calls_after_grace(): 1253 shutdown_event.wait(timeout=grace) 1254 with state.lock: 1255 state.server.cancel_all_calls() 1256 1257 thread = threading.Thread(target=cancel_all_calls_after_grace) 1258 thread.start() 1259 return shutdown_event 1260 shutdown_event.wait() 1261 return shutdown_event 1262 1263 1264def _start(state: _ServerState) -> None: 1265 with state.lock: 1266 if state.stage is not _ServerStage.STOPPED: 1267 raise ValueError("Cannot start already-started server!") 1268 state.server.start() 1269 state.stage = _ServerStage.STARTED 1270 _request_call(state) 1271 thread = threading.Thread(target=_serve, args=(state,)) 1272 thread.daemon = True 1273 thread.start() 1274 1275 1276def _validate_generic_rpc_handlers( 1277 generic_rpc_handlers: Iterable[grpc.GenericRpcHandler], 1278) -> None: 1279 for generic_rpc_handler in generic_rpc_handlers: 1280 service_attribute = getattr(generic_rpc_handler, "service", None) 1281 if service_attribute is None: 1282 raise AttributeError( 1283 '"{}" must conform to grpc.GenericRpcHandler type but does ' 1284 'not have "service" method!'.format(generic_rpc_handler) 1285 ) 1286 1287 1288def _augment_options( 1289 base_options: Sequence[ChannelArgumentType], 1290 compression: Optional[grpc.Compression], 1291) -> Sequence[ChannelArgumentType]: 1292 compression_option = _compression.create_channel_option(compression) 1293 return tuple(base_options) + compression_option 1294 1295 1296class _Server(grpc.Server): 1297 _state: _ServerState 1298 1299 # pylint: disable=too-many-arguments 1300 def __init__( 1301 self, 1302 thread_pool: futures.ThreadPoolExecutor, 1303 generic_handlers: Sequence[grpc.GenericRpcHandler], 1304 interceptors: Sequence[grpc.ServerInterceptor], 1305 options: Sequence[ChannelArgumentType], 1306 maximum_concurrent_rpcs: Optional[int], 1307 compression: Optional[grpc.Compression], 1308 xds: bool, 1309 ): 1310 completion_queue = cygrpc.CompletionQueue() 1311 server = cygrpc.Server(_augment_options(options, compression), xds) 1312 server.register_completion_queue(completion_queue) 1313 self._state = _ServerState( 1314 completion_queue, 1315 server, 1316 generic_handlers, 1317 _interceptor.service_pipeline(interceptors), 1318 thread_pool, 1319 maximum_concurrent_rpcs, 1320 ) 1321 1322 def add_generic_rpc_handlers( 1323 self, generic_rpc_handlers: Iterable[grpc.GenericRpcHandler] 1324 ) -> None: 1325 _validate_generic_rpc_handlers(generic_rpc_handlers) 1326 _add_generic_handlers(self._state, generic_rpc_handlers) 1327 1328 def add_insecure_port(self, address: str) -> int: 1329 return _common.validate_port_binding_result( 1330 address, _add_insecure_port(self._state, _common.encode(address)) 1331 ) 1332 1333 def add_secure_port( 1334 self, address: str, server_credentials: grpc.ServerCredentials 1335 ) -> int: 1336 return _common.validate_port_binding_result( 1337 address, 1338 _add_secure_port( 1339 self._state, _common.encode(address), server_credentials 1340 ), 1341 ) 1342 1343 def start(self) -> None: 1344 _start(self._state) 1345 1346 def wait_for_termination(self, timeout: Optional[float] = None) -> bool: 1347 # NOTE(https://bugs.python.org/issue35935) 1348 # Remove this workaround once threading.Event.wait() is working with 1349 # CTRL+C across platforms. 1350 return _common.wait( 1351 self._state.termination_event.wait, 1352 self._state.termination_event.is_set, 1353 timeout=timeout, 1354 ) 1355 1356 def stop(self, grace: Optional[float]) -> threading.Event: 1357 return _stop(self._state, grace) 1358 1359 def __del__(self): 1360 if hasattr(self, "_state"): 1361 # We can not grab a lock in __del__(), so set a flag to signal the 1362 # serving daemon thread (if it exists) to initiate shutdown. 1363 self._state.server_deallocated = True 1364 1365 1366def create_server( 1367 thread_pool: futures.ThreadPoolExecutor, 1368 generic_rpc_handlers: Sequence[grpc.GenericRpcHandler], 1369 interceptors: Sequence[grpc.ServerInterceptor], 1370 options: Sequence[ChannelArgumentType], 1371 maximum_concurrent_rpcs: Optional[int], 1372 compression: Optional[grpc.Compression], 1373 xds: bool, 1374) -> _Server: 1375 _validate_generic_rpc_handlers(generic_rpc_handlers) 1376 return _Server( 1377 thread_pool, 1378 generic_rpc_handlers, 1379 interceptors, 1380 options, 1381 maximum_concurrent_rpcs, 1382 compression, 1383 xds, 1384 ) 1385