1# Copyright 2018 Google Inc.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#      http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15import datetime
16import json
17import os
18
19import mock
20import pytest
21from six.moves import http_client
22
23from google.auth import _helpers
24from google.auth import crypt
25from google.auth import exceptions
26from google.auth import impersonated_credentials
27from google.auth import transport
28from google.auth.impersonated_credentials import Credentials
29from google.oauth2 import credentials
30from google.oauth2 import service_account
31
32DATA_DIR = os.path.join(os.path.dirname(__file__), "", "data")
33
34with open(os.path.join(DATA_DIR, "privatekey.pem"), "rb") as fh:
35    PRIVATE_KEY_BYTES = fh.read()
36
37SERVICE_ACCOUNT_JSON_FILE = os.path.join(DATA_DIR, "service_account.json")
38
39ID_TOKEN_DATA = (
40    "eyJhbGciOiJSUzI1NiIsImtpZCI6ImRmMzc1ODkwOGI3OTIyOTNhZDk3N2Ew"
41    "Yjk5MWQ5OGE3N2Y0ZWVlY2QiLCJ0eXAiOiJKV1QifQ.eyJhdWQiOiJodHRwc"
42    "zovL2Zvby5iYXIiLCJhenAiOiIxMDIxMDE1NTA4MzQyMDA3MDg1NjgiLCJle"
43    "HAiOjE1NjQ0NzUwNTEsImlhdCI6MTU2NDQ3MTQ1MSwiaXNzIjoiaHR0cHM6L"
44    "y9hY2NvdW50cy5nb29nbGUuY29tIiwic3ViIjoiMTAyMTAxNTUwODM0MjAwN"
45    "zA4NTY4In0.redacted"
46)
47ID_TOKEN_EXPIRY = 1564475051
48
49with open(SERVICE_ACCOUNT_JSON_FILE, "r") as fh:
50    SERVICE_ACCOUNT_INFO = json.load(fh)
51
52SIGNER = crypt.RSASigner.from_string(PRIVATE_KEY_BYTES, "1")
53TOKEN_URI = "https://example.com/oauth2/token"
54
55
56@pytest.fixture
57def mock_donor_credentials():
58    with mock.patch("google.oauth2._client.jwt_grant", autospec=True) as grant:
59        grant.return_value = (
60            "source token",
61            _helpers.utcnow() + datetime.timedelta(seconds=500),
62            {},
63        )
64        yield grant
65
66
67class MockResponse:
68    def __init__(self, json_data, status_code):
69        self.json_data = json_data
70        self.status_code = status_code
71
72    def json(self):
73        return self.json_data
74
75
76@pytest.fixture
77def mock_authorizedsession_sign():
78    with mock.patch(
79        "google.auth.transport.requests.AuthorizedSession.request", autospec=True
80    ) as auth_session:
81        data = {"keyId": "1", "signedBlob": "c2lnbmF0dXJl"}
82        auth_session.return_value = MockResponse(data, http_client.OK)
83        yield auth_session
84
85
86@pytest.fixture
87def mock_authorizedsession_idtoken():
88    with mock.patch(
89        "google.auth.transport.requests.AuthorizedSession.request", autospec=True
90    ) as auth_session:
91        data = {"token": ID_TOKEN_DATA}
92        auth_session.return_value = MockResponse(data, http_client.OK)
93        yield auth_session
94
95
96class TestImpersonatedCredentials(object):
97
98    SERVICE_ACCOUNT_EMAIL = "[email protected]"
99    TARGET_PRINCIPAL = "[email protected]"
100    TARGET_SCOPES = ["https://www.googleapis.com/auth/devstorage.read_only"]
101    DELEGATES = []
102    LIFETIME = 3600
103    SOURCE_CREDENTIALS = service_account.Credentials(
104        SIGNER, SERVICE_ACCOUNT_EMAIL, TOKEN_URI
105    )
106    USER_SOURCE_CREDENTIALS = credentials.Credentials(token="ABCDE")
107    IAM_ENDPOINT_OVERRIDE = (
108        "https://us-east1-iamcredentials.googleapis.com/v1/projects/-"
109        + "/serviceAccounts/{}:generateAccessToken".format(SERVICE_ACCOUNT_EMAIL)
110    )
111
112    def make_credentials(
113        self,
114        source_credentials=SOURCE_CREDENTIALS,
115        lifetime=LIFETIME,
116        target_principal=TARGET_PRINCIPAL,
117        iam_endpoint_override=None,
118    ):
119
120        return Credentials(
121            source_credentials=source_credentials,
122            target_principal=target_principal,
123            target_scopes=self.TARGET_SCOPES,
124            delegates=self.DELEGATES,
125            lifetime=lifetime,
126            iam_endpoint_override=iam_endpoint_override,
127        )
128
129    def test_make_from_user_credentials(self):
130        credentials = self.make_credentials(
131            source_credentials=self.USER_SOURCE_CREDENTIALS
132        )
133        assert not credentials.valid
134        assert credentials.expired
135
136    def test_default_state(self):
137        credentials = self.make_credentials()
138        assert not credentials.valid
139        assert credentials.expired
140
141    def make_request(
142        self,
143        data,
144        status=http_client.OK,
145        headers=None,
146        side_effect=None,
147        use_data_bytes=True,
148    ):
149        response = mock.create_autospec(transport.Response, instance=False)
150        response.status = status
151        response.data = _helpers.to_bytes(data) if use_data_bytes else data
152        response.headers = headers or {}
153
154        request = mock.create_autospec(transport.Request, instance=False)
155        request.side_effect = side_effect
156        request.return_value = response
157
158        return request
159
160    @pytest.mark.parametrize("use_data_bytes", [True, False])
161    def test_refresh_success(self, use_data_bytes, mock_donor_credentials):
162        credentials = self.make_credentials(lifetime=None)
163        token = "token"
164
165        expire_time = (
166            _helpers.utcnow().replace(microsecond=0) + datetime.timedelta(seconds=500)
167        ).isoformat("T") + "Z"
168        response_body = {"accessToken": token, "expireTime": expire_time}
169
170        request = self.make_request(
171            data=json.dumps(response_body),
172            status=http_client.OK,
173            use_data_bytes=use_data_bytes,
174        )
175
176        credentials.refresh(request)
177
178        assert credentials.valid
179        assert not credentials.expired
180
181    @pytest.mark.parametrize("use_data_bytes", [True, False])
182    def test_refresh_success_iam_endpoint_override(
183        self, use_data_bytes, mock_donor_credentials
184    ):
185        credentials = self.make_credentials(
186            lifetime=None, iam_endpoint_override=self.IAM_ENDPOINT_OVERRIDE
187        )
188        token = "token"
189
190        expire_time = (
191            _helpers.utcnow().replace(microsecond=0) + datetime.timedelta(seconds=500)
192        ).isoformat("T") + "Z"
193        response_body = {"accessToken": token, "expireTime": expire_time}
194
195        request = self.make_request(
196            data=json.dumps(response_body),
197            status=http_client.OK,
198            use_data_bytes=use_data_bytes,
199        )
200
201        credentials.refresh(request)
202
203        assert credentials.valid
204        assert not credentials.expired
205        # Confirm override endpoint used.
206        request_kwargs = request.call_args[1]
207        assert request_kwargs["url"] == self.IAM_ENDPOINT_OVERRIDE
208
209    @pytest.mark.parametrize("time_skew", [100, -100])
210    def test_refresh_source_credentials(self, time_skew):
211        credentials = self.make_credentials(lifetime=None)
212
213        # Source credentials is refreshed only if it is expired within
214        # _helpers.REFRESH_THRESHOLD from now. We add a time_skew to the expiry, so
215        # source credentials is refreshed only if time_skew <= 0.
216        credentials._source_credentials.expiry = (
217            _helpers.utcnow()
218            + _helpers.REFRESH_THRESHOLD
219            + datetime.timedelta(seconds=time_skew)
220        )
221        credentials._source_credentials.token = "Token"
222
223        with mock.patch(
224            "google.oauth2.service_account.Credentials.refresh", autospec=True
225        ) as source_cred_refresh:
226            expire_time = (
227                _helpers.utcnow().replace(microsecond=0)
228                + datetime.timedelta(seconds=500)
229            ).isoformat("T") + "Z"
230            response_body = {"accessToken": "token", "expireTime": expire_time}
231            request = self.make_request(
232                data=json.dumps(response_body), status=http_client.OK
233            )
234
235            credentials.refresh(request)
236
237            assert credentials.valid
238            assert not credentials.expired
239
240            # Source credentials is refreshed only if it is expired within
241            # _helpers.REFRESH_THRESHOLD
242            if time_skew > 0:
243                source_cred_refresh.assert_not_called()
244            else:
245                source_cred_refresh.assert_called_once()
246
247    def test_refresh_failure_malformed_expire_time(self, mock_donor_credentials):
248        credentials = self.make_credentials(lifetime=None)
249        token = "token"
250
251        expire_time = (_helpers.utcnow() + datetime.timedelta(seconds=500)).isoformat(
252            "T"
253        )
254        response_body = {"accessToken": token, "expireTime": expire_time}
255
256        request = self.make_request(
257            data=json.dumps(response_body), status=http_client.OK
258        )
259
260        with pytest.raises(exceptions.RefreshError) as excinfo:
261            credentials.refresh(request)
262
263        assert excinfo.match(impersonated_credentials._REFRESH_ERROR)
264
265        assert not credentials.valid
266        assert credentials.expired
267
268    def test_refresh_failure_unauthorzed(self, mock_donor_credentials):
269        credentials = self.make_credentials(lifetime=None)
270
271        response_body = {
272            "error": {
273                "code": 403,
274                "message": "The caller does not have permission",
275                "status": "PERMISSION_DENIED",
276            }
277        }
278
279        request = self.make_request(
280            data=json.dumps(response_body), status=http_client.UNAUTHORIZED
281        )
282
283        with pytest.raises(exceptions.RefreshError) as excinfo:
284            credentials.refresh(request)
285
286        assert excinfo.match(impersonated_credentials._REFRESH_ERROR)
287
288        assert not credentials.valid
289        assert credentials.expired
290
291    def test_refresh_failure_http_error(self, mock_donor_credentials):
292        credentials = self.make_credentials(lifetime=None)
293
294        response_body = {}
295
296        request = self.make_request(
297            data=json.dumps(response_body), status=http_client.HTTPException
298        )
299
300        with pytest.raises(exceptions.RefreshError) as excinfo:
301            credentials.refresh(request)
302
303        assert excinfo.match(impersonated_credentials._REFRESH_ERROR)
304
305        assert not credentials.valid
306        assert credentials.expired
307
308    def test_expired(self):
309        credentials = self.make_credentials(lifetime=None)
310        assert credentials.expired
311
312    def test_signer(self):
313        credentials = self.make_credentials()
314        assert isinstance(credentials.signer, impersonated_credentials.Credentials)
315
316    def test_signer_email(self):
317        credentials = self.make_credentials(target_principal=self.TARGET_PRINCIPAL)
318        assert credentials.signer_email == self.TARGET_PRINCIPAL
319
320    def test_service_account_email(self):
321        credentials = self.make_credentials(target_principal=self.TARGET_PRINCIPAL)
322        assert credentials.service_account_email == self.TARGET_PRINCIPAL
323
324    def test_sign_bytes(self, mock_donor_credentials, mock_authorizedsession_sign):
325        credentials = self.make_credentials(lifetime=None)
326        token = "token"
327
328        expire_time = (
329            _helpers.utcnow().replace(microsecond=0) + datetime.timedelta(seconds=500)
330        ).isoformat("T") + "Z"
331        token_response_body = {"accessToken": token, "expireTime": expire_time}
332
333        response = mock.create_autospec(transport.Response, instance=False)
334        response.status = http_client.OK
335        response.data = _helpers.to_bytes(json.dumps(token_response_body))
336
337        request = mock.create_autospec(transport.Request, instance=False)
338        request.return_value = response
339
340        credentials.refresh(request)
341
342        assert credentials.valid
343        assert not credentials.expired
344
345        signature = credentials.sign_bytes(b"signed bytes")
346        assert signature == b"signature"
347
348    def test_sign_bytes_failure(self):
349        credentials = self.make_credentials(lifetime=None)
350
351        with mock.patch(
352            "google.auth.transport.requests.AuthorizedSession.request", autospec=True
353        ) as auth_session:
354            data = {"error": {"code": 403, "message": "unauthorized"}}
355            auth_session.return_value = MockResponse(data, http_client.FORBIDDEN)
356
357            with pytest.raises(exceptions.TransportError) as excinfo:
358                credentials.sign_bytes(b"foo")
359            assert excinfo.match("'code': 403")
360
361    def test_with_quota_project(self):
362        credentials = self.make_credentials()
363
364        quota_project_creds = credentials.with_quota_project("project-foo")
365        assert quota_project_creds._quota_project_id == "project-foo"
366
367    @pytest.mark.parametrize("use_data_bytes", [True, False])
368    def test_with_quota_project_iam_endpoint_override(
369        self, use_data_bytes, mock_donor_credentials
370    ):
371        credentials = self.make_credentials(
372            lifetime=None, iam_endpoint_override=self.IAM_ENDPOINT_OVERRIDE
373        )
374        token = "token"
375        # iam_endpoint_override should be copied to created credentials.
376        quota_project_creds = credentials.with_quota_project("project-foo")
377
378        expire_time = (
379            _helpers.utcnow().replace(microsecond=0) + datetime.timedelta(seconds=500)
380        ).isoformat("T") + "Z"
381        response_body = {"accessToken": token, "expireTime": expire_time}
382
383        request = self.make_request(
384            data=json.dumps(response_body),
385            status=http_client.OK,
386            use_data_bytes=use_data_bytes,
387        )
388
389        quota_project_creds.refresh(request)
390
391        assert quota_project_creds.valid
392        assert not quota_project_creds.expired
393        # Confirm override endpoint used.
394        request_kwargs = request.call_args[1]
395        assert request_kwargs["url"] == self.IAM_ENDPOINT_OVERRIDE
396
397    def test_id_token_success(
398        self, mock_donor_credentials, mock_authorizedsession_idtoken
399    ):
400        credentials = self.make_credentials(lifetime=None)
401        token = "token"
402        target_audience = "https://foo.bar"
403
404        expire_time = (
405            _helpers.utcnow().replace(microsecond=0) + datetime.timedelta(seconds=500)
406        ).isoformat("T") + "Z"
407        response_body = {"accessToken": token, "expireTime": expire_time}
408
409        request = self.make_request(
410            data=json.dumps(response_body), status=http_client.OK
411        )
412
413        credentials.refresh(request)
414
415        assert credentials.valid
416        assert not credentials.expired
417
418        id_creds = impersonated_credentials.IDTokenCredentials(
419            credentials, target_audience=target_audience
420        )
421        id_creds.refresh(request)
422
423        assert id_creds.token == ID_TOKEN_DATA
424        assert id_creds.expiry == datetime.datetime.fromtimestamp(ID_TOKEN_EXPIRY)
425
426    def test_id_token_from_credential(
427        self, mock_donor_credentials, mock_authorizedsession_idtoken
428    ):
429        credentials = self.make_credentials(lifetime=None)
430        token = "token"
431        target_audience = "https://foo.bar"
432
433        expire_time = (
434            _helpers.utcnow().replace(microsecond=0) + datetime.timedelta(seconds=500)
435        ).isoformat("T") + "Z"
436        response_body = {"accessToken": token, "expireTime": expire_time}
437
438        request = self.make_request(
439            data=json.dumps(response_body), status=http_client.OK
440        )
441
442        credentials.refresh(request)
443
444        assert credentials.valid
445        assert not credentials.expired
446
447        id_creds = impersonated_credentials.IDTokenCredentials(
448            credentials, target_audience=target_audience, include_email=True
449        )
450        id_creds = id_creds.from_credentials(target_credentials=credentials)
451        id_creds.refresh(request)
452
453        assert id_creds.token == ID_TOKEN_DATA
454        assert id_creds._include_email is True
455
456    def test_id_token_with_target_audience(
457        self, mock_donor_credentials, mock_authorizedsession_idtoken
458    ):
459        credentials = self.make_credentials(lifetime=None)
460        token = "token"
461        target_audience = "https://foo.bar"
462
463        expire_time = (
464            _helpers.utcnow().replace(microsecond=0) + datetime.timedelta(seconds=500)
465        ).isoformat("T") + "Z"
466        response_body = {"accessToken": token, "expireTime": expire_time}
467
468        request = self.make_request(
469            data=json.dumps(response_body), status=http_client.OK
470        )
471
472        credentials.refresh(request)
473
474        assert credentials.valid
475        assert not credentials.expired
476
477        id_creds = impersonated_credentials.IDTokenCredentials(
478            credentials, include_email=True
479        )
480        id_creds = id_creds.with_target_audience(target_audience=target_audience)
481        id_creds.refresh(request)
482
483        assert id_creds.token == ID_TOKEN_DATA
484        assert id_creds.expiry == datetime.datetime.fromtimestamp(ID_TOKEN_EXPIRY)
485        assert id_creds._include_email is True
486
487    def test_id_token_invalid_cred(
488        self, mock_donor_credentials, mock_authorizedsession_idtoken
489    ):
490        credentials = None
491
492        with pytest.raises(exceptions.GoogleAuthError) as excinfo:
493            impersonated_credentials.IDTokenCredentials(credentials)
494
495        assert excinfo.match("Provided Credential must be" " impersonated_credentials")
496
497    def test_id_token_with_include_email(
498        self, mock_donor_credentials, mock_authorizedsession_idtoken
499    ):
500        credentials = self.make_credentials(lifetime=None)
501        token = "token"
502        target_audience = "https://foo.bar"
503
504        expire_time = (
505            _helpers.utcnow().replace(microsecond=0) + datetime.timedelta(seconds=500)
506        ).isoformat("T") + "Z"
507        response_body = {"accessToken": token, "expireTime": expire_time}
508
509        request = self.make_request(
510            data=json.dumps(response_body), status=http_client.OK
511        )
512
513        credentials.refresh(request)
514
515        assert credentials.valid
516        assert not credentials.expired
517
518        id_creds = impersonated_credentials.IDTokenCredentials(
519            credentials, target_audience=target_audience
520        )
521        id_creds = id_creds.with_include_email(True)
522        id_creds.refresh(request)
523
524        assert id_creds.token == ID_TOKEN_DATA
525
526    def test_id_token_with_quota_project(
527        self, mock_donor_credentials, mock_authorizedsession_idtoken
528    ):
529        credentials = self.make_credentials(lifetime=None)
530        token = "token"
531        target_audience = "https://foo.bar"
532
533        expire_time = (
534            _helpers.utcnow().replace(microsecond=0) + datetime.timedelta(seconds=500)
535        ).isoformat("T") + "Z"
536        response_body = {"accessToken": token, "expireTime": expire_time}
537
538        request = self.make_request(
539            data=json.dumps(response_body), status=http_client.OK
540        )
541
542        credentials.refresh(request)
543
544        assert credentials.valid
545        assert not credentials.expired
546
547        id_creds = impersonated_credentials.IDTokenCredentials(
548            credentials, target_audience=target_audience
549        )
550        id_creds = id_creds.with_quota_project("project-foo")
551        id_creds.refresh(request)
552
553        assert id_creds.quota_project_id == "project-foo"
554