1*14675a02SAndroid Build Coastguard Worker# Copyright 2022 Google LLC 2*14675a02SAndroid Build Coastguard Worker# 3*14675a02SAndroid Build Coastguard Worker# Licensed under the Apache License, Version 2.0 (the "License"); 4*14675a02SAndroid Build Coastguard Worker# you may not use this file except in compliance with the License. 5*14675a02SAndroid Build Coastguard Worker# You may obtain a copy of the License at 6*14675a02SAndroid Build Coastguard Worker# 7*14675a02SAndroid Build Coastguard Worker# http://www.apache.org/licenses/LICENSE-2.0 8*14675a02SAndroid Build Coastguard Worker# 9*14675a02SAndroid Build Coastguard Worker# Unless required by applicable law or agreed to in writing, software 10*14675a02SAndroid Build Coastguard Worker# distributed under the License is distributed on an "AS IS" BASIS, 11*14675a02SAndroid Build Coastguard Worker# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12*14675a02SAndroid Build Coastguard Worker# See the License for the specific language governing permissions and 13*14675a02SAndroid Build Coastguard Worker# limitations under the License. 14*14675a02SAndroid Build Coastguard Worker"""Tests for task_assignments.""" 15*14675a02SAndroid Build Coastguard Worker 16*14675a02SAndroid Build Coastguard Workerimport http 17*14675a02SAndroid Build Coastguard Workerfrom unittest import mock 18*14675a02SAndroid Build Coastguard Workerimport uuid 19*14675a02SAndroid Build Coastguard Worker 20*14675a02SAndroid Build Coastguard Workerfrom absl.testing import absltest 21*14675a02SAndroid Build Coastguard Worker 22*14675a02SAndroid Build Coastguard Workerfrom google.rpc import code_pb2 23*14675a02SAndroid Build Coastguard Workerfrom fcp.demo import aggregations 24*14675a02SAndroid Build Coastguard Workerfrom fcp.demo import http_actions 25*14675a02SAndroid Build Coastguard Workerfrom fcp.demo import task_assignments 26*14675a02SAndroid Build Coastguard Workerfrom fcp.protos.federatedcompute import common_pb2 27*14675a02SAndroid Build Coastguard Workerfrom fcp.protos.federatedcompute import eligibility_eval_tasks_pb2 28*14675a02SAndroid Build Coastguard Workerfrom fcp.protos.federatedcompute import task_assignments_pb2 29*14675a02SAndroid Build Coastguard Worker 30*14675a02SAndroid Build Coastguard Worker_TaskAssignmentMode = ( 31*14675a02SAndroid Build Coastguard Worker eligibility_eval_tasks_pb2.PopulationEligibilitySpec.TaskInfo.TaskAssignmentMode 32*14675a02SAndroid Build Coastguard Worker) 33*14675a02SAndroid Build Coastguard Worker 34*14675a02SAndroid Build Coastguard WorkerPOPULATION_NAME = 'test/population' 35*14675a02SAndroid Build Coastguard WorkerFORWARDING_INFO = common_pb2.ForwardingInfo( 36*14675a02SAndroid Build Coastguard Worker target_uri_prefix='https://forwarding.example/') 37*14675a02SAndroid Build Coastguard Worker 38*14675a02SAndroid Build Coastguard Worker 39*14675a02SAndroid Build Coastguard Workerclass TaskAssignmentsTest(absltest.TestCase): 40*14675a02SAndroid Build Coastguard Worker 41*14675a02SAndroid Build Coastguard Worker def setUp(self): 42*14675a02SAndroid Build Coastguard Worker super().setUp() 43*14675a02SAndroid Build Coastguard Worker self.mock_aggregations_service = self.enter_context( 44*14675a02SAndroid Build Coastguard Worker mock.patch.object(aggregations, 'Service', autospec=True)) 45*14675a02SAndroid Build Coastguard Worker self.mock_aggregations_service.pre_authorize_clients.return_value = [''] 46*14675a02SAndroid Build Coastguard Worker 47*14675a02SAndroid Build Coastguard Worker def test_start_task_assignment_with_wrong_population(self): 48*14675a02SAndroid Build Coastguard Worker service = task_assignments.Service(POPULATION_NAME, lambda: FORWARDING_INFO, 49*14675a02SAndroid Build Coastguard Worker self.mock_aggregations_service) 50*14675a02SAndroid Build Coastguard Worker request = task_assignments_pb2.StartTaskAssignmentRequest( 51*14675a02SAndroid Build Coastguard Worker population_name='other/population', session_id='session-id') 52*14675a02SAndroid Build Coastguard Worker with self.assertRaises(http_actions.HttpError) as cm: 53*14675a02SAndroid Build Coastguard Worker service.start_task_assignment(request) 54*14675a02SAndroid Build Coastguard Worker self.assertEqual(cm.exception.code, http.HTTPStatus.NOT_FOUND) 55*14675a02SAndroid Build Coastguard Worker 56*14675a02SAndroid Build Coastguard Worker @mock.patch.object(uuid, 'uuid4', return_value=uuid.uuid4(), autospec=True) 57*14675a02SAndroid Build Coastguard Worker def test_start_task_assignment_with_no_tasks(self, mock_uuid): 58*14675a02SAndroid Build Coastguard Worker service = task_assignments.Service(POPULATION_NAME, lambda: FORWARDING_INFO, 59*14675a02SAndroid Build Coastguard Worker self.mock_aggregations_service) 60*14675a02SAndroid Build Coastguard Worker request = task_assignments_pb2.StartTaskAssignmentRequest( 61*14675a02SAndroid Build Coastguard Worker population_name=POPULATION_NAME, session_id='session-id') 62*14675a02SAndroid Build Coastguard Worker operation = service.start_task_assignment(request) 63*14675a02SAndroid Build Coastguard Worker self.assertEqual(operation.name, f'operations/{mock_uuid.return_value}') 64*14675a02SAndroid Build Coastguard Worker self.assertTrue(operation.done) 65*14675a02SAndroid Build Coastguard Worker 66*14675a02SAndroid Build Coastguard Worker metadata = task_assignments_pb2.StartTaskAssignmentMetadata() 67*14675a02SAndroid Build Coastguard Worker operation.metadata.Unpack(metadata) 68*14675a02SAndroid Build Coastguard Worker self.assertEqual(metadata, 69*14675a02SAndroid Build Coastguard Worker task_assignments_pb2.StartTaskAssignmentMetadata()) 70*14675a02SAndroid Build Coastguard Worker 71*14675a02SAndroid Build Coastguard Worker response = task_assignments_pb2.StartTaskAssignmentResponse() 72*14675a02SAndroid Build Coastguard Worker operation.response.Unpack(response) 73*14675a02SAndroid Build Coastguard Worker self.assertEqual( 74*14675a02SAndroid Build Coastguard Worker response, 75*14675a02SAndroid Build Coastguard Worker task_assignments_pb2.StartTaskAssignmentResponse( 76*14675a02SAndroid Build Coastguard Worker rejection_info=common_pb2.RejectionInfo() 77*14675a02SAndroid Build Coastguard Worker ), 78*14675a02SAndroid Build Coastguard Worker ) 79*14675a02SAndroid Build Coastguard Worker 80*14675a02SAndroid Build Coastguard Worker def test_start_task_assignment_with_multiple_assignment_task(self): 81*14675a02SAndroid Build Coastguard Worker service = task_assignments.Service( 82*14675a02SAndroid Build Coastguard Worker POPULATION_NAME, lambda: FORWARDING_INFO, self.mock_aggregations_service 83*14675a02SAndroid Build Coastguard Worker ) 84*14675a02SAndroid Build Coastguard Worker service.add_task( 85*14675a02SAndroid Build Coastguard Worker 'task', 86*14675a02SAndroid Build Coastguard Worker _TaskAssignmentMode.TASK_ASSIGNMENT_MODE_MULTIPLE, 87*14675a02SAndroid Build Coastguard Worker 'aggregation-session', 88*14675a02SAndroid Build Coastguard Worker common_pb2.Resource(uri='https://task.example/plan'), 89*14675a02SAndroid Build Coastguard Worker common_pb2.Resource(uri='https://task.example/checkpoint'), 90*14675a02SAndroid Build Coastguard Worker 'https://task.example/{key_base10}', 91*14675a02SAndroid Build Coastguard Worker ) 92*14675a02SAndroid Build Coastguard Worker 93*14675a02SAndroid Build Coastguard Worker request = task_assignments_pb2.StartTaskAssignmentRequest( 94*14675a02SAndroid Build Coastguard Worker population_name=POPULATION_NAME, session_id='session-id' 95*14675a02SAndroid Build Coastguard Worker ) 96*14675a02SAndroid Build Coastguard Worker operation = service.start_task_assignment(request) 97*14675a02SAndroid Build Coastguard Worker self.assertTrue(operation.done) 98*14675a02SAndroid Build Coastguard Worker 99*14675a02SAndroid Build Coastguard Worker response = task_assignments_pb2.StartTaskAssignmentResponse() 100*14675a02SAndroid Build Coastguard Worker operation.response.Unpack(response) 101*14675a02SAndroid Build Coastguard Worker self.assertEqual( 102*14675a02SAndroid Build Coastguard Worker response, 103*14675a02SAndroid Build Coastguard Worker task_assignments_pb2.StartTaskAssignmentResponse( 104*14675a02SAndroid Build Coastguard Worker rejection_info=common_pb2.RejectionInfo())) 105*14675a02SAndroid Build Coastguard Worker 106*14675a02SAndroid Build Coastguard Worker @mock.patch.object(uuid, 'uuid4', return_value=uuid.uuid4(), autospec=True) 107*14675a02SAndroid Build Coastguard Worker def test_start_task_assignment_with_one_task(self, mock_uuid): 108*14675a02SAndroid Build Coastguard Worker service = task_assignments.Service(POPULATION_NAME, lambda: FORWARDING_INFO, 109*14675a02SAndroid Build Coastguard Worker self.mock_aggregations_service) 110*14675a02SAndroid Build Coastguard Worker 111*14675a02SAndroid Build Coastguard Worker self.mock_aggregations_service.pre_authorize_clients.return_value = [ 112*14675a02SAndroid Build Coastguard Worker 'token' 113*14675a02SAndroid Build Coastguard Worker ] 114*14675a02SAndroid Build Coastguard Worker 115*14675a02SAndroid Build Coastguard Worker task_plan = common_pb2.Resource(uri='https://task.example/plan') 116*14675a02SAndroid Build Coastguard Worker task_checkpoint = common_pb2.Resource(uri='https://task.example/checkpoint') 117*14675a02SAndroid Build Coastguard Worker task_federated_select_uri_template = 'https://task.example/{key_base10}' 118*14675a02SAndroid Build Coastguard Worker service.add_task( 119*14675a02SAndroid Build Coastguard Worker 'task', 120*14675a02SAndroid Build Coastguard Worker _TaskAssignmentMode.TASK_ASSIGNMENT_MODE_SINGLE, 121*14675a02SAndroid Build Coastguard Worker 'aggregation-session', 122*14675a02SAndroid Build Coastguard Worker task_plan, 123*14675a02SAndroid Build Coastguard Worker task_checkpoint, 124*14675a02SAndroid Build Coastguard Worker task_federated_select_uri_template, 125*14675a02SAndroid Build Coastguard Worker ) 126*14675a02SAndroid Build Coastguard Worker 127*14675a02SAndroid Build Coastguard Worker request = task_assignments_pb2.StartTaskAssignmentRequest( 128*14675a02SAndroid Build Coastguard Worker population_name=POPULATION_NAME, session_id='session-id') 129*14675a02SAndroid Build Coastguard Worker operation = service.start_task_assignment(request) 130*14675a02SAndroid Build Coastguard Worker self.assertEqual(operation.name, f'operations/{mock_uuid.return_value}') 131*14675a02SAndroid Build Coastguard Worker self.assertTrue(operation.done) 132*14675a02SAndroid Build Coastguard Worker 133*14675a02SAndroid Build Coastguard Worker metadata = task_assignments_pb2.StartTaskAssignmentMetadata() 134*14675a02SAndroid Build Coastguard Worker operation.metadata.Unpack(metadata) 135*14675a02SAndroid Build Coastguard Worker self.assertEqual(metadata, 136*14675a02SAndroid Build Coastguard Worker task_assignments_pb2.StartTaskAssignmentMetadata()) 137*14675a02SAndroid Build Coastguard Worker 138*14675a02SAndroid Build Coastguard Worker response = task_assignments_pb2.StartTaskAssignmentResponse() 139*14675a02SAndroid Build Coastguard Worker operation.response.Unpack(response) 140*14675a02SAndroid Build Coastguard Worker self.assertEqual( 141*14675a02SAndroid Build Coastguard Worker response, 142*14675a02SAndroid Build Coastguard Worker task_assignments_pb2.StartTaskAssignmentResponse( 143*14675a02SAndroid Build Coastguard Worker task_assignment=task_assignments_pb2.TaskAssignment( 144*14675a02SAndroid Build Coastguard Worker aggregation_data_forwarding_info=FORWARDING_INFO, 145*14675a02SAndroid Build Coastguard Worker aggregation_info=( 146*14675a02SAndroid Build Coastguard Worker task_assignments_pb2.TaskAssignment.AggregationInfo() 147*14675a02SAndroid Build Coastguard Worker ), 148*14675a02SAndroid Build Coastguard Worker session_id=request.session_id, 149*14675a02SAndroid Build Coastguard Worker aggregation_id='aggregation-session', 150*14675a02SAndroid Build Coastguard Worker authorization_token='token', 151*14675a02SAndroid Build Coastguard Worker task_name='task', 152*14675a02SAndroid Build Coastguard Worker plan=task_plan, 153*14675a02SAndroid Build Coastguard Worker init_checkpoint=task_checkpoint, 154*14675a02SAndroid Build Coastguard Worker federated_select_uri_info=( 155*14675a02SAndroid Build Coastguard Worker task_assignments_pb2.FederatedSelectUriInfo( 156*14675a02SAndroid Build Coastguard Worker uri_template=task_federated_select_uri_template 157*14675a02SAndroid Build Coastguard Worker ) 158*14675a02SAndroid Build Coastguard Worker ), 159*14675a02SAndroid Build Coastguard Worker ) 160*14675a02SAndroid Build Coastguard Worker ), 161*14675a02SAndroid Build Coastguard Worker ) 162*14675a02SAndroid Build Coastguard Worker 163*14675a02SAndroid Build Coastguard Worker self.mock_aggregations_service.pre_authorize_clients.assert_called_once_with( 164*14675a02SAndroid Build Coastguard Worker 'aggregation-session', num_tokens=1) 165*14675a02SAndroid Build Coastguard Worker 166*14675a02SAndroid Build Coastguard Worker def test_start_task_assignment_with_multiple_tasks(self): 167*14675a02SAndroid Build Coastguard Worker service = task_assignments.Service(POPULATION_NAME, lambda: FORWARDING_INFO, 168*14675a02SAndroid Build Coastguard Worker self.mock_aggregations_service) 169*14675a02SAndroid Build Coastguard Worker 170*14675a02SAndroid Build Coastguard Worker self.mock_aggregations_service.pre_authorize_clients.return_value = [ 171*14675a02SAndroid Build Coastguard Worker 'token' 172*14675a02SAndroid Build Coastguard Worker ] 173*14675a02SAndroid Build Coastguard Worker 174*14675a02SAndroid Build Coastguard Worker task1_plan = common_pb2.Resource(uri='https://task1.example/plan') 175*14675a02SAndroid Build Coastguard Worker task1_checkpoint = common_pb2.Resource( 176*14675a02SAndroid Build Coastguard Worker uri='https://task1.example/checkpoint') 177*14675a02SAndroid Build Coastguard Worker task1_federated_select_uri_template = 'https://task1.example/{key_base10}' 178*14675a02SAndroid Build Coastguard Worker service.add_task( 179*14675a02SAndroid Build Coastguard Worker 'task1', 180*14675a02SAndroid Build Coastguard Worker _TaskAssignmentMode.TASK_ASSIGNMENT_MODE_SINGLE, 181*14675a02SAndroid Build Coastguard Worker 'aggregation-session1', 182*14675a02SAndroid Build Coastguard Worker task1_plan, 183*14675a02SAndroid Build Coastguard Worker task1_checkpoint, 184*14675a02SAndroid Build Coastguard Worker task1_federated_select_uri_template, 185*14675a02SAndroid Build Coastguard Worker ) 186*14675a02SAndroid Build Coastguard Worker task2_plan = common_pb2.Resource(uri='https://task2.example/plan') 187*14675a02SAndroid Build Coastguard Worker task2_checkpoint = common_pb2.Resource( 188*14675a02SAndroid Build Coastguard Worker uri='https://task2.example/checkpoint') 189*14675a02SAndroid Build Coastguard Worker task2_federated_select_uri_template = 'https://task2.example/{key_base10}' 190*14675a02SAndroid Build Coastguard Worker service.add_task( 191*14675a02SAndroid Build Coastguard Worker 'task2', 192*14675a02SAndroid Build Coastguard Worker _TaskAssignmentMode.TASK_ASSIGNMENT_MODE_SINGLE, 193*14675a02SAndroid Build Coastguard Worker 'aggregation-session2', 194*14675a02SAndroid Build Coastguard Worker task2_plan, 195*14675a02SAndroid Build Coastguard Worker task2_checkpoint, 196*14675a02SAndroid Build Coastguard Worker task2_federated_select_uri_template, 197*14675a02SAndroid Build Coastguard Worker ) 198*14675a02SAndroid Build Coastguard Worker 199*14675a02SAndroid Build Coastguard Worker request = task_assignments_pb2.StartTaskAssignmentRequest( 200*14675a02SAndroid Build Coastguard Worker population_name=POPULATION_NAME, session_id='session-id') 201*14675a02SAndroid Build Coastguard Worker 202*14675a02SAndroid Build Coastguard Worker # Initially, task1 should be used. 203*14675a02SAndroid Build Coastguard Worker operation = service.start_task_assignment(request) 204*14675a02SAndroid Build Coastguard Worker response = task_assignments_pb2.StartTaskAssignmentResponse() 205*14675a02SAndroid Build Coastguard Worker operation.response.Unpack(response) 206*14675a02SAndroid Build Coastguard Worker self.assertEqual( 207*14675a02SAndroid Build Coastguard Worker response, 208*14675a02SAndroid Build Coastguard Worker task_assignments_pb2.StartTaskAssignmentResponse( 209*14675a02SAndroid Build Coastguard Worker task_assignment=task_assignments_pb2.TaskAssignment( 210*14675a02SAndroid Build Coastguard Worker aggregation_data_forwarding_info=FORWARDING_INFO, 211*14675a02SAndroid Build Coastguard Worker aggregation_info=( 212*14675a02SAndroid Build Coastguard Worker task_assignments_pb2.TaskAssignment.AggregationInfo() 213*14675a02SAndroid Build Coastguard Worker ), 214*14675a02SAndroid Build Coastguard Worker session_id=request.session_id, 215*14675a02SAndroid Build Coastguard Worker aggregation_id='aggregation-session1', 216*14675a02SAndroid Build Coastguard Worker authorization_token='token', 217*14675a02SAndroid Build Coastguard Worker task_name='task1', 218*14675a02SAndroid Build Coastguard Worker plan=task1_plan, 219*14675a02SAndroid Build Coastguard Worker init_checkpoint=task1_checkpoint, 220*14675a02SAndroid Build Coastguard Worker federated_select_uri_info=( 221*14675a02SAndroid Build Coastguard Worker task_assignments_pb2.FederatedSelectUriInfo( 222*14675a02SAndroid Build Coastguard Worker uri_template=task1_federated_select_uri_template 223*14675a02SAndroid Build Coastguard Worker ) 224*14675a02SAndroid Build Coastguard Worker ), 225*14675a02SAndroid Build Coastguard Worker ) 226*14675a02SAndroid Build Coastguard Worker ), 227*14675a02SAndroid Build Coastguard Worker ) 228*14675a02SAndroid Build Coastguard Worker self.mock_aggregations_service.pre_authorize_clients.assert_called_with( 229*14675a02SAndroid Build Coastguard Worker 'aggregation-session1', num_tokens=1) 230*14675a02SAndroid Build Coastguard Worker 231*14675a02SAndroid Build Coastguard Worker # After task1 is removed, task2 should be used. 232*14675a02SAndroid Build Coastguard Worker service.remove_task('aggregation-session1') 233*14675a02SAndroid Build Coastguard Worker operation = service.start_task_assignment(request) 234*14675a02SAndroid Build Coastguard Worker response = task_assignments_pb2.StartTaskAssignmentResponse() 235*14675a02SAndroid Build Coastguard Worker operation.response.Unpack(response) 236*14675a02SAndroid Build Coastguard Worker self.assertEqual( 237*14675a02SAndroid Build Coastguard Worker response, 238*14675a02SAndroid Build Coastguard Worker task_assignments_pb2.StartTaskAssignmentResponse( 239*14675a02SAndroid Build Coastguard Worker task_assignment=task_assignments_pb2.TaskAssignment( 240*14675a02SAndroid Build Coastguard Worker aggregation_data_forwarding_info=FORWARDING_INFO, 241*14675a02SAndroid Build Coastguard Worker aggregation_info=( 242*14675a02SAndroid Build Coastguard Worker task_assignments_pb2.TaskAssignment.AggregationInfo() 243*14675a02SAndroid Build Coastguard Worker ), 244*14675a02SAndroid Build Coastguard Worker session_id=request.session_id, 245*14675a02SAndroid Build Coastguard Worker aggregation_id='aggregation-session2', 246*14675a02SAndroid Build Coastguard Worker authorization_token='token', 247*14675a02SAndroid Build Coastguard Worker task_name='task2', 248*14675a02SAndroid Build Coastguard Worker plan=task2_plan, 249*14675a02SAndroid Build Coastguard Worker init_checkpoint=task2_checkpoint, 250*14675a02SAndroid Build Coastguard Worker federated_select_uri_info=( 251*14675a02SAndroid Build Coastguard Worker task_assignments_pb2.FederatedSelectUriInfo( 252*14675a02SAndroid Build Coastguard Worker uri_template=task2_federated_select_uri_template 253*14675a02SAndroid Build Coastguard Worker ) 254*14675a02SAndroid Build Coastguard Worker ), 255*14675a02SAndroid Build Coastguard Worker ) 256*14675a02SAndroid Build Coastguard Worker ), 257*14675a02SAndroid Build Coastguard Worker ) 258*14675a02SAndroid Build Coastguard Worker self.mock_aggregations_service.pre_authorize_clients.assert_called_with( 259*14675a02SAndroid Build Coastguard Worker 'aggregation-session2', num_tokens=1) 260*14675a02SAndroid Build Coastguard Worker 261*14675a02SAndroid Build Coastguard Worker # After task2 is removed, the client should be rejected. 262*14675a02SAndroid Build Coastguard Worker service.remove_task('aggregation-session2') 263*14675a02SAndroid Build Coastguard Worker operation = service.start_task_assignment(request) 264*14675a02SAndroid Build Coastguard Worker response = task_assignments_pb2.StartTaskAssignmentResponse() 265*14675a02SAndroid Build Coastguard Worker operation.response.Unpack(response) 266*14675a02SAndroid Build Coastguard Worker self.assertEqual( 267*14675a02SAndroid Build Coastguard Worker response, 268*14675a02SAndroid Build Coastguard Worker task_assignments_pb2.StartTaskAssignmentResponse( 269*14675a02SAndroid Build Coastguard Worker rejection_info=common_pb2.RejectionInfo())) 270*14675a02SAndroid Build Coastguard Worker 271*14675a02SAndroid Build Coastguard Worker def test_perform_multiple_task_assignments_with_wrong_population(self): 272*14675a02SAndroid Build Coastguard Worker service = task_assignments.Service(POPULATION_NAME, lambda: FORWARDING_INFO, 273*14675a02SAndroid Build Coastguard Worker self.mock_aggregations_service) 274*14675a02SAndroid Build Coastguard Worker request = task_assignments_pb2.PerformMultipleTaskAssignmentsRequest( 275*14675a02SAndroid Build Coastguard Worker population_name='other/population', 276*14675a02SAndroid Build Coastguard Worker session_id='session-id', 277*14675a02SAndroid Build Coastguard Worker task_names=['task1', 'task2', 'task3']) 278*14675a02SAndroid Build Coastguard Worker with self.assertRaises(http_actions.HttpError) as cm: 279*14675a02SAndroid Build Coastguard Worker service.perform_multiple_task_assignments(request) 280*14675a02SAndroid Build Coastguard Worker self.assertEqual(cm.exception.code, http.HTTPStatus.NOT_FOUND) 281*14675a02SAndroid Build Coastguard Worker 282*14675a02SAndroid Build Coastguard Worker def test_perform_multiple_task_assignments_without_tasks(self): 283*14675a02SAndroid Build Coastguard Worker service = task_assignments.Service(POPULATION_NAME, lambda: FORWARDING_INFO, 284*14675a02SAndroid Build Coastguard Worker self.mock_aggregations_service) 285*14675a02SAndroid Build Coastguard Worker 286*14675a02SAndroid Build Coastguard Worker request = task_assignments_pb2.PerformMultipleTaskAssignmentsRequest( 287*14675a02SAndroid Build Coastguard Worker population_name=POPULATION_NAME, 288*14675a02SAndroid Build Coastguard Worker session_id='session-id', 289*14675a02SAndroid Build Coastguard Worker task_names=['task1', 'task2', 'task3']) 290*14675a02SAndroid Build Coastguard Worker self.assertEqual( 291*14675a02SAndroid Build Coastguard Worker service.perform_multiple_task_assignments(request), 292*14675a02SAndroid Build Coastguard Worker task_assignments_pb2.PerformMultipleTaskAssignmentsResponse()) 293*14675a02SAndroid Build Coastguard Worker 294*14675a02SAndroid Build Coastguard Worker def test_perform_multiple_task_assignments_with_multiple_tasks(self): 295*14675a02SAndroid Build Coastguard Worker self.mock_aggregations_service.pre_authorize_clients.side_effect = ( 296*14675a02SAndroid Build Coastguard Worker lambda session_id, num_tokens=1: [f'token-for-{session_id}']) 297*14675a02SAndroid Build Coastguard Worker service = task_assignments.Service(POPULATION_NAME, lambda: FORWARDING_INFO, 298*14675a02SAndroid Build Coastguard Worker self.mock_aggregations_service) 299*14675a02SAndroid Build Coastguard Worker 300*14675a02SAndroid Build Coastguard Worker task1_plan = common_pb2.Resource(uri='https://task1.example/plan') 301*14675a02SAndroid Build Coastguard Worker task1_checkpoint = common_pb2.Resource( 302*14675a02SAndroid Build Coastguard Worker uri='https://task1.example/checkpoint') 303*14675a02SAndroid Build Coastguard Worker task1_federated_select_uri_template = 'https://task1.example/{key_base10}' 304*14675a02SAndroid Build Coastguard Worker service.add_task( 305*14675a02SAndroid Build Coastguard Worker 'task1', 306*14675a02SAndroid Build Coastguard Worker _TaskAssignmentMode.TASK_ASSIGNMENT_MODE_MULTIPLE, 307*14675a02SAndroid Build Coastguard Worker 'aggregation-session1', 308*14675a02SAndroid Build Coastguard Worker task1_plan, 309*14675a02SAndroid Build Coastguard Worker task1_checkpoint, 310*14675a02SAndroid Build Coastguard Worker task1_federated_select_uri_template, 311*14675a02SAndroid Build Coastguard Worker ) 312*14675a02SAndroid Build Coastguard Worker task2_plan = common_pb2.Resource(uri='https://task2.example/plan') 313*14675a02SAndroid Build Coastguard Worker task2_checkpoint = common_pb2.Resource( 314*14675a02SAndroid Build Coastguard Worker uri='https://task2.example/checkpoint') 315*14675a02SAndroid Build Coastguard Worker task2_federated_select_uri_template = 'https://task2.example/{key_base10}' 316*14675a02SAndroid Build Coastguard Worker service.add_task( 317*14675a02SAndroid Build Coastguard Worker 'task2', 318*14675a02SAndroid Build Coastguard Worker _TaskAssignmentMode.TASK_ASSIGNMENT_MODE_MULTIPLE, 319*14675a02SAndroid Build Coastguard Worker 'aggregation-session2', 320*14675a02SAndroid Build Coastguard Worker task2_plan, 321*14675a02SAndroid Build Coastguard Worker task2_checkpoint, 322*14675a02SAndroid Build Coastguard Worker task2_federated_select_uri_template, 323*14675a02SAndroid Build Coastguard Worker ) 324*14675a02SAndroid Build Coastguard Worker # Tasks using other TaskAssignmentModes should be skipped. 325*14675a02SAndroid Build Coastguard Worker task3_plan = common_pb2.Resource(uri='https://task3.example/plan') 326*14675a02SAndroid Build Coastguard Worker task3_checkpoint = common_pb2.Resource( 327*14675a02SAndroid Build Coastguard Worker uri='https://task3.example/checkpoint' 328*14675a02SAndroid Build Coastguard Worker ) 329*14675a02SAndroid Build Coastguard Worker task3_federated_select_uri_template = 'https://task3.example/{key_base10}' 330*14675a02SAndroid Build Coastguard Worker service.add_task( 331*14675a02SAndroid Build Coastguard Worker 'task3', 332*14675a02SAndroid Build Coastguard Worker _TaskAssignmentMode.TASK_ASSIGNMENT_MODE_SINGLE, 333*14675a02SAndroid Build Coastguard Worker 'aggregation-session3', 334*14675a02SAndroid Build Coastguard Worker task3_plan, 335*14675a02SAndroid Build Coastguard Worker task3_checkpoint, 336*14675a02SAndroid Build Coastguard Worker task3_federated_select_uri_template, 337*14675a02SAndroid Build Coastguard Worker ) 338*14675a02SAndroid Build Coastguard Worker 339*14675a02SAndroid Build Coastguard Worker request = task_assignments_pb2.PerformMultipleTaskAssignmentsRequest( 340*14675a02SAndroid Build Coastguard Worker population_name=POPULATION_NAME, 341*14675a02SAndroid Build Coastguard Worker session_id='session-id', 342*14675a02SAndroid Build Coastguard Worker task_names=['task1', 'task2', 'task3']) 343*14675a02SAndroid Build Coastguard Worker self.assertCountEqual( 344*14675a02SAndroid Build Coastguard Worker service.perform_multiple_task_assignments(request).task_assignments, 345*14675a02SAndroid Build Coastguard Worker [ 346*14675a02SAndroid Build Coastguard Worker task_assignments_pb2.TaskAssignment( 347*14675a02SAndroid Build Coastguard Worker aggregation_data_forwarding_info=FORWARDING_INFO, 348*14675a02SAndroid Build Coastguard Worker aggregation_info=( 349*14675a02SAndroid Build Coastguard Worker task_assignments_pb2.TaskAssignment.AggregationInfo() 350*14675a02SAndroid Build Coastguard Worker ), 351*14675a02SAndroid Build Coastguard Worker session_id=request.session_id, 352*14675a02SAndroid Build Coastguard Worker aggregation_id='aggregation-session1', 353*14675a02SAndroid Build Coastguard Worker authorization_token='token-for-aggregation-session1', 354*14675a02SAndroid Build Coastguard Worker task_name='task1', 355*14675a02SAndroid Build Coastguard Worker plan=task1_plan, 356*14675a02SAndroid Build Coastguard Worker init_checkpoint=task1_checkpoint, 357*14675a02SAndroid Build Coastguard Worker federated_select_uri_info=( 358*14675a02SAndroid Build Coastguard Worker task_assignments_pb2.FederatedSelectUriInfo( 359*14675a02SAndroid Build Coastguard Worker uri_template=task1_federated_select_uri_template 360*14675a02SAndroid Build Coastguard Worker ) 361*14675a02SAndroid Build Coastguard Worker ), 362*14675a02SAndroid Build Coastguard Worker ), 363*14675a02SAndroid Build Coastguard Worker task_assignments_pb2.TaskAssignment( 364*14675a02SAndroid Build Coastguard Worker aggregation_data_forwarding_info=FORWARDING_INFO, 365*14675a02SAndroid Build Coastguard Worker aggregation_info=( 366*14675a02SAndroid Build Coastguard Worker task_assignments_pb2.TaskAssignment.AggregationInfo() 367*14675a02SAndroid Build Coastguard Worker ), 368*14675a02SAndroid Build Coastguard Worker session_id=request.session_id, 369*14675a02SAndroid Build Coastguard Worker aggregation_id='aggregation-session2', 370*14675a02SAndroid Build Coastguard Worker authorization_token='token-for-aggregation-session2', 371*14675a02SAndroid Build Coastguard Worker task_name='task2', 372*14675a02SAndroid Build Coastguard Worker plan=task2_plan, 373*14675a02SAndroid Build Coastguard Worker init_checkpoint=task2_checkpoint, 374*14675a02SAndroid Build Coastguard Worker federated_select_uri_info=( 375*14675a02SAndroid Build Coastguard Worker task_assignments_pb2.FederatedSelectUriInfo( 376*14675a02SAndroid Build Coastguard Worker uri_template=task2_federated_select_uri_template 377*14675a02SAndroid Build Coastguard Worker ) 378*14675a02SAndroid Build Coastguard Worker ), 379*14675a02SAndroid Build Coastguard Worker ), 380*14675a02SAndroid Build Coastguard Worker # 'task3' should be omitted since there isn't a corresponding task. 381*14675a02SAndroid Build Coastguard Worker ], 382*14675a02SAndroid Build Coastguard Worker ) 383*14675a02SAndroid Build Coastguard Worker 384*14675a02SAndroid Build Coastguard Worker def test_add_task_with_invalid_task_assignment_mode(self): 385*14675a02SAndroid Build Coastguard Worker service = task_assignments.Service( 386*14675a02SAndroid Build Coastguard Worker POPULATION_NAME, lambda: FORWARDING_INFO, self.mock_aggregations_service 387*14675a02SAndroid Build Coastguard Worker ) 388*14675a02SAndroid Build Coastguard Worker with self.assertRaises(ValueError): 389*14675a02SAndroid Build Coastguard Worker service.add_task( 390*14675a02SAndroid Build Coastguard Worker 'task', 391*14675a02SAndroid Build Coastguard Worker _TaskAssignmentMode.TASK_ASSIGNMENT_MODE_UNSPECIFIED, 392*14675a02SAndroid Build Coastguard Worker 'aggregation-session', 393*14675a02SAndroid Build Coastguard Worker common_pb2.Resource(uri='https://task.example/plan'), 394*14675a02SAndroid Build Coastguard Worker common_pb2.Resource(uri='https://task.example/checkpoint'), 395*14675a02SAndroid Build Coastguard Worker 'https://task.example/{key_base10}', 396*14675a02SAndroid Build Coastguard Worker ) 397*14675a02SAndroid Build Coastguard Worker 398*14675a02SAndroid Build Coastguard Worker def test_remove_multiple_assignment_task(self): 399*14675a02SAndroid Build Coastguard Worker service = task_assignments.Service( 400*14675a02SAndroid Build Coastguard Worker POPULATION_NAME, lambda: FORWARDING_INFO, self.mock_aggregations_service 401*14675a02SAndroid Build Coastguard Worker ) 402*14675a02SAndroid Build Coastguard Worker service.add_task( 403*14675a02SAndroid Build Coastguard Worker 'task', 404*14675a02SAndroid Build Coastguard Worker _TaskAssignmentMode.TASK_ASSIGNMENT_MODE_MULTIPLE, 405*14675a02SAndroid Build Coastguard Worker 'aggregation-session', 406*14675a02SAndroid Build Coastguard Worker common_pb2.Resource(uri='https://task.example/plan'), 407*14675a02SAndroid Build Coastguard Worker common_pb2.Resource(uri='https://task.example/checkpoint'), 408*14675a02SAndroid Build Coastguard Worker 'https://task.example/{key_base10}', 409*14675a02SAndroid Build Coastguard Worker ) 410*14675a02SAndroid Build Coastguard Worker service.remove_task('aggregation-session') 411*14675a02SAndroid Build Coastguard Worker 412*14675a02SAndroid Build Coastguard Worker request = task_assignments_pb2.PerformMultipleTaskAssignmentsRequest( 413*14675a02SAndroid Build Coastguard Worker population_name=POPULATION_NAME, 414*14675a02SAndroid Build Coastguard Worker session_id='session-id', 415*14675a02SAndroid Build Coastguard Worker task_names=['task'], 416*14675a02SAndroid Build Coastguard Worker ) 417*14675a02SAndroid Build Coastguard Worker self.assertEmpty( 418*14675a02SAndroid Build Coastguard Worker service.perform_multiple_task_assignments(request).task_assignments 419*14675a02SAndroid Build Coastguard Worker ) 420*14675a02SAndroid Build Coastguard Worker 421*14675a02SAndroid Build Coastguard Worker def test_remove_missing_task(self): 422*14675a02SAndroid Build Coastguard Worker service = task_assignments.Service(POPULATION_NAME, lambda: FORWARDING_INFO, 423*14675a02SAndroid Build Coastguard Worker self.mock_aggregations_service) 424*14675a02SAndroid Build Coastguard Worker with self.assertRaises(KeyError): 425*14675a02SAndroid Build Coastguard Worker service.remove_task('does-not-exist') 426*14675a02SAndroid Build Coastguard Worker 427*14675a02SAndroid Build Coastguard Worker def test_report_task_result(self): 428*14675a02SAndroid Build Coastguard Worker service = task_assignments.Service(POPULATION_NAME, lambda: FORWARDING_INFO, 429*14675a02SAndroid Build Coastguard Worker self.mock_aggregations_service) 430*14675a02SAndroid Build Coastguard Worker request = task_assignments_pb2.ReportTaskResultRequest( 431*14675a02SAndroid Build Coastguard Worker population_name=POPULATION_NAME, 432*14675a02SAndroid Build Coastguard Worker session_id='session-id', 433*14675a02SAndroid Build Coastguard Worker aggregation_id='aggregation-id', 434*14675a02SAndroid Build Coastguard Worker computation_status_code=code_pb2.ABORTED) 435*14675a02SAndroid Build Coastguard Worker self.assertEqual( 436*14675a02SAndroid Build Coastguard Worker service.report_task_result(request), 437*14675a02SAndroid Build Coastguard Worker task_assignments_pb2.ReportTaskResultResponse()) 438*14675a02SAndroid Build Coastguard Worker 439*14675a02SAndroid Build Coastguard Worker def test_report_task_result_with_wrong_population(self): 440*14675a02SAndroid Build Coastguard Worker service = task_assignments.Service(POPULATION_NAME, lambda: FORWARDING_INFO, 441*14675a02SAndroid Build Coastguard Worker self.mock_aggregations_service) 442*14675a02SAndroid Build Coastguard Worker request = task_assignments_pb2.ReportTaskResultRequest( 443*14675a02SAndroid Build Coastguard Worker population_name='other/population', 444*14675a02SAndroid Build Coastguard Worker session_id='session-id', 445*14675a02SAndroid Build Coastguard Worker aggregation_id='aggregation-id', 446*14675a02SAndroid Build Coastguard Worker computation_status_code=code_pb2.ABORTED) 447*14675a02SAndroid Build Coastguard Worker with self.assertRaises(http_actions.HttpError) as cm: 448*14675a02SAndroid Build Coastguard Worker service.report_task_result(request) 449*14675a02SAndroid Build Coastguard Worker self.assertEqual(cm.exception.code, http.HTTPStatus.NOT_FOUND) 450*14675a02SAndroid Build Coastguard Worker 451*14675a02SAndroid Build Coastguard Worker 452*14675a02SAndroid Build Coastguard Workerif __name__ == '__main__': 453*14675a02SAndroid Build Coastguard Worker absltest.main() 454