1# Copyright 2020 Google LLC
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#      http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15import datetime
16import json
17
18import mock
19import pytest
20from six.moves import http_client
21from six.moves import urllib
22
23from google.auth import _helpers
24from google.auth import exceptions
25from google.auth import external_account
26from google.auth import transport
27
28
29CLIENT_ID = "username"
30CLIENT_SECRET = "password"
31# Base64 encoding of "username:password"
32BASIC_AUTH_ENCODING = "dXNlcm5hbWU6cGFzc3dvcmQ="
33SERVICE_ACCOUNT_EMAIL = "[email protected]"
34# List of valid workforce pool audiences.
35TEST_USER_AUDIENCES = [
36    "//iam.googleapis.com/locations/global/workforcePools/pool-id/providers/provider-id",
37    "//iam.googleapis.com/locations/eu/workforcePools/pool-id/providers/provider-id",
38    "//iam.googleapis.com/locations/eu/workforcePools/workloadIdentityPools/providers/provider-id",
39]
40# Workload identity pool audiences or invalid workforce pool audiences.
41TEST_NON_USER_AUDIENCES = [
42    # Legacy K8s audience format.
43    "identitynamespace:1f12345:my_provider",
44    (
45        "//iam.googleapis.com/projects/123456/locations/"
46        "global/workloadIdentityPools/pool-id/providers/"
47        "provider-id"
48    ),
49    (
50        "//iam.googleapis.com/projects/123456/locations/"
51        "eu/workloadIdentityPools/pool-id/providers/"
52        "provider-id"
53    ),
54    # Pool ID with workforcePools string.
55    (
56        "//iam.googleapis.com/projects/123456/locations/"
57        "global/workloadIdentityPools/workforcePools/providers/"
58        "provider-id"
59    ),
60    # Unrealistic / incorrect workforce pool audiences.
61    "//iamgoogleapis.com/locations/eu/workforcePools/pool-id/providers/provider-id",
62    "//iam.googleapiscom/locations/eu/workforcePools/pool-id/providers/provider-id",
63    "//iam.googleapis.com/locations/workforcePools/pool-id/providers/provider-id",
64    "//iam.googleapis.com/locations/eu/workforcePool/pool-id/providers/provider-id",
65    "//iam.googleapis.com/locations//workforcePool/pool-id/providers/provider-id",
66]
67
68
69class CredentialsImpl(external_account.Credentials):
70    def __init__(
71        self,
72        audience,
73        subject_token_type,
74        token_url,
75        credential_source,
76        service_account_impersonation_url=None,
77        client_id=None,
78        client_secret=None,
79        quota_project_id=None,
80        scopes=None,
81        default_scopes=None,
82        workforce_pool_user_project=None,
83    ):
84        super(CredentialsImpl, self).__init__(
85            audience=audience,
86            subject_token_type=subject_token_type,
87            token_url=token_url,
88            credential_source=credential_source,
89            service_account_impersonation_url=service_account_impersonation_url,
90            client_id=client_id,
91            client_secret=client_secret,
92            quota_project_id=quota_project_id,
93            scopes=scopes,
94            default_scopes=default_scopes,
95            workforce_pool_user_project=workforce_pool_user_project,
96        )
97        self._counter = 0
98
99    def retrieve_subject_token(self, request):
100        counter = self._counter
101        self._counter += 1
102        return "subject_token_{}".format(counter)
103
104
105class TestCredentials(object):
106    TOKEN_URL = "https://sts.googleapis.com/v1/token"
107    PROJECT_NUMBER = "123456"
108    POOL_ID = "POOL_ID"
109    PROVIDER_ID = "PROVIDER_ID"
110    AUDIENCE = (
111        "//iam.googleapis.com/projects/{}"
112        "/locations/global/workloadIdentityPools/{}"
113        "/providers/{}"
114    ).format(PROJECT_NUMBER, POOL_ID, PROVIDER_ID)
115    WORKFORCE_AUDIENCE = (
116        "//iam.googleapis.com/locations/global/workforcePools/{}/providers/{}"
117    ).format(POOL_ID, PROVIDER_ID)
118    WORKFORCE_POOL_USER_PROJECT = "WORKFORCE_POOL_USER_PROJECT_NUMBER"
119    SUBJECT_TOKEN_TYPE = "urn:ietf:params:oauth:token-type:jwt"
120    WORKFORCE_SUBJECT_TOKEN_TYPE = "urn:ietf:params:oauth:token-type:id_token"
121    CREDENTIAL_SOURCE = {"file": "/var/run/secrets/goog.id/token"}
122    SUCCESS_RESPONSE = {
123        "access_token": "ACCESS_TOKEN",
124        "issued_token_type": "urn:ietf:params:oauth:token-type:access_token",
125        "token_type": "Bearer",
126        "expires_in": 3600,
127        "scope": "scope1 scope2",
128    }
129    ERROR_RESPONSE = {
130        "error": "invalid_request",
131        "error_description": "Invalid subject token",
132        "error_uri": "https://tools.ietf.org/html/rfc6749",
133    }
134    QUOTA_PROJECT_ID = "QUOTA_PROJECT_ID"
135    SERVICE_ACCOUNT_IMPERSONATION_URL = (
136        "https://us-east1-iamcredentials.googleapis.com/v1/projects/-"
137        + "/serviceAccounts/{}:generateAccessToken".format(SERVICE_ACCOUNT_EMAIL)
138    )
139    SCOPES = ["scope1", "scope2"]
140    IMPERSONATION_ERROR_RESPONSE = {
141        "error": {
142            "code": 400,
143            "message": "Request contains an invalid argument",
144            "status": "INVALID_ARGUMENT",
145        }
146    }
147    PROJECT_ID = "my-proj-id"
148    CLOUD_RESOURCE_MANAGER_URL = (
149        "https://cloudresourcemanager.googleapis.com/v1/projects/"
150    )
151    CLOUD_RESOURCE_MANAGER_SUCCESS_RESPONSE = {
152        "projectNumber": PROJECT_NUMBER,
153        "projectId": PROJECT_ID,
154        "lifecycleState": "ACTIVE",
155        "name": "project-name",
156        "createTime": "2018-11-06T04:42:54.109Z",
157        "parent": {"type": "folder", "id": "12345678901"},
158    }
159
160    @classmethod
161    def make_credentials(
162        cls,
163        client_id=None,
164        client_secret=None,
165        quota_project_id=None,
166        scopes=None,
167        default_scopes=None,
168        service_account_impersonation_url=None,
169    ):
170        return CredentialsImpl(
171            audience=cls.AUDIENCE,
172            subject_token_type=cls.SUBJECT_TOKEN_TYPE,
173            token_url=cls.TOKEN_URL,
174            service_account_impersonation_url=service_account_impersonation_url,
175            credential_source=cls.CREDENTIAL_SOURCE,
176            client_id=client_id,
177            client_secret=client_secret,
178            quota_project_id=quota_project_id,
179            scopes=scopes,
180            default_scopes=default_scopes,
181        )
182
183    @classmethod
184    def make_workforce_pool_credentials(
185        cls,
186        client_id=None,
187        client_secret=None,
188        quota_project_id=None,
189        scopes=None,
190        default_scopes=None,
191        service_account_impersonation_url=None,
192        workforce_pool_user_project=None,
193    ):
194        return CredentialsImpl(
195            audience=cls.WORKFORCE_AUDIENCE,
196            subject_token_type=cls.WORKFORCE_SUBJECT_TOKEN_TYPE,
197            token_url=cls.TOKEN_URL,
198            service_account_impersonation_url=service_account_impersonation_url,
199            credential_source=cls.CREDENTIAL_SOURCE,
200            client_id=client_id,
201            client_secret=client_secret,
202            quota_project_id=quota_project_id,
203            scopes=scopes,
204            default_scopes=default_scopes,
205            workforce_pool_user_project=workforce_pool_user_project,
206        )
207
208    @classmethod
209    def make_mock_request(
210        cls,
211        status=http_client.OK,
212        data=None,
213        impersonation_status=None,
214        impersonation_data=None,
215        cloud_resource_manager_status=None,
216        cloud_resource_manager_data=None,
217    ):
218        # STS token exchange request.
219        token_response = mock.create_autospec(transport.Response, instance=True)
220        token_response.status = status
221        token_response.data = json.dumps(data).encode("utf-8")
222        responses = [token_response]
223
224        # If service account impersonation is requested, mock the expected response.
225        if impersonation_status:
226            impersonation_response = mock.create_autospec(
227                transport.Response, instance=True
228            )
229            impersonation_response.status = impersonation_status
230            impersonation_response.data = json.dumps(impersonation_data).encode("utf-8")
231            responses.append(impersonation_response)
232
233        # If cloud resource manager is requested, mock the expected response.
234        if cloud_resource_manager_status:
235            cloud_resource_manager_response = mock.create_autospec(
236                transport.Response, instance=True
237            )
238            cloud_resource_manager_response.status = cloud_resource_manager_status
239            cloud_resource_manager_response.data = json.dumps(
240                cloud_resource_manager_data
241            ).encode("utf-8")
242            responses.append(cloud_resource_manager_response)
243
244        request = mock.create_autospec(transport.Request)
245        request.side_effect = responses
246
247        return request
248
249    @classmethod
250    def assert_token_request_kwargs(cls, request_kwargs, headers, request_data):
251        assert request_kwargs["url"] == cls.TOKEN_URL
252        assert request_kwargs["method"] == "POST"
253        assert request_kwargs["headers"] == headers
254        assert request_kwargs["body"] is not None
255        body_tuples = urllib.parse.parse_qsl(request_kwargs["body"])
256        for (k, v) in body_tuples:
257            assert v.decode("utf-8") == request_data[k.decode("utf-8")]
258        assert len(body_tuples) == len(request_data.keys())
259
260    @classmethod
261    def assert_impersonation_request_kwargs(cls, request_kwargs, headers, request_data):
262        assert request_kwargs["url"] == cls.SERVICE_ACCOUNT_IMPERSONATION_URL
263        assert request_kwargs["method"] == "POST"
264        assert request_kwargs["headers"] == headers
265        assert request_kwargs["body"] is not None
266        body_json = json.loads(request_kwargs["body"].decode("utf-8"))
267        assert body_json == request_data
268
269    @classmethod
270    def assert_resource_manager_request_kwargs(
271        cls, request_kwargs, project_number, headers
272    ):
273        assert request_kwargs["url"] == cls.CLOUD_RESOURCE_MANAGER_URL + project_number
274        assert request_kwargs["method"] == "GET"
275        assert request_kwargs["headers"] == headers
276        assert "body" not in request_kwargs
277
278    def test_default_state(self):
279        credentials = self.make_credentials()
280
281        # Not token acquired yet
282        assert not credentials.token
283        assert not credentials.valid
284        # Expiration hasn't been set yet
285        assert not credentials.expiry
286        assert not credentials.expired
287        # Scopes are required
288        assert not credentials.scopes
289        assert credentials.requires_scopes
290        assert not credentials.quota_project_id
291
292    def test_nonworkforce_with_workforce_pool_user_project(self):
293        with pytest.raises(ValueError) as excinfo:
294            CredentialsImpl(
295                audience=self.AUDIENCE,
296                subject_token_type=self.SUBJECT_TOKEN_TYPE,
297                token_url=self.TOKEN_URL,
298                credential_source=self.CREDENTIAL_SOURCE,
299                workforce_pool_user_project=self.WORKFORCE_POOL_USER_PROJECT,
300            )
301
302        assert excinfo.match(
303            "workforce_pool_user_project should not be set for non-workforce "
304            "pool credentials"
305        )
306
307    def test_with_scopes(self):
308        credentials = self.make_credentials()
309
310        assert not credentials.scopes
311        assert credentials.requires_scopes
312
313        scoped_credentials = credentials.with_scopes(["email"])
314
315        assert scoped_credentials.has_scopes(["email"])
316        assert not scoped_credentials.requires_scopes
317
318    def test_with_scopes_workforce_pool(self):
319        credentials = self.make_workforce_pool_credentials(
320            workforce_pool_user_project=self.WORKFORCE_POOL_USER_PROJECT
321        )
322
323        assert not credentials.scopes
324        assert credentials.requires_scopes
325
326        scoped_credentials = credentials.with_scopes(["email"])
327
328        assert scoped_credentials.has_scopes(["email"])
329        assert not scoped_credentials.requires_scopes
330        assert (
331            scoped_credentials.info.get("workforce_pool_user_project")
332            == self.WORKFORCE_POOL_USER_PROJECT
333        )
334
335    def test_with_scopes_using_user_and_default_scopes(self):
336        credentials = self.make_credentials()
337
338        assert not credentials.scopes
339        assert credentials.requires_scopes
340
341        scoped_credentials = credentials.with_scopes(
342            ["email"], default_scopes=["profile"]
343        )
344
345        assert scoped_credentials.has_scopes(["email"])
346        assert not scoped_credentials.has_scopes(["profile"])
347        assert not scoped_credentials.requires_scopes
348        assert scoped_credentials.scopes == ["email"]
349        assert scoped_credentials.default_scopes == ["profile"]
350
351    def test_with_scopes_using_default_scopes_only(self):
352        credentials = self.make_credentials()
353
354        assert not credentials.scopes
355        assert credentials.requires_scopes
356
357        scoped_credentials = credentials.with_scopes(None, default_scopes=["profile"])
358
359        assert scoped_credentials.has_scopes(["profile"])
360        assert not scoped_credentials.requires_scopes
361
362    def test_with_scopes_full_options_propagated(self):
363        credentials = self.make_credentials(
364            client_id=CLIENT_ID,
365            client_secret=CLIENT_SECRET,
366            quota_project_id=self.QUOTA_PROJECT_ID,
367            scopes=self.SCOPES,
368            default_scopes=["default1"],
369            service_account_impersonation_url=self.SERVICE_ACCOUNT_IMPERSONATION_URL,
370        )
371
372        with mock.patch.object(
373            external_account.Credentials, "__init__", return_value=None
374        ) as mock_init:
375            credentials.with_scopes(["email"], ["default2"])
376
377        # Confirm with_scopes initialized the credential with the expected
378        # parameters and scopes.
379        mock_init.assert_called_once_with(
380            audience=self.AUDIENCE,
381            subject_token_type=self.SUBJECT_TOKEN_TYPE,
382            token_url=self.TOKEN_URL,
383            credential_source=self.CREDENTIAL_SOURCE,
384            service_account_impersonation_url=self.SERVICE_ACCOUNT_IMPERSONATION_URL,
385            client_id=CLIENT_ID,
386            client_secret=CLIENT_SECRET,
387            quota_project_id=self.QUOTA_PROJECT_ID,
388            scopes=["email"],
389            default_scopes=["default2"],
390            workforce_pool_user_project=None,
391        )
392
393    def test_with_quota_project(self):
394        credentials = self.make_credentials()
395
396        assert not credentials.scopes
397        assert not credentials.quota_project_id
398
399        quota_project_creds = credentials.with_quota_project("project-foo")
400
401        assert quota_project_creds.quota_project_id == "project-foo"
402
403    def test_with_quota_project_workforce_pool(self):
404        credentials = self.make_workforce_pool_credentials(
405            workforce_pool_user_project=self.WORKFORCE_POOL_USER_PROJECT
406        )
407
408        assert not credentials.scopes
409        assert not credentials.quota_project_id
410
411        quota_project_creds = credentials.with_quota_project("project-foo")
412
413        assert quota_project_creds.quota_project_id == "project-foo"
414        assert (
415            quota_project_creds.info.get("workforce_pool_user_project")
416            == self.WORKFORCE_POOL_USER_PROJECT
417        )
418
419    def test_with_quota_project_full_options_propagated(self):
420        credentials = self.make_credentials(
421            client_id=CLIENT_ID,
422            client_secret=CLIENT_SECRET,
423            quota_project_id=self.QUOTA_PROJECT_ID,
424            scopes=self.SCOPES,
425            default_scopes=["default1"],
426            service_account_impersonation_url=self.SERVICE_ACCOUNT_IMPERSONATION_URL,
427        )
428
429        with mock.patch.object(
430            external_account.Credentials, "__init__", return_value=None
431        ) as mock_init:
432            credentials.with_quota_project("project-foo")
433
434        # Confirm with_quota_project initialized the credential with the
435        # expected parameters and quota project ID.
436        mock_init.assert_called_once_with(
437            audience=self.AUDIENCE,
438            subject_token_type=self.SUBJECT_TOKEN_TYPE,
439            token_url=self.TOKEN_URL,
440            credential_source=self.CREDENTIAL_SOURCE,
441            service_account_impersonation_url=self.SERVICE_ACCOUNT_IMPERSONATION_URL,
442            client_id=CLIENT_ID,
443            client_secret=CLIENT_SECRET,
444            quota_project_id="project-foo",
445            scopes=self.SCOPES,
446            default_scopes=["default1"],
447            workforce_pool_user_project=None,
448        )
449
450    def test_with_invalid_impersonation_target_principal(self):
451        invalid_url = "https://iamcredentials.googleapis.com/v1/invalid"
452
453        with pytest.raises(exceptions.RefreshError) as excinfo:
454            self.make_credentials(service_account_impersonation_url=invalid_url)
455
456        assert excinfo.match(
457            r"Unable to determine target principal from service account impersonation URL."
458        )
459
460    def test_info(self):
461        credentials = self.make_credentials()
462
463        assert credentials.info == {
464            "type": "external_account",
465            "audience": self.AUDIENCE,
466            "subject_token_type": self.SUBJECT_TOKEN_TYPE,
467            "token_url": self.TOKEN_URL,
468            "credential_source": self.CREDENTIAL_SOURCE.copy(),
469        }
470
471    def test_info_workforce_pool(self):
472        credentials = self.make_workforce_pool_credentials(
473            workforce_pool_user_project=self.WORKFORCE_POOL_USER_PROJECT
474        )
475
476        assert credentials.info == {
477            "type": "external_account",
478            "audience": self.WORKFORCE_AUDIENCE,
479            "subject_token_type": self.WORKFORCE_SUBJECT_TOKEN_TYPE,
480            "token_url": self.TOKEN_URL,
481            "credential_source": self.CREDENTIAL_SOURCE.copy(),
482            "workforce_pool_user_project": self.WORKFORCE_POOL_USER_PROJECT,
483        }
484
485    def test_info_with_full_options(self):
486        credentials = self.make_credentials(
487            client_id=CLIENT_ID,
488            client_secret=CLIENT_SECRET,
489            quota_project_id=self.QUOTA_PROJECT_ID,
490            service_account_impersonation_url=self.SERVICE_ACCOUNT_IMPERSONATION_URL,
491        )
492
493        assert credentials.info == {
494            "type": "external_account",
495            "audience": self.AUDIENCE,
496            "subject_token_type": self.SUBJECT_TOKEN_TYPE,
497            "token_url": self.TOKEN_URL,
498            "service_account_impersonation_url": self.SERVICE_ACCOUNT_IMPERSONATION_URL,
499            "credential_source": self.CREDENTIAL_SOURCE.copy(),
500            "quota_project_id": self.QUOTA_PROJECT_ID,
501            "client_id": CLIENT_ID,
502            "client_secret": CLIENT_SECRET,
503        }
504
505    def test_service_account_email_without_impersonation(self):
506        credentials = self.make_credentials()
507
508        assert credentials.service_account_email is None
509
510    def test_service_account_email_with_impersonation(self):
511        credentials = self.make_credentials(
512            service_account_impersonation_url=self.SERVICE_ACCOUNT_IMPERSONATION_URL
513        )
514
515        assert credentials.service_account_email == SERVICE_ACCOUNT_EMAIL
516
517    @pytest.mark.parametrize("audience", TEST_NON_USER_AUDIENCES)
518    def test_is_user_with_non_users(self, audience):
519        credentials = CredentialsImpl(
520            audience=audience,
521            subject_token_type=self.SUBJECT_TOKEN_TYPE,
522            token_url=self.TOKEN_URL,
523            credential_source=self.CREDENTIAL_SOURCE,
524        )
525
526        assert credentials.is_user is False
527
528    @pytest.mark.parametrize("audience", TEST_USER_AUDIENCES)
529    def test_is_user_with_users(self, audience):
530        credentials = CredentialsImpl(
531            audience=audience,
532            subject_token_type=self.SUBJECT_TOKEN_TYPE,
533            token_url=self.TOKEN_URL,
534            credential_source=self.CREDENTIAL_SOURCE,
535        )
536
537        assert credentials.is_user is True
538
539    @pytest.mark.parametrize("audience", TEST_USER_AUDIENCES)
540    def test_is_user_with_users_and_impersonation(self, audience):
541        # Initialize the credentials with service account impersonation.
542        credentials = CredentialsImpl(
543            audience=audience,
544            subject_token_type=self.SUBJECT_TOKEN_TYPE,
545            token_url=self.TOKEN_URL,
546            credential_source=self.CREDENTIAL_SOURCE,
547            service_account_impersonation_url=self.SERVICE_ACCOUNT_IMPERSONATION_URL,
548        )
549
550        # Even though the audience is for a workforce pool, since service account
551        # impersonation is used, the credentials will represent a service account and
552        # not a user.
553        assert credentials.is_user is False
554
555    @pytest.mark.parametrize("audience", TEST_NON_USER_AUDIENCES)
556    def test_is_workforce_pool_with_non_users(self, audience):
557        credentials = CredentialsImpl(
558            audience=audience,
559            subject_token_type=self.SUBJECT_TOKEN_TYPE,
560            token_url=self.TOKEN_URL,
561            credential_source=self.CREDENTIAL_SOURCE,
562        )
563
564        assert credentials.is_workforce_pool is False
565
566    @pytest.mark.parametrize("audience", TEST_USER_AUDIENCES)
567    def test_is_workforce_pool_with_users(self, audience):
568        credentials = CredentialsImpl(
569            audience=audience,
570            subject_token_type=self.SUBJECT_TOKEN_TYPE,
571            token_url=self.TOKEN_URL,
572            credential_source=self.CREDENTIAL_SOURCE,
573        )
574
575        assert credentials.is_workforce_pool is True
576
577    @pytest.mark.parametrize("audience", TEST_USER_AUDIENCES)
578    def test_is_workforce_pool_with_users_and_impersonation(self, audience):
579        # Initialize the credentials with workforce audience and service account
580        # impersonation.
581        credentials = CredentialsImpl(
582            audience=audience,
583            subject_token_type=self.SUBJECT_TOKEN_TYPE,
584            token_url=self.TOKEN_URL,
585            credential_source=self.CREDENTIAL_SOURCE,
586            service_account_impersonation_url=self.SERVICE_ACCOUNT_IMPERSONATION_URL,
587        )
588
589        # Even though impersonation is used, is_workforce_pool should still return True.
590        assert credentials.is_workforce_pool is True
591
592    @mock.patch("google.auth._helpers.utcnow", return_value=datetime.datetime.min)
593    def test_refresh_without_client_auth_success(self, unused_utcnow):
594        response = self.SUCCESS_RESPONSE.copy()
595        # Test custom expiration to confirm expiry is set correctly.
596        response["expires_in"] = 2800
597        expected_expiry = datetime.datetime.min + datetime.timedelta(
598            seconds=response["expires_in"]
599        )
600        headers = {"Content-Type": "application/x-www-form-urlencoded"}
601        request_data = {
602            "grant_type": "urn:ietf:params:oauth:grant-type:token-exchange",
603            "audience": self.AUDIENCE,
604            "requested_token_type": "urn:ietf:params:oauth:token-type:access_token",
605            "subject_token": "subject_token_0",
606            "subject_token_type": self.SUBJECT_TOKEN_TYPE,
607        }
608        request = self.make_mock_request(status=http_client.OK, data=response)
609        credentials = self.make_credentials()
610
611        credentials.refresh(request)
612
613        self.assert_token_request_kwargs(request.call_args[1], headers, request_data)
614        assert credentials.valid
615        assert credentials.expiry == expected_expiry
616        assert not credentials.expired
617        assert credentials.token == response["access_token"]
618
619    @mock.patch("google.auth._helpers.utcnow", return_value=datetime.datetime.min)
620    def test_refresh_workforce_without_client_auth_success(self, unused_utcnow):
621        response = self.SUCCESS_RESPONSE.copy()
622        # Test custom expiration to confirm expiry is set correctly.
623        response["expires_in"] = 2800
624        expected_expiry = datetime.datetime.min + datetime.timedelta(
625            seconds=response["expires_in"]
626        )
627        headers = {"Content-Type": "application/x-www-form-urlencoded"}
628        request_data = {
629            "grant_type": "urn:ietf:params:oauth:grant-type:token-exchange",
630            "audience": self.WORKFORCE_AUDIENCE,
631            "requested_token_type": "urn:ietf:params:oauth:token-type:access_token",
632            "subject_token": "subject_token_0",
633            "subject_token_type": self.WORKFORCE_SUBJECT_TOKEN_TYPE,
634            "options": urllib.parse.quote(
635                json.dumps({"userProject": self.WORKFORCE_POOL_USER_PROJECT})
636            ),
637        }
638        request = self.make_mock_request(status=http_client.OK, data=response)
639        credentials = self.make_workforce_pool_credentials(
640            workforce_pool_user_project=self.WORKFORCE_POOL_USER_PROJECT
641        )
642
643        credentials.refresh(request)
644
645        self.assert_token_request_kwargs(request.call_args[1], headers, request_data)
646        assert credentials.valid
647        assert credentials.expiry == expected_expiry
648        assert not credentials.expired
649        assert credentials.token == response["access_token"]
650
651    @mock.patch("google.auth._helpers.utcnow", return_value=datetime.datetime.min)
652    def test_refresh_workforce_with_client_auth_success(self, unused_utcnow):
653        response = self.SUCCESS_RESPONSE.copy()
654        # Test custom expiration to confirm expiry is set correctly.
655        response["expires_in"] = 2800
656        expected_expiry = datetime.datetime.min + datetime.timedelta(
657            seconds=response["expires_in"]
658        )
659        headers = {
660            "Content-Type": "application/x-www-form-urlencoded",
661            "Authorization": "Basic {}".format(BASIC_AUTH_ENCODING),
662        }
663        request_data = {
664            "grant_type": "urn:ietf:params:oauth:grant-type:token-exchange",
665            "audience": self.WORKFORCE_AUDIENCE,
666            "requested_token_type": "urn:ietf:params:oauth:token-type:access_token",
667            "subject_token": "subject_token_0",
668            "subject_token_type": self.WORKFORCE_SUBJECT_TOKEN_TYPE,
669        }
670        request = self.make_mock_request(status=http_client.OK, data=response)
671        # Client Auth will have higher priority over workforce_pool_user_project.
672        credentials = self.make_workforce_pool_credentials(
673            client_id=CLIENT_ID,
674            client_secret=CLIENT_SECRET,
675            workforce_pool_user_project=self.WORKFORCE_POOL_USER_PROJECT,
676        )
677
678        credentials.refresh(request)
679
680        self.assert_token_request_kwargs(request.call_args[1], headers, request_data)
681        assert credentials.valid
682        assert credentials.expiry == expected_expiry
683        assert not credentials.expired
684        assert credentials.token == response["access_token"]
685
686    @mock.patch("google.auth._helpers.utcnow", return_value=datetime.datetime.min)
687    def test_refresh_workforce_with_client_auth_and_no_workforce_project_success(
688        self, unused_utcnow
689    ):
690        response = self.SUCCESS_RESPONSE.copy()
691        # Test custom expiration to confirm expiry is set correctly.
692        response["expires_in"] = 2800
693        expected_expiry = datetime.datetime.min + datetime.timedelta(
694            seconds=response["expires_in"]
695        )
696        headers = {
697            "Content-Type": "application/x-www-form-urlencoded",
698            "Authorization": "Basic {}".format(BASIC_AUTH_ENCODING),
699        }
700        request_data = {
701            "grant_type": "urn:ietf:params:oauth:grant-type:token-exchange",
702            "audience": self.WORKFORCE_AUDIENCE,
703            "requested_token_type": "urn:ietf:params:oauth:token-type:access_token",
704            "subject_token": "subject_token_0",
705            "subject_token_type": self.WORKFORCE_SUBJECT_TOKEN_TYPE,
706        }
707        request = self.make_mock_request(status=http_client.OK, data=response)
708        # Client Auth will be sufficient for user project determination.
709        credentials = self.make_workforce_pool_credentials(
710            client_id=CLIENT_ID,
711            client_secret=CLIENT_SECRET,
712            workforce_pool_user_project=None,
713        )
714
715        credentials.refresh(request)
716
717        self.assert_token_request_kwargs(request.call_args[1], headers, request_data)
718        assert credentials.valid
719        assert credentials.expiry == expected_expiry
720        assert not credentials.expired
721        assert credentials.token == response["access_token"]
722
723    def test_refresh_impersonation_without_client_auth_success(self):
724        # Simulate service account access token expires in 2800 seconds.
725        expire_time = (
726            _helpers.utcnow().replace(microsecond=0) + datetime.timedelta(seconds=2800)
727        ).isoformat("T") + "Z"
728        expected_expiry = datetime.datetime.strptime(expire_time, "%Y-%m-%dT%H:%M:%SZ")
729        # STS token exchange request/response.
730        token_response = self.SUCCESS_RESPONSE.copy()
731        token_headers = {"Content-Type": "application/x-www-form-urlencoded"}
732        token_request_data = {
733            "grant_type": "urn:ietf:params:oauth:grant-type:token-exchange",
734            "audience": self.AUDIENCE,
735            "requested_token_type": "urn:ietf:params:oauth:token-type:access_token",
736            "subject_token": "subject_token_0",
737            "subject_token_type": self.SUBJECT_TOKEN_TYPE,
738            "scope": "https://www.googleapis.com/auth/iam",
739        }
740        # Service account impersonation request/response.
741        impersonation_response = {
742            "accessToken": "SA_ACCESS_TOKEN",
743            "expireTime": expire_time,
744        }
745        impersonation_headers = {
746            "Content-Type": "application/json",
747            "authorization": "Bearer {}".format(token_response["access_token"]),
748        }
749        impersonation_request_data = {
750            "delegates": None,
751            "scope": self.SCOPES,
752            "lifetime": "3600s",
753        }
754        # Initialize mock request to handle token exchange and service account
755        # impersonation request.
756        request = self.make_mock_request(
757            status=http_client.OK,
758            data=token_response,
759            impersonation_status=http_client.OK,
760            impersonation_data=impersonation_response,
761        )
762        # Initialize credentials with service account impersonation.
763        credentials = self.make_credentials(
764            service_account_impersonation_url=self.SERVICE_ACCOUNT_IMPERSONATION_URL,
765            scopes=self.SCOPES,
766        )
767
768        credentials.refresh(request)
769
770        # Only 2 requests should be processed.
771        assert len(request.call_args_list) == 2
772        # Verify token exchange request parameters.
773        self.assert_token_request_kwargs(
774            request.call_args_list[0][1], token_headers, token_request_data
775        )
776        # Verify service account impersonation request parameters.
777        self.assert_impersonation_request_kwargs(
778            request.call_args_list[1][1],
779            impersonation_headers,
780            impersonation_request_data,
781        )
782        assert credentials.valid
783        assert credentials.expiry == expected_expiry
784        assert not credentials.expired
785        assert credentials.token == impersonation_response["accessToken"]
786
787    def test_refresh_workforce_impersonation_without_client_auth_success(self):
788        # Simulate service account access token expires in 2800 seconds.
789        expire_time = (
790            _helpers.utcnow().replace(microsecond=0) + datetime.timedelta(seconds=2800)
791        ).isoformat("T") + "Z"
792        expected_expiry = datetime.datetime.strptime(expire_time, "%Y-%m-%dT%H:%M:%SZ")
793        # STS token exchange request/response.
794        token_response = self.SUCCESS_RESPONSE.copy()
795        token_headers = {"Content-Type": "application/x-www-form-urlencoded"}
796        token_request_data = {
797            "grant_type": "urn:ietf:params:oauth:grant-type:token-exchange",
798            "audience": self.WORKFORCE_AUDIENCE,
799            "requested_token_type": "urn:ietf:params:oauth:token-type:access_token",
800            "subject_token": "subject_token_0",
801            "subject_token_type": self.WORKFORCE_SUBJECT_TOKEN_TYPE,
802            "scope": "https://www.googleapis.com/auth/iam",
803            "options": urllib.parse.quote(
804                json.dumps({"userProject": self.WORKFORCE_POOL_USER_PROJECT})
805            ),
806        }
807        # Service account impersonation request/response.
808        impersonation_response = {
809            "accessToken": "SA_ACCESS_TOKEN",
810            "expireTime": expire_time,
811        }
812        impersonation_headers = {
813            "Content-Type": "application/json",
814            "authorization": "Bearer {}".format(token_response["access_token"]),
815        }
816        impersonation_request_data = {
817            "delegates": None,
818            "scope": self.SCOPES,
819            "lifetime": "3600s",
820        }
821        # Initialize mock request to handle token exchange and service account
822        # impersonation request.
823        request = self.make_mock_request(
824            status=http_client.OK,
825            data=token_response,
826            impersonation_status=http_client.OK,
827            impersonation_data=impersonation_response,
828        )
829        # Initialize credentials with service account impersonation.
830        credentials = self.make_workforce_pool_credentials(
831            service_account_impersonation_url=self.SERVICE_ACCOUNT_IMPERSONATION_URL,
832            scopes=self.SCOPES,
833            workforce_pool_user_project=self.WORKFORCE_POOL_USER_PROJECT,
834        )
835
836        credentials.refresh(request)
837
838        # Only 2 requests should be processed.
839        assert len(request.call_args_list) == 2
840        # Verify token exchange request parameters.
841        self.assert_token_request_kwargs(
842            request.call_args_list[0][1], token_headers, token_request_data
843        )
844        # Verify service account impersonation request parameters.
845        self.assert_impersonation_request_kwargs(
846            request.call_args_list[1][1],
847            impersonation_headers,
848            impersonation_request_data,
849        )
850        assert credentials.valid
851        assert credentials.expiry == expected_expiry
852        assert not credentials.expired
853        assert credentials.token == impersonation_response["accessToken"]
854
855    def test_refresh_without_client_auth_success_explicit_user_scopes_ignore_default_scopes(
856        self,
857    ):
858        headers = {"Content-Type": "application/x-www-form-urlencoded"}
859        request_data = {
860            "grant_type": "urn:ietf:params:oauth:grant-type:token-exchange",
861            "audience": self.AUDIENCE,
862            "requested_token_type": "urn:ietf:params:oauth:token-type:access_token",
863            "scope": "scope1 scope2",
864            "subject_token": "subject_token_0",
865            "subject_token_type": self.SUBJECT_TOKEN_TYPE,
866        }
867        request = self.make_mock_request(
868            status=http_client.OK, data=self.SUCCESS_RESPONSE
869        )
870        credentials = self.make_credentials(
871            scopes=["scope1", "scope2"],
872            # Default scopes will be ignored in favor of user scopes.
873            default_scopes=["ignored"],
874        )
875
876        credentials.refresh(request)
877
878        self.assert_token_request_kwargs(request.call_args[1], headers, request_data)
879        assert credentials.valid
880        assert not credentials.expired
881        assert credentials.token == self.SUCCESS_RESPONSE["access_token"]
882        assert credentials.has_scopes(["scope1", "scope2"])
883        assert not credentials.has_scopes(["ignored"])
884
885    def test_refresh_without_client_auth_success_explicit_default_scopes_only(self):
886        headers = {"Content-Type": "application/x-www-form-urlencoded"}
887        request_data = {
888            "grant_type": "urn:ietf:params:oauth:grant-type:token-exchange",
889            "audience": self.AUDIENCE,
890            "requested_token_type": "urn:ietf:params:oauth:token-type:access_token",
891            "scope": "scope1 scope2",
892            "subject_token": "subject_token_0",
893            "subject_token_type": self.SUBJECT_TOKEN_TYPE,
894        }
895        request = self.make_mock_request(
896            status=http_client.OK, data=self.SUCCESS_RESPONSE
897        )
898        credentials = self.make_credentials(
899            scopes=None,
900            # Default scopes will be used since user scopes are none.
901            default_scopes=["scope1", "scope2"],
902        )
903
904        credentials.refresh(request)
905
906        self.assert_token_request_kwargs(request.call_args[1], headers, request_data)
907        assert credentials.valid
908        assert not credentials.expired
909        assert credentials.token == self.SUCCESS_RESPONSE["access_token"]
910        assert credentials.has_scopes(["scope1", "scope2"])
911
912    def test_refresh_without_client_auth_error(self):
913        request = self.make_mock_request(
914            status=http_client.BAD_REQUEST, data=self.ERROR_RESPONSE
915        )
916        credentials = self.make_credentials()
917
918        with pytest.raises(exceptions.OAuthError) as excinfo:
919            credentials.refresh(request)
920
921        assert excinfo.match(
922            r"Error code invalid_request: Invalid subject token - https://tools.ietf.org/html/rfc6749"
923        )
924        assert not credentials.expired
925        assert credentials.token is None
926
927    def test_refresh_impersonation_without_client_auth_error(self):
928        request = self.make_mock_request(
929            status=http_client.OK,
930            data=self.SUCCESS_RESPONSE,
931            impersonation_status=http_client.BAD_REQUEST,
932            impersonation_data=self.IMPERSONATION_ERROR_RESPONSE,
933        )
934        credentials = self.make_credentials(
935            service_account_impersonation_url=self.SERVICE_ACCOUNT_IMPERSONATION_URL,
936            scopes=self.SCOPES,
937        )
938
939        with pytest.raises(exceptions.RefreshError) as excinfo:
940            credentials.refresh(request)
941
942        assert excinfo.match(r"Unable to acquire impersonated credentials")
943        assert not credentials.expired
944        assert credentials.token is None
945
946    def test_refresh_with_client_auth_success(self):
947        headers = {
948            "Content-Type": "application/x-www-form-urlencoded",
949            "Authorization": "Basic {}".format(BASIC_AUTH_ENCODING),
950        }
951        request_data = {
952            "grant_type": "urn:ietf:params:oauth:grant-type:token-exchange",
953            "audience": self.AUDIENCE,
954            "requested_token_type": "urn:ietf:params:oauth:token-type:access_token",
955            "subject_token": "subject_token_0",
956            "subject_token_type": self.SUBJECT_TOKEN_TYPE,
957        }
958        request = self.make_mock_request(
959            status=http_client.OK, data=self.SUCCESS_RESPONSE
960        )
961        credentials = self.make_credentials(
962            client_id=CLIENT_ID, client_secret=CLIENT_SECRET
963        )
964
965        credentials.refresh(request)
966
967        self.assert_token_request_kwargs(request.call_args[1], headers, request_data)
968        assert credentials.valid
969        assert not credentials.expired
970        assert credentials.token == self.SUCCESS_RESPONSE["access_token"]
971
972    def test_refresh_impersonation_with_client_auth_success_ignore_default_scopes(self):
973        # Simulate service account access token expires in 2800 seconds.
974        expire_time = (
975            _helpers.utcnow().replace(microsecond=0) + datetime.timedelta(seconds=2800)
976        ).isoformat("T") + "Z"
977        expected_expiry = datetime.datetime.strptime(expire_time, "%Y-%m-%dT%H:%M:%SZ")
978        # STS token exchange request/response.
979        token_response = self.SUCCESS_RESPONSE.copy()
980        token_headers = {
981            "Content-Type": "application/x-www-form-urlencoded",
982            "Authorization": "Basic {}".format(BASIC_AUTH_ENCODING),
983        }
984        token_request_data = {
985            "grant_type": "urn:ietf:params:oauth:grant-type:token-exchange",
986            "audience": self.AUDIENCE,
987            "requested_token_type": "urn:ietf:params:oauth:token-type:access_token",
988            "subject_token": "subject_token_0",
989            "subject_token_type": self.SUBJECT_TOKEN_TYPE,
990            "scope": "https://www.googleapis.com/auth/iam",
991        }
992        # Service account impersonation request/response.
993        impersonation_response = {
994            "accessToken": "SA_ACCESS_TOKEN",
995            "expireTime": expire_time,
996        }
997        impersonation_headers = {
998            "Content-Type": "application/json",
999            "authorization": "Bearer {}".format(token_response["access_token"]),
1000        }
1001        impersonation_request_data = {
1002            "delegates": None,
1003            "scope": self.SCOPES,
1004            "lifetime": "3600s",
1005        }
1006        # Initialize mock request to handle token exchange and service account
1007        # impersonation request.
1008        request = self.make_mock_request(
1009            status=http_client.OK,
1010            data=token_response,
1011            impersonation_status=http_client.OK,
1012            impersonation_data=impersonation_response,
1013        )
1014        # Initialize credentials with service account impersonation and basic auth.
1015        credentials = self.make_credentials(
1016            client_id=CLIENT_ID,
1017            client_secret=CLIENT_SECRET,
1018            service_account_impersonation_url=self.SERVICE_ACCOUNT_IMPERSONATION_URL,
1019            scopes=self.SCOPES,
1020            # Default scopes will be ignored since user scopes are specified.
1021            default_scopes=["ignored"],
1022        )
1023
1024        credentials.refresh(request)
1025
1026        # Only 2 requests should be processed.
1027        assert len(request.call_args_list) == 2
1028        # Verify token exchange request parameters.
1029        self.assert_token_request_kwargs(
1030            request.call_args_list[0][1], token_headers, token_request_data
1031        )
1032        # Verify service account impersonation request parameters.
1033        self.assert_impersonation_request_kwargs(
1034            request.call_args_list[1][1],
1035            impersonation_headers,
1036            impersonation_request_data,
1037        )
1038        assert credentials.valid
1039        assert credentials.expiry == expected_expiry
1040        assert not credentials.expired
1041        assert credentials.token == impersonation_response["accessToken"]
1042
1043    def test_refresh_impersonation_with_client_auth_success_use_default_scopes(self):
1044        # Simulate service account access token expires in 2800 seconds.
1045        expire_time = (
1046            _helpers.utcnow().replace(microsecond=0) + datetime.timedelta(seconds=2800)
1047        ).isoformat("T") + "Z"
1048        expected_expiry = datetime.datetime.strptime(expire_time, "%Y-%m-%dT%H:%M:%SZ")
1049        # STS token exchange request/response.
1050        token_response = self.SUCCESS_RESPONSE.copy()
1051        token_headers = {
1052            "Content-Type": "application/x-www-form-urlencoded",
1053            "Authorization": "Basic {}".format(BASIC_AUTH_ENCODING),
1054        }
1055        token_request_data = {
1056            "grant_type": "urn:ietf:params:oauth:grant-type:token-exchange",
1057            "audience": self.AUDIENCE,
1058            "requested_token_type": "urn:ietf:params:oauth:token-type:access_token",
1059            "subject_token": "subject_token_0",
1060            "subject_token_type": self.SUBJECT_TOKEN_TYPE,
1061            "scope": "https://www.googleapis.com/auth/iam",
1062        }
1063        # Service account impersonation request/response.
1064        impersonation_response = {
1065            "accessToken": "SA_ACCESS_TOKEN",
1066            "expireTime": expire_time,
1067        }
1068        impersonation_headers = {
1069            "Content-Type": "application/json",
1070            "authorization": "Bearer {}".format(token_response["access_token"]),
1071        }
1072        impersonation_request_data = {
1073            "delegates": None,
1074            "scope": self.SCOPES,
1075            "lifetime": "3600s",
1076        }
1077        # Initialize mock request to handle token exchange and service account
1078        # impersonation request.
1079        request = self.make_mock_request(
1080            status=http_client.OK,
1081            data=token_response,
1082            impersonation_status=http_client.OK,
1083            impersonation_data=impersonation_response,
1084        )
1085        # Initialize credentials with service account impersonation and basic auth.
1086        credentials = self.make_credentials(
1087            client_id=CLIENT_ID,
1088            client_secret=CLIENT_SECRET,
1089            service_account_impersonation_url=self.SERVICE_ACCOUNT_IMPERSONATION_URL,
1090            scopes=None,
1091            # Default scopes will be used since user specified scopes are none.
1092            default_scopes=self.SCOPES,
1093        )
1094
1095        credentials.refresh(request)
1096
1097        # Only 2 requests should be processed.
1098        assert len(request.call_args_list) == 2
1099        # Verify token exchange request parameters.
1100        self.assert_token_request_kwargs(
1101            request.call_args_list[0][1], token_headers, token_request_data
1102        )
1103        # Verify service account impersonation request parameters.
1104        self.assert_impersonation_request_kwargs(
1105            request.call_args_list[1][1],
1106            impersonation_headers,
1107            impersonation_request_data,
1108        )
1109        assert credentials.valid
1110        assert credentials.expiry == expected_expiry
1111        assert not credentials.expired
1112        assert credentials.token == impersonation_response["accessToken"]
1113
1114    def test_apply_without_quota_project_id(self):
1115        headers = {}
1116        request = self.make_mock_request(
1117            status=http_client.OK, data=self.SUCCESS_RESPONSE
1118        )
1119        credentials = self.make_credentials()
1120
1121        credentials.refresh(request)
1122        credentials.apply(headers)
1123
1124        assert headers == {
1125            "authorization": "Bearer {}".format(self.SUCCESS_RESPONSE["access_token"])
1126        }
1127
1128    def test_apply_workforce_without_quota_project_id(self):
1129        headers = {}
1130        request = self.make_mock_request(
1131            status=http_client.OK, data=self.SUCCESS_RESPONSE
1132        )
1133        credentials = self.make_workforce_pool_credentials(
1134            workforce_pool_user_project=self.WORKFORCE_POOL_USER_PROJECT
1135        )
1136
1137        credentials.refresh(request)
1138        credentials.apply(headers)
1139
1140        assert headers == {
1141            "authorization": "Bearer {}".format(self.SUCCESS_RESPONSE["access_token"])
1142        }
1143
1144    def test_apply_impersonation_without_quota_project_id(self):
1145        expire_time = (
1146            _helpers.utcnow().replace(microsecond=0) + datetime.timedelta(seconds=3600)
1147        ).isoformat("T") + "Z"
1148        # Service account impersonation response.
1149        impersonation_response = {
1150            "accessToken": "SA_ACCESS_TOKEN",
1151            "expireTime": expire_time,
1152        }
1153        # Initialize mock request to handle token exchange and service account
1154        # impersonation request.
1155        request = self.make_mock_request(
1156            status=http_client.OK,
1157            data=self.SUCCESS_RESPONSE.copy(),
1158            impersonation_status=http_client.OK,
1159            impersonation_data=impersonation_response,
1160        )
1161        # Initialize credentials with service account impersonation.
1162        credentials = self.make_credentials(
1163            service_account_impersonation_url=self.SERVICE_ACCOUNT_IMPERSONATION_URL,
1164            scopes=self.SCOPES,
1165        )
1166        headers = {}
1167
1168        credentials.refresh(request)
1169        credentials.apply(headers)
1170
1171        assert headers == {
1172            "authorization": "Bearer {}".format(impersonation_response["accessToken"])
1173        }
1174
1175    def test_apply_with_quota_project_id(self):
1176        headers = {"other": "header-value"}
1177        request = self.make_mock_request(
1178            status=http_client.OK, data=self.SUCCESS_RESPONSE
1179        )
1180        credentials = self.make_credentials(quota_project_id=self.QUOTA_PROJECT_ID)
1181
1182        credentials.refresh(request)
1183        credentials.apply(headers)
1184
1185        assert headers == {
1186            "other": "header-value",
1187            "authorization": "Bearer {}".format(self.SUCCESS_RESPONSE["access_token"]),
1188            "x-goog-user-project": self.QUOTA_PROJECT_ID,
1189        }
1190
1191    def test_apply_impersonation_with_quota_project_id(self):
1192        expire_time = (
1193            _helpers.utcnow().replace(microsecond=0) + datetime.timedelta(seconds=3600)
1194        ).isoformat("T") + "Z"
1195        # Service account impersonation response.
1196        impersonation_response = {
1197            "accessToken": "SA_ACCESS_TOKEN",
1198            "expireTime": expire_time,
1199        }
1200        # Initialize mock request to handle token exchange and service account
1201        # impersonation request.
1202        request = self.make_mock_request(
1203            status=http_client.OK,
1204            data=self.SUCCESS_RESPONSE.copy(),
1205            impersonation_status=http_client.OK,
1206            impersonation_data=impersonation_response,
1207        )
1208        # Initialize credentials with service account impersonation.
1209        credentials = self.make_credentials(
1210            service_account_impersonation_url=self.SERVICE_ACCOUNT_IMPERSONATION_URL,
1211            scopes=self.SCOPES,
1212            quota_project_id=self.QUOTA_PROJECT_ID,
1213        )
1214        headers = {"other": "header-value"}
1215
1216        credentials.refresh(request)
1217        credentials.apply(headers)
1218
1219        assert headers == {
1220            "other": "header-value",
1221            "authorization": "Bearer {}".format(impersonation_response["accessToken"]),
1222            "x-goog-user-project": self.QUOTA_PROJECT_ID,
1223        }
1224
1225    def test_before_request(self):
1226        headers = {"other": "header-value"}
1227        request = self.make_mock_request(
1228            status=http_client.OK, data=self.SUCCESS_RESPONSE
1229        )
1230        credentials = self.make_credentials()
1231
1232        # First call should call refresh, setting the token.
1233        credentials.before_request(request, "POST", "https://example.com/api", headers)
1234
1235        assert headers == {
1236            "other": "header-value",
1237            "authorization": "Bearer {}".format(self.SUCCESS_RESPONSE["access_token"]),
1238        }
1239
1240        # Second call shouldn't call refresh.
1241        credentials.before_request(request, "POST", "https://example.com/api", headers)
1242
1243        assert headers == {
1244            "other": "header-value",
1245            "authorization": "Bearer {}".format(self.SUCCESS_RESPONSE["access_token"]),
1246        }
1247
1248    def test_before_request_workforce(self):
1249        headers = {"other": "header-value"}
1250        request = self.make_mock_request(
1251            status=http_client.OK, data=self.SUCCESS_RESPONSE
1252        )
1253        credentials = self.make_workforce_pool_credentials(
1254            workforce_pool_user_project=self.WORKFORCE_POOL_USER_PROJECT
1255        )
1256
1257        # First call should call refresh, setting the token.
1258        credentials.before_request(request, "POST", "https://example.com/api", headers)
1259
1260        assert headers == {
1261            "other": "header-value",
1262            "authorization": "Bearer {}".format(self.SUCCESS_RESPONSE["access_token"]),
1263        }
1264
1265        # Second call shouldn't call refresh.
1266        credentials.before_request(request, "POST", "https://example.com/api", headers)
1267
1268        assert headers == {
1269            "other": "header-value",
1270            "authorization": "Bearer {}".format(self.SUCCESS_RESPONSE["access_token"]),
1271        }
1272
1273    def test_before_request_impersonation(self):
1274        expire_time = (
1275            _helpers.utcnow().replace(microsecond=0) + datetime.timedelta(seconds=3600)
1276        ).isoformat("T") + "Z"
1277        # Service account impersonation response.
1278        impersonation_response = {
1279            "accessToken": "SA_ACCESS_TOKEN",
1280            "expireTime": expire_time,
1281        }
1282        # Initialize mock request to handle token exchange and service account
1283        # impersonation request.
1284        request = self.make_mock_request(
1285            status=http_client.OK,
1286            data=self.SUCCESS_RESPONSE.copy(),
1287            impersonation_status=http_client.OK,
1288            impersonation_data=impersonation_response,
1289        )
1290        headers = {"other": "header-value"}
1291        credentials = self.make_credentials(
1292            service_account_impersonation_url=self.SERVICE_ACCOUNT_IMPERSONATION_URL
1293        )
1294
1295        # First call should call refresh, setting the token.
1296        credentials.before_request(request, "POST", "https://example.com/api", headers)
1297
1298        assert headers == {
1299            "other": "header-value",
1300            "authorization": "Bearer {}".format(impersonation_response["accessToken"]),
1301        }
1302
1303        # Second call shouldn't call refresh.
1304        credentials.before_request(request, "POST", "https://example.com/api", headers)
1305
1306        assert headers == {
1307            "other": "header-value",
1308            "authorization": "Bearer {}".format(impersonation_response["accessToken"]),
1309        }
1310
1311    @mock.patch("google.auth._helpers.utcnow")
1312    def test_before_request_expired(self, utcnow):
1313        headers = {}
1314        request = self.make_mock_request(
1315            status=http_client.OK, data=self.SUCCESS_RESPONSE
1316        )
1317        credentials = self.make_credentials()
1318        credentials.token = "token"
1319        utcnow.return_value = datetime.datetime.min
1320        # Set the expiration to one second more than now plus the clock skew
1321        # accomodation. These credentials should be valid.
1322        credentials.expiry = (
1323            datetime.datetime.min
1324            + _helpers.REFRESH_THRESHOLD
1325            + datetime.timedelta(seconds=1)
1326        )
1327
1328        assert credentials.valid
1329        assert not credentials.expired
1330
1331        credentials.before_request(request, "POST", "https://example.com/api", headers)
1332
1333        # Cached token should be used.
1334        assert headers == {"authorization": "Bearer token"}
1335
1336        # Next call should simulate 1 second passed.
1337        utcnow.return_value = datetime.datetime.min + datetime.timedelta(seconds=1)
1338
1339        assert not credentials.valid
1340        assert credentials.expired
1341
1342        credentials.before_request(request, "POST", "https://example.com/api", headers)
1343
1344        # New token should be retrieved.
1345        assert headers == {
1346            "authorization": "Bearer {}".format(self.SUCCESS_RESPONSE["access_token"])
1347        }
1348
1349    @mock.patch("google.auth._helpers.utcnow")
1350    def test_before_request_impersonation_expired(self, utcnow):
1351        headers = {}
1352        expire_time = (
1353            datetime.datetime.min + datetime.timedelta(seconds=3601)
1354        ).isoformat("T") + "Z"
1355        # Service account impersonation response.
1356        impersonation_response = {
1357            "accessToken": "SA_ACCESS_TOKEN",
1358            "expireTime": expire_time,
1359        }
1360        # Initialize mock request to handle token exchange and service account
1361        # impersonation request.
1362        request = self.make_mock_request(
1363            status=http_client.OK,
1364            data=self.SUCCESS_RESPONSE.copy(),
1365            impersonation_status=http_client.OK,
1366            impersonation_data=impersonation_response,
1367        )
1368        credentials = self.make_credentials(
1369            service_account_impersonation_url=self.SERVICE_ACCOUNT_IMPERSONATION_URL
1370        )
1371        credentials.token = "token"
1372        utcnow.return_value = datetime.datetime.min
1373        # Set the expiration to one second more than now plus the clock skew
1374        # accomodation. These credentials should be valid.
1375        credentials.expiry = (
1376            datetime.datetime.min
1377            + _helpers.REFRESH_THRESHOLD
1378            + datetime.timedelta(seconds=1)
1379        )
1380
1381        assert credentials.valid
1382        assert not credentials.expired
1383
1384        credentials.before_request(request, "POST", "https://example.com/api", headers)
1385
1386        # Cached token should be used.
1387        assert headers == {"authorization": "Bearer token"}
1388
1389        # Next call should simulate 1 second passed. This will trigger the expiration
1390        # threshold.
1391        utcnow.return_value = datetime.datetime.min + datetime.timedelta(seconds=1)
1392
1393        assert not credentials.valid
1394        assert credentials.expired
1395
1396        credentials.before_request(request, "POST", "https://example.com/api", headers)
1397
1398        # New token should be retrieved.
1399        assert headers == {
1400            "authorization": "Bearer {}".format(impersonation_response["accessToken"])
1401        }
1402
1403    @pytest.mark.parametrize(
1404        "audience",
1405        [
1406            # Legacy K8s audience format.
1407            "identitynamespace:1f12345:my_provider",
1408            # Unrealistic audiences.
1409            "//iam.googleapis.com/projects",
1410            "//iam.googleapis.com/projects/",
1411            "//iam.googleapis.com/project/123456",
1412            "//iam.googleapis.com/projects//123456",
1413            "//iam.googleapis.com/prefix_projects/123456",
1414            "//iam.googleapis.com/projects_suffix/123456",
1415        ],
1416    )
1417    def test_project_number_indeterminable(self, audience):
1418        credentials = CredentialsImpl(
1419            audience=audience,
1420            subject_token_type=self.SUBJECT_TOKEN_TYPE,
1421            token_url=self.TOKEN_URL,
1422            credential_source=self.CREDENTIAL_SOURCE,
1423        )
1424
1425        assert credentials.project_number is None
1426        assert credentials.get_project_id(None) is None
1427
1428    def test_project_number_determinable(self):
1429        credentials = CredentialsImpl(
1430            audience=self.AUDIENCE,
1431            subject_token_type=self.SUBJECT_TOKEN_TYPE,
1432            token_url=self.TOKEN_URL,
1433            credential_source=self.CREDENTIAL_SOURCE,
1434        )
1435
1436        assert credentials.project_number == self.PROJECT_NUMBER
1437
1438    def test_project_number_workforce(self):
1439        credentials = CredentialsImpl(
1440            audience=self.WORKFORCE_AUDIENCE,
1441            subject_token_type=self.WORKFORCE_SUBJECT_TOKEN_TYPE,
1442            token_url=self.TOKEN_URL,
1443            credential_source=self.CREDENTIAL_SOURCE,
1444            workforce_pool_user_project=self.WORKFORCE_POOL_USER_PROJECT,
1445        )
1446
1447        assert credentials.project_number is None
1448
1449    def test_project_id_without_scopes(self):
1450        # Initialize credentials with no scopes.
1451        credentials = CredentialsImpl(
1452            audience=self.AUDIENCE,
1453            subject_token_type=self.SUBJECT_TOKEN_TYPE,
1454            token_url=self.TOKEN_URL,
1455            credential_source=self.CREDENTIAL_SOURCE,
1456        )
1457
1458        assert credentials.get_project_id(None) is None
1459
1460    def test_get_project_id_cloud_resource_manager_success(self):
1461        # STS token exchange request/response.
1462        token_response = self.SUCCESS_RESPONSE.copy()
1463        token_headers = {"Content-Type": "application/x-www-form-urlencoded"}
1464        token_request_data = {
1465            "grant_type": "urn:ietf:params:oauth:grant-type:token-exchange",
1466            "audience": self.AUDIENCE,
1467            "requested_token_type": "urn:ietf:params:oauth:token-type:access_token",
1468            "subject_token": "subject_token_0",
1469            "subject_token_type": self.SUBJECT_TOKEN_TYPE,
1470            "scope": "https://www.googleapis.com/auth/iam",
1471        }
1472        # Service account impersonation request/response.
1473        expire_time = (
1474            _helpers.utcnow().replace(microsecond=0) + datetime.timedelta(seconds=3600)
1475        ).isoformat("T") + "Z"
1476        expected_expiry = datetime.datetime.strptime(expire_time, "%Y-%m-%dT%H:%M:%SZ")
1477        impersonation_response = {
1478            "accessToken": "SA_ACCESS_TOKEN",
1479            "expireTime": expire_time,
1480        }
1481        impersonation_headers = {
1482            "Content-Type": "application/json",
1483            "x-goog-user-project": self.QUOTA_PROJECT_ID,
1484            "authorization": "Bearer {}".format(token_response["access_token"]),
1485        }
1486        impersonation_request_data = {
1487            "delegates": None,
1488            "scope": self.SCOPES,
1489            "lifetime": "3600s",
1490        }
1491        # Initialize mock request to handle token exchange, service account
1492        # impersonation and cloud resource manager request.
1493        request = self.make_mock_request(
1494            status=http_client.OK,
1495            data=self.SUCCESS_RESPONSE.copy(),
1496            impersonation_status=http_client.OK,
1497            impersonation_data=impersonation_response,
1498            cloud_resource_manager_status=http_client.OK,
1499            cloud_resource_manager_data=self.CLOUD_RESOURCE_MANAGER_SUCCESS_RESPONSE,
1500        )
1501        credentials = self.make_credentials(
1502            service_account_impersonation_url=self.SERVICE_ACCOUNT_IMPERSONATION_URL,
1503            scopes=self.SCOPES,
1504            quota_project_id=self.QUOTA_PROJECT_ID,
1505        )
1506
1507        # Expected project ID from cloud resource manager response should be returned.
1508        project_id = credentials.get_project_id(request)
1509
1510        assert project_id == self.PROJECT_ID
1511        # 3 requests should be processed.
1512        assert len(request.call_args_list) == 3
1513        # Verify token exchange request parameters.
1514        self.assert_token_request_kwargs(
1515            request.call_args_list[0][1], token_headers, token_request_data
1516        )
1517        # Verify service account impersonation request parameters.
1518        self.assert_impersonation_request_kwargs(
1519            request.call_args_list[1][1],
1520            impersonation_headers,
1521            impersonation_request_data,
1522        )
1523        # In the process of getting project ID, an access token should be
1524        # retrieved.
1525        assert credentials.valid
1526        assert credentials.expiry == expected_expiry
1527        assert not credentials.expired
1528        assert credentials.token == impersonation_response["accessToken"]
1529        # Verify cloud resource manager request parameters.
1530        self.assert_resource_manager_request_kwargs(
1531            request.call_args_list[2][1],
1532            self.PROJECT_NUMBER,
1533            {
1534                "x-goog-user-project": self.QUOTA_PROJECT_ID,
1535                "authorization": "Bearer {}".format(
1536                    impersonation_response["accessToken"]
1537                ),
1538            },
1539        )
1540
1541        # Calling get_project_id again should return the cached project_id.
1542        project_id = credentials.get_project_id(request)
1543
1544        assert project_id == self.PROJECT_ID
1545        # No additional requests.
1546        assert len(request.call_args_list) == 3
1547
1548    def test_workforce_pool_get_project_id_cloud_resource_manager_success(self):
1549        # STS token exchange request/response.
1550        token_headers = {"Content-Type": "application/x-www-form-urlencoded"}
1551        token_request_data = {
1552            "grant_type": "urn:ietf:params:oauth:grant-type:token-exchange",
1553            "audience": self.WORKFORCE_AUDIENCE,
1554            "requested_token_type": "urn:ietf:params:oauth:token-type:access_token",
1555            "subject_token": "subject_token_0",
1556            "subject_token_type": self.WORKFORCE_SUBJECT_TOKEN_TYPE,
1557            "scope": "scope1 scope2",
1558            "options": urllib.parse.quote(
1559                json.dumps({"userProject": self.WORKFORCE_POOL_USER_PROJECT})
1560            ),
1561        }
1562        # Initialize mock request to handle token exchange and cloud resource
1563        # manager request.
1564        request = self.make_mock_request(
1565            status=http_client.OK,
1566            data=self.SUCCESS_RESPONSE.copy(),
1567            cloud_resource_manager_status=http_client.OK,
1568            cloud_resource_manager_data=self.CLOUD_RESOURCE_MANAGER_SUCCESS_RESPONSE,
1569        )
1570        credentials = self.make_workforce_pool_credentials(
1571            scopes=self.SCOPES,
1572            quota_project_id=self.QUOTA_PROJECT_ID,
1573            workforce_pool_user_project=self.WORKFORCE_POOL_USER_PROJECT,
1574        )
1575
1576        # Expected project ID from cloud resource manager response should be returned.
1577        project_id = credentials.get_project_id(request)
1578
1579        assert project_id == self.PROJECT_ID
1580        # 2 requests should be processed.
1581        assert len(request.call_args_list) == 2
1582        # Verify token exchange request parameters.
1583        self.assert_token_request_kwargs(
1584            request.call_args_list[0][1], token_headers, token_request_data
1585        )
1586        # In the process of getting project ID, an access token should be
1587        # retrieved.
1588        assert credentials.valid
1589        assert not credentials.expired
1590        assert credentials.token == self.SUCCESS_RESPONSE["access_token"]
1591        # Verify cloud resource manager request parameters.
1592        self.assert_resource_manager_request_kwargs(
1593            request.call_args_list[1][1],
1594            self.WORKFORCE_POOL_USER_PROJECT,
1595            {
1596                "x-goog-user-project": self.QUOTA_PROJECT_ID,
1597                "authorization": "Bearer {}".format(
1598                    self.SUCCESS_RESPONSE["access_token"]
1599                ),
1600            },
1601        )
1602
1603        # Calling get_project_id again should return the cached project_id.
1604        project_id = credentials.get_project_id(request)
1605
1606        assert project_id == self.PROJECT_ID
1607        # No additional requests.
1608        assert len(request.call_args_list) == 2
1609
1610    def test_get_project_id_cloud_resource_manager_error(self):
1611        # Simulate resource doesn't have sufficient permissions to access
1612        # cloud resource manager.
1613        request = self.make_mock_request(
1614            status=http_client.OK,
1615            data=self.SUCCESS_RESPONSE.copy(),
1616            cloud_resource_manager_status=http_client.UNAUTHORIZED,
1617        )
1618        credentials = self.make_credentials(scopes=self.SCOPES)
1619
1620        project_id = credentials.get_project_id(request)
1621
1622        assert project_id is None
1623        # Only 2 requests to STS and cloud resource manager should be sent.
1624        assert len(request.call_args_list) == 2
1625