xref: /aosp_15_r20/external/autotest/utils/frozen_chromite/third_party/infra_libs/httplib2_utils.py (revision 9c5db1993ded3edbeafc8092d69fe5de2ee02df7)
1# Copyright 2015 The Chromium Authors. All rights reserved.
2# Use of this source code is governed by a BSD-style license that can be
3# found in the LICENSE file.
4
5import base64
6import collections
7import copy
8import json
9import logging
10import re
11import socket
12import time
13
14import httplib2
15import oauth2client.client
16import six
17from six.moves import http_client as httplib
18
19from googleapiclient import errors
20from infra_libs.ts_mon.common import http_metrics
21
22# TODO(nxia): crbug.com/790760 upgrade oauth2client to 4.1.2.
23oauth2client_util_imported = False
24try:
25  from oauth2client import util
26  oauth2client_util_imported = True
27except ImportError:
28  pass
29
30
31# default timeout for http requests, in seconds
32DEFAULT_TIMEOUT = 30
33
34
35class AuthError(Exception):
36  pass
37
38
39class DelegateServiceAccountCredentials(
40    oauth2client.client.AssertionCredentials):
41  """Authorizes an HTTP client with a service account for which we are an actor.
42
43  This class uses the IAM API to sign a JWT with the private key of another
44  service account for which we have the "Service Account Actor" role.
45  """
46
47  MAX_TOKEN_LIFETIME_SECS = 3600 # 1 hour in seconds
48  _SIGN_BLOB_URL = 'https://iam.googleapis.com/v1/%s:signBlob'
49
50  def __init__(self, http, service_account_email, scopes, project='-'):
51    """
52    Args:
53      http: An httplib2.Http object that is authorized by another
54        oauth2client.client.OAuth2Credentials with credentials that have the
55        service account actor role on the service_account_email.
56      service_account_email: The email address of the service account for which
57        to obtain an access token.
58      scopes: The desired scopes for the token.
59      project: The cloud project to which service_account_email belongs.  The
60        default of '-' makes the IAM API figure it out for us.
61    """
62    if not oauth2client_util_imported:
63      raise AssertionError('Failed to import oauth2client.util.')
64    super(DelegateServiceAccountCredentials, self).__init__(None)
65    self._service_account_email = service_account_email
66    self._scopes = util.scopes_to_string(scopes)
67    self._http = http
68    self._name = 'projects/%s/serviceAccounts/%s' % (
69        project, service_account_email)
70
71  def sign_blob(self, blob):
72    response, content = self._http.request(
73        self._SIGN_BLOB_URL % self._name,
74        method='POST',
75        body=json.dumps({'bytesToSign': base64.b64encode(blob)}),
76        headers={'Content-Type': 'application/json'})
77    if response.status != 200:
78      raise AuthError('Failed to sign blob as %s: %d %s' % (
79          self._service_account_email, response.status, response.reason))
80
81    data = json.loads(content)
82    return data['keyId'], data['signature']
83
84  def _generate_assertion(self):
85    # This is copied with small modifications from
86    # oauth2client.service_account._ServiceAccountCredentials.
87
88    header = {
89        'alg': 'RS256',
90        'typ': 'JWT',
91    }
92
93    now = int(time.time())
94    payload = {
95        'aud': self.token_uri,
96        'scope': self._scopes,
97        'iat': now,
98        'exp': now + self.MAX_TOKEN_LIFETIME_SECS,
99        'iss': self._service_account_email,
100    }
101
102    assertion_input = (
103        self._urlsafe_b64encode(header) + b'.' +
104        self._urlsafe_b64encode(payload))
105
106    # Sign the assertion.
107    _, rsa_bytes = self.sign_blob(assertion_input)
108    signature = rsa_bytes.rstrip(b'=')
109
110    return assertion_input + b'.' + signature
111
112  def _urlsafe_b64encode(self, data):
113    # Copied verbatim from oauth2client.service_account.
114    return base64.urlsafe_b64encode(
115        json.dumps(data, separators=(',', ':')).encode('UTF-8')).rstrip(b'=')
116
117
118class RetriableHttp(object):
119  """A httplib2.Http object that retries on failure."""
120
121  def __init__(self, http, max_tries=5, backoff_time=1,
122               retrying_statuses_fn=None):
123    """
124    Args:
125      http: an httplib2.Http instance
126      max_tries: a number of maximum tries
127      backoff_time: a number of seconds to sleep between retries
128      retrying_statuses_fn: a function that returns True if a given status
129                            should be retried
130    """
131    self._http = http
132    self._max_tries = max_tries
133    self._backoff_time = backoff_time
134    self._retrying_statuses_fn = retrying_statuses_fn or \
135                                 set(range(500,599)).__contains__
136
137  def request(self, uri, method='GET', body=None, *args, **kwargs):
138    for i in range(1, self._max_tries + 1):
139      try:
140        response, content = self._http.request(uri, method, body, *args,
141                                               **kwargs)
142
143        if self._retrying_statuses_fn(response.status):
144          logging.info('RetriableHttp: attempt %d receiving status %d, %s',
145                       i, response.status,
146                       'final attempt' if i == self._max_tries else \
147                       'will retry')
148        else:
149          break
150      except (ValueError, errors.Error,
151              socket.timeout, socket.error, socket.herror, socket.gaierror,
152              httplib2.HttpLib2Error) as error:
153        logging.info('RetriableHttp: attempt %d received exception: %s, %s',
154                     i, error, 'final attempt' if i == self._max_tries else \
155                     'will retry')
156        if i == self._max_tries:
157          raise
158      time.sleep(self._backoff_time)
159
160    return response, content
161
162  def __getattr__(self, name):
163    return getattr(self._http, name)
164
165  def __setattr__(self, name, value):
166    if name in ('request', '_http', '_max_tries', '_backoff_time',
167                '_retrying_statuses_fn'):
168      self.__dict__[name] = value
169    else:
170      setattr(self._http, name, value)
171
172
173class InstrumentedHttp(httplib2.Http):
174  """A httplib2.Http object that reports ts_mon metrics about its requests."""
175
176  def __init__(self, name, time_fn=time.time, timeout=DEFAULT_TIMEOUT,
177               **kwargs):
178    """
179    Args:
180      name: An identifier for the HTTP requests made by this object.
181      time_fn: Function returning the current time in seconds. Use for testing
182        purposes only.
183    """
184
185    super(InstrumentedHttp, self).__init__(timeout=timeout, **kwargs)
186    self.fields = {'name': name, 'client': 'httplib2'}
187    self.time_fn = time_fn
188
189  def _update_metrics(self, status, start_time):
190    status_fields = {'status': status}
191    status_fields.update(self.fields)
192    http_metrics.response_status.increment(fields=status_fields)
193
194    duration_msec = (self.time_fn() - start_time) * 1000
195    http_metrics.durations.add(duration_msec, fields=self.fields)
196
197  def request(self, uri, method="GET", body=None, *args, **kwargs):
198    request_bytes = 0
199    if body is not None:
200      request_bytes = len(body)
201    http_metrics.request_bytes.add(request_bytes, fields=self.fields)
202
203    start_time = self.time_fn()
204    try:
205      response, content = super(InstrumentedHttp, self).request(
206          uri, method, body, *args, **kwargs)
207    except socket.timeout:
208      self._update_metrics(http_metrics.STATUS_TIMEOUT, start_time)
209      raise
210    except (socket.error, socket.herror, socket.gaierror):
211      self._update_metrics(http_metrics.STATUS_ERROR, start_time)
212      raise
213    except (httplib.HTTPException, httplib2.HttpLib2Error) as ex:
214      status = http_metrics.STATUS_EXCEPTION
215      if 'Deadline exceeded while waiting for HTTP response' in str(ex):
216        # Raised on Appengine (gae_override/httplib.py).
217        status = http_metrics.STATUS_TIMEOUT
218      self._update_metrics(status, start_time)
219      raise
220    http_metrics.response_bytes.add(len(content), fields=self.fields)
221
222    self._update_metrics(response.status, start_time)
223
224    return response, content
225
226
227class HttpMock(object):
228  """Mock of httplib2.Http"""
229  HttpCall = collections.namedtuple('HttpCall', ('uri', 'method', 'body',
230                                                 'headers'))
231
232  def __init__(self, uris):
233    """
234    Args:
235      uris(dict): list of  (uri, headers, body). `uri` is a regexp for
236        matching the requested uri, (headers, body) gives the values returned
237        by the mock. Uris are tested in the order from `uris`.
238        `headers` is a dict mapping headers to value. The 'status' key is
239        mandatory. `body` is a string.
240        Ex: [('.*', {'status': 200}, 'nicely done.')]
241    """
242    self._uris = []
243    self.requests_made = []
244
245    for value in uris:
246      if not isinstance(value, (list, tuple)) or len(value) != 3:
247        raise ValueError("'uris' must be a sequence of (uri, headers, body)")
248      uri, headers, body = value
249      compiled_uri = re.compile(uri)
250      if not isinstance(headers, dict):
251        raise TypeError("'headers' must be a dict")
252      if not 'status' in headers:
253        raise ValueError("'headers' must have 'status' as a key")
254
255      new_headers = copy.copy(headers)
256      new_headers['status'] = int(new_headers['status'])
257
258      if not isinstance(body, six.string_types):
259        raise TypeError("'body' must be a string, got %s" % type(body))
260      self._uris.append((compiled_uri, new_headers, body))
261
262  # pylint: disable=unused-argument
263  def request(self, uri,
264              method='GET',
265              body=None,
266              headers=None,
267              redirections=1,
268              connection_type=None):
269    self.requests_made.append(self.HttpCall(uri, method, body, headers))
270    headers = None
271    body = None
272    for candidate in self._uris:
273      if candidate[0].match(uri):
274        _, headers, body = candidate
275        break
276    if not headers:
277      raise AssertionError("Unexpected request to %s" % uri)
278    return httplib2.Response(headers), body
279