1# Copyright 2016 Google Inc. All rights reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#      http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15import httplib2
16import mock
17import unittest2
18
19from oauth2client import client
20from oauth2client import transport
21
22
23class TestMemoryCache(unittest2.TestCase):
24
25    def test_get_set_delete(self):
26        cache = transport.MemoryCache()
27        self.assertIsNone(cache.get('foo'))
28        self.assertIsNone(cache.delete('foo'))
29        cache.set('foo', 'bar')
30        self.assertEqual('bar', cache.get('foo'))
31        cache.delete('foo')
32        self.assertIsNone(cache.get('foo'))
33
34
35class Test_get_cached_http(unittest2.TestCase):
36
37    def test_global(self):
38        cached_http = transport.get_cached_http()
39        self.assertIsInstance(cached_http, httplib2.Http)
40        self.assertIsInstance(cached_http.cache, transport.MemoryCache)
41
42    def test_value(self):
43        cache = object()
44        with mock.patch('oauth2client.transport._CACHED_HTTP', new=cache):
45            result = transport.get_cached_http()
46        self.assertIs(result, cache)
47
48
49class Test_get_http_object(unittest2.TestCase):
50
51    @mock.patch.object(httplib2, 'Http', return_value=object())
52    def test_it(self, http_klass):
53        result = transport.get_http_object()
54        self.assertEqual(result, http_klass.return_value)
55
56
57class Test__initialize_headers(unittest2.TestCase):
58
59    def test_null(self):
60        result = transport._initialize_headers(None)
61        self.assertEqual(result, {})
62
63    def test_copy(self):
64        headers = {'a': 1, 'b': 2}
65        result = transport._initialize_headers(headers)
66        self.assertEqual(result, headers)
67        self.assertIsNot(result, headers)
68
69
70class Test__apply_user_agent(unittest2.TestCase):
71
72    def test_null(self):
73        headers = object()
74        result = transport._apply_user_agent(headers, None)
75        self.assertIs(result, headers)
76
77    def test_new_agent(self):
78        headers = {}
79        user_agent = 'foo'
80        result = transport._apply_user_agent(headers, user_agent)
81        self.assertIs(result, headers)
82        self.assertEqual(result, {'user-agent': user_agent})
83
84    def test_append(self):
85        orig_agent = 'bar'
86        headers = {'user-agent': orig_agent}
87        user_agent = 'baz'
88        result = transport._apply_user_agent(headers, user_agent)
89        self.assertIs(result, headers)
90        final_agent = user_agent + ' ' + orig_agent
91        self.assertEqual(result, {'user-agent': final_agent})
92
93
94class Test_clean_headers(unittest2.TestCase):
95
96    def test_no_modify(self):
97        headers = {b'key': b'val'}
98        result = transport.clean_headers(headers)
99        self.assertIsNot(result, headers)
100        self.assertEqual(result, headers)
101
102    def test_cast_unicode(self):
103        headers = {u'key': u'val'}
104        header_bytes = {b'key': b'val'}
105        result = transport.clean_headers(headers)
106        self.assertIsNot(result, headers)
107        self.assertEqual(result, header_bytes)
108
109    def test_unicode_failure(self):
110        headers = {u'key': u'\u2603'}
111        with self.assertRaises(client.NonAsciiHeaderError):
112            transport.clean_headers(headers)
113
114    def test_cast_object(self):
115        headers = {b'key': True}
116        header_str = {b'key': b'True'}
117        result = transport.clean_headers(headers)
118        self.assertIsNot(result, headers)
119        self.assertEqual(result, header_str)
120
121
122class Test_wrap_http_for_auth(unittest2.TestCase):
123
124    def test_wrap(self):
125        credentials = object()
126        http = mock.Mock()
127        http.request = orig_req_method = object()
128        result = transport.wrap_http_for_auth(credentials, http)
129        self.assertIsNone(result)
130        self.assertNotEqual(http.request, orig_req_method)
131        self.assertIs(http.request.credentials, credentials)
132