1# Copyright 2019 The gRPC Authors 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15 16import inspect 17import traceback 18import functools 19 20 21cdef int _EMPTY_FLAG = 0 22cdef str _RPC_FINISHED_DETAILS = 'RPC already finished.' 23cdef str _SERVER_STOPPED_DETAILS = 'Server already stopped.' 24 25cdef _augment_metadata(tuple metadata, object compression): 26 if compression is None: 27 return metadata 28 else: 29 return (( 30 GRPC_COMPRESSION_REQUEST_ALGORITHM_MD_KEY, 31 _COMPRESSION_METADATA_STRING_MAPPING[compression] 32 ),) + metadata 33 34 35cdef class _HandlerCallDetails: 36 def __cinit__(self, str method, tuple invocation_metadata): 37 self.method = method 38 self.invocation_metadata = invocation_metadata 39 40 41class _ServerStoppedError(BaseError): 42 """Raised if the server is stopped.""" 43 44 45cdef class RPCState: 46 47 def __cinit__(self, AioServer server): 48 init_grpc_aio() 49 self.call = NULL 50 self.server = server 51 grpc_metadata_array_init(&self.request_metadata) 52 grpc_call_details_init(&self.details) 53 self.client_closed = False 54 self.abort_exception = None 55 self.metadata_sent = False 56 self.status_sent = False 57 self.status_code = StatusCode.ok 58 self.py_status_code = None 59 self.status_details = '' 60 self.trailing_metadata = _IMMUTABLE_EMPTY_METADATA 61 self.compression_algorithm = None 62 self.disable_next_compression = False 63 self.callbacks = [] 64 65 cdef bytes method(self): 66 return _slice_bytes(self.details.method) 67 68 cdef tuple invocation_metadata(self): 69 return _metadata(&self.request_metadata) 70 71 cdef void raise_for_termination(self) except *: 72 """Raise exceptions if RPC is not running. 73 74 Server method handlers may suppress the abort exception. We need to halt 75 the RPC execution in that case. This function needs to be called after 76 running application code. 77 78 Also, the server may stop unexpected. We need to check before calling 79 into Core functions, otherwise, segfault. 80 """ 81 if self.abort_exception is not None: 82 raise self.abort_exception 83 if self.status_sent: 84 raise UsageError(_RPC_FINISHED_DETAILS) 85 if self.server._status == AIO_SERVER_STATUS_STOPPED: 86 raise _ServerStoppedError(_SERVER_STOPPED_DETAILS) 87 88 cdef int get_write_flag(self): 89 if self.disable_next_compression: 90 self.disable_next_compression = False 91 return WriteFlag.no_compress 92 else: 93 return _EMPTY_FLAG 94 95 cdef Operation create_send_initial_metadata_op_if_not_sent(self): 96 cdef SendInitialMetadataOperation op 97 if self.metadata_sent: 98 return None 99 else: 100 op = SendInitialMetadataOperation( 101 _augment_metadata(_IMMUTABLE_EMPTY_METADATA, self.compression_algorithm), 102 _EMPTY_FLAG 103 ) 104 return op 105 106 def __dealloc__(self): 107 """Cleans the Core objects.""" 108 grpc_call_details_destroy(&self.details) 109 grpc_metadata_array_destroy(&self.request_metadata) 110 if self.call: 111 grpc_call_unref(self.call) 112 shutdown_grpc_aio() 113 114 115cdef class _ServicerContext: 116 117 def __cinit__(self, 118 RPCState rpc_state, 119 object request_deserializer, 120 object response_serializer, 121 object loop): 122 self._rpc_state = rpc_state 123 self._request_deserializer = request_deserializer 124 self._response_serializer = response_serializer 125 self._loop = loop 126 127 async def read(self): 128 cdef bytes raw_message 129 self._rpc_state.raise_for_termination() 130 131 raw_message = await _receive_message(self._rpc_state, self._loop) 132 self._rpc_state.raise_for_termination() 133 134 if raw_message is None: 135 return EOF 136 else: 137 return deserialize(self._request_deserializer, 138 raw_message) 139 140 async def write(self, object message): 141 self._rpc_state.raise_for_termination() 142 143 await _send_message(self._rpc_state, 144 serialize(self._response_serializer, message), 145 self._rpc_state.create_send_initial_metadata_op_if_not_sent(), 146 self._rpc_state.get_write_flag(), 147 self._loop) 148 self._rpc_state.metadata_sent = True 149 150 async def send_initial_metadata(self, object metadata): 151 self._rpc_state.raise_for_termination() 152 153 if self._rpc_state.metadata_sent: 154 raise UsageError('Send initial metadata failed: already sent') 155 else: 156 await _send_initial_metadata( 157 self._rpc_state, 158 _augment_metadata(tuple(metadata), self._rpc_state.compression_algorithm), 159 _EMPTY_FLAG, 160 self._loop 161 ) 162 self._rpc_state.metadata_sent = True 163 164 async def abort(self, 165 object code, 166 str details='', 167 tuple trailing_metadata=_IMMUTABLE_EMPTY_METADATA): 168 if self._rpc_state.abort_exception is not None: 169 raise UsageError('Abort already called!') 170 else: 171 # Keeps track of the exception object. After abort happen, the RPC 172 # should stop execution. However, if users decided to suppress it, it 173 # could lead to undefined behavior. 174 self._rpc_state.abort_exception = AbortError('Locally aborted.') 175 176 if trailing_metadata == _IMMUTABLE_EMPTY_METADATA and self._rpc_state.trailing_metadata: 177 trailing_metadata = self._rpc_state.trailing_metadata 178 else: 179 raise_if_not_valid_trailing_metadata(trailing_metadata) 180 self._rpc_state.trailing_metadata = trailing_metadata 181 182 if details == '' and self._rpc_state.status_details: 183 details = self._rpc_state.status_details 184 else: 185 self._rpc_state.status_details = details 186 187 actual_code = get_status_code(code) 188 self._rpc_state.py_status_code = code 189 self._rpc_state.status_code = actual_code 190 191 self._rpc_state.status_sent = True 192 await _send_error_status_from_server( 193 self._rpc_state, 194 actual_code, 195 details, 196 trailing_metadata, 197 self._rpc_state.create_send_initial_metadata_op_if_not_sent(), 198 self._loop 199 ) 200 201 raise self._rpc_state.abort_exception 202 203 async def abort_with_status(self, object status): 204 await self.abort(status.code, status.details, status.trailing_metadata) 205 206 def set_trailing_metadata(self, object metadata): 207 raise_if_not_valid_trailing_metadata(metadata) 208 self._rpc_state.trailing_metadata = tuple(metadata) 209 210 def trailing_metadata(self): 211 return self._rpc_state.trailing_metadata 212 213 def invocation_metadata(self): 214 return self._rpc_state.invocation_metadata() 215 216 def set_code(self, object code): 217 self._rpc_state.status_code = get_status_code(code) 218 self._rpc_state.py_status_code = code 219 220 def code(self): 221 return self._rpc_state.py_status_code 222 223 def set_details(self, str details): 224 self._rpc_state.status_details = details 225 226 def details(self): 227 return self._rpc_state.status_details 228 229 def set_compression(self, object compression): 230 if self._rpc_state.metadata_sent: 231 raise RuntimeError('Compression setting must be specified before sending initial metadata') 232 else: 233 self._rpc_state.compression_algorithm = compression 234 235 def disable_next_message_compression(self): 236 self._rpc_state.disable_next_compression = True 237 238 def peer(self): 239 cdef char *c_peer = NULL 240 c_peer = grpc_call_get_peer(self._rpc_state.call) 241 peer = (<bytes>c_peer).decode('utf8') 242 gpr_free(c_peer) 243 return peer 244 245 def peer_identities(self): 246 cdef Call query_call = Call() 247 query_call.c_call = self._rpc_state.call 248 identities = peer_identities(query_call) 249 query_call.c_call = NULL 250 return identities 251 252 def peer_identity_key(self): 253 cdef Call query_call = Call() 254 query_call.c_call = self._rpc_state.call 255 identity_key = peer_identity_key(query_call) 256 query_call.c_call = NULL 257 if identity_key: 258 return identity_key.decode('utf8') 259 else: 260 return None 261 262 def auth_context(self): 263 cdef Call query_call = Call() 264 query_call.c_call = self._rpc_state.call 265 bytes_ctx = auth_context(query_call) 266 query_call.c_call = NULL 267 if bytes_ctx: 268 ctx = {} 269 for key in bytes_ctx: 270 ctx[key.decode('utf8')] = bytes_ctx[key] 271 return ctx 272 else: 273 return {} 274 275 def time_remaining(self): 276 if self._rpc_state.details.deadline.seconds == _GPR_INF_FUTURE.seconds: 277 return None 278 else: 279 return max(_time_from_timespec(self._rpc_state.details.deadline) - time.time(), 0) 280 281 def add_done_callback(self, callback): 282 cb = functools.partial(callback, self) 283 self._rpc_state.callbacks.append(cb) 284 285 def done(self): 286 return self._rpc_state.status_sent 287 288 def cancelled(self): 289 return self._rpc_state.status_code == StatusCode.cancelled 290 291 292cdef class _SyncServicerContext: 293 """Sync servicer context for sync handler compatibility.""" 294 295 def __cinit__(self, 296 _ServicerContext context): 297 self._context = context 298 self._callbacks = [] 299 self._loop = context._loop 300 301 def abort(self, 302 object code, 303 str details='', 304 tuple trailing_metadata=_IMMUTABLE_EMPTY_METADATA): 305 future = asyncio.run_coroutine_threadsafe( 306 self._context.abort(code, details, trailing_metadata), 307 self._loop) 308 # Abort should raise an AbortError 309 future.exception() 310 311 def send_initial_metadata(self, object metadata): 312 future = asyncio.run_coroutine_threadsafe( 313 self._context.send_initial_metadata(metadata), 314 self._loop) 315 future.result() 316 317 def set_trailing_metadata(self, object metadata): 318 self._context.set_trailing_metadata(metadata) 319 320 def invocation_metadata(self): 321 return self._context.invocation_metadata() 322 323 def set_code(self, object code): 324 self._context.set_code(code) 325 326 def set_details(self, str details): 327 self._context.set_details(details) 328 329 def set_compression(self, object compression): 330 self._context.set_compression(compression) 331 332 def disable_next_message_compression(self): 333 self._context.disable_next_message_compression() 334 335 def add_callback(self, object callback): 336 self._callbacks.append(callback) 337 338 def peer(self): 339 return self._context.peer() 340 341 def peer_identities(self): 342 return self._context.peer_identities() 343 344 def peer_identity_key(self): 345 return self._context.peer_identity_key() 346 347 def auth_context(self): 348 return self._context.auth_context() 349 350 def time_remaining(self): 351 return self._context.time_remaining() 352 353 354async def _run_interceptor(object interceptors, object query_handler, 355 object handler_call_details): 356 interceptor = next(interceptors, None) 357 if interceptor: 358 continuation = functools.partial(_run_interceptor, interceptors, 359 query_handler) 360 return await interceptor.intercept_service(continuation, handler_call_details) 361 else: 362 return query_handler(handler_call_details) 363 364 365def _is_async_handler(object handler): 366 """Inspect if a method handler is async or sync.""" 367 return inspect.isawaitable(handler) or inspect.iscoroutinefunction(handler) or inspect.isasyncgenfunction(handler) 368 369 370async def _find_method_handler(str method, tuple metadata, list generic_handlers, 371 tuple interceptors): 372 def query_handlers(handler_call_details): 373 for generic_handler in generic_handlers: 374 method_handler = generic_handler.service(handler_call_details) 375 if method_handler is not None: 376 return method_handler 377 return None 378 379 cdef _HandlerCallDetails handler_call_details = _HandlerCallDetails(method, 380 metadata) 381 # interceptor 382 if interceptors: 383 return await _run_interceptor(iter(interceptors), query_handlers, 384 handler_call_details) 385 else: 386 return query_handlers(handler_call_details) 387 388 389async def _finish_handler_with_unary_response(RPCState rpc_state, 390 object unary_handler, 391 object request, 392 _ServicerContext servicer_context, 393 object response_serializer, 394 object loop): 395 """Finishes server method handler with a single response. 396 397 This function executes the application handler, and handles response 398 sending, as well as errors. It is shared between unary-unary and 399 stream-unary handlers. 400 """ 401 # Executes application logic 402 cdef object response_message 403 cdef _SyncServicerContext sync_servicer_context 404 install_context_from_request_call_event_aio(rpc_state) 405 406 if _is_async_handler(unary_handler): 407 # Run async method handlers in this coroutine 408 response_message = await unary_handler( 409 request, 410 servicer_context, 411 ) 412 else: 413 # Run sync method handlers in the thread pool 414 sync_servicer_context = _SyncServicerContext(servicer_context) 415 response_message = await loop.run_in_executor( 416 rpc_state.server.thread_pool(), 417 unary_handler, 418 request, 419 sync_servicer_context, 420 ) 421 # Support sync-stack callback 422 for callback in sync_servicer_context._callbacks: 423 callback() 424 425 # Raises exception if aborted 426 rpc_state.raise_for_termination() 427 428 # Serializes the response message 429 cdef bytes response_raw 430 if rpc_state.status_code == StatusCode.ok: 431 response_raw = serialize( 432 response_serializer, 433 response_message, 434 ) 435 else: 436 # Discards the response message if the status code is non-OK. 437 response_raw = b'' 438 439 # Assembles the batch operations 440 cdef tuple finish_ops 441 finish_ops = ( 442 SendMessageOperation(response_raw, rpc_state.get_write_flag()), 443 SendStatusFromServerOperation( 444 rpc_state.trailing_metadata, 445 rpc_state.status_code, 446 rpc_state.status_details, 447 _EMPTY_FLAGS, 448 ), 449 ) 450 if not rpc_state.metadata_sent: 451 finish_ops = prepend_send_initial_metadata_op( 452 finish_ops, 453 None) 454 rpc_state.metadata_sent = True 455 rpc_state.status_sent = True 456 await execute_batch(rpc_state, finish_ops, loop) 457 uninstall_context() 458 459 460async def _finish_handler_with_stream_responses(RPCState rpc_state, 461 object stream_handler, 462 object request, 463 _ServicerContext servicer_context, 464 object loop): 465 """Finishes server method handler with multiple responses. 466 467 This function executes the application handler, and handles response 468 sending, as well as errors. It is shared between unary-stream and 469 stream-stream handlers. 470 """ 471 cdef object async_response_generator 472 cdef object response_message 473 install_context_from_request_call_event_aio(rpc_state) 474 475 if inspect.iscoroutinefunction(stream_handler): 476 # Case 1: Coroutine async handler - using reader-writer API 477 # The handler uses reader / writer API, returns None. 478 await stream_handler( 479 request, 480 servicer_context, 481 ) 482 else: 483 if inspect.isasyncgenfunction(stream_handler): 484 # Case 2: Async handler - async generator 485 # The handler uses async generator API 486 async_response_generator = stream_handler( 487 request, 488 servicer_context, 489 ) 490 else: 491 # Case 3: Sync handler - normal generator 492 # NOTE(lidiz) Streaming handler in sync stack is either a generator 493 # function or a function returns a generator. 494 sync_servicer_context = _SyncServicerContext(servicer_context) 495 gen = stream_handler(request, sync_servicer_context) 496 async_response_generator = generator_to_async_generator(gen, 497 loop, 498 rpc_state.server.thread_pool()) 499 500 # Consumes messages from the generator 501 async for response_message in async_response_generator: 502 # Raises exception if aborted 503 rpc_state.raise_for_termination() 504 505 await servicer_context.write(response_message) 506 507 # Raises exception if aborted 508 rpc_state.raise_for_termination() 509 510 # Sends the final status of this RPC 511 cdef SendStatusFromServerOperation op = SendStatusFromServerOperation( 512 rpc_state.trailing_metadata, 513 rpc_state.status_code, 514 rpc_state.status_details, 515 _EMPTY_FLAGS, 516 ) 517 518 cdef tuple finish_ops = (op,) 519 if not rpc_state.metadata_sent: 520 finish_ops = prepend_send_initial_metadata_op( 521 finish_ops, 522 None 523 ) 524 rpc_state.metadata_sent = True 525 rpc_state.status_sent = True 526 await execute_batch(rpc_state, finish_ops, loop) 527 uninstall_context() 528 529 530async def _handle_unary_unary_rpc(object method_handler, 531 RPCState rpc_state, 532 object loop): 533 # Receives request message 534 cdef bytes request_raw = await _receive_message(rpc_state, loop) 535 if request_raw is None: 536 # The RPC was cancelled immediately after start on client side. 537 return 538 539 # Deserializes the request message 540 cdef object request_message = deserialize( 541 method_handler.request_deserializer, 542 request_raw, 543 ) 544 545 # Creates a dedecated ServicerContext 546 cdef _ServicerContext servicer_context = _ServicerContext( 547 rpc_state, 548 None, 549 None, 550 loop, 551 ) 552 553 # Finishes the application handler 554 await _finish_handler_with_unary_response( 555 rpc_state, 556 method_handler.unary_unary, 557 request_message, 558 servicer_context, 559 method_handler.response_serializer, 560 loop 561 ) 562 563 564async def _handle_unary_stream_rpc(object method_handler, 565 RPCState rpc_state, 566 object loop): 567 # Receives request message 568 cdef bytes request_raw = await _receive_message(rpc_state, loop) 569 if request_raw is None: 570 return 571 572 # Deserializes the request message 573 cdef object request_message = deserialize( 574 method_handler.request_deserializer, 575 request_raw, 576 ) 577 578 # Creates a dedecated ServicerContext 579 cdef _ServicerContext servicer_context = _ServicerContext( 580 rpc_state, 581 method_handler.request_deserializer, 582 method_handler.response_serializer, 583 loop, 584 ) 585 586 # Finishes the application handler 587 await _finish_handler_with_stream_responses( 588 rpc_state, 589 method_handler.unary_stream, 590 request_message, 591 servicer_context, 592 loop, 593 ) 594 595 596cdef class _MessageReceiver: 597 """Bridge between the async generator API and the reader-writer API.""" 598 599 def __cinit__(self, _ServicerContext servicer_context): 600 self._servicer_context = servicer_context 601 self._agen = None 602 603 async def _async_message_receiver(self): 604 """An async generator that receives messages.""" 605 cdef object message 606 while True: 607 message = await self._servicer_context.read() 608 if message is not EOF: 609 yield message 610 else: 611 break 612 613 def __aiter__(self): 614 # Prevents never awaited warning if application never used the async generator 615 if self._agen is None: 616 self._agen = self._async_message_receiver() 617 return self._agen 618 619 async def __anext__(self): 620 return await self.__aiter__().__anext__() 621 622 623async def _handle_stream_unary_rpc(object method_handler, 624 RPCState rpc_state, 625 object loop): 626 # Creates a dedecated ServicerContext 627 cdef _ServicerContext servicer_context = _ServicerContext( 628 rpc_state, 629 method_handler.request_deserializer, 630 None, 631 loop, 632 ) 633 634 # Prepares the request generator 635 cdef object request_iterator 636 if _is_async_handler(method_handler.stream_unary): 637 request_iterator = _MessageReceiver(servicer_context) 638 else: 639 request_iterator = async_generator_to_generator( 640 _MessageReceiver(servicer_context), 641 loop 642 ) 643 644 # Finishes the application handler 645 await _finish_handler_with_unary_response( 646 rpc_state, 647 method_handler.stream_unary, 648 request_iterator, 649 servicer_context, 650 method_handler.response_serializer, 651 loop 652 ) 653 654 655async def _handle_stream_stream_rpc(object method_handler, 656 RPCState rpc_state, 657 object loop): 658 # Creates a dedecated ServicerContext 659 cdef _ServicerContext servicer_context = _ServicerContext( 660 rpc_state, 661 method_handler.request_deserializer, 662 method_handler.response_serializer, 663 loop, 664 ) 665 666 # Prepares the request generator 667 cdef object request_iterator 668 if _is_async_handler(method_handler.stream_stream): 669 request_iterator = _MessageReceiver(servicer_context) 670 else: 671 request_iterator = async_generator_to_generator( 672 _MessageReceiver(servicer_context), 673 loop 674 ) 675 676 # Finishes the application handler 677 await _finish_handler_with_stream_responses( 678 rpc_state, 679 method_handler.stream_stream, 680 request_iterator, 681 servicer_context, 682 loop, 683 ) 684 685 686async def _handle_exceptions(RPCState rpc_state, object rpc_coro, object loop): 687 try: 688 try: 689 await rpc_coro 690 except AbortError as e: 691 # Caught AbortError check if it is the same one 692 assert rpc_state.abort_exception is e, 'Abort error has been replaced!' 693 return 694 else: 695 # Check if the abort exception got suppressed 696 if rpc_state.abort_exception is not None: 697 _LOGGER.error( 698 'Abort error unexpectedly suppressed: %s', 699 traceback.format_exception(rpc_state.abort_exception) 700 ) 701 except (KeyboardInterrupt, SystemExit): 702 raise 703 except asyncio.CancelledError: 704 _LOGGER.debug('RPC cancelled for servicer method [%s]', _decode(rpc_state.method())) 705 except _ServerStoppedError: 706 _LOGGER.warning('Aborting method [%s] due to server stop.', _decode(rpc_state.method())) 707 except ExecuteBatchError: 708 # If client closed (aka. cancelled), ignore the failed batch operations. 709 if rpc_state.client_closed: 710 return 711 else: 712 _LOGGER.exception('ExecuteBatchError raised in core by servicer method [%s]' % ( 713 _decode(rpc_state.method()))) 714 return 715 except Exception as e: 716 _LOGGER.exception('Unexpected [%s] raised by servicer method [%s]' % ( 717 type(e).__name__, 718 _decode(rpc_state.method()), 719 )) 720 if not rpc_state.status_sent and rpc_state.server._status != AIO_SERVER_STATUS_STOPPED: 721 # Allows users to raise other types of exception with specified status code 722 if rpc_state.status_code == StatusCode.ok: 723 status_code = StatusCode.unknown 724 else: 725 status_code = rpc_state.status_code 726 727 rpc_state.status_sent = True 728 try: 729 await _send_error_status_from_server( 730 rpc_state, 731 status_code, 732 'Unexpected %s: %s' % (type(e), e), 733 rpc_state.trailing_metadata, 734 rpc_state.create_send_initial_metadata_op_if_not_sent(), 735 loop 736 ) 737 except ExecuteBatchError: 738 _LOGGER.exception('Failed sending error status from server') 739 traceback.print_exc() 740 741 742cdef _add_callback_handler(object rpc_task, RPCState rpc_state): 743 744 def handle_callbacks(object unused_task): 745 try: 746 for callback in rpc_state.callbacks: 747 # The _ServicerContext object is bound in add_done_callback. 748 callback() 749 except: 750 _LOGGER.exception('Error in callback for method [%s]', _decode(rpc_state.method())) 751 752 rpc_task.add_done_callback(handle_callbacks) 753 754 755async def _handle_cancellation_from_core(object rpc_task, 756 RPCState rpc_state, 757 object loop): 758 cdef ReceiveCloseOnServerOperation op = ReceiveCloseOnServerOperation(_EMPTY_FLAG) 759 cdef tuple ops = (op,) 760 761 # Awaits cancellation from peer. 762 await execute_batch(rpc_state, ops, loop) 763 rpc_state.client_closed = True 764 # If 1) received cancel signal; 2) the Task is not finished; 3) the server 765 # wasn't replying final status. For condition 3, it might cause inaccurate 766 # log that an RPC is both aborted and cancelled. 767 if op.cancelled() and not rpc_task.done() and not rpc_state.status_sent: 768 # Injects `CancelledError` to halt the RPC coroutine 769 rpc_task.cancel() 770 771 772async def _schedule_rpc_coro(object rpc_coro, 773 RPCState rpc_state, 774 object loop): 775 # Schedules the RPC coroutine. 776 cdef object rpc_task = loop.create_task(_handle_exceptions( 777 rpc_state, 778 rpc_coro, 779 loop, 780 )) 781 _add_callback_handler(rpc_task, rpc_state) 782 await _handle_cancellation_from_core(rpc_task, rpc_state, loop) 783 784 785async def _handle_rpc(list generic_handlers, tuple interceptors, 786 RPCState rpc_state, object loop, bint concurrency_exceeded): 787 cdef object method_handler 788 # Finds the method handler (application logic) 789 method_handler = await _find_method_handler( 790 rpc_state.method().decode(), 791 rpc_state.invocation_metadata(), 792 generic_handlers, 793 interceptors, 794 ) 795 if method_handler is None: 796 rpc_state.status_sent = True 797 await _send_error_status_from_server( 798 rpc_state, 799 StatusCode.unimplemented, 800 'Method not found!', 801 _IMMUTABLE_EMPTY_METADATA, 802 rpc_state.create_send_initial_metadata_op_if_not_sent(), 803 loop 804 ) 805 return 806 807 if concurrency_exceeded: 808 rpc_state.status_sent = True 809 await _send_error_status_from_server( 810 rpc_state, 811 StatusCode.resource_exhausted, 812 'Concurrent RPC limit exceeded!', 813 _IMMUTABLE_EMPTY_METADATA, 814 rpc_state.create_send_initial_metadata_op_if_not_sent(), 815 loop 816 ) 817 return 818 819 # Handles unary-unary case 820 if not method_handler.request_streaming and not method_handler.response_streaming: 821 await _handle_unary_unary_rpc(method_handler, 822 rpc_state, 823 loop) 824 return 825 826 # Handles unary-stream case 827 if not method_handler.request_streaming and method_handler.response_streaming: 828 await _handle_unary_stream_rpc(method_handler, 829 rpc_state, 830 loop) 831 return 832 833 # Handles stream-unary case 834 if method_handler.request_streaming and not method_handler.response_streaming: 835 await _handle_stream_unary_rpc(method_handler, 836 rpc_state, 837 loop) 838 return 839 840 # Handles stream-stream case 841 if method_handler.request_streaming and method_handler.response_streaming: 842 await _handle_stream_stream_rpc(method_handler, 843 rpc_state, 844 loop) 845 return 846 847 848class _RequestCallError(Exception): pass 849 850cdef CallbackFailureHandler REQUEST_CALL_FAILURE_HANDLER = CallbackFailureHandler( 851 'grpc_server_request_call', None, _RequestCallError) 852 853 854cdef CallbackFailureHandler SERVER_SHUTDOWN_FAILURE_HANDLER = CallbackFailureHandler( 855 'grpc_server_shutdown_and_notify', 856 None, 857 InternalError) 858 859 860cdef class _ConcurrentRpcLimiter: 861 862 def __cinit__(self, int maximum_concurrent_rpcs): 863 if maximum_concurrent_rpcs <= 0: 864 raise ValueError("maximum_concurrent_rpcs should be a postive integer") 865 self._maximum_concurrent_rpcs = maximum_concurrent_rpcs 866 self._active_rpcs = 0 867 self.limiter_concurrency_exceeded = False 868 869 def check_before_request_call(self): 870 if self._active_rpcs >= self._maximum_concurrent_rpcs: 871 self.limiter_concurrency_exceeded = True 872 else: 873 self._active_rpcs += 1 874 875 def _decrease_active_rpcs_count(self, unused_future): 876 self._active_rpcs -= 1 877 if self._active_rpcs < self._maximum_concurrent_rpcs: 878 self.limiter_concurrency_exceeded = False 879 880 def decrease_once_finished(self, object rpc_task): 881 rpc_task.add_done_callback(self._decrease_active_rpcs_count) 882 883 884cdef class AioServer: 885 886 def __init__(self, loop, thread_pool, generic_handlers, interceptors, 887 options, maximum_concurrent_rpcs): 888 init_grpc_aio() 889 # NOTE(lidiz) Core objects won't be deallocated automatically. 890 # If AioServer.shutdown is not called, those objects will leak. 891 # TODO(rbellevi): Support xDS in aio server. 892 self._server = Server(options, False) 893 grpc_server_register_completion_queue( 894 self._server.c_server, 895 global_completion_queue(), 896 NULL 897 ) 898 899 self._loop = loop 900 self._status = AIO_SERVER_STATUS_READY 901 self._generic_handlers = [] 902 self.add_generic_rpc_handlers(generic_handlers) 903 self._serving_task = None 904 905 self._shutdown_lock = asyncio.Lock() 906 self._shutdown_completed = self._loop.create_future() 907 self._shutdown_callback_wrapper = CallbackWrapper( 908 self._shutdown_completed, 909 self._loop, 910 SERVER_SHUTDOWN_FAILURE_HANDLER) 911 self._crash_exception = None 912 913 if interceptors: 914 self._interceptors = tuple(interceptors) 915 else: 916 self._interceptors = () 917 918 self._thread_pool = thread_pool 919 if maximum_concurrent_rpcs is not None: 920 self._limiter = _ConcurrentRpcLimiter(maximum_concurrent_rpcs) 921 922 def add_generic_rpc_handlers(self, object generic_rpc_handlers): 923 self._generic_handlers.extend(generic_rpc_handlers) 924 925 def add_insecure_port(self, address): 926 return self._server.add_http2_port(address) 927 928 def add_secure_port(self, address, server_credentials): 929 return self._server.add_http2_port(address, 930 server_credentials._credentials) 931 932 async def _request_call(self): 933 cdef grpc_call_error error 934 cdef RPCState rpc_state = RPCState(self) 935 cdef object future = self._loop.create_future() 936 cdef CallbackWrapper wrapper = CallbackWrapper( 937 future, 938 self._loop, 939 REQUEST_CALL_FAILURE_HANDLER) 940 error = grpc_server_request_call( 941 self._server.c_server, &rpc_state.call, &rpc_state.details, 942 &rpc_state.request_metadata, 943 global_completion_queue(), global_completion_queue(), 944 wrapper.c_functor() 945 ) 946 if error != GRPC_CALL_OK: 947 raise InternalError("Error in grpc_server_request_call: %s" % error) 948 949 await future 950 return rpc_state 951 952 async def _server_main_loop(self, 953 object server_started): 954 self._server.start(backup_queue=False) 955 cdef RPCState rpc_state 956 server_started.set_result(True) 957 958 while True: 959 # When shutdown begins, no more new connections. 960 if self._status != AIO_SERVER_STATUS_RUNNING: 961 break 962 963 concurrency_exceeded = False 964 if self._limiter is not None: 965 self._limiter.check_before_request_call() 966 concurrency_exceeded = self._limiter.limiter_concurrency_exceeded 967 968 # Accepts new request from Core 969 rpc_state = await self._request_call() 970 971 # Creates the dedicated RPC coroutine. If we schedule it right now, 972 # there is no guarantee if the cancellation listening coroutine is 973 # ready or not. So, we should control the ordering by scheduling 974 # the coroutine onto event loop inside of the cancellation 975 # coroutine. 976 rpc_coro = _handle_rpc(self._generic_handlers, 977 self._interceptors, 978 rpc_state, 979 self._loop, 980 concurrency_exceeded) 981 982 # Fires off a task that listens on the cancellation from client. 983 rpc_task = self._loop.create_task( 984 _schedule_rpc_coro( 985 rpc_coro, 986 rpc_state, 987 self._loop 988 ) 989 ) 990 991 if self._limiter is not None: 992 self._limiter.decrease_once_finished(rpc_task) 993 994 def _serving_task_crash_handler(self, object task): 995 """Shutdown the server immediately if unexpectedly exited.""" 996 if task.cancelled(): 997 return 998 if task.exception() is None: 999 return 1000 if self._status != AIO_SERVER_STATUS_STOPPING: 1001 self._crash_exception = task.exception() 1002 _LOGGER.exception(self._crash_exception) 1003 self._loop.create_task(self.shutdown(None)) 1004 1005 async def start(self): 1006 if self._status == AIO_SERVER_STATUS_RUNNING: 1007 return 1008 elif self._status != AIO_SERVER_STATUS_READY: 1009 raise UsageError('Server not in ready state') 1010 1011 self._status = AIO_SERVER_STATUS_RUNNING 1012 cdef object server_started = self._loop.create_future() 1013 self._serving_task = self._loop.create_task(self._server_main_loop(server_started)) 1014 self._serving_task.add_done_callback(self._serving_task_crash_handler) 1015 # Needs to explicitly wait for the server to start up. 1016 # Otherwise, the actual start time of the server is un-controllable. 1017 await server_started 1018 1019 async def _start_shutting_down(self): 1020 """Prepares the server to shutting down. 1021 1022 This coroutine function is NOT coroutine-safe. 1023 """ 1024 # The shutdown callback won't be called until there is no live RPC. 1025 grpc_server_shutdown_and_notify( 1026 self._server.c_server, 1027 global_completion_queue(), 1028 self._shutdown_callback_wrapper.c_functor()) 1029 1030 # Ensures the serving task (coroutine) exits. 1031 try: 1032 await self._serving_task 1033 except _RequestCallError: 1034 pass 1035 1036 async def shutdown(self, grace): 1037 """Gracefully shutdown the Core server. 1038 1039 Application should only call shutdown once. 1040 1041 Args: 1042 grace: An optional float indicating the length of grace period in 1043 seconds. 1044 """ 1045 if self._status == AIO_SERVER_STATUS_READY or self._status == AIO_SERVER_STATUS_STOPPED: 1046 return 1047 1048 async with self._shutdown_lock: 1049 if self._status == AIO_SERVER_STATUS_RUNNING: 1050 self._server.is_shutting_down = True 1051 self._status = AIO_SERVER_STATUS_STOPPING 1052 await self._start_shutting_down() 1053 1054 if grace is None: 1055 # Directly cancels all calls 1056 grpc_server_cancel_all_calls(self._server.c_server) 1057 await self._shutdown_completed 1058 else: 1059 try: 1060 await asyncio.wait_for( 1061 asyncio.shield(self._shutdown_completed), 1062 grace, 1063 ) 1064 except asyncio.TimeoutError: 1065 # Cancels all ongoing calls by the end of grace period. 1066 grpc_server_cancel_all_calls(self._server.c_server) 1067 await self._shutdown_completed 1068 1069 async with self._shutdown_lock: 1070 if self._status == AIO_SERVER_STATUS_STOPPING: 1071 grpc_server_destroy(self._server.c_server) 1072 self._server.c_server = NULL 1073 self._server.is_shutdown = True 1074 self._status = AIO_SERVER_STATUS_STOPPED 1075 1076 async def wait_for_termination(self, object timeout): 1077 if timeout is None: 1078 await self._shutdown_completed 1079 else: 1080 try: 1081 await asyncio.wait_for( 1082 asyncio.shield(self._shutdown_completed), 1083 timeout, 1084 ) 1085 except asyncio.TimeoutError: 1086 if self._crash_exception is not None: 1087 raise self._crash_exception 1088 return True 1089 if self._crash_exception is not None: 1090 raise self._crash_exception 1091 return False 1092 1093 def __dealloc__(self): 1094 """Deallocation of Core objects are ensured by Python layer.""" 1095 # TODO(lidiz) if users create server, and then dealloc it immediately. 1096 # There is a potential memory leak of created Core server. 1097 if self._status != AIO_SERVER_STATUS_STOPPED: 1098 _LOGGER.debug( 1099 '__dealloc__ called on running server %s with status %d', 1100 self, 1101 self._status 1102 ) 1103 shutdown_grpc_aio() 1104 1105 cdef thread_pool(self): 1106 """Access the thread pool instance.""" 1107 return self._thread_pool 1108 1109 def is_running(self): 1110 return self._status == AIO_SERVER_STATUS_RUNNING 1111