# Copyright 2022 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Tests for aggregations.""" import asyncio import http import unittest from unittest import mock from absl.testing import absltest import tensorflow as tf from fcp.aggregation.protocol import aggregation_protocol_messages_pb2 as apm_pb2 from fcp.aggregation.protocol.python import aggregation_protocol from fcp.aggregation.tensorflow.python import aggregation_protocols from fcp.demo import aggregations from fcp.demo import http_actions from fcp.demo import media from fcp.demo import test_utils from fcp.protos import plan_pb2 from fcp.protos.federatedcompute import aggregations_pb2 from fcp.protos.federatedcompute import common_pb2 from pybind11_abseil import status as absl_status INPUT_TENSOR = 'in' OUTPUT_TENSOR = 'out' AGGREGATION_REQUIREMENTS = aggregations.AggregationRequirements( minimum_clients_in_server_published_aggregate=3, plan=plan_pb2.Plan(phase=[ plan_pb2.Plan.Phase( server_phase_v2=plan_pb2.ServerPhaseV2(aggregations=[ plan_pb2.ServerAggregationConfig( intrinsic_uri='federated_sum', intrinsic_args=[ plan_pb2.ServerAggregationConfig.IntrinsicArg( input_tensor=tf.TensorSpec(( ), tf.int32, INPUT_TENSOR).experimental_as_proto()) ], output_tensors=[ tf.TensorSpec(( ), tf.int32, OUTPUT_TENSOR).experimental_as_proto(), ]), ])), ])) FORWARDING_INFO = common_pb2.ForwardingInfo( target_uri_prefix='https://forwarding.example/') class NotOkStatus: """Matcher for a not-ok Status.""" def __eq__(self, other) -> bool: return isinstance(other, absl_status.Status) and not other.ok() class AggregationsTest(absltest.TestCase, unittest.IsolatedAsyncioTestCase): def setUp(self): super().setUp() self.mock_media_service = self.enter_context( mock.patch.object(media, 'Service', autospec=True)) self.mock_media_service.register_upload.return_value = 'upload-id' self.mock_media_service.finalize_upload.return_value = ( test_utils.create_checkpoint({INPUT_TENSOR: 0})) def test_pre_authorize_clients(self): service = aggregations.Service(lambda: FORWARDING_INFO, self.mock_media_service) session_id = service.create_session(AGGREGATION_REQUIREMENTS) tokens = service.pre_authorize_clients(session_id, 3) self.assertLen(tokens, 3) # The tokens should all be unique. self.assertLen(set(tokens), 3) def test_pre_authorize_clients_with_missing_session_id(self): service = aggregations.Service(lambda: FORWARDING_INFO, self.mock_media_service) with self.assertRaises(KeyError): service.pre_authorize_clients('does-not-exist', 1) def test_pre_authorize_clients_with_bad_count(self): service = aggregations.Service(lambda: FORWARDING_INFO, self.mock_media_service) session_id = service.create_session(AGGREGATION_REQUIREMENTS) self.assertEmpty(service.pre_authorize_clients(session_id, 0)) self.assertEmpty(service.pre_authorize_clients(session_id, -2)) def test_create_session(self): service = aggregations.Service(lambda: FORWARDING_INFO, self.mock_media_service) session_id = service.create_session(AGGREGATION_REQUIREMENTS) self.assertEqual( service.get_session_status(session_id), aggregations.SessionStatus( status=aggregations.AggregationStatus.PENDING)) def test_complete_session(self): service = aggregations.Service(lambda: FORWARDING_INFO, self.mock_media_service) session_id = service.create_session(AGGREGATION_REQUIREMENTS) # Upload results from the client. num_clients = ( AGGREGATION_REQUIREMENTS.minimum_clients_in_server_published_aggregate) for i in range(num_clients): tokens = service.pre_authorize_clients(session_id, 1) self.mock_media_service.register_upload.return_value = f'upload-{i}' operation = service.start_aggregation_data_upload( aggregations_pb2.StartAggregationDataUploadRequest( aggregation_id=session_id, authorization_token=tokens[0])) self.assertTrue(operation.done) start_upload_response = ( aggregations_pb2.StartAggregationDataUploadResponse()) operation.response.Unpack(start_upload_response) self.mock_media_service.finalize_upload.return_value = ( test_utils.create_checkpoint({INPUT_TENSOR: i})) service.submit_aggregation_result( aggregations_pb2.SubmitAggregationResultRequest( aggregation_id=session_id, client_token=start_upload_response.client_token, resource_name=start_upload_response.resource.resource_name)) # Now that all clients have contributed, the aggregation session can be # completed. status, aggregate = service.complete_session(session_id) self.assertEqual( status, aggregations.SessionStatus( status=aggregations.AggregationStatus.COMPLETED, num_clients_completed=num_clients, num_inputs_aggregated_and_included=num_clients)) self.assertEqual( test_utils.read_tensor_from_checkpoint(aggregate, OUTPUT_TENSOR, tf.int32), sum(range(num_clients))) # get_session_status should no longer return results. with self.assertRaises(KeyError): service.get_session_status(session_id) @mock.patch.object( aggregation_protocols, 'create_simple_aggregation_protocol', autospec=True) def test_complete_session_fails(self, mock_create_simple_agg_protocol): # Use a mock since it's not easy to cause # SimpleAggregationProtocol::Complete to fail. mock_agg_protocol = mock.create_autospec( aggregation_protocol.AggregationProtocol, instance=True) mock_create_simple_agg_protocol.return_value = mock_agg_protocol service = aggregations.Service(lambda: FORWARDING_INFO, self.mock_media_service) session_id = service.create_session(AGGREGATION_REQUIREMENTS) required_clients = ( AGGREGATION_REQUIREMENTS.minimum_clients_in_server_published_aggregate) agg_status = apm_pb2.StatusMessage( num_inputs_aggregated_and_included=required_clients) mock_agg_protocol.GetStatus.side_effect = lambda: agg_status def on_complete(): agg_status.num_inputs_discarded = ( agg_status.num_inputs_aggregated_and_included) agg_status.num_inputs_aggregated_and_included = 0 raise absl_status.StatusNotOk(absl_status.unknown_error('message')) mock_agg_protocol.Complete.side_effect = on_complete status, aggregate = service.complete_session(session_id) self.assertEqual( status, aggregations.SessionStatus( status=aggregations.AggregationStatus.FAILED, num_inputs_discarded=required_clients)) self.assertIsNone(aggregate) mock_agg_protocol.Complete.assert_called_once() mock_agg_protocol.Abort.assert_not_called() @mock.patch.object( aggregation_protocols, 'create_simple_aggregation_protocol', autospec=True) def test_complete_session_aborts(self, mock_create_simple_agg_protocol): # Use a mock since it's not easy to cause # SimpleAggregationProtocol::Complete to trigger a protocol abort. mock_agg_protocol = mock.create_autospec( aggregation_protocol.AggregationProtocol, instance=True) mock_create_simple_agg_protocol.return_value = mock_agg_protocol service = aggregations.Service(lambda: FORWARDING_INFO, self.mock_media_service) session_id = service.create_session(AGGREGATION_REQUIREMENTS) required_clients = ( AGGREGATION_REQUIREMENTS.minimum_clients_in_server_published_aggregate) agg_status = apm_pb2.StatusMessage( num_inputs_aggregated_and_included=required_clients) mock_agg_protocol.GetStatus.side_effect = lambda: agg_status def on_complete(): agg_status.num_inputs_discarded = ( agg_status.num_inputs_aggregated_and_included) agg_status.num_inputs_aggregated_and_included = 0 callback = mock_create_simple_agg_protocol.call_args.args[1] callback.OnAbort(absl_status.unknown_error('message')) mock_agg_protocol.Complete.side_effect = on_complete status, aggregate = service.complete_session(session_id) self.assertEqual( status, aggregations.SessionStatus( status=aggregations.AggregationStatus.FAILED, num_inputs_discarded=required_clients)) self.assertIsNone(aggregate) mock_agg_protocol.Complete.assert_called_once() mock_agg_protocol.Abort.assert_not_called() def test_complete_session_without_enough_inputs(self): service = aggregations.Service(lambda: FORWARDING_INFO, self.mock_media_service) session_id = service.create_session( aggregations.AggregationRequirements( minimum_clients_in_server_published_aggregate=3, plan=AGGREGATION_REQUIREMENTS.plan)) tokens = service.pre_authorize_clients(session_id, 2) # Upload results for one client. operation = service.start_aggregation_data_upload( aggregations_pb2.StartAggregationDataUploadRequest( aggregation_id=session_id, authorization_token=tokens[0])) self.assertTrue(operation.done) start_upload_response = ( aggregations_pb2.StartAggregationDataUploadResponse()) operation.response.Unpack(start_upload_response) service.submit_aggregation_result( aggregations_pb2.SubmitAggregationResultRequest( aggregation_id=session_id, client_token=start_upload_response.client_token, resource_name=start_upload_response.resource.resource_name)) # Complete the session before there are 2 completed clients. status, aggregate = service.complete_session(session_id) self.assertEqual( status, aggregations.SessionStatus( status=aggregations.AggregationStatus.FAILED, num_clients_completed=1, num_inputs_discarded=1)) self.assertIsNone(aggregate) def test_complete_session_with_missing_session_id(self): service = aggregations.Service(lambda: FORWARDING_INFO, self.mock_media_service) with self.assertRaises(KeyError): service.complete_session('does-not-exist') def test_abort_session_with_no_uploads(self): service = aggregations.Service(lambda: FORWARDING_INFO, self.mock_media_service) session_id = service.create_session(AGGREGATION_REQUIREMENTS) self.assertEqual( service.abort_session(session_id), aggregations.SessionStatus( status=aggregations.AggregationStatus.ABORTED)) def test_abort_session_with_uploads(self): service = aggregations.Service(lambda: FORWARDING_INFO, self.mock_media_service) session_id = service.create_session(AGGREGATION_REQUIREMENTS) tokens = service.pre_authorize_clients(session_id, 3) # Upload results for one client. self.mock_media_service.register_upload.return_value = 'upload1' operation = service.start_aggregation_data_upload( aggregations_pb2.StartAggregationDataUploadRequest( aggregation_id=session_id, authorization_token=tokens[0])) self.assertTrue(operation.done) start_upload_response = ( aggregations_pb2.StartAggregationDataUploadResponse()) operation.response.Unpack(start_upload_response) service.submit_aggregation_result( aggregations_pb2.SubmitAggregationResultRequest( aggregation_id=session_id, client_token=start_upload_response.client_token, resource_name=start_upload_response.resource.resource_name)) # Start a partial upload from a second client. self.mock_media_service.register_upload.return_value = 'upload2' service.start_aggregation_data_upload( aggregations_pb2.StartAggregationDataUploadRequest( aggregation_id=session_id, authorization_token=tokens[1])) # Abort the session. The pending client should be treated as failed. self.assertEqual( service.abort_session(session_id), aggregations.SessionStatus( status=aggregations.AggregationStatus.ABORTED, num_clients_completed=1, num_clients_aborted=1, num_inputs_discarded=1)) # The registered upload for the second client should have been finalized. self.mock_media_service.finalize_upload.assert_called_with('upload2') # get_session_status should no longer return results. with self.assertRaises(KeyError): service.get_session_status(session_id) def test_abort_session_with_missing_session_id(self): service = aggregations.Service(lambda: FORWARDING_INFO, self.mock_media_service) with self.assertRaises(KeyError): service.abort_session('does-not-exist') def test_get_session_status_with_missing_session_id(self): service = aggregations.Service(lambda: FORWARDING_INFO, self.mock_media_service) with self.assertRaises(KeyError): service.get_session_status('does-not-exist') async def test_wait(self): service = aggregations.Service(lambda: FORWARDING_INFO, self.mock_media_service) session_id = service.create_session(AGGREGATION_REQUIREMENTS) task = asyncio.create_task( service.wait(session_id, num_inputs_aggregated_and_included=1)) # The awaitable should not be done yet. await asyncio.wait([task], timeout=0.1) self.assertFalse(task.done()) # Upload results for one client. tokens = service.pre_authorize_clients(session_id, 1) self.mock_media_service.register_upload.return_value = 'upload' operation = service.start_aggregation_data_upload( aggregations_pb2.StartAggregationDataUploadRequest( aggregation_id=session_id, authorization_token=tokens[0])) self.assertTrue(operation.done) start_upload_response = ( aggregations_pb2.StartAggregationDataUploadResponse()) operation.response.Unpack(start_upload_response) service.submit_aggregation_result( aggregations_pb2.SubmitAggregationResultRequest( aggregation_id=session_id, client_token=start_upload_response.client_token, resource_name=start_upload_response.resource.resource_name)) # The awaitable should now return. await asyncio.wait([task], timeout=1) self.assertTrue(task.done()) self.assertEqual( task.result(), aggregations.SessionStatus( status=aggregations.AggregationStatus.PENDING, num_clients_completed=1, num_inputs_aggregated_and_included=1)) async def test_wait_already_satisfied(self): service = aggregations.Service(lambda: FORWARDING_INFO, self.mock_media_service) session_id = service.create_session(AGGREGATION_REQUIREMENTS) # Upload results for one client. tokens = service.pre_authorize_clients(session_id, 1) self.mock_media_service.register_upload.return_value = 'upload' operation = service.start_aggregation_data_upload( aggregations_pb2.StartAggregationDataUploadRequest( aggregation_id=session_id, authorization_token=tokens[0])) self.assertTrue(operation.done) start_upload_response = ( aggregations_pb2.StartAggregationDataUploadResponse()) operation.response.Unpack(start_upload_response) service.submit_aggregation_result( aggregations_pb2.SubmitAggregationResultRequest( aggregation_id=session_id, client_token=start_upload_response.client_token, resource_name=start_upload_response.resource.resource_name)) # Since a client has already reported, the condition should already be # satisfied. task = asyncio.create_task( service.wait(session_id, num_inputs_aggregated_and_included=1)) await asyncio.wait([task], timeout=1) self.assertTrue(task.done()) self.assertEqual( task.result(), aggregations.SessionStatus( status=aggregations.AggregationStatus.PENDING, num_clients_completed=1, num_inputs_aggregated_and_included=1)) async def test_wait_with_abort(self): service = aggregations.Service(lambda: FORWARDING_INFO, self.mock_media_service) session_id = service.create_session(AGGREGATION_REQUIREMENTS) task = asyncio.create_task( service.wait(session_id, num_inputs_aggregated_and_included=1)) # The awaitable should not be done yet. await asyncio.wait([task], timeout=0.1) self.assertFalse(task.done()) # The awaitable should return once the session is aborted. status = service.abort_session(session_id) await asyncio.wait([task], timeout=1) self.assertTrue(task.done()) self.assertEqual(task.result(), status) @mock.patch.object( aggregation_protocols, 'create_simple_aggregation_protocol', autospec=True) async def test_wait_with_protocol_abort(self, mock_create_simple_agg_protocol): # Use a mock since it's not easy to cause the AggregationProtocol to abort. mock_agg_protocol = mock.create_autospec( aggregation_protocol.AggregationProtocol, instance=True) mock_create_simple_agg_protocol.return_value = mock_agg_protocol mock_agg_protocol.GetStatus.return_value = apm_pb2.StatusMessage( num_clients_aborted=1234) service = aggregations.Service(lambda: FORWARDING_INFO, self.mock_media_service) session_id = service.create_session(AGGREGATION_REQUIREMENTS) task = asyncio.create_task( service.wait(session_id, num_inputs_aggregated_and_included=1)) # The awaitable should not be done yet. await asyncio.wait([task], timeout=0.1) self.assertFalse(task.done()) # The awaitable should return once the AggregationProtocol aborts. callback = mock_create_simple_agg_protocol.call_args.args[1] callback.OnAbort(absl_status.unknown_error('message')) await asyncio.wait([task], timeout=1) self.assertTrue(task.done()) self.assertEqual( task.result(), aggregations.SessionStatus( status=aggregations.AggregationStatus.FAILED, num_clients_aborted=1234)) async def test_wait_with_complete(self): service = aggregations.Service(lambda: FORWARDING_INFO, self.mock_media_service) session_id = service.create_session( aggregations.AggregationRequirements( minimum_clients_in_server_published_aggregate=0, plan=AGGREGATION_REQUIREMENTS.plan)) task = asyncio.create_task( service.wait(session_id, num_inputs_aggregated_and_included=1)) # The awaitable should not be done yet. await asyncio.wait([task], timeout=0.1) self.assertFalse(task.done()) # The awaitable should return once the session is completed. status, _ = service.complete_session(session_id) await asyncio.wait([task], timeout=1) self.assertTrue(task.done()) self.assertEqual(task.result(), status) async def test_wait_without_condition(self): service = aggregations.Service(lambda: FORWARDING_INFO, self.mock_media_service) session_id = service.create_session(AGGREGATION_REQUIREMENTS) task = asyncio.create_task(service.wait(session_id)) # If there are no conditions, the wait should be trivially satisfied. await asyncio.wait([task], timeout=1) self.assertTrue(task.done()) self.assertEqual(task.result(), service.get_session_status(session_id)) async def test_wait_with_missing_session_id(self): service = aggregations.Service(lambda: FORWARDING_INFO, self.mock_media_service) task = asyncio.create_task(service.wait('does-not-exist')) await asyncio.wait([task], timeout=1) self.assertTrue(task.done()) self.assertIsInstance(task.exception(), KeyError) def test_start_aggregation_data_upload(self): service = aggregations.Service(lambda: FORWARDING_INFO, self.mock_media_service) session_id = service.create_session(AGGREGATION_REQUIREMENTS) tokens = service.pre_authorize_clients(session_id, 1) self.mock_media_service.register_upload.return_value = 'upload' operation = service.start_aggregation_data_upload( aggregations_pb2.StartAggregationDataUploadRequest( aggregation_id=session_id, authorization_token=tokens[0])) self.assertNotEmpty(operation.name) self.assertTrue(operation.done) metadata = aggregations_pb2.StartAggregationDataUploadMetadata() operation.metadata.Unpack(metadata) self.assertEqual(metadata, aggregations_pb2.StartAggregationDataUploadMetadata()) response = aggregations_pb2.StartAggregationDataUploadResponse() operation.response.Unpack(response) # The client token should be set and different from the authorization token. self.assertNotEmpty(response.client_token) self.assertNotEqual(response.client_token, tokens[0]) self.assertEqual( response, aggregations_pb2.StartAggregationDataUploadResponse( aggregation_protocol_forwarding_info=FORWARDING_INFO, resource=common_pb2.ByteStreamResource( data_upload_forwarding_info=FORWARDING_INFO, resource_name='upload'), client_token=response.client_token)) self.assertEqual( service.get_session_status(session_id), aggregations.SessionStatus( status=aggregations.AggregationStatus.PENDING, num_clients_pending=1)) def test_start_aggregagation_data_upload_with_missing_session_id(self): service = aggregations.Service(lambda: FORWARDING_INFO, self.mock_media_service) session_id = service.create_session(AGGREGATION_REQUIREMENTS) tokens = service.pre_authorize_clients(session_id, 1) with self.assertRaises(http_actions.HttpError) as cm: service.start_aggregation_data_upload( aggregations_pb2.StartAggregationDataUploadRequest( aggregation_id='does-not-exist', authorization_token=tokens[0])) self.assertEqual(cm.exception.code, http.HTTPStatus.NOT_FOUND) self.assertEqual( service.get_session_status(session_id), aggregations.SessionStatus( status=aggregations.AggregationStatus.PENDING)) def test_start_aggregagation_data_upload_with_invalid_token(self): service = aggregations.Service(lambda: FORWARDING_INFO, self.mock_media_service) session_id = service.create_session(AGGREGATION_REQUIREMENTS) with self.assertRaises(http_actions.HttpError) as cm: service.start_aggregation_data_upload( aggregations_pb2.StartAggregationDataUploadRequest( aggregation_id=session_id, authorization_token='does-not-exist')) self.assertEqual(cm.exception.code, http.HTTPStatus.UNAUTHORIZED) self.assertEqual( service.get_session_status(session_id), aggregations.SessionStatus( status=aggregations.AggregationStatus.PENDING)) def test_start_aggregagation_data_upload_twice(self): service = aggregations.Service(lambda: FORWARDING_INFO, self.mock_media_service) session_id = service.create_session(AGGREGATION_REQUIREMENTS) tokens = service.pre_authorize_clients(session_id, 1) service.start_aggregation_data_upload( aggregations_pb2.StartAggregationDataUploadRequest( aggregation_id=session_id, authorization_token=tokens[0])) with self.assertRaises(http_actions.HttpError) as cm: service.start_aggregation_data_upload( aggregations_pb2.StartAggregationDataUploadRequest( aggregation_id=session_id, authorization_token=tokens[0])) self.assertEqual(cm.exception.code, http.HTTPStatus.UNAUTHORIZED) def test_submit_aggregation_result(self): service = aggregations.Service(lambda: FORWARDING_INFO, self.mock_media_service) session_id = service.create_session(AGGREGATION_REQUIREMENTS) # Upload results from the client. tokens = service.pre_authorize_clients(session_id, 1) operation = service.start_aggregation_data_upload( aggregations_pb2.StartAggregationDataUploadRequest( aggregation_id=session_id, authorization_token=tokens[0])) self.assertTrue(operation.done) start_upload_response = ( aggregations_pb2.StartAggregationDataUploadResponse()) operation.response.Unpack(start_upload_response) submit_response = service.submit_aggregation_result( aggregations_pb2.SubmitAggregationResultRequest( aggregation_id=session_id, client_token=start_upload_response.client_token, resource_name=start_upload_response.resource.resource_name)) self.assertEqual(submit_response, aggregations_pb2.SubmitAggregationResultResponse()) self.mock_media_service.finalize_upload.assert_called_with( start_upload_response.resource.resource_name) self.assertEqual( service.get_session_status(session_id), aggregations.SessionStatus( status=aggregations.AggregationStatus.PENDING, num_clients_completed=1, num_inputs_aggregated_and_included=1)) def test_submit_aggregation_result_with_invalid_client_input(self): service = aggregations.Service(lambda: FORWARDING_INFO, self.mock_media_service) session_id = service.create_session(AGGREGATION_REQUIREMENTS) tokens = service.pre_authorize_clients(session_id, 1) operation = service.start_aggregation_data_upload( aggregations_pb2.StartAggregationDataUploadRequest( aggregation_id=session_id, authorization_token=tokens[0])) self.assertTrue(operation.done) start_upload_response = ( aggregations_pb2.StartAggregationDataUploadResponse()) operation.response.Unpack(start_upload_response) self.mock_media_service.finalize_upload.return_value = b'invalid' with self.assertRaises(http_actions.HttpError): service.submit_aggregation_result( aggregations_pb2.SubmitAggregationResultRequest( aggregation_id=session_id, client_token=start_upload_response.client_token, resource_name=start_upload_response.resource.resource_name)) self.assertEqual( service.get_session_status(session_id), aggregations.SessionStatus( status=aggregations.AggregationStatus.PENDING, num_clients_failed=1)) def test_submit_aggregation_result_with_missing_session_id(self): service = aggregations.Service(lambda: FORWARDING_INFO, self.mock_media_service) session_id = service.create_session(AGGREGATION_REQUIREMENTS) with self.assertRaises(http_actions.HttpError) as cm: service.submit_aggregation_result( aggregations_pb2.SubmitAggregationResultRequest( aggregation_id='does-not-exist', client_token='client-token', resource_name='upload-id')) self.assertEqual(cm.exception.code, http.HTTPStatus.NOT_FOUND) self.assertEqual( service.get_session_status(session_id), aggregations.SessionStatus( status=aggregations.AggregationStatus.PENDING)) def test_submit_aggregation_result_with_invalid_token(self): service = aggregations.Service(lambda: FORWARDING_INFO, self.mock_media_service) session_id = service.create_session(AGGREGATION_REQUIREMENTS) with self.assertRaises(http_actions.HttpError) as cm: service.submit_aggregation_result( aggregations_pb2.SubmitAggregationResultRequest( aggregation_id=session_id, client_token='does-not-exist', resource_name='upload-id')) self.assertEqual(cm.exception.code, http.HTTPStatus.UNAUTHORIZED) self.assertEqual( service.get_session_status(session_id), aggregations.SessionStatus( status=aggregations.AggregationStatus.PENDING)) def test_submit_aggregation_result_with_finalize_upload_error(self): service = aggregations.Service(lambda: FORWARDING_INFO, self.mock_media_service) session_id = service.create_session(AGGREGATION_REQUIREMENTS) tokens = service.pre_authorize_clients(session_id, 1) operation = service.start_aggregation_data_upload( aggregations_pb2.StartAggregationDataUploadRequest( aggregation_id=session_id, authorization_token=tokens[0])) self.assertTrue(operation.done) start_upload_response = ( aggregations_pb2.StartAggregationDataUploadResponse()) operation.response.Unpack(start_upload_response) # If the resource_name doesn't correspond to a registered upload, # finalize_upload will raise a KeyError. self.mock_media_service.finalize_upload.side_effect = KeyError() with self.assertRaises(http_actions.HttpError) as cm: service.submit_aggregation_result( aggregations_pb2.SubmitAggregationResultRequest( aggregation_id=session_id, client_token=start_upload_response.client_token, resource_name=start_upload_response.resource.resource_name)) self.assertEqual(cm.exception.code, http.HTTPStatus.INTERNAL_SERVER_ERROR) self.assertEqual( service.get_session_status(session_id), aggregations.SessionStatus( status=aggregations.AggregationStatus.PENDING, num_clients_failed=1)) def test_submit_aggregation_result_with_unuploaded_resource(self): service = aggregations.Service(lambda: FORWARDING_INFO, self.mock_media_service) session_id = service.create_session(AGGREGATION_REQUIREMENTS) tokens = service.pre_authorize_clients(session_id, 1) operation = service.start_aggregation_data_upload( aggregations_pb2.StartAggregationDataUploadRequest( aggregation_id=session_id, authorization_token=tokens[0])) self.assertTrue(operation.done) start_upload_response = ( aggregations_pb2.StartAggregationDataUploadResponse()) operation.response.Unpack(start_upload_response) # If the resource_name is valid but no resource was uploaded, # finalize_resource will return None. self.mock_media_service.finalize_upload.return_value = None with self.assertRaises(http_actions.HttpError) as cm: service.submit_aggregation_result( aggregations_pb2.SubmitAggregationResultRequest( aggregation_id=session_id, client_token=start_upload_response.client_token, resource_name=start_upload_response.resource.resource_name)) self.assertEqual(cm.exception.code, http.HTTPStatus.BAD_REQUEST) self.assertEqual( service.get_session_status(session_id), aggregations.SessionStatus( status=aggregations.AggregationStatus.PENDING, num_clients_failed=1)) def test_abort_aggregation(self): service = aggregations.Service(lambda: FORWARDING_INFO, self.mock_media_service) session_id = service.create_session(AGGREGATION_REQUIREMENTS) tokens = service.pre_authorize_clients(session_id, 1) operation = service.start_aggregation_data_upload( aggregations_pb2.StartAggregationDataUploadRequest( aggregation_id=session_id, authorization_token=tokens[0])) self.assertTrue(operation.done) start_upload_response = ( aggregations_pb2.StartAggregationDataUploadResponse()) operation.response.Unpack(start_upload_response) self.assertEqual( service.abort_aggregation( aggregations_pb2.AbortAggregationRequest( aggregation_id=session_id, client_token=start_upload_response.client_token)), aggregations_pb2.AbortAggregationResponse()) self.assertEqual( service.get_session_status(session_id), aggregations.SessionStatus( status=aggregations.AggregationStatus.PENDING, num_clients_completed=1, num_inputs_discarded=1)) def test_abort_aggregation_with_missing_session_id(self): service = aggregations.Service(lambda: FORWARDING_INFO, self.mock_media_service) session_id = service.create_session(AGGREGATION_REQUIREMENTS) tokens = service.pre_authorize_clients(session_id, 1) operation = service.start_aggregation_data_upload( aggregations_pb2.StartAggregationDataUploadRequest( aggregation_id=session_id, authorization_token=tokens[0])) self.assertTrue(operation.done) start_upload_response = ( aggregations_pb2.StartAggregationDataUploadResponse()) operation.response.Unpack(start_upload_response) with self.assertRaises(http_actions.HttpError) as cm: service.abort_aggregation( aggregations_pb2.AbortAggregationRequest( aggregation_id='does-not-exist', client_token=start_upload_response.client_token)) self.assertEqual(cm.exception.code, http.HTTPStatus.NOT_FOUND) self.assertEqual( service.get_session_status(session_id), aggregations.SessionStatus( status=aggregations.AggregationStatus.PENDING, num_clients_pending=1)) def test_abort_aggregation_with_invalid_token(self): service = aggregations.Service(lambda: FORWARDING_INFO, self.mock_media_service) session_id = service.create_session(AGGREGATION_REQUIREMENTS) with self.assertRaises(http_actions.HttpError) as cm: service.abort_aggregation( aggregations_pb2.AbortAggregationRequest( aggregation_id=session_id, client_token='does-not-exist')) self.assertEqual(cm.exception.code, http.HTTPStatus.UNAUTHORIZED) self.assertEqual( service.get_session_status(session_id), aggregations.SessionStatus( status=aggregations.AggregationStatus.PENDING)) if __name__ == '__main__': absltest.main()