1# Copyright 2022 Google LLC 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14"""Tests for task_assignments.""" 15 16import http 17from unittest import mock 18import uuid 19 20from absl.testing import absltest 21 22from google.rpc import code_pb2 23from fcp.demo import aggregations 24from fcp.demo import http_actions 25from fcp.demo import task_assignments 26from fcp.protos.federatedcompute import common_pb2 27from fcp.protos.federatedcompute import eligibility_eval_tasks_pb2 28from fcp.protos.federatedcompute import task_assignments_pb2 29 30_TaskAssignmentMode = ( 31 eligibility_eval_tasks_pb2.PopulationEligibilitySpec.TaskInfo.TaskAssignmentMode 32) 33 34POPULATION_NAME = 'test/population' 35FORWARDING_INFO = common_pb2.ForwardingInfo( 36 target_uri_prefix='https://forwarding.example/') 37 38 39class TaskAssignmentsTest(absltest.TestCase): 40 41 def setUp(self): 42 super().setUp() 43 self.mock_aggregations_service = self.enter_context( 44 mock.patch.object(aggregations, 'Service', autospec=True)) 45 self.mock_aggregations_service.pre_authorize_clients.return_value = [''] 46 47 def test_start_task_assignment_with_wrong_population(self): 48 service = task_assignments.Service(POPULATION_NAME, lambda: FORWARDING_INFO, 49 self.mock_aggregations_service) 50 request = task_assignments_pb2.StartTaskAssignmentRequest( 51 population_name='other/population', session_id='session-id') 52 with self.assertRaises(http_actions.HttpError) as cm: 53 service.start_task_assignment(request) 54 self.assertEqual(cm.exception.code, http.HTTPStatus.NOT_FOUND) 55 56 @mock.patch.object(uuid, 'uuid4', return_value=uuid.uuid4(), autospec=True) 57 def test_start_task_assignment_with_no_tasks(self, mock_uuid): 58 service = task_assignments.Service(POPULATION_NAME, lambda: FORWARDING_INFO, 59 self.mock_aggregations_service) 60 request = task_assignments_pb2.StartTaskAssignmentRequest( 61 population_name=POPULATION_NAME, session_id='session-id') 62 operation = service.start_task_assignment(request) 63 self.assertEqual(operation.name, f'operations/{mock_uuid.return_value}') 64 self.assertTrue(operation.done) 65 66 metadata = task_assignments_pb2.StartTaskAssignmentMetadata() 67 operation.metadata.Unpack(metadata) 68 self.assertEqual(metadata, 69 task_assignments_pb2.StartTaskAssignmentMetadata()) 70 71 response = task_assignments_pb2.StartTaskAssignmentResponse() 72 operation.response.Unpack(response) 73 self.assertEqual( 74 response, 75 task_assignments_pb2.StartTaskAssignmentResponse( 76 rejection_info=common_pb2.RejectionInfo() 77 ), 78 ) 79 80 def test_start_task_assignment_with_multiple_assignment_task(self): 81 service = task_assignments.Service( 82 POPULATION_NAME, lambda: FORWARDING_INFO, self.mock_aggregations_service 83 ) 84 service.add_task( 85 'task', 86 _TaskAssignmentMode.TASK_ASSIGNMENT_MODE_MULTIPLE, 87 'aggregation-session', 88 common_pb2.Resource(uri='https://task.example/plan'), 89 common_pb2.Resource(uri='https://task.example/checkpoint'), 90 'https://task.example/{key_base10}', 91 ) 92 93 request = task_assignments_pb2.StartTaskAssignmentRequest( 94 population_name=POPULATION_NAME, session_id='session-id' 95 ) 96 operation = service.start_task_assignment(request) 97 self.assertTrue(operation.done) 98 99 response = task_assignments_pb2.StartTaskAssignmentResponse() 100 operation.response.Unpack(response) 101 self.assertEqual( 102 response, 103 task_assignments_pb2.StartTaskAssignmentResponse( 104 rejection_info=common_pb2.RejectionInfo())) 105 106 @mock.patch.object(uuid, 'uuid4', return_value=uuid.uuid4(), autospec=True) 107 def test_start_task_assignment_with_one_task(self, mock_uuid): 108 service = task_assignments.Service(POPULATION_NAME, lambda: FORWARDING_INFO, 109 self.mock_aggregations_service) 110 111 self.mock_aggregations_service.pre_authorize_clients.return_value = [ 112 'token' 113 ] 114 115 task_plan = common_pb2.Resource(uri='https://task.example/plan') 116 task_checkpoint = common_pb2.Resource(uri='https://task.example/checkpoint') 117 task_federated_select_uri_template = 'https://task.example/{key_base10}' 118 service.add_task( 119 'task', 120 _TaskAssignmentMode.TASK_ASSIGNMENT_MODE_SINGLE, 121 'aggregation-session', 122 task_plan, 123 task_checkpoint, 124 task_federated_select_uri_template, 125 ) 126 127 request = task_assignments_pb2.StartTaskAssignmentRequest( 128 population_name=POPULATION_NAME, session_id='session-id') 129 operation = service.start_task_assignment(request) 130 self.assertEqual(operation.name, f'operations/{mock_uuid.return_value}') 131 self.assertTrue(operation.done) 132 133 metadata = task_assignments_pb2.StartTaskAssignmentMetadata() 134 operation.metadata.Unpack(metadata) 135 self.assertEqual(metadata, 136 task_assignments_pb2.StartTaskAssignmentMetadata()) 137 138 response = task_assignments_pb2.StartTaskAssignmentResponse() 139 operation.response.Unpack(response) 140 self.assertEqual( 141 response, 142 task_assignments_pb2.StartTaskAssignmentResponse( 143 task_assignment=task_assignments_pb2.TaskAssignment( 144 aggregation_data_forwarding_info=FORWARDING_INFO, 145 aggregation_info=( 146 task_assignments_pb2.TaskAssignment.AggregationInfo() 147 ), 148 session_id=request.session_id, 149 aggregation_id='aggregation-session', 150 authorization_token='token', 151 task_name='task', 152 plan=task_plan, 153 init_checkpoint=task_checkpoint, 154 federated_select_uri_info=( 155 task_assignments_pb2.FederatedSelectUriInfo( 156 uri_template=task_federated_select_uri_template 157 ) 158 ), 159 ) 160 ), 161 ) 162 163 self.mock_aggregations_service.pre_authorize_clients.assert_called_once_with( 164 'aggregation-session', num_tokens=1) 165 166 def test_start_task_assignment_with_multiple_tasks(self): 167 service = task_assignments.Service(POPULATION_NAME, lambda: FORWARDING_INFO, 168 self.mock_aggregations_service) 169 170 self.mock_aggregations_service.pre_authorize_clients.return_value = [ 171 'token' 172 ] 173 174 task1_plan = common_pb2.Resource(uri='https://task1.example/plan') 175 task1_checkpoint = common_pb2.Resource( 176 uri='https://task1.example/checkpoint') 177 task1_federated_select_uri_template = 'https://task1.example/{key_base10}' 178 service.add_task( 179 'task1', 180 _TaskAssignmentMode.TASK_ASSIGNMENT_MODE_SINGLE, 181 'aggregation-session1', 182 task1_plan, 183 task1_checkpoint, 184 task1_federated_select_uri_template, 185 ) 186 task2_plan = common_pb2.Resource(uri='https://task2.example/plan') 187 task2_checkpoint = common_pb2.Resource( 188 uri='https://task2.example/checkpoint') 189 task2_federated_select_uri_template = 'https://task2.example/{key_base10}' 190 service.add_task( 191 'task2', 192 _TaskAssignmentMode.TASK_ASSIGNMENT_MODE_SINGLE, 193 'aggregation-session2', 194 task2_plan, 195 task2_checkpoint, 196 task2_federated_select_uri_template, 197 ) 198 199 request = task_assignments_pb2.StartTaskAssignmentRequest( 200 population_name=POPULATION_NAME, session_id='session-id') 201 202 # Initially, task1 should be used. 203 operation = service.start_task_assignment(request) 204 response = task_assignments_pb2.StartTaskAssignmentResponse() 205 operation.response.Unpack(response) 206 self.assertEqual( 207 response, 208 task_assignments_pb2.StartTaskAssignmentResponse( 209 task_assignment=task_assignments_pb2.TaskAssignment( 210 aggregation_data_forwarding_info=FORWARDING_INFO, 211 aggregation_info=( 212 task_assignments_pb2.TaskAssignment.AggregationInfo() 213 ), 214 session_id=request.session_id, 215 aggregation_id='aggregation-session1', 216 authorization_token='token', 217 task_name='task1', 218 plan=task1_plan, 219 init_checkpoint=task1_checkpoint, 220 federated_select_uri_info=( 221 task_assignments_pb2.FederatedSelectUriInfo( 222 uri_template=task1_federated_select_uri_template 223 ) 224 ), 225 ) 226 ), 227 ) 228 self.mock_aggregations_service.pre_authorize_clients.assert_called_with( 229 'aggregation-session1', num_tokens=1) 230 231 # After task1 is removed, task2 should be used. 232 service.remove_task('aggregation-session1') 233 operation = service.start_task_assignment(request) 234 response = task_assignments_pb2.StartTaskAssignmentResponse() 235 operation.response.Unpack(response) 236 self.assertEqual( 237 response, 238 task_assignments_pb2.StartTaskAssignmentResponse( 239 task_assignment=task_assignments_pb2.TaskAssignment( 240 aggregation_data_forwarding_info=FORWARDING_INFO, 241 aggregation_info=( 242 task_assignments_pb2.TaskAssignment.AggregationInfo() 243 ), 244 session_id=request.session_id, 245 aggregation_id='aggregation-session2', 246 authorization_token='token', 247 task_name='task2', 248 plan=task2_plan, 249 init_checkpoint=task2_checkpoint, 250 federated_select_uri_info=( 251 task_assignments_pb2.FederatedSelectUriInfo( 252 uri_template=task2_federated_select_uri_template 253 ) 254 ), 255 ) 256 ), 257 ) 258 self.mock_aggregations_service.pre_authorize_clients.assert_called_with( 259 'aggregation-session2', num_tokens=1) 260 261 # After task2 is removed, the client should be rejected. 262 service.remove_task('aggregation-session2') 263 operation = service.start_task_assignment(request) 264 response = task_assignments_pb2.StartTaskAssignmentResponse() 265 operation.response.Unpack(response) 266 self.assertEqual( 267 response, 268 task_assignments_pb2.StartTaskAssignmentResponse( 269 rejection_info=common_pb2.RejectionInfo())) 270 271 def test_perform_multiple_task_assignments_with_wrong_population(self): 272 service = task_assignments.Service(POPULATION_NAME, lambda: FORWARDING_INFO, 273 self.mock_aggregations_service) 274 request = task_assignments_pb2.PerformMultipleTaskAssignmentsRequest( 275 population_name='other/population', 276 session_id='session-id', 277 task_names=['task1', 'task2', 'task3']) 278 with self.assertRaises(http_actions.HttpError) as cm: 279 service.perform_multiple_task_assignments(request) 280 self.assertEqual(cm.exception.code, http.HTTPStatus.NOT_FOUND) 281 282 def test_perform_multiple_task_assignments_without_tasks(self): 283 service = task_assignments.Service(POPULATION_NAME, lambda: FORWARDING_INFO, 284 self.mock_aggregations_service) 285 286 request = task_assignments_pb2.PerformMultipleTaskAssignmentsRequest( 287 population_name=POPULATION_NAME, 288 session_id='session-id', 289 task_names=['task1', 'task2', 'task3']) 290 self.assertEqual( 291 service.perform_multiple_task_assignments(request), 292 task_assignments_pb2.PerformMultipleTaskAssignmentsResponse()) 293 294 def test_perform_multiple_task_assignments_with_multiple_tasks(self): 295 self.mock_aggregations_service.pre_authorize_clients.side_effect = ( 296 lambda session_id, num_tokens=1: [f'token-for-{session_id}']) 297 service = task_assignments.Service(POPULATION_NAME, lambda: FORWARDING_INFO, 298 self.mock_aggregations_service) 299 300 task1_plan = common_pb2.Resource(uri='https://task1.example/plan') 301 task1_checkpoint = common_pb2.Resource( 302 uri='https://task1.example/checkpoint') 303 task1_federated_select_uri_template = 'https://task1.example/{key_base10}' 304 service.add_task( 305 'task1', 306 _TaskAssignmentMode.TASK_ASSIGNMENT_MODE_MULTIPLE, 307 'aggregation-session1', 308 task1_plan, 309 task1_checkpoint, 310 task1_federated_select_uri_template, 311 ) 312 task2_plan = common_pb2.Resource(uri='https://task2.example/plan') 313 task2_checkpoint = common_pb2.Resource( 314 uri='https://task2.example/checkpoint') 315 task2_federated_select_uri_template = 'https://task2.example/{key_base10}' 316 service.add_task( 317 'task2', 318 _TaskAssignmentMode.TASK_ASSIGNMENT_MODE_MULTIPLE, 319 'aggregation-session2', 320 task2_plan, 321 task2_checkpoint, 322 task2_federated_select_uri_template, 323 ) 324 # Tasks using other TaskAssignmentModes should be skipped. 325 task3_plan = common_pb2.Resource(uri='https://task3.example/plan') 326 task3_checkpoint = common_pb2.Resource( 327 uri='https://task3.example/checkpoint' 328 ) 329 task3_federated_select_uri_template = 'https://task3.example/{key_base10}' 330 service.add_task( 331 'task3', 332 _TaskAssignmentMode.TASK_ASSIGNMENT_MODE_SINGLE, 333 'aggregation-session3', 334 task3_plan, 335 task3_checkpoint, 336 task3_federated_select_uri_template, 337 ) 338 339 request = task_assignments_pb2.PerformMultipleTaskAssignmentsRequest( 340 population_name=POPULATION_NAME, 341 session_id='session-id', 342 task_names=['task1', 'task2', 'task3']) 343 self.assertCountEqual( 344 service.perform_multiple_task_assignments(request).task_assignments, 345 [ 346 task_assignments_pb2.TaskAssignment( 347 aggregation_data_forwarding_info=FORWARDING_INFO, 348 aggregation_info=( 349 task_assignments_pb2.TaskAssignment.AggregationInfo() 350 ), 351 session_id=request.session_id, 352 aggregation_id='aggregation-session1', 353 authorization_token='token-for-aggregation-session1', 354 task_name='task1', 355 plan=task1_plan, 356 init_checkpoint=task1_checkpoint, 357 federated_select_uri_info=( 358 task_assignments_pb2.FederatedSelectUriInfo( 359 uri_template=task1_federated_select_uri_template 360 ) 361 ), 362 ), 363 task_assignments_pb2.TaskAssignment( 364 aggregation_data_forwarding_info=FORWARDING_INFO, 365 aggregation_info=( 366 task_assignments_pb2.TaskAssignment.AggregationInfo() 367 ), 368 session_id=request.session_id, 369 aggregation_id='aggregation-session2', 370 authorization_token='token-for-aggregation-session2', 371 task_name='task2', 372 plan=task2_plan, 373 init_checkpoint=task2_checkpoint, 374 federated_select_uri_info=( 375 task_assignments_pb2.FederatedSelectUriInfo( 376 uri_template=task2_federated_select_uri_template 377 ) 378 ), 379 ), 380 # 'task3' should be omitted since there isn't a corresponding task. 381 ], 382 ) 383 384 def test_add_task_with_invalid_task_assignment_mode(self): 385 service = task_assignments.Service( 386 POPULATION_NAME, lambda: FORWARDING_INFO, self.mock_aggregations_service 387 ) 388 with self.assertRaises(ValueError): 389 service.add_task( 390 'task', 391 _TaskAssignmentMode.TASK_ASSIGNMENT_MODE_UNSPECIFIED, 392 'aggregation-session', 393 common_pb2.Resource(uri='https://task.example/plan'), 394 common_pb2.Resource(uri='https://task.example/checkpoint'), 395 'https://task.example/{key_base10}', 396 ) 397 398 def test_remove_multiple_assignment_task(self): 399 service = task_assignments.Service( 400 POPULATION_NAME, lambda: FORWARDING_INFO, self.mock_aggregations_service 401 ) 402 service.add_task( 403 'task', 404 _TaskAssignmentMode.TASK_ASSIGNMENT_MODE_MULTIPLE, 405 'aggregation-session', 406 common_pb2.Resource(uri='https://task.example/plan'), 407 common_pb2.Resource(uri='https://task.example/checkpoint'), 408 'https://task.example/{key_base10}', 409 ) 410 service.remove_task('aggregation-session') 411 412 request = task_assignments_pb2.PerformMultipleTaskAssignmentsRequest( 413 population_name=POPULATION_NAME, 414 session_id='session-id', 415 task_names=['task'], 416 ) 417 self.assertEmpty( 418 service.perform_multiple_task_assignments(request).task_assignments 419 ) 420 421 def test_remove_missing_task(self): 422 service = task_assignments.Service(POPULATION_NAME, lambda: FORWARDING_INFO, 423 self.mock_aggregations_service) 424 with self.assertRaises(KeyError): 425 service.remove_task('does-not-exist') 426 427 def test_report_task_result(self): 428 service = task_assignments.Service(POPULATION_NAME, lambda: FORWARDING_INFO, 429 self.mock_aggregations_service) 430 request = task_assignments_pb2.ReportTaskResultRequest( 431 population_name=POPULATION_NAME, 432 session_id='session-id', 433 aggregation_id='aggregation-id', 434 computation_status_code=code_pb2.ABORTED) 435 self.assertEqual( 436 service.report_task_result(request), 437 task_assignments_pb2.ReportTaskResultResponse()) 438 439 def test_report_task_result_with_wrong_population(self): 440 service = task_assignments.Service(POPULATION_NAME, lambda: FORWARDING_INFO, 441 self.mock_aggregations_service) 442 request = task_assignments_pb2.ReportTaskResultRequest( 443 population_name='other/population', 444 session_id='session-id', 445 aggregation_id='aggregation-id', 446 computation_status_code=code_pb2.ABORTED) 447 with self.assertRaises(http_actions.HttpError) as cm: 448 service.report_task_result(request) 449 self.assertEqual(cm.exception.code, http.HTTPStatus.NOT_FOUND) 450 451 452if __name__ == '__main__': 453 absltest.main() 454