1# Copyright 2016 Google LLC
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#      http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15import time
16
17import flask
18import pytest
19from pytest_localserver.http import WSGIServer
20from six.moves import http_client
21
22from google.auth import exceptions
23
24# .invalid will never resolve, see https://tools.ietf.org/html/rfc2606
25NXDOMAIN = "test.invalid"
26
27
28class RequestResponseTests(object):
29    @pytest.fixture(scope="module")
30    def server(self):
31        """Provides a test HTTP server.
32
33        The test server is automatically created before
34        a test and destroyed at the end. The server is serving a test
35        application that can be used to verify requests.
36        """
37        app = flask.Flask(__name__)
38        app.debug = True
39
40        # pylint: disable=unused-variable
41        # (pylint thinks the flask routes are unusued.)
42        @app.route("/basic")
43        def index():
44            header_value = flask.request.headers.get("x-test-header", "value")
45            headers = {"X-Test-Header": header_value}
46            return "Basic Content", http_client.OK, headers
47
48        @app.route("/server_error")
49        def server_error():
50            return "Error", http_client.INTERNAL_SERVER_ERROR
51
52        @app.route("/wait")
53        def wait():
54            time.sleep(3)
55            return "Waited"
56
57        # pylint: enable=unused-variable
58
59        server = WSGIServer(application=app.wsgi_app)
60        server.start()
61        yield server
62        server.stop()
63
64    def test_request_basic(self, server):
65        request = self.make_request()
66        response = request(url=server.url + "/basic", method="GET")
67
68        assert response.status == http_client.OK
69        assert response.headers["x-test-header"] == "value"
70        assert response.data == b"Basic Content"
71
72    def test_request_with_timeout_success(self, server):
73        request = self.make_request()
74        response = request(url=server.url + "/basic", method="GET", timeout=2)
75
76        assert response.status == http_client.OK
77        assert response.headers["x-test-header"] == "value"
78        assert response.data == b"Basic Content"
79
80    def test_request_with_timeout_failure(self, server):
81        request = self.make_request()
82
83        with pytest.raises(exceptions.TransportError):
84            request(url=server.url + "/wait", method="GET", timeout=1)
85
86    def test_request_headers(self, server):
87        request = self.make_request()
88        response = request(
89            url=server.url + "/basic",
90            method="GET",
91            headers={"x-test-header": "hello world"},
92        )
93
94        assert response.status == http_client.OK
95        assert response.headers["x-test-header"] == "hello world"
96        assert response.data == b"Basic Content"
97
98    def test_request_error(self, server):
99        request = self.make_request()
100        response = request(url=server.url + "/server_error", method="GET")
101
102        assert response.status == http_client.INTERNAL_SERVER_ERROR
103        assert response.data == b"Error"
104
105    def test_connection_error(self):
106        request = self.make_request()
107        with pytest.raises(exceptions.TransportError):
108            request(url="http://{}".format(NXDOMAIN), method="GET")
109