xref: /aosp_15_r20/external/pigweed/pw_rpc/py/pw_rpc/callback_client/impl.py (revision 61c4878ac05f98d0ceed94b57d316916de578985)
1# Copyright 2021 The Pigweed Authors
2#
3# Licensed under the Apache License, Version 2.0 (the "License"); you may not
4# use this file except in compliance with the License. You may obtain a copy of
5# the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12# License for the specific language governing permissions and limitations under
13# the License.
14"""The callback-based pw_rpc client implementation."""
15
16from __future__ import annotations
17
18import inspect
19import logging
20import textwrap
21from typing import Any, Callable, Iterable, Type
22
23from dataclasses import dataclass
24from pw_status import Status
25from google.protobuf.message import Message
26
27from pw_rpc import client, descriptors
28from pw_rpc.client import PendingRpc, PendingRpcs
29from pw_rpc.descriptors import Channel, Method, Service
30
31from pw_rpc.callback_client.call import (
32    UseDefault,
33    OptionalTimeout,
34    CallTypeT,
35    UnaryResponse,
36    StreamResponse,
37    Call,
38    UnaryCall,
39    ServerStreamingCall,
40    ClientStreamingCall,
41    BidirectionalStreamingCall,
42    OnNextCallback,
43    OnCompletedCallback,
44    OnErrorCallback,
45)
46
47_LOG = logging.getLogger(__package__)
48
49
50DEFAULT_MAX_STREAM_RESPONSES = 2**14
51
52
53@dataclass(eq=True, frozen=True)
54class CallInfo:
55    method: Method
56
57    @property
58    def service(self) -> Service:
59        return self.method.service
60
61
62class _MethodClient:
63    """A method that can be invoked for a particular channel."""
64
65    def __init__(
66        self,
67        client_impl: Impl,
68        rpcs: PendingRpcs,
69        channel: Channel,
70        method: Method,
71        default_timeout_s: float | None,
72    ) -> None:
73        self._impl = client_impl
74        self._rpcs = rpcs
75        self._channel = channel
76        self._method = method
77        self.default_timeout_s: float | None = default_timeout_s
78
79    @property
80    def channel(self) -> Channel:
81        return self._channel
82
83    @property
84    def method(self) -> Method:
85        return self._method
86
87    @property
88    def service(self) -> Service:
89        return self._method.service
90
91    @property
92    def request(self) -> type:
93        """Returns the request proto class."""
94        return self.method.request_type
95
96    @property
97    def response(self) -> type:
98        """Returns the response proto class."""
99        return self.method.response_type
100
101    def __repr__(self) -> str:
102        return self.help()
103
104    def help(self) -> str:
105        """Returns a help message about this RPC."""
106        function_call = self.method.full_name + '('
107
108        docstring = inspect.getdoc(self.__call__)  # type: ignore[operator] # pylint: disable=no-member
109        assert docstring is not None
110
111        annotation = inspect.signature(self).return_annotation  # type: ignore[arg-type] # pylint: disable=line-too-long
112        if isinstance(annotation, type):
113            annotation = annotation.__name__
114
115        arg_sep = f',\n{" " * len(function_call)}'
116        return (
117            f'{function_call}'
118            f'{arg_sep.join(descriptors.field_help(self.method.request_type))})'
119            f'\n\n{textwrap.indent(docstring, "  ")}\n\n'
120            f'  Returns {annotation}.'
121        )
122
123    def _start_call(
124        self,
125        call_type: Type[CallTypeT],
126        request: Message | None,
127        timeout_s: OptionalTimeout,
128        on_next: OnNextCallback | None,
129        on_completed: OnCompletedCallback | None,
130        on_error: OnErrorCallback | None,
131        max_responses: int,
132    ) -> CallTypeT:
133        """Creates the Call object and invokes the RPC using it."""
134        if timeout_s is UseDefault.VALUE:
135            timeout_s = self.default_timeout_s
136
137        if self._impl.on_call_hook:
138            self._impl.on_call_hook(CallInfo(self._method))
139
140        rpc = PendingRpc(
141            self._channel,
142            self.service,
143            self.method,
144            self._rpcs.allocate_call_id(),
145        )
146        call = call_type(
147            self._rpcs,
148            rpc,
149            timeout_s,
150            on_next,
151            on_completed,
152            on_error,
153            max_responses,
154        )
155        call._invoke(request)  # pylint: disable=protected-access
156        return call
157
158    def _open_call(
159        self,
160        call_type: Type[CallTypeT],
161        on_next: OnNextCallback | None,
162        on_completed: OnCompletedCallback | None,
163        on_error: OnErrorCallback | None,
164        max_responses: int,
165    ) -> CallTypeT:
166        """Creates a Call object with the open call ID."""
167        rpc = PendingRpc(
168            self._channel,
169            self.service,
170            self.method,
171            client.OPEN_CALL_ID,
172        )
173        call = call_type(
174            self._rpcs,
175            rpc,
176            None,
177            on_next,
178            on_completed,
179            on_error,
180            max_responses,
181        )
182        call._open()  # pylint: disable=protected-access
183        return call
184
185    def _client_streaming_call_type(
186        self, base: Type[CallTypeT]
187    ) -> Type[CallTypeT]:
188        """Creates a client or bidirectional stream call type.
189
190        Applies the signature from the request protobuf to the send method.
191        """
192
193        def send(
194            self, request_proto: Message | None = None, /, **request_fields
195        ) -> None:
196            ClientStreamingCall.send(self, request_proto, **request_fields)
197
198        _apply_protobuf_signature(self.method, send)
199
200        return type(
201            f'{self.method.name}_{base.__name__}', (base,), dict(send=send)
202        )
203
204
205def _function_docstring(method: Method) -> str:
206    return f'''\
207Invokes the {method.full_name} {method.type.sentence_name()} RPC.
208
209This function accepts either the request protobuf fields as keyword arguments or
210a request protobuf as a positional argument.
211'''
212
213
214def _update_call_method(method: Method, function: Callable) -> None:
215    """Updates the name, docstring, and parameters to match a method."""
216    function.__name__ = method.full_name
217    function.__doc__ = _function_docstring(method)
218    _apply_protobuf_signature(method, function)
219
220
221def _apply_protobuf_signature(method: Method, function: Callable) -> None:
222    """Update a function signature to accept proto arguments.
223
224    In order to have good tab completion and help messages, update the function
225    signature to accept only keyword arguments for the proto message fields.
226    This doesn't actually change the function signature -- it just updates how
227    it appears when inspected.
228    """
229    sig = inspect.signature(function)
230
231    params = [next(iter(sig.parameters.values()))]  # Get the "self" parameter
232    params += method.request_parameters()
233    params.append(
234        inspect.Parameter('pw_rpc_timeout_s', inspect.Parameter.KEYWORD_ONLY)
235    )
236
237    function.__signature__ = sig.replace(  # type: ignore[attr-defined]
238        parameters=params
239    )
240
241
242class _UnaryMethodClient(_MethodClient):
243    def invoke(
244        self,
245        request: Message | None = None,
246        on_next: OnNextCallback | None = None,
247        on_completed: OnCompletedCallback | None = None,
248        on_error: OnErrorCallback | None = None,
249        *,
250        request_args: dict[str, Any] | None = None,
251        timeout_s: OptionalTimeout = UseDefault.VALUE,
252    ) -> UnaryCall:
253        """Invokes the unary RPC and returns a call object."""
254        return self._start_call(
255            UnaryCall,
256            self.method.get_request(request, request_args),
257            timeout_s,
258            on_next,
259            on_completed,
260            on_error,
261            max_responses=1,
262        )
263
264    def open(
265        self,
266        on_next: OnNextCallback | None = None,
267        on_completed: OnCompletedCallback | None = None,
268        on_error: OnErrorCallback | None = None,
269    ) -> UnaryCall:
270        """Invokes the unary RPC and returns a call object."""
271        return self._open_call(
272            UnaryCall, on_next, on_completed, on_error, max_responses=1
273        )
274
275
276class _ServerStreamingMethodClient(_MethodClient):
277    def invoke(
278        self,
279        request: Message | None = None,
280        on_next: OnNextCallback | None = None,
281        on_completed: OnCompletedCallback | None = None,
282        on_error: OnErrorCallback | None = None,
283        max_responses: int = DEFAULT_MAX_STREAM_RESPONSES,
284        *,
285        request_args: dict[str, Any] | None = None,
286        timeout_s: OptionalTimeout = UseDefault.VALUE,
287    ) -> ServerStreamingCall:
288        """Invokes the server streaming RPC and returns a call object."""
289        return self._start_call(
290            ServerStreamingCall,
291            self.method.get_request(request, request_args),
292            timeout_s,
293            on_next,
294            on_completed,
295            on_error,
296            max_responses=max_responses,
297        )
298
299    def open(
300        self,
301        on_next: OnNextCallback | None = None,
302        on_completed: OnCompletedCallback | None = None,
303        on_error: OnErrorCallback | None = None,
304        max_responses: int = DEFAULT_MAX_STREAM_RESPONSES,
305    ) -> ServerStreamingCall:
306        """Returns a call object for the RPC, even if the RPC cannot be invoked.
307
308        Can be used to listen for responses from an RPC server that may yet be
309        available.
310        """
311        return self._open_call(
312            ServerStreamingCall, on_next, on_completed, on_error, max_responses
313        )
314
315
316class _ClientStreamingMethodClient(_MethodClient):
317    def invoke(
318        self,
319        on_next: OnNextCallback | None = None,
320        on_completed: OnCompletedCallback | None = None,
321        on_error: OnErrorCallback | None = None,
322        *,
323        timeout_s: OptionalTimeout = UseDefault.VALUE,
324    ) -> ClientStreamingCall:
325        """Invokes the client streaming RPC and returns a call object"""
326        return self._start_call(
327            self._client_streaming_call_type(ClientStreamingCall),
328            None,
329            timeout_s,
330            on_next,
331            on_completed,
332            on_error,
333            max_responses=1,
334        )
335
336    def open(
337        self,
338        on_next: OnNextCallback | None = None,
339        on_completed: OnCompletedCallback | None = None,
340        on_error: OnErrorCallback | None = None,
341    ) -> ClientStreamingCall:
342        """Returns a call object for the RPC, even if the RPC cannot be invoked.
343
344        Can be used to listen for responses from an RPC server that may yet be
345        available.
346        """
347        return self._open_call(
348            self._client_streaming_call_type(ClientStreamingCall),
349            on_next,
350            on_completed,
351            on_error,
352            max_responses=1,
353        )
354
355    def __call__(
356        self,
357        requests: Iterable[Message] = (),
358        *,
359        timeout_s: OptionalTimeout = UseDefault.VALUE,
360    ) -> UnaryResponse:
361        return self.invoke().finish_and_wait(requests, timeout_s=timeout_s)
362
363
364class _BidirectionalStreamingMethodClient(_MethodClient):
365    def invoke(
366        self,
367        on_next: OnNextCallback | None = None,
368        on_completed: OnCompletedCallback | None = None,
369        on_error: OnErrorCallback | None = None,
370        max_responses: int = DEFAULT_MAX_STREAM_RESPONSES,
371        *,
372        timeout_s: OptionalTimeout = UseDefault.VALUE,
373    ) -> BidirectionalStreamingCall:
374        """Invokes the bidirectional streaming RPC and returns a call object."""
375        return self._start_call(
376            self._client_streaming_call_type(BidirectionalStreamingCall),
377            None,
378            timeout_s,
379            on_next,
380            on_completed,
381            on_error,
382            max_responses=max_responses,
383        )
384
385    def open(
386        self,
387        on_next: OnNextCallback | None = None,
388        on_completed: OnCompletedCallback | None = None,
389        on_error: OnErrorCallback | None = None,
390        max_responses: int = DEFAULT_MAX_STREAM_RESPONSES,
391    ) -> BidirectionalStreamingCall:
392        """Returns a call object for the RPC, even if the RPC cannot be invoked.
393
394        Can be used to listen for responses from an RPC server that may yet be
395        available.
396        """
397        return self._open_call(
398            self._client_streaming_call_type(BidirectionalStreamingCall),
399            on_next,
400            on_completed,
401            on_error,
402            max_responses=max_responses,
403        )
404
405    def __call__(
406        self,
407        requests: Iterable[Message] = (),
408        *,
409        timeout_s: OptionalTimeout = UseDefault.VALUE,
410    ) -> StreamResponse:
411        return self.invoke().finish_and_wait(requests, timeout_s=timeout_s)
412
413
414def _method_client_docstring(method: Method) -> str:
415    return f'''\
416Class that invokes the {method.full_name} {method.type.sentence_name()} RPC.
417
418Calling this directly invokes the RPC synchronously. The RPC can be invoked
419asynchronously using the invoke method.
420'''
421
422
423class Impl(client.ClientImpl):
424    """Callback-based ClientImpl, for use with pw_rpc.Client.
425
426    Args:
427        on_call_hook: A callable object to handle RPC method calls.
428            If hook is set, it will be called before RPC execution.
429    """
430
431    def __init__(
432        self,
433        default_unary_timeout_s: float | None = None,
434        default_stream_timeout_s: float | None = None,
435        on_call_hook: Callable[[CallInfo], Any] | None = None,
436    ) -> None:
437        super().__init__()
438        self._default_unary_timeout_s = default_unary_timeout_s
439        self._default_stream_timeout_s = default_stream_timeout_s
440        self.on_call_hook = on_call_hook
441
442    @property
443    def default_unary_timeout_s(self) -> float | None:
444        return self._default_unary_timeout_s
445
446    @property
447    def default_stream_timeout_s(self) -> float | None:
448        return self._default_stream_timeout_s
449
450    def method_client(self, channel: Channel, method: Method) -> _MethodClient:
451        """Returns an object that invokes a method using the given chanel."""
452
453        if method.type is Method.Type.UNARY:
454            return self._create_unary_method_client(
455                channel, method, self.default_unary_timeout_s
456            )
457
458        if method.type is Method.Type.SERVER_STREAMING:
459            return self._create_server_streaming_method_client(
460                channel, method, self.default_stream_timeout_s
461            )
462
463        if method.type is Method.Type.CLIENT_STREAMING:
464            return self._create_method_client(
465                _ClientStreamingMethodClient,
466                channel,
467                method,
468                self.default_unary_timeout_s,
469            )
470
471        if method.type is Method.Type.BIDIRECTIONAL_STREAMING:
472            return self._create_method_client(
473                _BidirectionalStreamingMethodClient,
474                channel,
475                method,
476                self.default_stream_timeout_s,
477            )
478
479        raise AssertionError(f'Unknown method type {method.type}')
480
481    def _create_method_client(
482        self,
483        base: type,
484        channel: Channel,
485        method: Method,
486        default_timeout_s: float | None,
487        **fields,
488    ):
489        """Creates a _MethodClient derived class customized for the method."""
490        method_client_type = type(
491            f'{method.name}{base.__name__}',
492            (base,),
493            dict(__doc__=_method_client_docstring(method), **fields),
494        )
495        return method_client_type(
496            self, self.rpcs, channel, method, default_timeout_s
497        )
498
499    def _create_unary_method_client(
500        self,
501        channel: Channel,
502        method: Method,
503        default_timeout_s: float | None,
504    ) -> _UnaryMethodClient:
505        """Creates a _UnaryMethodClient with a customized __call__ method."""
506
507        def call(
508            self: _UnaryMethodClient,
509            request_proto: Message | None = None,
510            /,
511            *,
512            pw_rpc_timeout_s: OptionalTimeout = UseDefault.VALUE,
513            **request_fields,
514        ) -> UnaryResponse:
515            return self.invoke(
516                self.method.get_request(request_proto, request_fields)
517            ).wait(pw_rpc_timeout_s)
518
519        _update_call_method(method, call)
520        return self._create_method_client(
521            _UnaryMethodClient,
522            channel,
523            method,
524            default_timeout_s,
525            __call__=call,
526        )
527
528    def _create_server_streaming_method_client(
529        self,
530        channel: Channel,
531        method: Method,
532        default_timeout_s: float | None,
533    ) -> _ServerStreamingMethodClient:
534        """Creates _ServerStreamingMethodClient with custom __call__ method."""
535
536        def call(
537            self: _ServerStreamingMethodClient,
538            request_proto: Message | None = None,
539            /,
540            *,
541            pw_rpc_timeout_s: OptionalTimeout = UseDefault.VALUE,
542            **request_fields,
543        ) -> StreamResponse:
544            return self.invoke(
545                self.method.get_request(request_proto, request_fields)
546            ).wait(pw_rpc_timeout_s)
547
548        _update_call_method(method, call)
549        return self._create_method_client(
550            _ServerStreamingMethodClient,
551            channel,
552            method,
553            default_timeout_s,
554            __call__=call,
555        )
556
557    def handle_response(
558        self,
559        rpc: PendingRpc,
560        context: Call,
561        payload,
562    ) -> None:
563        """Invokes the callback associated with this RPC."""
564        context._handle_response(payload)  # pylint: disable=protected-access
565
566    def handle_completion(
567        self,
568        rpc: PendingRpc,
569        context: Call,
570        status: Status,
571    ):
572        context._handle_completion(status)  # pylint: disable=protected-access
573
574    def handle_error(
575        self,
576        rpc: PendingRpc,
577        context: Call,
578        status: Status,
579    ) -> None:
580        context._handle_error(status)  # pylint: disable=protected-access
581