xref: /aosp_15_r20/external/federated-compute/fcp/demo/aggregations_test.py (revision 14675a029014e728ec732f129a32e299b2da0601)
1# Copyright 2022 Google LLC
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#      http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""Tests for aggregations."""
15
16import asyncio
17import http
18import unittest
19from unittest import mock
20
21from absl.testing import absltest
22import tensorflow as tf
23
24from fcp.aggregation.protocol import aggregation_protocol_messages_pb2 as apm_pb2
25from fcp.aggregation.protocol.python import aggregation_protocol
26from fcp.aggregation.tensorflow.python import aggregation_protocols
27from fcp.demo import aggregations
28from fcp.demo import http_actions
29from fcp.demo import media
30from fcp.demo import test_utils
31from fcp.protos import plan_pb2
32from fcp.protos.federatedcompute import aggregations_pb2
33from fcp.protos.federatedcompute import common_pb2
34from pybind11_abseil import status as absl_status
35
36INPUT_TENSOR = 'in'
37OUTPUT_TENSOR = 'out'
38AGGREGATION_REQUIREMENTS = aggregations.AggregationRequirements(
39    minimum_clients_in_server_published_aggregate=3,
40    plan=plan_pb2.Plan(phase=[
41        plan_pb2.Plan.Phase(
42            server_phase_v2=plan_pb2.ServerPhaseV2(aggregations=[
43                plan_pb2.ServerAggregationConfig(
44                    intrinsic_uri='federated_sum',
45                    intrinsic_args=[
46                        plan_pb2.ServerAggregationConfig.IntrinsicArg(
47                            input_tensor=tf.TensorSpec((
48                            ), tf.int32, INPUT_TENSOR).experimental_as_proto())
49                    ],
50                    output_tensors=[
51                        tf.TensorSpec((
52                        ), tf.int32, OUTPUT_TENSOR).experimental_as_proto(),
53                    ]),
54            ])),
55    ]))
56FORWARDING_INFO = common_pb2.ForwardingInfo(
57    target_uri_prefix='https://forwarding.example/')
58
59
60class NotOkStatus:
61  """Matcher for a not-ok Status."""
62
63  def __eq__(self, other) -> bool:
64    return isinstance(other, absl_status.Status) and not other.ok()
65
66
67class AggregationsTest(absltest.TestCase, unittest.IsolatedAsyncioTestCase):
68
69  def setUp(self):
70    super().setUp()
71    self.mock_media_service = self.enter_context(
72        mock.patch.object(media, 'Service', autospec=True))
73    self.mock_media_service.register_upload.return_value = 'upload-id'
74    self.mock_media_service.finalize_upload.return_value = (
75        test_utils.create_checkpoint({INPUT_TENSOR: 0}))
76
77  def test_pre_authorize_clients(self):
78    service = aggregations.Service(lambda: FORWARDING_INFO,
79                                   self.mock_media_service)
80    session_id = service.create_session(AGGREGATION_REQUIREMENTS)
81    tokens = service.pre_authorize_clients(session_id, 3)
82    self.assertLen(tokens, 3)
83    # The tokens should all be unique.
84    self.assertLen(set(tokens), 3)
85
86  def test_pre_authorize_clients_with_missing_session_id(self):
87    service = aggregations.Service(lambda: FORWARDING_INFO,
88                                   self.mock_media_service)
89    with self.assertRaises(KeyError):
90      service.pre_authorize_clients('does-not-exist', 1)
91
92  def test_pre_authorize_clients_with_bad_count(self):
93    service = aggregations.Service(lambda: FORWARDING_INFO,
94                                   self.mock_media_service)
95    session_id = service.create_session(AGGREGATION_REQUIREMENTS)
96    self.assertEmpty(service.pre_authorize_clients(session_id, 0))
97    self.assertEmpty(service.pre_authorize_clients(session_id, -2))
98
99  def test_create_session(self):
100    service = aggregations.Service(lambda: FORWARDING_INFO,
101                                   self.mock_media_service)
102    session_id = service.create_session(AGGREGATION_REQUIREMENTS)
103    self.assertEqual(
104        service.get_session_status(session_id),
105        aggregations.SessionStatus(
106            status=aggregations.AggregationStatus.PENDING))
107
108  def test_complete_session(self):
109    service = aggregations.Service(lambda: FORWARDING_INFO,
110                                   self.mock_media_service)
111    session_id = service.create_session(AGGREGATION_REQUIREMENTS)
112
113    # Upload results from the client.
114    num_clients = (
115        AGGREGATION_REQUIREMENTS.minimum_clients_in_server_published_aggregate)
116    for i in range(num_clients):
117      tokens = service.pre_authorize_clients(session_id, 1)
118
119      self.mock_media_service.register_upload.return_value = f'upload-{i}'
120      operation = service.start_aggregation_data_upload(
121          aggregations_pb2.StartAggregationDataUploadRequest(
122              aggregation_id=session_id, authorization_token=tokens[0]))
123      self.assertTrue(operation.done)
124      start_upload_response = (
125          aggregations_pb2.StartAggregationDataUploadResponse())
126      operation.response.Unpack(start_upload_response)
127
128      self.mock_media_service.finalize_upload.return_value = (
129          test_utils.create_checkpoint({INPUT_TENSOR: i}))
130      service.submit_aggregation_result(
131          aggregations_pb2.SubmitAggregationResultRequest(
132              aggregation_id=session_id,
133              client_token=start_upload_response.client_token,
134              resource_name=start_upload_response.resource.resource_name))
135
136    # Now that all clients have contributed, the aggregation session can be
137    # completed.
138    status, aggregate = service.complete_session(session_id)
139    self.assertEqual(
140        status,
141        aggregations.SessionStatus(
142            status=aggregations.AggregationStatus.COMPLETED,
143            num_clients_completed=num_clients,
144            num_inputs_aggregated_and_included=num_clients))
145    self.assertEqual(
146        test_utils.read_tensor_from_checkpoint(aggregate,
147                                               OUTPUT_TENSOR, tf.int32),
148        sum(range(num_clients)))
149
150    # get_session_status should no longer return results.
151    with self.assertRaises(KeyError):
152      service.get_session_status(session_id)
153
154  @mock.patch.object(
155      aggregation_protocols,
156      'create_simple_aggregation_protocol',
157      autospec=True)
158  def test_complete_session_fails(self, mock_create_simple_agg_protocol):
159    # Use a mock since it's not easy to cause
160    # SimpleAggregationProtocol::Complete to fail.
161    mock_agg_protocol = mock.create_autospec(
162        aggregation_protocol.AggregationProtocol, instance=True)
163    mock_create_simple_agg_protocol.return_value = mock_agg_protocol
164
165    service = aggregations.Service(lambda: FORWARDING_INFO,
166                                   self.mock_media_service)
167    session_id = service.create_session(AGGREGATION_REQUIREMENTS)
168
169    required_clients = (
170        AGGREGATION_REQUIREMENTS.minimum_clients_in_server_published_aggregate)
171    agg_status = apm_pb2.StatusMessage(
172        num_inputs_aggregated_and_included=required_clients)
173    mock_agg_protocol.GetStatus.side_effect = lambda: agg_status
174
175    def on_complete():
176      agg_status.num_inputs_discarded = (
177          agg_status.num_inputs_aggregated_and_included)
178      agg_status.num_inputs_aggregated_and_included = 0
179      raise absl_status.StatusNotOk(absl_status.unknown_error('message'))
180
181    mock_agg_protocol.Complete.side_effect = on_complete
182
183    status, aggregate = service.complete_session(session_id)
184    self.assertEqual(
185        status,
186        aggregations.SessionStatus(
187            status=aggregations.AggregationStatus.FAILED,
188            num_inputs_discarded=required_clients))
189    self.assertIsNone(aggregate)
190    mock_agg_protocol.Complete.assert_called_once()
191    mock_agg_protocol.Abort.assert_not_called()
192
193  @mock.patch.object(
194      aggregation_protocols,
195      'create_simple_aggregation_protocol',
196      autospec=True)
197  def test_complete_session_aborts(self, mock_create_simple_agg_protocol):
198    # Use a mock since it's not easy to cause
199    # SimpleAggregationProtocol::Complete to trigger a protocol abort.
200    mock_agg_protocol = mock.create_autospec(
201        aggregation_protocol.AggregationProtocol, instance=True)
202    mock_create_simple_agg_protocol.return_value = mock_agg_protocol
203
204    service = aggregations.Service(lambda: FORWARDING_INFO,
205                                   self.mock_media_service)
206    session_id = service.create_session(AGGREGATION_REQUIREMENTS)
207
208    required_clients = (
209        AGGREGATION_REQUIREMENTS.minimum_clients_in_server_published_aggregate)
210    agg_status = apm_pb2.StatusMessage(
211        num_inputs_aggregated_and_included=required_clients)
212    mock_agg_protocol.GetStatus.side_effect = lambda: agg_status
213
214    def on_complete():
215      agg_status.num_inputs_discarded = (
216          agg_status.num_inputs_aggregated_and_included)
217      agg_status.num_inputs_aggregated_and_included = 0
218      callback = mock_create_simple_agg_protocol.call_args.args[1]
219      callback.OnAbort(absl_status.unknown_error('message'))
220
221    mock_agg_protocol.Complete.side_effect = on_complete
222
223    status, aggregate = service.complete_session(session_id)
224    self.assertEqual(
225        status,
226        aggregations.SessionStatus(
227            status=aggregations.AggregationStatus.FAILED,
228            num_inputs_discarded=required_clients))
229    self.assertIsNone(aggregate)
230    mock_agg_protocol.Complete.assert_called_once()
231    mock_agg_protocol.Abort.assert_not_called()
232
233  def test_complete_session_without_enough_inputs(self):
234    service = aggregations.Service(lambda: FORWARDING_INFO,
235                                   self.mock_media_service)
236    session_id = service.create_session(
237        aggregations.AggregationRequirements(
238            minimum_clients_in_server_published_aggregate=3,
239            plan=AGGREGATION_REQUIREMENTS.plan))
240    tokens = service.pre_authorize_clients(session_id, 2)
241
242    # Upload results for one client.
243    operation = service.start_aggregation_data_upload(
244        aggregations_pb2.StartAggregationDataUploadRequest(
245            aggregation_id=session_id, authorization_token=tokens[0]))
246    self.assertTrue(operation.done)
247    start_upload_response = (
248        aggregations_pb2.StartAggregationDataUploadResponse())
249    operation.response.Unpack(start_upload_response)
250    service.submit_aggregation_result(
251        aggregations_pb2.SubmitAggregationResultRequest(
252            aggregation_id=session_id,
253            client_token=start_upload_response.client_token,
254            resource_name=start_upload_response.resource.resource_name))
255
256    # Complete the session before there are 2 completed clients.
257    status, aggregate = service.complete_session(session_id)
258    self.assertEqual(
259        status,
260        aggregations.SessionStatus(
261            status=aggregations.AggregationStatus.FAILED,
262            num_clients_completed=1,
263            num_inputs_discarded=1))
264    self.assertIsNone(aggregate)
265
266  def test_complete_session_with_missing_session_id(self):
267    service = aggregations.Service(lambda: FORWARDING_INFO,
268                                   self.mock_media_service)
269    with self.assertRaises(KeyError):
270      service.complete_session('does-not-exist')
271
272  def test_abort_session_with_no_uploads(self):
273    service = aggregations.Service(lambda: FORWARDING_INFO,
274                                   self.mock_media_service)
275    session_id = service.create_session(AGGREGATION_REQUIREMENTS)
276    self.assertEqual(
277        service.abort_session(session_id),
278        aggregations.SessionStatus(
279            status=aggregations.AggregationStatus.ABORTED))
280
281  def test_abort_session_with_uploads(self):
282    service = aggregations.Service(lambda: FORWARDING_INFO,
283                                   self.mock_media_service)
284    session_id = service.create_session(AGGREGATION_REQUIREMENTS)
285    tokens = service.pre_authorize_clients(session_id, 3)
286
287    # Upload results for one client.
288    self.mock_media_service.register_upload.return_value = 'upload1'
289    operation = service.start_aggregation_data_upload(
290        aggregations_pb2.StartAggregationDataUploadRequest(
291            aggregation_id=session_id, authorization_token=tokens[0]))
292    self.assertTrue(operation.done)
293    start_upload_response = (
294        aggregations_pb2.StartAggregationDataUploadResponse())
295    operation.response.Unpack(start_upload_response)
296    service.submit_aggregation_result(
297        aggregations_pb2.SubmitAggregationResultRequest(
298            aggregation_id=session_id,
299            client_token=start_upload_response.client_token,
300            resource_name=start_upload_response.resource.resource_name))
301
302    # Start a partial upload from a second client.
303    self.mock_media_service.register_upload.return_value = 'upload2'
304    service.start_aggregation_data_upload(
305        aggregations_pb2.StartAggregationDataUploadRequest(
306            aggregation_id=session_id, authorization_token=tokens[1]))
307
308    # Abort the session. The pending client should be treated as failed.
309    self.assertEqual(
310        service.abort_session(session_id),
311        aggregations.SessionStatus(
312            status=aggregations.AggregationStatus.ABORTED,
313            num_clients_completed=1,
314            num_clients_aborted=1,
315            num_inputs_discarded=1))
316    # The registered upload for the second client should have been finalized.
317    self.mock_media_service.finalize_upload.assert_called_with('upload2')
318    # get_session_status should no longer return results.
319    with self.assertRaises(KeyError):
320      service.get_session_status(session_id)
321
322  def test_abort_session_with_missing_session_id(self):
323    service = aggregations.Service(lambda: FORWARDING_INFO,
324                                   self.mock_media_service)
325    with self.assertRaises(KeyError):
326      service.abort_session('does-not-exist')
327
328  def test_get_session_status_with_missing_session_id(self):
329    service = aggregations.Service(lambda: FORWARDING_INFO,
330                                   self.mock_media_service)
331    with self.assertRaises(KeyError):
332      service.get_session_status('does-not-exist')
333
334  async def test_wait(self):
335    service = aggregations.Service(lambda: FORWARDING_INFO,
336                                   self.mock_media_service)
337    session_id = service.create_session(AGGREGATION_REQUIREMENTS)
338    task = asyncio.create_task(
339        service.wait(session_id, num_inputs_aggregated_and_included=1))
340    # The awaitable should not be done yet.
341    await asyncio.wait([task], timeout=0.1)
342    self.assertFalse(task.done())
343
344    # Upload results for one client.
345    tokens = service.pre_authorize_clients(session_id, 1)
346    self.mock_media_service.register_upload.return_value = 'upload'
347    operation = service.start_aggregation_data_upload(
348        aggregations_pb2.StartAggregationDataUploadRequest(
349            aggregation_id=session_id, authorization_token=tokens[0]))
350    self.assertTrue(operation.done)
351    start_upload_response = (
352        aggregations_pb2.StartAggregationDataUploadResponse())
353    operation.response.Unpack(start_upload_response)
354    service.submit_aggregation_result(
355        aggregations_pb2.SubmitAggregationResultRequest(
356            aggregation_id=session_id,
357            client_token=start_upload_response.client_token,
358            resource_name=start_upload_response.resource.resource_name))
359
360    # The awaitable should now return.
361    await asyncio.wait([task], timeout=1)
362    self.assertTrue(task.done())
363    self.assertEqual(
364        task.result(),
365        aggregations.SessionStatus(
366            status=aggregations.AggregationStatus.PENDING,
367            num_clients_completed=1,
368            num_inputs_aggregated_and_included=1))
369
370  async def test_wait_already_satisfied(self):
371    service = aggregations.Service(lambda: FORWARDING_INFO,
372                                   self.mock_media_service)
373    session_id = service.create_session(AGGREGATION_REQUIREMENTS)
374
375    # Upload results for one client.
376    tokens = service.pre_authorize_clients(session_id, 1)
377    self.mock_media_service.register_upload.return_value = 'upload'
378    operation = service.start_aggregation_data_upload(
379        aggregations_pb2.StartAggregationDataUploadRequest(
380            aggregation_id=session_id, authorization_token=tokens[0]))
381    self.assertTrue(operation.done)
382    start_upload_response = (
383        aggregations_pb2.StartAggregationDataUploadResponse())
384    operation.response.Unpack(start_upload_response)
385    service.submit_aggregation_result(
386        aggregations_pb2.SubmitAggregationResultRequest(
387            aggregation_id=session_id,
388            client_token=start_upload_response.client_token,
389            resource_name=start_upload_response.resource.resource_name))
390
391    # Since a client has already reported, the condition should already be
392    # satisfied.
393    task = asyncio.create_task(
394        service.wait(session_id, num_inputs_aggregated_and_included=1))
395    await asyncio.wait([task], timeout=1)
396    self.assertTrue(task.done())
397    self.assertEqual(
398        task.result(),
399        aggregations.SessionStatus(
400            status=aggregations.AggregationStatus.PENDING,
401            num_clients_completed=1,
402            num_inputs_aggregated_and_included=1))
403
404  async def test_wait_with_abort(self):
405    service = aggregations.Service(lambda: FORWARDING_INFO,
406                                   self.mock_media_service)
407    session_id = service.create_session(AGGREGATION_REQUIREMENTS)
408    task = asyncio.create_task(
409        service.wait(session_id, num_inputs_aggregated_and_included=1))
410    # The awaitable should not be done yet.
411    await asyncio.wait([task], timeout=0.1)
412    self.assertFalse(task.done())
413
414    # The awaitable should return once the session is aborted.
415    status = service.abort_session(session_id)
416    await asyncio.wait([task], timeout=1)
417    self.assertTrue(task.done())
418    self.assertEqual(task.result(), status)
419
420  @mock.patch.object(
421      aggregation_protocols,
422      'create_simple_aggregation_protocol',
423      autospec=True)
424  async def test_wait_with_protocol_abort(self,
425                                          mock_create_simple_agg_protocol):
426    # Use a mock since it's not easy to cause the AggregationProtocol to abort.
427    mock_agg_protocol = mock.create_autospec(
428        aggregation_protocol.AggregationProtocol, instance=True)
429    mock_create_simple_agg_protocol.return_value = mock_agg_protocol
430    mock_agg_protocol.GetStatus.return_value = apm_pb2.StatusMessage(
431        num_clients_aborted=1234)
432
433    service = aggregations.Service(lambda: FORWARDING_INFO,
434                                   self.mock_media_service)
435    session_id = service.create_session(AGGREGATION_REQUIREMENTS)
436    task = asyncio.create_task(
437        service.wait(session_id, num_inputs_aggregated_and_included=1))
438    # The awaitable should not be done yet.
439    await asyncio.wait([task], timeout=0.1)
440    self.assertFalse(task.done())
441
442    # The awaitable should return once the AggregationProtocol aborts.
443    callback = mock_create_simple_agg_protocol.call_args.args[1]
444    callback.OnAbort(absl_status.unknown_error('message'))
445    await asyncio.wait([task], timeout=1)
446    self.assertTrue(task.done())
447    self.assertEqual(
448        task.result(),
449        aggregations.SessionStatus(
450            status=aggregations.AggregationStatus.FAILED,
451            num_clients_aborted=1234))
452
453  async def test_wait_with_complete(self):
454    service = aggregations.Service(lambda: FORWARDING_INFO,
455                                   self.mock_media_service)
456    session_id = service.create_session(
457        aggregations.AggregationRequirements(
458            minimum_clients_in_server_published_aggregate=0,
459            plan=AGGREGATION_REQUIREMENTS.plan))
460    task = asyncio.create_task(
461        service.wait(session_id, num_inputs_aggregated_and_included=1))
462    # The awaitable should not be done yet.
463    await asyncio.wait([task], timeout=0.1)
464    self.assertFalse(task.done())
465
466    # The awaitable should return once the session is completed.
467    status, _ = service.complete_session(session_id)
468    await asyncio.wait([task], timeout=1)
469    self.assertTrue(task.done())
470    self.assertEqual(task.result(), status)
471
472  async def test_wait_without_condition(self):
473    service = aggregations.Service(lambda: FORWARDING_INFO,
474                                   self.mock_media_service)
475    session_id = service.create_session(AGGREGATION_REQUIREMENTS)
476    task = asyncio.create_task(service.wait(session_id))
477    # If there are no conditions, the wait should be trivially satisfied.
478    await asyncio.wait([task], timeout=1)
479    self.assertTrue(task.done())
480    self.assertEqual(task.result(), service.get_session_status(session_id))
481
482  async def test_wait_with_missing_session_id(self):
483    service = aggregations.Service(lambda: FORWARDING_INFO,
484                                   self.mock_media_service)
485    task = asyncio.create_task(service.wait('does-not-exist'))
486    await asyncio.wait([task], timeout=1)
487    self.assertTrue(task.done())
488    self.assertIsInstance(task.exception(), KeyError)
489
490  def test_start_aggregation_data_upload(self):
491    service = aggregations.Service(lambda: FORWARDING_INFO,
492                                   self.mock_media_service)
493    session_id = service.create_session(AGGREGATION_REQUIREMENTS)
494    tokens = service.pre_authorize_clients(session_id, 1)
495    self.mock_media_service.register_upload.return_value = 'upload'
496    operation = service.start_aggregation_data_upload(
497        aggregations_pb2.StartAggregationDataUploadRequest(
498            aggregation_id=session_id, authorization_token=tokens[0]))
499    self.assertNotEmpty(operation.name)
500    self.assertTrue(operation.done)
501
502    metadata = aggregations_pb2.StartAggregationDataUploadMetadata()
503    operation.metadata.Unpack(metadata)
504    self.assertEqual(metadata,
505                     aggregations_pb2.StartAggregationDataUploadMetadata())
506
507    response = aggregations_pb2.StartAggregationDataUploadResponse()
508    operation.response.Unpack(response)
509    # The client token should be set and different from the authorization token.
510    self.assertNotEmpty(response.client_token)
511    self.assertNotEqual(response.client_token, tokens[0])
512    self.assertEqual(
513        response,
514        aggregations_pb2.StartAggregationDataUploadResponse(
515            aggregation_protocol_forwarding_info=FORWARDING_INFO,
516            resource=common_pb2.ByteStreamResource(
517                data_upload_forwarding_info=FORWARDING_INFO,
518                resource_name='upload'),
519            client_token=response.client_token))
520    self.assertEqual(
521        service.get_session_status(session_id),
522        aggregations.SessionStatus(
523            status=aggregations.AggregationStatus.PENDING,
524            num_clients_pending=1))
525
526  def test_start_aggregagation_data_upload_with_missing_session_id(self):
527    service = aggregations.Service(lambda: FORWARDING_INFO,
528                                   self.mock_media_service)
529    session_id = service.create_session(AGGREGATION_REQUIREMENTS)
530    tokens = service.pre_authorize_clients(session_id, 1)
531    with self.assertRaises(http_actions.HttpError) as cm:
532      service.start_aggregation_data_upload(
533          aggregations_pb2.StartAggregationDataUploadRequest(
534              aggregation_id='does-not-exist', authorization_token=tokens[0]))
535    self.assertEqual(cm.exception.code, http.HTTPStatus.NOT_FOUND)
536    self.assertEqual(
537        service.get_session_status(session_id),
538        aggregations.SessionStatus(
539            status=aggregations.AggregationStatus.PENDING))
540
541  def test_start_aggregagation_data_upload_with_invalid_token(self):
542    service = aggregations.Service(lambda: FORWARDING_INFO,
543                                   self.mock_media_service)
544    session_id = service.create_session(AGGREGATION_REQUIREMENTS)
545    with self.assertRaises(http_actions.HttpError) as cm:
546      service.start_aggregation_data_upload(
547          aggregations_pb2.StartAggregationDataUploadRequest(
548              aggregation_id=session_id, authorization_token='does-not-exist'))
549    self.assertEqual(cm.exception.code, http.HTTPStatus.UNAUTHORIZED)
550    self.assertEqual(
551        service.get_session_status(session_id),
552        aggregations.SessionStatus(
553            status=aggregations.AggregationStatus.PENDING))
554
555  def test_start_aggregagation_data_upload_twice(self):
556    service = aggregations.Service(lambda: FORWARDING_INFO,
557                                   self.mock_media_service)
558    session_id = service.create_session(AGGREGATION_REQUIREMENTS)
559    tokens = service.pre_authorize_clients(session_id, 1)
560    service.start_aggregation_data_upload(
561        aggregations_pb2.StartAggregationDataUploadRequest(
562            aggregation_id=session_id, authorization_token=tokens[0]))
563    with self.assertRaises(http_actions.HttpError) as cm:
564      service.start_aggregation_data_upload(
565          aggregations_pb2.StartAggregationDataUploadRequest(
566              aggregation_id=session_id, authorization_token=tokens[0]))
567    self.assertEqual(cm.exception.code, http.HTTPStatus.UNAUTHORIZED)
568
569  def test_submit_aggregation_result(self):
570    service = aggregations.Service(lambda: FORWARDING_INFO,
571                                   self.mock_media_service)
572    session_id = service.create_session(AGGREGATION_REQUIREMENTS)
573
574    # Upload results from the client.
575    tokens = service.pre_authorize_clients(session_id, 1)
576
577    operation = service.start_aggregation_data_upload(
578        aggregations_pb2.StartAggregationDataUploadRequest(
579            aggregation_id=session_id, authorization_token=tokens[0]))
580    self.assertTrue(operation.done)
581    start_upload_response = (
582        aggregations_pb2.StartAggregationDataUploadResponse())
583    operation.response.Unpack(start_upload_response)
584
585    submit_response = service.submit_aggregation_result(
586        aggregations_pb2.SubmitAggregationResultRequest(
587            aggregation_id=session_id,
588            client_token=start_upload_response.client_token,
589            resource_name=start_upload_response.resource.resource_name))
590    self.assertEqual(submit_response,
591                     aggregations_pb2.SubmitAggregationResultResponse())
592    self.mock_media_service.finalize_upload.assert_called_with(
593        start_upload_response.resource.resource_name)
594    self.assertEqual(
595        service.get_session_status(session_id),
596        aggregations.SessionStatus(
597            status=aggregations.AggregationStatus.PENDING,
598            num_clients_completed=1,
599            num_inputs_aggregated_and_included=1))
600
601  def test_submit_aggregation_result_with_invalid_client_input(self):
602    service = aggregations.Service(lambda: FORWARDING_INFO,
603                                   self.mock_media_service)
604    session_id = service.create_session(AGGREGATION_REQUIREMENTS)
605
606    tokens = service.pre_authorize_clients(session_id, 1)
607    operation = service.start_aggregation_data_upload(
608        aggregations_pb2.StartAggregationDataUploadRequest(
609            aggregation_id=session_id, authorization_token=tokens[0]))
610    self.assertTrue(operation.done)
611    start_upload_response = (
612        aggregations_pb2.StartAggregationDataUploadResponse())
613    operation.response.Unpack(start_upload_response)
614
615    self.mock_media_service.finalize_upload.return_value = b'invalid'
616    with self.assertRaises(http_actions.HttpError):
617      service.submit_aggregation_result(
618          aggregations_pb2.SubmitAggregationResultRequest(
619              aggregation_id=session_id,
620              client_token=start_upload_response.client_token,
621              resource_name=start_upload_response.resource.resource_name))
622    self.assertEqual(
623        service.get_session_status(session_id),
624        aggregations.SessionStatus(
625            status=aggregations.AggregationStatus.PENDING,
626            num_clients_failed=1))
627
628  def test_submit_aggregation_result_with_missing_session_id(self):
629    service = aggregations.Service(lambda: FORWARDING_INFO,
630                                   self.mock_media_service)
631    session_id = service.create_session(AGGREGATION_REQUIREMENTS)
632    with self.assertRaises(http_actions.HttpError) as cm:
633      service.submit_aggregation_result(
634          aggregations_pb2.SubmitAggregationResultRequest(
635              aggregation_id='does-not-exist',
636              client_token='client-token',
637              resource_name='upload-id'))
638    self.assertEqual(cm.exception.code, http.HTTPStatus.NOT_FOUND)
639    self.assertEqual(
640        service.get_session_status(session_id),
641        aggregations.SessionStatus(
642            status=aggregations.AggregationStatus.PENDING))
643
644  def test_submit_aggregation_result_with_invalid_token(self):
645    service = aggregations.Service(lambda: FORWARDING_INFO,
646                                   self.mock_media_service)
647    session_id = service.create_session(AGGREGATION_REQUIREMENTS)
648    with self.assertRaises(http_actions.HttpError) as cm:
649      service.submit_aggregation_result(
650          aggregations_pb2.SubmitAggregationResultRequest(
651              aggregation_id=session_id,
652              client_token='does-not-exist',
653              resource_name='upload-id'))
654    self.assertEqual(cm.exception.code, http.HTTPStatus.UNAUTHORIZED)
655    self.assertEqual(
656        service.get_session_status(session_id),
657        aggregations.SessionStatus(
658            status=aggregations.AggregationStatus.PENDING))
659
660  def test_submit_aggregation_result_with_finalize_upload_error(self):
661    service = aggregations.Service(lambda: FORWARDING_INFO,
662                                   self.mock_media_service)
663    session_id = service.create_session(AGGREGATION_REQUIREMENTS)
664    tokens = service.pre_authorize_clients(session_id, 1)
665    operation = service.start_aggregation_data_upload(
666        aggregations_pb2.StartAggregationDataUploadRequest(
667            aggregation_id=session_id, authorization_token=tokens[0]))
668    self.assertTrue(operation.done)
669    start_upload_response = (
670        aggregations_pb2.StartAggregationDataUploadResponse())
671    operation.response.Unpack(start_upload_response)
672
673    # If the resource_name doesn't correspond to a registered upload,
674    # finalize_upload will raise a KeyError.
675    self.mock_media_service.finalize_upload.side_effect = KeyError()
676    with self.assertRaises(http_actions.HttpError) as cm:
677      service.submit_aggregation_result(
678          aggregations_pb2.SubmitAggregationResultRequest(
679              aggregation_id=session_id,
680              client_token=start_upload_response.client_token,
681              resource_name=start_upload_response.resource.resource_name))
682    self.assertEqual(cm.exception.code, http.HTTPStatus.INTERNAL_SERVER_ERROR)
683    self.assertEqual(
684        service.get_session_status(session_id),
685        aggregations.SessionStatus(
686            status=aggregations.AggregationStatus.PENDING,
687            num_clients_failed=1))
688
689  def test_submit_aggregation_result_with_unuploaded_resource(self):
690    service = aggregations.Service(lambda: FORWARDING_INFO,
691                                   self.mock_media_service)
692    session_id = service.create_session(AGGREGATION_REQUIREMENTS)
693    tokens = service.pre_authorize_clients(session_id, 1)
694    operation = service.start_aggregation_data_upload(
695        aggregations_pb2.StartAggregationDataUploadRequest(
696            aggregation_id=session_id, authorization_token=tokens[0]))
697    self.assertTrue(operation.done)
698    start_upload_response = (
699        aggregations_pb2.StartAggregationDataUploadResponse())
700    operation.response.Unpack(start_upload_response)
701
702    # If the resource_name is valid but no resource was uploaded,
703    # finalize_resource will return None.
704    self.mock_media_service.finalize_upload.return_value = None
705    with self.assertRaises(http_actions.HttpError) as cm:
706      service.submit_aggregation_result(
707          aggregations_pb2.SubmitAggregationResultRequest(
708              aggregation_id=session_id,
709              client_token=start_upload_response.client_token,
710              resource_name=start_upload_response.resource.resource_name))
711    self.assertEqual(cm.exception.code, http.HTTPStatus.BAD_REQUEST)
712    self.assertEqual(
713        service.get_session_status(session_id),
714        aggregations.SessionStatus(
715            status=aggregations.AggregationStatus.PENDING,
716            num_clients_failed=1))
717
718  def test_abort_aggregation(self):
719    service = aggregations.Service(lambda: FORWARDING_INFO,
720                                   self.mock_media_service)
721    session_id = service.create_session(AGGREGATION_REQUIREMENTS)
722    tokens = service.pre_authorize_clients(session_id, 1)
723    operation = service.start_aggregation_data_upload(
724        aggregations_pb2.StartAggregationDataUploadRequest(
725            aggregation_id=session_id, authorization_token=tokens[0]))
726    self.assertTrue(operation.done)
727    start_upload_response = (
728        aggregations_pb2.StartAggregationDataUploadResponse())
729    operation.response.Unpack(start_upload_response)
730    self.assertEqual(
731        service.abort_aggregation(
732            aggregations_pb2.AbortAggregationRequest(
733                aggregation_id=session_id,
734                client_token=start_upload_response.client_token)),
735        aggregations_pb2.AbortAggregationResponse())
736    self.assertEqual(
737        service.get_session_status(session_id),
738        aggregations.SessionStatus(
739            status=aggregations.AggregationStatus.PENDING,
740            num_clients_completed=1,
741            num_inputs_discarded=1))
742
743  def test_abort_aggregation_with_missing_session_id(self):
744    service = aggregations.Service(lambda: FORWARDING_INFO,
745                                   self.mock_media_service)
746    session_id = service.create_session(AGGREGATION_REQUIREMENTS)
747    tokens = service.pre_authorize_clients(session_id, 1)
748    operation = service.start_aggregation_data_upload(
749        aggregations_pb2.StartAggregationDataUploadRequest(
750            aggregation_id=session_id, authorization_token=tokens[0]))
751    self.assertTrue(operation.done)
752    start_upload_response = (
753        aggregations_pb2.StartAggregationDataUploadResponse())
754    operation.response.Unpack(start_upload_response)
755    with self.assertRaises(http_actions.HttpError) as cm:
756      service.abort_aggregation(
757          aggregations_pb2.AbortAggregationRequest(
758              aggregation_id='does-not-exist',
759              client_token=start_upload_response.client_token))
760    self.assertEqual(cm.exception.code, http.HTTPStatus.NOT_FOUND)
761    self.assertEqual(
762        service.get_session_status(session_id),
763        aggregations.SessionStatus(
764            status=aggregations.AggregationStatus.PENDING,
765            num_clients_pending=1))
766
767  def test_abort_aggregation_with_invalid_token(self):
768    service = aggregations.Service(lambda: FORWARDING_INFO,
769                                   self.mock_media_service)
770    session_id = service.create_session(AGGREGATION_REQUIREMENTS)
771    with self.assertRaises(http_actions.HttpError) as cm:
772      service.abort_aggregation(
773          aggregations_pb2.AbortAggregationRequest(
774              aggregation_id=session_id, client_token='does-not-exist'))
775    self.assertEqual(cm.exception.code, http.HTTPStatus.UNAUTHORIZED)
776    self.assertEqual(
777        service.get_session_status(session_id),
778        aggregations.SessionStatus(
779            status=aggregations.AggregationStatus.PENDING))
780
781
782if __name__ == '__main__':
783  absltest.main()
784