1*14675a02SAndroid Build Coastguard Worker# Copyright 2022 Google LLC 2*14675a02SAndroid Build Coastguard Worker# 3*14675a02SAndroid Build Coastguard Worker# Licensed under the Apache License, Version 2.0 (the "License"); 4*14675a02SAndroid Build Coastguard Worker# you may not use this file except in compliance with the License. 5*14675a02SAndroid Build Coastguard Worker# You may obtain a copy of the License at 6*14675a02SAndroid Build Coastguard Worker# 7*14675a02SAndroid Build Coastguard Worker# http://www.apache.org/licenses/LICENSE-2.0 8*14675a02SAndroid Build Coastguard Worker# 9*14675a02SAndroid Build Coastguard Worker# Unless required by applicable law or agreed to in writing, software 10*14675a02SAndroid Build Coastguard Worker# distributed under the License is distributed on an "AS IS" BASIS, 11*14675a02SAndroid Build Coastguard Worker# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12*14675a02SAndroid Build Coastguard Worker# See the License for the specific language governing permissions and 13*14675a02SAndroid Build Coastguard Worker# limitations under the License. 14*14675a02SAndroid Build Coastguard Worker"""Tests for aggregations.""" 15*14675a02SAndroid Build Coastguard Worker 16*14675a02SAndroid Build Coastguard Workerimport asyncio 17*14675a02SAndroid Build Coastguard Workerimport http 18*14675a02SAndroid Build Coastguard Workerimport unittest 19*14675a02SAndroid Build Coastguard Workerfrom unittest import mock 20*14675a02SAndroid Build Coastguard Worker 21*14675a02SAndroid Build Coastguard Workerfrom absl.testing import absltest 22*14675a02SAndroid Build Coastguard Workerimport tensorflow as tf 23*14675a02SAndroid Build Coastguard Worker 24*14675a02SAndroid Build Coastguard Workerfrom fcp.aggregation.protocol import aggregation_protocol_messages_pb2 as apm_pb2 25*14675a02SAndroid Build Coastguard Workerfrom fcp.aggregation.protocol.python import aggregation_protocol 26*14675a02SAndroid Build Coastguard Workerfrom fcp.aggregation.tensorflow.python import aggregation_protocols 27*14675a02SAndroid Build Coastguard Workerfrom fcp.demo import aggregations 28*14675a02SAndroid Build Coastguard Workerfrom fcp.demo import http_actions 29*14675a02SAndroid Build Coastguard Workerfrom fcp.demo import media 30*14675a02SAndroid Build Coastguard Workerfrom fcp.demo import test_utils 31*14675a02SAndroid Build Coastguard Workerfrom fcp.protos import plan_pb2 32*14675a02SAndroid Build Coastguard Workerfrom fcp.protos.federatedcompute import aggregations_pb2 33*14675a02SAndroid Build Coastguard Workerfrom fcp.protos.federatedcompute import common_pb2 34*14675a02SAndroid Build Coastguard Workerfrom pybind11_abseil import status as absl_status 35*14675a02SAndroid Build Coastguard Worker 36*14675a02SAndroid Build Coastguard WorkerINPUT_TENSOR = 'in' 37*14675a02SAndroid Build Coastguard WorkerOUTPUT_TENSOR = 'out' 38*14675a02SAndroid Build Coastguard WorkerAGGREGATION_REQUIREMENTS = aggregations.AggregationRequirements( 39*14675a02SAndroid Build Coastguard Worker minimum_clients_in_server_published_aggregate=3, 40*14675a02SAndroid Build Coastguard Worker plan=plan_pb2.Plan(phase=[ 41*14675a02SAndroid Build Coastguard Worker plan_pb2.Plan.Phase( 42*14675a02SAndroid Build Coastguard Worker server_phase_v2=plan_pb2.ServerPhaseV2(aggregations=[ 43*14675a02SAndroid Build Coastguard Worker plan_pb2.ServerAggregationConfig( 44*14675a02SAndroid Build Coastguard Worker intrinsic_uri='federated_sum', 45*14675a02SAndroid Build Coastguard Worker intrinsic_args=[ 46*14675a02SAndroid Build Coastguard Worker plan_pb2.ServerAggregationConfig.IntrinsicArg( 47*14675a02SAndroid Build Coastguard Worker input_tensor=tf.TensorSpec(( 48*14675a02SAndroid Build Coastguard Worker ), tf.int32, INPUT_TENSOR).experimental_as_proto()) 49*14675a02SAndroid Build Coastguard Worker ], 50*14675a02SAndroid Build Coastguard Worker output_tensors=[ 51*14675a02SAndroid Build Coastguard Worker tf.TensorSpec(( 52*14675a02SAndroid Build Coastguard Worker ), tf.int32, OUTPUT_TENSOR).experimental_as_proto(), 53*14675a02SAndroid Build Coastguard Worker ]), 54*14675a02SAndroid Build Coastguard Worker ])), 55*14675a02SAndroid Build Coastguard Worker ])) 56*14675a02SAndroid Build Coastguard WorkerFORWARDING_INFO = common_pb2.ForwardingInfo( 57*14675a02SAndroid Build Coastguard Worker target_uri_prefix='https://forwarding.example/') 58*14675a02SAndroid Build Coastguard Worker 59*14675a02SAndroid Build Coastguard Worker 60*14675a02SAndroid Build Coastguard Workerclass NotOkStatus: 61*14675a02SAndroid Build Coastguard Worker """Matcher for a not-ok Status.""" 62*14675a02SAndroid Build Coastguard Worker 63*14675a02SAndroid Build Coastguard Worker def __eq__(self, other) -> bool: 64*14675a02SAndroid Build Coastguard Worker return isinstance(other, absl_status.Status) and not other.ok() 65*14675a02SAndroid Build Coastguard Worker 66*14675a02SAndroid Build Coastguard Worker 67*14675a02SAndroid Build Coastguard Workerclass AggregationsTest(absltest.TestCase, unittest.IsolatedAsyncioTestCase): 68*14675a02SAndroid Build Coastguard Worker 69*14675a02SAndroid Build Coastguard Worker def setUp(self): 70*14675a02SAndroid Build Coastguard Worker super().setUp() 71*14675a02SAndroid Build Coastguard Worker self.mock_media_service = self.enter_context( 72*14675a02SAndroid Build Coastguard Worker mock.patch.object(media, 'Service', autospec=True)) 73*14675a02SAndroid Build Coastguard Worker self.mock_media_service.register_upload.return_value = 'upload-id' 74*14675a02SAndroid Build Coastguard Worker self.mock_media_service.finalize_upload.return_value = ( 75*14675a02SAndroid Build Coastguard Worker test_utils.create_checkpoint({INPUT_TENSOR: 0})) 76*14675a02SAndroid Build Coastguard Worker 77*14675a02SAndroid Build Coastguard Worker def test_pre_authorize_clients(self): 78*14675a02SAndroid Build Coastguard Worker service = aggregations.Service(lambda: FORWARDING_INFO, 79*14675a02SAndroid Build Coastguard Worker self.mock_media_service) 80*14675a02SAndroid Build Coastguard Worker session_id = service.create_session(AGGREGATION_REQUIREMENTS) 81*14675a02SAndroid Build Coastguard Worker tokens = service.pre_authorize_clients(session_id, 3) 82*14675a02SAndroid Build Coastguard Worker self.assertLen(tokens, 3) 83*14675a02SAndroid Build Coastguard Worker # The tokens should all be unique. 84*14675a02SAndroid Build Coastguard Worker self.assertLen(set(tokens), 3) 85*14675a02SAndroid Build Coastguard Worker 86*14675a02SAndroid Build Coastguard Worker def test_pre_authorize_clients_with_missing_session_id(self): 87*14675a02SAndroid Build Coastguard Worker service = aggregations.Service(lambda: FORWARDING_INFO, 88*14675a02SAndroid Build Coastguard Worker self.mock_media_service) 89*14675a02SAndroid Build Coastguard Worker with self.assertRaises(KeyError): 90*14675a02SAndroid Build Coastguard Worker service.pre_authorize_clients('does-not-exist', 1) 91*14675a02SAndroid Build Coastguard Worker 92*14675a02SAndroid Build Coastguard Worker def test_pre_authorize_clients_with_bad_count(self): 93*14675a02SAndroid Build Coastguard Worker service = aggregations.Service(lambda: FORWARDING_INFO, 94*14675a02SAndroid Build Coastguard Worker self.mock_media_service) 95*14675a02SAndroid Build Coastguard Worker session_id = service.create_session(AGGREGATION_REQUIREMENTS) 96*14675a02SAndroid Build Coastguard Worker self.assertEmpty(service.pre_authorize_clients(session_id, 0)) 97*14675a02SAndroid Build Coastguard Worker self.assertEmpty(service.pre_authorize_clients(session_id, -2)) 98*14675a02SAndroid Build Coastguard Worker 99*14675a02SAndroid Build Coastguard Worker def test_create_session(self): 100*14675a02SAndroid Build Coastguard Worker service = aggregations.Service(lambda: FORWARDING_INFO, 101*14675a02SAndroid Build Coastguard Worker self.mock_media_service) 102*14675a02SAndroid Build Coastguard Worker session_id = service.create_session(AGGREGATION_REQUIREMENTS) 103*14675a02SAndroid Build Coastguard Worker self.assertEqual( 104*14675a02SAndroid Build Coastguard Worker service.get_session_status(session_id), 105*14675a02SAndroid Build Coastguard Worker aggregations.SessionStatus( 106*14675a02SAndroid Build Coastguard Worker status=aggregations.AggregationStatus.PENDING)) 107*14675a02SAndroid Build Coastguard Worker 108*14675a02SAndroid Build Coastguard Worker def test_complete_session(self): 109*14675a02SAndroid Build Coastguard Worker service = aggregations.Service(lambda: FORWARDING_INFO, 110*14675a02SAndroid Build Coastguard Worker self.mock_media_service) 111*14675a02SAndroid Build Coastguard Worker session_id = service.create_session(AGGREGATION_REQUIREMENTS) 112*14675a02SAndroid Build Coastguard Worker 113*14675a02SAndroid Build Coastguard Worker # Upload results from the client. 114*14675a02SAndroid Build Coastguard Worker num_clients = ( 115*14675a02SAndroid Build Coastguard Worker AGGREGATION_REQUIREMENTS.minimum_clients_in_server_published_aggregate) 116*14675a02SAndroid Build Coastguard Worker for i in range(num_clients): 117*14675a02SAndroid Build Coastguard Worker tokens = service.pre_authorize_clients(session_id, 1) 118*14675a02SAndroid Build Coastguard Worker 119*14675a02SAndroid Build Coastguard Worker self.mock_media_service.register_upload.return_value = f'upload-{i}' 120*14675a02SAndroid Build Coastguard Worker operation = service.start_aggregation_data_upload( 121*14675a02SAndroid Build Coastguard Worker aggregations_pb2.StartAggregationDataUploadRequest( 122*14675a02SAndroid Build Coastguard Worker aggregation_id=session_id, authorization_token=tokens[0])) 123*14675a02SAndroid Build Coastguard Worker self.assertTrue(operation.done) 124*14675a02SAndroid Build Coastguard Worker start_upload_response = ( 125*14675a02SAndroid Build Coastguard Worker aggregations_pb2.StartAggregationDataUploadResponse()) 126*14675a02SAndroid Build Coastguard Worker operation.response.Unpack(start_upload_response) 127*14675a02SAndroid Build Coastguard Worker 128*14675a02SAndroid Build Coastguard Worker self.mock_media_service.finalize_upload.return_value = ( 129*14675a02SAndroid Build Coastguard Worker test_utils.create_checkpoint({INPUT_TENSOR: i})) 130*14675a02SAndroid Build Coastguard Worker service.submit_aggregation_result( 131*14675a02SAndroid Build Coastguard Worker aggregations_pb2.SubmitAggregationResultRequest( 132*14675a02SAndroid Build Coastguard Worker aggregation_id=session_id, 133*14675a02SAndroid Build Coastguard Worker client_token=start_upload_response.client_token, 134*14675a02SAndroid Build Coastguard Worker resource_name=start_upload_response.resource.resource_name)) 135*14675a02SAndroid Build Coastguard Worker 136*14675a02SAndroid Build Coastguard Worker # Now that all clients have contributed, the aggregation session can be 137*14675a02SAndroid Build Coastguard Worker # completed. 138*14675a02SAndroid Build Coastguard Worker status, aggregate = service.complete_session(session_id) 139*14675a02SAndroid Build Coastguard Worker self.assertEqual( 140*14675a02SAndroid Build Coastguard Worker status, 141*14675a02SAndroid Build Coastguard Worker aggregations.SessionStatus( 142*14675a02SAndroid Build Coastguard Worker status=aggregations.AggregationStatus.COMPLETED, 143*14675a02SAndroid Build Coastguard Worker num_clients_completed=num_clients, 144*14675a02SAndroid Build Coastguard Worker num_inputs_aggregated_and_included=num_clients)) 145*14675a02SAndroid Build Coastguard Worker self.assertEqual( 146*14675a02SAndroid Build Coastguard Worker test_utils.read_tensor_from_checkpoint(aggregate, 147*14675a02SAndroid Build Coastguard Worker OUTPUT_TENSOR, tf.int32), 148*14675a02SAndroid Build Coastguard Worker sum(range(num_clients))) 149*14675a02SAndroid Build Coastguard Worker 150*14675a02SAndroid Build Coastguard Worker # get_session_status should no longer return results. 151*14675a02SAndroid Build Coastguard Worker with self.assertRaises(KeyError): 152*14675a02SAndroid Build Coastguard Worker service.get_session_status(session_id) 153*14675a02SAndroid Build Coastguard Worker 154*14675a02SAndroid Build Coastguard Worker @mock.patch.object( 155*14675a02SAndroid Build Coastguard Worker aggregation_protocols, 156*14675a02SAndroid Build Coastguard Worker 'create_simple_aggregation_protocol', 157*14675a02SAndroid Build Coastguard Worker autospec=True) 158*14675a02SAndroid Build Coastguard Worker def test_complete_session_fails(self, mock_create_simple_agg_protocol): 159*14675a02SAndroid Build Coastguard Worker # Use a mock since it's not easy to cause 160*14675a02SAndroid Build Coastguard Worker # SimpleAggregationProtocol::Complete to fail. 161*14675a02SAndroid Build Coastguard Worker mock_agg_protocol = mock.create_autospec( 162*14675a02SAndroid Build Coastguard Worker aggregation_protocol.AggregationProtocol, instance=True) 163*14675a02SAndroid Build Coastguard Worker mock_create_simple_agg_protocol.return_value = mock_agg_protocol 164*14675a02SAndroid Build Coastguard Worker 165*14675a02SAndroid Build Coastguard Worker service = aggregations.Service(lambda: FORWARDING_INFO, 166*14675a02SAndroid Build Coastguard Worker self.mock_media_service) 167*14675a02SAndroid Build Coastguard Worker session_id = service.create_session(AGGREGATION_REQUIREMENTS) 168*14675a02SAndroid Build Coastguard Worker 169*14675a02SAndroid Build Coastguard Worker required_clients = ( 170*14675a02SAndroid Build Coastguard Worker AGGREGATION_REQUIREMENTS.minimum_clients_in_server_published_aggregate) 171*14675a02SAndroid Build Coastguard Worker agg_status = apm_pb2.StatusMessage( 172*14675a02SAndroid Build Coastguard Worker num_inputs_aggregated_and_included=required_clients) 173*14675a02SAndroid Build Coastguard Worker mock_agg_protocol.GetStatus.side_effect = lambda: agg_status 174*14675a02SAndroid Build Coastguard Worker 175*14675a02SAndroid Build Coastguard Worker def on_complete(): 176*14675a02SAndroid Build Coastguard Worker agg_status.num_inputs_discarded = ( 177*14675a02SAndroid Build Coastguard Worker agg_status.num_inputs_aggregated_and_included) 178*14675a02SAndroid Build Coastguard Worker agg_status.num_inputs_aggregated_and_included = 0 179*14675a02SAndroid Build Coastguard Worker raise absl_status.StatusNotOk(absl_status.unknown_error('message')) 180*14675a02SAndroid Build Coastguard Worker 181*14675a02SAndroid Build Coastguard Worker mock_agg_protocol.Complete.side_effect = on_complete 182*14675a02SAndroid Build Coastguard Worker 183*14675a02SAndroid Build Coastguard Worker status, aggregate = service.complete_session(session_id) 184*14675a02SAndroid Build Coastguard Worker self.assertEqual( 185*14675a02SAndroid Build Coastguard Worker status, 186*14675a02SAndroid Build Coastguard Worker aggregations.SessionStatus( 187*14675a02SAndroid Build Coastguard Worker status=aggregations.AggregationStatus.FAILED, 188*14675a02SAndroid Build Coastguard Worker num_inputs_discarded=required_clients)) 189*14675a02SAndroid Build Coastguard Worker self.assertIsNone(aggregate) 190*14675a02SAndroid Build Coastguard Worker mock_agg_protocol.Complete.assert_called_once() 191*14675a02SAndroid Build Coastguard Worker mock_agg_protocol.Abort.assert_not_called() 192*14675a02SAndroid Build Coastguard Worker 193*14675a02SAndroid Build Coastguard Worker @mock.patch.object( 194*14675a02SAndroid Build Coastguard Worker aggregation_protocols, 195*14675a02SAndroid Build Coastguard Worker 'create_simple_aggregation_protocol', 196*14675a02SAndroid Build Coastguard Worker autospec=True) 197*14675a02SAndroid Build Coastguard Worker def test_complete_session_aborts(self, mock_create_simple_agg_protocol): 198*14675a02SAndroid Build Coastguard Worker # Use a mock since it's not easy to cause 199*14675a02SAndroid Build Coastguard Worker # SimpleAggregationProtocol::Complete to trigger a protocol abort. 200*14675a02SAndroid Build Coastguard Worker mock_agg_protocol = mock.create_autospec( 201*14675a02SAndroid Build Coastguard Worker aggregation_protocol.AggregationProtocol, instance=True) 202*14675a02SAndroid Build Coastguard Worker mock_create_simple_agg_protocol.return_value = mock_agg_protocol 203*14675a02SAndroid Build Coastguard Worker 204*14675a02SAndroid Build Coastguard Worker service = aggregations.Service(lambda: FORWARDING_INFO, 205*14675a02SAndroid Build Coastguard Worker self.mock_media_service) 206*14675a02SAndroid Build Coastguard Worker session_id = service.create_session(AGGREGATION_REQUIREMENTS) 207*14675a02SAndroid Build Coastguard Worker 208*14675a02SAndroid Build Coastguard Worker required_clients = ( 209*14675a02SAndroid Build Coastguard Worker AGGREGATION_REQUIREMENTS.minimum_clients_in_server_published_aggregate) 210*14675a02SAndroid Build Coastguard Worker agg_status = apm_pb2.StatusMessage( 211*14675a02SAndroid Build Coastguard Worker num_inputs_aggregated_and_included=required_clients) 212*14675a02SAndroid Build Coastguard Worker mock_agg_protocol.GetStatus.side_effect = lambda: agg_status 213*14675a02SAndroid Build Coastguard Worker 214*14675a02SAndroid Build Coastguard Worker def on_complete(): 215*14675a02SAndroid Build Coastguard Worker agg_status.num_inputs_discarded = ( 216*14675a02SAndroid Build Coastguard Worker agg_status.num_inputs_aggregated_and_included) 217*14675a02SAndroid Build Coastguard Worker agg_status.num_inputs_aggregated_and_included = 0 218*14675a02SAndroid Build Coastguard Worker callback = mock_create_simple_agg_protocol.call_args.args[1] 219*14675a02SAndroid Build Coastguard Worker callback.OnAbort(absl_status.unknown_error('message')) 220*14675a02SAndroid Build Coastguard Worker 221*14675a02SAndroid Build Coastguard Worker mock_agg_protocol.Complete.side_effect = on_complete 222*14675a02SAndroid Build Coastguard Worker 223*14675a02SAndroid Build Coastguard Worker status, aggregate = service.complete_session(session_id) 224*14675a02SAndroid Build Coastguard Worker self.assertEqual( 225*14675a02SAndroid Build Coastguard Worker status, 226*14675a02SAndroid Build Coastguard Worker aggregations.SessionStatus( 227*14675a02SAndroid Build Coastguard Worker status=aggregations.AggregationStatus.FAILED, 228*14675a02SAndroid Build Coastguard Worker num_inputs_discarded=required_clients)) 229*14675a02SAndroid Build Coastguard Worker self.assertIsNone(aggregate) 230*14675a02SAndroid Build Coastguard Worker mock_agg_protocol.Complete.assert_called_once() 231*14675a02SAndroid Build Coastguard Worker mock_agg_protocol.Abort.assert_not_called() 232*14675a02SAndroid Build Coastguard Worker 233*14675a02SAndroid Build Coastguard Worker def test_complete_session_without_enough_inputs(self): 234*14675a02SAndroid Build Coastguard Worker service = aggregations.Service(lambda: FORWARDING_INFO, 235*14675a02SAndroid Build Coastguard Worker self.mock_media_service) 236*14675a02SAndroid Build Coastguard Worker session_id = service.create_session( 237*14675a02SAndroid Build Coastguard Worker aggregations.AggregationRequirements( 238*14675a02SAndroid Build Coastguard Worker minimum_clients_in_server_published_aggregate=3, 239*14675a02SAndroid Build Coastguard Worker plan=AGGREGATION_REQUIREMENTS.plan)) 240*14675a02SAndroid Build Coastguard Worker tokens = service.pre_authorize_clients(session_id, 2) 241*14675a02SAndroid Build Coastguard Worker 242*14675a02SAndroid Build Coastguard Worker # Upload results for one client. 243*14675a02SAndroid Build Coastguard Worker operation = service.start_aggregation_data_upload( 244*14675a02SAndroid Build Coastguard Worker aggregations_pb2.StartAggregationDataUploadRequest( 245*14675a02SAndroid Build Coastguard Worker aggregation_id=session_id, authorization_token=tokens[0])) 246*14675a02SAndroid Build Coastguard Worker self.assertTrue(operation.done) 247*14675a02SAndroid Build Coastguard Worker start_upload_response = ( 248*14675a02SAndroid Build Coastguard Worker aggregations_pb2.StartAggregationDataUploadResponse()) 249*14675a02SAndroid Build Coastguard Worker operation.response.Unpack(start_upload_response) 250*14675a02SAndroid Build Coastguard Worker service.submit_aggregation_result( 251*14675a02SAndroid Build Coastguard Worker aggregations_pb2.SubmitAggregationResultRequest( 252*14675a02SAndroid Build Coastguard Worker aggregation_id=session_id, 253*14675a02SAndroid Build Coastguard Worker client_token=start_upload_response.client_token, 254*14675a02SAndroid Build Coastguard Worker resource_name=start_upload_response.resource.resource_name)) 255*14675a02SAndroid Build Coastguard Worker 256*14675a02SAndroid Build Coastguard Worker # Complete the session before there are 2 completed clients. 257*14675a02SAndroid Build Coastguard Worker status, aggregate = service.complete_session(session_id) 258*14675a02SAndroid Build Coastguard Worker self.assertEqual( 259*14675a02SAndroid Build Coastguard Worker status, 260*14675a02SAndroid Build Coastguard Worker aggregations.SessionStatus( 261*14675a02SAndroid Build Coastguard Worker status=aggregations.AggregationStatus.FAILED, 262*14675a02SAndroid Build Coastguard Worker num_clients_completed=1, 263*14675a02SAndroid Build Coastguard Worker num_inputs_discarded=1)) 264*14675a02SAndroid Build Coastguard Worker self.assertIsNone(aggregate) 265*14675a02SAndroid Build Coastguard Worker 266*14675a02SAndroid Build Coastguard Worker def test_complete_session_with_missing_session_id(self): 267*14675a02SAndroid Build Coastguard Worker service = aggregations.Service(lambda: FORWARDING_INFO, 268*14675a02SAndroid Build Coastguard Worker self.mock_media_service) 269*14675a02SAndroid Build Coastguard Worker with self.assertRaises(KeyError): 270*14675a02SAndroid Build Coastguard Worker service.complete_session('does-not-exist') 271*14675a02SAndroid Build Coastguard Worker 272*14675a02SAndroid Build Coastguard Worker def test_abort_session_with_no_uploads(self): 273*14675a02SAndroid Build Coastguard Worker service = aggregations.Service(lambda: FORWARDING_INFO, 274*14675a02SAndroid Build Coastguard Worker self.mock_media_service) 275*14675a02SAndroid Build Coastguard Worker session_id = service.create_session(AGGREGATION_REQUIREMENTS) 276*14675a02SAndroid Build Coastguard Worker self.assertEqual( 277*14675a02SAndroid Build Coastguard Worker service.abort_session(session_id), 278*14675a02SAndroid Build Coastguard Worker aggregations.SessionStatus( 279*14675a02SAndroid Build Coastguard Worker status=aggregations.AggregationStatus.ABORTED)) 280*14675a02SAndroid Build Coastguard Worker 281*14675a02SAndroid Build Coastguard Worker def test_abort_session_with_uploads(self): 282*14675a02SAndroid Build Coastguard Worker service = aggregations.Service(lambda: FORWARDING_INFO, 283*14675a02SAndroid Build Coastguard Worker self.mock_media_service) 284*14675a02SAndroid Build Coastguard Worker session_id = service.create_session(AGGREGATION_REQUIREMENTS) 285*14675a02SAndroid Build Coastguard Worker tokens = service.pre_authorize_clients(session_id, 3) 286*14675a02SAndroid Build Coastguard Worker 287*14675a02SAndroid Build Coastguard Worker # Upload results for one client. 288*14675a02SAndroid Build Coastguard Worker self.mock_media_service.register_upload.return_value = 'upload1' 289*14675a02SAndroid Build Coastguard Worker operation = service.start_aggregation_data_upload( 290*14675a02SAndroid Build Coastguard Worker aggregations_pb2.StartAggregationDataUploadRequest( 291*14675a02SAndroid Build Coastguard Worker aggregation_id=session_id, authorization_token=tokens[0])) 292*14675a02SAndroid Build Coastguard Worker self.assertTrue(operation.done) 293*14675a02SAndroid Build Coastguard Worker start_upload_response = ( 294*14675a02SAndroid Build Coastguard Worker aggregations_pb2.StartAggregationDataUploadResponse()) 295*14675a02SAndroid Build Coastguard Worker operation.response.Unpack(start_upload_response) 296*14675a02SAndroid Build Coastguard Worker service.submit_aggregation_result( 297*14675a02SAndroid Build Coastguard Worker aggregations_pb2.SubmitAggregationResultRequest( 298*14675a02SAndroid Build Coastguard Worker aggregation_id=session_id, 299*14675a02SAndroid Build Coastguard Worker client_token=start_upload_response.client_token, 300*14675a02SAndroid Build Coastguard Worker resource_name=start_upload_response.resource.resource_name)) 301*14675a02SAndroid Build Coastguard Worker 302*14675a02SAndroid Build Coastguard Worker # Start a partial upload from a second client. 303*14675a02SAndroid Build Coastguard Worker self.mock_media_service.register_upload.return_value = 'upload2' 304*14675a02SAndroid Build Coastguard Worker service.start_aggregation_data_upload( 305*14675a02SAndroid Build Coastguard Worker aggregations_pb2.StartAggregationDataUploadRequest( 306*14675a02SAndroid Build Coastguard Worker aggregation_id=session_id, authorization_token=tokens[1])) 307*14675a02SAndroid Build Coastguard Worker 308*14675a02SAndroid Build Coastguard Worker # Abort the session. The pending client should be treated as failed. 309*14675a02SAndroid Build Coastguard Worker self.assertEqual( 310*14675a02SAndroid Build Coastguard Worker service.abort_session(session_id), 311*14675a02SAndroid Build Coastguard Worker aggregations.SessionStatus( 312*14675a02SAndroid Build Coastguard Worker status=aggregations.AggregationStatus.ABORTED, 313*14675a02SAndroid Build Coastguard Worker num_clients_completed=1, 314*14675a02SAndroid Build Coastguard Worker num_clients_aborted=1, 315*14675a02SAndroid Build Coastguard Worker num_inputs_discarded=1)) 316*14675a02SAndroid Build Coastguard Worker # The registered upload for the second client should have been finalized. 317*14675a02SAndroid Build Coastguard Worker self.mock_media_service.finalize_upload.assert_called_with('upload2') 318*14675a02SAndroid Build Coastguard Worker # get_session_status should no longer return results. 319*14675a02SAndroid Build Coastguard Worker with self.assertRaises(KeyError): 320*14675a02SAndroid Build Coastguard Worker service.get_session_status(session_id) 321*14675a02SAndroid Build Coastguard Worker 322*14675a02SAndroid Build Coastguard Worker def test_abort_session_with_missing_session_id(self): 323*14675a02SAndroid Build Coastguard Worker service = aggregations.Service(lambda: FORWARDING_INFO, 324*14675a02SAndroid Build Coastguard Worker self.mock_media_service) 325*14675a02SAndroid Build Coastguard Worker with self.assertRaises(KeyError): 326*14675a02SAndroid Build Coastguard Worker service.abort_session('does-not-exist') 327*14675a02SAndroid Build Coastguard Worker 328*14675a02SAndroid Build Coastguard Worker def test_get_session_status_with_missing_session_id(self): 329*14675a02SAndroid Build Coastguard Worker service = aggregations.Service(lambda: FORWARDING_INFO, 330*14675a02SAndroid Build Coastguard Worker self.mock_media_service) 331*14675a02SAndroid Build Coastguard Worker with self.assertRaises(KeyError): 332*14675a02SAndroid Build Coastguard Worker service.get_session_status('does-not-exist') 333*14675a02SAndroid Build Coastguard Worker 334*14675a02SAndroid Build Coastguard Worker async def test_wait(self): 335*14675a02SAndroid Build Coastguard Worker service = aggregations.Service(lambda: FORWARDING_INFO, 336*14675a02SAndroid Build Coastguard Worker self.mock_media_service) 337*14675a02SAndroid Build Coastguard Worker session_id = service.create_session(AGGREGATION_REQUIREMENTS) 338*14675a02SAndroid Build Coastguard Worker task = asyncio.create_task( 339*14675a02SAndroid Build Coastguard Worker service.wait(session_id, num_inputs_aggregated_and_included=1)) 340*14675a02SAndroid Build Coastguard Worker # The awaitable should not be done yet. 341*14675a02SAndroid Build Coastguard Worker await asyncio.wait([task], timeout=0.1) 342*14675a02SAndroid Build Coastguard Worker self.assertFalse(task.done()) 343*14675a02SAndroid Build Coastguard Worker 344*14675a02SAndroid Build Coastguard Worker # Upload results for one client. 345*14675a02SAndroid Build Coastguard Worker tokens = service.pre_authorize_clients(session_id, 1) 346*14675a02SAndroid Build Coastguard Worker self.mock_media_service.register_upload.return_value = 'upload' 347*14675a02SAndroid Build Coastguard Worker operation = service.start_aggregation_data_upload( 348*14675a02SAndroid Build Coastguard Worker aggregations_pb2.StartAggregationDataUploadRequest( 349*14675a02SAndroid Build Coastguard Worker aggregation_id=session_id, authorization_token=tokens[0])) 350*14675a02SAndroid Build Coastguard Worker self.assertTrue(operation.done) 351*14675a02SAndroid Build Coastguard Worker start_upload_response = ( 352*14675a02SAndroid Build Coastguard Worker aggregations_pb2.StartAggregationDataUploadResponse()) 353*14675a02SAndroid Build Coastguard Worker operation.response.Unpack(start_upload_response) 354*14675a02SAndroid Build Coastguard Worker service.submit_aggregation_result( 355*14675a02SAndroid Build Coastguard Worker aggregations_pb2.SubmitAggregationResultRequest( 356*14675a02SAndroid Build Coastguard Worker aggregation_id=session_id, 357*14675a02SAndroid Build Coastguard Worker client_token=start_upload_response.client_token, 358*14675a02SAndroid Build Coastguard Worker resource_name=start_upload_response.resource.resource_name)) 359*14675a02SAndroid Build Coastguard Worker 360*14675a02SAndroid Build Coastguard Worker # The awaitable should now return. 361*14675a02SAndroid Build Coastguard Worker await asyncio.wait([task], timeout=1) 362*14675a02SAndroid Build Coastguard Worker self.assertTrue(task.done()) 363*14675a02SAndroid Build Coastguard Worker self.assertEqual( 364*14675a02SAndroid Build Coastguard Worker task.result(), 365*14675a02SAndroid Build Coastguard Worker aggregations.SessionStatus( 366*14675a02SAndroid Build Coastguard Worker status=aggregations.AggregationStatus.PENDING, 367*14675a02SAndroid Build Coastguard Worker num_clients_completed=1, 368*14675a02SAndroid Build Coastguard Worker num_inputs_aggregated_and_included=1)) 369*14675a02SAndroid Build Coastguard Worker 370*14675a02SAndroid Build Coastguard Worker async def test_wait_already_satisfied(self): 371*14675a02SAndroid Build Coastguard Worker service = aggregations.Service(lambda: FORWARDING_INFO, 372*14675a02SAndroid Build Coastguard Worker self.mock_media_service) 373*14675a02SAndroid Build Coastguard Worker session_id = service.create_session(AGGREGATION_REQUIREMENTS) 374*14675a02SAndroid Build Coastguard Worker 375*14675a02SAndroid Build Coastguard Worker # Upload results for one client. 376*14675a02SAndroid Build Coastguard Worker tokens = service.pre_authorize_clients(session_id, 1) 377*14675a02SAndroid Build Coastguard Worker self.mock_media_service.register_upload.return_value = 'upload' 378*14675a02SAndroid Build Coastguard Worker operation = service.start_aggregation_data_upload( 379*14675a02SAndroid Build Coastguard Worker aggregations_pb2.StartAggregationDataUploadRequest( 380*14675a02SAndroid Build Coastguard Worker aggregation_id=session_id, authorization_token=tokens[0])) 381*14675a02SAndroid Build Coastguard Worker self.assertTrue(operation.done) 382*14675a02SAndroid Build Coastguard Worker start_upload_response = ( 383*14675a02SAndroid Build Coastguard Worker aggregations_pb2.StartAggregationDataUploadResponse()) 384*14675a02SAndroid Build Coastguard Worker operation.response.Unpack(start_upload_response) 385*14675a02SAndroid Build Coastguard Worker service.submit_aggregation_result( 386*14675a02SAndroid Build Coastguard Worker aggregations_pb2.SubmitAggregationResultRequest( 387*14675a02SAndroid Build Coastguard Worker aggregation_id=session_id, 388*14675a02SAndroid Build Coastguard Worker client_token=start_upload_response.client_token, 389*14675a02SAndroid Build Coastguard Worker resource_name=start_upload_response.resource.resource_name)) 390*14675a02SAndroid Build Coastguard Worker 391*14675a02SAndroid Build Coastguard Worker # Since a client has already reported, the condition should already be 392*14675a02SAndroid Build Coastguard Worker # satisfied. 393*14675a02SAndroid Build Coastguard Worker task = asyncio.create_task( 394*14675a02SAndroid Build Coastguard Worker service.wait(session_id, num_inputs_aggregated_and_included=1)) 395*14675a02SAndroid Build Coastguard Worker await asyncio.wait([task], timeout=1) 396*14675a02SAndroid Build Coastguard Worker self.assertTrue(task.done()) 397*14675a02SAndroid Build Coastguard Worker self.assertEqual( 398*14675a02SAndroid Build Coastguard Worker task.result(), 399*14675a02SAndroid Build Coastguard Worker aggregations.SessionStatus( 400*14675a02SAndroid Build Coastguard Worker status=aggregations.AggregationStatus.PENDING, 401*14675a02SAndroid Build Coastguard Worker num_clients_completed=1, 402*14675a02SAndroid Build Coastguard Worker num_inputs_aggregated_and_included=1)) 403*14675a02SAndroid Build Coastguard Worker 404*14675a02SAndroid Build Coastguard Worker async def test_wait_with_abort(self): 405*14675a02SAndroid Build Coastguard Worker service = aggregations.Service(lambda: FORWARDING_INFO, 406*14675a02SAndroid Build Coastguard Worker self.mock_media_service) 407*14675a02SAndroid Build Coastguard Worker session_id = service.create_session(AGGREGATION_REQUIREMENTS) 408*14675a02SAndroid Build Coastguard Worker task = asyncio.create_task( 409*14675a02SAndroid Build Coastguard Worker service.wait(session_id, num_inputs_aggregated_and_included=1)) 410*14675a02SAndroid Build Coastguard Worker # The awaitable should not be done yet. 411*14675a02SAndroid Build Coastguard Worker await asyncio.wait([task], timeout=0.1) 412*14675a02SAndroid Build Coastguard Worker self.assertFalse(task.done()) 413*14675a02SAndroid Build Coastguard Worker 414*14675a02SAndroid Build Coastguard Worker # The awaitable should return once the session is aborted. 415*14675a02SAndroid Build Coastguard Worker status = service.abort_session(session_id) 416*14675a02SAndroid Build Coastguard Worker await asyncio.wait([task], timeout=1) 417*14675a02SAndroid Build Coastguard Worker self.assertTrue(task.done()) 418*14675a02SAndroid Build Coastguard Worker self.assertEqual(task.result(), status) 419*14675a02SAndroid Build Coastguard Worker 420*14675a02SAndroid Build Coastguard Worker @mock.patch.object( 421*14675a02SAndroid Build Coastguard Worker aggregation_protocols, 422*14675a02SAndroid Build Coastguard Worker 'create_simple_aggregation_protocol', 423*14675a02SAndroid Build Coastguard Worker autospec=True) 424*14675a02SAndroid Build Coastguard Worker async def test_wait_with_protocol_abort(self, 425*14675a02SAndroid Build Coastguard Worker mock_create_simple_agg_protocol): 426*14675a02SAndroid Build Coastguard Worker # Use a mock since it's not easy to cause the AggregationProtocol to abort. 427*14675a02SAndroid Build Coastguard Worker mock_agg_protocol = mock.create_autospec( 428*14675a02SAndroid Build Coastguard Worker aggregation_protocol.AggregationProtocol, instance=True) 429*14675a02SAndroid Build Coastguard Worker mock_create_simple_agg_protocol.return_value = mock_agg_protocol 430*14675a02SAndroid Build Coastguard Worker mock_agg_protocol.GetStatus.return_value = apm_pb2.StatusMessage( 431*14675a02SAndroid Build Coastguard Worker num_clients_aborted=1234) 432*14675a02SAndroid Build Coastguard Worker 433*14675a02SAndroid Build Coastguard Worker service = aggregations.Service(lambda: FORWARDING_INFO, 434*14675a02SAndroid Build Coastguard Worker self.mock_media_service) 435*14675a02SAndroid Build Coastguard Worker session_id = service.create_session(AGGREGATION_REQUIREMENTS) 436*14675a02SAndroid Build Coastguard Worker task = asyncio.create_task( 437*14675a02SAndroid Build Coastguard Worker service.wait(session_id, num_inputs_aggregated_and_included=1)) 438*14675a02SAndroid Build Coastguard Worker # The awaitable should not be done yet. 439*14675a02SAndroid Build Coastguard Worker await asyncio.wait([task], timeout=0.1) 440*14675a02SAndroid Build Coastguard Worker self.assertFalse(task.done()) 441*14675a02SAndroid Build Coastguard Worker 442*14675a02SAndroid Build Coastguard Worker # The awaitable should return once the AggregationProtocol aborts. 443*14675a02SAndroid Build Coastguard Worker callback = mock_create_simple_agg_protocol.call_args.args[1] 444*14675a02SAndroid Build Coastguard Worker callback.OnAbort(absl_status.unknown_error('message')) 445*14675a02SAndroid Build Coastguard Worker await asyncio.wait([task], timeout=1) 446*14675a02SAndroid Build Coastguard Worker self.assertTrue(task.done()) 447*14675a02SAndroid Build Coastguard Worker self.assertEqual( 448*14675a02SAndroid Build Coastguard Worker task.result(), 449*14675a02SAndroid Build Coastguard Worker aggregations.SessionStatus( 450*14675a02SAndroid Build Coastguard Worker status=aggregations.AggregationStatus.FAILED, 451*14675a02SAndroid Build Coastguard Worker num_clients_aborted=1234)) 452*14675a02SAndroid Build Coastguard Worker 453*14675a02SAndroid Build Coastguard Worker async def test_wait_with_complete(self): 454*14675a02SAndroid Build Coastguard Worker service = aggregations.Service(lambda: FORWARDING_INFO, 455*14675a02SAndroid Build Coastguard Worker self.mock_media_service) 456*14675a02SAndroid Build Coastguard Worker session_id = service.create_session( 457*14675a02SAndroid Build Coastguard Worker aggregations.AggregationRequirements( 458*14675a02SAndroid Build Coastguard Worker minimum_clients_in_server_published_aggregate=0, 459*14675a02SAndroid Build Coastguard Worker plan=AGGREGATION_REQUIREMENTS.plan)) 460*14675a02SAndroid Build Coastguard Worker task = asyncio.create_task( 461*14675a02SAndroid Build Coastguard Worker service.wait(session_id, num_inputs_aggregated_and_included=1)) 462*14675a02SAndroid Build Coastguard Worker # The awaitable should not be done yet. 463*14675a02SAndroid Build Coastguard Worker await asyncio.wait([task], timeout=0.1) 464*14675a02SAndroid Build Coastguard Worker self.assertFalse(task.done()) 465*14675a02SAndroid Build Coastguard Worker 466*14675a02SAndroid Build Coastguard Worker # The awaitable should return once the session is completed. 467*14675a02SAndroid Build Coastguard Worker status, _ = service.complete_session(session_id) 468*14675a02SAndroid Build Coastguard Worker await asyncio.wait([task], timeout=1) 469*14675a02SAndroid Build Coastguard Worker self.assertTrue(task.done()) 470*14675a02SAndroid Build Coastguard Worker self.assertEqual(task.result(), status) 471*14675a02SAndroid Build Coastguard Worker 472*14675a02SAndroid Build Coastguard Worker async def test_wait_without_condition(self): 473*14675a02SAndroid Build Coastguard Worker service = aggregations.Service(lambda: FORWARDING_INFO, 474*14675a02SAndroid Build Coastguard Worker self.mock_media_service) 475*14675a02SAndroid Build Coastguard Worker session_id = service.create_session(AGGREGATION_REQUIREMENTS) 476*14675a02SAndroid Build Coastguard Worker task = asyncio.create_task(service.wait(session_id)) 477*14675a02SAndroid Build Coastguard Worker # If there are no conditions, the wait should be trivially satisfied. 478*14675a02SAndroid Build Coastguard Worker await asyncio.wait([task], timeout=1) 479*14675a02SAndroid Build Coastguard Worker self.assertTrue(task.done()) 480*14675a02SAndroid Build Coastguard Worker self.assertEqual(task.result(), service.get_session_status(session_id)) 481*14675a02SAndroid Build Coastguard Worker 482*14675a02SAndroid Build Coastguard Worker async def test_wait_with_missing_session_id(self): 483*14675a02SAndroid Build Coastguard Worker service = aggregations.Service(lambda: FORWARDING_INFO, 484*14675a02SAndroid Build Coastguard Worker self.mock_media_service) 485*14675a02SAndroid Build Coastguard Worker task = asyncio.create_task(service.wait('does-not-exist')) 486*14675a02SAndroid Build Coastguard Worker await asyncio.wait([task], timeout=1) 487*14675a02SAndroid Build Coastguard Worker self.assertTrue(task.done()) 488*14675a02SAndroid Build Coastguard Worker self.assertIsInstance(task.exception(), KeyError) 489*14675a02SAndroid Build Coastguard Worker 490*14675a02SAndroid Build Coastguard Worker def test_start_aggregation_data_upload(self): 491*14675a02SAndroid Build Coastguard Worker service = aggregations.Service(lambda: FORWARDING_INFO, 492*14675a02SAndroid Build Coastguard Worker self.mock_media_service) 493*14675a02SAndroid Build Coastguard Worker session_id = service.create_session(AGGREGATION_REQUIREMENTS) 494*14675a02SAndroid Build Coastguard Worker tokens = service.pre_authorize_clients(session_id, 1) 495*14675a02SAndroid Build Coastguard Worker self.mock_media_service.register_upload.return_value = 'upload' 496*14675a02SAndroid Build Coastguard Worker operation = service.start_aggregation_data_upload( 497*14675a02SAndroid Build Coastguard Worker aggregations_pb2.StartAggregationDataUploadRequest( 498*14675a02SAndroid Build Coastguard Worker aggregation_id=session_id, authorization_token=tokens[0])) 499*14675a02SAndroid Build Coastguard Worker self.assertNotEmpty(operation.name) 500*14675a02SAndroid Build Coastguard Worker self.assertTrue(operation.done) 501*14675a02SAndroid Build Coastguard Worker 502*14675a02SAndroid Build Coastguard Worker metadata = aggregations_pb2.StartAggregationDataUploadMetadata() 503*14675a02SAndroid Build Coastguard Worker operation.metadata.Unpack(metadata) 504*14675a02SAndroid Build Coastguard Worker self.assertEqual(metadata, 505*14675a02SAndroid Build Coastguard Worker aggregations_pb2.StartAggregationDataUploadMetadata()) 506*14675a02SAndroid Build Coastguard Worker 507*14675a02SAndroid Build Coastguard Worker response = aggregations_pb2.StartAggregationDataUploadResponse() 508*14675a02SAndroid Build Coastguard Worker operation.response.Unpack(response) 509*14675a02SAndroid Build Coastguard Worker # The client token should be set and different from the authorization token. 510*14675a02SAndroid Build Coastguard Worker self.assertNotEmpty(response.client_token) 511*14675a02SAndroid Build Coastguard Worker self.assertNotEqual(response.client_token, tokens[0]) 512*14675a02SAndroid Build Coastguard Worker self.assertEqual( 513*14675a02SAndroid Build Coastguard Worker response, 514*14675a02SAndroid Build Coastguard Worker aggregations_pb2.StartAggregationDataUploadResponse( 515*14675a02SAndroid Build Coastguard Worker aggregation_protocol_forwarding_info=FORWARDING_INFO, 516*14675a02SAndroid Build Coastguard Worker resource=common_pb2.ByteStreamResource( 517*14675a02SAndroid Build Coastguard Worker data_upload_forwarding_info=FORWARDING_INFO, 518*14675a02SAndroid Build Coastguard Worker resource_name='upload'), 519*14675a02SAndroid Build Coastguard Worker client_token=response.client_token)) 520*14675a02SAndroid Build Coastguard Worker self.assertEqual( 521*14675a02SAndroid Build Coastguard Worker service.get_session_status(session_id), 522*14675a02SAndroid Build Coastguard Worker aggregations.SessionStatus( 523*14675a02SAndroid Build Coastguard Worker status=aggregations.AggregationStatus.PENDING, 524*14675a02SAndroid Build Coastguard Worker num_clients_pending=1)) 525*14675a02SAndroid Build Coastguard Worker 526*14675a02SAndroid Build Coastguard Worker def test_start_aggregagation_data_upload_with_missing_session_id(self): 527*14675a02SAndroid Build Coastguard Worker service = aggregations.Service(lambda: FORWARDING_INFO, 528*14675a02SAndroid Build Coastguard Worker self.mock_media_service) 529*14675a02SAndroid Build Coastguard Worker session_id = service.create_session(AGGREGATION_REQUIREMENTS) 530*14675a02SAndroid Build Coastguard Worker tokens = service.pre_authorize_clients(session_id, 1) 531*14675a02SAndroid Build Coastguard Worker with self.assertRaises(http_actions.HttpError) as cm: 532*14675a02SAndroid Build Coastguard Worker service.start_aggregation_data_upload( 533*14675a02SAndroid Build Coastguard Worker aggregations_pb2.StartAggregationDataUploadRequest( 534*14675a02SAndroid Build Coastguard Worker aggregation_id='does-not-exist', authorization_token=tokens[0])) 535*14675a02SAndroid Build Coastguard Worker self.assertEqual(cm.exception.code, http.HTTPStatus.NOT_FOUND) 536*14675a02SAndroid Build Coastguard Worker self.assertEqual( 537*14675a02SAndroid Build Coastguard Worker service.get_session_status(session_id), 538*14675a02SAndroid Build Coastguard Worker aggregations.SessionStatus( 539*14675a02SAndroid Build Coastguard Worker status=aggregations.AggregationStatus.PENDING)) 540*14675a02SAndroid Build Coastguard Worker 541*14675a02SAndroid Build Coastguard Worker def test_start_aggregagation_data_upload_with_invalid_token(self): 542*14675a02SAndroid Build Coastguard Worker service = aggregations.Service(lambda: FORWARDING_INFO, 543*14675a02SAndroid Build Coastguard Worker self.mock_media_service) 544*14675a02SAndroid Build Coastguard Worker session_id = service.create_session(AGGREGATION_REQUIREMENTS) 545*14675a02SAndroid Build Coastguard Worker with self.assertRaises(http_actions.HttpError) as cm: 546*14675a02SAndroid Build Coastguard Worker service.start_aggregation_data_upload( 547*14675a02SAndroid Build Coastguard Worker aggregations_pb2.StartAggregationDataUploadRequest( 548*14675a02SAndroid Build Coastguard Worker aggregation_id=session_id, authorization_token='does-not-exist')) 549*14675a02SAndroid Build Coastguard Worker self.assertEqual(cm.exception.code, http.HTTPStatus.UNAUTHORIZED) 550*14675a02SAndroid Build Coastguard Worker self.assertEqual( 551*14675a02SAndroid Build Coastguard Worker service.get_session_status(session_id), 552*14675a02SAndroid Build Coastguard Worker aggregations.SessionStatus( 553*14675a02SAndroid Build Coastguard Worker status=aggregations.AggregationStatus.PENDING)) 554*14675a02SAndroid Build Coastguard Worker 555*14675a02SAndroid Build Coastguard Worker def test_start_aggregagation_data_upload_twice(self): 556*14675a02SAndroid Build Coastguard Worker service = aggregations.Service(lambda: FORWARDING_INFO, 557*14675a02SAndroid Build Coastguard Worker self.mock_media_service) 558*14675a02SAndroid Build Coastguard Worker session_id = service.create_session(AGGREGATION_REQUIREMENTS) 559*14675a02SAndroid Build Coastguard Worker tokens = service.pre_authorize_clients(session_id, 1) 560*14675a02SAndroid Build Coastguard Worker service.start_aggregation_data_upload( 561*14675a02SAndroid Build Coastguard Worker aggregations_pb2.StartAggregationDataUploadRequest( 562*14675a02SAndroid Build Coastguard Worker aggregation_id=session_id, authorization_token=tokens[0])) 563*14675a02SAndroid Build Coastguard Worker with self.assertRaises(http_actions.HttpError) as cm: 564*14675a02SAndroid Build Coastguard Worker service.start_aggregation_data_upload( 565*14675a02SAndroid Build Coastguard Worker aggregations_pb2.StartAggregationDataUploadRequest( 566*14675a02SAndroid Build Coastguard Worker aggregation_id=session_id, authorization_token=tokens[0])) 567*14675a02SAndroid Build Coastguard Worker self.assertEqual(cm.exception.code, http.HTTPStatus.UNAUTHORIZED) 568*14675a02SAndroid Build Coastguard Worker 569*14675a02SAndroid Build Coastguard Worker def test_submit_aggregation_result(self): 570*14675a02SAndroid Build Coastguard Worker service = aggregations.Service(lambda: FORWARDING_INFO, 571*14675a02SAndroid Build Coastguard Worker self.mock_media_service) 572*14675a02SAndroid Build Coastguard Worker session_id = service.create_session(AGGREGATION_REQUIREMENTS) 573*14675a02SAndroid Build Coastguard Worker 574*14675a02SAndroid Build Coastguard Worker # Upload results from the client. 575*14675a02SAndroid Build Coastguard Worker tokens = service.pre_authorize_clients(session_id, 1) 576*14675a02SAndroid Build Coastguard Worker 577*14675a02SAndroid Build Coastguard Worker operation = service.start_aggregation_data_upload( 578*14675a02SAndroid Build Coastguard Worker aggregations_pb2.StartAggregationDataUploadRequest( 579*14675a02SAndroid Build Coastguard Worker aggregation_id=session_id, authorization_token=tokens[0])) 580*14675a02SAndroid Build Coastguard Worker self.assertTrue(operation.done) 581*14675a02SAndroid Build Coastguard Worker start_upload_response = ( 582*14675a02SAndroid Build Coastguard Worker aggregations_pb2.StartAggregationDataUploadResponse()) 583*14675a02SAndroid Build Coastguard Worker operation.response.Unpack(start_upload_response) 584*14675a02SAndroid Build Coastguard Worker 585*14675a02SAndroid Build Coastguard Worker submit_response = service.submit_aggregation_result( 586*14675a02SAndroid Build Coastguard Worker aggregations_pb2.SubmitAggregationResultRequest( 587*14675a02SAndroid Build Coastguard Worker aggregation_id=session_id, 588*14675a02SAndroid Build Coastguard Worker client_token=start_upload_response.client_token, 589*14675a02SAndroid Build Coastguard Worker resource_name=start_upload_response.resource.resource_name)) 590*14675a02SAndroid Build Coastguard Worker self.assertEqual(submit_response, 591*14675a02SAndroid Build Coastguard Worker aggregations_pb2.SubmitAggregationResultResponse()) 592*14675a02SAndroid Build Coastguard Worker self.mock_media_service.finalize_upload.assert_called_with( 593*14675a02SAndroid Build Coastguard Worker start_upload_response.resource.resource_name) 594*14675a02SAndroid Build Coastguard Worker self.assertEqual( 595*14675a02SAndroid Build Coastguard Worker service.get_session_status(session_id), 596*14675a02SAndroid Build Coastguard Worker aggregations.SessionStatus( 597*14675a02SAndroid Build Coastguard Worker status=aggregations.AggregationStatus.PENDING, 598*14675a02SAndroid Build Coastguard Worker num_clients_completed=1, 599*14675a02SAndroid Build Coastguard Worker num_inputs_aggregated_and_included=1)) 600*14675a02SAndroid Build Coastguard Worker 601*14675a02SAndroid Build Coastguard Worker def test_submit_aggregation_result_with_invalid_client_input(self): 602*14675a02SAndroid Build Coastguard Worker service = aggregations.Service(lambda: FORWARDING_INFO, 603*14675a02SAndroid Build Coastguard Worker self.mock_media_service) 604*14675a02SAndroid Build Coastguard Worker session_id = service.create_session(AGGREGATION_REQUIREMENTS) 605*14675a02SAndroid Build Coastguard Worker 606*14675a02SAndroid Build Coastguard Worker tokens = service.pre_authorize_clients(session_id, 1) 607*14675a02SAndroid Build Coastguard Worker operation = service.start_aggregation_data_upload( 608*14675a02SAndroid Build Coastguard Worker aggregations_pb2.StartAggregationDataUploadRequest( 609*14675a02SAndroid Build Coastguard Worker aggregation_id=session_id, authorization_token=tokens[0])) 610*14675a02SAndroid Build Coastguard Worker self.assertTrue(operation.done) 611*14675a02SAndroid Build Coastguard Worker start_upload_response = ( 612*14675a02SAndroid Build Coastguard Worker aggregations_pb2.StartAggregationDataUploadResponse()) 613*14675a02SAndroid Build Coastguard Worker operation.response.Unpack(start_upload_response) 614*14675a02SAndroid Build Coastguard Worker 615*14675a02SAndroid Build Coastguard Worker self.mock_media_service.finalize_upload.return_value = b'invalid' 616*14675a02SAndroid Build Coastguard Worker with self.assertRaises(http_actions.HttpError): 617*14675a02SAndroid Build Coastguard Worker service.submit_aggregation_result( 618*14675a02SAndroid Build Coastguard Worker aggregations_pb2.SubmitAggregationResultRequest( 619*14675a02SAndroid Build Coastguard Worker aggregation_id=session_id, 620*14675a02SAndroid Build Coastguard Worker client_token=start_upload_response.client_token, 621*14675a02SAndroid Build Coastguard Worker resource_name=start_upload_response.resource.resource_name)) 622*14675a02SAndroid Build Coastguard Worker self.assertEqual( 623*14675a02SAndroid Build Coastguard Worker service.get_session_status(session_id), 624*14675a02SAndroid Build Coastguard Worker aggregations.SessionStatus( 625*14675a02SAndroid Build Coastguard Worker status=aggregations.AggregationStatus.PENDING, 626*14675a02SAndroid Build Coastguard Worker num_clients_failed=1)) 627*14675a02SAndroid Build Coastguard Worker 628*14675a02SAndroid Build Coastguard Worker def test_submit_aggregation_result_with_missing_session_id(self): 629*14675a02SAndroid Build Coastguard Worker service = aggregations.Service(lambda: FORWARDING_INFO, 630*14675a02SAndroid Build Coastguard Worker self.mock_media_service) 631*14675a02SAndroid Build Coastguard Worker session_id = service.create_session(AGGREGATION_REQUIREMENTS) 632*14675a02SAndroid Build Coastguard Worker with self.assertRaises(http_actions.HttpError) as cm: 633*14675a02SAndroid Build Coastguard Worker service.submit_aggregation_result( 634*14675a02SAndroid Build Coastguard Worker aggregations_pb2.SubmitAggregationResultRequest( 635*14675a02SAndroid Build Coastguard Worker aggregation_id='does-not-exist', 636*14675a02SAndroid Build Coastguard Worker client_token='client-token', 637*14675a02SAndroid Build Coastguard Worker resource_name='upload-id')) 638*14675a02SAndroid Build Coastguard Worker self.assertEqual(cm.exception.code, http.HTTPStatus.NOT_FOUND) 639*14675a02SAndroid Build Coastguard Worker self.assertEqual( 640*14675a02SAndroid Build Coastguard Worker service.get_session_status(session_id), 641*14675a02SAndroid Build Coastguard Worker aggregations.SessionStatus( 642*14675a02SAndroid Build Coastguard Worker status=aggregations.AggregationStatus.PENDING)) 643*14675a02SAndroid Build Coastguard Worker 644*14675a02SAndroid Build Coastguard Worker def test_submit_aggregation_result_with_invalid_token(self): 645*14675a02SAndroid Build Coastguard Worker service = aggregations.Service(lambda: FORWARDING_INFO, 646*14675a02SAndroid Build Coastguard Worker self.mock_media_service) 647*14675a02SAndroid Build Coastguard Worker session_id = service.create_session(AGGREGATION_REQUIREMENTS) 648*14675a02SAndroid Build Coastguard Worker with self.assertRaises(http_actions.HttpError) as cm: 649*14675a02SAndroid Build Coastguard Worker service.submit_aggregation_result( 650*14675a02SAndroid Build Coastguard Worker aggregations_pb2.SubmitAggregationResultRequest( 651*14675a02SAndroid Build Coastguard Worker aggregation_id=session_id, 652*14675a02SAndroid Build Coastguard Worker client_token='does-not-exist', 653*14675a02SAndroid Build Coastguard Worker resource_name='upload-id')) 654*14675a02SAndroid Build Coastguard Worker self.assertEqual(cm.exception.code, http.HTTPStatus.UNAUTHORIZED) 655*14675a02SAndroid Build Coastguard Worker self.assertEqual( 656*14675a02SAndroid Build Coastguard Worker service.get_session_status(session_id), 657*14675a02SAndroid Build Coastguard Worker aggregations.SessionStatus( 658*14675a02SAndroid Build Coastguard Worker status=aggregations.AggregationStatus.PENDING)) 659*14675a02SAndroid Build Coastguard Worker 660*14675a02SAndroid Build Coastguard Worker def test_submit_aggregation_result_with_finalize_upload_error(self): 661*14675a02SAndroid Build Coastguard Worker service = aggregations.Service(lambda: FORWARDING_INFO, 662*14675a02SAndroid Build Coastguard Worker self.mock_media_service) 663*14675a02SAndroid Build Coastguard Worker session_id = service.create_session(AGGREGATION_REQUIREMENTS) 664*14675a02SAndroid Build Coastguard Worker tokens = service.pre_authorize_clients(session_id, 1) 665*14675a02SAndroid Build Coastguard Worker operation = service.start_aggregation_data_upload( 666*14675a02SAndroid Build Coastguard Worker aggregations_pb2.StartAggregationDataUploadRequest( 667*14675a02SAndroid Build Coastguard Worker aggregation_id=session_id, authorization_token=tokens[0])) 668*14675a02SAndroid Build Coastguard Worker self.assertTrue(operation.done) 669*14675a02SAndroid Build Coastguard Worker start_upload_response = ( 670*14675a02SAndroid Build Coastguard Worker aggregations_pb2.StartAggregationDataUploadResponse()) 671*14675a02SAndroid Build Coastguard Worker operation.response.Unpack(start_upload_response) 672*14675a02SAndroid Build Coastguard Worker 673*14675a02SAndroid Build Coastguard Worker # If the resource_name doesn't correspond to a registered upload, 674*14675a02SAndroid Build Coastguard Worker # finalize_upload will raise a KeyError. 675*14675a02SAndroid Build Coastguard Worker self.mock_media_service.finalize_upload.side_effect = KeyError() 676*14675a02SAndroid Build Coastguard Worker with self.assertRaises(http_actions.HttpError) as cm: 677*14675a02SAndroid Build Coastguard Worker service.submit_aggregation_result( 678*14675a02SAndroid Build Coastguard Worker aggregations_pb2.SubmitAggregationResultRequest( 679*14675a02SAndroid Build Coastguard Worker aggregation_id=session_id, 680*14675a02SAndroid Build Coastguard Worker client_token=start_upload_response.client_token, 681*14675a02SAndroid Build Coastguard Worker resource_name=start_upload_response.resource.resource_name)) 682*14675a02SAndroid Build Coastguard Worker self.assertEqual(cm.exception.code, http.HTTPStatus.INTERNAL_SERVER_ERROR) 683*14675a02SAndroid Build Coastguard Worker self.assertEqual( 684*14675a02SAndroid Build Coastguard Worker service.get_session_status(session_id), 685*14675a02SAndroid Build Coastguard Worker aggregations.SessionStatus( 686*14675a02SAndroid Build Coastguard Worker status=aggregations.AggregationStatus.PENDING, 687*14675a02SAndroid Build Coastguard Worker num_clients_failed=1)) 688*14675a02SAndroid Build Coastguard Worker 689*14675a02SAndroid Build Coastguard Worker def test_submit_aggregation_result_with_unuploaded_resource(self): 690*14675a02SAndroid Build Coastguard Worker service = aggregations.Service(lambda: FORWARDING_INFO, 691*14675a02SAndroid Build Coastguard Worker self.mock_media_service) 692*14675a02SAndroid Build Coastguard Worker session_id = service.create_session(AGGREGATION_REQUIREMENTS) 693*14675a02SAndroid Build Coastguard Worker tokens = service.pre_authorize_clients(session_id, 1) 694*14675a02SAndroid Build Coastguard Worker operation = service.start_aggregation_data_upload( 695*14675a02SAndroid Build Coastguard Worker aggregations_pb2.StartAggregationDataUploadRequest( 696*14675a02SAndroid Build Coastguard Worker aggregation_id=session_id, authorization_token=tokens[0])) 697*14675a02SAndroid Build Coastguard Worker self.assertTrue(operation.done) 698*14675a02SAndroid Build Coastguard Worker start_upload_response = ( 699*14675a02SAndroid Build Coastguard Worker aggregations_pb2.StartAggregationDataUploadResponse()) 700*14675a02SAndroid Build Coastguard Worker operation.response.Unpack(start_upload_response) 701*14675a02SAndroid Build Coastguard Worker 702*14675a02SAndroid Build Coastguard Worker # If the resource_name is valid but no resource was uploaded, 703*14675a02SAndroid Build Coastguard Worker # finalize_resource will return None. 704*14675a02SAndroid Build Coastguard Worker self.mock_media_service.finalize_upload.return_value = None 705*14675a02SAndroid Build Coastguard Worker with self.assertRaises(http_actions.HttpError) as cm: 706*14675a02SAndroid Build Coastguard Worker service.submit_aggregation_result( 707*14675a02SAndroid Build Coastguard Worker aggregations_pb2.SubmitAggregationResultRequest( 708*14675a02SAndroid Build Coastguard Worker aggregation_id=session_id, 709*14675a02SAndroid Build Coastguard Worker client_token=start_upload_response.client_token, 710*14675a02SAndroid Build Coastguard Worker resource_name=start_upload_response.resource.resource_name)) 711*14675a02SAndroid Build Coastguard Worker self.assertEqual(cm.exception.code, http.HTTPStatus.BAD_REQUEST) 712*14675a02SAndroid Build Coastguard Worker self.assertEqual( 713*14675a02SAndroid Build Coastguard Worker service.get_session_status(session_id), 714*14675a02SAndroid Build Coastguard Worker aggregations.SessionStatus( 715*14675a02SAndroid Build Coastguard Worker status=aggregations.AggregationStatus.PENDING, 716*14675a02SAndroid Build Coastguard Worker num_clients_failed=1)) 717*14675a02SAndroid Build Coastguard Worker 718*14675a02SAndroid Build Coastguard Worker def test_abort_aggregation(self): 719*14675a02SAndroid Build Coastguard Worker service = aggregations.Service(lambda: FORWARDING_INFO, 720*14675a02SAndroid Build Coastguard Worker self.mock_media_service) 721*14675a02SAndroid Build Coastguard Worker session_id = service.create_session(AGGREGATION_REQUIREMENTS) 722*14675a02SAndroid Build Coastguard Worker tokens = service.pre_authorize_clients(session_id, 1) 723*14675a02SAndroid Build Coastguard Worker operation = service.start_aggregation_data_upload( 724*14675a02SAndroid Build Coastguard Worker aggregations_pb2.StartAggregationDataUploadRequest( 725*14675a02SAndroid Build Coastguard Worker aggregation_id=session_id, authorization_token=tokens[0])) 726*14675a02SAndroid Build Coastguard Worker self.assertTrue(operation.done) 727*14675a02SAndroid Build Coastguard Worker start_upload_response = ( 728*14675a02SAndroid Build Coastguard Worker aggregations_pb2.StartAggregationDataUploadResponse()) 729*14675a02SAndroid Build Coastguard Worker operation.response.Unpack(start_upload_response) 730*14675a02SAndroid Build Coastguard Worker self.assertEqual( 731*14675a02SAndroid Build Coastguard Worker service.abort_aggregation( 732*14675a02SAndroid Build Coastguard Worker aggregations_pb2.AbortAggregationRequest( 733*14675a02SAndroid Build Coastguard Worker aggregation_id=session_id, 734*14675a02SAndroid Build Coastguard Worker client_token=start_upload_response.client_token)), 735*14675a02SAndroid Build Coastguard Worker aggregations_pb2.AbortAggregationResponse()) 736*14675a02SAndroid Build Coastguard Worker self.assertEqual( 737*14675a02SAndroid Build Coastguard Worker service.get_session_status(session_id), 738*14675a02SAndroid Build Coastguard Worker aggregations.SessionStatus( 739*14675a02SAndroid Build Coastguard Worker status=aggregations.AggregationStatus.PENDING, 740*14675a02SAndroid Build Coastguard Worker num_clients_completed=1, 741*14675a02SAndroid Build Coastguard Worker num_inputs_discarded=1)) 742*14675a02SAndroid Build Coastguard Worker 743*14675a02SAndroid Build Coastguard Worker def test_abort_aggregation_with_missing_session_id(self): 744*14675a02SAndroid Build Coastguard Worker service = aggregations.Service(lambda: FORWARDING_INFO, 745*14675a02SAndroid Build Coastguard Worker self.mock_media_service) 746*14675a02SAndroid Build Coastguard Worker session_id = service.create_session(AGGREGATION_REQUIREMENTS) 747*14675a02SAndroid Build Coastguard Worker tokens = service.pre_authorize_clients(session_id, 1) 748*14675a02SAndroid Build Coastguard Worker operation = service.start_aggregation_data_upload( 749*14675a02SAndroid Build Coastguard Worker aggregations_pb2.StartAggregationDataUploadRequest( 750*14675a02SAndroid Build Coastguard Worker aggregation_id=session_id, authorization_token=tokens[0])) 751*14675a02SAndroid Build Coastguard Worker self.assertTrue(operation.done) 752*14675a02SAndroid Build Coastguard Worker start_upload_response = ( 753*14675a02SAndroid Build Coastguard Worker aggregations_pb2.StartAggregationDataUploadResponse()) 754*14675a02SAndroid Build Coastguard Worker operation.response.Unpack(start_upload_response) 755*14675a02SAndroid Build Coastguard Worker with self.assertRaises(http_actions.HttpError) as cm: 756*14675a02SAndroid Build Coastguard Worker service.abort_aggregation( 757*14675a02SAndroid Build Coastguard Worker aggregations_pb2.AbortAggregationRequest( 758*14675a02SAndroid Build Coastguard Worker aggregation_id='does-not-exist', 759*14675a02SAndroid Build Coastguard Worker client_token=start_upload_response.client_token)) 760*14675a02SAndroid Build Coastguard Worker self.assertEqual(cm.exception.code, http.HTTPStatus.NOT_FOUND) 761*14675a02SAndroid Build Coastguard Worker self.assertEqual( 762*14675a02SAndroid Build Coastguard Worker service.get_session_status(session_id), 763*14675a02SAndroid Build Coastguard Worker aggregations.SessionStatus( 764*14675a02SAndroid Build Coastguard Worker status=aggregations.AggregationStatus.PENDING, 765*14675a02SAndroid Build Coastguard Worker num_clients_pending=1)) 766*14675a02SAndroid Build Coastguard Worker 767*14675a02SAndroid Build Coastguard Worker def test_abort_aggregation_with_invalid_token(self): 768*14675a02SAndroid Build Coastguard Worker service = aggregations.Service(lambda: FORWARDING_INFO, 769*14675a02SAndroid Build Coastguard Worker self.mock_media_service) 770*14675a02SAndroid Build Coastguard Worker session_id = service.create_session(AGGREGATION_REQUIREMENTS) 771*14675a02SAndroid Build Coastguard Worker with self.assertRaises(http_actions.HttpError) as cm: 772*14675a02SAndroid Build Coastguard Worker service.abort_aggregation( 773*14675a02SAndroid Build Coastguard Worker aggregations_pb2.AbortAggregationRequest( 774*14675a02SAndroid Build Coastguard Worker aggregation_id=session_id, client_token='does-not-exist')) 775*14675a02SAndroid Build Coastguard Worker self.assertEqual(cm.exception.code, http.HTTPStatus.UNAUTHORIZED) 776*14675a02SAndroid Build Coastguard Worker self.assertEqual( 777*14675a02SAndroid Build Coastguard Worker service.get_session_status(session_id), 778*14675a02SAndroid Build Coastguard Worker aggregations.SessionStatus( 779*14675a02SAndroid Build Coastguard Worker status=aggregations.AggregationStatus.PENDING)) 780*14675a02SAndroid Build Coastguard Worker 781*14675a02SAndroid Build Coastguard Worker 782*14675a02SAndroid Build Coastguard Workerif __name__ == '__main__': 783*14675a02SAndroid Build Coastguard Worker absltest.main() 784