xref: /aosp_15_r20/external/grpc-grpc/src/python/grpcio_tests/tests/interop/client.py (revision cc02d7e222339f7a4f6ba5f422e6413f4bd931f2)
1# Copyright 2015 gRPC authors.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""The Python implementation of the GRPC interoperability test client."""
15
16import os
17
18from absl import app
19from absl.flags import argparse_flags
20from google import auth as google_auth
21from google.auth import jwt as google_auth_jwt
22import grpc
23
24from src.proto.grpc.testing import test_pb2_grpc
25from tests.interop import methods
26from tests.interop import resources
27
28
29def parse_interop_client_args(argv):
30    parser = argparse_flags.ArgumentParser()
31    parser.add_argument(
32        "--server_host",
33        default="localhost",
34        type=str,
35        help="the host to which to connect",
36    )
37    parser.add_argument(
38        "--server_port",
39        type=int,
40        required=True,
41        help="the port to which to connect",
42    )
43    parser.add_argument(
44        "--test_case",
45        default="large_unary",
46        type=str,
47        help="the test case to execute",
48    )
49    parser.add_argument(
50        "--use_tls",
51        default=False,
52        type=resources.parse_bool,
53        help="require a secure connection",
54    )
55    parser.add_argument(
56        "--use_alts",
57        default=False,
58        type=resources.parse_bool,
59        help="require an ALTS secure connection",
60    )
61    parser.add_argument(
62        "--use_test_ca",
63        default=False,
64        type=resources.parse_bool,
65        help="replace platform root CAs with ca.pem",
66    )
67    parser.add_argument(
68        "--custom_credentials_type",
69        choices=["compute_engine_channel_creds"],
70        default=None,
71        help="use google default credentials",
72    )
73    parser.add_argument(
74        "--server_host_override",
75        type=str,
76        help="the server host to which to claim to connect",
77    )
78    parser.add_argument(
79        "--oauth_scope", type=str, help="scope for OAuth tokens"
80    )
81    parser.add_argument(
82        "--default_service_account",
83        type=str,
84        help="email address of the default service account",
85    )
86    parser.add_argument(
87        "--grpc_test_use_grpclb_with_child_policy",
88        type=str,
89        help=(
90            "If non-empty, set a static service config on channels created by "
91            + "grpc::CreateTestChannel, that configures the grpclb LB policy "
92            + "with a child policy being the value of this flag (e.g."
93            " round_robin " + "or pick_first)."
94        ),
95    )
96    return parser.parse_args(argv[1:])
97
98
99def _create_call_credentials(args):
100    if args.test_case == "oauth2_auth_token":
101        google_credentials, unused_project_id = google_auth.default(
102            scopes=[args.oauth_scope]
103        )
104        google_credentials.refresh(google_auth.transport.requests.Request())
105        return grpc.access_token_call_credentials(google_credentials.token)
106    elif args.test_case == "compute_engine_creds":
107        google_credentials, unused_project_id = google_auth.default(
108            scopes=[args.oauth_scope]
109        )
110        return grpc.metadata_call_credentials(
111            google_auth.transport.grpc.AuthMetadataPlugin(
112                credentials=google_credentials,
113                request=google_auth.transport.requests.Request(),
114            )
115        )
116    elif args.test_case == "jwt_token_creds":
117        google_credentials = (
118            google_auth_jwt.OnDemandCredentials.from_service_account_file(
119                os.environ[google_auth.environment_vars.CREDENTIALS]
120            )
121        )
122        return grpc.metadata_call_credentials(
123            google_auth.transport.grpc.AuthMetadataPlugin(
124                credentials=google_credentials, request=None
125            )
126        )
127    else:
128        return None
129
130
131def get_secure_channel_parameters(args):
132    call_credentials = _create_call_credentials(args)
133
134    channel_opts = ()
135    if args.grpc_test_use_grpclb_with_child_policy:
136        channel_opts += (
137            (
138                "grpc.service_config",
139                '{"loadBalancingConfig": [{"grpclb": {"childPolicy": [{"%s":'
140                " {}}]}}]}" % args.grpc_test_use_grpclb_with_child_policy,
141            ),
142        )
143    if args.custom_credentials_type is not None:
144        if args.custom_credentials_type == "compute_engine_channel_creds":
145            assert call_credentials is None
146            google_credentials, unused_project_id = google_auth.default(
147                scopes=[args.oauth_scope]
148            )
149            call_creds = grpc.metadata_call_credentials(
150                google_auth.transport.grpc.AuthMetadataPlugin(
151                    credentials=google_credentials,
152                    request=google_auth.transport.requests.Request(),
153                )
154            )
155            channel_credentials = grpc.compute_engine_channel_credentials(
156                call_creds
157            )
158        else:
159            raise ValueError(
160                "Unknown credentials type '{}'".format(
161                    args.custom_credentials_type
162                )
163            )
164    elif args.use_tls:
165        if args.use_test_ca:
166            root_certificates = resources.test_root_certificates()
167        else:
168            root_certificates = None  # will load default roots.
169
170        channel_credentials = grpc.ssl_channel_credentials(root_certificates)
171        if call_credentials is not None:
172            channel_credentials = grpc.composite_channel_credentials(
173                channel_credentials, call_credentials
174            )
175
176        if args.server_host_override:
177            channel_opts += (
178                (
179                    "grpc.ssl_target_name_override",
180                    args.server_host_override,
181                ),
182            )
183    elif args.use_alts:
184        channel_credentials = grpc.alts_channel_credentials()
185
186    return channel_credentials, channel_opts
187
188
189def _create_channel(args):
190    target = "{}:{}".format(args.server_host, args.server_port)
191
192    if (
193        args.use_tls
194        or args.use_alts
195        or args.custom_credentials_type is not None
196    ):
197        channel_credentials, options = get_secure_channel_parameters(args)
198        return grpc.secure_channel(target, channel_credentials, options)
199    else:
200        return grpc.insecure_channel(target)
201
202
203def create_stub(channel, args):
204    if args.test_case == "unimplemented_service":
205        return test_pb2_grpc.UnimplementedServiceStub(channel)
206    else:
207        return test_pb2_grpc.TestServiceStub(channel)
208
209
210def _test_case_from_arg(test_case_arg):
211    for test_case in methods.TestCase:
212        if test_case_arg == test_case.value:
213            return test_case
214    else:
215        raise ValueError('No test case "%s"!' % test_case_arg)
216
217
218def test_interoperability(args):
219    channel = _create_channel(args)
220    stub = create_stub(channel, args)
221    test_case = _test_case_from_arg(args.test_case)
222    test_case.test_interoperability(stub, args)
223
224
225if __name__ == "__main__":
226    app.run(test_interoperability, flags_parser=parse_interop_client_args)
227