1# Copyright 2022 Google LLC 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14"""Tests for aggregations.""" 15 16import asyncio 17import http 18import unittest 19from unittest import mock 20 21from absl.testing import absltest 22import tensorflow as tf 23 24from fcp.aggregation.protocol import aggregation_protocol_messages_pb2 as apm_pb2 25from fcp.aggregation.protocol.python import aggregation_protocol 26from fcp.aggregation.tensorflow.python import aggregation_protocols 27from fcp.demo import aggregations 28from fcp.demo import http_actions 29from fcp.demo import media 30from fcp.demo import test_utils 31from fcp.protos import plan_pb2 32from fcp.protos.federatedcompute import aggregations_pb2 33from fcp.protos.federatedcompute import common_pb2 34from pybind11_abseil import status as absl_status 35 36INPUT_TENSOR = 'in' 37OUTPUT_TENSOR = 'out' 38AGGREGATION_REQUIREMENTS = aggregations.AggregationRequirements( 39 minimum_clients_in_server_published_aggregate=3, 40 plan=plan_pb2.Plan(phase=[ 41 plan_pb2.Plan.Phase( 42 server_phase_v2=plan_pb2.ServerPhaseV2(aggregations=[ 43 plan_pb2.ServerAggregationConfig( 44 intrinsic_uri='federated_sum', 45 intrinsic_args=[ 46 plan_pb2.ServerAggregationConfig.IntrinsicArg( 47 input_tensor=tf.TensorSpec(( 48 ), tf.int32, INPUT_TENSOR).experimental_as_proto()) 49 ], 50 output_tensors=[ 51 tf.TensorSpec(( 52 ), tf.int32, OUTPUT_TENSOR).experimental_as_proto(), 53 ]), 54 ])), 55 ])) 56FORWARDING_INFO = common_pb2.ForwardingInfo( 57 target_uri_prefix='https://forwarding.example/') 58 59 60class NotOkStatus: 61 """Matcher for a not-ok Status.""" 62 63 def __eq__(self, other) -> bool: 64 return isinstance(other, absl_status.Status) and not other.ok() 65 66 67class AggregationsTest(absltest.TestCase, unittest.IsolatedAsyncioTestCase): 68 69 def setUp(self): 70 super().setUp() 71 self.mock_media_service = self.enter_context( 72 mock.patch.object(media, 'Service', autospec=True)) 73 self.mock_media_service.register_upload.return_value = 'upload-id' 74 self.mock_media_service.finalize_upload.return_value = ( 75 test_utils.create_checkpoint({INPUT_TENSOR: 0})) 76 77 def test_pre_authorize_clients(self): 78 service = aggregations.Service(lambda: FORWARDING_INFO, 79 self.mock_media_service) 80 session_id = service.create_session(AGGREGATION_REQUIREMENTS) 81 tokens = service.pre_authorize_clients(session_id, 3) 82 self.assertLen(tokens, 3) 83 # The tokens should all be unique. 84 self.assertLen(set(tokens), 3) 85 86 def test_pre_authorize_clients_with_missing_session_id(self): 87 service = aggregations.Service(lambda: FORWARDING_INFO, 88 self.mock_media_service) 89 with self.assertRaises(KeyError): 90 service.pre_authorize_clients('does-not-exist', 1) 91 92 def test_pre_authorize_clients_with_bad_count(self): 93 service = aggregations.Service(lambda: FORWARDING_INFO, 94 self.mock_media_service) 95 session_id = service.create_session(AGGREGATION_REQUIREMENTS) 96 self.assertEmpty(service.pre_authorize_clients(session_id, 0)) 97 self.assertEmpty(service.pre_authorize_clients(session_id, -2)) 98 99 def test_create_session(self): 100 service = aggregations.Service(lambda: FORWARDING_INFO, 101 self.mock_media_service) 102 session_id = service.create_session(AGGREGATION_REQUIREMENTS) 103 self.assertEqual( 104 service.get_session_status(session_id), 105 aggregations.SessionStatus( 106 status=aggregations.AggregationStatus.PENDING)) 107 108 def test_complete_session(self): 109 service = aggregations.Service(lambda: FORWARDING_INFO, 110 self.mock_media_service) 111 session_id = service.create_session(AGGREGATION_REQUIREMENTS) 112 113 # Upload results from the client. 114 num_clients = ( 115 AGGREGATION_REQUIREMENTS.minimum_clients_in_server_published_aggregate) 116 for i in range(num_clients): 117 tokens = service.pre_authorize_clients(session_id, 1) 118 119 self.mock_media_service.register_upload.return_value = f'upload-{i}' 120 operation = service.start_aggregation_data_upload( 121 aggregations_pb2.StartAggregationDataUploadRequest( 122 aggregation_id=session_id, authorization_token=tokens[0])) 123 self.assertTrue(operation.done) 124 start_upload_response = ( 125 aggregations_pb2.StartAggregationDataUploadResponse()) 126 operation.response.Unpack(start_upload_response) 127 128 self.mock_media_service.finalize_upload.return_value = ( 129 test_utils.create_checkpoint({INPUT_TENSOR: i})) 130 service.submit_aggregation_result( 131 aggregations_pb2.SubmitAggregationResultRequest( 132 aggregation_id=session_id, 133 client_token=start_upload_response.client_token, 134 resource_name=start_upload_response.resource.resource_name)) 135 136 # Now that all clients have contributed, the aggregation session can be 137 # completed. 138 status, aggregate = service.complete_session(session_id) 139 self.assertEqual( 140 status, 141 aggregations.SessionStatus( 142 status=aggregations.AggregationStatus.COMPLETED, 143 num_clients_completed=num_clients, 144 num_inputs_aggregated_and_included=num_clients)) 145 self.assertEqual( 146 test_utils.read_tensor_from_checkpoint(aggregate, 147 OUTPUT_TENSOR, tf.int32), 148 sum(range(num_clients))) 149 150 # get_session_status should no longer return results. 151 with self.assertRaises(KeyError): 152 service.get_session_status(session_id) 153 154 @mock.patch.object( 155 aggregation_protocols, 156 'create_simple_aggregation_protocol', 157 autospec=True) 158 def test_complete_session_fails(self, mock_create_simple_agg_protocol): 159 # Use a mock since it's not easy to cause 160 # SimpleAggregationProtocol::Complete to fail. 161 mock_agg_protocol = mock.create_autospec( 162 aggregation_protocol.AggregationProtocol, instance=True) 163 mock_create_simple_agg_protocol.return_value = mock_agg_protocol 164 165 service = aggregations.Service(lambda: FORWARDING_INFO, 166 self.mock_media_service) 167 session_id = service.create_session(AGGREGATION_REQUIREMENTS) 168 169 required_clients = ( 170 AGGREGATION_REQUIREMENTS.minimum_clients_in_server_published_aggregate) 171 agg_status = apm_pb2.StatusMessage( 172 num_inputs_aggregated_and_included=required_clients) 173 mock_agg_protocol.GetStatus.side_effect = lambda: agg_status 174 175 def on_complete(): 176 agg_status.num_inputs_discarded = ( 177 agg_status.num_inputs_aggregated_and_included) 178 agg_status.num_inputs_aggregated_and_included = 0 179 raise absl_status.StatusNotOk(absl_status.unknown_error('message')) 180 181 mock_agg_protocol.Complete.side_effect = on_complete 182 183 status, aggregate = service.complete_session(session_id) 184 self.assertEqual( 185 status, 186 aggregations.SessionStatus( 187 status=aggregations.AggregationStatus.FAILED, 188 num_inputs_discarded=required_clients)) 189 self.assertIsNone(aggregate) 190 mock_agg_protocol.Complete.assert_called_once() 191 mock_agg_protocol.Abort.assert_not_called() 192 193 @mock.patch.object( 194 aggregation_protocols, 195 'create_simple_aggregation_protocol', 196 autospec=True) 197 def test_complete_session_aborts(self, mock_create_simple_agg_protocol): 198 # Use a mock since it's not easy to cause 199 # SimpleAggregationProtocol::Complete to trigger a protocol abort. 200 mock_agg_protocol = mock.create_autospec( 201 aggregation_protocol.AggregationProtocol, instance=True) 202 mock_create_simple_agg_protocol.return_value = mock_agg_protocol 203 204 service = aggregations.Service(lambda: FORWARDING_INFO, 205 self.mock_media_service) 206 session_id = service.create_session(AGGREGATION_REQUIREMENTS) 207 208 required_clients = ( 209 AGGREGATION_REQUIREMENTS.minimum_clients_in_server_published_aggregate) 210 agg_status = apm_pb2.StatusMessage( 211 num_inputs_aggregated_and_included=required_clients) 212 mock_agg_protocol.GetStatus.side_effect = lambda: agg_status 213 214 def on_complete(): 215 agg_status.num_inputs_discarded = ( 216 agg_status.num_inputs_aggregated_and_included) 217 agg_status.num_inputs_aggregated_and_included = 0 218 callback = mock_create_simple_agg_protocol.call_args.args[1] 219 callback.OnAbort(absl_status.unknown_error('message')) 220 221 mock_agg_protocol.Complete.side_effect = on_complete 222 223 status, aggregate = service.complete_session(session_id) 224 self.assertEqual( 225 status, 226 aggregations.SessionStatus( 227 status=aggregations.AggregationStatus.FAILED, 228 num_inputs_discarded=required_clients)) 229 self.assertIsNone(aggregate) 230 mock_agg_protocol.Complete.assert_called_once() 231 mock_agg_protocol.Abort.assert_not_called() 232 233 def test_complete_session_without_enough_inputs(self): 234 service = aggregations.Service(lambda: FORWARDING_INFO, 235 self.mock_media_service) 236 session_id = service.create_session( 237 aggregations.AggregationRequirements( 238 minimum_clients_in_server_published_aggregate=3, 239 plan=AGGREGATION_REQUIREMENTS.plan)) 240 tokens = service.pre_authorize_clients(session_id, 2) 241 242 # Upload results for one client. 243 operation = service.start_aggregation_data_upload( 244 aggregations_pb2.StartAggregationDataUploadRequest( 245 aggregation_id=session_id, authorization_token=tokens[0])) 246 self.assertTrue(operation.done) 247 start_upload_response = ( 248 aggregations_pb2.StartAggregationDataUploadResponse()) 249 operation.response.Unpack(start_upload_response) 250 service.submit_aggregation_result( 251 aggregations_pb2.SubmitAggregationResultRequest( 252 aggregation_id=session_id, 253 client_token=start_upload_response.client_token, 254 resource_name=start_upload_response.resource.resource_name)) 255 256 # Complete the session before there are 2 completed clients. 257 status, aggregate = service.complete_session(session_id) 258 self.assertEqual( 259 status, 260 aggregations.SessionStatus( 261 status=aggregations.AggregationStatus.FAILED, 262 num_clients_completed=1, 263 num_inputs_discarded=1)) 264 self.assertIsNone(aggregate) 265 266 def test_complete_session_with_missing_session_id(self): 267 service = aggregations.Service(lambda: FORWARDING_INFO, 268 self.mock_media_service) 269 with self.assertRaises(KeyError): 270 service.complete_session('does-not-exist') 271 272 def test_abort_session_with_no_uploads(self): 273 service = aggregations.Service(lambda: FORWARDING_INFO, 274 self.mock_media_service) 275 session_id = service.create_session(AGGREGATION_REQUIREMENTS) 276 self.assertEqual( 277 service.abort_session(session_id), 278 aggregations.SessionStatus( 279 status=aggregations.AggregationStatus.ABORTED)) 280 281 def test_abort_session_with_uploads(self): 282 service = aggregations.Service(lambda: FORWARDING_INFO, 283 self.mock_media_service) 284 session_id = service.create_session(AGGREGATION_REQUIREMENTS) 285 tokens = service.pre_authorize_clients(session_id, 3) 286 287 # Upload results for one client. 288 self.mock_media_service.register_upload.return_value = 'upload1' 289 operation = service.start_aggregation_data_upload( 290 aggregations_pb2.StartAggregationDataUploadRequest( 291 aggregation_id=session_id, authorization_token=tokens[0])) 292 self.assertTrue(operation.done) 293 start_upload_response = ( 294 aggregations_pb2.StartAggregationDataUploadResponse()) 295 operation.response.Unpack(start_upload_response) 296 service.submit_aggregation_result( 297 aggregations_pb2.SubmitAggregationResultRequest( 298 aggregation_id=session_id, 299 client_token=start_upload_response.client_token, 300 resource_name=start_upload_response.resource.resource_name)) 301 302 # Start a partial upload from a second client. 303 self.mock_media_service.register_upload.return_value = 'upload2' 304 service.start_aggregation_data_upload( 305 aggregations_pb2.StartAggregationDataUploadRequest( 306 aggregation_id=session_id, authorization_token=tokens[1])) 307 308 # Abort the session. The pending client should be treated as failed. 309 self.assertEqual( 310 service.abort_session(session_id), 311 aggregations.SessionStatus( 312 status=aggregations.AggregationStatus.ABORTED, 313 num_clients_completed=1, 314 num_clients_aborted=1, 315 num_inputs_discarded=1)) 316 # The registered upload for the second client should have been finalized. 317 self.mock_media_service.finalize_upload.assert_called_with('upload2') 318 # get_session_status should no longer return results. 319 with self.assertRaises(KeyError): 320 service.get_session_status(session_id) 321 322 def test_abort_session_with_missing_session_id(self): 323 service = aggregations.Service(lambda: FORWARDING_INFO, 324 self.mock_media_service) 325 with self.assertRaises(KeyError): 326 service.abort_session('does-not-exist') 327 328 def test_get_session_status_with_missing_session_id(self): 329 service = aggregations.Service(lambda: FORWARDING_INFO, 330 self.mock_media_service) 331 with self.assertRaises(KeyError): 332 service.get_session_status('does-not-exist') 333 334 async def test_wait(self): 335 service = aggregations.Service(lambda: FORWARDING_INFO, 336 self.mock_media_service) 337 session_id = service.create_session(AGGREGATION_REQUIREMENTS) 338 task = asyncio.create_task( 339 service.wait(session_id, num_inputs_aggregated_and_included=1)) 340 # The awaitable should not be done yet. 341 await asyncio.wait([task], timeout=0.1) 342 self.assertFalse(task.done()) 343 344 # Upload results for one client. 345 tokens = service.pre_authorize_clients(session_id, 1) 346 self.mock_media_service.register_upload.return_value = 'upload' 347 operation = service.start_aggregation_data_upload( 348 aggregations_pb2.StartAggregationDataUploadRequest( 349 aggregation_id=session_id, authorization_token=tokens[0])) 350 self.assertTrue(operation.done) 351 start_upload_response = ( 352 aggregations_pb2.StartAggregationDataUploadResponse()) 353 operation.response.Unpack(start_upload_response) 354 service.submit_aggregation_result( 355 aggregations_pb2.SubmitAggregationResultRequest( 356 aggregation_id=session_id, 357 client_token=start_upload_response.client_token, 358 resource_name=start_upload_response.resource.resource_name)) 359 360 # The awaitable should now return. 361 await asyncio.wait([task], timeout=1) 362 self.assertTrue(task.done()) 363 self.assertEqual( 364 task.result(), 365 aggregations.SessionStatus( 366 status=aggregations.AggregationStatus.PENDING, 367 num_clients_completed=1, 368 num_inputs_aggregated_and_included=1)) 369 370 async def test_wait_already_satisfied(self): 371 service = aggregations.Service(lambda: FORWARDING_INFO, 372 self.mock_media_service) 373 session_id = service.create_session(AGGREGATION_REQUIREMENTS) 374 375 # Upload results for one client. 376 tokens = service.pre_authorize_clients(session_id, 1) 377 self.mock_media_service.register_upload.return_value = 'upload' 378 operation = service.start_aggregation_data_upload( 379 aggregations_pb2.StartAggregationDataUploadRequest( 380 aggregation_id=session_id, authorization_token=tokens[0])) 381 self.assertTrue(operation.done) 382 start_upload_response = ( 383 aggregations_pb2.StartAggregationDataUploadResponse()) 384 operation.response.Unpack(start_upload_response) 385 service.submit_aggregation_result( 386 aggregations_pb2.SubmitAggregationResultRequest( 387 aggregation_id=session_id, 388 client_token=start_upload_response.client_token, 389 resource_name=start_upload_response.resource.resource_name)) 390 391 # Since a client has already reported, the condition should already be 392 # satisfied. 393 task = asyncio.create_task( 394 service.wait(session_id, num_inputs_aggregated_and_included=1)) 395 await asyncio.wait([task], timeout=1) 396 self.assertTrue(task.done()) 397 self.assertEqual( 398 task.result(), 399 aggregations.SessionStatus( 400 status=aggregations.AggregationStatus.PENDING, 401 num_clients_completed=1, 402 num_inputs_aggregated_and_included=1)) 403 404 async def test_wait_with_abort(self): 405 service = aggregations.Service(lambda: FORWARDING_INFO, 406 self.mock_media_service) 407 session_id = service.create_session(AGGREGATION_REQUIREMENTS) 408 task = asyncio.create_task( 409 service.wait(session_id, num_inputs_aggregated_and_included=1)) 410 # The awaitable should not be done yet. 411 await asyncio.wait([task], timeout=0.1) 412 self.assertFalse(task.done()) 413 414 # The awaitable should return once the session is aborted. 415 status = service.abort_session(session_id) 416 await asyncio.wait([task], timeout=1) 417 self.assertTrue(task.done()) 418 self.assertEqual(task.result(), status) 419 420 @mock.patch.object( 421 aggregation_protocols, 422 'create_simple_aggregation_protocol', 423 autospec=True) 424 async def test_wait_with_protocol_abort(self, 425 mock_create_simple_agg_protocol): 426 # Use a mock since it's not easy to cause the AggregationProtocol to abort. 427 mock_agg_protocol = mock.create_autospec( 428 aggregation_protocol.AggregationProtocol, instance=True) 429 mock_create_simple_agg_protocol.return_value = mock_agg_protocol 430 mock_agg_protocol.GetStatus.return_value = apm_pb2.StatusMessage( 431 num_clients_aborted=1234) 432 433 service = aggregations.Service(lambda: FORWARDING_INFO, 434 self.mock_media_service) 435 session_id = service.create_session(AGGREGATION_REQUIREMENTS) 436 task = asyncio.create_task( 437 service.wait(session_id, num_inputs_aggregated_and_included=1)) 438 # The awaitable should not be done yet. 439 await asyncio.wait([task], timeout=0.1) 440 self.assertFalse(task.done()) 441 442 # The awaitable should return once the AggregationProtocol aborts. 443 callback = mock_create_simple_agg_protocol.call_args.args[1] 444 callback.OnAbort(absl_status.unknown_error('message')) 445 await asyncio.wait([task], timeout=1) 446 self.assertTrue(task.done()) 447 self.assertEqual( 448 task.result(), 449 aggregations.SessionStatus( 450 status=aggregations.AggregationStatus.FAILED, 451 num_clients_aborted=1234)) 452 453 async def test_wait_with_complete(self): 454 service = aggregations.Service(lambda: FORWARDING_INFO, 455 self.mock_media_service) 456 session_id = service.create_session( 457 aggregations.AggregationRequirements( 458 minimum_clients_in_server_published_aggregate=0, 459 plan=AGGREGATION_REQUIREMENTS.plan)) 460 task = asyncio.create_task( 461 service.wait(session_id, num_inputs_aggregated_and_included=1)) 462 # The awaitable should not be done yet. 463 await asyncio.wait([task], timeout=0.1) 464 self.assertFalse(task.done()) 465 466 # The awaitable should return once the session is completed. 467 status, _ = service.complete_session(session_id) 468 await asyncio.wait([task], timeout=1) 469 self.assertTrue(task.done()) 470 self.assertEqual(task.result(), status) 471 472 async def test_wait_without_condition(self): 473 service = aggregations.Service(lambda: FORWARDING_INFO, 474 self.mock_media_service) 475 session_id = service.create_session(AGGREGATION_REQUIREMENTS) 476 task = asyncio.create_task(service.wait(session_id)) 477 # If there are no conditions, the wait should be trivially satisfied. 478 await asyncio.wait([task], timeout=1) 479 self.assertTrue(task.done()) 480 self.assertEqual(task.result(), service.get_session_status(session_id)) 481 482 async def test_wait_with_missing_session_id(self): 483 service = aggregations.Service(lambda: FORWARDING_INFO, 484 self.mock_media_service) 485 task = asyncio.create_task(service.wait('does-not-exist')) 486 await asyncio.wait([task], timeout=1) 487 self.assertTrue(task.done()) 488 self.assertIsInstance(task.exception(), KeyError) 489 490 def test_start_aggregation_data_upload(self): 491 service = aggregations.Service(lambda: FORWARDING_INFO, 492 self.mock_media_service) 493 session_id = service.create_session(AGGREGATION_REQUIREMENTS) 494 tokens = service.pre_authorize_clients(session_id, 1) 495 self.mock_media_service.register_upload.return_value = 'upload' 496 operation = service.start_aggregation_data_upload( 497 aggregations_pb2.StartAggregationDataUploadRequest( 498 aggregation_id=session_id, authorization_token=tokens[0])) 499 self.assertNotEmpty(operation.name) 500 self.assertTrue(operation.done) 501 502 metadata = aggregations_pb2.StartAggregationDataUploadMetadata() 503 operation.metadata.Unpack(metadata) 504 self.assertEqual(metadata, 505 aggregations_pb2.StartAggregationDataUploadMetadata()) 506 507 response = aggregations_pb2.StartAggregationDataUploadResponse() 508 operation.response.Unpack(response) 509 # The client token should be set and different from the authorization token. 510 self.assertNotEmpty(response.client_token) 511 self.assertNotEqual(response.client_token, tokens[0]) 512 self.assertEqual( 513 response, 514 aggregations_pb2.StartAggregationDataUploadResponse( 515 aggregation_protocol_forwarding_info=FORWARDING_INFO, 516 resource=common_pb2.ByteStreamResource( 517 data_upload_forwarding_info=FORWARDING_INFO, 518 resource_name='upload'), 519 client_token=response.client_token)) 520 self.assertEqual( 521 service.get_session_status(session_id), 522 aggregations.SessionStatus( 523 status=aggregations.AggregationStatus.PENDING, 524 num_clients_pending=1)) 525 526 def test_start_aggregagation_data_upload_with_missing_session_id(self): 527 service = aggregations.Service(lambda: FORWARDING_INFO, 528 self.mock_media_service) 529 session_id = service.create_session(AGGREGATION_REQUIREMENTS) 530 tokens = service.pre_authorize_clients(session_id, 1) 531 with self.assertRaises(http_actions.HttpError) as cm: 532 service.start_aggregation_data_upload( 533 aggregations_pb2.StartAggregationDataUploadRequest( 534 aggregation_id='does-not-exist', authorization_token=tokens[0])) 535 self.assertEqual(cm.exception.code, http.HTTPStatus.NOT_FOUND) 536 self.assertEqual( 537 service.get_session_status(session_id), 538 aggregations.SessionStatus( 539 status=aggregations.AggregationStatus.PENDING)) 540 541 def test_start_aggregagation_data_upload_with_invalid_token(self): 542 service = aggregations.Service(lambda: FORWARDING_INFO, 543 self.mock_media_service) 544 session_id = service.create_session(AGGREGATION_REQUIREMENTS) 545 with self.assertRaises(http_actions.HttpError) as cm: 546 service.start_aggregation_data_upload( 547 aggregations_pb2.StartAggregationDataUploadRequest( 548 aggregation_id=session_id, authorization_token='does-not-exist')) 549 self.assertEqual(cm.exception.code, http.HTTPStatus.UNAUTHORIZED) 550 self.assertEqual( 551 service.get_session_status(session_id), 552 aggregations.SessionStatus( 553 status=aggregations.AggregationStatus.PENDING)) 554 555 def test_start_aggregagation_data_upload_twice(self): 556 service = aggregations.Service(lambda: FORWARDING_INFO, 557 self.mock_media_service) 558 session_id = service.create_session(AGGREGATION_REQUIREMENTS) 559 tokens = service.pre_authorize_clients(session_id, 1) 560 service.start_aggregation_data_upload( 561 aggregations_pb2.StartAggregationDataUploadRequest( 562 aggregation_id=session_id, authorization_token=tokens[0])) 563 with self.assertRaises(http_actions.HttpError) as cm: 564 service.start_aggregation_data_upload( 565 aggregations_pb2.StartAggregationDataUploadRequest( 566 aggregation_id=session_id, authorization_token=tokens[0])) 567 self.assertEqual(cm.exception.code, http.HTTPStatus.UNAUTHORIZED) 568 569 def test_submit_aggregation_result(self): 570 service = aggregations.Service(lambda: FORWARDING_INFO, 571 self.mock_media_service) 572 session_id = service.create_session(AGGREGATION_REQUIREMENTS) 573 574 # Upload results from the client. 575 tokens = service.pre_authorize_clients(session_id, 1) 576 577 operation = service.start_aggregation_data_upload( 578 aggregations_pb2.StartAggregationDataUploadRequest( 579 aggregation_id=session_id, authorization_token=tokens[0])) 580 self.assertTrue(operation.done) 581 start_upload_response = ( 582 aggregations_pb2.StartAggregationDataUploadResponse()) 583 operation.response.Unpack(start_upload_response) 584 585 submit_response = service.submit_aggregation_result( 586 aggregations_pb2.SubmitAggregationResultRequest( 587 aggregation_id=session_id, 588 client_token=start_upload_response.client_token, 589 resource_name=start_upload_response.resource.resource_name)) 590 self.assertEqual(submit_response, 591 aggregations_pb2.SubmitAggregationResultResponse()) 592 self.mock_media_service.finalize_upload.assert_called_with( 593 start_upload_response.resource.resource_name) 594 self.assertEqual( 595 service.get_session_status(session_id), 596 aggregations.SessionStatus( 597 status=aggregations.AggregationStatus.PENDING, 598 num_clients_completed=1, 599 num_inputs_aggregated_and_included=1)) 600 601 def test_submit_aggregation_result_with_invalid_client_input(self): 602 service = aggregations.Service(lambda: FORWARDING_INFO, 603 self.mock_media_service) 604 session_id = service.create_session(AGGREGATION_REQUIREMENTS) 605 606 tokens = service.pre_authorize_clients(session_id, 1) 607 operation = service.start_aggregation_data_upload( 608 aggregations_pb2.StartAggregationDataUploadRequest( 609 aggregation_id=session_id, authorization_token=tokens[0])) 610 self.assertTrue(operation.done) 611 start_upload_response = ( 612 aggregations_pb2.StartAggregationDataUploadResponse()) 613 operation.response.Unpack(start_upload_response) 614 615 self.mock_media_service.finalize_upload.return_value = b'invalid' 616 with self.assertRaises(http_actions.HttpError): 617 service.submit_aggregation_result( 618 aggregations_pb2.SubmitAggregationResultRequest( 619 aggregation_id=session_id, 620 client_token=start_upload_response.client_token, 621 resource_name=start_upload_response.resource.resource_name)) 622 self.assertEqual( 623 service.get_session_status(session_id), 624 aggregations.SessionStatus( 625 status=aggregations.AggregationStatus.PENDING, 626 num_clients_failed=1)) 627 628 def test_submit_aggregation_result_with_missing_session_id(self): 629 service = aggregations.Service(lambda: FORWARDING_INFO, 630 self.mock_media_service) 631 session_id = service.create_session(AGGREGATION_REQUIREMENTS) 632 with self.assertRaises(http_actions.HttpError) as cm: 633 service.submit_aggregation_result( 634 aggregations_pb2.SubmitAggregationResultRequest( 635 aggregation_id='does-not-exist', 636 client_token='client-token', 637 resource_name='upload-id')) 638 self.assertEqual(cm.exception.code, http.HTTPStatus.NOT_FOUND) 639 self.assertEqual( 640 service.get_session_status(session_id), 641 aggregations.SessionStatus( 642 status=aggregations.AggregationStatus.PENDING)) 643 644 def test_submit_aggregation_result_with_invalid_token(self): 645 service = aggregations.Service(lambda: FORWARDING_INFO, 646 self.mock_media_service) 647 session_id = service.create_session(AGGREGATION_REQUIREMENTS) 648 with self.assertRaises(http_actions.HttpError) as cm: 649 service.submit_aggregation_result( 650 aggregations_pb2.SubmitAggregationResultRequest( 651 aggregation_id=session_id, 652 client_token='does-not-exist', 653 resource_name='upload-id')) 654 self.assertEqual(cm.exception.code, http.HTTPStatus.UNAUTHORIZED) 655 self.assertEqual( 656 service.get_session_status(session_id), 657 aggregations.SessionStatus( 658 status=aggregations.AggregationStatus.PENDING)) 659 660 def test_submit_aggregation_result_with_finalize_upload_error(self): 661 service = aggregations.Service(lambda: FORWARDING_INFO, 662 self.mock_media_service) 663 session_id = service.create_session(AGGREGATION_REQUIREMENTS) 664 tokens = service.pre_authorize_clients(session_id, 1) 665 operation = service.start_aggregation_data_upload( 666 aggregations_pb2.StartAggregationDataUploadRequest( 667 aggregation_id=session_id, authorization_token=tokens[0])) 668 self.assertTrue(operation.done) 669 start_upload_response = ( 670 aggregations_pb2.StartAggregationDataUploadResponse()) 671 operation.response.Unpack(start_upload_response) 672 673 # If the resource_name doesn't correspond to a registered upload, 674 # finalize_upload will raise a KeyError. 675 self.mock_media_service.finalize_upload.side_effect = KeyError() 676 with self.assertRaises(http_actions.HttpError) as cm: 677 service.submit_aggregation_result( 678 aggregations_pb2.SubmitAggregationResultRequest( 679 aggregation_id=session_id, 680 client_token=start_upload_response.client_token, 681 resource_name=start_upload_response.resource.resource_name)) 682 self.assertEqual(cm.exception.code, http.HTTPStatus.INTERNAL_SERVER_ERROR) 683 self.assertEqual( 684 service.get_session_status(session_id), 685 aggregations.SessionStatus( 686 status=aggregations.AggregationStatus.PENDING, 687 num_clients_failed=1)) 688 689 def test_submit_aggregation_result_with_unuploaded_resource(self): 690 service = aggregations.Service(lambda: FORWARDING_INFO, 691 self.mock_media_service) 692 session_id = service.create_session(AGGREGATION_REQUIREMENTS) 693 tokens = service.pre_authorize_clients(session_id, 1) 694 operation = service.start_aggregation_data_upload( 695 aggregations_pb2.StartAggregationDataUploadRequest( 696 aggregation_id=session_id, authorization_token=tokens[0])) 697 self.assertTrue(operation.done) 698 start_upload_response = ( 699 aggregations_pb2.StartAggregationDataUploadResponse()) 700 operation.response.Unpack(start_upload_response) 701 702 # If the resource_name is valid but no resource was uploaded, 703 # finalize_resource will return None. 704 self.mock_media_service.finalize_upload.return_value = None 705 with self.assertRaises(http_actions.HttpError) as cm: 706 service.submit_aggregation_result( 707 aggregations_pb2.SubmitAggregationResultRequest( 708 aggregation_id=session_id, 709 client_token=start_upload_response.client_token, 710 resource_name=start_upload_response.resource.resource_name)) 711 self.assertEqual(cm.exception.code, http.HTTPStatus.BAD_REQUEST) 712 self.assertEqual( 713 service.get_session_status(session_id), 714 aggregations.SessionStatus( 715 status=aggregations.AggregationStatus.PENDING, 716 num_clients_failed=1)) 717 718 def test_abort_aggregation(self): 719 service = aggregations.Service(lambda: FORWARDING_INFO, 720 self.mock_media_service) 721 session_id = service.create_session(AGGREGATION_REQUIREMENTS) 722 tokens = service.pre_authorize_clients(session_id, 1) 723 operation = service.start_aggregation_data_upload( 724 aggregations_pb2.StartAggregationDataUploadRequest( 725 aggregation_id=session_id, authorization_token=tokens[0])) 726 self.assertTrue(operation.done) 727 start_upload_response = ( 728 aggregations_pb2.StartAggregationDataUploadResponse()) 729 operation.response.Unpack(start_upload_response) 730 self.assertEqual( 731 service.abort_aggregation( 732 aggregations_pb2.AbortAggregationRequest( 733 aggregation_id=session_id, 734 client_token=start_upload_response.client_token)), 735 aggregations_pb2.AbortAggregationResponse()) 736 self.assertEqual( 737 service.get_session_status(session_id), 738 aggregations.SessionStatus( 739 status=aggregations.AggregationStatus.PENDING, 740 num_clients_completed=1, 741 num_inputs_discarded=1)) 742 743 def test_abort_aggregation_with_missing_session_id(self): 744 service = aggregations.Service(lambda: FORWARDING_INFO, 745 self.mock_media_service) 746 session_id = service.create_session(AGGREGATION_REQUIREMENTS) 747 tokens = service.pre_authorize_clients(session_id, 1) 748 operation = service.start_aggregation_data_upload( 749 aggregations_pb2.StartAggregationDataUploadRequest( 750 aggregation_id=session_id, authorization_token=tokens[0])) 751 self.assertTrue(operation.done) 752 start_upload_response = ( 753 aggregations_pb2.StartAggregationDataUploadResponse()) 754 operation.response.Unpack(start_upload_response) 755 with self.assertRaises(http_actions.HttpError) as cm: 756 service.abort_aggregation( 757 aggregations_pb2.AbortAggregationRequest( 758 aggregation_id='does-not-exist', 759 client_token=start_upload_response.client_token)) 760 self.assertEqual(cm.exception.code, http.HTTPStatus.NOT_FOUND) 761 self.assertEqual( 762 service.get_session_status(session_id), 763 aggregations.SessionStatus( 764 status=aggregations.AggregationStatus.PENDING, 765 num_clients_pending=1)) 766 767 def test_abort_aggregation_with_invalid_token(self): 768 service = aggregations.Service(lambda: FORWARDING_INFO, 769 self.mock_media_service) 770 session_id = service.create_session(AGGREGATION_REQUIREMENTS) 771 with self.assertRaises(http_actions.HttpError) as cm: 772 service.abort_aggregation( 773 aggregations_pb2.AbortAggregationRequest( 774 aggregation_id=session_id, client_token='does-not-exist')) 775 self.assertEqual(cm.exception.code, http.HTTPStatus.UNAUTHORIZED) 776 self.assertEqual( 777 service.get_session_status(session_id), 778 aggregations.SessionStatus( 779 status=aggregations.AggregationStatus.PENDING)) 780 781 782if __name__ == '__main__': 783 absltest.main() 784