xref: /aosp_15_r20/external/federated-compute/fcp/demo/task_assignments_test.py (revision 14675a029014e728ec732f129a32e299b2da0601)
1# Copyright 2022 Google LLC
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#      http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""Tests for task_assignments."""
15
16import http
17from unittest import mock
18import uuid
19
20from absl.testing import absltest
21
22from google.rpc import code_pb2
23from fcp.demo import aggregations
24from fcp.demo import http_actions
25from fcp.demo import task_assignments
26from fcp.protos.federatedcompute import common_pb2
27from fcp.protos.federatedcompute import eligibility_eval_tasks_pb2
28from fcp.protos.federatedcompute import task_assignments_pb2
29
30_TaskAssignmentMode = (
31    eligibility_eval_tasks_pb2.PopulationEligibilitySpec.TaskInfo.TaskAssignmentMode
32)
33
34POPULATION_NAME = 'test/population'
35FORWARDING_INFO = common_pb2.ForwardingInfo(
36    target_uri_prefix='https://forwarding.example/')
37
38
39class TaskAssignmentsTest(absltest.TestCase):
40
41  def setUp(self):
42    super().setUp()
43    self.mock_aggregations_service = self.enter_context(
44        mock.patch.object(aggregations, 'Service', autospec=True))
45    self.mock_aggregations_service.pre_authorize_clients.return_value = ['']
46
47  def test_start_task_assignment_with_wrong_population(self):
48    service = task_assignments.Service(POPULATION_NAME, lambda: FORWARDING_INFO,
49                                       self.mock_aggregations_service)
50    request = task_assignments_pb2.StartTaskAssignmentRequest(
51        population_name='other/population', session_id='session-id')
52    with self.assertRaises(http_actions.HttpError) as cm:
53      service.start_task_assignment(request)
54    self.assertEqual(cm.exception.code, http.HTTPStatus.NOT_FOUND)
55
56  @mock.patch.object(uuid, 'uuid4', return_value=uuid.uuid4(), autospec=True)
57  def test_start_task_assignment_with_no_tasks(self, mock_uuid):
58    service = task_assignments.Service(POPULATION_NAME, lambda: FORWARDING_INFO,
59                                       self.mock_aggregations_service)
60    request = task_assignments_pb2.StartTaskAssignmentRequest(
61        population_name=POPULATION_NAME, session_id='session-id')
62    operation = service.start_task_assignment(request)
63    self.assertEqual(operation.name, f'operations/{mock_uuid.return_value}')
64    self.assertTrue(operation.done)
65
66    metadata = task_assignments_pb2.StartTaskAssignmentMetadata()
67    operation.metadata.Unpack(metadata)
68    self.assertEqual(metadata,
69                     task_assignments_pb2.StartTaskAssignmentMetadata())
70
71    response = task_assignments_pb2.StartTaskAssignmentResponse()
72    operation.response.Unpack(response)
73    self.assertEqual(
74        response,
75        task_assignments_pb2.StartTaskAssignmentResponse(
76            rejection_info=common_pb2.RejectionInfo()
77        ),
78    )
79
80  def test_start_task_assignment_with_multiple_assignment_task(self):
81    service = task_assignments.Service(
82        POPULATION_NAME, lambda: FORWARDING_INFO, self.mock_aggregations_service
83    )
84    service.add_task(
85        'task',
86        _TaskAssignmentMode.TASK_ASSIGNMENT_MODE_MULTIPLE,
87        'aggregation-session',
88        common_pb2.Resource(uri='https://task.example/plan'),
89        common_pb2.Resource(uri='https://task.example/checkpoint'),
90        'https://task.example/{key_base10}',
91    )
92
93    request = task_assignments_pb2.StartTaskAssignmentRequest(
94        population_name=POPULATION_NAME, session_id='session-id'
95    )
96    operation = service.start_task_assignment(request)
97    self.assertTrue(operation.done)
98
99    response = task_assignments_pb2.StartTaskAssignmentResponse()
100    operation.response.Unpack(response)
101    self.assertEqual(
102        response,
103        task_assignments_pb2.StartTaskAssignmentResponse(
104            rejection_info=common_pb2.RejectionInfo()))
105
106  @mock.patch.object(uuid, 'uuid4', return_value=uuid.uuid4(), autospec=True)
107  def test_start_task_assignment_with_one_task(self, mock_uuid):
108    service = task_assignments.Service(POPULATION_NAME, lambda: FORWARDING_INFO,
109                                       self.mock_aggregations_service)
110
111    self.mock_aggregations_service.pre_authorize_clients.return_value = [
112        'token'
113    ]
114
115    task_plan = common_pb2.Resource(uri='https://task.example/plan')
116    task_checkpoint = common_pb2.Resource(uri='https://task.example/checkpoint')
117    task_federated_select_uri_template = 'https://task.example/{key_base10}'
118    service.add_task(
119        'task',
120        _TaskAssignmentMode.TASK_ASSIGNMENT_MODE_SINGLE,
121        'aggregation-session',
122        task_plan,
123        task_checkpoint,
124        task_federated_select_uri_template,
125    )
126
127    request = task_assignments_pb2.StartTaskAssignmentRequest(
128        population_name=POPULATION_NAME, session_id='session-id')
129    operation = service.start_task_assignment(request)
130    self.assertEqual(operation.name, f'operations/{mock_uuid.return_value}')
131    self.assertTrue(operation.done)
132
133    metadata = task_assignments_pb2.StartTaskAssignmentMetadata()
134    operation.metadata.Unpack(metadata)
135    self.assertEqual(metadata,
136                     task_assignments_pb2.StartTaskAssignmentMetadata())
137
138    response = task_assignments_pb2.StartTaskAssignmentResponse()
139    operation.response.Unpack(response)
140    self.assertEqual(
141        response,
142        task_assignments_pb2.StartTaskAssignmentResponse(
143            task_assignment=task_assignments_pb2.TaskAssignment(
144                aggregation_data_forwarding_info=FORWARDING_INFO,
145                aggregation_info=(
146                    task_assignments_pb2.TaskAssignment.AggregationInfo()
147                ),
148                session_id=request.session_id,
149                aggregation_id='aggregation-session',
150                authorization_token='token',
151                task_name='task',
152                plan=task_plan,
153                init_checkpoint=task_checkpoint,
154                federated_select_uri_info=(
155                    task_assignments_pb2.FederatedSelectUriInfo(
156                        uri_template=task_federated_select_uri_template
157                    )
158                ),
159            )
160        ),
161    )
162
163    self.mock_aggregations_service.pre_authorize_clients.assert_called_once_with(
164        'aggregation-session', num_tokens=1)
165
166  def test_start_task_assignment_with_multiple_tasks(self):
167    service = task_assignments.Service(POPULATION_NAME, lambda: FORWARDING_INFO,
168                                       self.mock_aggregations_service)
169
170    self.mock_aggregations_service.pre_authorize_clients.return_value = [
171        'token'
172    ]
173
174    task1_plan = common_pb2.Resource(uri='https://task1.example/plan')
175    task1_checkpoint = common_pb2.Resource(
176        uri='https://task1.example/checkpoint')
177    task1_federated_select_uri_template = 'https://task1.example/{key_base10}'
178    service.add_task(
179        'task1',
180        _TaskAssignmentMode.TASK_ASSIGNMENT_MODE_SINGLE,
181        'aggregation-session1',
182        task1_plan,
183        task1_checkpoint,
184        task1_federated_select_uri_template,
185    )
186    task2_plan = common_pb2.Resource(uri='https://task2.example/plan')
187    task2_checkpoint = common_pb2.Resource(
188        uri='https://task2.example/checkpoint')
189    task2_federated_select_uri_template = 'https://task2.example/{key_base10}'
190    service.add_task(
191        'task2',
192        _TaskAssignmentMode.TASK_ASSIGNMENT_MODE_SINGLE,
193        'aggregation-session2',
194        task2_plan,
195        task2_checkpoint,
196        task2_federated_select_uri_template,
197    )
198
199    request = task_assignments_pb2.StartTaskAssignmentRequest(
200        population_name=POPULATION_NAME, session_id='session-id')
201
202    # Initially, task1 should be used.
203    operation = service.start_task_assignment(request)
204    response = task_assignments_pb2.StartTaskAssignmentResponse()
205    operation.response.Unpack(response)
206    self.assertEqual(
207        response,
208        task_assignments_pb2.StartTaskAssignmentResponse(
209            task_assignment=task_assignments_pb2.TaskAssignment(
210                aggregation_data_forwarding_info=FORWARDING_INFO,
211                aggregation_info=(
212                    task_assignments_pb2.TaskAssignment.AggregationInfo()
213                ),
214                session_id=request.session_id,
215                aggregation_id='aggregation-session1',
216                authorization_token='token',
217                task_name='task1',
218                plan=task1_plan,
219                init_checkpoint=task1_checkpoint,
220                federated_select_uri_info=(
221                    task_assignments_pb2.FederatedSelectUriInfo(
222                        uri_template=task1_federated_select_uri_template
223                    )
224                ),
225            )
226        ),
227    )
228    self.mock_aggregations_service.pre_authorize_clients.assert_called_with(
229        'aggregation-session1', num_tokens=1)
230
231    # After task1 is removed, task2 should be used.
232    service.remove_task('aggregation-session1')
233    operation = service.start_task_assignment(request)
234    response = task_assignments_pb2.StartTaskAssignmentResponse()
235    operation.response.Unpack(response)
236    self.assertEqual(
237        response,
238        task_assignments_pb2.StartTaskAssignmentResponse(
239            task_assignment=task_assignments_pb2.TaskAssignment(
240                aggregation_data_forwarding_info=FORWARDING_INFO,
241                aggregation_info=(
242                    task_assignments_pb2.TaskAssignment.AggregationInfo()
243                ),
244                session_id=request.session_id,
245                aggregation_id='aggregation-session2',
246                authorization_token='token',
247                task_name='task2',
248                plan=task2_plan,
249                init_checkpoint=task2_checkpoint,
250                federated_select_uri_info=(
251                    task_assignments_pb2.FederatedSelectUriInfo(
252                        uri_template=task2_federated_select_uri_template
253                    )
254                ),
255            )
256        ),
257    )
258    self.mock_aggregations_service.pre_authorize_clients.assert_called_with(
259        'aggregation-session2', num_tokens=1)
260
261    # After task2 is removed, the client should be rejected.
262    service.remove_task('aggregation-session2')
263    operation = service.start_task_assignment(request)
264    response = task_assignments_pb2.StartTaskAssignmentResponse()
265    operation.response.Unpack(response)
266    self.assertEqual(
267        response,
268        task_assignments_pb2.StartTaskAssignmentResponse(
269            rejection_info=common_pb2.RejectionInfo()))
270
271  def test_perform_multiple_task_assignments_with_wrong_population(self):
272    service = task_assignments.Service(POPULATION_NAME, lambda: FORWARDING_INFO,
273                                       self.mock_aggregations_service)
274    request = task_assignments_pb2.PerformMultipleTaskAssignmentsRequest(
275        population_name='other/population',
276        session_id='session-id',
277        task_names=['task1', 'task2', 'task3'])
278    with self.assertRaises(http_actions.HttpError) as cm:
279      service.perform_multiple_task_assignments(request)
280    self.assertEqual(cm.exception.code, http.HTTPStatus.NOT_FOUND)
281
282  def test_perform_multiple_task_assignments_without_tasks(self):
283    service = task_assignments.Service(POPULATION_NAME, lambda: FORWARDING_INFO,
284                                       self.mock_aggregations_service)
285
286    request = task_assignments_pb2.PerformMultipleTaskAssignmentsRequest(
287        population_name=POPULATION_NAME,
288        session_id='session-id',
289        task_names=['task1', 'task2', 'task3'])
290    self.assertEqual(
291        service.perform_multiple_task_assignments(request),
292        task_assignments_pb2.PerformMultipleTaskAssignmentsResponse())
293
294  def test_perform_multiple_task_assignments_with_multiple_tasks(self):
295    self.mock_aggregations_service.pre_authorize_clients.side_effect = (
296        lambda session_id, num_tokens=1: [f'token-for-{session_id}'])
297    service = task_assignments.Service(POPULATION_NAME, lambda: FORWARDING_INFO,
298                                       self.mock_aggregations_service)
299
300    task1_plan = common_pb2.Resource(uri='https://task1.example/plan')
301    task1_checkpoint = common_pb2.Resource(
302        uri='https://task1.example/checkpoint')
303    task1_federated_select_uri_template = 'https://task1.example/{key_base10}'
304    service.add_task(
305        'task1',
306        _TaskAssignmentMode.TASK_ASSIGNMENT_MODE_MULTIPLE,
307        'aggregation-session1',
308        task1_plan,
309        task1_checkpoint,
310        task1_federated_select_uri_template,
311    )
312    task2_plan = common_pb2.Resource(uri='https://task2.example/plan')
313    task2_checkpoint = common_pb2.Resource(
314        uri='https://task2.example/checkpoint')
315    task2_federated_select_uri_template = 'https://task2.example/{key_base10}'
316    service.add_task(
317        'task2',
318        _TaskAssignmentMode.TASK_ASSIGNMENT_MODE_MULTIPLE,
319        'aggregation-session2',
320        task2_plan,
321        task2_checkpoint,
322        task2_federated_select_uri_template,
323    )
324    # Tasks using other TaskAssignmentModes should be skipped.
325    task3_plan = common_pb2.Resource(uri='https://task3.example/plan')
326    task3_checkpoint = common_pb2.Resource(
327        uri='https://task3.example/checkpoint'
328    )
329    task3_federated_select_uri_template = 'https://task3.example/{key_base10}'
330    service.add_task(
331        'task3',
332        _TaskAssignmentMode.TASK_ASSIGNMENT_MODE_SINGLE,
333        'aggregation-session3',
334        task3_plan,
335        task3_checkpoint,
336        task3_federated_select_uri_template,
337    )
338
339    request = task_assignments_pb2.PerformMultipleTaskAssignmentsRequest(
340        population_name=POPULATION_NAME,
341        session_id='session-id',
342        task_names=['task1', 'task2', 'task3'])
343    self.assertCountEqual(
344        service.perform_multiple_task_assignments(request).task_assignments,
345        [
346            task_assignments_pb2.TaskAssignment(
347                aggregation_data_forwarding_info=FORWARDING_INFO,
348                aggregation_info=(
349                    task_assignments_pb2.TaskAssignment.AggregationInfo()
350                ),
351                session_id=request.session_id,
352                aggregation_id='aggregation-session1',
353                authorization_token='token-for-aggregation-session1',
354                task_name='task1',
355                plan=task1_plan,
356                init_checkpoint=task1_checkpoint,
357                federated_select_uri_info=(
358                    task_assignments_pb2.FederatedSelectUriInfo(
359                        uri_template=task1_federated_select_uri_template
360                    )
361                ),
362            ),
363            task_assignments_pb2.TaskAssignment(
364                aggregation_data_forwarding_info=FORWARDING_INFO,
365                aggregation_info=(
366                    task_assignments_pb2.TaskAssignment.AggregationInfo()
367                ),
368                session_id=request.session_id,
369                aggregation_id='aggregation-session2',
370                authorization_token='token-for-aggregation-session2',
371                task_name='task2',
372                plan=task2_plan,
373                init_checkpoint=task2_checkpoint,
374                federated_select_uri_info=(
375                    task_assignments_pb2.FederatedSelectUriInfo(
376                        uri_template=task2_federated_select_uri_template
377                    )
378                ),
379            ),
380            # 'task3' should be omitted since there isn't a corresponding task.
381        ],
382    )
383
384  def test_add_task_with_invalid_task_assignment_mode(self):
385    service = task_assignments.Service(
386        POPULATION_NAME, lambda: FORWARDING_INFO, self.mock_aggregations_service
387    )
388    with self.assertRaises(ValueError):
389      service.add_task(
390          'task',
391          _TaskAssignmentMode.TASK_ASSIGNMENT_MODE_UNSPECIFIED,
392          'aggregation-session',
393          common_pb2.Resource(uri='https://task.example/plan'),
394          common_pb2.Resource(uri='https://task.example/checkpoint'),
395          'https://task.example/{key_base10}',
396      )
397
398  def test_remove_multiple_assignment_task(self):
399    service = task_assignments.Service(
400        POPULATION_NAME, lambda: FORWARDING_INFO, self.mock_aggregations_service
401    )
402    service.add_task(
403        'task',
404        _TaskAssignmentMode.TASK_ASSIGNMENT_MODE_MULTIPLE,
405        'aggregation-session',
406        common_pb2.Resource(uri='https://task.example/plan'),
407        common_pb2.Resource(uri='https://task.example/checkpoint'),
408        'https://task.example/{key_base10}',
409    )
410    service.remove_task('aggregation-session')
411
412    request = task_assignments_pb2.PerformMultipleTaskAssignmentsRequest(
413        population_name=POPULATION_NAME,
414        session_id='session-id',
415        task_names=['task'],
416    )
417    self.assertEmpty(
418        service.perform_multiple_task_assignments(request).task_assignments
419    )
420
421  def test_remove_missing_task(self):
422    service = task_assignments.Service(POPULATION_NAME, lambda: FORWARDING_INFO,
423                                       self.mock_aggregations_service)
424    with self.assertRaises(KeyError):
425      service.remove_task('does-not-exist')
426
427  def test_report_task_result(self):
428    service = task_assignments.Service(POPULATION_NAME, lambda: FORWARDING_INFO,
429                                       self.mock_aggregations_service)
430    request = task_assignments_pb2.ReportTaskResultRequest(
431        population_name=POPULATION_NAME,
432        session_id='session-id',
433        aggregation_id='aggregation-id',
434        computation_status_code=code_pb2.ABORTED)
435    self.assertEqual(
436        service.report_task_result(request),
437        task_assignments_pb2.ReportTaskResultResponse())
438
439  def test_report_task_result_with_wrong_population(self):
440    service = task_assignments.Service(POPULATION_NAME, lambda: FORWARDING_INFO,
441                                       self.mock_aggregations_service)
442    request = task_assignments_pb2.ReportTaskResultRequest(
443        population_name='other/population',
444        session_id='session-id',
445        aggregation_id='aggregation-id',
446        computation_status_code=code_pb2.ABORTED)
447    with self.assertRaises(http_actions.HttpError) as cm:
448      service.report_task_result(request)
449    self.assertEqual(cm.exception.code, http.HTTPStatus.NOT_FOUND)
450
451
452if __name__ == '__main__':
453  absltest.main()
454