1# Copyright 2018 Google Inc. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15import datetime 16import json 17import os 18 19import mock 20import pytest 21from six.moves import http_client 22 23from google.auth import _helpers 24from google.auth import crypt 25from google.auth import exceptions 26from google.auth import impersonated_credentials 27from google.auth import transport 28from google.auth.impersonated_credentials import Credentials 29from google.oauth2 import credentials 30from google.oauth2 import service_account 31 32DATA_DIR = os.path.join(os.path.dirname(__file__), "", "data") 33 34with open(os.path.join(DATA_DIR, "privatekey.pem"), "rb") as fh: 35 PRIVATE_KEY_BYTES = fh.read() 36 37SERVICE_ACCOUNT_JSON_FILE = os.path.join(DATA_DIR, "service_account.json") 38 39ID_TOKEN_DATA = ( 40 "eyJhbGciOiJSUzI1NiIsImtpZCI6ImRmMzc1ODkwOGI3OTIyOTNhZDk3N2Ew" 41 "Yjk5MWQ5OGE3N2Y0ZWVlY2QiLCJ0eXAiOiJKV1QifQ.eyJhdWQiOiJodHRwc" 42 "zovL2Zvby5iYXIiLCJhenAiOiIxMDIxMDE1NTA4MzQyMDA3MDg1NjgiLCJle" 43 "HAiOjE1NjQ0NzUwNTEsImlhdCI6MTU2NDQ3MTQ1MSwiaXNzIjoiaHR0cHM6L" 44 "y9hY2NvdW50cy5nb29nbGUuY29tIiwic3ViIjoiMTAyMTAxNTUwODM0MjAwN" 45 "zA4NTY4In0.redacted" 46) 47ID_TOKEN_EXPIRY = 1564475051 48 49with open(SERVICE_ACCOUNT_JSON_FILE, "r") as fh: 50 SERVICE_ACCOUNT_INFO = json.load(fh) 51 52SIGNER = crypt.RSASigner.from_string(PRIVATE_KEY_BYTES, "1") 53TOKEN_URI = "https://example.com/oauth2/token" 54 55 56@pytest.fixture 57def mock_donor_credentials(): 58 with mock.patch("google.oauth2._client.jwt_grant", autospec=True) as grant: 59 grant.return_value = ( 60 "source token", 61 _helpers.utcnow() + datetime.timedelta(seconds=500), 62 {}, 63 ) 64 yield grant 65 66 67class MockResponse: 68 def __init__(self, json_data, status_code): 69 self.json_data = json_data 70 self.status_code = status_code 71 72 def json(self): 73 return self.json_data 74 75 76@pytest.fixture 77def mock_authorizedsession_sign(): 78 with mock.patch( 79 "google.auth.transport.requests.AuthorizedSession.request", autospec=True 80 ) as auth_session: 81 data = {"keyId": "1", "signedBlob": "c2lnbmF0dXJl"} 82 auth_session.return_value = MockResponse(data, http_client.OK) 83 yield auth_session 84 85 86@pytest.fixture 87def mock_authorizedsession_idtoken(): 88 with mock.patch( 89 "google.auth.transport.requests.AuthorizedSession.request", autospec=True 90 ) as auth_session: 91 data = {"token": ID_TOKEN_DATA} 92 auth_session.return_value = MockResponse(data, http_client.OK) 93 yield auth_session 94 95 96class TestImpersonatedCredentials(object): 97 98 SERVICE_ACCOUNT_EMAIL = "[email protected]" 99 TARGET_PRINCIPAL = "[email protected]" 100 TARGET_SCOPES = ["https://www.googleapis.com/auth/devstorage.read_only"] 101 DELEGATES = [] 102 LIFETIME = 3600 103 SOURCE_CREDENTIALS = service_account.Credentials( 104 SIGNER, SERVICE_ACCOUNT_EMAIL, TOKEN_URI 105 ) 106 USER_SOURCE_CREDENTIALS = credentials.Credentials(token="ABCDE") 107 IAM_ENDPOINT_OVERRIDE = ( 108 "https://us-east1-iamcredentials.googleapis.com/v1/projects/-" 109 + "/serviceAccounts/{}:generateAccessToken".format(SERVICE_ACCOUNT_EMAIL) 110 ) 111 112 def make_credentials( 113 self, 114 source_credentials=SOURCE_CREDENTIALS, 115 lifetime=LIFETIME, 116 target_principal=TARGET_PRINCIPAL, 117 iam_endpoint_override=None, 118 ): 119 120 return Credentials( 121 source_credentials=source_credentials, 122 target_principal=target_principal, 123 target_scopes=self.TARGET_SCOPES, 124 delegates=self.DELEGATES, 125 lifetime=lifetime, 126 iam_endpoint_override=iam_endpoint_override, 127 ) 128 129 def test_make_from_user_credentials(self): 130 credentials = self.make_credentials( 131 source_credentials=self.USER_SOURCE_CREDENTIALS 132 ) 133 assert not credentials.valid 134 assert credentials.expired 135 136 def test_default_state(self): 137 credentials = self.make_credentials() 138 assert not credentials.valid 139 assert credentials.expired 140 141 def make_request( 142 self, 143 data, 144 status=http_client.OK, 145 headers=None, 146 side_effect=None, 147 use_data_bytes=True, 148 ): 149 response = mock.create_autospec(transport.Response, instance=False) 150 response.status = status 151 response.data = _helpers.to_bytes(data) if use_data_bytes else data 152 response.headers = headers or {} 153 154 request = mock.create_autospec(transport.Request, instance=False) 155 request.side_effect = side_effect 156 request.return_value = response 157 158 return request 159 160 @pytest.mark.parametrize("use_data_bytes", [True, False]) 161 def test_refresh_success(self, use_data_bytes, mock_donor_credentials): 162 credentials = self.make_credentials(lifetime=None) 163 token = "token" 164 165 expire_time = ( 166 _helpers.utcnow().replace(microsecond=0) + datetime.timedelta(seconds=500) 167 ).isoformat("T") + "Z" 168 response_body = {"accessToken": token, "expireTime": expire_time} 169 170 request = self.make_request( 171 data=json.dumps(response_body), 172 status=http_client.OK, 173 use_data_bytes=use_data_bytes, 174 ) 175 176 credentials.refresh(request) 177 178 assert credentials.valid 179 assert not credentials.expired 180 181 @pytest.mark.parametrize("use_data_bytes", [True, False]) 182 def test_refresh_success_iam_endpoint_override( 183 self, use_data_bytes, mock_donor_credentials 184 ): 185 credentials = self.make_credentials( 186 lifetime=None, iam_endpoint_override=self.IAM_ENDPOINT_OVERRIDE 187 ) 188 token = "token" 189 190 expire_time = ( 191 _helpers.utcnow().replace(microsecond=0) + datetime.timedelta(seconds=500) 192 ).isoformat("T") + "Z" 193 response_body = {"accessToken": token, "expireTime": expire_time} 194 195 request = self.make_request( 196 data=json.dumps(response_body), 197 status=http_client.OK, 198 use_data_bytes=use_data_bytes, 199 ) 200 201 credentials.refresh(request) 202 203 assert credentials.valid 204 assert not credentials.expired 205 # Confirm override endpoint used. 206 request_kwargs = request.call_args[1] 207 assert request_kwargs["url"] == self.IAM_ENDPOINT_OVERRIDE 208 209 @pytest.mark.parametrize("time_skew", [100, -100]) 210 def test_refresh_source_credentials(self, time_skew): 211 credentials = self.make_credentials(lifetime=None) 212 213 # Source credentials is refreshed only if it is expired within 214 # _helpers.REFRESH_THRESHOLD from now. We add a time_skew to the expiry, so 215 # source credentials is refreshed only if time_skew <= 0. 216 credentials._source_credentials.expiry = ( 217 _helpers.utcnow() 218 + _helpers.REFRESH_THRESHOLD 219 + datetime.timedelta(seconds=time_skew) 220 ) 221 credentials._source_credentials.token = "Token" 222 223 with mock.patch( 224 "google.oauth2.service_account.Credentials.refresh", autospec=True 225 ) as source_cred_refresh: 226 expire_time = ( 227 _helpers.utcnow().replace(microsecond=0) 228 + datetime.timedelta(seconds=500) 229 ).isoformat("T") + "Z" 230 response_body = {"accessToken": "token", "expireTime": expire_time} 231 request = self.make_request( 232 data=json.dumps(response_body), status=http_client.OK 233 ) 234 235 credentials.refresh(request) 236 237 assert credentials.valid 238 assert not credentials.expired 239 240 # Source credentials is refreshed only if it is expired within 241 # _helpers.REFRESH_THRESHOLD 242 if time_skew > 0: 243 source_cred_refresh.assert_not_called() 244 else: 245 source_cred_refresh.assert_called_once() 246 247 def test_refresh_failure_malformed_expire_time(self, mock_donor_credentials): 248 credentials = self.make_credentials(lifetime=None) 249 token = "token" 250 251 expire_time = (_helpers.utcnow() + datetime.timedelta(seconds=500)).isoformat( 252 "T" 253 ) 254 response_body = {"accessToken": token, "expireTime": expire_time} 255 256 request = self.make_request( 257 data=json.dumps(response_body), status=http_client.OK 258 ) 259 260 with pytest.raises(exceptions.RefreshError) as excinfo: 261 credentials.refresh(request) 262 263 assert excinfo.match(impersonated_credentials._REFRESH_ERROR) 264 265 assert not credentials.valid 266 assert credentials.expired 267 268 def test_refresh_failure_unauthorzed(self, mock_donor_credentials): 269 credentials = self.make_credentials(lifetime=None) 270 271 response_body = { 272 "error": { 273 "code": 403, 274 "message": "The caller does not have permission", 275 "status": "PERMISSION_DENIED", 276 } 277 } 278 279 request = self.make_request( 280 data=json.dumps(response_body), status=http_client.UNAUTHORIZED 281 ) 282 283 with pytest.raises(exceptions.RefreshError) as excinfo: 284 credentials.refresh(request) 285 286 assert excinfo.match(impersonated_credentials._REFRESH_ERROR) 287 288 assert not credentials.valid 289 assert credentials.expired 290 291 def test_refresh_failure_http_error(self, mock_donor_credentials): 292 credentials = self.make_credentials(lifetime=None) 293 294 response_body = {} 295 296 request = self.make_request( 297 data=json.dumps(response_body), status=http_client.HTTPException 298 ) 299 300 with pytest.raises(exceptions.RefreshError) as excinfo: 301 credentials.refresh(request) 302 303 assert excinfo.match(impersonated_credentials._REFRESH_ERROR) 304 305 assert not credentials.valid 306 assert credentials.expired 307 308 def test_expired(self): 309 credentials = self.make_credentials(lifetime=None) 310 assert credentials.expired 311 312 def test_signer(self): 313 credentials = self.make_credentials() 314 assert isinstance(credentials.signer, impersonated_credentials.Credentials) 315 316 def test_signer_email(self): 317 credentials = self.make_credentials(target_principal=self.TARGET_PRINCIPAL) 318 assert credentials.signer_email == self.TARGET_PRINCIPAL 319 320 def test_service_account_email(self): 321 credentials = self.make_credentials(target_principal=self.TARGET_PRINCIPAL) 322 assert credentials.service_account_email == self.TARGET_PRINCIPAL 323 324 def test_sign_bytes(self, mock_donor_credentials, mock_authorizedsession_sign): 325 credentials = self.make_credentials(lifetime=None) 326 token = "token" 327 328 expire_time = ( 329 _helpers.utcnow().replace(microsecond=0) + datetime.timedelta(seconds=500) 330 ).isoformat("T") + "Z" 331 token_response_body = {"accessToken": token, "expireTime": expire_time} 332 333 response = mock.create_autospec(transport.Response, instance=False) 334 response.status = http_client.OK 335 response.data = _helpers.to_bytes(json.dumps(token_response_body)) 336 337 request = mock.create_autospec(transport.Request, instance=False) 338 request.return_value = response 339 340 credentials.refresh(request) 341 342 assert credentials.valid 343 assert not credentials.expired 344 345 signature = credentials.sign_bytes(b"signed bytes") 346 assert signature == b"signature" 347 348 def test_sign_bytes_failure(self): 349 credentials = self.make_credentials(lifetime=None) 350 351 with mock.patch( 352 "google.auth.transport.requests.AuthorizedSession.request", autospec=True 353 ) as auth_session: 354 data = {"error": {"code": 403, "message": "unauthorized"}} 355 auth_session.return_value = MockResponse(data, http_client.FORBIDDEN) 356 357 with pytest.raises(exceptions.TransportError) as excinfo: 358 credentials.sign_bytes(b"foo") 359 assert excinfo.match("'code': 403") 360 361 def test_with_quota_project(self): 362 credentials = self.make_credentials() 363 364 quota_project_creds = credentials.with_quota_project("project-foo") 365 assert quota_project_creds._quota_project_id == "project-foo" 366 367 @pytest.mark.parametrize("use_data_bytes", [True, False]) 368 def test_with_quota_project_iam_endpoint_override( 369 self, use_data_bytes, mock_donor_credentials 370 ): 371 credentials = self.make_credentials( 372 lifetime=None, iam_endpoint_override=self.IAM_ENDPOINT_OVERRIDE 373 ) 374 token = "token" 375 # iam_endpoint_override should be copied to created credentials. 376 quota_project_creds = credentials.with_quota_project("project-foo") 377 378 expire_time = ( 379 _helpers.utcnow().replace(microsecond=0) + datetime.timedelta(seconds=500) 380 ).isoformat("T") + "Z" 381 response_body = {"accessToken": token, "expireTime": expire_time} 382 383 request = self.make_request( 384 data=json.dumps(response_body), 385 status=http_client.OK, 386 use_data_bytes=use_data_bytes, 387 ) 388 389 quota_project_creds.refresh(request) 390 391 assert quota_project_creds.valid 392 assert not quota_project_creds.expired 393 # Confirm override endpoint used. 394 request_kwargs = request.call_args[1] 395 assert request_kwargs["url"] == self.IAM_ENDPOINT_OVERRIDE 396 397 def test_id_token_success( 398 self, mock_donor_credentials, mock_authorizedsession_idtoken 399 ): 400 credentials = self.make_credentials(lifetime=None) 401 token = "token" 402 target_audience = "https://foo.bar" 403 404 expire_time = ( 405 _helpers.utcnow().replace(microsecond=0) + datetime.timedelta(seconds=500) 406 ).isoformat("T") + "Z" 407 response_body = {"accessToken": token, "expireTime": expire_time} 408 409 request = self.make_request( 410 data=json.dumps(response_body), status=http_client.OK 411 ) 412 413 credentials.refresh(request) 414 415 assert credentials.valid 416 assert not credentials.expired 417 418 id_creds = impersonated_credentials.IDTokenCredentials( 419 credentials, target_audience=target_audience 420 ) 421 id_creds.refresh(request) 422 423 assert id_creds.token == ID_TOKEN_DATA 424 assert id_creds.expiry == datetime.datetime.fromtimestamp(ID_TOKEN_EXPIRY) 425 426 def test_id_token_from_credential( 427 self, mock_donor_credentials, mock_authorizedsession_idtoken 428 ): 429 credentials = self.make_credentials(lifetime=None) 430 token = "token" 431 target_audience = "https://foo.bar" 432 433 expire_time = ( 434 _helpers.utcnow().replace(microsecond=0) + datetime.timedelta(seconds=500) 435 ).isoformat("T") + "Z" 436 response_body = {"accessToken": token, "expireTime": expire_time} 437 438 request = self.make_request( 439 data=json.dumps(response_body), status=http_client.OK 440 ) 441 442 credentials.refresh(request) 443 444 assert credentials.valid 445 assert not credentials.expired 446 447 id_creds = impersonated_credentials.IDTokenCredentials( 448 credentials, target_audience=target_audience, include_email=True 449 ) 450 id_creds = id_creds.from_credentials(target_credentials=credentials) 451 id_creds.refresh(request) 452 453 assert id_creds.token == ID_TOKEN_DATA 454 assert id_creds._include_email is True 455 456 def test_id_token_with_target_audience( 457 self, mock_donor_credentials, mock_authorizedsession_idtoken 458 ): 459 credentials = self.make_credentials(lifetime=None) 460 token = "token" 461 target_audience = "https://foo.bar" 462 463 expire_time = ( 464 _helpers.utcnow().replace(microsecond=0) + datetime.timedelta(seconds=500) 465 ).isoformat("T") + "Z" 466 response_body = {"accessToken": token, "expireTime": expire_time} 467 468 request = self.make_request( 469 data=json.dumps(response_body), status=http_client.OK 470 ) 471 472 credentials.refresh(request) 473 474 assert credentials.valid 475 assert not credentials.expired 476 477 id_creds = impersonated_credentials.IDTokenCredentials( 478 credentials, include_email=True 479 ) 480 id_creds = id_creds.with_target_audience(target_audience=target_audience) 481 id_creds.refresh(request) 482 483 assert id_creds.token == ID_TOKEN_DATA 484 assert id_creds.expiry == datetime.datetime.fromtimestamp(ID_TOKEN_EXPIRY) 485 assert id_creds._include_email is True 486 487 def test_id_token_invalid_cred( 488 self, mock_donor_credentials, mock_authorizedsession_idtoken 489 ): 490 credentials = None 491 492 with pytest.raises(exceptions.GoogleAuthError) as excinfo: 493 impersonated_credentials.IDTokenCredentials(credentials) 494 495 assert excinfo.match("Provided Credential must be" " impersonated_credentials") 496 497 def test_id_token_with_include_email( 498 self, mock_donor_credentials, mock_authorizedsession_idtoken 499 ): 500 credentials = self.make_credentials(lifetime=None) 501 token = "token" 502 target_audience = "https://foo.bar" 503 504 expire_time = ( 505 _helpers.utcnow().replace(microsecond=0) + datetime.timedelta(seconds=500) 506 ).isoformat("T") + "Z" 507 response_body = {"accessToken": token, "expireTime": expire_time} 508 509 request = self.make_request( 510 data=json.dumps(response_body), status=http_client.OK 511 ) 512 513 credentials.refresh(request) 514 515 assert credentials.valid 516 assert not credentials.expired 517 518 id_creds = impersonated_credentials.IDTokenCredentials( 519 credentials, target_audience=target_audience 520 ) 521 id_creds = id_creds.with_include_email(True) 522 id_creds.refresh(request) 523 524 assert id_creds.token == ID_TOKEN_DATA 525 526 def test_id_token_with_quota_project( 527 self, mock_donor_credentials, mock_authorizedsession_idtoken 528 ): 529 credentials = self.make_credentials(lifetime=None) 530 token = "token" 531 target_audience = "https://foo.bar" 532 533 expire_time = ( 534 _helpers.utcnow().replace(microsecond=0) + datetime.timedelta(seconds=500) 535 ).isoformat("T") + "Z" 536 response_body = {"accessToken": token, "expireTime": expire_time} 537 538 request = self.make_request( 539 data=json.dumps(response_body), status=http_client.OK 540 ) 541 542 credentials.refresh(request) 543 544 assert credentials.valid 545 assert not credentials.expired 546 547 id_creds = impersonated_credentials.IDTokenCredentials( 548 credentials, target_audience=target_audience 549 ) 550 id_creds = id_creds.with_quota_project("project-foo") 551 id_creds.refresh(request) 552 553 assert id_creds.quota_project_id == "project-foo" 554