1#!/usr/bin/env python3 2 3# Copyright 2022 Google LLC 4# 5# Licensed under the Apache License, Version 2.0 (the "License"); 6# you may not use this file except in compliance with the License. 7# You may obtain a copy of the License at 8# 9# https://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, software 12# distributed under the License is distributed on an "AS IS" BASIS, 13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14# See the License for the specific language governing permissions and 15# limitations under the License. 16 17"""Custom mmi2grpc gRPC compiler.""" 18 19from __future__ import annotations 20 21import os 22import sys 23 24from typing import Dict, List, Optional, Set, Tuple, Union 25 26from google.protobuf.compiler.plugin_pb2 import CodeGeneratorRequest, CodeGeneratorResponse 27from google.protobuf.descriptor import ( 28 FieldDescriptor 29) 30from google.protobuf.descriptor_pb2 import ( 31 FileDescriptorProto, 32 EnumDescriptorProto, 33 DescriptorProto, 34 ServiceDescriptorProto, 35 MethodDescriptorProto, 36 FieldDescriptorProto, 37) 38 39_REQUEST = CodeGeneratorRequest.FromString(sys.stdin.buffer.read()) 40 41 42def find_type_in_file(proto_file: FileDescriptorProto, type_name: str) -> Optional[Union[DescriptorProto, EnumDescriptorProto]]: 43 for enum in proto_file.enum_type: 44 if enum.name == type_name: 45 return enum 46 for message in proto_file.message_type: 47 if message.name == type_name: 48 return message 49 return None 50 51 52def find_type(package: str, type_name: str) -> Tuple[FileDescriptorProto, Union[DescriptorProto, EnumDescriptorProto]]: 53 for file in _REQUEST.proto_file: 54 if file.package == package and (type := find_type_in_file(file, type_name)): 55 return file, type 56 raise Exception(f'Type {package}.{type_name} not found') 57 58 59def add_import(imports: List[str], import_str: str) -> None: 60 if not import_str in imports: 61 imports.append(import_str) 62 63 64def import_type(imports: List[str], type: str, local: Optional[FileDescriptorProto]) -> Tuple[str, Union[DescriptorProto, EnumDescriptorProto], str]: 65 package = type[1:type.rindex('.')] 66 type_name = type[type.rindex('.')+1:] 67 file, desc = find_type(package, type_name) 68 if file == local: 69 return f'{type_name}', desc, '' 70 python_path = file.name.replace('.proto', '').replace('/', '.') 71 module_path = python_path[:python_path.rindex('.')] 72 module_name = python_path[python_path.rindex('.')+1:] + '_pb2' 73 add_import(imports, f'from {module_path} import {module_name}') 74 dft_import = '' 75 if isinstance(desc, EnumDescriptorProto): 76 dft_import = f'from {module_path}.{module_name} import {desc.value[0].name}' 77 return f'{module_name}.{type_name}', desc, dft_import 78 79 80def collect_type(imports: List[str], parent: DescriptorProto, field: FieldDescriptorProto, local: Optional[FileDescriptorProto]) -> Tuple[str, str, str]: 81 dft: str 82 dft_import: str = '' 83 if field.type == FieldDescriptor.TYPE_BYTES: 84 type = 'bytes' 85 dft = 'b\'\'' 86 elif field.type == FieldDescriptor.TYPE_STRING: 87 type = 'str' 88 dft = '\'\'' 89 elif field.type == FieldDescriptor.TYPE_BOOL: 90 type = 'bool' 91 dft = 'False' 92 elif field.type in [ 93 FieldDescriptor.TYPE_FLOAT, 94 FieldDescriptor.TYPE_DOUBLE 95 ]: 96 type = 'float' 97 dft = '0.0' 98 elif field.type in [ 99 FieldDescriptor.TYPE_INT64, 100 FieldDescriptor.TYPE_UINT64, 101 FieldDescriptor.TYPE_INT32, 102 FieldDescriptor.TYPE_FIXED64, 103 FieldDescriptor.TYPE_FIXED32, 104 FieldDescriptor.TYPE_UINT32, 105 FieldDescriptor.TYPE_SFIXED32, 106 FieldDescriptor.TYPE_SFIXED64, 107 FieldDescriptor.TYPE_SINT32, 108 FieldDescriptor.TYPE_SINT64 109 ]: 110 type = 'int' 111 dft = '0' 112 elif field.type in [FieldDescriptor.TYPE_ENUM, FieldDescriptor.TYPE_MESSAGE]: 113 parts = field.type_name.split(f".{parent.name}.", 2) 114 if len(parts) == 2: 115 type = parts[1] 116 for nested_type in parent.nested_type: 117 if nested_type.name == type: 118 assert nested_type.options.map_entry 119 assert field.label == FieldDescriptor.LABEL_REPEATED 120 key_type, _, _ = collect_type(imports, nested_type, nested_type.field[0], local) 121 val_type, _, _ = collect_type(imports, nested_type, nested_type.field[1], local) 122 add_import(imports, 'from typing import Dict') 123 return f'Dict[{key_type}, {val_type}]', '{}', '' 124 type, desc, enum_dft = import_type(imports, field.type_name, local) 125 if isinstance(desc, EnumDescriptorProto): 126 dft_import = enum_dft 127 dft = desc.value[0].name 128 else: 129 dft = f'{type}()' 130 else: 131 raise Exception(f'TODO: {field}') 132 133 if field.label == FieldDescriptor.LABEL_REPEATED: 134 add_import(imports, 'from typing import List') 135 type = f'List[{type}]' 136 dft = '[]' 137 138 return type, dft, dft_import 139 140 141def collect_field(imports: List[str], message: DescriptorProto, field: FieldDescriptorProto, local: Optional[FileDescriptorProto]) -> Tuple[Optional[int], str, str, str, str]: 142 type, dft, dft_import = collect_type(imports, message, field, local) 143 oneof_index = field.oneof_index if 'oneof_index' in f'{field}' else None 144 return oneof_index, field.name, type, dft, dft_import 145 146 147def collect_message(imports: List[str], message: DescriptorProto, local: Optional[FileDescriptorProto]) -> Tuple[ 148 List[Tuple[str, str, str]], 149 Dict[str, List[Tuple[str, str]]], 150]: 151 fields: List[Tuple[str, str, str]] = [] 152 oneof: Dict[str, List[Tuple[str, str]]] = {} 153 154 for field in message.field: 155 idx, name, type, dft, dft_import = collect_field(imports, message, field, local) 156 if idx is not None: 157 oneof_name = message.oneof_decl[idx].name 158 oneof.setdefault(oneof_name, []) 159 oneof[oneof_name].append((name, type)) 160 else: 161 add_import(imports, dft_import) 162 fields.append((name, type, dft)) 163 164 for oneof_name, oneof_fields in oneof.items(): 165 for name, type in oneof_fields: 166 add_import(imports, 'from typing import Optional') 167 fields.append((name, f'Optional[{type}]', 'None')) 168 169 return fields, oneof 170 171 172def generate_enum(imports: List[str], file: FileDescriptorProto, enum: EnumDescriptorProto, res: List[CodeGeneratorResponse.File]) -> List[str]: 173 res.append(CodeGeneratorResponse.File( 174 name=file.name.replace('.proto', '_pb2.py'), 175 insertion_point=f'module_scope', 176 content=f'class {enum.name}: ...\n\n' 177 )) 178 add_import(imports, 'from google.protobuf.internal.enum_type_wrapper import EnumTypeWrapper') 179 return [ 180 f'class {enum.name}(int, EnumTypeWrapper):', 181 f' pass', 182 f'', 183 *[f'{value.name}: {enum.name}' for value in enum.value], 184 '' 185 ] 186 187 188def generate_message(imports: List[str], file: FileDescriptorProto, message: DescriptorProto, res: List[CodeGeneratorResponse.File]) -> List[str]: 189 nested_message_lines: List[str] = [] 190 message_lines: List[str] = [f'class {message.name}(Message):'] 191 192 add_import(imports, 'from google.protobuf.message import Message') 193 fields, oneof = collect_message(imports, message, file) 194 195 for (name, type, _) in fields: 196 message_lines.append(f' {name}: {type}') 197 198 args = ', '.join([f'{name}: {type} = {dft}' for name, type, dft in fields]) 199 if args: args = ', ' + args 200 message_lines.extend([ 201 f'', 202 f' def __init__(self{args}) -> None: ...', 203 f'' 204 ]) 205 206 for oneof_name, oneof_fields in oneof.items(): 207 literals: str = ', '.join((f'Literal[\'{name}\']' for name, _ in oneof_fields)) 208 types: Set[str] = set((type for _, type in oneof_fields)) 209 if len(types) == 1: 210 type = 'Optional[' + types.pop() + ']' 211 else: 212 types.add('None') 213 type = 'Union[' + ', '.join(sorted(types)) + ']' 214 215 nested_message_lines.extend([ 216 f'class {message.name}_{oneof_name}_dict(TypedDict, total=False):', 217 '\n'.join([f' {name}: {type}' for name, type in oneof_fields]), 218 f'', 219 ]) 220 221 add_import(imports, 'from typing import Union') 222 add_import(imports, 'from typing_extensions import TypedDict') 223 add_import(imports, 'from typing_extensions import Literal') 224 message_lines.extend([ 225 f' @property', 226 f' def {oneof_name}(self) -> {type}: ...' 227 f'', 228 f' def {oneof_name}_variant(self) -> Union[{literals}, None]: ...' 229 f'', 230 f' def {oneof_name}_asdict(self) -> {message.name}_{oneof_name}_dict: ...', 231 f'', 232 ]) 233 234 return_variant = '\n '.join([f'if variant == \'{name}\': return unwrap(self.{name})' for name, _ in oneof_fields]) 235 return_asdict = '\n '.join([f'if variant == \'{name}\': return {{\'{name}\': unwrap(self.{name})}} # type: ignore' for name, _ in oneof_fields]) 236 if return_variant: return_variant += '\n ' 237 if return_asdict: return_asdict += '\n ' 238 239 res.append(CodeGeneratorResponse.File( 240 name=file.name.replace('.proto', '_pb2.py'), 241 insertion_point=f'module_scope', 242 content=f""" 243def _{message.name}_{oneof_name}(self: {message.name}): 244 variant = self.{oneof_name}_variant() 245 if variant is None: return None 246 {return_variant}raise Exception('Field `{oneof_name}` not found.') 247 248def _{message.name}_{oneof_name}_variant(self: {message.name}): 249 return self.WhichOneof('{oneof_name}') # type: ignore 250 251def _{message.name}_{oneof_name}_asdict(self: {message.name}): 252 variant = self.{oneof_name}_variant() 253 if variant is None: return {{}} 254 {return_asdict}raise Exception('Field `{oneof_name}` not found.') 255 256setattr({message.name}, '{oneof_name}', property(_{message.name}_{oneof_name})) 257setattr({message.name}, '{oneof_name}_variant', _{message.name}_{oneof_name}_variant) 258setattr({message.name}, '{oneof_name}_asdict', _{message.name}_{oneof_name}_asdict) 259""")) 260 261 return message_lines + nested_message_lines 262 263 264def generate_service_method(imports: List[str], file: FileDescriptorProto, service: ServiceDescriptorProto, method: MethodDescriptorProto, sync: bool = True) -> List[str]: 265 input_mode = 'stream' if method.client_streaming else 'unary' 266 output_mode = 'stream' if method.server_streaming else 'unary' 267 268 input_type, input_msg, _ = import_type(imports, method.input_type, None) 269 output_type, _, _ = import_type(imports, method.output_type, None) 270 271 input_type_pb2, _, _ = import_type(imports, method.input_type, None) 272 output_type_pb2, _, _ = import_type(imports, method.output_type, None) 273 274 if output_mode == 'stream': 275 if input_mode == 'stream': 276 output_type_hint = f'StreamStream[{input_type}, {output_type}]' 277 if sync: 278 add_import(imports, f'from ._utils import Sender') 279 add_import(imports, f'from ._utils import Stream') 280 add_import(imports, f'from ._utils import StreamStream') 281 else: 282 add_import(imports, f'from ._utils import AioSender as Sender') 283 add_import(imports, f'from ._utils import AioStream as Stream') 284 add_import(imports, f'from ._utils import AioStreamStream as StreamStream') 285 else: 286 output_type_hint = f'Stream[{output_type}]' 287 if sync: 288 add_import(imports, f'from ._utils import Stream') 289 else: 290 add_import(imports, f'from ._utils import AioStream as Stream') 291 else: 292 output_type_hint = output_type if sync else f'Awaitable[{output_type}]' 293 if not sync: add_import(imports, f'from typing import Awaitable') 294 295 if input_mode == 'stream' and output_mode == 'stream': 296 add_import(imports, f'from typing import Optional') 297 return ( 298 f'def {method.name}(self, timeout: Optional[float] = None) -> {output_type_hint}:\n' 299 f' tx: Sender[{input_type}] = Sender()\n' 300 f' rx: Stream[{output_type}] = self.channel.{input_mode}_{output_mode}( # type: ignore\n' 301 f" '/{file.package}.{service.name}/{method.name}',\n" 302 f' request_serializer={input_type_pb2}.SerializeToString, # type: ignore\n' 303 f' response_deserializer={output_type_pb2}.FromString # type: ignore\n' 304 f' )(tx)\n' 305 f' return StreamStream(tx, rx)' 306 ).split('\n') 307 if input_mode == 'stream': 308 iterator_type = 'Iterator' if sync else 'AsyncIterator' 309 add_import(imports, f'from typing import {iterator_type}') 310 add_import(imports, f'from typing import Optional') 311 return ( 312 f'def {method.name}(self, iterator: {iterator_type}[{input_type}], timeout: Optional[float] = None) -> {output_type_hint}:\n' 313 f' return self.channel.{input_mode}_{output_mode}( # type: ignore\n' 314 f" '/{file.package}.{service.name}/{method.name}',\n" 315 f' request_serializer={input_type_pb2}.SerializeToString, # type: ignore\n' 316 f' response_deserializer={output_type_pb2}.FromString # type: ignore\n' 317 f' )(iterator)' 318 ).split('\n') 319 else: 320 add_import(imports, f'from typing import Optional') 321 assert isinstance(input_msg, DescriptorProto) 322 input_fields, _ = collect_message(imports, input_msg, None) 323 args = ', '.join([f'{name}: {type} = {dft}' for name, type, dft in input_fields]) 324 args_name = ', '.join([f'{name}={name}' for name, _, _ in input_fields]) 325 if args: args = ', ' + args 326 return ( 327 f'def {method.name}(self{args}, wait_for_ready: Optional[bool] = None, timeout: Optional[float] = None) -> {output_type_hint}:\n' 328 f' return self.channel.{input_mode}_{output_mode}( # type: ignore\n' 329 f" '/{file.package}.{service.name}/{method.name}',\n" 330 f' request_serializer={input_type_pb2}.SerializeToString, # type: ignore\n' 331 f' response_deserializer={output_type_pb2}.FromString # type: ignore\n' 332 f' )({input_type_pb2}({args_name}), wait_for_ready=wait_for_ready, timeout=timeout) # type: ignore' 333 ).split('\n') 334 335 336def generate_service(imports: List[str], file: FileDescriptorProto, service: ServiceDescriptorProto, sync: bool = True) -> List[str]: 337 methods = '\n\n '.join([ 338 '\n '.join( 339 generate_service_method(imports, file, service, method, sync) 340 ) for method in service.method 341 ]) 342 channel_type = 'grpc.Channel' if sync else 'grpc.aio.Channel' 343 return ( 344 f'class {service.name}:\n' 345 f' channel: {channel_type}\n' 346 f'\n' 347 f' def __init__(self, channel: {channel_type}) -> None:\n' 348 f' self.channel = channel\n' 349 f'\n' 350 f' {methods}\n' 351 ).split('\n') 352 353 354def generate_servicer_method(imports: List[str], method: MethodDescriptorProto, sync: bool = True) -> List[str]: 355 input_mode = 'stream' if method.client_streaming else 'unary' 356 output_mode = 'stream' if method.server_streaming else 'unary' 357 358 input_type, _, _ = import_type(imports, method.input_type, None) 359 output_type, _, _ = import_type(imports, method.output_type, None) 360 361 output_type_hint = output_type 362 if output_mode == 'stream': 363 if sync: 364 output_type_hint = f'Generator[{output_type}, None, None]' 365 add_import(imports, f'from typing import Generator') 366 else: 367 output_type_hint = f'AsyncGenerator[{output_type}, None]' 368 add_import(imports, f'from typing import AsyncGenerator') 369 370 iterator_type = 'Iterator' if sync else 'AsyncIterator' 371 372 if input_mode == 'stream': 373 iterator_type = 'Iterator' if sync else 'AsyncIterator' 374 add_import(imports, f'from typing import {iterator_type}') 375 lines = (('' if sync else 'async ') + ( 376 f'def {method.name}(self, request: {iterator_type}[{input_type}], context: grpc.ServicerContext) -> {output_type_hint}:\n' 377 f' context.set_code(grpc.StatusCode.UNIMPLEMENTED) # type: ignore\n' 378 f' context.set_details("Method not implemented!") # type: ignore\n' 379 f' raise NotImplementedError("Method not implemented!")' 380 )).split('\n') 381 else: 382 lines = (('' if sync else 'async ') + ( 383 f'def {method.name}(self, request: {input_type}, context: grpc.ServicerContext) -> {output_type_hint}:\n' 384 f' context.set_code(grpc.StatusCode.UNIMPLEMENTED) # type: ignore\n' 385 f' context.set_details("Method not implemented!") # type: ignore\n' 386 f' raise NotImplementedError("Method not implemented!")' 387 )).split('\n') 388 if output_mode == 'stream': 389 lines.append(f' yield {output_type}() # no-op: to make the linter happy') 390 return lines 391 392 393def generate_servicer(imports: List[str], file: FileDescriptorProto, service: ServiceDescriptorProto, sync: bool = True) -> List[str]: 394 methods = '\n\n '.join([ 395 '\n '.join( 396 generate_servicer_method(imports, method, sync) 397 ) for method in service.method 398 ]) 399 if not methods: 400 methods = 'pass' 401 return ( 402 f'class {service.name}Servicer:\n' 403 f' {methods}\n' 404 ).split('\n') 405 406 407def generate_rpc_method_handler(imports: List[str], method: MethodDescriptorProto) -> List[str]: 408 input_mode = 'stream' if method.client_streaming else 'unary' 409 output_mode = 'stream' if method.server_streaming else 'unary' 410 411 input_type, _, _ = import_type(imports, method.input_type, None) 412 output_type, _, _ = import_type(imports, method.output_type, None) 413 414 return ( 415 f"'{method.name}': grpc.{input_mode}_{output_mode}_rpc_method_handler( # type: ignore\n" 416 f' servicer.{method.name},\n' 417 f' request_deserializer={input_type}.FromString, # type: ignore\n' 418 f' response_serializer={output_type}.SerializeToString, # type: ignore\n' 419 f' ),\n' 420 ).split('\n') 421 422 423def generate_add_servicer_to_server_method(imports: List[str], file: FileDescriptorProto, service: ServiceDescriptorProto, sync: bool = True) -> List[str]: 424 method_handlers = ' '.join([ 425 '\n '.join( 426 generate_rpc_method_handler(imports, method) 427 ) for method in service.method 428 ]) 429 server_type = 'grpc.Server' if sync else 'grpc.aio.Server' 430 return ( 431 f'def add_{service.name}Servicer_to_server(servicer: {service.name}Servicer, server: {server_type}) -> None:\n' 432 f' rpc_method_handlers = {{\n' 433 f' {method_handlers}\n' 434 f' }}\n' 435 f' generic_handler = grpc.method_handlers_generic_handler( # type: ignore\n' 436 f" '{file.package}.{service.name}', rpc_method_handlers)\n" 437 f' server.add_generic_rpc_handlers((generic_handler,)) # type: ignore\n' 438 ).split('\n') 439 440 441_HEADER = '''# Copyright 2022 Google LLC 442# 443# Licensed under the Apache License, Version 2.0 (the "License"); 444# you may not use this file except in compliance with the License. 445# You may obtain a copy of the License at 446# 447# https://www.apache.org/licenses/LICENSE-2.0 448# 449# Unless required by applicable law or agreed to in writing, software 450# distributed under the License is distributed on an "AS IS" BASIS, 451# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 452# See the License for the specific language governing permissions and 453# limitations under the License. 454 455"""Generated python gRPC interfaces.""" 456 457from __future__ import annotations 458''' 459 460_UTILS_PY = f'''{_HEADER} 461 462import asyncio 463import queue 464import grpc 465import sys 466 467from typing import Any, AsyncIterable, AsyncIterator, Generic, Iterator, TypeVar 468 469 470_T_co = TypeVar('_T_co', covariant=True) # pytype: disable=not-supported-yet 471_T = TypeVar('_T') 472 473 474class Stream(Iterator[_T_co], grpc.RpcContext): ... 475 476 477class AioStream(AsyncIterable[_T_co], grpc.RpcContext): ... 478 479 480class Sender(Iterator[_T]): 481 if sys.version_info >= (3, 8): 482 _inner: queue.Queue[_T] 483 else: 484 _inner: queue.Queue 485 486 def __init__(self) -> None: 487 self._inner = queue.Queue() 488 489 def __iter__(self) -> Iterator[_T]: 490 return self 491 492 def __next__(self) -> _T: 493 return self._inner.get() 494 495 def send(self, item: _T) -> None: 496 self._inner.put(item) 497 498 499class AioSender(AsyncIterator[_T]): 500 if sys.version_info >= (3, 8): 501 _inner: asyncio.Queue[_T] 502 else: 503 _inner: asyncio.Queue 504 505 def __init__(self) -> None: 506 self._inner = asyncio.Queue() 507 508 def __iter__(self) -> AsyncIterator[_T]: 509 return self 510 511 async def __anext__(self) -> _T: 512 return await self._inner.get() 513 514 async def send(self, item: _T) -> None: 515 await self._inner.put(item) 516 517 def send_nowait(self, item: _T) -> None: 518 self._inner.put_nowait(item) 519 520 521class StreamStream(Generic[_T, _T_co], Iterator[_T_co], grpc.RpcContext): 522 _sender: Sender[_T] 523 _receiver: Stream[_T_co] 524 525 def __init__(self, sender: Sender[_T], receiver: Stream[_T_co]) -> None: 526 self._sender = sender 527 self._receiver = receiver 528 529 def send(self, item: _T) -> None: 530 self._sender.send(item) 531 532 def __iter__(self) -> Iterator[_T_co]: 533 return self._receiver.__iter__() 534 535 def __next__(self) -> _T_co: 536 return self._receiver.__next__() 537 538 def is_active(self) -> bool: 539 return self._receiver.is_active() # type: ignore 540 541 def time_remaining(self) -> float: 542 return self._receiver.time_remaining() # type: ignore 543 544 def cancel(self) -> None: 545 self._receiver.cancel() # type: ignore 546 547 def add_callback(self, callback: Any) -> None: 548 self._receiver.add_callback(callback) # type: ignore 549 550 551class AioStreamStream(Generic[_T, _T_co], AsyncIterator[_T_co], grpc.RpcContext): 552 _sender: AioSender[_T] 553 _receiver: AioStream[_T_co] 554 555 def __init__(self, sender: AioSender[_T], receiver: AioStream[_T_co]) -> None: 556 self._sender = sender 557 self._receiver = receiver 558 559 def __aiter__(self) -> AsyncIterator[_T_co]: 560 return self._receiver.__aiter__() 561 562 async def __anext__(self) -> _T_co: 563 return await self._receiver.__aiter__().__anext__() 564 565 async def send(self, item: _T) -> None: 566 await self._sender.send(item) 567 568 def send_nowait(self, item: _T) -> None: 569 self._sender.send_nowait(item) 570 571 def is_active(self) -> bool: 572 return self._receiver.is_active() # type: ignore 573 574 def time_remaining(self) -> float: 575 return self._receiver.time_remaining() # type: ignore 576 577 def cancel(self) -> None: 578 self._receiver.cancel() # type: ignore 579 580 def add_callback(self, callback: Any) -> None: 581 self._receiver.add_callback(callback) # type: ignore 582''' 583 584 585_FILES: List[CodeGeneratorResponse.File] = [] 586_UTILS_FILES: Set[str] = set() 587 588 589for file_name in _REQUEST.file_to_generate: 590 file: FileDescriptorProto = next(filter(lambda x: x.name == file_name, _REQUEST.proto_file)) 591 592 _FILES.append(CodeGeneratorResponse.File( 593 name=file.name.replace('.proto', '_pb2.py'), 594 insertion_point=f'module_scope', 595 content='def unwrap(x):\n assert x\n return x\n' 596 )) 597 598 pyi_imports: List[str] = [] 599 grpc_imports: List[str] = ['import grpc'] 600 grpc_aio_imports: List[str] = ['import grpc', 'import grpc.aio'] 601 602 enums = '\n'.join(sum([generate_enum(pyi_imports, file, enum, _FILES) for enum in file.enum_type], [])) 603 messages = '\n'.join(sum([generate_message(pyi_imports, file, message, _FILES) for message in file.message_type], [])) 604 605 services = '\n'.join(sum([generate_service(grpc_imports, file, service) for service in file.service], [])) 606 aio_services = '\n'.join(sum([generate_service(grpc_aio_imports, file, service, False) for service in file.service], [])) 607 608 servicers = '\n'.join(sum([generate_servicer(grpc_imports, file, service) for service in file.service], [])) 609 aio_servicers = '\n'.join(sum([generate_servicer(grpc_aio_imports, file, service, False) for service in file.service], [])) 610 611 add_servicer_methods = '\n'.join(sum([generate_add_servicer_to_server_method(grpc_imports, file, service) for service in file.service], [])) 612 aio_add_servicer_methods = '\n'.join(sum([generate_add_servicer_to_server_method(grpc_aio_imports, file, service, False) for service in file.service], [])) 613 614 pyi_imports.sort() 615 grpc_imports.sort() 616 grpc_aio_imports.sort() 617 618 pyi_imports_str: str = '\n'.join(pyi_imports) 619 grpc_imports_str: str = '\n'.join(grpc_imports) 620 grpc_aio_imports_str: str = '\n'.join(grpc_aio_imports) 621 622 utils_filename = file_name.replace(os.path.basename(file_name), '_utils.py') 623 if utils_filename not in _UTILS_FILES: 624 _UTILS_FILES.add(utils_filename) 625 _FILES.extend([ 626 CodeGeneratorResponse.File( 627 name=utils_filename, 628 content=_UTILS_PY, 629 ) 630 ]) 631 632 _FILES.extend([ 633 CodeGeneratorResponse.File( 634 name=file.name.replace('.proto', '_pb2.pyi'), 635 content=f'{_HEADER}\n\n{pyi_imports_str}\n\n{enums}\n\n{messages}\n' 636 ), 637 CodeGeneratorResponse.File( 638 name=file_name.replace('.proto', '_grpc.py'), 639 content=f'{_HEADER}\n\n{grpc_imports_str}\n\n{services}\n\n{servicers}\n\n{add_servicer_methods}' 640 ), 641 CodeGeneratorResponse.File( 642 name=file_name.replace('.proto', '_grpc_aio.py'), 643 content=f'{_HEADER}\n\n{grpc_aio_imports_str}\n\n{aio_services}\n\n{aio_servicers}\n\n{aio_add_servicer_methods}' 644 ) 645 ]) 646 647 648sys.stdout.buffer.write(CodeGeneratorResponse(file=_FILES).SerializeToString()) 649