xref: /aosp_15_r20/external/grpc-grpc/src/python/grpcio/grpc/beta/_server_adaptations.py (revision cc02d7e222339f7a4f6ba5f422e6413f4bd931f2)
1# Copyright 2016 gRPC authors.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""Translates gRPC's server-side API into gRPC's server-side Beta API."""
15
16import collections
17import threading
18
19import grpc
20from grpc import _common
21from grpc.beta import _metadata
22from grpc.beta import interfaces
23from grpc.framework.common import cardinality
24from grpc.framework.common import style
25from grpc.framework.foundation import abandonment
26from grpc.framework.foundation import logging_pool
27from grpc.framework.foundation import stream
28from grpc.framework.interfaces.face import face
29
30# pylint: disable=too-many-return-statements
31
32_DEFAULT_POOL_SIZE = 8
33
34
35class _ServerProtocolContext(interfaces.GRPCServicerContext):
36    def __init__(self, servicer_context):
37        self._servicer_context = servicer_context
38
39    def peer(self):
40        return self._servicer_context.peer()
41
42    def disable_next_response_compression(self):
43        pass  # TODO(https://github.com/grpc/grpc/issues/4078): design, implement.
44
45
46class _FaceServicerContext(face.ServicerContext):
47    def __init__(self, servicer_context):
48        self._servicer_context = servicer_context
49
50    def is_active(self):
51        return self._servicer_context.is_active()
52
53    def time_remaining(self):
54        return self._servicer_context.time_remaining()
55
56    def add_abortion_callback(self, abortion_callback):
57        raise NotImplementedError(
58            "add_abortion_callback no longer supported server-side!"
59        )
60
61    def cancel(self):
62        self._servicer_context.cancel()
63
64    def protocol_context(self):
65        return _ServerProtocolContext(self._servicer_context)
66
67    def invocation_metadata(self):
68        return _metadata.beta(self._servicer_context.invocation_metadata())
69
70    def initial_metadata(self, initial_metadata):
71        self._servicer_context.send_initial_metadata(
72            _metadata.unbeta(initial_metadata)
73        )
74
75    def terminal_metadata(self, terminal_metadata):
76        self._servicer_context.set_terminal_metadata(
77            _metadata.unbeta(terminal_metadata)
78        )
79
80    def code(self, code):
81        self._servicer_context.set_code(code)
82
83    def details(self, details):
84        self._servicer_context.set_details(details)
85
86
87def _adapt_unary_request_inline(unary_request_inline):
88    def adaptation(request, servicer_context):
89        return unary_request_inline(
90            request, _FaceServicerContext(servicer_context)
91        )
92
93    return adaptation
94
95
96def _adapt_stream_request_inline(stream_request_inline):
97    def adaptation(request_iterator, servicer_context):
98        return stream_request_inline(
99            request_iterator, _FaceServicerContext(servicer_context)
100        )
101
102    return adaptation
103
104
105class _Callback(stream.Consumer):
106    def __init__(self):
107        self._condition = threading.Condition()
108        self._values = []
109        self._terminated = False
110        self._cancelled = False
111
112    def consume(self, value):
113        with self._condition:
114            self._values.append(value)
115            self._condition.notify_all()
116
117    def terminate(self):
118        with self._condition:
119            self._terminated = True
120            self._condition.notify_all()
121
122    def consume_and_terminate(self, value):
123        with self._condition:
124            self._values.append(value)
125            self._terminated = True
126            self._condition.notify_all()
127
128    def cancel(self):
129        with self._condition:
130            self._cancelled = True
131            self._condition.notify_all()
132
133    def draw_one_value(self):
134        with self._condition:
135            while True:
136                if self._cancelled:
137                    raise abandonment.Abandoned()
138                elif self._values:
139                    return self._values.pop(0)
140                elif self._terminated:
141                    return None
142                else:
143                    self._condition.wait()
144
145    def draw_all_values(self):
146        with self._condition:
147            while True:
148                if self._cancelled:
149                    raise abandonment.Abandoned()
150                elif self._terminated:
151                    all_values = tuple(self._values)
152                    self._values = None
153                    return all_values
154                else:
155                    self._condition.wait()
156
157
158def _run_request_pipe_thread(
159    request_iterator, request_consumer, servicer_context
160):
161    thread_joined = threading.Event()
162
163    def pipe_requests():
164        for request in request_iterator:
165            if not servicer_context.is_active() or thread_joined.is_set():
166                return
167            request_consumer.consume(request)
168            if not servicer_context.is_active() or thread_joined.is_set():
169                return
170        request_consumer.terminate()
171
172    request_pipe_thread = threading.Thread(target=pipe_requests)
173    request_pipe_thread.daemon = True
174    request_pipe_thread.start()
175
176
177def _adapt_unary_unary_event(unary_unary_event):
178    def adaptation(request, servicer_context):
179        callback = _Callback()
180        if not servicer_context.add_callback(callback.cancel):
181            raise abandonment.Abandoned()
182        unary_unary_event(
183            request,
184            callback.consume_and_terminate,
185            _FaceServicerContext(servicer_context),
186        )
187        return callback.draw_all_values()[0]
188
189    return adaptation
190
191
192def _adapt_unary_stream_event(unary_stream_event):
193    def adaptation(request, servicer_context):
194        callback = _Callback()
195        if not servicer_context.add_callback(callback.cancel):
196            raise abandonment.Abandoned()
197        unary_stream_event(
198            request, callback, _FaceServicerContext(servicer_context)
199        )
200        while True:
201            response = callback.draw_one_value()
202            if response is None:
203                return
204            else:
205                yield response
206
207    return adaptation
208
209
210def _adapt_stream_unary_event(stream_unary_event):
211    def adaptation(request_iterator, servicer_context):
212        callback = _Callback()
213        if not servicer_context.add_callback(callback.cancel):
214            raise abandonment.Abandoned()
215        request_consumer = stream_unary_event(
216            callback.consume_and_terminate,
217            _FaceServicerContext(servicer_context),
218        )
219        _run_request_pipe_thread(
220            request_iterator, request_consumer, servicer_context
221        )
222        return callback.draw_all_values()[0]
223
224    return adaptation
225
226
227def _adapt_stream_stream_event(stream_stream_event):
228    def adaptation(request_iterator, servicer_context):
229        callback = _Callback()
230        if not servicer_context.add_callback(callback.cancel):
231            raise abandonment.Abandoned()
232        request_consumer = stream_stream_event(
233            callback, _FaceServicerContext(servicer_context)
234        )
235        _run_request_pipe_thread(
236            request_iterator, request_consumer, servicer_context
237        )
238        while True:
239            response = callback.draw_one_value()
240            if response is None:
241                return
242            else:
243                yield response
244
245    return adaptation
246
247
248class _SimpleMethodHandler(
249    collections.namedtuple(
250        "_MethodHandler",
251        (
252            "request_streaming",
253            "response_streaming",
254            "request_deserializer",
255            "response_serializer",
256            "unary_unary",
257            "unary_stream",
258            "stream_unary",
259            "stream_stream",
260        ),
261    ),
262    grpc.RpcMethodHandler,
263):
264    pass
265
266
267def _simple_method_handler(
268    implementation, request_deserializer, response_serializer
269):
270    if implementation.style is style.Service.INLINE:
271        if implementation.cardinality is cardinality.Cardinality.UNARY_UNARY:
272            return _SimpleMethodHandler(
273                False,
274                False,
275                request_deserializer,
276                response_serializer,
277                _adapt_unary_request_inline(implementation.unary_unary_inline),
278                None,
279                None,
280                None,
281            )
282        elif implementation.cardinality is cardinality.Cardinality.UNARY_STREAM:
283            return _SimpleMethodHandler(
284                False,
285                True,
286                request_deserializer,
287                response_serializer,
288                None,
289                _adapt_unary_request_inline(implementation.unary_stream_inline),
290                None,
291                None,
292            )
293        elif implementation.cardinality is cardinality.Cardinality.STREAM_UNARY:
294            return _SimpleMethodHandler(
295                True,
296                False,
297                request_deserializer,
298                response_serializer,
299                None,
300                None,
301                _adapt_stream_request_inline(
302                    implementation.stream_unary_inline
303                ),
304                None,
305            )
306        elif (
307            implementation.cardinality is cardinality.Cardinality.STREAM_STREAM
308        ):
309            return _SimpleMethodHandler(
310                True,
311                True,
312                request_deserializer,
313                response_serializer,
314                None,
315                None,
316                None,
317                _adapt_stream_request_inline(
318                    implementation.stream_stream_inline
319                ),
320            )
321    elif implementation.style is style.Service.EVENT:
322        if implementation.cardinality is cardinality.Cardinality.UNARY_UNARY:
323            return _SimpleMethodHandler(
324                False,
325                False,
326                request_deserializer,
327                response_serializer,
328                _adapt_unary_unary_event(implementation.unary_unary_event),
329                None,
330                None,
331                None,
332            )
333        elif implementation.cardinality is cardinality.Cardinality.UNARY_STREAM:
334            return _SimpleMethodHandler(
335                False,
336                True,
337                request_deserializer,
338                response_serializer,
339                None,
340                _adapt_unary_stream_event(implementation.unary_stream_event),
341                None,
342                None,
343            )
344        elif implementation.cardinality is cardinality.Cardinality.STREAM_UNARY:
345            return _SimpleMethodHandler(
346                True,
347                False,
348                request_deserializer,
349                response_serializer,
350                None,
351                None,
352                _adapt_stream_unary_event(implementation.stream_unary_event),
353                None,
354            )
355        elif (
356            implementation.cardinality is cardinality.Cardinality.STREAM_STREAM
357        ):
358            return _SimpleMethodHandler(
359                True,
360                True,
361                request_deserializer,
362                response_serializer,
363                None,
364                None,
365                None,
366                _adapt_stream_stream_event(implementation.stream_stream_event),
367            )
368    raise ValueError()
369
370
371def _flatten_method_pair_map(method_pair_map):
372    method_pair_map = method_pair_map or {}
373    flat_map = {}
374    for method_pair in method_pair_map:
375        method = _common.fully_qualified_method(method_pair[0], method_pair[1])
376        flat_map[method] = method_pair_map[method_pair]
377    return flat_map
378
379
380class _GenericRpcHandler(grpc.GenericRpcHandler):
381    def __init__(
382        self,
383        method_implementations,
384        multi_method_implementation,
385        request_deserializers,
386        response_serializers,
387    ):
388        self._method_implementations = _flatten_method_pair_map(
389            method_implementations
390        )
391        self._request_deserializers = _flatten_method_pair_map(
392            request_deserializers
393        )
394        self._response_serializers = _flatten_method_pair_map(
395            response_serializers
396        )
397        self._multi_method_implementation = multi_method_implementation
398
399    def service(self, handler_call_details):
400        method_implementation = self._method_implementations.get(
401            handler_call_details.method
402        )
403        if method_implementation is not None:
404            return _simple_method_handler(
405                method_implementation,
406                self._request_deserializers.get(handler_call_details.method),
407                self._response_serializers.get(handler_call_details.method),
408            )
409        elif self._multi_method_implementation is None:
410            return None
411        else:
412            try:
413                return None  # TODO(nathaniel): call the multimethod.
414            except face.NoSuchMethodError:
415                return None
416
417
418class _Server(interfaces.Server):
419    def __init__(self, grpc_server):
420        self._grpc_server = grpc_server
421
422    def add_insecure_port(self, address):
423        return self._grpc_server.add_insecure_port(address)
424
425    def add_secure_port(self, address, server_credentials):
426        return self._grpc_server.add_secure_port(address, server_credentials)
427
428    def start(self):
429        self._grpc_server.start()
430
431    def stop(self, grace):
432        return self._grpc_server.stop(grace)
433
434    def __enter__(self):
435        self._grpc_server.start()
436        return self
437
438    def __exit__(self, exc_type, exc_val, exc_tb):
439        self._grpc_server.stop(None)
440        return False
441
442
443def server(
444    service_implementations,
445    multi_method_implementation,
446    request_deserializers,
447    response_serializers,
448    thread_pool,
449    thread_pool_size,
450):
451    generic_rpc_handler = _GenericRpcHandler(
452        service_implementations,
453        multi_method_implementation,
454        request_deserializers,
455        response_serializers,
456    )
457    if thread_pool is None:
458        effective_thread_pool = logging_pool.pool(
459            _DEFAULT_POOL_SIZE if thread_pool_size is None else thread_pool_size
460        )
461    else:
462        effective_thread_pool = thread_pool
463    return _Server(
464        grpc.server(effective_thread_pool, handlers=(generic_rpc_handler,))
465    )
466