1# Copyright 2016 Google LLC
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#      http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15"""Transport adapter for Requests."""
16
17from __future__ import absolute_import
18
19import functools
20import logging
21import numbers
22import os
23import time
24
25try:
26    import requests
27except ImportError as caught_exc:  # pragma: NO COVER
28    import six
29
30    six.raise_from(
31        ImportError(
32            "The requests library is not installed, please install the "
33            "requests package to use the requests transport."
34        ),
35        caught_exc,
36    )
37import requests.adapters  # pylint: disable=ungrouped-imports
38import requests.exceptions  # pylint: disable=ungrouped-imports
39from requests.packages.urllib3.util.ssl_ import (
40    create_urllib3_context,
41)  # pylint: disable=ungrouped-imports
42import six  # pylint: disable=ungrouped-imports
43
44from google.auth import environment_vars
45from google.auth import exceptions
46from google.auth import transport
47import google.auth.transport._mtls_helper
48from google.oauth2 import service_account
49
50_LOGGER = logging.getLogger(__name__)
51
52_DEFAULT_TIMEOUT = 120  # in seconds
53
54
55class _Response(transport.Response):
56    """Requests transport response adapter.
57
58    Args:
59        response (requests.Response): The raw Requests response.
60    """
61
62    def __init__(self, response):
63        self._response = response
64
65    @property
66    def status(self):
67        return self._response.status_code
68
69    @property
70    def headers(self):
71        return self._response.headers
72
73    @property
74    def data(self):
75        return self._response.content
76
77
78class TimeoutGuard(object):
79    """A context manager raising an error if the suite execution took too long.
80
81    Args:
82        timeout (Union[None, Union[float, Tuple[float, float]]]):
83            The maximum number of seconds a suite can run without the context
84            manager raising a timeout exception on exit. If passed as a tuple,
85            the smaller of the values is taken as a timeout. If ``None``, a
86            timeout error is never raised.
87        timeout_error_type (Optional[Exception]):
88            The type of the error to raise on timeout. Defaults to
89            :class:`requests.exceptions.Timeout`.
90    """
91
92    def __init__(self, timeout, timeout_error_type=requests.exceptions.Timeout):
93        self._timeout = timeout
94        self.remaining_timeout = timeout
95        self._timeout_error_type = timeout_error_type
96
97    def __enter__(self):
98        self._start = time.time()
99        return self
100
101    def __exit__(self, exc_type, exc_value, traceback):
102        if exc_value:
103            return  # let the error bubble up automatically
104
105        if self._timeout is None:
106            return  # nothing to do, the timeout was not specified
107
108        elapsed = time.time() - self._start
109        deadline_hit = False
110
111        if isinstance(self._timeout, numbers.Number):
112            self.remaining_timeout = self._timeout - elapsed
113            deadline_hit = self.remaining_timeout <= 0
114        else:
115            self.remaining_timeout = tuple(x - elapsed for x in self._timeout)
116            deadline_hit = min(self.remaining_timeout) <= 0
117
118        if deadline_hit:
119            raise self._timeout_error_type()
120
121
122class Request(transport.Request):
123    """Requests request adapter.
124
125    This class is used internally for making requests using various transports
126    in a consistent way. If you use :class:`AuthorizedSession` you do not need
127    to construct or use this class directly.
128
129    This class can be useful if you want to manually refresh a
130    :class:`~google.auth.credentials.Credentials` instance::
131
132        import google.auth.transport.requests
133        import requests
134
135        request = google.auth.transport.requests.Request()
136
137        credentials.refresh(request)
138
139    Args:
140        session (requests.Session): An instance :class:`requests.Session` used
141            to make HTTP requests. If not specified, a session will be created.
142
143    .. automethod:: __call__
144    """
145
146    def __init__(self, session=None):
147        if not session:
148            session = requests.Session()
149
150        self.session = session
151
152    def __call__(
153        self,
154        url,
155        method="GET",
156        body=None,
157        headers=None,
158        timeout=_DEFAULT_TIMEOUT,
159        **kwargs
160    ):
161        """Make an HTTP request using requests.
162
163        Args:
164            url (str): The URI to be requested.
165            method (str): The HTTP method to use for the request. Defaults
166                to 'GET'.
167            body (bytes): The payload or body in HTTP request.
168            headers (Mapping[str, str]): Request headers.
169            timeout (Optional[int]): The number of seconds to wait for a
170                response from the server. If not specified or if None, the
171                requests default timeout will be used.
172            kwargs: Additional arguments passed through to the underlying
173                requests :meth:`~requests.Session.request` method.
174
175        Returns:
176            google.auth.transport.Response: The HTTP response.
177
178        Raises:
179            google.auth.exceptions.TransportError: If any exception occurred.
180        """
181        try:
182            _LOGGER.debug("Making request: %s %s", method, url)
183            response = self.session.request(
184                method, url, data=body, headers=headers, timeout=timeout, **kwargs
185            )
186            return _Response(response)
187        except requests.exceptions.RequestException as caught_exc:
188            new_exc = exceptions.TransportError(caught_exc)
189            six.raise_from(new_exc, caught_exc)
190
191
192class _MutualTlsAdapter(requests.adapters.HTTPAdapter):
193    """
194    A TransportAdapter that enables mutual TLS.
195
196    Args:
197        cert (bytes): client certificate in PEM format
198        key (bytes): client private key in PEM format
199
200    Raises:
201        ImportError: if certifi or pyOpenSSL is not installed
202        OpenSSL.crypto.Error: if client cert or key is invalid
203    """
204
205    def __init__(self, cert, key):
206        import certifi
207        from OpenSSL import crypto
208        import urllib3.contrib.pyopenssl
209
210        urllib3.contrib.pyopenssl.inject_into_urllib3()
211
212        pkey = crypto.load_privatekey(crypto.FILETYPE_PEM, key)
213        x509 = crypto.load_certificate(crypto.FILETYPE_PEM, cert)
214
215        ctx_poolmanager = create_urllib3_context()
216        ctx_poolmanager.load_verify_locations(cafile=certifi.where())
217        ctx_poolmanager._ctx.use_certificate(x509)
218        ctx_poolmanager._ctx.use_privatekey(pkey)
219        self._ctx_poolmanager = ctx_poolmanager
220
221        ctx_proxymanager = create_urllib3_context()
222        ctx_proxymanager.load_verify_locations(cafile=certifi.where())
223        ctx_proxymanager._ctx.use_certificate(x509)
224        ctx_proxymanager._ctx.use_privatekey(pkey)
225        self._ctx_proxymanager = ctx_proxymanager
226
227        super(_MutualTlsAdapter, self).__init__()
228
229    def init_poolmanager(self, *args, **kwargs):
230        kwargs["ssl_context"] = self._ctx_poolmanager
231        super(_MutualTlsAdapter, self).init_poolmanager(*args, **kwargs)
232
233    def proxy_manager_for(self, *args, **kwargs):
234        kwargs["ssl_context"] = self._ctx_proxymanager
235        return super(_MutualTlsAdapter, self).proxy_manager_for(*args, **kwargs)
236
237
238class AuthorizedSession(requests.Session):
239    """A Requests Session class with credentials.
240
241    This class is used to perform requests to API endpoints that require
242    authorization::
243
244        from google.auth.transport.requests import AuthorizedSession
245
246        authed_session = AuthorizedSession(credentials)
247
248        response = authed_session.request(
249            'GET', 'https://www.googleapis.com/storage/v1/b')
250
251
252    The underlying :meth:`request` implementation handles adding the
253    credentials' headers to the request and refreshing credentials as needed.
254
255    This class also supports mutual TLS via :meth:`configure_mtls_channel`
256    method. In order to use this method, the `GOOGLE_API_USE_CLIENT_CERTIFICATE`
257    environment variable must be explicitly set to ``true``, otherwise it does
258    nothing. Assume the environment is set to ``true``, the method behaves in the
259    following manner:
260
261    If client_cert_callback is provided, client certificate and private
262    key are loaded using the callback; if client_cert_callback is None,
263    application default SSL credentials will be used. Exceptions are raised if
264    there are problems with the certificate, private key, or the loading process,
265    so it should be called within a try/except block.
266
267    First we set the environment variable to ``true``, then create an :class:`AuthorizedSession`
268    instance and specify the endpoints::
269
270        regular_endpoint = 'https://pubsub.googleapis.com/v1/projects/{my_project_id}/topics'
271        mtls_endpoint = 'https://pubsub.mtls.googleapis.com/v1/projects/{my_project_id}/topics'
272
273        authed_session = AuthorizedSession(credentials)
274
275    Now we can pass a callback to :meth:`configure_mtls_channel`::
276
277        def my_cert_callback():
278            # some code to load client cert bytes and private key bytes, both in
279            # PEM format.
280            some_code_to_load_client_cert_and_key()
281            if loaded:
282                return cert, key
283            raise MyClientCertFailureException()
284
285        # Always call configure_mtls_channel within a try/except block.
286        try:
287            authed_session.configure_mtls_channel(my_cert_callback)
288        except:
289            # handle exceptions.
290
291        if authed_session.is_mtls:
292            response = authed_session.request('GET', mtls_endpoint)
293        else:
294            response = authed_session.request('GET', regular_endpoint)
295
296
297    You can alternatively use application default SSL credentials like this::
298
299        try:
300            authed_session.configure_mtls_channel()
301        except:
302            # handle exceptions.
303
304    Args:
305        credentials (google.auth.credentials.Credentials): The credentials to
306            add to the request.
307        refresh_status_codes (Sequence[int]): Which HTTP status codes indicate
308            that credentials should be refreshed and the request should be
309            retried.
310        max_refresh_attempts (int): The maximum number of times to attempt to
311            refresh the credentials and retry the request.
312        refresh_timeout (Optional[int]): The timeout value in seconds for
313            credential refresh HTTP requests.
314        auth_request (google.auth.transport.requests.Request):
315            (Optional) An instance of
316            :class:`~google.auth.transport.requests.Request` used when
317            refreshing credentials. If not passed,
318            an instance of :class:`~google.auth.transport.requests.Request`
319            is created.
320        default_host (Optional[str]): A host like "pubsub.googleapis.com".
321            This is used when a self-signed JWT is created from service
322            account credentials.
323    """
324
325    def __init__(
326        self,
327        credentials,
328        refresh_status_codes=transport.DEFAULT_REFRESH_STATUS_CODES,
329        max_refresh_attempts=transport.DEFAULT_MAX_REFRESH_ATTEMPTS,
330        refresh_timeout=None,
331        auth_request=None,
332        default_host=None,
333    ):
334        super(AuthorizedSession, self).__init__()
335        self.credentials = credentials
336        self._refresh_status_codes = refresh_status_codes
337        self._max_refresh_attempts = max_refresh_attempts
338        self._refresh_timeout = refresh_timeout
339        self._is_mtls = False
340        self._default_host = default_host
341
342        if auth_request is None:
343            self._auth_request_session = requests.Session()
344
345            # Using an adapter to make HTTP requests robust to network errors.
346            # This adapter retrys HTTP requests when network errors occur
347            # and the requests seems safely retryable.
348            retry_adapter = requests.adapters.HTTPAdapter(max_retries=3)
349            self._auth_request_session.mount("https://", retry_adapter)
350
351            # Do not pass `self` as the session here, as it can lead to
352            # infinite recursion.
353            auth_request = Request(self._auth_request_session)
354        else:
355            self._auth_request_session = None
356
357        # Request instance used by internal methods (for example,
358        # credentials.refresh).
359        self._auth_request = auth_request
360
361        # https://google.aip.dev/auth/4111
362        # Attempt to use self-signed JWTs when a service account is used.
363        if isinstance(self.credentials, service_account.Credentials):
364            self.credentials._create_self_signed_jwt(
365                "https://{}/".format(self._default_host) if self._default_host else None
366            )
367
368    def configure_mtls_channel(self, client_cert_callback=None):
369        """Configure the client certificate and key for SSL connection.
370
371        The function does nothing unless `GOOGLE_API_USE_CLIENT_CERTIFICATE` is
372        explicitly set to `true`. In this case if client certificate and key are
373        successfully obtained (from the given client_cert_callback or from application
374        default SSL credentials), a :class:`_MutualTlsAdapter` instance will be mounted
375        to "https://" prefix.
376
377        Args:
378            client_cert_callback (Optional[Callable[[], (bytes, bytes)]]):
379                The optional callback returns the client certificate and private
380                key bytes both in PEM format.
381                If the callback is None, application default SSL credentials
382                will be used.
383
384        Raises:
385            google.auth.exceptions.MutualTLSChannelError: If mutual TLS channel
386                creation failed for any reason.
387        """
388        use_client_cert = os.getenv(
389            environment_vars.GOOGLE_API_USE_CLIENT_CERTIFICATE, "false"
390        )
391        if use_client_cert != "true":
392            self._is_mtls = False
393            return
394
395        try:
396            import OpenSSL
397        except ImportError as caught_exc:
398            new_exc = exceptions.MutualTLSChannelError(caught_exc)
399            six.raise_from(new_exc, caught_exc)
400
401        try:
402            (
403                self._is_mtls,
404                cert,
405                key,
406            ) = google.auth.transport._mtls_helper.get_client_cert_and_key(
407                client_cert_callback
408            )
409
410            if self._is_mtls:
411                mtls_adapter = _MutualTlsAdapter(cert, key)
412                self.mount("https://", mtls_adapter)
413        except (
414            exceptions.ClientCertError,
415            ImportError,
416            OpenSSL.crypto.Error,
417        ) as caught_exc:
418            new_exc = exceptions.MutualTLSChannelError(caught_exc)
419            six.raise_from(new_exc, caught_exc)
420
421    def request(
422        self,
423        method,
424        url,
425        data=None,
426        headers=None,
427        max_allowed_time=None,
428        timeout=_DEFAULT_TIMEOUT,
429        **kwargs
430    ):
431        """Implementation of Requests' request.
432
433        Args:
434            timeout (Optional[Union[float, Tuple[float, float]]]):
435                The amount of time in seconds to wait for the server response
436                with each individual request. Can also be passed as a tuple
437                ``(connect_timeout, read_timeout)``. See :meth:`requests.Session.request`
438                documentation for details.
439            max_allowed_time (Optional[float]):
440                If the method runs longer than this, a ``Timeout`` exception is
441                automatically raised. Unlike the ``timeout`` parameter, this
442                value applies to the total method execution time, even if
443                multiple requests are made under the hood.
444
445                Mind that it is not guaranteed that the timeout error is raised
446                at ``max_allowed_time``. It might take longer, for example, if
447                an underlying request takes a lot of time, but the request
448                itself does not timeout, e.g. if a large file is being
449                transmitted. The timout error will be raised after such
450                request completes.
451        """
452        # pylint: disable=arguments-differ
453        # Requests has a ton of arguments to request, but only two
454        # (method, url) are required. We pass through all of the other
455        # arguments to super, so no need to exhaustively list them here.
456
457        # Use a kwarg for this instead of an attribute to maintain
458        # thread-safety.
459        _credential_refresh_attempt = kwargs.pop("_credential_refresh_attempt", 0)
460
461        # Make a copy of the headers. They will be modified by the credentials
462        # and we want to pass the original headers if we recurse.
463        request_headers = headers.copy() if headers is not None else {}
464
465        # Do not apply the timeout unconditionally in order to not override the
466        # _auth_request's default timeout.
467        auth_request = (
468            self._auth_request
469            if timeout is None
470            else functools.partial(self._auth_request, timeout=timeout)
471        )
472
473        remaining_time = max_allowed_time
474
475        with TimeoutGuard(remaining_time) as guard:
476            self.credentials.before_request(auth_request, method, url, request_headers)
477        remaining_time = guard.remaining_timeout
478
479        with TimeoutGuard(remaining_time) as guard:
480            response = super(AuthorizedSession, self).request(
481                method,
482                url,
483                data=data,
484                headers=request_headers,
485                timeout=timeout,
486                **kwargs
487            )
488        remaining_time = guard.remaining_timeout
489
490        # If the response indicated that the credentials needed to be
491        # refreshed, then refresh the credentials and re-attempt the
492        # request.
493        # A stored token may expire between the time it is retrieved and
494        # the time the request is made, so we may need to try twice.
495        if (
496            response.status_code in self._refresh_status_codes
497            and _credential_refresh_attempt < self._max_refresh_attempts
498        ):
499
500            _LOGGER.info(
501                "Refreshing credentials due to a %s response. Attempt %s/%s.",
502                response.status_code,
503                _credential_refresh_attempt + 1,
504                self._max_refresh_attempts,
505            )
506
507            # Do not apply the timeout unconditionally in order to not override the
508            # _auth_request's default timeout.
509            auth_request = (
510                self._auth_request
511                if timeout is None
512                else functools.partial(self._auth_request, timeout=timeout)
513            )
514
515            with TimeoutGuard(remaining_time) as guard:
516                self.credentials.refresh(auth_request)
517            remaining_time = guard.remaining_timeout
518
519            # Recurse. Pass in the original headers, not our modified set, but
520            # do pass the adjusted max allowed time (i.e. the remaining total time).
521            return self.request(
522                method,
523                url,
524                data=data,
525                headers=headers,
526                max_allowed_time=remaining_time,
527                timeout=timeout,
528                _credential_refresh_attempt=_credential_refresh_attempt + 1,
529                **kwargs
530            )
531
532        return response
533
534    @property
535    def is_mtls(self):
536        """Indicates if the created SSL channel is mutual TLS."""
537        return self._is_mtls
538
539    def close(self):
540        if self._auth_request_session is not None:
541            self._auth_request_session.close()
542        super(AuthorizedSession, self).close()
543