xref: /aosp_15_r20/external/pigweed/pw_rpc/py/pw_rpc/descriptors.py (revision 61c4878ac05f98d0ceed94b57d316916de578985)
1# Copyright 2020 The Pigweed Authors
2#
3# Licensed under the Apache License, Version 2.0 (the "License"); you may not
4# use this file except in compliance with the License. You may obtain a copy of
5# the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12# License for the specific language governing permissions and limitations under
13# the License.
14"""Types representing the basic pw_rpc concepts: channel, service, method."""
15
16from __future__ import annotations
17
18import abc
19from dataclasses import dataclass
20import enum
21from inspect import Parameter
22from typing import (
23    Any,
24    Callable,
25    Collection,
26    Generic,
27    Iterable,
28    Iterator,
29    TypeVar,
30)
31
32from google.protobuf import descriptor_pb2, message_factory
33from google.protobuf.descriptor import (
34    FieldDescriptor,
35    MethodDescriptor,
36    ServiceDescriptor,
37)
38from google.protobuf.message import Message
39from pw_protobuf_compiler import python_protos
40
41from pw_rpc import ids
42
43
44@dataclass(frozen=True)
45class Channel:
46    id: int
47    output: Callable[[bytes], Any]
48
49    def __repr__(self) -> str:
50        return f'Channel({self.id})'
51
52
53class ChannelManipulator(abc.ABC):
54    """A a pipe interface that may manipulate packets before they're sent.
55
56    ``ChannelManipulator``s allow application-specific packet handling to be
57    injected into the packet processing pipeline for an ingress or egress
58    channel-like pathway. This is particularly useful for integration testing
59    resilience to things like packet loss on a usually-reliable transport. RPC
60    server integrations (e.g. ``HdlcRpcLocalServerAndClient``) may provide an
61    opportunity to inject a ``ChannelManipulator`` for this use case.
62
63    A ``ChannelManipulator`` should not modify send_packet, as the consumer of a
64    ``ChannelManipulator`` will use ``send_packet`` to insert the provided
65    ``ChannelManipulator`` into a packet processing path.
66
67    For example:
68
69    .. code-block:: python
70
71      class PacketLogger(ChannelManipulator):
72          def process_and_send(self, packet: bytes) -> None:
73              _LOG.debug('Received packet with payload: %s', str(packet))
74              self.send_packet(packet)
75
76
77      packet_logger = PacketLogger()
78
79      # Configure actual send command.
80      packet_logger.send_packet = socket.sendall
81
82      # Route the output channel through the PacketLogger channel manipulator.
83      channels = tuple(Channel(_DEFAULT_CHANNEL, packet_logger))
84
85      # Create a RPC client.
86      reader = SocketReader(socket)
87      with reader:
88          client = HdlcRpcClient(reader, protos, channels, stdout)
89          with client:
90            # Do something with client
91    """
92
93    def __init__(self) -> None:
94        self.send_packet: Callable[[bytes], Any] = lambda _: None
95
96    @abc.abstractmethod
97    def process_and_send(self, packet: bytes) -> None:
98        """Processes an incoming packet before optionally sending it.
99
100        Implementations of this method may send the processed packet, multiple
101        packets, or no packets at all via the registered `send_packet()`
102        handler.
103        """
104
105    def __call__(self, data: bytes) -> None:
106        self.process_and_send(data)
107
108
109@dataclass(frozen=True, eq=False)
110class Service:
111    """Describes an RPC service."""
112
113    _descriptor: ServiceDescriptor
114    id: int
115    methods: Methods
116
117    @property
118    def name(self):
119        return self._descriptor.name
120
121    @property
122    def full_name(self):
123        return self._descriptor.full_name
124
125    @property
126    def package(self):
127        return self._descriptor.file.package
128
129    @classmethod
130    def from_descriptor(cls, descriptor: ServiceDescriptor) -> Service:
131        service = cls(
132            descriptor,
133            ids.calculate(descriptor.full_name),
134            None,  # type: ignore[arg-type]
135        )
136        object.__setattr__(
137            service,
138            'methods',
139            Methods(
140                Method.from_descriptor(method_descriptor, service)
141                for method_descriptor in descriptor.methods
142            ),
143        )
144
145        return service
146
147    def __repr__(self) -> str:
148        return f'Service({self.full_name!r})'
149
150    def __str__(self) -> str:
151        return self.full_name
152
153
154def _streaming_attributes(method) -> tuple[bool, bool]:
155    # TODO(hepler): Investigate adding server_streaming and client_streaming
156    #     attributes to the generated protobuf code. As a workaround,
157    #     deserialize the FileDescriptorProto to get that information.
158    service = method.containing_service
159
160    file_pb = descriptor_pb2.FileDescriptorProto()
161    file_pb.MergeFromString(service.file.serialized_pb)
162
163    method_pb = file_pb.service[service.index].method[
164        method.index
165    ]  # pylint: disable=no-member
166    return method_pb.server_streaming, method_pb.client_streaming
167
168
169_PROTO_FIELD_TYPES = {
170    FieldDescriptor.TYPE_BOOL: bool,
171    FieldDescriptor.TYPE_BYTES: bytes,
172    FieldDescriptor.TYPE_DOUBLE: float,
173    FieldDescriptor.TYPE_ENUM: int,
174    FieldDescriptor.TYPE_FIXED32: int,
175    FieldDescriptor.TYPE_FIXED64: int,
176    FieldDescriptor.TYPE_FLOAT: float,
177    FieldDescriptor.TYPE_INT32: int,
178    FieldDescriptor.TYPE_INT64: int,
179    FieldDescriptor.TYPE_SFIXED32: int,
180    FieldDescriptor.TYPE_SFIXED64: int,
181    FieldDescriptor.TYPE_SINT32: int,
182    FieldDescriptor.TYPE_SINT64: int,
183    FieldDescriptor.TYPE_STRING: str,
184    FieldDescriptor.TYPE_UINT32: int,
185    FieldDescriptor.TYPE_UINT64: int,
186    # These types are not annotated:
187    # FieldDescriptor.TYPE_GROUP = 10
188    # FieldDescriptor.TYPE_MESSAGE = 11
189}
190
191
192def _field_type_annotation(field: FieldDescriptor):
193    """Creates a field type annotation to use in the help message only."""
194    if field.type == FieldDescriptor.TYPE_MESSAGE:
195        annotation = message_factory.GetMessageClass(field.message_type)
196    else:
197        annotation = _PROTO_FIELD_TYPES.get(field.type, Parameter.empty)
198
199    if field.label == FieldDescriptor.LABEL_REPEATED:
200        return Iterable[annotation]  # type: ignore[valid-type]
201
202    return annotation
203
204
205def field_help(proto_message, *, annotations: bool = False) -> Iterator[str]:
206    """Yields argument strings for proto fields for use in a help message."""
207    for field in proto_message.DESCRIPTOR.fields:
208        if field.type == FieldDescriptor.TYPE_ENUM:
209            value = field.enum_type.values_by_number[field.default_value].name
210            type_name = field.enum_type.full_name
211            value = f'{type_name.rsplit(".", 1)[0]}.{value}'
212        else:
213            type_name = _PROTO_FIELD_TYPES[field.type].__name__
214            value = repr(field.default_value)
215
216        if annotations:
217            yield f'{field.name}: {type_name} = {value}'
218        else:
219            yield f'{field.name}={value}'
220
221
222@dataclass(frozen=True, eq=False)
223class Method:
224    """Describes a method in a service."""
225
226    _descriptor: MethodDescriptor
227    service: Service
228    id: int
229    server_streaming: bool
230    client_streaming: bool
231    request_type: Any
232    response_type: Any
233
234    @classmethod
235    def from_descriptor(
236        cls, descriptor: MethodDescriptor, service: Service
237    ) -> Method:
238        return Method(
239            descriptor,
240            service,
241            ids.calculate(descriptor.name),
242            *_streaming_attributes(descriptor),
243            message_factory.GetMessageClass(descriptor.input_type),
244            message_factory.GetMessageClass(descriptor.output_type),
245        )
246
247    class Type(enum.Enum):
248        UNARY = 0
249        SERVER_STREAMING = 1
250        CLIENT_STREAMING = 2
251        BIDIRECTIONAL_STREAMING = 3
252
253        def sentence_name(self) -> str:
254            return self.name.lower().replace(
255                '_', ' '
256            )  # pylint: disable=no-member
257
258    @property
259    def name(self) -> str:
260        return self._descriptor.name
261
262    @property
263    def full_name(self) -> str:
264        return self._descriptor.full_name
265
266    @property
267    def package(self) -> str:
268        return self._descriptor.containing_service.file.package
269
270    @property
271    def type(self) -> Method.Type:
272        if self.server_streaming and self.client_streaming:
273            return self.Type.BIDIRECTIONAL_STREAMING
274
275        if self.server_streaming:
276            return self.Type.SERVER_STREAMING
277
278        if self.client_streaming:
279            return self.Type.CLIENT_STREAMING
280
281        return self.Type.UNARY
282
283    def get_request(
284        self, proto: Message | None, proto_kwargs: dict[str, Any] | None
285    ) -> Message:
286        """Returns a request_type protobuf message.
287
288        The client implementation may use this to support providing a request
289        as either a message object or as keyword arguments for the message's
290        fields (but not both).
291        """
292        if proto_kwargs is None:
293            proto_kwargs = {}
294
295        if proto and proto_kwargs:
296            proto_str = repr(proto).strip() or "''"
297            raise TypeError(
298                'Requests must be provided either as a message object or a '
299                'series of keyword args, but both were provided '
300                f"({proto_str} and {proto_kwargs!r})"
301            )
302
303        if proto is None:
304            return self.request_type(**proto_kwargs)
305
306        if not isinstance(proto, self.request_type):
307            try:
308                bad_type = proto.DESCRIPTOR.full_name
309            except AttributeError:
310                bad_type = type(proto).__name__
311
312            raise TypeError(
313                f'Expected a message of type '
314                f'{self.request_type.DESCRIPTOR.full_name}, '
315                f'got {bad_type}'
316            )
317
318        return proto
319
320    def request_parameters(self) -> Iterator[Parameter]:
321        """Yields inspect.Parameters corresponding to the request's fields.
322
323        This can be used to make function signatures match the request proto.
324        """
325        for field in self.request_type.DESCRIPTOR.fields:
326            yield Parameter(
327                field.name,
328                Parameter.KEYWORD_ONLY,
329                annotation=_field_type_annotation(field),
330                default=field.default_value,
331            )
332
333    def __repr__(self) -> str:
334        req = self._method_parameter(self.request_type, self.client_streaming)
335        res = self._method_parameter(self.response_type, self.server_streaming)
336        return f'<{self.full_name}({req}) returns ({res})>'
337
338    def _method_parameter(self, proto, streaming: bool) -> str:
339        """Returns a description of the method's request or response type."""
340        stream = 'stream ' if streaming else ''
341
342        if proto.DESCRIPTOR.file.package == self.service.package:
343            return stream + proto.DESCRIPTOR.name
344
345        return stream + proto.DESCRIPTOR.full_name
346
347    def __str__(self) -> str:
348        return self.full_name
349
350
351T = TypeVar('T')
352
353
354def _name(item: Service | Method) -> str:
355    return item.full_name if isinstance(item, Service) else item.name
356
357
358class _AccessByName(Generic[T]):
359    """Wrapper for accessing types by name within a proto package structure."""
360
361    def __init__(self, name: str, item: T):
362        setattr(self, name, item)
363
364
365class ServiceAccessor(Collection[T]):
366    """Navigates RPC services by name or ID."""
367
368    def __init__(self, members, as_attrs: str = ''):
369        """Creates accessor from an {item: value} dict or [values] iterable."""
370        # If the members arg was passed as a [values] iterable, convert it to
371        # an equivalent dictionary.
372        if not isinstance(members, dict):
373            members = {m: m for m in members}
374
375        by_name: dict[str, Any] = {_name(k): v for k, v in members.items()}
376        self._by_id = {k.id: v for k, v in members.items()}
377        # Note: a dictionary is used rather than `setattr` in order to
378        # (1) Hint to the type checker that there will be extra fields
379        # (2) Ensure that built-in attributes such as `_by_id`` are not
380        #     overwritten.
381        self._attrs: dict[str, Any] = {}
382
383        if as_attrs == 'members':
384            for name, member in by_name.items():
385                self._attrs[name] = member
386        elif as_attrs == 'packages':
387            for package in python_protos.as_packages(
388                (m.package, _AccessByName(m.name, members[m])) for m in members
389            ).packages:
390                self._attrs[str(package)] = package
391        elif as_attrs:
392            raise ValueError(f'Unexpected value {as_attrs!r} for as_attrs')
393
394    def __getattr__(self, name: str) -> Any:
395        return self._attrs[name]
396
397    def __getitem__(self, name_or_id: str | int) -> Any:
398        """Accesses a service/method by the string name or ID."""
399        try:
400            return self._by_id[_id(name_or_id)]
401        except KeyError:
402            pass
403
404        name = f' ("{name_or_id}")' if isinstance(name_or_id, str) else ''
405        raise KeyError(f'Unknown ID {_id(name_or_id)}{name}')
406
407    def __iter__(self) -> Iterator[T]:
408        return iter(self._by_id.values())
409
410    def __len__(self) -> int:
411        return len(self._by_id)
412
413    def __contains__(self, name_or_id) -> bool:
414        return _id(name_or_id) in self._by_id
415
416    def __repr__(self) -> str:
417        members = ', '.join(repr(m) for m in self._by_id.values())
418        return f'{self.__class__.__name__}({members})'
419
420
421def _id(handle: str | int) -> int:
422    return ids.calculate(handle) if isinstance(handle, str) else handle
423
424
425class Methods(ServiceAccessor[Method]):
426    """A collection of Method descriptors in a Service."""
427
428    def __init__(self, method: Iterable[Method]):
429        super().__init__(method)
430
431
432class Services(ServiceAccessor[Service]):
433    """A collection of Service descriptors."""
434
435    def __init__(self, services: Iterable[Service]):
436        super().__init__(services)
437
438
439def get_method(service_accessor: ServiceAccessor, name: str):
440    """Returns a method matching the given full name in a ServiceAccessor.
441
442    Args:
443      name: name as package.Service/Method or package.Service.Method.
444
445    Raises:
446      ValueError: the method name is not properly formatted
447      KeyError: the method is not present
448    """
449    if '/' in name:
450        service_name, method_name = name.split('/')
451    else:
452        service_name, method_name = name.rsplit('.', 1)
453
454    service = service_accessor[service_name]
455    if isinstance(service, Service):
456        service = service.methods
457
458    return service[method_name]
459
460
461@dataclass(frozen=True)
462class RpcIds:
463    """Integer IDs that uniquely identify a remote procedure call."""
464
465    channel_id: int
466    service_id: int
467    method_id: int
468    call_id: int
469
470
471@dataclass(frozen=True)
472class PendingRpc:
473    """Tracks an active RPC call."""
474
475    channel: Channel
476    service: Service
477    method: Method
478    call_id: int
479
480    @property
481    def channel_id(self) -> int:
482        return self.channel.id
483
484    @property
485    def service_id(self) -> int:
486        return self.service.id
487
488    @property
489    def method_id(self) -> int:
490        return self.method.id
491
492    def __eq__(self, other: Any) -> bool:
493        if isinstance(other, PendingRpc):
494            return self.ids() == other.ids()
495
496        return NotImplemented
497
498    def __hash__(self) -> int:
499        return hash(self.ids())
500
501    def ids(self) -> RpcIds:
502        return RpcIds(
503            self.channel.id, self.service.id, self.method.id, self.call_id
504        )
505
506    def __str__(self) -> str:
507        return (
508            f'PendingRpc(channel={self.channel.id}, method={self.method}, '
509            f'call_id={self.call_id})'
510        )
511
512    def matches_channel_service_method(self, other: PendingRpc) -> bool:
513        return (
514            self.channel.id == other.channel.id
515            and self.service.id == other.service.id
516            and self.method.id == other.method.id
517        )
518