xref: /aosp_15_r20/external/grpc-grpc/src/python/grpcio/grpc/_server.py (revision cc02d7e222339f7a4f6ba5f422e6413f4bd931f2)
1# Copyright 2016 gRPC authors.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""Service-side implementation of gRPC Python."""
15
16from __future__ import annotations
17
18import collections
19from concurrent import futures
20import contextvars
21import enum
22import logging
23import threading
24import time
25import traceback
26from typing import (
27    Any,
28    Callable,
29    Iterable,
30    Iterator,
31    List,
32    Mapping,
33    Optional,
34    Sequence,
35    Set,
36    Tuple,
37    Union,
38)
39
40import grpc  # pytype: disable=pyi-error
41from grpc import _common  # pytype: disable=pyi-error
42from grpc import _compression  # pytype: disable=pyi-error
43from grpc import _interceptor  # pytype: disable=pyi-error
44from grpc._cython import cygrpc
45from grpc._typing import ArityAgnosticMethodHandler
46from grpc._typing import ChannelArgumentType
47from grpc._typing import DeserializingFunction
48from grpc._typing import MetadataType
49from grpc._typing import NullaryCallbackType
50from grpc._typing import ResponseType
51from grpc._typing import SerializingFunction
52from grpc._typing import ServerCallbackTag
53from grpc._typing import ServerTagCallbackType
54
55_LOGGER = logging.getLogger(__name__)
56
57_SHUTDOWN_TAG = "shutdown"
58_REQUEST_CALL_TAG = "request_call"
59
60_RECEIVE_CLOSE_ON_SERVER_TOKEN = "receive_close_on_server"
61_SEND_INITIAL_METADATA_TOKEN = "send_initial_metadata"
62_RECEIVE_MESSAGE_TOKEN = "receive_message"
63_SEND_MESSAGE_TOKEN = "send_message"
64_SEND_INITIAL_METADATA_AND_SEND_MESSAGE_TOKEN = (
65    "send_initial_metadata * send_message"
66)
67_SEND_STATUS_FROM_SERVER_TOKEN = "send_status_from_server"
68_SEND_INITIAL_METADATA_AND_SEND_STATUS_FROM_SERVER_TOKEN = (
69    "send_initial_metadata * send_status_from_server"
70)
71
72_OPEN = "open"
73_CLOSED = "closed"
74_CANCELLED = "cancelled"
75
76_EMPTY_FLAGS = 0
77
78_DEALLOCATED_SERVER_CHECK_PERIOD_S = 1.0
79_INF_TIMEOUT = 1e9
80
81
82def _serialized_request(request_event: cygrpc.BaseEvent) -> bytes:
83    return request_event.batch_operations[0].message()
84
85
86def _application_code(code: grpc.StatusCode) -> cygrpc.StatusCode:
87    cygrpc_code = _common.STATUS_CODE_TO_CYGRPC_STATUS_CODE.get(code)
88    return cygrpc.StatusCode.unknown if cygrpc_code is None else cygrpc_code
89
90
91def _completion_code(state: _RPCState) -> cygrpc.StatusCode:
92    if state.code is None:
93        return cygrpc.StatusCode.ok
94    else:
95        return _application_code(state.code)
96
97
98def _abortion_code(
99    state: _RPCState, code: cygrpc.StatusCode
100) -> cygrpc.StatusCode:
101    if state.code is None:
102        return code
103    else:
104        return _application_code(state.code)
105
106
107def _details(state: _RPCState) -> bytes:
108    return b"" if state.details is None else state.details
109
110
111class _HandlerCallDetails(
112    collections.namedtuple(
113        "_HandlerCallDetails",
114        (
115            "method",
116            "invocation_metadata",
117        ),
118    ),
119    grpc.HandlerCallDetails,
120):
121    pass
122
123
124class _RPCState(object):
125    context: contextvars.Context
126    condition: threading.Condition
127    due = Set[str]
128    request: Any
129    client: str
130    initial_metadata_allowed: bool
131    compression_algorithm: Optional[grpc.Compression]
132    disable_next_compression: bool
133    trailing_metadata: Optional[MetadataType]
134    code: Optional[grpc.StatusCode]
135    details: Optional[bytes]
136    statused: bool
137    rpc_errors: List[Exception]
138    callbacks: Optional[List[NullaryCallbackType]]
139    aborted: bool
140
141    def __init__(self):
142        self.context = contextvars.Context()
143        self.condition = threading.Condition()
144        self.due = set()
145        self.request = None
146        self.client = _OPEN
147        self.initial_metadata_allowed = True
148        self.compression_algorithm = None
149        self.disable_next_compression = False
150        self.trailing_metadata = None
151        self.code = None
152        self.details = None
153        self.statused = False
154        self.rpc_errors = []
155        self.callbacks = []
156        self.aborted = False
157
158
159def _raise_rpc_error(state: _RPCState) -> None:
160    rpc_error = grpc.RpcError()
161    state.rpc_errors.append(rpc_error)
162    raise rpc_error
163
164
165def _possibly_finish_call(
166    state: _RPCState, token: str
167) -> ServerTagCallbackType:
168    state.due.remove(token)
169    if not _is_rpc_state_active(state) and not state.due:
170        callbacks = state.callbacks
171        state.callbacks = None
172        return state, callbacks
173    else:
174        return None, ()
175
176
177def _send_status_from_server(state: _RPCState, token: str) -> ServerCallbackTag:
178    def send_status_from_server(unused_send_status_from_server_event):
179        with state.condition:
180            return _possibly_finish_call(state, token)
181
182    return send_status_from_server
183
184
185def _get_initial_metadata(
186    state: _RPCState, metadata: Optional[MetadataType]
187) -> Optional[MetadataType]:
188    with state.condition:
189        if state.compression_algorithm:
190            compression_metadata = (
191                _compression.compression_algorithm_to_metadata(
192                    state.compression_algorithm
193                ),
194            )
195            if metadata is None:
196                return compression_metadata
197            else:
198                return compression_metadata + tuple(metadata)
199        else:
200            return metadata
201
202
203def _get_initial_metadata_operation(
204    state: _RPCState, metadata: Optional[MetadataType]
205) -> cygrpc.Operation:
206    operation = cygrpc.SendInitialMetadataOperation(
207        _get_initial_metadata(state, metadata), _EMPTY_FLAGS
208    )
209    return operation
210
211
212def _abort(
213    state: _RPCState, call: cygrpc.Call, code: cygrpc.StatusCode, details: bytes
214) -> None:
215    if state.client is not _CANCELLED:
216        effective_code = _abortion_code(state, code)
217        effective_details = details if state.details is None else state.details
218        if state.initial_metadata_allowed:
219            operations = (
220                _get_initial_metadata_operation(state, None),
221                cygrpc.SendStatusFromServerOperation(
222                    state.trailing_metadata,
223                    effective_code,
224                    effective_details,
225                    _EMPTY_FLAGS,
226                ),
227            )
228            token = _SEND_INITIAL_METADATA_AND_SEND_STATUS_FROM_SERVER_TOKEN
229        else:
230            operations = (
231                cygrpc.SendStatusFromServerOperation(
232                    state.trailing_metadata,
233                    effective_code,
234                    effective_details,
235                    _EMPTY_FLAGS,
236                ),
237            )
238            token = _SEND_STATUS_FROM_SERVER_TOKEN
239        call.start_server_batch(
240            operations, _send_status_from_server(state, token)
241        )
242        state.statused = True
243        state.due.add(token)
244
245
246def _receive_close_on_server(state: _RPCState) -> ServerCallbackTag:
247    def receive_close_on_server(receive_close_on_server_event):
248        with state.condition:
249            if receive_close_on_server_event.batch_operations[0].cancelled():
250                state.client = _CANCELLED
251            elif state.client is _OPEN:
252                state.client = _CLOSED
253            state.condition.notify_all()
254            return _possibly_finish_call(state, _RECEIVE_CLOSE_ON_SERVER_TOKEN)
255
256    return receive_close_on_server
257
258
259def _receive_message(
260    state: _RPCState,
261    call: cygrpc.Call,
262    request_deserializer: Optional[DeserializingFunction],
263) -> ServerCallbackTag:
264    def receive_message(receive_message_event):
265        serialized_request = _serialized_request(receive_message_event)
266        if serialized_request is None:
267            with state.condition:
268                if state.client is _OPEN:
269                    state.client = _CLOSED
270                state.condition.notify_all()
271                return _possibly_finish_call(state, _RECEIVE_MESSAGE_TOKEN)
272        else:
273            request = _common.deserialize(
274                serialized_request, request_deserializer
275            )
276            with state.condition:
277                if request is None:
278                    _abort(
279                        state,
280                        call,
281                        cygrpc.StatusCode.internal,
282                        b"Exception deserializing request!",
283                    )
284                else:
285                    state.request = request
286                state.condition.notify_all()
287                return _possibly_finish_call(state, _RECEIVE_MESSAGE_TOKEN)
288
289    return receive_message
290
291
292def _send_initial_metadata(state: _RPCState) -> ServerCallbackTag:
293    def send_initial_metadata(unused_send_initial_metadata_event):
294        with state.condition:
295            return _possibly_finish_call(state, _SEND_INITIAL_METADATA_TOKEN)
296
297    return send_initial_metadata
298
299
300def _send_message(state: _RPCState, token: str) -> ServerCallbackTag:
301    def send_message(unused_send_message_event):
302        with state.condition:
303            state.condition.notify_all()
304            return _possibly_finish_call(state, token)
305
306    return send_message
307
308
309class _Context(grpc.ServicerContext):
310    _rpc_event: cygrpc.BaseEvent
311    _state: _RPCState
312    request_deserializer: Optional[DeserializingFunction]
313
314    def __init__(
315        self,
316        rpc_event: cygrpc.BaseEvent,
317        state: _RPCState,
318        request_deserializer: Optional[DeserializingFunction],
319    ):
320        self._rpc_event = rpc_event
321        self._state = state
322        self._request_deserializer = request_deserializer
323
324    def is_active(self) -> bool:
325        with self._state.condition:
326            return _is_rpc_state_active(self._state)
327
328    def time_remaining(self) -> float:
329        return max(self._rpc_event.call_details.deadline - time.time(), 0)
330
331    def cancel(self) -> None:
332        self._rpc_event.call.cancel()
333
334    def add_callback(self, callback: NullaryCallbackType) -> bool:
335        with self._state.condition:
336            if self._state.callbacks is None:
337                return False
338            else:
339                self._state.callbacks.append(callback)
340                return True
341
342    def disable_next_message_compression(self) -> None:
343        with self._state.condition:
344            self._state.disable_next_compression = True
345
346    def invocation_metadata(self) -> Optional[MetadataType]:
347        return self._rpc_event.invocation_metadata
348
349    def peer(self) -> str:
350        return _common.decode(self._rpc_event.call.peer())
351
352    def peer_identities(self) -> Optional[Sequence[bytes]]:
353        return cygrpc.peer_identities(self._rpc_event.call)
354
355    def peer_identity_key(self) -> Optional[str]:
356        id_key = cygrpc.peer_identity_key(self._rpc_event.call)
357        return id_key if id_key is None else _common.decode(id_key)
358
359    def auth_context(self) -> Mapping[str, Sequence[bytes]]:
360        auth_context = cygrpc.auth_context(self._rpc_event.call)
361        auth_context_dict = {} if auth_context is None else auth_context
362        return {
363            _common.decode(key): value
364            for key, value in auth_context_dict.items()
365        }
366
367    def set_compression(self, compression: grpc.Compression) -> None:
368        with self._state.condition:
369            self._state.compression_algorithm = compression
370
371    def send_initial_metadata(self, initial_metadata: MetadataType) -> None:
372        with self._state.condition:
373            if self._state.client is _CANCELLED:
374                _raise_rpc_error(self._state)
375            else:
376                if self._state.initial_metadata_allowed:
377                    operation = _get_initial_metadata_operation(
378                        self._state, initial_metadata
379                    )
380                    self._rpc_event.call.start_server_batch(
381                        (operation,), _send_initial_metadata(self._state)
382                    )
383                    self._state.initial_metadata_allowed = False
384                    self._state.due.add(_SEND_INITIAL_METADATA_TOKEN)
385                else:
386                    raise ValueError("Initial metadata no longer allowed!")
387
388    def set_trailing_metadata(self, trailing_metadata: MetadataType) -> None:
389        with self._state.condition:
390            self._state.trailing_metadata = trailing_metadata
391
392    def trailing_metadata(self) -> Optional[MetadataType]:
393        return self._state.trailing_metadata
394
395    def abort(self, code: grpc.StatusCode, details: str) -> None:
396        # treat OK like other invalid arguments: fail the RPC
397        if code == grpc.StatusCode.OK:
398            _LOGGER.error(
399                "abort() called with StatusCode.OK; returning UNKNOWN"
400            )
401            code = grpc.StatusCode.UNKNOWN
402            details = ""
403        with self._state.condition:
404            self._state.code = code
405            self._state.details = _common.encode(details)
406            self._state.aborted = True
407            raise Exception()
408
409    def abort_with_status(self, status: grpc.Status) -> None:
410        self._state.trailing_metadata = status.trailing_metadata
411        self.abort(status.code, status.details)
412
413    def set_code(self, code: grpc.StatusCode) -> None:
414        with self._state.condition:
415            self._state.code = code
416
417    def code(self) -> grpc.StatusCode:
418        return self._state.code
419
420    def set_details(self, details: str) -> None:
421        with self._state.condition:
422            self._state.details = _common.encode(details)
423
424    def details(self) -> bytes:
425        return self._state.details
426
427    def _finalize_state(self) -> None:
428        pass
429
430
431class _RequestIterator(object):
432    _state: _RPCState
433    _call: cygrpc.Call
434    _request_deserializer: Optional[DeserializingFunction]
435
436    def __init__(
437        self,
438        state: _RPCState,
439        call: cygrpc.Call,
440        request_deserializer: Optional[DeserializingFunction],
441    ):
442        self._state = state
443        self._call = call
444        self._request_deserializer = request_deserializer
445
446    def _raise_or_start_receive_message(self) -> None:
447        if self._state.client is _CANCELLED:
448            _raise_rpc_error(self._state)
449        elif not _is_rpc_state_active(self._state):
450            raise StopIteration()
451        else:
452            self._call.start_server_batch(
453                (cygrpc.ReceiveMessageOperation(_EMPTY_FLAGS),),
454                _receive_message(
455                    self._state, self._call, self._request_deserializer
456                ),
457            )
458            self._state.due.add(_RECEIVE_MESSAGE_TOKEN)
459
460    def _look_for_request(self) -> Any:
461        if self._state.client is _CANCELLED:
462            _raise_rpc_error(self._state)
463        elif (
464            self._state.request is None
465            and _RECEIVE_MESSAGE_TOKEN not in self._state.due
466        ):
467            raise StopIteration()
468        else:
469            request = self._state.request
470            self._state.request = None
471            return request
472
473        raise AssertionError()  # should never run
474
475    def _next(self) -> Any:
476        with self._state.condition:
477            self._raise_or_start_receive_message()
478            while True:
479                self._state.condition.wait()
480                request = self._look_for_request()
481                if request is not None:
482                    return request
483
484    def __iter__(self) -> _RequestIterator:
485        return self
486
487    def __next__(self) -> Any:
488        return self._next()
489
490    def next(self) -> Any:
491        return self._next()
492
493
494def _unary_request(
495    rpc_event: cygrpc.BaseEvent,
496    state: _RPCState,
497    request_deserializer: Optional[DeserializingFunction],
498) -> Callable[[], Any]:
499    def unary_request():
500        with state.condition:
501            if not _is_rpc_state_active(state):
502                return None
503            else:
504                rpc_event.call.start_server_batch(
505                    (cygrpc.ReceiveMessageOperation(_EMPTY_FLAGS),),
506                    _receive_message(
507                        state, rpc_event.call, request_deserializer
508                    ),
509                )
510                state.due.add(_RECEIVE_MESSAGE_TOKEN)
511                while True:
512                    state.condition.wait()
513                    if state.request is None:
514                        if state.client is _CLOSED:
515                            details = '"{}" requires exactly one request message.'.format(
516                                rpc_event.call_details.method
517                            )
518                            _abort(
519                                state,
520                                rpc_event.call,
521                                cygrpc.StatusCode.unimplemented,
522                                _common.encode(details),
523                            )
524                            return None
525                        elif state.client is _CANCELLED:
526                            return None
527                    else:
528                        request = state.request
529                        state.request = None
530                        return request
531
532    return unary_request
533
534
535def _call_behavior(
536    rpc_event: cygrpc.BaseEvent,
537    state: _RPCState,
538    behavior: ArityAgnosticMethodHandler,
539    argument: Any,
540    request_deserializer: Optional[DeserializingFunction],
541    send_response_callback: Optional[Callable[[ResponseType], None]] = None,
542) -> Tuple[Union[ResponseType, Iterator[ResponseType]], bool]:
543    from grpc import _create_servicer_context  # pytype: disable=pyi-error
544
545    with _create_servicer_context(
546        rpc_event, state, request_deserializer
547    ) as context:
548        try:
549            response_or_iterator = None
550            if send_response_callback is not None:
551                response_or_iterator = behavior(
552                    argument, context, send_response_callback
553                )
554            else:
555                response_or_iterator = behavior(argument, context)
556            return response_or_iterator, True
557        except Exception as exception:  # pylint: disable=broad-except
558            with state.condition:
559                if state.aborted:
560                    _abort(
561                        state,
562                        rpc_event.call,
563                        cygrpc.StatusCode.unknown,
564                        b"RPC Aborted",
565                    )
566                elif exception not in state.rpc_errors:
567                    try:
568                        details = "Exception calling application: {}".format(
569                            exception
570                        )
571                    except Exception:  # pylint: disable=broad-except
572                        details = (
573                            "Calling application raised unprintable Exception!"
574                        )
575                        _LOGGER.exception(
576                            traceback.format_exception(
577                                type(exception),
578                                exception,
579                                exception.__traceback__,
580                            )
581                        )
582                        traceback.print_exc()
583                    _LOGGER.exception(details)
584                    _abort(
585                        state,
586                        rpc_event.call,
587                        cygrpc.StatusCode.unknown,
588                        _common.encode(details),
589                    )
590            return None, False
591
592
593def _take_response_from_response_iterator(
594    rpc_event: cygrpc.BaseEvent,
595    state: _RPCState,
596    response_iterator: Iterator[ResponseType],
597) -> Tuple[ResponseType, bool]:
598    try:
599        return next(response_iterator), True
600    except StopIteration:
601        return None, True
602    except Exception as exception:  # pylint: disable=broad-except
603        with state.condition:
604            if state.aborted:
605                _abort(
606                    state,
607                    rpc_event.call,
608                    cygrpc.StatusCode.unknown,
609                    b"RPC Aborted",
610                )
611            elif exception not in state.rpc_errors:
612                details = "Exception iterating responses: {}".format(exception)
613                _LOGGER.exception(details)
614                _abort(
615                    state,
616                    rpc_event.call,
617                    cygrpc.StatusCode.unknown,
618                    _common.encode(details),
619                )
620        return None, False
621
622
623def _serialize_response(
624    rpc_event: cygrpc.BaseEvent,
625    state: _RPCState,
626    response: Any,
627    response_serializer: Optional[SerializingFunction],
628) -> Optional[bytes]:
629    serialized_response = _common.serialize(response, response_serializer)
630    if serialized_response is None:
631        with state.condition:
632            _abort(
633                state,
634                rpc_event.call,
635                cygrpc.StatusCode.internal,
636                b"Failed to serialize response!",
637            )
638        return None
639    else:
640        return serialized_response
641
642
643def _get_send_message_op_flags_from_state(
644    state: _RPCState,
645) -> Union[int, cygrpc.WriteFlag]:
646    if state.disable_next_compression:
647        return cygrpc.WriteFlag.no_compress
648    else:
649        return _EMPTY_FLAGS
650
651
652def _reset_per_message_state(state: _RPCState) -> None:
653    with state.condition:
654        state.disable_next_compression = False
655
656
657def _send_response(
658    rpc_event: cygrpc.BaseEvent, state: _RPCState, serialized_response: bytes
659) -> bool:
660    with state.condition:
661        if not _is_rpc_state_active(state):
662            return False
663        else:
664            if state.initial_metadata_allowed:
665                operations = (
666                    _get_initial_metadata_operation(state, None),
667                    cygrpc.SendMessageOperation(
668                        serialized_response,
669                        _get_send_message_op_flags_from_state(state),
670                    ),
671                )
672                state.initial_metadata_allowed = False
673                token = _SEND_INITIAL_METADATA_AND_SEND_MESSAGE_TOKEN
674            else:
675                operations = (
676                    cygrpc.SendMessageOperation(
677                        serialized_response,
678                        _get_send_message_op_flags_from_state(state),
679                    ),
680                )
681                token = _SEND_MESSAGE_TOKEN
682            rpc_event.call.start_server_batch(
683                operations, _send_message(state, token)
684            )
685            state.due.add(token)
686            _reset_per_message_state(state)
687            while True:
688                state.condition.wait()
689                if token not in state.due:
690                    return _is_rpc_state_active(state)
691
692
693def _status(
694    rpc_event: cygrpc.BaseEvent,
695    state: _RPCState,
696    serialized_response: Optional[bytes],
697) -> None:
698    with state.condition:
699        if state.client is not _CANCELLED:
700            code = _completion_code(state)
701            details = _details(state)
702            operations = [
703                cygrpc.SendStatusFromServerOperation(
704                    state.trailing_metadata, code, details, _EMPTY_FLAGS
705                ),
706            ]
707            if state.initial_metadata_allowed:
708                operations.append(_get_initial_metadata_operation(state, None))
709            if serialized_response is not None:
710                operations.append(
711                    cygrpc.SendMessageOperation(
712                        serialized_response,
713                        _get_send_message_op_flags_from_state(state),
714                    )
715                )
716            rpc_event.call.start_server_batch(
717                operations,
718                _send_status_from_server(state, _SEND_STATUS_FROM_SERVER_TOKEN),
719            )
720            state.statused = True
721            _reset_per_message_state(state)
722            state.due.add(_SEND_STATUS_FROM_SERVER_TOKEN)
723
724
725def _unary_response_in_pool(
726    rpc_event: cygrpc.BaseEvent,
727    state: _RPCState,
728    behavior: ArityAgnosticMethodHandler,
729    argument_thunk: Callable[[], Any],
730    request_deserializer: Optional[SerializingFunction],
731    response_serializer: Optional[SerializingFunction],
732) -> None:
733    cygrpc.install_context_from_request_call_event(rpc_event)
734
735    try:
736        argument = argument_thunk()
737        if argument is not None:
738            response, proceed = _call_behavior(
739                rpc_event, state, behavior, argument, request_deserializer
740            )
741            if proceed:
742                serialized_response = _serialize_response(
743                    rpc_event, state, response, response_serializer
744                )
745                if serialized_response is not None:
746                    _status(rpc_event, state, serialized_response)
747    except Exception:  # pylint: disable=broad-except
748        traceback.print_exc()
749    finally:
750        cygrpc.uninstall_context()
751
752
753def _stream_response_in_pool(
754    rpc_event: cygrpc.BaseEvent,
755    state: _RPCState,
756    behavior: ArityAgnosticMethodHandler,
757    argument_thunk: Callable[[], Any],
758    request_deserializer: Optional[DeserializingFunction],
759    response_serializer: Optional[SerializingFunction],
760) -> None:
761    cygrpc.install_context_from_request_call_event(rpc_event)
762
763    def send_response(response: Any) -> None:
764        if response is None:
765            _status(rpc_event, state, None)
766        else:
767            serialized_response = _serialize_response(
768                rpc_event, state, response, response_serializer
769            )
770            if serialized_response is not None:
771                _send_response(rpc_event, state, serialized_response)
772
773    try:
774        argument = argument_thunk()
775        if argument is not None:
776            if (
777                hasattr(behavior, "experimental_non_blocking")
778                and behavior.experimental_non_blocking
779            ):
780                _call_behavior(
781                    rpc_event,
782                    state,
783                    behavior,
784                    argument,
785                    request_deserializer,
786                    send_response_callback=send_response,
787                )
788            else:
789                response_iterator, proceed = _call_behavior(
790                    rpc_event, state, behavior, argument, request_deserializer
791                )
792                if proceed:
793                    _send_message_callback_to_blocking_iterator_adapter(
794                        rpc_event, state, send_response, response_iterator
795                    )
796    except Exception:  # pylint: disable=broad-except
797        traceback.print_exc()
798    finally:
799        cygrpc.uninstall_context()
800
801
802def _is_rpc_state_active(state: _RPCState) -> bool:
803    return state.client is not _CANCELLED and not state.statused
804
805
806def _send_message_callback_to_blocking_iterator_adapter(
807    rpc_event: cygrpc.BaseEvent,
808    state: _RPCState,
809    send_response_callback: Callable[[ResponseType], None],
810    response_iterator: Iterator[ResponseType],
811) -> None:
812    while True:
813        response, proceed = _take_response_from_response_iterator(
814            rpc_event, state, response_iterator
815        )
816        if proceed:
817            send_response_callback(response)
818            if not _is_rpc_state_active(state):
819                break
820        else:
821            break
822
823
824def _select_thread_pool_for_behavior(
825    behavior: ArityAgnosticMethodHandler,
826    default_thread_pool: futures.ThreadPoolExecutor,
827) -> futures.ThreadPoolExecutor:
828    if hasattr(behavior, "experimental_thread_pool") and isinstance(
829        behavior.experimental_thread_pool, futures.ThreadPoolExecutor
830    ):
831        return behavior.experimental_thread_pool
832    else:
833        return default_thread_pool
834
835
836def _handle_unary_unary(
837    rpc_event: cygrpc.BaseEvent,
838    state: _RPCState,
839    method_handler: grpc.RpcMethodHandler,
840    default_thread_pool: futures.ThreadPoolExecutor,
841) -> futures.Future:
842    unary_request = _unary_request(
843        rpc_event, state, method_handler.request_deserializer
844    )
845    thread_pool = _select_thread_pool_for_behavior(
846        method_handler.unary_unary, default_thread_pool
847    )
848    return thread_pool.submit(
849        state.context.run,
850        _unary_response_in_pool,
851        rpc_event,
852        state,
853        method_handler.unary_unary,
854        unary_request,
855        method_handler.request_deserializer,
856        method_handler.response_serializer,
857    )
858
859
860def _handle_unary_stream(
861    rpc_event: cygrpc.BaseEvent,
862    state: _RPCState,
863    method_handler: grpc.RpcMethodHandler,
864    default_thread_pool: futures.ThreadPoolExecutor,
865) -> futures.Future:
866    unary_request = _unary_request(
867        rpc_event, state, method_handler.request_deserializer
868    )
869    thread_pool = _select_thread_pool_for_behavior(
870        method_handler.unary_stream, default_thread_pool
871    )
872    return thread_pool.submit(
873        state.context.run,
874        _stream_response_in_pool,
875        rpc_event,
876        state,
877        method_handler.unary_stream,
878        unary_request,
879        method_handler.request_deserializer,
880        method_handler.response_serializer,
881    )
882
883
884def _handle_stream_unary(
885    rpc_event: cygrpc.BaseEvent,
886    state: _RPCState,
887    method_handler: grpc.RpcMethodHandler,
888    default_thread_pool: futures.ThreadPoolExecutor,
889) -> futures.Future:
890    request_iterator = _RequestIterator(
891        state, rpc_event.call, method_handler.request_deserializer
892    )
893    thread_pool = _select_thread_pool_for_behavior(
894        method_handler.stream_unary, default_thread_pool
895    )
896    return thread_pool.submit(
897        state.context.run,
898        _unary_response_in_pool,
899        rpc_event,
900        state,
901        method_handler.stream_unary,
902        lambda: request_iterator,
903        method_handler.request_deserializer,
904        method_handler.response_serializer,
905    )
906
907
908def _handle_stream_stream(
909    rpc_event: cygrpc.BaseEvent,
910    state: _RPCState,
911    method_handler: grpc.RpcMethodHandler,
912    default_thread_pool: futures.ThreadPoolExecutor,
913) -> futures.Future:
914    request_iterator = _RequestIterator(
915        state, rpc_event.call, method_handler.request_deserializer
916    )
917    thread_pool = _select_thread_pool_for_behavior(
918        method_handler.stream_stream, default_thread_pool
919    )
920    return thread_pool.submit(
921        state.context.run,
922        _stream_response_in_pool,
923        rpc_event,
924        state,
925        method_handler.stream_stream,
926        lambda: request_iterator,
927        method_handler.request_deserializer,
928        method_handler.response_serializer,
929    )
930
931
932def _find_method_handler(
933    rpc_event: cygrpc.BaseEvent,
934    state: _RPCState,
935    generic_handlers: List[grpc.GenericRpcHandler],
936    interceptor_pipeline: Optional[_interceptor._ServicePipeline],
937) -> Optional[grpc.RpcMethodHandler]:
938    def query_handlers(
939        handler_call_details: _HandlerCallDetails,
940    ) -> Optional[grpc.RpcMethodHandler]:
941        for generic_handler in generic_handlers:
942            method_handler = generic_handler.service(handler_call_details)
943            if method_handler is not None:
944                return method_handler
945        return None
946
947    handler_call_details = _HandlerCallDetails(
948        _common.decode(rpc_event.call_details.method),
949        rpc_event.invocation_metadata,
950    )
951
952    if interceptor_pipeline is not None:
953        return state.context.run(
954            interceptor_pipeline.execute, query_handlers, handler_call_details
955        )
956    else:
957        return state.context.run(query_handlers, handler_call_details)
958
959
960def _reject_rpc(
961    rpc_event: cygrpc.BaseEvent,
962    rpc_state: _RPCState,
963    status: cygrpc.StatusCode,
964    details: bytes,
965):
966    operations = (
967        _get_initial_metadata_operation(rpc_state, None),
968        cygrpc.ReceiveCloseOnServerOperation(_EMPTY_FLAGS),
969        cygrpc.SendStatusFromServerOperation(
970            None, status, details, _EMPTY_FLAGS
971        ),
972    )
973    rpc_event.call.start_server_batch(
974        operations,
975        lambda ignored_event: (
976            rpc_state,
977            (),
978        ),
979    )
980
981
982def _handle_with_method_handler(
983    rpc_event: cygrpc.BaseEvent,
984    state: _RPCState,
985    method_handler: grpc.RpcMethodHandler,
986    thread_pool: futures.ThreadPoolExecutor,
987) -> futures.Future:
988    with state.condition:
989        rpc_event.call.start_server_batch(
990            (cygrpc.ReceiveCloseOnServerOperation(_EMPTY_FLAGS),),
991            _receive_close_on_server(state),
992        )
993        state.due.add(_RECEIVE_CLOSE_ON_SERVER_TOKEN)
994        if method_handler.request_streaming:
995            if method_handler.response_streaming:
996                return _handle_stream_stream(
997                    rpc_event, state, method_handler, thread_pool
998                )
999            else:
1000                return _handle_stream_unary(
1001                    rpc_event, state, method_handler, thread_pool
1002                )
1003        else:
1004            if method_handler.response_streaming:
1005                return _handle_unary_stream(
1006                    rpc_event, state, method_handler, thread_pool
1007                )
1008            else:
1009                return _handle_unary_unary(
1010                    rpc_event, state, method_handler, thread_pool
1011                )
1012
1013
1014def _handle_call(
1015    rpc_event: cygrpc.BaseEvent,
1016    generic_handlers: List[grpc.GenericRpcHandler],
1017    interceptor_pipeline: Optional[_interceptor._ServicePipeline],
1018    thread_pool: futures.ThreadPoolExecutor,
1019    concurrency_exceeded: bool,
1020) -> Tuple[Optional[_RPCState], Optional[futures.Future]]:
1021    if not rpc_event.success:
1022        return None, None
1023    if rpc_event.call_details.method is not None:
1024        rpc_state = _RPCState()
1025        try:
1026            method_handler = _find_method_handler(
1027                rpc_event, rpc_state, generic_handlers, interceptor_pipeline
1028            )
1029        except Exception as exception:  # pylint: disable=broad-except
1030            details = "Exception servicing handler: {}".format(exception)
1031            _LOGGER.exception(details)
1032            _reject_rpc(
1033                rpc_event,
1034                rpc_state,
1035                cygrpc.StatusCode.unknown,
1036                b"Error in service handler!",
1037            )
1038            return rpc_state, None
1039        if method_handler is None:
1040            _reject_rpc(
1041                rpc_event,
1042                rpc_state,
1043                cygrpc.StatusCode.unimplemented,
1044                b"Method not found!",
1045            )
1046            return rpc_state, None
1047        elif concurrency_exceeded:
1048            _reject_rpc(
1049                rpc_event,
1050                rpc_state,
1051                cygrpc.StatusCode.resource_exhausted,
1052                b"Concurrent RPC limit exceeded!",
1053            )
1054            return rpc_state, None
1055        else:
1056            return (
1057                rpc_state,
1058                _handle_with_method_handler(
1059                    rpc_event, rpc_state, method_handler, thread_pool
1060                ),
1061            )
1062    else:
1063        return None, None
1064
1065
1066@enum.unique
1067class _ServerStage(enum.Enum):
1068    STOPPED = "stopped"
1069    STARTED = "started"
1070    GRACE = "grace"
1071
1072
1073class _ServerState(object):
1074    lock: threading.RLock
1075    completion_queue: cygrpc.CompletionQueue
1076    server: cygrpc.Server
1077    generic_handlers: List[grpc.GenericRpcHandler]
1078    interceptor_pipeline: Optional[_interceptor._ServicePipeline]
1079    thread_pool: futures.ThreadPoolExecutor
1080    stage: _ServerStage
1081    termination_event: threading.Event
1082    shutdown_events: List[threading.Event]
1083    maximum_concurrent_rpcs: Optional[int]
1084    active_rpc_count: int
1085    rpc_states: Set[_RPCState]
1086    due: Set[str]
1087    server_deallocated: bool
1088
1089    # pylint: disable=too-many-arguments
1090    def __init__(
1091        self,
1092        completion_queue: cygrpc.CompletionQueue,
1093        server: cygrpc.Server,
1094        generic_handlers: Sequence[grpc.GenericRpcHandler],
1095        interceptor_pipeline: Optional[_interceptor._ServicePipeline],
1096        thread_pool: futures.ThreadPoolExecutor,
1097        maximum_concurrent_rpcs: Optional[int],
1098    ):
1099        self.lock = threading.RLock()
1100        self.completion_queue = completion_queue
1101        self.server = server
1102        self.generic_handlers = list(generic_handlers)
1103        self.interceptor_pipeline = interceptor_pipeline
1104        self.thread_pool = thread_pool
1105        self.stage = _ServerStage.STOPPED
1106        self.termination_event = threading.Event()
1107        self.shutdown_events = [self.termination_event]
1108        self.maximum_concurrent_rpcs = maximum_concurrent_rpcs
1109        self.active_rpc_count = 0
1110
1111        # TODO(https://github.com/grpc/grpc/issues/6597): eliminate these fields.
1112        self.rpc_states = set()
1113        self.due = set()
1114
1115        # A "volatile" flag to interrupt the daemon serving thread
1116        self.server_deallocated = False
1117
1118
1119def _add_generic_handlers(
1120    state: _ServerState, generic_handlers: Iterable[grpc.GenericRpcHandler]
1121) -> None:
1122    with state.lock:
1123        state.generic_handlers.extend(generic_handlers)
1124
1125
1126def _add_insecure_port(state: _ServerState, address: bytes) -> int:
1127    with state.lock:
1128        return state.server.add_http2_port(address)
1129
1130
1131def _add_secure_port(
1132    state: _ServerState,
1133    address: bytes,
1134    server_credentials: grpc.ServerCredentials,
1135) -> int:
1136    with state.lock:
1137        return state.server.add_http2_port(
1138            address, server_credentials._credentials
1139        )
1140
1141
1142def _request_call(state: _ServerState) -> None:
1143    state.server.request_call(
1144        state.completion_queue, state.completion_queue, _REQUEST_CALL_TAG
1145    )
1146    state.due.add(_REQUEST_CALL_TAG)
1147
1148
1149# TODO(https://github.com/grpc/grpc/issues/6597): delete this function.
1150def _stop_serving(state: _ServerState) -> bool:
1151    if not state.rpc_states and not state.due:
1152        state.server.destroy()
1153        for shutdown_event in state.shutdown_events:
1154            shutdown_event.set()
1155        state.stage = _ServerStage.STOPPED
1156        return True
1157    else:
1158        return False
1159
1160
1161def _on_call_completed(state: _ServerState) -> None:
1162    with state.lock:
1163        state.active_rpc_count -= 1
1164
1165
1166def _process_event_and_continue(
1167    state: _ServerState, event: cygrpc.BaseEvent
1168) -> bool:
1169    should_continue = True
1170    if event.tag is _SHUTDOWN_TAG:
1171        with state.lock:
1172            state.due.remove(_SHUTDOWN_TAG)
1173            if _stop_serving(state):
1174                should_continue = False
1175    elif event.tag is _REQUEST_CALL_TAG:
1176        with state.lock:
1177            state.due.remove(_REQUEST_CALL_TAG)
1178            concurrency_exceeded = (
1179                state.maximum_concurrent_rpcs is not None
1180                and state.active_rpc_count >= state.maximum_concurrent_rpcs
1181            )
1182            rpc_state, rpc_future = _handle_call(
1183                event,
1184                state.generic_handlers,
1185                state.interceptor_pipeline,
1186                state.thread_pool,
1187                concurrency_exceeded,
1188            )
1189            if rpc_state is not None:
1190                state.rpc_states.add(rpc_state)
1191            if rpc_future is not None:
1192                state.active_rpc_count += 1
1193                rpc_future.add_done_callback(
1194                    lambda unused_future: _on_call_completed(state)
1195                )
1196            if state.stage is _ServerStage.STARTED:
1197                _request_call(state)
1198            elif _stop_serving(state):
1199                should_continue = False
1200    else:
1201        rpc_state, callbacks = event.tag(event)
1202        for callback in callbacks:
1203            try:
1204                callback()
1205            except Exception:  # pylint: disable=broad-except
1206                _LOGGER.exception("Exception calling callback!")
1207        if rpc_state is not None:
1208            with state.lock:
1209                state.rpc_states.remove(rpc_state)
1210                if _stop_serving(state):
1211                    should_continue = False
1212    return should_continue
1213
1214
1215def _serve(state: _ServerState) -> None:
1216    while True:
1217        timeout = time.time() + _DEALLOCATED_SERVER_CHECK_PERIOD_S
1218        event = state.completion_queue.poll(timeout)
1219        if state.server_deallocated:
1220            _begin_shutdown_once(state)
1221        if event.completion_type != cygrpc.CompletionType.queue_timeout:
1222            if not _process_event_and_continue(state, event):
1223                return
1224        # We want to force the deletion of the previous event
1225        # ~before~ we poll again; if the event has a reference
1226        # to a shutdown Call object, this can induce spinlock.
1227        event = None
1228
1229
1230def _begin_shutdown_once(state: _ServerState) -> None:
1231    with state.lock:
1232        if state.stage is _ServerStage.STARTED:
1233            state.server.shutdown(state.completion_queue, _SHUTDOWN_TAG)
1234            state.stage = _ServerStage.GRACE
1235            state.due.add(_SHUTDOWN_TAG)
1236
1237
1238def _stop(state: _ServerState, grace: Optional[float]) -> threading.Event:
1239    with state.lock:
1240        if state.stage is _ServerStage.STOPPED:
1241            shutdown_event = threading.Event()
1242            shutdown_event.set()
1243            return shutdown_event
1244        else:
1245            _begin_shutdown_once(state)
1246            shutdown_event = threading.Event()
1247            state.shutdown_events.append(shutdown_event)
1248            if grace is None:
1249                state.server.cancel_all_calls()
1250            else:
1251
1252                def cancel_all_calls_after_grace():
1253                    shutdown_event.wait(timeout=grace)
1254                    with state.lock:
1255                        state.server.cancel_all_calls()
1256
1257                thread = threading.Thread(target=cancel_all_calls_after_grace)
1258                thread.start()
1259                return shutdown_event
1260    shutdown_event.wait()
1261    return shutdown_event
1262
1263
1264def _start(state: _ServerState) -> None:
1265    with state.lock:
1266        if state.stage is not _ServerStage.STOPPED:
1267            raise ValueError("Cannot start already-started server!")
1268        state.server.start()
1269        state.stage = _ServerStage.STARTED
1270        _request_call(state)
1271        thread = threading.Thread(target=_serve, args=(state,))
1272        thread.daemon = True
1273        thread.start()
1274
1275
1276def _validate_generic_rpc_handlers(
1277    generic_rpc_handlers: Iterable[grpc.GenericRpcHandler],
1278) -> None:
1279    for generic_rpc_handler in generic_rpc_handlers:
1280        service_attribute = getattr(generic_rpc_handler, "service", None)
1281        if service_attribute is None:
1282            raise AttributeError(
1283                '"{}" must conform to grpc.GenericRpcHandler type but does '
1284                'not have "service" method!'.format(generic_rpc_handler)
1285            )
1286
1287
1288def _augment_options(
1289    base_options: Sequence[ChannelArgumentType],
1290    compression: Optional[grpc.Compression],
1291) -> Sequence[ChannelArgumentType]:
1292    compression_option = _compression.create_channel_option(compression)
1293    return tuple(base_options) + compression_option
1294
1295
1296class _Server(grpc.Server):
1297    _state: _ServerState
1298
1299    # pylint: disable=too-many-arguments
1300    def __init__(
1301        self,
1302        thread_pool: futures.ThreadPoolExecutor,
1303        generic_handlers: Sequence[grpc.GenericRpcHandler],
1304        interceptors: Sequence[grpc.ServerInterceptor],
1305        options: Sequence[ChannelArgumentType],
1306        maximum_concurrent_rpcs: Optional[int],
1307        compression: Optional[grpc.Compression],
1308        xds: bool,
1309    ):
1310        completion_queue = cygrpc.CompletionQueue()
1311        server = cygrpc.Server(_augment_options(options, compression), xds)
1312        server.register_completion_queue(completion_queue)
1313        self._state = _ServerState(
1314            completion_queue,
1315            server,
1316            generic_handlers,
1317            _interceptor.service_pipeline(interceptors),
1318            thread_pool,
1319            maximum_concurrent_rpcs,
1320        )
1321
1322    def add_generic_rpc_handlers(
1323        self, generic_rpc_handlers: Iterable[grpc.GenericRpcHandler]
1324    ) -> None:
1325        _validate_generic_rpc_handlers(generic_rpc_handlers)
1326        _add_generic_handlers(self._state, generic_rpc_handlers)
1327
1328    def add_insecure_port(self, address: str) -> int:
1329        return _common.validate_port_binding_result(
1330            address, _add_insecure_port(self._state, _common.encode(address))
1331        )
1332
1333    def add_secure_port(
1334        self, address: str, server_credentials: grpc.ServerCredentials
1335    ) -> int:
1336        return _common.validate_port_binding_result(
1337            address,
1338            _add_secure_port(
1339                self._state, _common.encode(address), server_credentials
1340            ),
1341        )
1342
1343    def start(self) -> None:
1344        _start(self._state)
1345
1346    def wait_for_termination(self, timeout: Optional[float] = None) -> bool:
1347        # NOTE(https://bugs.python.org/issue35935)
1348        # Remove this workaround once threading.Event.wait() is working with
1349        # CTRL+C across platforms.
1350        return _common.wait(
1351            self._state.termination_event.wait,
1352            self._state.termination_event.is_set,
1353            timeout=timeout,
1354        )
1355
1356    def stop(self, grace: Optional[float]) -> threading.Event:
1357        return _stop(self._state, grace)
1358
1359    def __del__(self):
1360        if hasattr(self, "_state"):
1361            # We can not grab a lock in __del__(), so set a flag to signal the
1362            # serving daemon thread (if it exists) to initiate shutdown.
1363            self._state.server_deallocated = True
1364
1365
1366def create_server(
1367    thread_pool: futures.ThreadPoolExecutor,
1368    generic_rpc_handlers: Sequence[grpc.GenericRpcHandler],
1369    interceptors: Sequence[grpc.ServerInterceptor],
1370    options: Sequence[ChannelArgumentType],
1371    maximum_concurrent_rpcs: Optional[int],
1372    compression: Optional[grpc.Compression],
1373    xds: bool,
1374) -> _Server:
1375    _validate_generic_rpc_handlers(generic_rpc_handlers)
1376    return _Server(
1377        thread_pool,
1378        generic_rpc_handlers,
1379        interceptors,
1380        options,
1381        maximum_concurrent_rpcs,
1382        compression,
1383        xds,
1384    )
1385