1# Copyright 2020 The Pigweed Authors 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); you may not 4# use this file except in compliance with the License. You may obtain a copy of 5# the License at 6# 7# https://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 11# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 12# License for the specific language governing permissions and limitations under 13# the License. 14"""Types representing the basic pw_rpc concepts: channel, service, method.""" 15 16from __future__ import annotations 17 18import abc 19from dataclasses import dataclass 20import enum 21from inspect import Parameter 22from typing import ( 23 Any, 24 Callable, 25 Collection, 26 Generic, 27 Iterable, 28 Iterator, 29 TypeVar, 30) 31 32from google.protobuf import descriptor_pb2, message_factory 33from google.protobuf.descriptor import ( 34 FieldDescriptor, 35 MethodDescriptor, 36 ServiceDescriptor, 37) 38from google.protobuf.message import Message 39from pw_protobuf_compiler import python_protos 40 41from pw_rpc import ids 42 43 44@dataclass(frozen=True) 45class Channel: 46 id: int 47 output: Callable[[bytes], Any] 48 49 def __repr__(self) -> str: 50 return f'Channel({self.id})' 51 52 53class ChannelManipulator(abc.ABC): 54 """A a pipe interface that may manipulate packets before they're sent. 55 56 ``ChannelManipulator``s allow application-specific packet handling to be 57 injected into the packet processing pipeline for an ingress or egress 58 channel-like pathway. This is particularly useful for integration testing 59 resilience to things like packet loss on a usually-reliable transport. RPC 60 server integrations (e.g. ``HdlcRpcLocalServerAndClient``) may provide an 61 opportunity to inject a ``ChannelManipulator`` for this use case. 62 63 A ``ChannelManipulator`` should not modify send_packet, as the consumer of a 64 ``ChannelManipulator`` will use ``send_packet`` to insert the provided 65 ``ChannelManipulator`` into a packet processing path. 66 67 For example: 68 69 .. code-block:: python 70 71 class PacketLogger(ChannelManipulator): 72 def process_and_send(self, packet: bytes) -> None: 73 _LOG.debug('Received packet with payload: %s', str(packet)) 74 self.send_packet(packet) 75 76 77 packet_logger = PacketLogger() 78 79 # Configure actual send command. 80 packet_logger.send_packet = socket.sendall 81 82 # Route the output channel through the PacketLogger channel manipulator. 83 channels = tuple(Channel(_DEFAULT_CHANNEL, packet_logger)) 84 85 # Create a RPC client. 86 reader = SocketReader(socket) 87 with reader: 88 client = HdlcRpcClient(reader, protos, channels, stdout) 89 with client: 90 # Do something with client 91 """ 92 93 def __init__(self) -> None: 94 self.send_packet: Callable[[bytes], Any] = lambda _: None 95 96 @abc.abstractmethod 97 def process_and_send(self, packet: bytes) -> None: 98 """Processes an incoming packet before optionally sending it. 99 100 Implementations of this method may send the processed packet, multiple 101 packets, or no packets at all via the registered `send_packet()` 102 handler. 103 """ 104 105 def __call__(self, data: bytes) -> None: 106 self.process_and_send(data) 107 108 109@dataclass(frozen=True, eq=False) 110class Service: 111 """Describes an RPC service.""" 112 113 _descriptor: ServiceDescriptor 114 id: int 115 methods: Methods 116 117 @property 118 def name(self): 119 return self._descriptor.name 120 121 @property 122 def full_name(self): 123 return self._descriptor.full_name 124 125 @property 126 def package(self): 127 return self._descriptor.file.package 128 129 @classmethod 130 def from_descriptor(cls, descriptor: ServiceDescriptor) -> Service: 131 service = cls( 132 descriptor, 133 ids.calculate(descriptor.full_name), 134 None, # type: ignore[arg-type] 135 ) 136 object.__setattr__( 137 service, 138 'methods', 139 Methods( 140 Method.from_descriptor(method_descriptor, service) 141 for method_descriptor in descriptor.methods 142 ), 143 ) 144 145 return service 146 147 def __repr__(self) -> str: 148 return f'Service({self.full_name!r})' 149 150 def __str__(self) -> str: 151 return self.full_name 152 153 154def _streaming_attributes(method) -> tuple[bool, bool]: 155 # TODO(hepler): Investigate adding server_streaming and client_streaming 156 # attributes to the generated protobuf code. As a workaround, 157 # deserialize the FileDescriptorProto to get that information. 158 service = method.containing_service 159 160 file_pb = descriptor_pb2.FileDescriptorProto() 161 file_pb.MergeFromString(service.file.serialized_pb) 162 163 method_pb = file_pb.service[service.index].method[ 164 method.index 165 ] # pylint: disable=no-member 166 return method_pb.server_streaming, method_pb.client_streaming 167 168 169_PROTO_FIELD_TYPES = { 170 FieldDescriptor.TYPE_BOOL: bool, 171 FieldDescriptor.TYPE_BYTES: bytes, 172 FieldDescriptor.TYPE_DOUBLE: float, 173 FieldDescriptor.TYPE_ENUM: int, 174 FieldDescriptor.TYPE_FIXED32: int, 175 FieldDescriptor.TYPE_FIXED64: int, 176 FieldDescriptor.TYPE_FLOAT: float, 177 FieldDescriptor.TYPE_INT32: int, 178 FieldDescriptor.TYPE_INT64: int, 179 FieldDescriptor.TYPE_SFIXED32: int, 180 FieldDescriptor.TYPE_SFIXED64: int, 181 FieldDescriptor.TYPE_SINT32: int, 182 FieldDescriptor.TYPE_SINT64: int, 183 FieldDescriptor.TYPE_STRING: str, 184 FieldDescriptor.TYPE_UINT32: int, 185 FieldDescriptor.TYPE_UINT64: int, 186 # These types are not annotated: 187 # FieldDescriptor.TYPE_GROUP = 10 188 # FieldDescriptor.TYPE_MESSAGE = 11 189} 190 191 192def _field_type_annotation(field: FieldDescriptor): 193 """Creates a field type annotation to use in the help message only.""" 194 if field.type == FieldDescriptor.TYPE_MESSAGE: 195 annotation = message_factory.GetMessageClass(field.message_type) 196 else: 197 annotation = _PROTO_FIELD_TYPES.get(field.type, Parameter.empty) 198 199 if field.label == FieldDescriptor.LABEL_REPEATED: 200 return Iterable[annotation] # type: ignore[valid-type] 201 202 return annotation 203 204 205def field_help(proto_message, *, annotations: bool = False) -> Iterator[str]: 206 """Yields argument strings for proto fields for use in a help message.""" 207 for field in proto_message.DESCRIPTOR.fields: 208 if field.type == FieldDescriptor.TYPE_ENUM: 209 value = field.enum_type.values_by_number[field.default_value].name 210 type_name = field.enum_type.full_name 211 value = f'{type_name.rsplit(".", 1)[0]}.{value}' 212 else: 213 type_name = _PROTO_FIELD_TYPES[field.type].__name__ 214 value = repr(field.default_value) 215 216 if annotations: 217 yield f'{field.name}: {type_name} = {value}' 218 else: 219 yield f'{field.name}={value}' 220 221 222@dataclass(frozen=True, eq=False) 223class Method: 224 """Describes a method in a service.""" 225 226 _descriptor: MethodDescriptor 227 service: Service 228 id: int 229 server_streaming: bool 230 client_streaming: bool 231 request_type: Any 232 response_type: Any 233 234 @classmethod 235 def from_descriptor( 236 cls, descriptor: MethodDescriptor, service: Service 237 ) -> Method: 238 return Method( 239 descriptor, 240 service, 241 ids.calculate(descriptor.name), 242 *_streaming_attributes(descriptor), 243 message_factory.GetMessageClass(descriptor.input_type), 244 message_factory.GetMessageClass(descriptor.output_type), 245 ) 246 247 class Type(enum.Enum): 248 UNARY = 0 249 SERVER_STREAMING = 1 250 CLIENT_STREAMING = 2 251 BIDIRECTIONAL_STREAMING = 3 252 253 def sentence_name(self) -> str: 254 return self.name.lower().replace( 255 '_', ' ' 256 ) # pylint: disable=no-member 257 258 @property 259 def name(self) -> str: 260 return self._descriptor.name 261 262 @property 263 def full_name(self) -> str: 264 return self._descriptor.full_name 265 266 @property 267 def package(self) -> str: 268 return self._descriptor.containing_service.file.package 269 270 @property 271 def type(self) -> Method.Type: 272 if self.server_streaming and self.client_streaming: 273 return self.Type.BIDIRECTIONAL_STREAMING 274 275 if self.server_streaming: 276 return self.Type.SERVER_STREAMING 277 278 if self.client_streaming: 279 return self.Type.CLIENT_STREAMING 280 281 return self.Type.UNARY 282 283 def get_request( 284 self, proto: Message | None, proto_kwargs: dict[str, Any] | None 285 ) -> Message: 286 """Returns a request_type protobuf message. 287 288 The client implementation may use this to support providing a request 289 as either a message object or as keyword arguments for the message's 290 fields (but not both). 291 """ 292 if proto_kwargs is None: 293 proto_kwargs = {} 294 295 if proto and proto_kwargs: 296 proto_str = repr(proto).strip() or "''" 297 raise TypeError( 298 'Requests must be provided either as a message object or a ' 299 'series of keyword args, but both were provided ' 300 f"({proto_str} and {proto_kwargs!r})" 301 ) 302 303 if proto is None: 304 return self.request_type(**proto_kwargs) 305 306 if not isinstance(proto, self.request_type): 307 try: 308 bad_type = proto.DESCRIPTOR.full_name 309 except AttributeError: 310 bad_type = type(proto).__name__ 311 312 raise TypeError( 313 f'Expected a message of type ' 314 f'{self.request_type.DESCRIPTOR.full_name}, ' 315 f'got {bad_type}' 316 ) 317 318 return proto 319 320 def request_parameters(self) -> Iterator[Parameter]: 321 """Yields inspect.Parameters corresponding to the request's fields. 322 323 This can be used to make function signatures match the request proto. 324 """ 325 for field in self.request_type.DESCRIPTOR.fields: 326 yield Parameter( 327 field.name, 328 Parameter.KEYWORD_ONLY, 329 annotation=_field_type_annotation(field), 330 default=field.default_value, 331 ) 332 333 def __repr__(self) -> str: 334 req = self._method_parameter(self.request_type, self.client_streaming) 335 res = self._method_parameter(self.response_type, self.server_streaming) 336 return f'<{self.full_name}({req}) returns ({res})>' 337 338 def _method_parameter(self, proto, streaming: bool) -> str: 339 """Returns a description of the method's request or response type.""" 340 stream = 'stream ' if streaming else '' 341 342 if proto.DESCRIPTOR.file.package == self.service.package: 343 return stream + proto.DESCRIPTOR.name 344 345 return stream + proto.DESCRIPTOR.full_name 346 347 def __str__(self) -> str: 348 return self.full_name 349 350 351T = TypeVar('T') 352 353 354def _name(item: Service | Method) -> str: 355 return item.full_name if isinstance(item, Service) else item.name 356 357 358class _AccessByName(Generic[T]): 359 """Wrapper for accessing types by name within a proto package structure.""" 360 361 def __init__(self, name: str, item: T): 362 setattr(self, name, item) 363 364 365class ServiceAccessor(Collection[T]): 366 """Navigates RPC services by name or ID.""" 367 368 def __init__(self, members, as_attrs: str = ''): 369 """Creates accessor from an {item: value} dict or [values] iterable.""" 370 # If the members arg was passed as a [values] iterable, convert it to 371 # an equivalent dictionary. 372 if not isinstance(members, dict): 373 members = {m: m for m in members} 374 375 by_name: dict[str, Any] = {_name(k): v for k, v in members.items()} 376 self._by_id = {k.id: v for k, v in members.items()} 377 # Note: a dictionary is used rather than `setattr` in order to 378 # (1) Hint to the type checker that there will be extra fields 379 # (2) Ensure that built-in attributes such as `_by_id`` are not 380 # overwritten. 381 self._attrs: dict[str, Any] = {} 382 383 if as_attrs == 'members': 384 for name, member in by_name.items(): 385 self._attrs[name] = member 386 elif as_attrs == 'packages': 387 for package in python_protos.as_packages( 388 (m.package, _AccessByName(m.name, members[m])) for m in members 389 ).packages: 390 self._attrs[str(package)] = package 391 elif as_attrs: 392 raise ValueError(f'Unexpected value {as_attrs!r} for as_attrs') 393 394 def __getattr__(self, name: str) -> Any: 395 return self._attrs[name] 396 397 def __getitem__(self, name_or_id: str | int) -> Any: 398 """Accesses a service/method by the string name or ID.""" 399 try: 400 return self._by_id[_id(name_or_id)] 401 except KeyError: 402 pass 403 404 name = f' ("{name_or_id}")' if isinstance(name_or_id, str) else '' 405 raise KeyError(f'Unknown ID {_id(name_or_id)}{name}') 406 407 def __iter__(self) -> Iterator[T]: 408 return iter(self._by_id.values()) 409 410 def __len__(self) -> int: 411 return len(self._by_id) 412 413 def __contains__(self, name_or_id) -> bool: 414 return _id(name_or_id) in self._by_id 415 416 def __repr__(self) -> str: 417 members = ', '.join(repr(m) for m in self._by_id.values()) 418 return f'{self.__class__.__name__}({members})' 419 420 421def _id(handle: str | int) -> int: 422 return ids.calculate(handle) if isinstance(handle, str) else handle 423 424 425class Methods(ServiceAccessor[Method]): 426 """A collection of Method descriptors in a Service.""" 427 428 def __init__(self, method: Iterable[Method]): 429 super().__init__(method) 430 431 432class Services(ServiceAccessor[Service]): 433 """A collection of Service descriptors.""" 434 435 def __init__(self, services: Iterable[Service]): 436 super().__init__(services) 437 438 439def get_method(service_accessor: ServiceAccessor, name: str): 440 """Returns a method matching the given full name in a ServiceAccessor. 441 442 Args: 443 name: name as package.Service/Method or package.Service.Method. 444 445 Raises: 446 ValueError: the method name is not properly formatted 447 KeyError: the method is not present 448 """ 449 if '/' in name: 450 service_name, method_name = name.split('/') 451 else: 452 service_name, method_name = name.rsplit('.', 1) 453 454 service = service_accessor[service_name] 455 if isinstance(service, Service): 456 service = service.methods 457 458 return service[method_name] 459 460 461@dataclass(frozen=True) 462class RpcIds: 463 """Integer IDs that uniquely identify a remote procedure call.""" 464 465 channel_id: int 466 service_id: int 467 method_id: int 468 call_id: int 469 470 471@dataclass(frozen=True) 472class PendingRpc: 473 """Tracks an active RPC call.""" 474 475 channel: Channel 476 service: Service 477 method: Method 478 call_id: int 479 480 @property 481 def channel_id(self) -> int: 482 return self.channel.id 483 484 @property 485 def service_id(self) -> int: 486 return self.service.id 487 488 @property 489 def method_id(self) -> int: 490 return self.method.id 491 492 def __eq__(self, other: Any) -> bool: 493 if isinstance(other, PendingRpc): 494 return self.ids() == other.ids() 495 496 return NotImplemented 497 498 def __hash__(self) -> int: 499 return hash(self.ids()) 500 501 def ids(self) -> RpcIds: 502 return RpcIds( 503 self.channel.id, self.service.id, self.method.id, self.call_id 504 ) 505 506 def __str__(self) -> str: 507 return ( 508 f'PendingRpc(channel={self.channel.id}, method={self.method}, ' 509 f'call_id={self.call_id})' 510 ) 511 512 def matches_channel_service_method(self, other: PendingRpc) -> bool: 513 return ( 514 self.channel.id == other.channel.id 515 and self.service.id == other.service.id 516 and self.method.id == other.method.id 517 ) 518