1# Copyright 2022 Google LLC 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14"""Tests for media.""" 15 16import http 17from unittest import mock 18import uuid 19 20from absl.testing import absltest 21 22from fcp.demo import http_actions 23from fcp.demo import media 24from fcp.protos.federatedcompute import common_pb2 25 26 27class MediaTest(absltest.TestCase): 28 29 @mock.patch.object(uuid, 'uuid4', return_value=uuid.uuid4(), autospec=True) 30 def test_create_download_group(self, mock_uuid): 31 forwarding_info = common_pb2.ForwardingInfo( 32 target_uri_prefix='https://media.example/') 33 service = media.Service(lambda: forwarding_info) 34 with service.create_download_group() as group: 35 self.assertEqual(group.prefix, 36 f'https://media.example/data/{mock_uuid.return_value}/') 37 name = 'file-name' 38 self.assertEqual(group.add(name, b'data'), group.prefix + name) 39 40 def test_download(self): 41 service = media.Service(common_pb2.ForwardingInfo) 42 with service.create_download_group() as group: 43 data = b'data' 44 url = group.add('name', data) 45 self.assertEqual( 46 service.download(b'', 47 *url.split('/')[-2:]), 48 http_actions.HttpResponse( 49 body=data, 50 headers={ 51 'Content-Length': len(data), 52 'Content-Type': 'application/octet-stream', 53 })) 54 55 def test_download_with_content_type(self): 56 service = media.Service(common_pb2.ForwardingInfo) 57 with service.create_download_group() as group: 58 data = b'data' 59 content_type = 'application/x-test' 60 url = group.add('name', data, content_type=content_type) 61 self.assertEqual( 62 service.download(b'', 63 *url.split('/')[-2:]), 64 http_actions.HttpResponse( 65 body=data, 66 headers={ 67 'Content-Length': len(data), 68 'Content-Type': content_type, 69 })) 70 71 def test_download_multiple_files(self): 72 service = media.Service(common_pb2.ForwardingInfo) 73 with service.create_download_group() as group: 74 data1 = b'data1' 75 data2 = b'data2' 76 url1 = group.add('file1', data1) 77 url2 = group.add('file2', data2) 78 self.assertEqual(service.download(b'', *url1.split('/')[-2:]).body, data1) 79 self.assertEqual(service.download(b'', *url2.split('/')[-2:]).body, data2) 80 81 def test_download_multiple_groups(self): 82 service = media.Service(common_pb2.ForwardingInfo) 83 with service.create_download_group() as group1, ( 84 service.create_download_group()) as group2: 85 self.assertNotEqual(group1.prefix, group2.prefix) 86 data1 = b'data1' 87 data2 = b'data2' 88 url1 = group1.add('name', data1) 89 url2 = group2.add('name', data2) 90 self.assertEqual(service.download(b'', *url1.split('/')[-2:]).body, data1) 91 self.assertEqual(service.download(b'', *url2.split('/')[-2:]).body, data2) 92 93 def test_download_unregistered_group(self): 94 service = media.Service(common_pb2.ForwardingInfo) 95 with self.assertRaises(http_actions.HttpError) as cm: 96 service.download(b'', 'does-not-exist', 'does-not-exist') 97 self.assertEqual(cm.exception.code, http.HTTPStatus.NOT_FOUND) 98 99 def test_download_unregistered_file(self): 100 service = media.Service(common_pb2.ForwardingInfo) 101 with service.create_download_group() as group: 102 url = group.add('name', b'data') 103 with self.assertRaises(http_actions.HttpError) as cm: 104 service.download(b'', url.split('/')[-2], 'does-not-exist') 105 self.assertEqual(cm.exception.code, http.HTTPStatus.NOT_FOUND) 106 107 def test_download_no_longer_registered(self): 108 service = media.Service(common_pb2.ForwardingInfo) 109 with service.create_download_group() as group: 110 url = group.add('name', b'data') 111 with self.assertRaises(http_actions.HttpError) as cm: 112 service.download(b'', *url.split('/')[-2:]) 113 self.assertEqual(cm.exception.code, http.HTTPStatus.NOT_FOUND) 114 115 def test_register_duplicate_download(self): 116 service = media.Service(common_pb2.ForwardingInfo) 117 with service.create_download_group() as group: 118 data1 = b'data' 119 url = group.add('name', data1) 120 with self.assertRaises(KeyError): 121 group.add('name', b'data2') 122 123 # The original file should still be downloadable. 124 self.assertEqual(service.download(b'', *url.split('/')[-2:]).body, data1) 125 126 def test_upload(self): 127 service = media.Service(common_pb2.ForwardingInfo) 128 name = service.register_upload() 129 data = b'data' 130 self.assertEqual( 131 service.upload(data, name), http_actions.HttpResponse(body=b'')) 132 self.assertEqual(service.finalize_upload(name), data) 133 134 def test_upload_without_data(self): 135 service = media.Service(common_pb2.ForwardingInfo) 136 name = service.register_upload() 137 self.assertIsNone(service.finalize_upload(name)) 138 139 def test_upload_multiple_times(self): 140 service = media.Service(common_pb2.ForwardingInfo) 141 name = service.register_upload() 142 143 data = b'data1' 144 self.assertEqual( 145 service.upload(data, name), http_actions.HttpResponse(body=b'')) 146 147 with self.assertRaises(http_actions.HttpError) as cm: 148 service.upload(b'data2', name) 149 self.assertEqual(cm.exception.code, http.HTTPStatus.UNAUTHORIZED) 150 151 self.assertEqual(service.finalize_upload(name), data) 152 153 def test_upload_multiple(self): 154 service = media.Service(common_pb2.ForwardingInfo) 155 name1 = service.register_upload() 156 name2 = service.register_upload() 157 158 # Order shouldn't matter. 159 service.upload(b'data2', name2) 160 service.upload(b'data1', name1) 161 162 self.assertEqual(service.finalize_upload(name1), b'data1') 163 self.assertEqual(service.finalize_upload(name2), b'data2') 164 165 def test_upload_unregistered(self): 166 service = media.Service(common_pb2.ForwardingInfo) 167 with self.assertRaises(http_actions.HttpError) as cm: 168 service.upload(b'data', 'does-not-exist') 169 self.assertEqual(cm.exception.code, http.HTTPStatus.UNAUTHORIZED) 170 171 with self.assertRaises(KeyError): 172 service.finalize_upload('does-not-exist') 173 174 def test_upload_no_longer_registered(self): 175 service = media.Service(common_pb2.ForwardingInfo) 176 name = service.register_upload() 177 self.assertIsNone(service.finalize_upload(name)) 178 179 with self.assertRaises(http_actions.HttpError) as cm: 180 service.upload(b'data', name) 181 self.assertEqual(cm.exception.code, http.HTTPStatus.UNAUTHORIZED) 182 183 with self.assertRaises(KeyError): 184 service.finalize_upload(name) 185 186 187if __name__ == '__main__': 188 absltest.main() 189