1# Copyright 2021 The Pigweed Authors 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); you may not 4# use this file except in compliance with the License. You may obtain a copy of 5# the License at 6# 7# https://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 11# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 12# License for the specific language governing permissions and limitations under 13# the License. 14"""The callback-based pw_rpc client implementation.""" 15 16from __future__ import annotations 17 18import inspect 19import logging 20import textwrap 21from typing import Any, Callable, Iterable, Type 22 23from dataclasses import dataclass 24from pw_status import Status 25from google.protobuf.message import Message 26 27from pw_rpc import client, descriptors 28from pw_rpc.client import PendingRpc, PendingRpcs 29from pw_rpc.descriptors import Channel, Method, Service 30 31from pw_rpc.callback_client.call import ( 32 UseDefault, 33 OptionalTimeout, 34 CallTypeT, 35 UnaryResponse, 36 StreamResponse, 37 Call, 38 UnaryCall, 39 ServerStreamingCall, 40 ClientStreamingCall, 41 BidirectionalStreamingCall, 42 OnNextCallback, 43 OnCompletedCallback, 44 OnErrorCallback, 45) 46 47_LOG = logging.getLogger(__package__) 48 49 50DEFAULT_MAX_STREAM_RESPONSES = 2**14 51 52 53@dataclass(eq=True, frozen=True) 54class CallInfo: 55 method: Method 56 57 @property 58 def service(self) -> Service: 59 return self.method.service 60 61 62class _MethodClient: 63 """A method that can be invoked for a particular channel.""" 64 65 def __init__( 66 self, 67 client_impl: Impl, 68 rpcs: PendingRpcs, 69 channel: Channel, 70 method: Method, 71 default_timeout_s: float | None, 72 ) -> None: 73 self._impl = client_impl 74 self._rpcs = rpcs 75 self._channel = channel 76 self._method = method 77 self.default_timeout_s: float | None = default_timeout_s 78 79 @property 80 def channel(self) -> Channel: 81 return self._channel 82 83 @property 84 def method(self) -> Method: 85 return self._method 86 87 @property 88 def service(self) -> Service: 89 return self._method.service 90 91 @property 92 def request(self) -> type: 93 """Returns the request proto class.""" 94 return self.method.request_type 95 96 @property 97 def response(self) -> type: 98 """Returns the response proto class.""" 99 return self.method.response_type 100 101 def __repr__(self) -> str: 102 return self.help() 103 104 def help(self) -> str: 105 """Returns a help message about this RPC.""" 106 function_call = self.method.full_name + '(' 107 108 docstring = inspect.getdoc(self.__call__) # type: ignore[operator] # pylint: disable=no-member 109 assert docstring is not None 110 111 annotation = inspect.signature(self).return_annotation # type: ignore[arg-type] # pylint: disable=line-too-long 112 if isinstance(annotation, type): 113 annotation = annotation.__name__ 114 115 arg_sep = f',\n{" " * len(function_call)}' 116 return ( 117 f'{function_call}' 118 f'{arg_sep.join(descriptors.field_help(self.method.request_type))})' 119 f'\n\n{textwrap.indent(docstring, " ")}\n\n' 120 f' Returns {annotation}.' 121 ) 122 123 def _start_call( 124 self, 125 call_type: Type[CallTypeT], 126 request: Message | None, 127 timeout_s: OptionalTimeout, 128 on_next: OnNextCallback | None, 129 on_completed: OnCompletedCallback | None, 130 on_error: OnErrorCallback | None, 131 max_responses: int, 132 ) -> CallTypeT: 133 """Creates the Call object and invokes the RPC using it.""" 134 if timeout_s is UseDefault.VALUE: 135 timeout_s = self.default_timeout_s 136 137 if self._impl.on_call_hook: 138 self._impl.on_call_hook(CallInfo(self._method)) 139 140 rpc = PendingRpc( 141 self._channel, 142 self.service, 143 self.method, 144 self._rpcs.allocate_call_id(), 145 ) 146 call = call_type( 147 self._rpcs, 148 rpc, 149 timeout_s, 150 on_next, 151 on_completed, 152 on_error, 153 max_responses, 154 ) 155 call._invoke(request) # pylint: disable=protected-access 156 return call 157 158 def _open_call( 159 self, 160 call_type: Type[CallTypeT], 161 on_next: OnNextCallback | None, 162 on_completed: OnCompletedCallback | None, 163 on_error: OnErrorCallback | None, 164 max_responses: int, 165 ) -> CallTypeT: 166 """Creates a Call object with the open call ID.""" 167 rpc = PendingRpc( 168 self._channel, 169 self.service, 170 self.method, 171 client.OPEN_CALL_ID, 172 ) 173 call = call_type( 174 self._rpcs, 175 rpc, 176 None, 177 on_next, 178 on_completed, 179 on_error, 180 max_responses, 181 ) 182 call._open() # pylint: disable=protected-access 183 return call 184 185 def _client_streaming_call_type( 186 self, base: Type[CallTypeT] 187 ) -> Type[CallTypeT]: 188 """Creates a client or bidirectional stream call type. 189 190 Applies the signature from the request protobuf to the send method. 191 """ 192 193 def send( 194 self, request_proto: Message | None = None, /, **request_fields 195 ) -> None: 196 ClientStreamingCall.send(self, request_proto, **request_fields) 197 198 _apply_protobuf_signature(self.method, send) 199 200 return type( 201 f'{self.method.name}_{base.__name__}', (base,), dict(send=send) 202 ) 203 204 205def _function_docstring(method: Method) -> str: 206 return f'''\ 207Invokes the {method.full_name} {method.type.sentence_name()} RPC. 208 209This function accepts either the request protobuf fields as keyword arguments or 210a request protobuf as a positional argument. 211''' 212 213 214def _update_call_method(method: Method, function: Callable) -> None: 215 """Updates the name, docstring, and parameters to match a method.""" 216 function.__name__ = method.full_name 217 function.__doc__ = _function_docstring(method) 218 _apply_protobuf_signature(method, function) 219 220 221def _apply_protobuf_signature(method: Method, function: Callable) -> None: 222 """Update a function signature to accept proto arguments. 223 224 In order to have good tab completion and help messages, update the function 225 signature to accept only keyword arguments for the proto message fields. 226 This doesn't actually change the function signature -- it just updates how 227 it appears when inspected. 228 """ 229 sig = inspect.signature(function) 230 231 params = [next(iter(sig.parameters.values()))] # Get the "self" parameter 232 params += method.request_parameters() 233 params.append( 234 inspect.Parameter('pw_rpc_timeout_s', inspect.Parameter.KEYWORD_ONLY) 235 ) 236 237 function.__signature__ = sig.replace( # type: ignore[attr-defined] 238 parameters=params 239 ) 240 241 242class _UnaryMethodClient(_MethodClient): 243 def invoke( 244 self, 245 request: Message | None = None, 246 on_next: OnNextCallback | None = None, 247 on_completed: OnCompletedCallback | None = None, 248 on_error: OnErrorCallback | None = None, 249 *, 250 request_args: dict[str, Any] | None = None, 251 timeout_s: OptionalTimeout = UseDefault.VALUE, 252 ) -> UnaryCall: 253 """Invokes the unary RPC and returns a call object.""" 254 return self._start_call( 255 UnaryCall, 256 self.method.get_request(request, request_args), 257 timeout_s, 258 on_next, 259 on_completed, 260 on_error, 261 max_responses=1, 262 ) 263 264 def open( 265 self, 266 on_next: OnNextCallback | None = None, 267 on_completed: OnCompletedCallback | None = None, 268 on_error: OnErrorCallback | None = None, 269 ) -> UnaryCall: 270 """Invokes the unary RPC and returns a call object.""" 271 return self._open_call( 272 UnaryCall, on_next, on_completed, on_error, max_responses=1 273 ) 274 275 276class _ServerStreamingMethodClient(_MethodClient): 277 def invoke( 278 self, 279 request: Message | None = None, 280 on_next: OnNextCallback | None = None, 281 on_completed: OnCompletedCallback | None = None, 282 on_error: OnErrorCallback | None = None, 283 max_responses: int = DEFAULT_MAX_STREAM_RESPONSES, 284 *, 285 request_args: dict[str, Any] | None = None, 286 timeout_s: OptionalTimeout = UseDefault.VALUE, 287 ) -> ServerStreamingCall: 288 """Invokes the server streaming RPC and returns a call object.""" 289 return self._start_call( 290 ServerStreamingCall, 291 self.method.get_request(request, request_args), 292 timeout_s, 293 on_next, 294 on_completed, 295 on_error, 296 max_responses=max_responses, 297 ) 298 299 def open( 300 self, 301 on_next: OnNextCallback | None = None, 302 on_completed: OnCompletedCallback | None = None, 303 on_error: OnErrorCallback | None = None, 304 max_responses: int = DEFAULT_MAX_STREAM_RESPONSES, 305 ) -> ServerStreamingCall: 306 """Returns a call object for the RPC, even if the RPC cannot be invoked. 307 308 Can be used to listen for responses from an RPC server that may yet be 309 available. 310 """ 311 return self._open_call( 312 ServerStreamingCall, on_next, on_completed, on_error, max_responses 313 ) 314 315 316class _ClientStreamingMethodClient(_MethodClient): 317 def invoke( 318 self, 319 on_next: OnNextCallback | None = None, 320 on_completed: OnCompletedCallback | None = None, 321 on_error: OnErrorCallback | None = None, 322 *, 323 timeout_s: OptionalTimeout = UseDefault.VALUE, 324 ) -> ClientStreamingCall: 325 """Invokes the client streaming RPC and returns a call object""" 326 return self._start_call( 327 self._client_streaming_call_type(ClientStreamingCall), 328 None, 329 timeout_s, 330 on_next, 331 on_completed, 332 on_error, 333 max_responses=1, 334 ) 335 336 def open( 337 self, 338 on_next: OnNextCallback | None = None, 339 on_completed: OnCompletedCallback | None = None, 340 on_error: OnErrorCallback | None = None, 341 ) -> ClientStreamingCall: 342 """Returns a call object for the RPC, even if the RPC cannot be invoked. 343 344 Can be used to listen for responses from an RPC server that may yet be 345 available. 346 """ 347 return self._open_call( 348 self._client_streaming_call_type(ClientStreamingCall), 349 on_next, 350 on_completed, 351 on_error, 352 max_responses=1, 353 ) 354 355 def __call__( 356 self, 357 requests: Iterable[Message] = (), 358 *, 359 timeout_s: OptionalTimeout = UseDefault.VALUE, 360 ) -> UnaryResponse: 361 return self.invoke().finish_and_wait(requests, timeout_s=timeout_s) 362 363 364class _BidirectionalStreamingMethodClient(_MethodClient): 365 def invoke( 366 self, 367 on_next: OnNextCallback | None = None, 368 on_completed: OnCompletedCallback | None = None, 369 on_error: OnErrorCallback | None = None, 370 max_responses: int = DEFAULT_MAX_STREAM_RESPONSES, 371 *, 372 timeout_s: OptionalTimeout = UseDefault.VALUE, 373 ) -> BidirectionalStreamingCall: 374 """Invokes the bidirectional streaming RPC and returns a call object.""" 375 return self._start_call( 376 self._client_streaming_call_type(BidirectionalStreamingCall), 377 None, 378 timeout_s, 379 on_next, 380 on_completed, 381 on_error, 382 max_responses=max_responses, 383 ) 384 385 def open( 386 self, 387 on_next: OnNextCallback | None = None, 388 on_completed: OnCompletedCallback | None = None, 389 on_error: OnErrorCallback | None = None, 390 max_responses: int = DEFAULT_MAX_STREAM_RESPONSES, 391 ) -> BidirectionalStreamingCall: 392 """Returns a call object for the RPC, even if the RPC cannot be invoked. 393 394 Can be used to listen for responses from an RPC server that may yet be 395 available. 396 """ 397 return self._open_call( 398 self._client_streaming_call_type(BidirectionalStreamingCall), 399 on_next, 400 on_completed, 401 on_error, 402 max_responses=max_responses, 403 ) 404 405 def __call__( 406 self, 407 requests: Iterable[Message] = (), 408 *, 409 timeout_s: OptionalTimeout = UseDefault.VALUE, 410 ) -> StreamResponse: 411 return self.invoke().finish_and_wait(requests, timeout_s=timeout_s) 412 413 414def _method_client_docstring(method: Method) -> str: 415 return f'''\ 416Class that invokes the {method.full_name} {method.type.sentence_name()} RPC. 417 418Calling this directly invokes the RPC synchronously. The RPC can be invoked 419asynchronously using the invoke method. 420''' 421 422 423class Impl(client.ClientImpl): 424 """Callback-based ClientImpl, for use with pw_rpc.Client. 425 426 Args: 427 on_call_hook: A callable object to handle RPC method calls. 428 If hook is set, it will be called before RPC execution. 429 """ 430 431 def __init__( 432 self, 433 default_unary_timeout_s: float | None = None, 434 default_stream_timeout_s: float | None = None, 435 on_call_hook: Callable[[CallInfo], Any] | None = None, 436 ) -> None: 437 super().__init__() 438 self._default_unary_timeout_s = default_unary_timeout_s 439 self._default_stream_timeout_s = default_stream_timeout_s 440 self.on_call_hook = on_call_hook 441 442 @property 443 def default_unary_timeout_s(self) -> float | None: 444 return self._default_unary_timeout_s 445 446 @property 447 def default_stream_timeout_s(self) -> float | None: 448 return self._default_stream_timeout_s 449 450 def method_client(self, channel: Channel, method: Method) -> _MethodClient: 451 """Returns an object that invokes a method using the given chanel.""" 452 453 if method.type is Method.Type.UNARY: 454 return self._create_unary_method_client( 455 channel, method, self.default_unary_timeout_s 456 ) 457 458 if method.type is Method.Type.SERVER_STREAMING: 459 return self._create_server_streaming_method_client( 460 channel, method, self.default_stream_timeout_s 461 ) 462 463 if method.type is Method.Type.CLIENT_STREAMING: 464 return self._create_method_client( 465 _ClientStreamingMethodClient, 466 channel, 467 method, 468 self.default_unary_timeout_s, 469 ) 470 471 if method.type is Method.Type.BIDIRECTIONAL_STREAMING: 472 return self._create_method_client( 473 _BidirectionalStreamingMethodClient, 474 channel, 475 method, 476 self.default_stream_timeout_s, 477 ) 478 479 raise AssertionError(f'Unknown method type {method.type}') 480 481 def _create_method_client( 482 self, 483 base: type, 484 channel: Channel, 485 method: Method, 486 default_timeout_s: float | None, 487 **fields, 488 ): 489 """Creates a _MethodClient derived class customized for the method.""" 490 method_client_type = type( 491 f'{method.name}{base.__name__}', 492 (base,), 493 dict(__doc__=_method_client_docstring(method), **fields), 494 ) 495 return method_client_type( 496 self, self.rpcs, channel, method, default_timeout_s 497 ) 498 499 def _create_unary_method_client( 500 self, 501 channel: Channel, 502 method: Method, 503 default_timeout_s: float | None, 504 ) -> _UnaryMethodClient: 505 """Creates a _UnaryMethodClient with a customized __call__ method.""" 506 507 def call( 508 self: _UnaryMethodClient, 509 request_proto: Message | None = None, 510 /, 511 *, 512 pw_rpc_timeout_s: OptionalTimeout = UseDefault.VALUE, 513 **request_fields, 514 ) -> UnaryResponse: 515 return self.invoke( 516 self.method.get_request(request_proto, request_fields) 517 ).wait(pw_rpc_timeout_s) 518 519 _update_call_method(method, call) 520 return self._create_method_client( 521 _UnaryMethodClient, 522 channel, 523 method, 524 default_timeout_s, 525 __call__=call, 526 ) 527 528 def _create_server_streaming_method_client( 529 self, 530 channel: Channel, 531 method: Method, 532 default_timeout_s: float | None, 533 ) -> _ServerStreamingMethodClient: 534 """Creates _ServerStreamingMethodClient with custom __call__ method.""" 535 536 def call( 537 self: _ServerStreamingMethodClient, 538 request_proto: Message | None = None, 539 /, 540 *, 541 pw_rpc_timeout_s: OptionalTimeout = UseDefault.VALUE, 542 **request_fields, 543 ) -> StreamResponse: 544 return self.invoke( 545 self.method.get_request(request_proto, request_fields) 546 ).wait(pw_rpc_timeout_s) 547 548 _update_call_method(method, call) 549 return self._create_method_client( 550 _ServerStreamingMethodClient, 551 channel, 552 method, 553 default_timeout_s, 554 __call__=call, 555 ) 556 557 def handle_response( 558 self, 559 rpc: PendingRpc, 560 context: Call, 561 payload, 562 ) -> None: 563 """Invokes the callback associated with this RPC.""" 564 context._handle_response(payload) # pylint: disable=protected-access 565 566 def handle_completion( 567 self, 568 rpc: PendingRpc, 569 context: Call, 570 status: Status, 571 ): 572 context._handle_completion(status) # pylint: disable=protected-access 573 574 def handle_error( 575 self, 576 rpc: PendingRpc, 577 context: Call, 578 status: Status, 579 ) -> None: 580 context._handle_error(status) # pylint: disable=protected-access 581