1# Copyright 2016 Google LLC
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#      http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15"""Authorization support for gRPC."""
16
17from __future__ import absolute_import
18
19import logging
20import os
21
22import six
23
24from google.auth import environment_vars
25from google.auth import exceptions
26from google.auth.transport import _mtls_helper
27from google.oauth2 import service_account
28
29try:
30    import grpc
31except ImportError as caught_exc:  # pragma: NO COVER
32    six.raise_from(
33        ImportError(
34            "gRPC is not installed, please install the grpcio package "
35            "to use the gRPC transport."
36        ),
37        caught_exc,
38    )
39
40_LOGGER = logging.getLogger(__name__)
41
42
43class AuthMetadataPlugin(grpc.AuthMetadataPlugin):
44    """A `gRPC AuthMetadataPlugin`_ that inserts the credentials into each
45    request.
46
47    .. _gRPC AuthMetadataPlugin:
48        http://www.grpc.io/grpc/python/grpc.html#grpc.AuthMetadataPlugin
49
50    Args:
51        credentials (google.auth.credentials.Credentials): The credentials to
52            add to requests.
53        request (google.auth.transport.Request): A HTTP transport request
54            object used to refresh credentials as needed.
55        default_host (Optional[str]): A host like "pubsub.googleapis.com".
56            This is used when a self-signed JWT is created from service
57            account credentials.
58    """
59
60    def __init__(self, credentials, request, default_host=None):
61        # pylint: disable=no-value-for-parameter
62        # pylint doesn't realize that the super method takes no arguments
63        # because this class is the same name as the superclass.
64        super(AuthMetadataPlugin, self).__init__()
65        self._credentials = credentials
66        self._request = request
67        self._default_host = default_host
68
69    def _get_authorization_headers(self, context):
70        """Gets the authorization headers for a request.
71
72        Returns:
73            Sequence[Tuple[str, str]]: A list of request headers (key, value)
74                to add to the request.
75        """
76        headers = {}
77
78        # https://google.aip.dev/auth/4111
79        # Attempt to use self-signed JWTs when a service account is used.
80        # A default host must be explicitly provided since it cannot always
81        # be determined from the context.service_url.
82        if isinstance(self._credentials, service_account.Credentials):
83            self._credentials._create_self_signed_jwt(
84                "https://{}/".format(self._default_host) if self._default_host else None
85            )
86
87        self._credentials.before_request(
88            self._request, context.method_name, context.service_url, headers
89        )
90
91        return list(six.iteritems(headers))
92
93    def __call__(self, context, callback):
94        """Passes authorization metadata into the given callback.
95
96        Args:
97            context (grpc.AuthMetadataContext): The RPC context.
98            callback (grpc.AuthMetadataPluginCallback): The callback that will
99                be invoked to pass in the authorization metadata.
100        """
101        callback(self._get_authorization_headers(context), None)
102
103
104def secure_authorized_channel(
105    credentials,
106    request,
107    target,
108    ssl_credentials=None,
109    client_cert_callback=None,
110    **kwargs
111):
112    """Creates a secure authorized gRPC channel.
113
114    This creates a channel with SSL and :class:`AuthMetadataPlugin`. This
115    channel can be used to create a stub that can make authorized requests.
116    Users can configure client certificate or rely on device certificates to
117    establish a mutual TLS channel, if the `GOOGLE_API_USE_CLIENT_CERTIFICATE`
118    variable is explicitly set to `true`.
119
120    Example::
121
122        import google.auth
123        import google.auth.transport.grpc
124        import google.auth.transport.requests
125        from google.cloud.speech.v1 import cloud_speech_pb2
126
127        # Get credentials.
128        credentials, _ = google.auth.default()
129
130        # Get an HTTP request function to refresh credentials.
131        request = google.auth.transport.requests.Request()
132
133        # Create a channel.
134        channel = google.auth.transport.grpc.secure_authorized_channel(
135            credentials, regular_endpoint, request,
136            ssl_credentials=grpc.ssl_channel_credentials())
137
138        # Use the channel to create a stub.
139        cloud_speech.create_Speech_stub(channel)
140
141    Usage:
142
143    There are actually a couple of options to create a channel, depending on if
144    you want to create a regular or mutual TLS channel.
145
146    First let's list the endpoints (regular vs mutual TLS) to choose from::
147
148        regular_endpoint = 'speech.googleapis.com:443'
149        mtls_endpoint = 'speech.mtls.googleapis.com:443'
150
151    Option 1: create a regular (non-mutual) TLS channel by explicitly setting
152    the ssl_credentials::
153
154        regular_ssl_credentials = grpc.ssl_channel_credentials()
155
156        channel = google.auth.transport.grpc.secure_authorized_channel(
157            credentials, regular_endpoint, request,
158            ssl_credentials=regular_ssl_credentials)
159
160    Option 2: create a mutual TLS channel by calling a callback which returns
161    the client side certificate and the key (Note that
162    `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable must be explicitly
163    set to `true`)::
164
165        def my_client_cert_callback():
166            code_to_load_client_cert_and_key()
167            if loaded:
168                return (pem_cert_bytes, pem_key_bytes)
169            raise MyClientCertFailureException()
170
171        try:
172            channel = google.auth.transport.grpc.secure_authorized_channel(
173                credentials, mtls_endpoint, request,
174                client_cert_callback=my_client_cert_callback)
175        except MyClientCertFailureException:
176            # handle the exception
177
178    Option 3: use application default SSL credentials. It searches and uses
179    the command in a context aware metadata file, which is available on devices
180    with endpoint verification support (Note that
181    `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable must be explicitly
182    set to `true`).
183    See https://cloud.google.com/endpoint-verification/docs/overview::
184
185        try:
186            default_ssl_credentials = SslCredentials()
187        except:
188            # Exception can be raised if the context aware metadata is malformed.
189            # See :class:`SslCredentials` for the possible exceptions.
190
191        # Choose the endpoint based on the SSL credentials type.
192        if default_ssl_credentials.is_mtls:
193            endpoint_to_use = mtls_endpoint
194        else:
195            endpoint_to_use = regular_endpoint
196        channel = google.auth.transport.grpc.secure_authorized_channel(
197            credentials, endpoint_to_use, request,
198            ssl_credentials=default_ssl_credentials)
199
200    Option 4: not setting ssl_credentials and client_cert_callback. For devices
201    without endpoint verification support or `GOOGLE_API_USE_CLIENT_CERTIFICATE`
202    environment variable is not `true`, a regular TLS channel is created;
203    otherwise, a mutual TLS channel is created, however, the call should be
204    wrapped in a try/except block in case of malformed context aware metadata.
205
206    The following code uses regular_endpoint, it works the same no matter the
207    created channle is regular or mutual TLS. Regular endpoint ignores client
208    certificate and key::
209
210        channel = google.auth.transport.grpc.secure_authorized_channel(
211            credentials, regular_endpoint, request)
212
213    The following code uses mtls_endpoint, if the created channle is regular,
214    and API mtls_endpoint is confgured to require client SSL credentials, API
215    calls using this channel will be rejected::
216
217        channel = google.auth.transport.grpc.secure_authorized_channel(
218            credentials, mtls_endpoint, request)
219
220    Args:
221        credentials (google.auth.credentials.Credentials): The credentials to
222            add to requests.
223        request (google.auth.transport.Request): A HTTP transport request
224            object used to refresh credentials as needed. Even though gRPC
225            is a separate transport, there's no way to refresh the credentials
226            without using a standard http transport.
227        target (str): The host and port of the service.
228        ssl_credentials (grpc.ChannelCredentials): Optional SSL channel
229            credentials. This can be used to specify different certificates.
230            This argument is mutually exclusive with client_cert_callback;
231            providing both will raise an exception.
232            If ssl_credentials and client_cert_callback are None, application
233            default SSL credentials are used if `GOOGLE_API_USE_CLIENT_CERTIFICATE`
234            environment variable is explicitly set to `true`, otherwise one way TLS
235            SSL credentials are used.
236        client_cert_callback (Callable[[], (bytes, bytes)]): Optional
237            callback function to obtain client certicate and key for mutual TLS
238            connection. This argument is mutually exclusive with
239            ssl_credentials; providing both will raise an exception.
240            This argument does nothing unless `GOOGLE_API_USE_CLIENT_CERTIFICATE`
241            environment variable is explicitly set to `true`.
242        kwargs: Additional arguments to pass to :func:`grpc.secure_channel`.
243
244    Returns:
245        grpc.Channel: The created gRPC channel.
246
247    Raises:
248        google.auth.exceptions.MutualTLSChannelError: If mutual TLS channel
249            creation failed for any reason.
250    """
251    # Create the metadata plugin for inserting the authorization header.
252    metadata_plugin = AuthMetadataPlugin(credentials, request)
253
254    # Create a set of grpc.CallCredentials using the metadata plugin.
255    google_auth_credentials = grpc.metadata_call_credentials(metadata_plugin)
256
257    if ssl_credentials and client_cert_callback:
258        raise ValueError(
259            "Received both ssl_credentials and client_cert_callback; "
260            "these are mutually exclusive."
261        )
262
263    # If SSL credentials are not explicitly set, try client_cert_callback and ADC.
264    if not ssl_credentials:
265        use_client_cert = os.getenv(
266            environment_vars.GOOGLE_API_USE_CLIENT_CERTIFICATE, "false"
267        )
268        if use_client_cert == "true" and client_cert_callback:
269            # Use the callback if provided.
270            cert, key = client_cert_callback()
271            ssl_credentials = grpc.ssl_channel_credentials(
272                certificate_chain=cert, private_key=key
273            )
274        elif use_client_cert == "true":
275            # Use application default SSL credentials.
276            adc_ssl_credentils = SslCredentials()
277            ssl_credentials = adc_ssl_credentils.ssl_credentials
278        else:
279            ssl_credentials = grpc.ssl_channel_credentials()
280
281    # Combine the ssl credentials and the authorization credentials.
282    composite_credentials = grpc.composite_channel_credentials(
283        ssl_credentials, google_auth_credentials
284    )
285
286    return grpc.secure_channel(target, composite_credentials, **kwargs)
287
288
289class SslCredentials:
290    """Class for application default SSL credentials.
291
292    The behavior is controlled by `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment
293    variable whose default value is `false`. Client certificate will not be used
294    unless the environment variable is explicitly set to `true`. See
295    https://google.aip.dev/auth/4114
296
297    If the environment variable is `true`, then for devices with endpoint verification
298    support, a device certificate will be automatically loaded and mutual TLS will
299    be established.
300    See https://cloud.google.com/endpoint-verification/docs/overview.
301    """
302
303    def __init__(self):
304        use_client_cert = os.getenv(
305            environment_vars.GOOGLE_API_USE_CLIENT_CERTIFICATE, "false"
306        )
307        if use_client_cert != "true":
308            self._is_mtls = False
309        else:
310            # Load client SSL credentials.
311            metadata_path = _mtls_helper._check_dca_metadata_path(
312                _mtls_helper.CONTEXT_AWARE_METADATA_PATH
313            )
314            self._is_mtls = metadata_path is not None
315
316    @property
317    def ssl_credentials(self):
318        """Get the created SSL channel credentials.
319
320        For devices with endpoint verification support, if the device certificate
321        loading has any problems, corresponding exceptions will be raised. For
322        a device without endpoint verification support, no exceptions will be
323        raised.
324
325        Returns:
326            grpc.ChannelCredentials: The created grpc channel credentials.
327
328        Raises:
329            google.auth.exceptions.MutualTLSChannelError: If mutual TLS channel
330                creation failed for any reason.
331        """
332        if self._is_mtls:
333            try:
334                _, cert, key, _ = _mtls_helper.get_client_ssl_credentials()
335                self._ssl_credentials = grpc.ssl_channel_credentials(
336                    certificate_chain=cert, private_key=key
337                )
338            except exceptions.ClientCertError as caught_exc:
339                new_exc = exceptions.MutualTLSChannelError(caught_exc)
340                six.raise_from(new_exc, caught_exc)
341        else:
342            self._ssl_credentials = grpc.ssl_channel_credentials()
343
344        return self._ssl_credentials
345
346    @property
347    def is_mtls(self):
348        """Indicates if the created SSL channel credentials is mutual TLS."""
349        return self._is_mtls
350