xref: /aosp_15_r20/external/grpc-grpc/src/python/grpcio/grpc/_cython/_cygrpc/aio/server.pyx.pxi (revision cc02d7e222339f7a4f6ba5f422e6413f4bd931f2)
1# Copyright 2019 The gRPC Authors
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15
16import inspect
17import traceback
18import functools
19
20
21cdef int _EMPTY_FLAG = 0
22cdef str _RPC_FINISHED_DETAILS = 'RPC already finished.'
23cdef str _SERVER_STOPPED_DETAILS = 'Server already stopped.'
24
25cdef _augment_metadata(tuple metadata, object compression):
26    if compression is None:
27        return metadata
28    else:
29        return ((
30            GRPC_COMPRESSION_REQUEST_ALGORITHM_MD_KEY,
31            _COMPRESSION_METADATA_STRING_MAPPING[compression]
32        ),) + metadata
33
34
35cdef class _HandlerCallDetails:
36    def __cinit__(self, str method, tuple invocation_metadata):
37        self.method = method
38        self.invocation_metadata = invocation_metadata
39
40
41class _ServerStoppedError(BaseError):
42    """Raised if the server is stopped."""
43
44
45cdef class RPCState:
46
47    def __cinit__(self, AioServer server):
48        init_grpc_aio()
49        self.call = NULL
50        self.server = server
51        grpc_metadata_array_init(&self.request_metadata)
52        grpc_call_details_init(&self.details)
53        self.client_closed = False
54        self.abort_exception = None
55        self.metadata_sent = False
56        self.status_sent = False
57        self.status_code = StatusCode.ok
58        self.py_status_code = None
59        self.status_details = ''
60        self.trailing_metadata = _IMMUTABLE_EMPTY_METADATA
61        self.compression_algorithm = None
62        self.disable_next_compression = False
63        self.callbacks = []
64
65    cdef bytes method(self):
66        return _slice_bytes(self.details.method)
67
68    cdef tuple invocation_metadata(self):
69        return _metadata(&self.request_metadata)
70
71    cdef void raise_for_termination(self) except *:
72        """Raise exceptions if RPC is not running.
73
74        Server method handlers may suppress the abort exception. We need to halt
75        the RPC execution in that case. This function needs to be called after
76        running application code.
77
78        Also, the server may stop unexpected. We need to check before calling
79        into Core functions, otherwise, segfault.
80        """
81        if self.abort_exception is not None:
82            raise self.abort_exception
83        if self.status_sent:
84            raise UsageError(_RPC_FINISHED_DETAILS)
85        if self.server._status == AIO_SERVER_STATUS_STOPPED:
86            raise _ServerStoppedError(_SERVER_STOPPED_DETAILS)
87
88    cdef int get_write_flag(self):
89        if self.disable_next_compression:
90            self.disable_next_compression = False
91            return WriteFlag.no_compress
92        else:
93            return _EMPTY_FLAG
94
95    cdef Operation create_send_initial_metadata_op_if_not_sent(self):
96        cdef SendInitialMetadataOperation op
97        if self.metadata_sent:
98            return None
99        else:
100            op = SendInitialMetadataOperation(
101                _augment_metadata(_IMMUTABLE_EMPTY_METADATA, self.compression_algorithm),
102                _EMPTY_FLAG
103            )
104            return op
105
106    def __dealloc__(self):
107        """Cleans the Core objects."""
108        grpc_call_details_destroy(&self.details)
109        grpc_metadata_array_destroy(&self.request_metadata)
110        if self.call:
111            grpc_call_unref(self.call)
112        shutdown_grpc_aio()
113
114
115cdef class _ServicerContext:
116
117    def __cinit__(self,
118                  RPCState rpc_state,
119                  object request_deserializer,
120                  object response_serializer,
121                  object loop):
122        self._rpc_state = rpc_state
123        self._request_deserializer = request_deserializer
124        self._response_serializer = response_serializer
125        self._loop = loop
126
127    async def read(self):
128        cdef bytes raw_message
129        self._rpc_state.raise_for_termination()
130
131        raw_message = await _receive_message(self._rpc_state, self._loop)
132        self._rpc_state.raise_for_termination()
133
134        if raw_message is None:
135            return EOF
136        else:
137            return deserialize(self._request_deserializer,
138                            raw_message)
139
140    async def write(self, object message):
141        self._rpc_state.raise_for_termination()
142
143        await _send_message(self._rpc_state,
144                            serialize(self._response_serializer, message),
145                            self._rpc_state.create_send_initial_metadata_op_if_not_sent(),
146                            self._rpc_state.get_write_flag(),
147                            self._loop)
148        self._rpc_state.metadata_sent = True
149
150    async def send_initial_metadata(self, object metadata):
151        self._rpc_state.raise_for_termination()
152
153        if self._rpc_state.metadata_sent:
154            raise UsageError('Send initial metadata failed: already sent')
155        else:
156            await _send_initial_metadata(
157                self._rpc_state,
158                _augment_metadata(tuple(metadata), self._rpc_state.compression_algorithm),
159                _EMPTY_FLAG,
160                self._loop
161            )
162            self._rpc_state.metadata_sent = True
163
164    async def abort(self,
165              object code,
166              str details='',
167              tuple trailing_metadata=_IMMUTABLE_EMPTY_METADATA):
168        if self._rpc_state.abort_exception is not None:
169            raise UsageError('Abort already called!')
170        else:
171            # Keeps track of the exception object. After abort happen, the RPC
172            # should stop execution. However, if users decided to suppress it, it
173            # could lead to undefined behavior.
174            self._rpc_state.abort_exception = AbortError('Locally aborted.')
175
176            if trailing_metadata == _IMMUTABLE_EMPTY_METADATA and self._rpc_state.trailing_metadata:
177                trailing_metadata = self._rpc_state.trailing_metadata
178            else:
179                raise_if_not_valid_trailing_metadata(trailing_metadata)
180                self._rpc_state.trailing_metadata = trailing_metadata
181
182            if details == '' and self._rpc_state.status_details:
183                details = self._rpc_state.status_details
184            else:
185                self._rpc_state.status_details = details
186
187            actual_code = get_status_code(code)
188            self._rpc_state.py_status_code = code
189            self._rpc_state.status_code = actual_code
190
191            self._rpc_state.status_sent = True
192            await _send_error_status_from_server(
193                self._rpc_state,
194                actual_code,
195                details,
196                trailing_metadata,
197                self._rpc_state.create_send_initial_metadata_op_if_not_sent(),
198                self._loop
199            )
200
201            raise self._rpc_state.abort_exception
202
203    async def abort_with_status(self, object status):
204        await self.abort(status.code, status.details, status.trailing_metadata)
205
206    def set_trailing_metadata(self, object metadata):
207        raise_if_not_valid_trailing_metadata(metadata)
208        self._rpc_state.trailing_metadata = tuple(metadata)
209
210    def trailing_metadata(self):
211        return self._rpc_state.trailing_metadata
212
213    def invocation_metadata(self):
214        return self._rpc_state.invocation_metadata()
215
216    def set_code(self, object code):
217        self._rpc_state.status_code = get_status_code(code)
218        self._rpc_state.py_status_code = code
219
220    def code(self):
221        return self._rpc_state.py_status_code
222
223    def set_details(self, str details):
224        self._rpc_state.status_details = details
225
226    def details(self):
227        return self._rpc_state.status_details
228
229    def set_compression(self, object compression):
230        if self._rpc_state.metadata_sent:
231            raise RuntimeError('Compression setting must be specified before sending initial metadata')
232        else:
233            self._rpc_state.compression_algorithm = compression
234
235    def disable_next_message_compression(self):
236        self._rpc_state.disable_next_compression = True
237
238    def peer(self):
239        cdef char *c_peer = NULL
240        c_peer = grpc_call_get_peer(self._rpc_state.call)
241        peer = (<bytes>c_peer).decode('utf8')
242        gpr_free(c_peer)
243        return peer
244
245    def peer_identities(self):
246        cdef Call query_call = Call()
247        query_call.c_call = self._rpc_state.call
248        identities = peer_identities(query_call)
249        query_call.c_call = NULL
250        return identities
251
252    def peer_identity_key(self):
253        cdef Call query_call = Call()
254        query_call.c_call = self._rpc_state.call
255        identity_key = peer_identity_key(query_call)
256        query_call.c_call = NULL
257        if identity_key:
258            return identity_key.decode('utf8')
259        else:
260            return None
261
262    def auth_context(self):
263        cdef Call query_call = Call()
264        query_call.c_call = self._rpc_state.call
265        bytes_ctx = auth_context(query_call)
266        query_call.c_call = NULL
267        if bytes_ctx:
268            ctx = {}
269            for key in bytes_ctx:
270                ctx[key.decode('utf8')] = bytes_ctx[key]
271            return ctx
272        else:
273            return {}
274
275    def time_remaining(self):
276        if self._rpc_state.details.deadline.seconds == _GPR_INF_FUTURE.seconds:
277            return None
278        else:
279            return max(_time_from_timespec(self._rpc_state.details.deadline) - time.time(), 0)
280
281    def add_done_callback(self, callback):
282        cb = functools.partial(callback, self)
283        self._rpc_state.callbacks.append(cb)
284
285    def done(self):
286        return self._rpc_state.status_sent
287
288    def cancelled(self):
289        return self._rpc_state.status_code == StatusCode.cancelled
290
291
292cdef class _SyncServicerContext:
293    """Sync servicer context for sync handler compatibility."""
294
295    def __cinit__(self,
296                  _ServicerContext context):
297        self._context = context
298        self._callbacks = []
299        self._loop = context._loop
300
301    def abort(self,
302              object code,
303              str details='',
304              tuple trailing_metadata=_IMMUTABLE_EMPTY_METADATA):
305        future = asyncio.run_coroutine_threadsafe(
306            self._context.abort(code, details, trailing_metadata),
307            self._loop)
308        # Abort should raise an AbortError
309        future.exception()
310
311    def send_initial_metadata(self, object metadata):
312        future = asyncio.run_coroutine_threadsafe(
313            self._context.send_initial_metadata(metadata),
314            self._loop)
315        future.result()
316
317    def set_trailing_metadata(self, object metadata):
318        self._context.set_trailing_metadata(metadata)
319
320    def invocation_metadata(self):
321        return self._context.invocation_metadata()
322
323    def set_code(self, object code):
324        self._context.set_code(code)
325
326    def set_details(self, str details):
327        self._context.set_details(details)
328
329    def set_compression(self, object compression):
330        self._context.set_compression(compression)
331
332    def disable_next_message_compression(self):
333        self._context.disable_next_message_compression()
334
335    def add_callback(self, object callback):
336        self._callbacks.append(callback)
337
338    def peer(self):
339        return self._context.peer()
340
341    def peer_identities(self):
342        return self._context.peer_identities()
343
344    def peer_identity_key(self):
345        return self._context.peer_identity_key()
346
347    def auth_context(self):
348        return self._context.auth_context()
349
350    def time_remaining(self):
351        return self._context.time_remaining()
352
353
354async def _run_interceptor(object interceptors, object query_handler,
355                           object handler_call_details):
356    interceptor = next(interceptors, None)
357    if interceptor:
358        continuation = functools.partial(_run_interceptor, interceptors,
359                                         query_handler)
360        return await interceptor.intercept_service(continuation, handler_call_details)
361    else:
362        return query_handler(handler_call_details)
363
364
365def _is_async_handler(object handler):
366    """Inspect if a method handler is async or sync."""
367    return inspect.isawaitable(handler) or inspect.iscoroutinefunction(handler) or inspect.isasyncgenfunction(handler)
368
369
370async def _find_method_handler(str method, tuple metadata, list generic_handlers,
371                          tuple interceptors):
372    def query_handlers(handler_call_details):
373        for generic_handler in generic_handlers:
374            method_handler = generic_handler.service(handler_call_details)
375            if method_handler is not None:
376                return method_handler
377        return None
378
379    cdef _HandlerCallDetails handler_call_details = _HandlerCallDetails(method,
380                                                                        metadata)
381    # interceptor
382    if interceptors:
383        return await _run_interceptor(iter(interceptors), query_handlers,
384                                      handler_call_details)
385    else:
386        return query_handlers(handler_call_details)
387
388
389async def _finish_handler_with_unary_response(RPCState rpc_state,
390                                              object unary_handler,
391                                              object request,
392                                              _ServicerContext servicer_context,
393                                              object response_serializer,
394                                              object loop):
395    """Finishes server method handler with a single response.
396
397    This function executes the application handler, and handles response
398    sending, as well as errors. It is shared between unary-unary and
399    stream-unary handlers.
400    """
401    # Executes application logic
402    cdef object response_message
403    cdef _SyncServicerContext sync_servicer_context
404    install_context_from_request_call_event_aio(rpc_state)
405
406    if _is_async_handler(unary_handler):
407        # Run async method handlers in this coroutine
408        response_message = await unary_handler(
409            request,
410            servicer_context,
411        )
412    else:
413        # Run sync method handlers in the thread pool
414        sync_servicer_context = _SyncServicerContext(servicer_context)
415        response_message = await loop.run_in_executor(
416            rpc_state.server.thread_pool(),
417            unary_handler,
418            request,
419            sync_servicer_context,
420        )
421        # Support sync-stack callback
422        for callback in sync_servicer_context._callbacks:
423            callback()
424
425    # Raises exception if aborted
426    rpc_state.raise_for_termination()
427
428    # Serializes the response message
429    cdef bytes response_raw
430    if rpc_state.status_code == StatusCode.ok:
431        response_raw = serialize(
432            response_serializer,
433            response_message,
434        )
435    else:
436        # Discards the response message if the status code is non-OK.
437        response_raw = b''
438
439    # Assembles the batch operations
440    cdef tuple finish_ops
441    finish_ops = (
442        SendMessageOperation(response_raw, rpc_state.get_write_flag()),
443        SendStatusFromServerOperation(
444            rpc_state.trailing_metadata,
445            rpc_state.status_code,
446            rpc_state.status_details,
447            _EMPTY_FLAGS,
448        ),
449    )
450    if not rpc_state.metadata_sent:
451        finish_ops = prepend_send_initial_metadata_op(
452            finish_ops,
453            None)
454    rpc_state.metadata_sent = True
455    rpc_state.status_sent = True
456    await execute_batch(rpc_state, finish_ops, loop)
457    uninstall_context()
458
459
460async def _finish_handler_with_stream_responses(RPCState rpc_state,
461                                                object stream_handler,
462                                                object request,
463                                                _ServicerContext servicer_context,
464                                                object loop):
465    """Finishes server method handler with multiple responses.
466
467    This function executes the application handler, and handles response
468    sending, as well as errors. It is shared between unary-stream and
469    stream-stream handlers.
470    """
471    cdef object async_response_generator
472    cdef object response_message
473    install_context_from_request_call_event_aio(rpc_state)
474
475    if inspect.iscoroutinefunction(stream_handler):
476        # Case 1: Coroutine async handler - using reader-writer API
477        # The handler uses reader / writer API, returns None.
478        await stream_handler(
479            request,
480            servicer_context,
481        )
482    else:
483        if inspect.isasyncgenfunction(stream_handler):
484            # Case 2: Async handler - async generator
485            # The handler uses async generator API
486            async_response_generator = stream_handler(
487                request,
488                servicer_context,
489            )
490        else:
491            # Case 3: Sync handler - normal generator
492            # NOTE(lidiz) Streaming handler in sync stack is either a generator
493            # function or a function returns a generator.
494            sync_servicer_context = _SyncServicerContext(servicer_context)
495            gen = stream_handler(request, sync_servicer_context)
496            async_response_generator = generator_to_async_generator(gen,
497                                                                    loop,
498                                                                    rpc_state.server.thread_pool())
499
500        # Consumes messages from the generator
501        async for response_message in async_response_generator:
502            # Raises exception if aborted
503            rpc_state.raise_for_termination()
504
505            await servicer_context.write(response_message)
506
507    # Raises exception if aborted
508    rpc_state.raise_for_termination()
509
510    # Sends the final status of this RPC
511    cdef SendStatusFromServerOperation op = SendStatusFromServerOperation(
512        rpc_state.trailing_metadata,
513        rpc_state.status_code,
514        rpc_state.status_details,
515        _EMPTY_FLAGS,
516    )
517
518    cdef tuple finish_ops = (op,)
519    if not rpc_state.metadata_sent:
520        finish_ops = prepend_send_initial_metadata_op(
521            finish_ops,
522            None
523        )
524    rpc_state.metadata_sent = True
525    rpc_state.status_sent = True
526    await execute_batch(rpc_state, finish_ops, loop)
527    uninstall_context()
528
529
530async def _handle_unary_unary_rpc(object method_handler,
531                                  RPCState rpc_state,
532                                  object loop):
533    # Receives request message
534    cdef bytes request_raw = await _receive_message(rpc_state, loop)
535    if request_raw is None:
536        # The RPC was cancelled immediately after start on client side.
537        return
538
539    # Deserializes the request message
540    cdef object request_message = deserialize(
541        method_handler.request_deserializer,
542        request_raw,
543    )
544
545    # Creates a dedecated ServicerContext
546    cdef _ServicerContext servicer_context = _ServicerContext(
547        rpc_state,
548        None,
549        None,
550        loop,
551    )
552
553    # Finishes the application handler
554    await _finish_handler_with_unary_response(
555        rpc_state,
556        method_handler.unary_unary,
557        request_message,
558        servicer_context,
559        method_handler.response_serializer,
560        loop
561    )
562
563
564async def _handle_unary_stream_rpc(object method_handler,
565                                   RPCState rpc_state,
566                                   object loop):
567    # Receives request message
568    cdef bytes request_raw = await _receive_message(rpc_state, loop)
569    if request_raw is None:
570        return
571
572    # Deserializes the request message
573    cdef object request_message = deserialize(
574        method_handler.request_deserializer,
575        request_raw,
576    )
577
578    # Creates a dedecated ServicerContext
579    cdef _ServicerContext servicer_context = _ServicerContext(
580        rpc_state,
581        method_handler.request_deserializer,
582        method_handler.response_serializer,
583        loop,
584    )
585
586    # Finishes the application handler
587    await _finish_handler_with_stream_responses(
588        rpc_state,
589        method_handler.unary_stream,
590        request_message,
591        servicer_context,
592        loop,
593    )
594
595
596cdef class _MessageReceiver:
597    """Bridge between the async generator API and the reader-writer API."""
598
599    def __cinit__(self, _ServicerContext servicer_context):
600        self._servicer_context = servicer_context
601        self._agen = None
602
603    async def _async_message_receiver(self):
604        """An async generator that receives messages."""
605        cdef object message
606        while True:
607            message = await self._servicer_context.read()
608            if message is not EOF:
609                yield message
610            else:
611                break
612
613    def __aiter__(self):
614        # Prevents never awaited warning if application never used the async generator
615        if self._agen is None:
616            self._agen = self._async_message_receiver()
617        return self._agen
618
619    async def __anext__(self):
620        return await self.__aiter__().__anext__()
621
622
623async def _handle_stream_unary_rpc(object method_handler,
624                                   RPCState rpc_state,
625                                   object loop):
626    # Creates a dedecated ServicerContext
627    cdef _ServicerContext servicer_context = _ServicerContext(
628        rpc_state,
629        method_handler.request_deserializer,
630        None,
631        loop,
632    )
633
634    # Prepares the request generator
635    cdef object request_iterator
636    if _is_async_handler(method_handler.stream_unary):
637        request_iterator = _MessageReceiver(servicer_context)
638    else:
639        request_iterator = async_generator_to_generator(
640            _MessageReceiver(servicer_context),
641            loop
642        )
643
644    # Finishes the application handler
645    await _finish_handler_with_unary_response(
646        rpc_state,
647        method_handler.stream_unary,
648        request_iterator,
649        servicer_context,
650        method_handler.response_serializer,
651        loop
652    )
653
654
655async def _handle_stream_stream_rpc(object method_handler,
656                                    RPCState rpc_state,
657                                    object loop):
658    # Creates a dedecated ServicerContext
659    cdef _ServicerContext servicer_context = _ServicerContext(
660        rpc_state,
661        method_handler.request_deserializer,
662        method_handler.response_serializer,
663        loop,
664    )
665
666    # Prepares the request generator
667    cdef object request_iterator
668    if _is_async_handler(method_handler.stream_stream):
669        request_iterator = _MessageReceiver(servicer_context)
670    else:
671        request_iterator = async_generator_to_generator(
672            _MessageReceiver(servicer_context),
673            loop
674        )
675
676    # Finishes the application handler
677    await _finish_handler_with_stream_responses(
678        rpc_state,
679        method_handler.stream_stream,
680        request_iterator,
681        servicer_context,
682        loop,
683    )
684
685
686async def _handle_exceptions(RPCState rpc_state, object rpc_coro, object loop):
687    try:
688        try:
689            await rpc_coro
690        except AbortError as e:
691            # Caught AbortError check if it is the same one
692            assert rpc_state.abort_exception is e, 'Abort error has been replaced!'
693            return
694        else:
695            # Check if the abort exception got suppressed
696            if rpc_state.abort_exception is not None:
697                _LOGGER.error(
698                    'Abort error unexpectedly suppressed: %s',
699                    traceback.format_exception(rpc_state.abort_exception)
700                )
701    except (KeyboardInterrupt, SystemExit):
702        raise
703    except asyncio.CancelledError:
704        _LOGGER.debug('RPC cancelled for servicer method [%s]', _decode(rpc_state.method()))
705    except _ServerStoppedError:
706        _LOGGER.warning('Aborting method [%s] due to server stop.', _decode(rpc_state.method()))
707    except ExecuteBatchError:
708        # If client closed (aka. cancelled), ignore the failed batch operations.
709        if rpc_state.client_closed:
710            return
711        else:
712            _LOGGER.exception('ExecuteBatchError raised in core by servicer method [%s]' % (
713                _decode(rpc_state.method())))
714            return
715    except Exception as e:
716        _LOGGER.exception('Unexpected [%s] raised by servicer method [%s]' % (
717            type(e).__name__,
718            _decode(rpc_state.method()),
719        ))
720        if not rpc_state.status_sent and rpc_state.server._status != AIO_SERVER_STATUS_STOPPED:
721            # Allows users to raise other types of exception with specified status code
722            if rpc_state.status_code == StatusCode.ok:
723                status_code = StatusCode.unknown
724            else:
725                status_code = rpc_state.status_code
726
727            rpc_state.status_sent = True
728            try:
729                await _send_error_status_from_server(
730                    rpc_state,
731                    status_code,
732                    'Unexpected %s: %s' % (type(e), e),
733                    rpc_state.trailing_metadata,
734                    rpc_state.create_send_initial_metadata_op_if_not_sent(),
735                    loop
736                )
737            except ExecuteBatchError:
738                _LOGGER.exception('Failed sending error status from server')
739                traceback.print_exc()
740
741
742cdef _add_callback_handler(object rpc_task, RPCState rpc_state):
743
744    def handle_callbacks(object unused_task):
745        try:
746            for callback in rpc_state.callbacks:
747                # The _ServicerContext object is bound in add_done_callback.
748                callback()
749        except:
750            _LOGGER.exception('Error in callback for method [%s]', _decode(rpc_state.method()))
751
752    rpc_task.add_done_callback(handle_callbacks)
753
754
755async def _handle_cancellation_from_core(object rpc_task,
756                                         RPCState rpc_state,
757                                         object loop):
758    cdef ReceiveCloseOnServerOperation op = ReceiveCloseOnServerOperation(_EMPTY_FLAG)
759    cdef tuple ops = (op,)
760
761    # Awaits cancellation from peer.
762    await execute_batch(rpc_state, ops, loop)
763    rpc_state.client_closed = True
764    # If 1) received cancel signal; 2) the Task is not finished; 3) the server
765    # wasn't replying final status. For condition 3, it might cause inaccurate
766    # log that an RPC is both aborted and cancelled.
767    if op.cancelled() and not rpc_task.done() and not rpc_state.status_sent:
768        # Injects `CancelledError` to halt the RPC coroutine
769        rpc_task.cancel()
770
771
772async def _schedule_rpc_coro(object rpc_coro,
773                             RPCState rpc_state,
774                             object loop):
775    # Schedules the RPC coroutine.
776    cdef object rpc_task = loop.create_task(_handle_exceptions(
777        rpc_state,
778        rpc_coro,
779        loop,
780    ))
781    _add_callback_handler(rpc_task, rpc_state)
782    await _handle_cancellation_from_core(rpc_task, rpc_state, loop)
783
784
785async def _handle_rpc(list generic_handlers, tuple interceptors,
786                      RPCState rpc_state, object loop, bint concurrency_exceeded):
787    cdef object method_handler
788    # Finds the method handler (application logic)
789    method_handler = await _find_method_handler(
790        rpc_state.method().decode(),
791        rpc_state.invocation_metadata(),
792        generic_handlers,
793        interceptors,
794    )
795    if method_handler is None:
796        rpc_state.status_sent = True
797        await _send_error_status_from_server(
798            rpc_state,
799            StatusCode.unimplemented,
800            'Method not found!',
801            _IMMUTABLE_EMPTY_METADATA,
802            rpc_state.create_send_initial_metadata_op_if_not_sent(),
803            loop
804        )
805        return
806
807    if concurrency_exceeded:
808        rpc_state.status_sent = True
809        await _send_error_status_from_server(
810            rpc_state,
811            StatusCode.resource_exhausted,
812            'Concurrent RPC limit exceeded!',
813            _IMMUTABLE_EMPTY_METADATA,
814            rpc_state.create_send_initial_metadata_op_if_not_sent(),
815            loop
816        )
817        return
818
819    # Handles unary-unary case
820    if not method_handler.request_streaming and not method_handler.response_streaming:
821        await _handle_unary_unary_rpc(method_handler,
822                                      rpc_state,
823                                      loop)
824        return
825
826    # Handles unary-stream case
827    if not method_handler.request_streaming and method_handler.response_streaming:
828        await _handle_unary_stream_rpc(method_handler,
829                                       rpc_state,
830                                       loop)
831        return
832
833    # Handles stream-unary case
834    if method_handler.request_streaming and not method_handler.response_streaming:
835        await _handle_stream_unary_rpc(method_handler,
836                                       rpc_state,
837                                       loop)
838        return
839
840    # Handles stream-stream case
841    if method_handler.request_streaming and method_handler.response_streaming:
842        await _handle_stream_stream_rpc(method_handler,
843                                        rpc_state,
844                                        loop)
845        return
846
847
848class _RequestCallError(Exception): pass
849
850cdef CallbackFailureHandler REQUEST_CALL_FAILURE_HANDLER = CallbackFailureHandler(
851    'grpc_server_request_call', None, _RequestCallError)
852
853
854cdef CallbackFailureHandler SERVER_SHUTDOWN_FAILURE_HANDLER = CallbackFailureHandler(
855    'grpc_server_shutdown_and_notify',
856    None,
857    InternalError)
858
859
860cdef class _ConcurrentRpcLimiter:
861
862    def __cinit__(self, int maximum_concurrent_rpcs):
863        if maximum_concurrent_rpcs <= 0:
864            raise ValueError("maximum_concurrent_rpcs should be a postive integer")
865        self._maximum_concurrent_rpcs = maximum_concurrent_rpcs
866        self._active_rpcs = 0
867        self.limiter_concurrency_exceeded = False
868
869    def check_before_request_call(self):
870        if self._active_rpcs >= self._maximum_concurrent_rpcs:
871            self.limiter_concurrency_exceeded = True
872        else:
873            self._active_rpcs += 1
874
875    def _decrease_active_rpcs_count(self, unused_future):
876        self._active_rpcs -= 1
877        if self._active_rpcs < self._maximum_concurrent_rpcs:
878            self.limiter_concurrency_exceeded = False
879
880    def decrease_once_finished(self, object rpc_task):
881        rpc_task.add_done_callback(self._decrease_active_rpcs_count)
882
883
884cdef class AioServer:
885
886    def __init__(self, loop, thread_pool, generic_handlers, interceptors,
887                 options, maximum_concurrent_rpcs):
888        init_grpc_aio()
889        # NOTE(lidiz) Core objects won't be deallocated automatically.
890        # If AioServer.shutdown is not called, those objects will leak.
891        # TODO(rbellevi): Support xDS in aio server.
892        self._server = Server(options, False)
893        grpc_server_register_completion_queue(
894            self._server.c_server,
895            global_completion_queue(),
896            NULL
897        )
898
899        self._loop = loop
900        self._status = AIO_SERVER_STATUS_READY
901        self._generic_handlers = []
902        self.add_generic_rpc_handlers(generic_handlers)
903        self._serving_task = None
904
905        self._shutdown_lock = asyncio.Lock()
906        self._shutdown_completed = self._loop.create_future()
907        self._shutdown_callback_wrapper = CallbackWrapper(
908            self._shutdown_completed,
909            self._loop,
910            SERVER_SHUTDOWN_FAILURE_HANDLER)
911        self._crash_exception = None
912
913        if interceptors:
914            self._interceptors = tuple(interceptors)
915        else:
916            self._interceptors = ()
917
918        self._thread_pool = thread_pool
919        if maximum_concurrent_rpcs is not None:
920            self._limiter = _ConcurrentRpcLimiter(maximum_concurrent_rpcs)
921
922    def add_generic_rpc_handlers(self, object generic_rpc_handlers):
923        self._generic_handlers.extend(generic_rpc_handlers)
924
925    def add_insecure_port(self, address):
926        return self._server.add_http2_port(address)
927
928    def add_secure_port(self, address, server_credentials):
929        return self._server.add_http2_port(address,
930                                           server_credentials._credentials)
931
932    async def _request_call(self):
933        cdef grpc_call_error error
934        cdef RPCState rpc_state = RPCState(self)
935        cdef object future = self._loop.create_future()
936        cdef CallbackWrapper wrapper = CallbackWrapper(
937            future,
938            self._loop,
939            REQUEST_CALL_FAILURE_HANDLER)
940        error = grpc_server_request_call(
941            self._server.c_server, &rpc_state.call, &rpc_state.details,
942            &rpc_state.request_metadata,
943            global_completion_queue(), global_completion_queue(),
944            wrapper.c_functor()
945        )
946        if error != GRPC_CALL_OK:
947            raise InternalError("Error in grpc_server_request_call: %s" % error)
948
949        await future
950        return rpc_state
951
952    async def _server_main_loop(self,
953                                object server_started):
954        self._server.start(backup_queue=False)
955        cdef RPCState rpc_state
956        server_started.set_result(True)
957
958        while True:
959            # When shutdown begins, no more new connections.
960            if self._status != AIO_SERVER_STATUS_RUNNING:
961                break
962
963            concurrency_exceeded = False
964            if self._limiter is not None:
965                self._limiter.check_before_request_call()
966                concurrency_exceeded = self._limiter.limiter_concurrency_exceeded
967
968            # Accepts new request from Core
969            rpc_state = await self._request_call()
970
971            # Creates the dedicated RPC coroutine. If we schedule it right now,
972            # there is no guarantee if the cancellation listening coroutine is
973            # ready or not. So, we should control the ordering by scheduling
974            # the coroutine onto event loop inside of the cancellation
975            # coroutine.
976            rpc_coro = _handle_rpc(self._generic_handlers,
977                                   self._interceptors,
978                                   rpc_state,
979                                   self._loop,
980                                   concurrency_exceeded)
981
982            # Fires off a task that listens on the cancellation from client.
983            rpc_task = self._loop.create_task(
984                _schedule_rpc_coro(
985                    rpc_coro,
986                    rpc_state,
987                    self._loop
988                )
989            )
990
991            if self._limiter is not None:
992                self._limiter.decrease_once_finished(rpc_task)
993
994    def _serving_task_crash_handler(self, object task):
995        """Shutdown the server immediately if unexpectedly exited."""
996        if task.cancelled():
997            return
998        if task.exception() is None:
999            return
1000        if self._status != AIO_SERVER_STATUS_STOPPING:
1001            self._crash_exception = task.exception()
1002            _LOGGER.exception(self._crash_exception)
1003            self._loop.create_task(self.shutdown(None))
1004
1005    async def start(self):
1006        if self._status == AIO_SERVER_STATUS_RUNNING:
1007            return
1008        elif self._status != AIO_SERVER_STATUS_READY:
1009            raise UsageError('Server not in ready state')
1010
1011        self._status = AIO_SERVER_STATUS_RUNNING
1012        cdef object server_started = self._loop.create_future()
1013        self._serving_task = self._loop.create_task(self._server_main_loop(server_started))
1014        self._serving_task.add_done_callback(self._serving_task_crash_handler)
1015        # Needs to explicitly wait for the server to start up.
1016        # Otherwise, the actual start time of the server is un-controllable.
1017        await server_started
1018
1019    async def _start_shutting_down(self):
1020        """Prepares the server to shutting down.
1021
1022        This coroutine function is NOT coroutine-safe.
1023        """
1024        # The shutdown callback won't be called until there is no live RPC.
1025        grpc_server_shutdown_and_notify(
1026            self._server.c_server,
1027            global_completion_queue(),
1028            self._shutdown_callback_wrapper.c_functor())
1029
1030        # Ensures the serving task (coroutine) exits.
1031        try:
1032            await self._serving_task
1033        except _RequestCallError:
1034            pass
1035
1036    async def shutdown(self, grace):
1037        """Gracefully shutdown the Core server.
1038
1039        Application should only call shutdown once.
1040
1041        Args:
1042          grace: An optional float indicating the length of grace period in
1043            seconds.
1044        """
1045        if self._status == AIO_SERVER_STATUS_READY or self._status == AIO_SERVER_STATUS_STOPPED:
1046            return
1047
1048        async with self._shutdown_lock:
1049            if self._status == AIO_SERVER_STATUS_RUNNING:
1050                self._server.is_shutting_down = True
1051                self._status = AIO_SERVER_STATUS_STOPPING
1052                await self._start_shutting_down()
1053
1054        if grace is None:
1055            # Directly cancels all calls
1056            grpc_server_cancel_all_calls(self._server.c_server)
1057            await self._shutdown_completed
1058        else:
1059            try:
1060                await asyncio.wait_for(
1061                    asyncio.shield(self._shutdown_completed),
1062                    grace,
1063                )
1064            except asyncio.TimeoutError:
1065                # Cancels all ongoing calls by the end of grace period.
1066                grpc_server_cancel_all_calls(self._server.c_server)
1067                await self._shutdown_completed
1068
1069        async with self._shutdown_lock:
1070            if self._status == AIO_SERVER_STATUS_STOPPING:
1071                grpc_server_destroy(self._server.c_server)
1072                self._server.c_server = NULL
1073                self._server.is_shutdown = True
1074                self._status = AIO_SERVER_STATUS_STOPPED
1075
1076    async def wait_for_termination(self, object timeout):
1077        if timeout is None:
1078            await self._shutdown_completed
1079        else:
1080            try:
1081                await asyncio.wait_for(
1082                    asyncio.shield(self._shutdown_completed),
1083                    timeout,
1084                )
1085            except asyncio.TimeoutError:
1086                if self._crash_exception is not None:
1087                    raise self._crash_exception
1088                return True
1089        if self._crash_exception is not None:
1090            raise self._crash_exception
1091        return False
1092
1093    def __dealloc__(self):
1094        """Deallocation of Core objects are ensured by Python layer."""
1095        # TODO(lidiz) if users create server, and then dealloc it immediately.
1096        # There is a potential memory leak of created Core server.
1097        if self._status != AIO_SERVER_STATUS_STOPPED:
1098            _LOGGER.debug(
1099                '__dealloc__ called on running server %s with status %d',
1100                self,
1101                self._status
1102            )
1103        shutdown_grpc_aio()
1104
1105    cdef thread_pool(self):
1106        """Access the thread pool instance."""
1107        return self._thread_pool
1108
1109    def is_running(self):
1110        return self._status == AIO_SERVER_STATUS_RUNNING
1111