xref: /aosp_15_r20/external/pigweed/pw_rpc/py/pw_rpc/codegen_nanopb.py (revision 61c4878ac05f98d0ceed94b57d316916de578985)
1# Copyright 2021 The Pigweed Authors
2#
3# Licensed under the Apache License, Version 2.0 (the "License"); you may not
4# use this file except in compliance with the License. You may obtain a copy of
5# the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12# License for the specific language governing permissions and limitations under
13# the License.
14"""This module generates the code for nanopb-based pw_rpc services."""
15
16import os
17from typing import Iterable, NamedTuple
18
19from pw_protobuf.output_file import OutputFile
20from pw_protobuf.proto_tree import ProtoServiceMethod
21from pw_protobuf.proto_tree import build_node_tree
22from pw_rpc import codegen
23from pw_rpc.codegen import (
24    client_call_type,
25    get_id,
26    CodeGenerator,
27    RPC_NAMESPACE,
28)
29
30PROTO_H_EXTENSION = '.pb.h'
31PROTO_CC_EXTENSION = '.pb.cc'
32NANOPB_H_EXTENSION = '.pb.h'
33
34
35def _serde(method: ProtoServiceMethod) -> str:
36    """Returns the NanopbMethodSerde for this method."""
37    return (
38        f'{RPC_NAMESPACE}::internal::kNanopbMethodSerde<'
39        f'{method.request_type().nanopb_fields()}, '
40        f'{method.response_type().nanopb_fields()}>'
41    )
42
43
44def _proto_filename_to_nanopb_header(proto_file: str) -> str:
45    """Returns the generated nanopb header name for a .proto file."""
46    return os.path.splitext(proto_file)[0] + NANOPB_H_EXTENSION
47
48
49def _proto_filename_to_generated_header(proto_file: str) -> str:
50    """Returns the generated C++ RPC header name for a .proto file."""
51    filename = os.path.splitext(proto_file)[0]
52    return f'{filename}.rpc{PROTO_H_EXTENSION}'
53
54
55def _client_call(
56    method: ProtoServiceMethod, response: str | None = None
57) -> str:
58    template_args = []
59
60    if method.client_streaming():
61        template_args.append(method.request_type().nanopb_struct())
62
63    if response is None:
64        response = method.response_type().nanopb_struct()
65
66    template_args.append(response)
67
68    return f'{client_call_type(method, "Nanopb")}<{", ".join(template_args)}>'
69
70
71def _function(
72    method: ProtoServiceMethod,
73    response: str | None = None,
74    name: str | None = None,
75) -> str:
76    if name is None:
77        name = method.name()
78
79    return f'{_client_call(method, response)} {name}'
80
81
82def _user_args(
83    method: ProtoServiceMethod, response: str | None = None
84) -> Iterable[str]:
85    if not method.client_streaming():
86        yield f'const {method.request_type().nanopb_struct()}& request'
87
88    if response is None:
89        response = method.response_type().nanopb_struct()
90
91    if method.server_streaming():
92        yield f'::pw::Function<void(const {response}&)>&& on_next = nullptr'
93        yield '::pw::Function<void(::pw::Status)>&& on_completed = nullptr'
94    else:
95        yield (
96            f'::pw::Function<void(const {response}&, ::pw::Status)>&& '
97            'on_completed = nullptr'
98        )
99
100    yield '::pw::Function<void(::pw::Status)>&& on_error = nullptr'
101
102
103class NanopbCodeGenerator(CodeGenerator):
104    """Generates an RPC service and client using the Nanopb API."""
105
106    def name(self) -> str:
107        return 'nanopb'
108
109    def method_union_name(self) -> str:
110        return 'NanopbMethodUnion'
111
112    def includes(self, proto_file_name: str) -> Iterable[str]:
113        yield '#include "pw_rpc/nanopb/client_reader_writer.h"'
114        yield '#include "pw_rpc/nanopb/internal/method_union.h"'
115        yield '#include "pw_rpc/nanopb/server_reader_writer.h"'
116
117        # Include the corresponding nanopb header file for this proto file, in
118        # which the file's messages and enums are generated. All other files
119        # imported from the .proto file are #included in there.
120        nanopb_header = _proto_filename_to_nanopb_header(proto_file_name)
121        yield f'#include "{nanopb_header}"'
122
123    def service_aliases(self) -> None:
124        self.line('template <typename Response>')
125        self.line(
126            'using ServerWriter = '
127            f'{RPC_NAMESPACE}::NanopbServerWriter<Response>;'
128        )
129        self.line('template <typename Request, typename Response>')
130        self.line(
131            'using ServerReader = '
132            f'{RPC_NAMESPACE}::NanopbServerReader<Request, Response>;'
133        )
134        self.line('template <typename Request, typename Response>')
135        self.line(
136            'using ServerReaderWriter = '
137            f'{RPC_NAMESPACE}::NanopbServerReaderWriter<Request, Response>;'
138        )
139
140    def method_descriptor(self, method: ProtoServiceMethod) -> None:
141        self.line(
142            f'{RPC_NAMESPACE}::internal::'
143            f'GetNanopbOrRawMethodFor<&Implementation::{method.name()}, '
144            f'{method.type().cc_enum()}, '
145            f'{method.request_type().nanopb_struct()}, '
146            f'{method.response_type().nanopb_struct()}>('
147        )
148        with self.indent(4):
149            self.line(f'{get_id(method)},  // Hash of "{method.name()}"')
150            self.line(f'{_serde(method)}),')
151
152    def _client_member_function(
153        self,
154        method: ProtoServiceMethod,
155        response: str | None = None,
156        name: str | None = None,
157    ) -> None:
158        if response is None:
159            response = method.response_type().nanopb_struct()
160
161        if name is None:
162            name = method.name()
163
164        self.line(f'{_function(method, response, name)}(')
165        self.indented_list(*_user_args(method, response), end=') const {')
166
167        with self.indent():
168            client_call = _client_call(method, response)
169            base = 'Stream' if method.server_streaming() else 'Unary'
170            self.line(
171                f'return {RPC_NAMESPACE}::internal::'
172                f'Nanopb{base}ResponseClientCall<{response}>::'
173                f'template Start<{client_call}>('
174            )
175
176            service_client = RPC_NAMESPACE + '::internal::ServiceClient'
177
178            args = [
179                f'{service_client}::client()',
180                f'{service_client}::channel_id()',
181                'kServiceId',
182                get_id(method),
183                _serde(method),
184            ]
185            if method.server_streaming():
186                args.append('std::move(on_next)')
187
188            args.append('std::move(on_completed)')
189            args.append('std::move(on_error)')
190
191            if not method.client_streaming():
192                args.append('request')
193
194            self.indented_list(*args, end=');')
195
196        self.line('}')
197
198    def client_member_function(
199        self, method: ProtoServiceMethod, *, dynamic: bool
200    ) -> None:
201        """Outputs client code for a single RPC method."""
202        if dynamic:
203            self.line('// DynamicClient is not implemented for Nanopb')
204            return
205
206        self._client_member_function(method)
207
208        self.line(
209            'template <typename Response ='
210            + f'{method.response_type().nanopb_struct()}>'
211        )
212        self._client_member_function(
213            method, 'Response', method.name() + 'Template'
214        )
215
216    def _client_static_function(
217        self,
218        method: ProtoServiceMethod,
219        response: str | None = None,
220        name: str | None = None,
221    ) -> None:
222        if response is None:
223            response = method.response_type().nanopb_struct()
224
225        if name is None:
226            name = method.name()
227
228        self.line(f'static {_function(method, response, name)}(')
229        self.indented_list(
230            f'{RPC_NAMESPACE}::Client& client',
231            'uint32_t channel_id',
232            *_user_args(method, response),
233            end=') {',
234        )
235
236        with self.indent():
237            self.line(f'return Client(client, channel_id).{name}(')
238
239            args = []
240
241            if not method.client_streaming():
242                args.append('request')
243
244            if method.server_streaming():
245                args.append('std::move(on_next)')
246
247            self.indented_list(
248                *args,
249                'std::move(on_completed)',
250                'std::move(on_error)',
251                end=');',
252            )
253
254        self.line('}')
255
256    def client_static_function(self, method: ProtoServiceMethod) -> None:
257        self._client_static_function(method)
258
259        self.line(
260            'template <typename Response ='
261            + f'{method.response_type().nanopb_struct()}>'
262        )
263        self._client_static_function(
264            method, 'Response', method.name() + 'Template'
265        )
266
267    def method_info_specialization(self, method: ProtoServiceMethod) -> None:
268        self.line()
269        self.line(f'using Request = {method.request_type().nanopb_struct()};')
270        self.line(f'using Response = {method.response_type().nanopb_struct()};')
271        self.line()
272        self.line(
273            f'static constexpr const {RPC_NAMESPACE}::internal::'
274            'NanopbMethodSerde& serde() {'
275        )
276        with self.indent():
277            self.line(f'return {_serde(method)};')
278        self.line('}')
279
280
281class _CallbackFunction(NamedTuple):
282    """Represents a callback function parameter in a client RPC call."""
283
284    function_type: str
285    name: str
286    default_value: str | None = None
287
288    def __str__(self):
289        param = f'::pw::Function<{self.function_type}>&& {self.name}'
290        if self.default_value:
291            param += f' = {self.default_value}'
292        return param
293
294
295class StubGenerator(codegen.StubGenerator):
296    """Generates Nanopb RPC stubs."""
297
298    def unary_signature(self, method: ProtoServiceMethod, prefix: str) -> str:
299        return (
300            f'::pw::Status {prefix}{method.name()}( '
301            f'const {method.request_type().nanopb_struct()}& request, '
302            f'{method.response_type().nanopb_struct()}& response)'
303        )
304
305    def unary_stub(
306        self, method: ProtoServiceMethod, output: OutputFile
307    ) -> None:
308        output.write_line(codegen.STUB_REQUEST_TODO)
309        output.write_line('static_cast<void>(request);')
310        output.write_line(codegen.STUB_RESPONSE_TODO)
311        output.write_line('static_cast<void>(response);')
312        output.write_line('return ::pw::Status::Unimplemented();')
313
314    def server_streaming_signature(
315        self, method: ProtoServiceMethod, prefix: str
316    ) -> str:
317        return (
318            f'void {prefix}{method.name()}( '
319            f'const {method.request_type().nanopb_struct()}& request, '
320            f'ServerWriter<{method.response_type().nanopb_struct()}>& writer)'
321        )
322
323    def client_streaming_signature(
324        self, method: ProtoServiceMethod, prefix: str
325    ) -> str:
326        return (
327            f'void {prefix}{method.name()}( '
328            f'ServerReader<{method.request_type().nanopb_struct()}, '
329            f'{method.response_type().nanopb_struct()}>& reader)'
330        )
331
332    def bidirectional_streaming_signature(
333        self, method: ProtoServiceMethod, prefix: str
334    ) -> str:
335        return (
336            f'void {prefix}{method.name()}( '
337            f'ServerReaderWriter<{method.request_type().nanopb_struct()}, '
338            f'{method.response_type().nanopb_struct()}>& reader_writer)'
339        )
340
341
342def process_proto_file(proto_file) -> Iterable[OutputFile]:
343    """Generates code for a single .proto file."""
344
345    _, package_root = build_node_tree(proto_file)
346    output_filename = _proto_filename_to_generated_header(proto_file.name)
347    generator = NanopbCodeGenerator(output_filename)
348    codegen.generate_package(proto_file, package_root, generator)
349
350    codegen.package_stubs(package_root, generator, StubGenerator())
351
352    return [generator.output]
353