1# Copyright 2016 gRPC authors. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15import abc 16import contextlib 17import importlib 18import os 19from os import path 20import pkgutil 21import platform 22import shutil 23import sys 24import tempfile 25import unittest 26 27import grpc 28from grpc_tools import protoc 29 30if sys.version_info >= (3, 9, 0): 31 from importlib import resources 32else: 33 import pkg_resources 34 35 36from tests.unit import test_common 37 38_MESSAGES_IMPORT = b'import "messages.proto";' 39_SPLIT_NAMESPACE = b"package grpc_protoc_plugin.invocation_testing.split;" 40_COMMON_NAMESPACE = b"package grpc_protoc_plugin.invocation_testing;" 41 42_RELATIVE_PROTO_PATH = "relative_proto_path" 43_RELATIVE_PYTHON_OUT = "relative_python_out" 44 45_TEST_DIR = os.path.dirname(os.path.realpath(__file__)) 46 47 48@contextlib.contextmanager 49def _system_path(path_insertion): 50 old_system_path = sys.path[:] 51 sys.path = sys.path[0:1] + path_insertion + sys.path[1:] 52 yield 53 sys.path = old_system_path 54 55 56def _get_resource_file_name( 57 package_or_requirement: str, resource_name: str 58) -> str: 59 """Obtain the filename for a resource on the file system.""" 60 file_name = None 61 if sys.version_info >= (3, 9, 0): 62 file_name = ( 63 resources.files(package_or_requirement) / resource_name 64 ).resolve() 65 else: 66 file_name = pkg_resources.resource_filename( 67 package_or_requirement, resource_name 68 ) 69 return str(file_name) 70 71 72# NOTE(nathaniel): https://twitter.com/exoplaneteer/status/677259364256747520 73# Life lesson "just always default to idempotence" reinforced. 74def _create_directory_tree(root, path_components_sequence): 75 created = set() 76 for path_components in path_components_sequence: 77 thus_far = "" 78 for path_component in path_components: 79 relative_path = path.join(thus_far, path_component) 80 if relative_path not in created: 81 os.makedirs(path.join(root, relative_path)) 82 created.add(relative_path) 83 thus_far = path.join(thus_far, path_component) 84 85 86def _massage_proto_content( 87 proto_content, test_name_bytes, messages_proto_relative_file_name_bytes 88): 89 package_substitution = ( 90 b"package grpc_protoc_plugin.invocation_testing." 91 + test_name_bytes 92 + b";" 93 ) 94 common_namespace_substituted = proto_content.replace( 95 _COMMON_NAMESPACE, package_substitution 96 ) 97 split_namespace_substituted = common_namespace_substituted.replace( 98 _SPLIT_NAMESPACE, package_substitution 99 ) 100 message_import_replaced = split_namespace_substituted.replace( 101 _MESSAGES_IMPORT, 102 b'import "' + messages_proto_relative_file_name_bytes + b'";', 103 ) 104 return message_import_replaced 105 106 107def _packagify(directory): 108 for subdirectory, _, _ in os.walk(directory): 109 init_file_name = path.join(subdirectory, "__init__.py") 110 with open(init_file_name, "wb") as init_file: 111 init_file.write(b"") 112 113 114class _Servicer(object): 115 def __init__(self, response_class): 116 self._response_class = response_class 117 118 def Call(self, request, context): 119 return self._response_class() 120 121 122def _protoc( 123 proto_path, 124 python_out, 125 grpc_python_out_flag, 126 grpc_python_out, 127 absolute_proto_file_names, 128): 129 args = [ 130 "", 131 "--proto_path={}".format(proto_path), 132 ] 133 if python_out is not None: 134 args.append("--python_out={}".format(python_out)) 135 if grpc_python_out is not None: 136 args.append( 137 "--grpc_python_out={}:{}".format( 138 grpc_python_out_flag, grpc_python_out 139 ) 140 ) 141 args.extend(absolute_proto_file_names) 142 return protoc.main(args) 143 144 145class _Mid2016ProtocStyle(object): 146 def name(self): 147 return "Mid2016ProtocStyle" 148 149 def grpc_in_pb2_expected(self): 150 return True 151 152 def protoc(self, proto_path, python_out, absolute_proto_file_names): 153 return ( 154 _protoc( 155 proto_path, 156 python_out, 157 "grpc_1_0", 158 python_out, 159 absolute_proto_file_names, 160 ), 161 ) 162 163 164class _SingleProtocExecutionProtocStyle(object): 165 def name(self): 166 return "SingleProtocExecutionProtocStyle" 167 168 def grpc_in_pb2_expected(self): 169 return False 170 171 def protoc(self, proto_path, python_out, absolute_proto_file_names): 172 return ( 173 _protoc( 174 proto_path, 175 python_out, 176 "grpc_2_0", 177 python_out, 178 absolute_proto_file_names, 179 ), 180 ) 181 182 183class _ProtoBeforeGrpcProtocStyle(object): 184 def name(self): 185 return "ProtoBeforeGrpcProtocStyle" 186 187 def grpc_in_pb2_expected(self): 188 return False 189 190 def protoc(self, proto_path, python_out, absolute_proto_file_names): 191 pb2_protoc_exit_code = _protoc( 192 proto_path, python_out, None, None, absolute_proto_file_names 193 ) 194 pb2_grpc_protoc_exit_code = _protoc( 195 proto_path, None, "grpc_2_0", python_out, absolute_proto_file_names 196 ) 197 return pb2_protoc_exit_code, pb2_grpc_protoc_exit_code 198 199 200class _GrpcBeforeProtoProtocStyle(object): 201 def name(self): 202 return "GrpcBeforeProtoProtocStyle" 203 204 def grpc_in_pb2_expected(self): 205 return False 206 207 def protoc(self, proto_path, python_out, absolute_proto_file_names): 208 pb2_grpc_protoc_exit_code = _protoc( 209 proto_path, None, "grpc_2_0", python_out, absolute_proto_file_names 210 ) 211 pb2_protoc_exit_code = _protoc( 212 proto_path, python_out, None, None, absolute_proto_file_names 213 ) 214 return pb2_grpc_protoc_exit_code, pb2_protoc_exit_code 215 216 217_PROTOC_STYLES = ( 218 _Mid2016ProtocStyle(), 219 _SingleProtocExecutionProtocStyle(), 220 _ProtoBeforeGrpcProtocStyle(), 221 _GrpcBeforeProtoProtocStyle(), 222) 223 224 225@unittest.skipIf( 226 platform.python_implementation() == "PyPy", "Skip test if run with PyPy!" 227) 228class _Test(unittest.TestCase, metaclass=abc.ABCMeta): 229 def setUp(self): 230 self._directory = tempfile.mkdtemp(suffix=self.NAME, dir=".") 231 self._proto_path = path.join(self._directory, _RELATIVE_PROTO_PATH) 232 self._python_out = path.join(self._directory, _RELATIVE_PYTHON_OUT) 233 234 os.makedirs(self._proto_path) 235 os.makedirs(self._python_out) 236 237 proto_directories_and_names = { 238 ( 239 self.MESSAGES_PROTO_RELATIVE_DIRECTORY_NAMES, 240 self.MESSAGES_PROTO_FILE_NAME, 241 ), 242 ( 243 self.SERVICES_PROTO_RELATIVE_DIRECTORY_NAMES, 244 self.SERVICES_PROTO_FILE_NAME, 245 ), 246 } 247 messages_proto_relative_file_name_forward_slashes = "/".join( 248 self.MESSAGES_PROTO_RELATIVE_DIRECTORY_NAMES 249 + (self.MESSAGES_PROTO_FILE_NAME,) 250 ) 251 _create_directory_tree( 252 self._proto_path, 253 ( 254 relative_proto_directory_names 255 for relative_proto_directory_names, _ in proto_directories_and_names 256 ), 257 ) 258 self._absolute_proto_file_names = set() 259 for relative_directory_names, file_name in proto_directories_and_names: 260 absolute_proto_file_name = path.join( 261 self._proto_path, *relative_directory_names + (file_name,) 262 ) 263 raw_proto_content = pkgutil.get_data( 264 "tests.protoc_plugin.protos.invocation_testing", 265 path.join(*relative_directory_names + (file_name,)), 266 ) 267 massaged_proto_content = _massage_proto_content( 268 raw_proto_content, 269 self.NAME.encode(), 270 messages_proto_relative_file_name_forward_slashes.encode(), 271 ) 272 with open(absolute_proto_file_name, "wb") as proto_file: 273 proto_file.write(massaged_proto_content) 274 self._absolute_proto_file_names.add(absolute_proto_file_name) 275 276 def tearDown(self): 277 shutil.rmtree(self._directory) 278 279 def _protoc(self): 280 protoc_exit_codes = self.PROTOC_STYLE.protoc( 281 self._proto_path, self._python_out, self._absolute_proto_file_names 282 ) 283 for protoc_exit_code in protoc_exit_codes: 284 self.assertEqual(0, protoc_exit_code) 285 286 _packagify(self._python_out) 287 288 generated_modules = {} 289 expected_generated_full_module_names = { 290 self.EXPECTED_MESSAGES_PB2, 291 self.EXPECTED_SERVICES_PB2, 292 self.EXPECTED_SERVICES_PB2_GRPC, 293 } 294 with _system_path([self._python_out]): 295 for full_module_name in expected_generated_full_module_names: 296 module = importlib.import_module(full_module_name) 297 generated_modules[full_module_name] = module 298 299 self._messages_pb2 = generated_modules[self.EXPECTED_MESSAGES_PB2] 300 self._services_pb2 = generated_modules[self.EXPECTED_SERVICES_PB2] 301 self._services_pb2_grpc = generated_modules[ 302 self.EXPECTED_SERVICES_PB2_GRPC 303 ] 304 305 def _services_modules(self): 306 if self.PROTOC_STYLE.grpc_in_pb2_expected(): 307 return self._services_pb2, self._services_pb2_grpc 308 else: 309 return (self._services_pb2_grpc,) 310 311 def test_imported_attributes(self): 312 self._protoc() 313 314 self._messages_pb2.Request 315 self._messages_pb2.Response 316 self._services_pb2.DESCRIPTOR.services_by_name["TestService"] 317 for services_module in self._services_modules(): 318 services_module.TestServiceStub 319 services_module.TestServiceServicer 320 services_module.add_TestServiceServicer_to_server 321 322 def test_call(self): 323 self._protoc() 324 325 for services_module in self._services_modules(): 326 server = test_common.test_server() 327 services_module.add_TestServiceServicer_to_server( 328 _Servicer(self._messages_pb2.Response), server 329 ) 330 port = server.add_insecure_port("[::]:0") 331 server.start() 332 channel = grpc.insecure_channel("localhost:{}".format(port)) 333 stub = services_module.TestServiceStub(channel) 334 response = stub.Call(self._messages_pb2.Request()) 335 self.assertEqual(self._messages_pb2.Response(), response) 336 server.stop(None) 337 338 339def _create_test_case_class(split_proto, protoc_style): 340 attributes = {} 341 342 name = "{}{}".format( 343 "SplitProto" if split_proto else "SameProto", protoc_style.name() 344 ) 345 attributes["NAME"] = name 346 347 if split_proto: 348 attributes["MESSAGES_PROTO_RELATIVE_DIRECTORY_NAMES"] = ( 349 "split_messages", 350 "sub", 351 ) 352 attributes["MESSAGES_PROTO_FILE_NAME"] = "messages.proto" 353 attributes["SERVICES_PROTO_RELATIVE_DIRECTORY_NAMES"] = ( 354 "split_services", 355 ) 356 attributes["SERVICES_PROTO_FILE_NAME"] = "services.proto" 357 attributes["EXPECTED_MESSAGES_PB2"] = "split_messages.sub.messages_pb2" 358 attributes["EXPECTED_SERVICES_PB2"] = "split_services.services_pb2" 359 attributes[ 360 "EXPECTED_SERVICES_PB2_GRPC" 361 ] = "split_services.services_pb2_grpc" 362 else: 363 attributes["MESSAGES_PROTO_RELATIVE_DIRECTORY_NAMES"] = () 364 attributes["MESSAGES_PROTO_FILE_NAME"] = "same.proto" 365 attributes["SERVICES_PROTO_RELATIVE_DIRECTORY_NAMES"] = () 366 attributes["SERVICES_PROTO_FILE_NAME"] = "same.proto" 367 attributes["EXPECTED_MESSAGES_PB2"] = "same_pb2" 368 attributes["EXPECTED_SERVICES_PB2"] = "same_pb2" 369 attributes["EXPECTED_SERVICES_PB2_GRPC"] = "same_pb2_grpc" 370 371 attributes["PROTOC_STYLE"] = protoc_style 372 373 attributes["__module__"] = _Test.__module__ 374 375 return type("{}Test".format(name), (_Test,), attributes) 376 377 378def _create_test_case_classes(): 379 for split_proto in ( 380 False, 381 True, 382 ): 383 for protoc_style in _PROTOC_STYLES: 384 yield _create_test_case_class(split_proto, protoc_style) 385 386 387class WellKnownTypesTest(unittest.TestCase): 388 def testWellKnownTypes(self): 389 os.chdir(_TEST_DIR) 390 out_dir = tempfile.mkdtemp(suffix="wkt_test", dir=".") 391 well_known_protos_include = _get_resource_file_name( 392 "grpc_tools", "_proto" 393 ) 394 args = [ 395 "grpc_tools.protoc", 396 "--proto_path=protos", 397 "--proto_path={}".format(well_known_protos_include), 398 "--python_out={}".format(out_dir), 399 "--grpc_python_out={}".format(out_dir), 400 "protos/invocation_testing/compiler.proto", 401 ] 402 rc = protoc.main(args) 403 self.assertEqual(0, rc) 404 405 406def load_tests(loader, tests, pattern): 407 tests = tuple( 408 loader.loadTestsFromTestCase(test_case_class) 409 for test_case_class in _create_test_case_classes() 410 ) + tuple(loader.loadTestsFromTestCase(WellKnownTypesTest)) 411 return unittest.TestSuite(tests=tests) 412 413 414if __name__ == "__main__": 415 unittest.main(verbosity=2) 416