1#!/usr/bin/env python3
2
3# Copyright 2022 Google LLC
4#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9#     https://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16
17"""Custom mmi2grpc gRPC compiler."""
18
19from __future__ import annotations
20
21import os
22import sys
23
24from typing import Dict, List, Optional, Set, Tuple, Union
25
26from google.protobuf.compiler.plugin_pb2 import CodeGeneratorRequest, CodeGeneratorResponse
27from google.protobuf.descriptor import (
28    FieldDescriptor
29)
30from google.protobuf.descriptor_pb2 import (
31    FileDescriptorProto,
32    EnumDescriptorProto,
33    DescriptorProto,
34    ServiceDescriptorProto,
35    MethodDescriptorProto,
36    FieldDescriptorProto,
37)
38
39_REQUEST = CodeGeneratorRequest.FromString(sys.stdin.buffer.read())
40
41
42def find_type_in_file(proto_file: FileDescriptorProto, type_name: str) -> Optional[Union[DescriptorProto, EnumDescriptorProto]]:
43    for enum in  proto_file.enum_type:
44        if enum.name == type_name:
45            return enum
46    for message in  proto_file.message_type:
47        if message.name == type_name:
48            return message
49    return None
50
51
52def find_type(package: str, type_name: str) -> Tuple[FileDescriptorProto, Union[DescriptorProto, EnumDescriptorProto]]:
53    for file in _REQUEST.proto_file:
54        if file.package == package and (type := find_type_in_file(file, type_name)):
55            return file, type
56    raise Exception(f'Type {package}.{type_name} not found')
57
58
59def add_import(imports: List[str], import_str: str) -> None:
60    if not import_str in imports:
61        imports.append(import_str)
62
63
64def import_type(imports: List[str], type: str, local: Optional[FileDescriptorProto]) -> Tuple[str, Union[DescriptorProto, EnumDescriptorProto], str]:
65    package = type[1:type.rindex('.')]
66    type_name = type[type.rindex('.')+1:]
67    file, desc = find_type(package, type_name)
68    if file == local:
69        return f'{type_name}', desc, ''
70    python_path = file.name.replace('.proto', '').replace('/', '.')
71    module_path = python_path[:python_path.rindex('.')]
72    module_name = python_path[python_path.rindex('.')+1:] + '_pb2'
73    add_import(imports, f'from {module_path} import {module_name}')
74    dft_import = ''
75    if isinstance(desc, EnumDescriptorProto):
76        dft_import = f'from {module_path}.{module_name} import {desc.value[0].name}'
77    return f'{module_name}.{type_name}', desc, dft_import
78
79
80def collect_type(imports: List[str], parent: DescriptorProto, field: FieldDescriptorProto, local: Optional[FileDescriptorProto]) -> Tuple[str, str, str]:
81    dft: str
82    dft_import: str = ''
83    if field.type == FieldDescriptor.TYPE_BYTES:
84        type = 'bytes'
85        dft = 'b\'\''
86    elif field.type == FieldDescriptor.TYPE_STRING:
87        type = 'str'
88        dft = '\'\''
89    elif field.type == FieldDescriptor.TYPE_BOOL:
90        type = 'bool'
91        dft = 'False'
92    elif field.type in [
93        FieldDescriptor.TYPE_FLOAT,
94        FieldDescriptor.TYPE_DOUBLE
95    ]:
96        type = 'float'
97        dft = '0.0'
98    elif field.type in [
99        FieldDescriptor.TYPE_INT64,
100        FieldDescriptor.TYPE_UINT64,
101        FieldDescriptor.TYPE_INT32,
102        FieldDescriptor.TYPE_FIXED64,
103        FieldDescriptor.TYPE_FIXED32,
104        FieldDescriptor.TYPE_UINT32,
105        FieldDescriptor.TYPE_SFIXED32,
106        FieldDescriptor.TYPE_SFIXED64,
107        FieldDescriptor.TYPE_SINT32,
108        FieldDescriptor.TYPE_SINT64
109    ]:
110        type = 'int'
111        dft = '0'
112    elif field.type in [FieldDescriptor.TYPE_ENUM, FieldDescriptor.TYPE_MESSAGE]:
113        parts = field.type_name.split(f".{parent.name}.", 2)
114        if len(parts) == 2:
115            type = parts[1]
116            for nested_type in parent.nested_type:
117                if nested_type.name == type:
118                    assert nested_type.options.map_entry
119                    assert field.label == FieldDescriptor.LABEL_REPEATED
120                    key_type, _, _ = collect_type(imports, nested_type, nested_type.field[0], local)
121                    val_type, _, _ = collect_type(imports, nested_type, nested_type.field[1], local)
122                    add_import(imports, 'from typing import Dict')
123                    return f'Dict[{key_type}, {val_type}]', '{}', ''
124        type, desc, enum_dft = import_type(imports, field.type_name, local)
125        if isinstance(desc, EnumDescriptorProto):
126            dft_import = enum_dft
127            dft = desc.value[0].name
128        else:
129            dft = f'{type}()'
130    else:
131        raise Exception(f'TODO: {field}')
132
133    if field.label == FieldDescriptor.LABEL_REPEATED:
134        add_import(imports, 'from typing import List')
135        type = f'List[{type}]'
136        dft = '[]'
137
138    return type, dft, dft_import
139
140
141def collect_field(imports: List[str], message: DescriptorProto, field: FieldDescriptorProto, local: Optional[FileDescriptorProto]) -> Tuple[Optional[int], str, str, str, str]:
142    type, dft, dft_import = collect_type(imports, message, field, local)
143    oneof_index = field.oneof_index if 'oneof_index' in f'{field}' else None
144    return oneof_index, field.name, type, dft, dft_import
145
146
147def collect_message(imports: List[str], message: DescriptorProto, local: Optional[FileDescriptorProto]) -> Tuple[
148    List[Tuple[str, str, str]],
149    Dict[str, List[Tuple[str, str]]],
150]:
151    fields: List[Tuple[str, str, str]] = []
152    oneof: Dict[str, List[Tuple[str, str]]] = {}
153
154    for field in message.field:
155        idx, name, type, dft, dft_import = collect_field(imports, message, field, local)
156        if idx is not None:
157            oneof_name = message.oneof_decl[idx].name
158            oneof.setdefault(oneof_name, [])
159            oneof[oneof_name].append((name, type))
160        else:
161            add_import(imports, dft_import)
162            fields.append((name, type, dft))
163
164    for oneof_name, oneof_fields in oneof.items():
165        for name, type in oneof_fields:
166            add_import(imports, 'from typing import Optional')
167            fields.append((name, f'Optional[{type}]', 'None'))
168
169    return fields, oneof
170
171
172def generate_enum(imports: List[str], file: FileDescriptorProto, enum: EnumDescriptorProto, res: List[CodeGeneratorResponse.File]) -> List[str]:
173    res.append(CodeGeneratorResponse.File(
174        name=file.name.replace('.proto', '_pb2.py'),
175        insertion_point=f'module_scope',
176        content=f'class {enum.name}: ...\n\n'
177    ))
178    add_import(imports, 'from google.protobuf.internal.enum_type_wrapper import EnumTypeWrapper')
179    return [
180        f'class {enum.name}(int, EnumTypeWrapper):',
181        f'  pass',
182        f'',
183        *[f'{value.name}: {enum.name}' for value in enum.value],
184        ''
185    ]
186
187
188def generate_message(imports: List[str], file: FileDescriptorProto, message: DescriptorProto, res: List[CodeGeneratorResponse.File]) -> List[str]:
189    nested_message_lines: List[str] = []
190    message_lines: List[str] = [f'class {message.name}(Message):']
191
192    add_import(imports, 'from google.protobuf.message import Message')
193    fields, oneof = collect_message(imports, message, file)
194
195    for (name, type, _) in fields:
196        message_lines.append(f'  {name}: {type}')
197
198    args = ', '.join([f'{name}: {type} = {dft}' for name, type, dft in fields])
199    if args: args = ', ' + args
200    message_lines.extend([
201        f'',
202        f'  def __init__(self{args}) -> None: ...',
203        f''
204    ])
205
206    for oneof_name, oneof_fields in oneof.items():
207        literals: str = ', '.join((f'Literal[\'{name}\']' for name, _ in oneof_fields))
208        types: Set[str] = set((type for _, type in oneof_fields))
209        if len(types) == 1:
210            type = 'Optional[' + types.pop() + ']'
211        else:
212            types.add('None')
213            type = 'Union[' + ', '.join(sorted(types)) + ']'
214
215        nested_message_lines.extend([
216            f'class {message.name}_{oneof_name}_dict(TypedDict, total=False):',
217            '\n'.join([f'  {name}: {type}' for name, type in oneof_fields]),
218            f'',
219        ])
220
221        add_import(imports, 'from typing import Union')
222        add_import(imports, 'from typing_extensions import TypedDict')
223        add_import(imports, 'from typing_extensions import Literal')
224        message_lines.extend([
225            f'  @property',
226            f'  def {oneof_name}(self) -> {type}: ...'
227            f'',
228            f'  def {oneof_name}_variant(self) -> Union[{literals}, None]: ...'
229            f'',
230            f'  def {oneof_name}_asdict(self) -> {message.name}_{oneof_name}_dict: ...',
231            f'',
232        ])
233
234        return_variant = '\n  '.join([f'if variant == \'{name}\': return unwrap(self.{name})' for name, _ in oneof_fields])
235        return_asdict = '\n  '.join([f'if variant == \'{name}\': return {{\'{name}\': unwrap(self.{name})}}  # type: ignore' for name, _ in oneof_fields])
236        if return_variant: return_variant += '\n  '
237        if return_asdict: return_asdict += '\n  '
238
239        res.append(CodeGeneratorResponse.File(
240            name=file.name.replace('.proto', '_pb2.py'),
241            insertion_point=f'module_scope',
242            content=f"""
243def _{message.name}_{oneof_name}(self: {message.name}):
244  variant = self.{oneof_name}_variant()
245  if variant is None: return None
246  {return_variant}raise Exception('Field `{oneof_name}` not found.')
247
248def _{message.name}_{oneof_name}_variant(self: {message.name}):
249  return self.WhichOneof('{oneof_name}')  # type: ignore
250
251def _{message.name}_{oneof_name}_asdict(self: {message.name}):
252  variant = self.{oneof_name}_variant()
253  if variant is None: return {{}}
254  {return_asdict}raise Exception('Field `{oneof_name}` not found.')
255
256setattr({message.name}, '{oneof_name}', property(_{message.name}_{oneof_name}))
257setattr({message.name}, '{oneof_name}_variant', _{message.name}_{oneof_name}_variant)
258setattr({message.name}, '{oneof_name}_asdict', _{message.name}_{oneof_name}_asdict)
259"""))
260
261    return message_lines + nested_message_lines
262
263
264def generate_service_method(imports: List[str], file: FileDescriptorProto, service: ServiceDescriptorProto, method: MethodDescriptorProto, sync: bool = True) -> List[str]:
265    input_mode = 'stream' if method.client_streaming else 'unary'
266    output_mode = 'stream' if method.server_streaming else 'unary'
267
268    input_type, input_msg, _ = import_type(imports, method.input_type, None)
269    output_type, _, _ = import_type(imports, method.output_type, None)
270
271    input_type_pb2, _, _ = import_type(imports, method.input_type, None)
272    output_type_pb2, _, _ = import_type(imports, method.output_type, None)
273
274    if output_mode == 'stream':
275        if input_mode == 'stream':
276            output_type_hint = f'StreamStream[{input_type}, {output_type}]'
277            if sync:
278                add_import(imports, f'from ._utils import Sender')
279                add_import(imports, f'from ._utils import Stream')
280                add_import(imports, f'from ._utils import StreamStream')
281            else:
282                add_import(imports, f'from ._utils import AioSender as Sender')
283                add_import(imports, f'from ._utils import AioStream as Stream')
284                add_import(imports, f'from ._utils import AioStreamStream as StreamStream')
285        else:
286            output_type_hint = f'Stream[{output_type}]'
287            if sync:
288                add_import(imports, f'from ._utils import Stream')
289            else:
290                add_import(imports, f'from ._utils import AioStream as Stream')
291    else:
292        output_type_hint = output_type if sync else f'Awaitable[{output_type}]'
293        if not sync: add_import(imports, f'from typing import Awaitable')
294
295    if input_mode == 'stream' and output_mode == 'stream':
296        add_import(imports, f'from typing import Optional')
297        return (
298            f'def {method.name}(self, timeout: Optional[float] = None) -> {output_type_hint}:\n'
299            f'    tx: Sender[{input_type}] = Sender()\n'
300            f'    rx: Stream[{output_type}] = self.channel.{input_mode}_{output_mode}(  # type: ignore\n'
301            f"        '/{file.package}.{service.name}/{method.name}',\n"
302            f'        request_serializer={input_type_pb2}.SerializeToString,  # type: ignore\n'
303            f'        response_deserializer={output_type_pb2}.FromString  # type: ignore\n'
304            f'    )(tx)\n'
305            f'    return StreamStream(tx, rx)'
306        ).split('\n')
307    if input_mode == 'stream':
308        iterator_type = 'Iterator' if sync else 'AsyncIterator'
309        add_import(imports, f'from typing import {iterator_type}')
310        add_import(imports, f'from typing import Optional')
311        return (
312            f'def {method.name}(self, iterator: {iterator_type}[{input_type}], timeout: Optional[float] = None) -> {output_type_hint}:\n'
313            f'    return self.channel.{input_mode}_{output_mode}(  # type: ignore\n'
314            f"        '/{file.package}.{service.name}/{method.name}',\n"
315            f'        request_serializer={input_type_pb2}.SerializeToString,  # type: ignore\n'
316            f'        response_deserializer={output_type_pb2}.FromString  # type: ignore\n'
317            f'    )(iterator)'
318        ).split('\n')
319    else:
320        add_import(imports, f'from typing import Optional')
321        assert isinstance(input_msg, DescriptorProto)
322        input_fields, _ = collect_message(imports, input_msg, None)
323        args = ', '.join([f'{name}: {type} = {dft}' for name, type, dft in input_fields])
324        args_name = ', '.join([f'{name}={name}' for name, _, _ in input_fields])
325        if args: args = ', ' + args
326        return (
327            f'def {method.name}(self{args}, wait_for_ready: Optional[bool] = None, timeout: Optional[float] = None) -> {output_type_hint}:\n'
328            f'    return self.channel.{input_mode}_{output_mode}(  # type: ignore\n'
329            f"        '/{file.package}.{service.name}/{method.name}',\n"
330            f'        request_serializer={input_type_pb2}.SerializeToString,  # type: ignore\n'
331            f'        response_deserializer={output_type_pb2}.FromString  # type: ignore\n'
332            f'    )({input_type_pb2}({args_name}), wait_for_ready=wait_for_ready, timeout=timeout)  # type: ignore'
333        ).split('\n')
334
335
336def generate_service(imports: List[str], file: FileDescriptorProto, service: ServiceDescriptorProto, sync: bool = True) -> List[str]:
337    methods = '\n\n    '.join([
338        '\n    '.join(
339            generate_service_method(imports, file, service, method, sync)
340        ) for method in service.method
341    ])
342    channel_type = 'grpc.Channel' if sync else 'grpc.aio.Channel'
343    return (
344        f'class {service.name}:\n'
345        f'    channel: {channel_type}\n'
346        f'\n'
347        f'    def __init__(self, channel: {channel_type}) -> None:\n'
348        f'        self.channel = channel\n'
349        f'\n'
350        f'    {methods}\n'
351    ).split('\n')
352
353
354def generate_servicer_method(imports: List[str], method: MethodDescriptorProto, sync: bool = True) -> List[str]:
355    input_mode = 'stream' if method.client_streaming else 'unary'
356    output_mode = 'stream' if method.server_streaming else 'unary'
357
358    input_type, _, _ = import_type(imports, method.input_type, None)
359    output_type, _, _ = import_type(imports, method.output_type, None)
360
361    output_type_hint = output_type
362    if output_mode == 'stream':
363        if sync:
364            output_type_hint = f'Generator[{output_type}, None, None]'
365            add_import(imports, f'from typing import Generator')
366        else:
367            output_type_hint = f'AsyncGenerator[{output_type}, None]'
368            add_import(imports, f'from typing import AsyncGenerator')
369
370    iterator_type = 'Iterator' if sync else 'AsyncIterator'
371
372    if input_mode == 'stream':
373        iterator_type = 'Iterator' if sync else 'AsyncIterator'
374        add_import(imports, f'from typing import {iterator_type}')
375        lines = (('' if sync else 'async ') + (
376            f'def {method.name}(self, request: {iterator_type}[{input_type}], context: grpc.ServicerContext) -> {output_type_hint}:\n'
377            f'    context.set_code(grpc.StatusCode.UNIMPLEMENTED)  # type: ignore\n'
378            f'    context.set_details("Method not implemented!")  # type: ignore\n'
379            f'    raise NotImplementedError("Method not implemented!")'
380        )).split('\n')
381    else:
382        lines = (('' if sync else 'async ') + (
383            f'def {method.name}(self, request: {input_type}, context: grpc.ServicerContext) -> {output_type_hint}:\n'
384            f'    context.set_code(grpc.StatusCode.UNIMPLEMENTED)  # type: ignore\n'
385            f'    context.set_details("Method not implemented!")  # type: ignore\n'
386            f'    raise NotImplementedError("Method not implemented!")'
387        )).split('\n')
388    if output_mode == 'stream':
389        lines.append(f'    yield {output_type}()  # no-op: to make the linter happy')
390    return lines
391
392
393def generate_servicer(imports: List[str], file: FileDescriptorProto, service: ServiceDescriptorProto, sync: bool = True) -> List[str]:
394    methods = '\n\n    '.join([
395        '\n    '.join(
396            generate_servicer_method(imports, method, sync)
397        ) for method in service.method
398    ])
399    if not methods:
400        methods = 'pass'
401    return (
402        f'class {service.name}Servicer:\n'
403        f'    {methods}\n'
404    ).split('\n')
405
406
407def generate_rpc_method_handler(imports: List[str], method: MethodDescriptorProto) -> List[str]:
408    input_mode = 'stream' if method.client_streaming else 'unary'
409    output_mode = 'stream' if method.server_streaming else 'unary'
410
411    input_type, _, _ = import_type(imports, method.input_type, None)
412    output_type, _, _ = import_type(imports, method.output_type, None)
413
414    return (
415        f"'{method.name}': grpc.{input_mode}_{output_mode}_rpc_method_handler(  # type: ignore\n"
416        f'        servicer.{method.name},\n'
417        f'        request_deserializer={input_type}.FromString,  # type: ignore\n'
418        f'        response_serializer={output_type}.SerializeToString,  # type: ignore\n'
419        f'    ),\n'
420    ).split('\n')
421
422
423def generate_add_servicer_to_server_method(imports: List[str], file: FileDescriptorProto, service: ServiceDescriptorProto, sync: bool = True) -> List[str]:
424    method_handlers = '    '.join([
425        '\n    '.join(
426            generate_rpc_method_handler(imports, method)
427        ) for method in service.method
428    ])
429    server_type = 'grpc.Server' if sync else 'grpc.aio.Server'
430    return (
431        f'def add_{service.name}Servicer_to_server(servicer: {service.name}Servicer, server: {server_type}) -> None:\n'
432        f'    rpc_method_handlers = {{\n'
433        f'        {method_handlers}\n'
434        f'    }}\n'
435        f'    generic_handler = grpc.method_handlers_generic_handler(  # type: ignore\n'
436        f"        '{file.package}.{service.name}', rpc_method_handlers)\n"
437        f'    server.add_generic_rpc_handlers((generic_handler,))  # type: ignore\n'
438    ).split('\n')
439
440
441_HEADER = '''# Copyright 2022 Google LLC
442#
443# Licensed under the Apache License, Version 2.0 (the "License");
444# you may not use this file except in compliance with the License.
445# You may obtain a copy of the License at
446#
447#     https://www.apache.org/licenses/LICENSE-2.0
448#
449# Unless required by applicable law or agreed to in writing, software
450# distributed under the License is distributed on an "AS IS" BASIS,
451# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
452# See the License for the specific language governing permissions and
453# limitations under the License.
454
455"""Generated python gRPC interfaces."""
456
457from __future__ import annotations
458'''
459
460_UTILS_PY = f'''{_HEADER}
461
462import asyncio
463import queue
464import grpc
465import sys
466
467from typing import Any, AsyncIterable, AsyncIterator, Generic, Iterator, TypeVar
468
469
470_T_co = TypeVar('_T_co', covariant=True)  # pytype: disable=not-supported-yet
471_T = TypeVar('_T')
472
473
474class Stream(Iterator[_T_co], grpc.RpcContext): ...
475
476
477class AioStream(AsyncIterable[_T_co], grpc.RpcContext): ...
478
479
480class Sender(Iterator[_T]):
481    if sys.version_info >= (3, 8):
482        _inner: queue.Queue[_T]
483    else:
484        _inner: queue.Queue
485
486    def __init__(self) -> None:
487        self._inner = queue.Queue()
488
489    def __iter__(self) -> Iterator[_T]:
490        return self
491
492    def __next__(self) -> _T:
493        return self._inner.get()
494
495    def send(self, item: _T) -> None:
496        self._inner.put(item)
497
498
499class AioSender(AsyncIterator[_T]):
500    if sys.version_info >= (3, 8):
501        _inner: asyncio.Queue[_T]
502    else:
503        _inner: asyncio.Queue
504
505    def __init__(self) -> None:
506        self._inner = asyncio.Queue()
507
508    def __iter__(self) -> AsyncIterator[_T]:
509        return self
510
511    async def __anext__(self) -> _T:
512        return await self._inner.get()
513
514    async def send(self, item: _T) -> None:
515        await self._inner.put(item)
516
517    def send_nowait(self, item: _T) -> None:
518        self._inner.put_nowait(item)
519
520
521class StreamStream(Generic[_T, _T_co], Iterator[_T_co], grpc.RpcContext):
522    _sender: Sender[_T]
523    _receiver: Stream[_T_co]
524
525    def __init__(self, sender: Sender[_T], receiver: Stream[_T_co]) -> None:
526        self._sender = sender
527        self._receiver = receiver
528
529    def send(self, item: _T) -> None:
530        self._sender.send(item)
531
532    def __iter__(self) -> Iterator[_T_co]:
533        return self._receiver.__iter__()
534
535    def __next__(self) -> _T_co:
536        return self._receiver.__next__()
537
538    def is_active(self) -> bool:
539        return self._receiver.is_active()  # type: ignore
540
541    def time_remaining(self) -> float:
542        return self._receiver.time_remaining()  # type: ignore
543
544    def cancel(self) -> None:
545        self._receiver.cancel()  # type: ignore
546
547    def add_callback(self, callback: Any) -> None:
548        self._receiver.add_callback(callback)  # type: ignore
549
550
551class AioStreamStream(Generic[_T, _T_co], AsyncIterator[_T_co], grpc.RpcContext):
552    _sender: AioSender[_T]
553    _receiver: AioStream[_T_co]
554
555    def __init__(self, sender: AioSender[_T], receiver: AioStream[_T_co]) -> None:
556        self._sender = sender
557        self._receiver = receiver
558
559    def __aiter__(self) -> AsyncIterator[_T_co]:
560        return self._receiver.__aiter__()
561
562    async def __anext__(self) -> _T_co:
563        return await self._receiver.__aiter__().__anext__()
564
565    async def send(self, item: _T) -> None:
566        await self._sender.send(item)
567
568    def send_nowait(self, item: _T) -> None:
569        self._sender.send_nowait(item)
570
571    def is_active(self) -> bool:
572        return self._receiver.is_active()  # type: ignore
573
574    def time_remaining(self) -> float:
575        return self._receiver.time_remaining()  # type: ignore
576
577    def cancel(self) -> None:
578        self._receiver.cancel()  # type: ignore
579
580    def add_callback(self, callback: Any) -> None:
581        self._receiver.add_callback(callback)  # type: ignore
582'''
583
584
585_FILES: List[CodeGeneratorResponse.File] = []
586_UTILS_FILES: Set[str] = set()
587
588
589for file_name in _REQUEST.file_to_generate:
590    file: FileDescriptorProto = next(filter(lambda x: x.name == file_name, _REQUEST.proto_file))
591
592    _FILES.append(CodeGeneratorResponse.File(
593        name=file.name.replace('.proto', '_pb2.py'),
594        insertion_point=f'module_scope',
595        content='def unwrap(x):\n  assert x\n  return x\n'
596    ))
597
598    pyi_imports: List[str] = []
599    grpc_imports: List[str] = ['import grpc']
600    grpc_aio_imports: List[str] = ['import grpc', 'import grpc.aio']
601
602    enums = '\n'.join(sum([generate_enum(pyi_imports, file, enum, _FILES) for enum in file.enum_type], []))
603    messages = '\n'.join(sum([generate_message(pyi_imports, file, message, _FILES) for message in file.message_type], []))
604
605    services = '\n'.join(sum([generate_service(grpc_imports, file, service) for service in file.service], []))
606    aio_services = '\n'.join(sum([generate_service(grpc_aio_imports, file, service, False) for service in file.service], []))
607
608    servicers = '\n'.join(sum([generate_servicer(grpc_imports, file, service) for service in file.service], []))
609    aio_servicers = '\n'.join(sum([generate_servicer(grpc_aio_imports, file, service, False) for service in file.service], []))
610
611    add_servicer_methods = '\n'.join(sum([generate_add_servicer_to_server_method(grpc_imports, file, service) for service in file.service], []))
612    aio_add_servicer_methods = '\n'.join(sum([generate_add_servicer_to_server_method(grpc_aio_imports, file, service, False) for service in file.service], []))
613
614    pyi_imports.sort()
615    grpc_imports.sort()
616    grpc_aio_imports.sort()
617
618    pyi_imports_str: str = '\n'.join(pyi_imports)
619    grpc_imports_str: str = '\n'.join(grpc_imports)
620    grpc_aio_imports_str: str = '\n'.join(grpc_aio_imports)
621
622    utils_filename = file_name.replace(os.path.basename(file_name), '_utils.py')
623    if utils_filename not in _UTILS_FILES:
624        _UTILS_FILES.add(utils_filename)
625        _FILES.extend([
626            CodeGeneratorResponse.File(
627                name=utils_filename,
628                content=_UTILS_PY,
629            )
630        ])
631
632    _FILES.extend([
633        CodeGeneratorResponse.File(
634            name=file.name.replace('.proto', '_pb2.pyi'),
635            content=f'{_HEADER}\n\n{pyi_imports_str}\n\n{enums}\n\n{messages}\n'
636        ),
637        CodeGeneratorResponse.File(
638            name=file_name.replace('.proto', '_grpc.py'),
639            content=f'{_HEADER}\n\n{grpc_imports_str}\n\n{services}\n\n{servicers}\n\n{add_servicer_methods}'
640        ),
641        CodeGeneratorResponse.File(
642            name=file_name.replace('.proto', '_grpc_aio.py'),
643            content=f'{_HEADER}\n\n{grpc_aio_imports_str}\n\n{aio_services}\n\n{aio_servicers}\n\n{aio_add_servicer_methods}'
644        )
645    ])
646
647
648sys.stdout.buffer.write(CodeGeneratorResponse(file=_FILES).SerializeToString())
649