xref: /aosp_15_r20/external/grpc-grpc/src/python/grpcio_tests/tests/protoc_plugin/_split_definitions_test.py (revision cc02d7e222339f7a4f6ba5f422e6413f4bd931f2)
1# Copyright 2016 gRPC authors.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15import abc
16import contextlib
17import importlib
18import os
19from os import path
20import pkgutil
21import platform
22import shutil
23import sys
24import tempfile
25import unittest
26
27import grpc
28from grpc_tools import protoc
29
30if sys.version_info >= (3, 9, 0):
31    from importlib import resources
32else:
33    import pkg_resources
34
35
36from tests.unit import test_common
37
38_MESSAGES_IMPORT = b'import "messages.proto";'
39_SPLIT_NAMESPACE = b"package grpc_protoc_plugin.invocation_testing.split;"
40_COMMON_NAMESPACE = b"package grpc_protoc_plugin.invocation_testing;"
41
42_RELATIVE_PROTO_PATH = "relative_proto_path"
43_RELATIVE_PYTHON_OUT = "relative_python_out"
44
45_TEST_DIR = os.path.dirname(os.path.realpath(__file__))
46
47
48@contextlib.contextmanager
49def _system_path(path_insertion):
50    old_system_path = sys.path[:]
51    sys.path = sys.path[0:1] + path_insertion + sys.path[1:]
52    yield
53    sys.path = old_system_path
54
55
56def _get_resource_file_name(
57    package_or_requirement: str, resource_name: str
58) -> str:
59    """Obtain the filename for a resource on the file system."""
60    file_name = None
61    if sys.version_info >= (3, 9, 0):
62        file_name = (
63            resources.files(package_or_requirement) / resource_name
64        ).resolve()
65    else:
66        file_name = pkg_resources.resource_filename(
67            package_or_requirement, resource_name
68        )
69    return str(file_name)
70
71
72# NOTE(nathaniel): https://twitter.com/exoplaneteer/status/677259364256747520
73# Life lesson "just always default to idempotence" reinforced.
74def _create_directory_tree(root, path_components_sequence):
75    created = set()
76    for path_components in path_components_sequence:
77        thus_far = ""
78        for path_component in path_components:
79            relative_path = path.join(thus_far, path_component)
80            if relative_path not in created:
81                os.makedirs(path.join(root, relative_path))
82                created.add(relative_path)
83            thus_far = path.join(thus_far, path_component)
84
85
86def _massage_proto_content(
87    proto_content, test_name_bytes, messages_proto_relative_file_name_bytes
88):
89    package_substitution = (
90        b"package grpc_protoc_plugin.invocation_testing."
91        + test_name_bytes
92        + b";"
93    )
94    common_namespace_substituted = proto_content.replace(
95        _COMMON_NAMESPACE, package_substitution
96    )
97    split_namespace_substituted = common_namespace_substituted.replace(
98        _SPLIT_NAMESPACE, package_substitution
99    )
100    message_import_replaced = split_namespace_substituted.replace(
101        _MESSAGES_IMPORT,
102        b'import "' + messages_proto_relative_file_name_bytes + b'";',
103    )
104    return message_import_replaced
105
106
107def _packagify(directory):
108    for subdirectory, _, _ in os.walk(directory):
109        init_file_name = path.join(subdirectory, "__init__.py")
110        with open(init_file_name, "wb") as init_file:
111            init_file.write(b"")
112
113
114class _Servicer(object):
115    def __init__(self, response_class):
116        self._response_class = response_class
117
118    def Call(self, request, context):
119        return self._response_class()
120
121
122def _protoc(
123    proto_path,
124    python_out,
125    grpc_python_out_flag,
126    grpc_python_out,
127    absolute_proto_file_names,
128):
129    args = [
130        "",
131        "--proto_path={}".format(proto_path),
132    ]
133    if python_out is not None:
134        args.append("--python_out={}".format(python_out))
135    if grpc_python_out is not None:
136        args.append(
137            "--grpc_python_out={}:{}".format(
138                grpc_python_out_flag, grpc_python_out
139            )
140        )
141    args.extend(absolute_proto_file_names)
142    return protoc.main(args)
143
144
145class _Mid2016ProtocStyle(object):
146    def name(self):
147        return "Mid2016ProtocStyle"
148
149    def grpc_in_pb2_expected(self):
150        return True
151
152    def protoc(self, proto_path, python_out, absolute_proto_file_names):
153        return (
154            _protoc(
155                proto_path,
156                python_out,
157                "grpc_1_0",
158                python_out,
159                absolute_proto_file_names,
160            ),
161        )
162
163
164class _SingleProtocExecutionProtocStyle(object):
165    def name(self):
166        return "SingleProtocExecutionProtocStyle"
167
168    def grpc_in_pb2_expected(self):
169        return False
170
171    def protoc(self, proto_path, python_out, absolute_proto_file_names):
172        return (
173            _protoc(
174                proto_path,
175                python_out,
176                "grpc_2_0",
177                python_out,
178                absolute_proto_file_names,
179            ),
180        )
181
182
183class _ProtoBeforeGrpcProtocStyle(object):
184    def name(self):
185        return "ProtoBeforeGrpcProtocStyle"
186
187    def grpc_in_pb2_expected(self):
188        return False
189
190    def protoc(self, proto_path, python_out, absolute_proto_file_names):
191        pb2_protoc_exit_code = _protoc(
192            proto_path, python_out, None, None, absolute_proto_file_names
193        )
194        pb2_grpc_protoc_exit_code = _protoc(
195            proto_path, None, "grpc_2_0", python_out, absolute_proto_file_names
196        )
197        return pb2_protoc_exit_code, pb2_grpc_protoc_exit_code
198
199
200class _GrpcBeforeProtoProtocStyle(object):
201    def name(self):
202        return "GrpcBeforeProtoProtocStyle"
203
204    def grpc_in_pb2_expected(self):
205        return False
206
207    def protoc(self, proto_path, python_out, absolute_proto_file_names):
208        pb2_grpc_protoc_exit_code = _protoc(
209            proto_path, None, "grpc_2_0", python_out, absolute_proto_file_names
210        )
211        pb2_protoc_exit_code = _protoc(
212            proto_path, python_out, None, None, absolute_proto_file_names
213        )
214        return pb2_grpc_protoc_exit_code, pb2_protoc_exit_code
215
216
217_PROTOC_STYLES = (
218    _Mid2016ProtocStyle(),
219    _SingleProtocExecutionProtocStyle(),
220    _ProtoBeforeGrpcProtocStyle(),
221    _GrpcBeforeProtoProtocStyle(),
222)
223
224
225@unittest.skipIf(
226    platform.python_implementation() == "PyPy", "Skip test if run with PyPy!"
227)
228class _Test(unittest.TestCase, metaclass=abc.ABCMeta):
229    def setUp(self):
230        self._directory = tempfile.mkdtemp(suffix=self.NAME, dir=".")
231        self._proto_path = path.join(self._directory, _RELATIVE_PROTO_PATH)
232        self._python_out = path.join(self._directory, _RELATIVE_PYTHON_OUT)
233
234        os.makedirs(self._proto_path)
235        os.makedirs(self._python_out)
236
237        proto_directories_and_names = {
238            (
239                self.MESSAGES_PROTO_RELATIVE_DIRECTORY_NAMES,
240                self.MESSAGES_PROTO_FILE_NAME,
241            ),
242            (
243                self.SERVICES_PROTO_RELATIVE_DIRECTORY_NAMES,
244                self.SERVICES_PROTO_FILE_NAME,
245            ),
246        }
247        messages_proto_relative_file_name_forward_slashes = "/".join(
248            self.MESSAGES_PROTO_RELATIVE_DIRECTORY_NAMES
249            + (self.MESSAGES_PROTO_FILE_NAME,)
250        )
251        _create_directory_tree(
252            self._proto_path,
253            (
254                relative_proto_directory_names
255                for relative_proto_directory_names, _ in proto_directories_and_names
256            ),
257        )
258        self._absolute_proto_file_names = set()
259        for relative_directory_names, file_name in proto_directories_and_names:
260            absolute_proto_file_name = path.join(
261                self._proto_path, *relative_directory_names + (file_name,)
262            )
263            raw_proto_content = pkgutil.get_data(
264                "tests.protoc_plugin.protos.invocation_testing",
265                path.join(*relative_directory_names + (file_name,)),
266            )
267            massaged_proto_content = _massage_proto_content(
268                raw_proto_content,
269                self.NAME.encode(),
270                messages_proto_relative_file_name_forward_slashes.encode(),
271            )
272            with open(absolute_proto_file_name, "wb") as proto_file:
273                proto_file.write(massaged_proto_content)
274            self._absolute_proto_file_names.add(absolute_proto_file_name)
275
276    def tearDown(self):
277        shutil.rmtree(self._directory)
278
279    def _protoc(self):
280        protoc_exit_codes = self.PROTOC_STYLE.protoc(
281            self._proto_path, self._python_out, self._absolute_proto_file_names
282        )
283        for protoc_exit_code in protoc_exit_codes:
284            self.assertEqual(0, protoc_exit_code)
285
286        _packagify(self._python_out)
287
288        generated_modules = {}
289        expected_generated_full_module_names = {
290            self.EXPECTED_MESSAGES_PB2,
291            self.EXPECTED_SERVICES_PB2,
292            self.EXPECTED_SERVICES_PB2_GRPC,
293        }
294        with _system_path([self._python_out]):
295            for full_module_name in expected_generated_full_module_names:
296                module = importlib.import_module(full_module_name)
297                generated_modules[full_module_name] = module
298
299        self._messages_pb2 = generated_modules[self.EXPECTED_MESSAGES_PB2]
300        self._services_pb2 = generated_modules[self.EXPECTED_SERVICES_PB2]
301        self._services_pb2_grpc = generated_modules[
302            self.EXPECTED_SERVICES_PB2_GRPC
303        ]
304
305    def _services_modules(self):
306        if self.PROTOC_STYLE.grpc_in_pb2_expected():
307            return self._services_pb2, self._services_pb2_grpc
308        else:
309            return (self._services_pb2_grpc,)
310
311    def test_imported_attributes(self):
312        self._protoc()
313
314        self._messages_pb2.Request
315        self._messages_pb2.Response
316        self._services_pb2.DESCRIPTOR.services_by_name["TestService"]
317        for services_module in self._services_modules():
318            services_module.TestServiceStub
319            services_module.TestServiceServicer
320            services_module.add_TestServiceServicer_to_server
321
322    def test_call(self):
323        self._protoc()
324
325        for services_module in self._services_modules():
326            server = test_common.test_server()
327            services_module.add_TestServiceServicer_to_server(
328                _Servicer(self._messages_pb2.Response), server
329            )
330            port = server.add_insecure_port("[::]:0")
331            server.start()
332            channel = grpc.insecure_channel("localhost:{}".format(port))
333            stub = services_module.TestServiceStub(channel)
334            response = stub.Call(self._messages_pb2.Request())
335            self.assertEqual(self._messages_pb2.Response(), response)
336            server.stop(None)
337
338
339def _create_test_case_class(split_proto, protoc_style):
340    attributes = {}
341
342    name = "{}{}".format(
343        "SplitProto" if split_proto else "SameProto", protoc_style.name()
344    )
345    attributes["NAME"] = name
346
347    if split_proto:
348        attributes["MESSAGES_PROTO_RELATIVE_DIRECTORY_NAMES"] = (
349            "split_messages",
350            "sub",
351        )
352        attributes["MESSAGES_PROTO_FILE_NAME"] = "messages.proto"
353        attributes["SERVICES_PROTO_RELATIVE_DIRECTORY_NAMES"] = (
354            "split_services",
355        )
356        attributes["SERVICES_PROTO_FILE_NAME"] = "services.proto"
357        attributes["EXPECTED_MESSAGES_PB2"] = "split_messages.sub.messages_pb2"
358        attributes["EXPECTED_SERVICES_PB2"] = "split_services.services_pb2"
359        attributes[
360            "EXPECTED_SERVICES_PB2_GRPC"
361        ] = "split_services.services_pb2_grpc"
362    else:
363        attributes["MESSAGES_PROTO_RELATIVE_DIRECTORY_NAMES"] = ()
364        attributes["MESSAGES_PROTO_FILE_NAME"] = "same.proto"
365        attributes["SERVICES_PROTO_RELATIVE_DIRECTORY_NAMES"] = ()
366        attributes["SERVICES_PROTO_FILE_NAME"] = "same.proto"
367        attributes["EXPECTED_MESSAGES_PB2"] = "same_pb2"
368        attributes["EXPECTED_SERVICES_PB2"] = "same_pb2"
369        attributes["EXPECTED_SERVICES_PB2_GRPC"] = "same_pb2_grpc"
370
371    attributes["PROTOC_STYLE"] = protoc_style
372
373    attributes["__module__"] = _Test.__module__
374
375    return type("{}Test".format(name), (_Test,), attributes)
376
377
378def _create_test_case_classes():
379    for split_proto in (
380        False,
381        True,
382    ):
383        for protoc_style in _PROTOC_STYLES:
384            yield _create_test_case_class(split_proto, protoc_style)
385
386
387class WellKnownTypesTest(unittest.TestCase):
388    def testWellKnownTypes(self):
389        os.chdir(_TEST_DIR)
390        out_dir = tempfile.mkdtemp(suffix="wkt_test", dir=".")
391        well_known_protos_include = _get_resource_file_name(
392            "grpc_tools", "_proto"
393        )
394        args = [
395            "grpc_tools.protoc",
396            "--proto_path=protos",
397            "--proto_path={}".format(well_known_protos_include),
398            "--python_out={}".format(out_dir),
399            "--grpc_python_out={}".format(out_dir),
400            "protos/invocation_testing/compiler.proto",
401        ]
402        rc = protoc.main(args)
403        self.assertEqual(0, rc)
404
405
406def load_tests(loader, tests, pattern):
407    tests = tuple(
408        loader.loadTestsFromTestCase(test_case_class)
409        for test_case_class in _create_test_case_classes()
410    ) + tuple(loader.loadTestsFromTestCase(WellKnownTypesTest))
411    return unittest.TestSuite(tests=tests)
412
413
414if __name__ == "__main__":
415    unittest.main(verbosity=2)
416