xref: /aosp_15_r20/external/federated-compute/fcp/demo/media_test.py (revision 14675a029014e728ec732f129a32e299b2da0601)
1# Copyright 2022 Google LLC
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#      http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""Tests for media."""
15
16import http
17from unittest import mock
18import uuid
19
20from absl.testing import absltest
21
22from fcp.demo import http_actions
23from fcp.demo import media
24from fcp.protos.federatedcompute import common_pb2
25
26
27class MediaTest(absltest.TestCase):
28
29  @mock.patch.object(uuid, 'uuid4', return_value=uuid.uuid4(), autospec=True)
30  def test_create_download_group(self, mock_uuid):
31    forwarding_info = common_pb2.ForwardingInfo(
32        target_uri_prefix='https://media.example/')
33    service = media.Service(lambda: forwarding_info)
34    with service.create_download_group() as group:
35      self.assertEqual(group.prefix,
36                       f'https://media.example/data/{mock_uuid.return_value}/')
37      name = 'file-name'
38      self.assertEqual(group.add(name, b'data'), group.prefix + name)
39
40  def test_download(self):
41    service = media.Service(common_pb2.ForwardingInfo)
42    with service.create_download_group() as group:
43      data = b'data'
44      url = group.add('name', data)
45      self.assertEqual(
46          service.download(b'',
47                           *url.split('/')[-2:]),
48          http_actions.HttpResponse(
49              body=data,
50              headers={
51                  'Content-Length': len(data),
52                  'Content-Type': 'application/octet-stream',
53              }))
54
55  def test_download_with_content_type(self):
56    service = media.Service(common_pb2.ForwardingInfo)
57    with service.create_download_group() as group:
58      data = b'data'
59      content_type = 'application/x-test'
60      url = group.add('name', data, content_type=content_type)
61      self.assertEqual(
62          service.download(b'',
63                           *url.split('/')[-2:]),
64          http_actions.HttpResponse(
65              body=data,
66              headers={
67                  'Content-Length': len(data),
68                  'Content-Type': content_type,
69              }))
70
71  def test_download_multiple_files(self):
72    service = media.Service(common_pb2.ForwardingInfo)
73    with service.create_download_group() as group:
74      data1 = b'data1'
75      data2 = b'data2'
76      url1 = group.add('file1', data1)
77      url2 = group.add('file2', data2)
78      self.assertEqual(service.download(b'', *url1.split('/')[-2:]).body, data1)
79      self.assertEqual(service.download(b'', *url2.split('/')[-2:]).body, data2)
80
81  def test_download_multiple_groups(self):
82    service = media.Service(common_pb2.ForwardingInfo)
83    with service.create_download_group() as group1, (
84        service.create_download_group()) as group2:
85      self.assertNotEqual(group1.prefix, group2.prefix)
86      data1 = b'data1'
87      data2 = b'data2'
88      url1 = group1.add('name', data1)
89      url2 = group2.add('name', data2)
90      self.assertEqual(service.download(b'', *url1.split('/')[-2:]).body, data1)
91      self.assertEqual(service.download(b'', *url2.split('/')[-2:]).body, data2)
92
93  def test_download_unregistered_group(self):
94    service = media.Service(common_pb2.ForwardingInfo)
95    with self.assertRaises(http_actions.HttpError) as cm:
96      service.download(b'', 'does-not-exist', 'does-not-exist')
97    self.assertEqual(cm.exception.code, http.HTTPStatus.NOT_FOUND)
98
99  def test_download_unregistered_file(self):
100    service = media.Service(common_pb2.ForwardingInfo)
101    with service.create_download_group() as group:
102      url = group.add('name', b'data')
103      with self.assertRaises(http_actions.HttpError) as cm:
104        service.download(b'', url.split('/')[-2], 'does-not-exist')
105      self.assertEqual(cm.exception.code, http.HTTPStatus.NOT_FOUND)
106
107  def test_download_no_longer_registered(self):
108    service = media.Service(common_pb2.ForwardingInfo)
109    with service.create_download_group() as group:
110      url = group.add('name', b'data')
111    with self.assertRaises(http_actions.HttpError) as cm:
112      service.download(b'', *url.split('/')[-2:])
113    self.assertEqual(cm.exception.code, http.HTTPStatus.NOT_FOUND)
114
115  def test_register_duplicate_download(self):
116    service = media.Service(common_pb2.ForwardingInfo)
117    with service.create_download_group() as group:
118      data1 = b'data'
119      url = group.add('name', data1)
120      with self.assertRaises(KeyError):
121        group.add('name', b'data2')
122
123      # The original file should still be downloadable.
124      self.assertEqual(service.download(b'', *url.split('/')[-2:]).body, data1)
125
126  def test_upload(self):
127    service = media.Service(common_pb2.ForwardingInfo)
128    name = service.register_upload()
129    data = b'data'
130    self.assertEqual(
131        service.upload(data, name), http_actions.HttpResponse(body=b''))
132    self.assertEqual(service.finalize_upload(name), data)
133
134  def test_upload_without_data(self):
135    service = media.Service(common_pb2.ForwardingInfo)
136    name = service.register_upload()
137    self.assertIsNone(service.finalize_upload(name))
138
139  def test_upload_multiple_times(self):
140    service = media.Service(common_pb2.ForwardingInfo)
141    name = service.register_upload()
142
143    data = b'data1'
144    self.assertEqual(
145        service.upload(data, name), http_actions.HttpResponse(body=b''))
146
147    with self.assertRaises(http_actions.HttpError) as cm:
148      service.upload(b'data2', name)
149    self.assertEqual(cm.exception.code, http.HTTPStatus.UNAUTHORIZED)
150
151    self.assertEqual(service.finalize_upload(name), data)
152
153  def test_upload_multiple(self):
154    service = media.Service(common_pb2.ForwardingInfo)
155    name1 = service.register_upload()
156    name2 = service.register_upload()
157
158    # Order shouldn't matter.
159    service.upload(b'data2', name2)
160    service.upload(b'data1', name1)
161
162    self.assertEqual(service.finalize_upload(name1), b'data1')
163    self.assertEqual(service.finalize_upload(name2), b'data2')
164
165  def test_upload_unregistered(self):
166    service = media.Service(common_pb2.ForwardingInfo)
167    with self.assertRaises(http_actions.HttpError) as cm:
168      service.upload(b'data', 'does-not-exist')
169    self.assertEqual(cm.exception.code, http.HTTPStatus.UNAUTHORIZED)
170
171    with self.assertRaises(KeyError):
172      service.finalize_upload('does-not-exist')
173
174  def test_upload_no_longer_registered(self):
175    service = media.Service(common_pb2.ForwardingInfo)
176    name = service.register_upload()
177    self.assertIsNone(service.finalize_upload(name))
178
179    with self.assertRaises(http_actions.HttpError) as cm:
180      service.upload(b'data', name)
181    self.assertEqual(cm.exception.code, http.HTTPStatus.UNAUTHORIZED)
182
183    with self.assertRaises(KeyError):
184      service.finalize_upload(name)
185
186
187if __name__ == '__main__':
188  absltest.main()
189