1# Copyright 2020 Google LLC
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#      http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15import aiohttp
16from aioresponses import aioresponses, core
17import mock
18import pytest
19from tests_async.transport import async_compliance
20
21import google.auth._credentials_async
22from google.auth.transport import _aiohttp_requests as aiohttp_requests
23import google.auth.transport._mtls_helper
24
25
26class TestCombinedResponse:
27    @pytest.mark.asyncio
28    async def test__is_compressed(self):
29        response = core.CallbackResult(headers={"Content-Encoding": "gzip"})
30        combined_response = aiohttp_requests._CombinedResponse(response)
31        compressed = combined_response._is_compressed()
32        assert compressed
33
34    def test__is_compressed_not(self):
35        response = core.CallbackResult(headers={"Content-Encoding": "not"})
36        combined_response = aiohttp_requests._CombinedResponse(response)
37        compressed = combined_response._is_compressed()
38        assert not compressed
39
40    @pytest.mark.asyncio
41    async def test_raw_content(self):
42
43        mock_response = mock.AsyncMock()
44        mock_response.content.read.return_value = mock.sentinel.read
45        combined_response = aiohttp_requests._CombinedResponse(response=mock_response)
46        raw_content = await combined_response.raw_content()
47        assert raw_content == mock.sentinel.read
48
49        # Second call to validate the preconfigured path.
50        combined_response._raw_content = mock.sentinel.stored_raw
51        raw_content = await combined_response.raw_content()
52        assert raw_content == mock.sentinel.stored_raw
53
54    @pytest.mark.asyncio
55    async def test_content(self):
56        mock_response = mock.AsyncMock()
57        mock_response.content.read.return_value = mock.sentinel.read
58        combined_response = aiohttp_requests._CombinedResponse(response=mock_response)
59        content = await combined_response.content()
60        assert content == mock.sentinel.read
61
62    @mock.patch(
63        "google.auth.transport._aiohttp_requests.urllib3.response.MultiDecoder.decompress",
64        return_value="decompressed",
65        autospec=True,
66    )
67    @pytest.mark.asyncio
68    async def test_content_compressed(self, urllib3_mock):
69        rm = core.RequestMatch(
70            "url", headers={"Content-Encoding": "gzip"}, payload="compressed"
71        )
72        response = await rm.build_response(core.URL("url"))
73
74        combined_response = aiohttp_requests._CombinedResponse(response=response)
75        content = await combined_response.content()
76
77        urllib3_mock.assert_called_once()
78        assert content == "decompressed"
79
80
81class TestResponse:
82    def test_ctor(self):
83        response = aiohttp_requests._Response(mock.sentinel.response)
84        assert response._response == mock.sentinel.response
85
86    @pytest.mark.asyncio
87    async def test_headers_prop(self):
88        rm = core.RequestMatch("url", headers={"Content-Encoding": "header prop"})
89        mock_response = await rm.build_response(core.URL("url"))
90
91        response = aiohttp_requests._Response(mock_response)
92        assert response.headers["Content-Encoding"] == "header prop"
93
94    @pytest.mark.asyncio
95    async def test_status_prop(self):
96        rm = core.RequestMatch("url", status=123)
97        mock_response = await rm.build_response(core.URL("url"))
98        response = aiohttp_requests._Response(mock_response)
99        assert response.status == 123
100
101    @pytest.mark.asyncio
102    async def test_data_prop(self):
103        mock_response = mock.AsyncMock()
104        mock_response.content.read.return_value = mock.sentinel.read
105        response = aiohttp_requests._Response(mock_response)
106        data = await response.data.read()
107        assert data == mock.sentinel.read
108
109
110class TestRequestResponse(async_compliance.RequestResponseTests):
111    def make_request(self):
112        return aiohttp_requests.Request()
113
114    def make_with_parameter_request(self):
115        http = aiohttp.ClientSession(auto_decompress=False)
116        return aiohttp_requests.Request(http)
117
118    def test_unsupported_session(self):
119        http = aiohttp.ClientSession(auto_decompress=True)
120        with pytest.raises(ValueError):
121            aiohttp_requests.Request(http)
122
123    def test_timeout(self):
124        http = mock.create_autospec(
125            aiohttp.ClientSession, instance=True, _auto_decompress=False
126        )
127        request = aiohttp_requests.Request(http)
128        request(url="http://example.com", method="GET", timeout=5)
129
130
131class CredentialsStub(google.auth._credentials_async.Credentials):
132    def __init__(self, token="token"):
133        super(CredentialsStub, self).__init__()
134        self.token = token
135
136    def apply(self, headers, token=None):
137        headers["authorization"] = self.token
138
139    def refresh(self, request):
140        self.token += "1"
141
142
143class TestAuthorizedSession(object):
144    TEST_URL = "http://example.com/"
145    method = "GET"
146
147    def test_constructor(self):
148        authed_session = aiohttp_requests.AuthorizedSession(mock.sentinel.credentials)
149        assert authed_session.credentials == mock.sentinel.credentials
150
151    def test_constructor_with_auth_request(self):
152        http = mock.create_autospec(
153            aiohttp.ClientSession, instance=True, _auto_decompress=False
154        )
155        auth_request = aiohttp_requests.Request(http)
156
157        authed_session = aiohttp_requests.AuthorizedSession(
158            mock.sentinel.credentials, auth_request=auth_request
159        )
160
161        assert authed_session._auth_request == auth_request
162
163    @pytest.mark.asyncio
164    async def test_request(self):
165        with aioresponses() as mocked:
166            credentials = mock.Mock(wraps=CredentialsStub())
167
168            mocked.get(self.TEST_URL, status=200, body="test")
169            session = aiohttp_requests.AuthorizedSession(credentials)
170            resp = await session.request(
171                "GET",
172                "http://example.com/",
173                headers={"Keep-Alive": "timeout=5, max=1000", "fake": b"bytes"},
174            )
175
176            assert resp.status == 200
177            assert "test" == await resp.text()
178
179            await session.close()
180
181    @pytest.mark.asyncio
182    async def test_ctx(self):
183        with aioresponses() as mocked:
184            credentials = mock.Mock(wraps=CredentialsStub())
185            mocked.get("http://test.example.com", payload=dict(foo="bar"))
186            session = aiohttp_requests.AuthorizedSession(credentials)
187            resp = await session.request("GET", "http://test.example.com")
188            data = await resp.json()
189
190            assert dict(foo="bar") == data
191
192            await session.close()
193
194    @pytest.mark.asyncio
195    async def test_http_headers(self):
196        with aioresponses() as mocked:
197            credentials = mock.Mock(wraps=CredentialsStub())
198            mocked.post(
199                "http://example.com",
200                payload=dict(),
201                headers=dict(connection="keep-alive"),
202            )
203
204            session = aiohttp_requests.AuthorizedSession(credentials)
205            resp = await session.request("POST", "http://example.com")
206
207            assert resp.headers["Connection"] == "keep-alive"
208
209            await session.close()
210
211    @pytest.mark.asyncio
212    async def test_regexp_example(self):
213        with aioresponses() as mocked:
214            credentials = mock.Mock(wraps=CredentialsStub())
215            mocked.get("http://example.com", status=500)
216            mocked.get("http://example.com", status=200)
217
218            session1 = aiohttp_requests.AuthorizedSession(credentials)
219
220            resp1 = await session1.request("GET", "http://example.com")
221            session2 = aiohttp_requests.AuthorizedSession(credentials)
222            resp2 = await session2.request("GET", "http://example.com")
223
224            assert resp1.status == 500
225            assert resp2.status == 200
226
227            await session1.close()
228            await session2.close()
229
230    @pytest.mark.asyncio
231    async def test_request_no_refresh(self):
232        credentials = mock.Mock(wraps=CredentialsStub())
233        with aioresponses() as mocked:
234            mocked.get("http://example.com", status=200)
235            authed_session = aiohttp_requests.AuthorizedSession(credentials)
236            response = await authed_session.request("GET", "http://example.com")
237            assert response.status == 200
238            assert credentials.before_request.called
239            assert not credentials.refresh.called
240
241            await authed_session.close()
242
243    @pytest.mark.asyncio
244    async def test_request_refresh(self):
245        credentials = mock.Mock(wraps=CredentialsStub())
246        with aioresponses() as mocked:
247            mocked.get("http://example.com", status=401)
248            mocked.get("http://example.com", status=200)
249            authed_session = aiohttp_requests.AuthorizedSession(credentials)
250            response = await authed_session.request("GET", "http://example.com")
251            assert credentials.refresh.called
252            assert response.status == 200
253
254            await authed_session.close()
255