xref: /aosp_15_r20/external/federated-compute/fcp/demo/task_assignments_test.py (revision 14675a029014e728ec732f129a32e299b2da0601)
1*14675a02SAndroid Build Coastguard Worker# Copyright 2022 Google LLC
2*14675a02SAndroid Build Coastguard Worker#
3*14675a02SAndroid Build Coastguard Worker# Licensed under the Apache License, Version 2.0 (the "License");
4*14675a02SAndroid Build Coastguard Worker# you may not use this file except in compliance with the License.
5*14675a02SAndroid Build Coastguard Worker# You may obtain a copy of the License at
6*14675a02SAndroid Build Coastguard Worker#
7*14675a02SAndroid Build Coastguard Worker#      http://www.apache.org/licenses/LICENSE-2.0
8*14675a02SAndroid Build Coastguard Worker#
9*14675a02SAndroid Build Coastguard Worker# Unless required by applicable law or agreed to in writing, software
10*14675a02SAndroid Build Coastguard Worker# distributed under the License is distributed on an "AS IS" BASIS,
11*14675a02SAndroid Build Coastguard Worker# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12*14675a02SAndroid Build Coastguard Worker# See the License for the specific language governing permissions and
13*14675a02SAndroid Build Coastguard Worker# limitations under the License.
14*14675a02SAndroid Build Coastguard Worker"""Tests for task_assignments."""
15*14675a02SAndroid Build Coastguard Worker
16*14675a02SAndroid Build Coastguard Workerimport http
17*14675a02SAndroid Build Coastguard Workerfrom unittest import mock
18*14675a02SAndroid Build Coastguard Workerimport uuid
19*14675a02SAndroid Build Coastguard Worker
20*14675a02SAndroid Build Coastguard Workerfrom absl.testing import absltest
21*14675a02SAndroid Build Coastguard Worker
22*14675a02SAndroid Build Coastguard Workerfrom google.rpc import code_pb2
23*14675a02SAndroid Build Coastguard Workerfrom fcp.demo import aggregations
24*14675a02SAndroid Build Coastguard Workerfrom fcp.demo import http_actions
25*14675a02SAndroid Build Coastguard Workerfrom fcp.demo import task_assignments
26*14675a02SAndroid Build Coastguard Workerfrom fcp.protos.federatedcompute import common_pb2
27*14675a02SAndroid Build Coastguard Workerfrom fcp.protos.federatedcompute import eligibility_eval_tasks_pb2
28*14675a02SAndroid Build Coastguard Workerfrom fcp.protos.federatedcompute import task_assignments_pb2
29*14675a02SAndroid Build Coastguard Worker
30*14675a02SAndroid Build Coastguard Worker_TaskAssignmentMode = (
31*14675a02SAndroid Build Coastguard Worker    eligibility_eval_tasks_pb2.PopulationEligibilitySpec.TaskInfo.TaskAssignmentMode
32*14675a02SAndroid Build Coastguard Worker)
33*14675a02SAndroid Build Coastguard Worker
34*14675a02SAndroid Build Coastguard WorkerPOPULATION_NAME = 'test/population'
35*14675a02SAndroid Build Coastguard WorkerFORWARDING_INFO = common_pb2.ForwardingInfo(
36*14675a02SAndroid Build Coastguard Worker    target_uri_prefix='https://forwarding.example/')
37*14675a02SAndroid Build Coastguard Worker
38*14675a02SAndroid Build Coastguard Worker
39*14675a02SAndroid Build Coastguard Workerclass TaskAssignmentsTest(absltest.TestCase):
40*14675a02SAndroid Build Coastguard Worker
41*14675a02SAndroid Build Coastguard Worker  def setUp(self):
42*14675a02SAndroid Build Coastguard Worker    super().setUp()
43*14675a02SAndroid Build Coastguard Worker    self.mock_aggregations_service = self.enter_context(
44*14675a02SAndroid Build Coastguard Worker        mock.patch.object(aggregations, 'Service', autospec=True))
45*14675a02SAndroid Build Coastguard Worker    self.mock_aggregations_service.pre_authorize_clients.return_value = ['']
46*14675a02SAndroid Build Coastguard Worker
47*14675a02SAndroid Build Coastguard Worker  def test_start_task_assignment_with_wrong_population(self):
48*14675a02SAndroid Build Coastguard Worker    service = task_assignments.Service(POPULATION_NAME, lambda: FORWARDING_INFO,
49*14675a02SAndroid Build Coastguard Worker                                       self.mock_aggregations_service)
50*14675a02SAndroid Build Coastguard Worker    request = task_assignments_pb2.StartTaskAssignmentRequest(
51*14675a02SAndroid Build Coastguard Worker        population_name='other/population', session_id='session-id')
52*14675a02SAndroid Build Coastguard Worker    with self.assertRaises(http_actions.HttpError) as cm:
53*14675a02SAndroid Build Coastguard Worker      service.start_task_assignment(request)
54*14675a02SAndroid Build Coastguard Worker    self.assertEqual(cm.exception.code, http.HTTPStatus.NOT_FOUND)
55*14675a02SAndroid Build Coastguard Worker
56*14675a02SAndroid Build Coastguard Worker  @mock.patch.object(uuid, 'uuid4', return_value=uuid.uuid4(), autospec=True)
57*14675a02SAndroid Build Coastguard Worker  def test_start_task_assignment_with_no_tasks(self, mock_uuid):
58*14675a02SAndroid Build Coastguard Worker    service = task_assignments.Service(POPULATION_NAME, lambda: FORWARDING_INFO,
59*14675a02SAndroid Build Coastguard Worker                                       self.mock_aggregations_service)
60*14675a02SAndroid Build Coastguard Worker    request = task_assignments_pb2.StartTaskAssignmentRequest(
61*14675a02SAndroid Build Coastguard Worker        population_name=POPULATION_NAME, session_id='session-id')
62*14675a02SAndroid Build Coastguard Worker    operation = service.start_task_assignment(request)
63*14675a02SAndroid Build Coastguard Worker    self.assertEqual(operation.name, f'operations/{mock_uuid.return_value}')
64*14675a02SAndroid Build Coastguard Worker    self.assertTrue(operation.done)
65*14675a02SAndroid Build Coastguard Worker
66*14675a02SAndroid Build Coastguard Worker    metadata = task_assignments_pb2.StartTaskAssignmentMetadata()
67*14675a02SAndroid Build Coastguard Worker    operation.metadata.Unpack(metadata)
68*14675a02SAndroid Build Coastguard Worker    self.assertEqual(metadata,
69*14675a02SAndroid Build Coastguard Worker                     task_assignments_pb2.StartTaskAssignmentMetadata())
70*14675a02SAndroid Build Coastguard Worker
71*14675a02SAndroid Build Coastguard Worker    response = task_assignments_pb2.StartTaskAssignmentResponse()
72*14675a02SAndroid Build Coastguard Worker    operation.response.Unpack(response)
73*14675a02SAndroid Build Coastguard Worker    self.assertEqual(
74*14675a02SAndroid Build Coastguard Worker        response,
75*14675a02SAndroid Build Coastguard Worker        task_assignments_pb2.StartTaskAssignmentResponse(
76*14675a02SAndroid Build Coastguard Worker            rejection_info=common_pb2.RejectionInfo()
77*14675a02SAndroid Build Coastguard Worker        ),
78*14675a02SAndroid Build Coastguard Worker    )
79*14675a02SAndroid Build Coastguard Worker
80*14675a02SAndroid Build Coastguard Worker  def test_start_task_assignment_with_multiple_assignment_task(self):
81*14675a02SAndroid Build Coastguard Worker    service = task_assignments.Service(
82*14675a02SAndroid Build Coastguard Worker        POPULATION_NAME, lambda: FORWARDING_INFO, self.mock_aggregations_service
83*14675a02SAndroid Build Coastguard Worker    )
84*14675a02SAndroid Build Coastguard Worker    service.add_task(
85*14675a02SAndroid Build Coastguard Worker        'task',
86*14675a02SAndroid Build Coastguard Worker        _TaskAssignmentMode.TASK_ASSIGNMENT_MODE_MULTIPLE,
87*14675a02SAndroid Build Coastguard Worker        'aggregation-session',
88*14675a02SAndroid Build Coastguard Worker        common_pb2.Resource(uri='https://task.example/plan'),
89*14675a02SAndroid Build Coastguard Worker        common_pb2.Resource(uri='https://task.example/checkpoint'),
90*14675a02SAndroid Build Coastguard Worker        'https://task.example/{key_base10}',
91*14675a02SAndroid Build Coastguard Worker    )
92*14675a02SAndroid Build Coastguard Worker
93*14675a02SAndroid Build Coastguard Worker    request = task_assignments_pb2.StartTaskAssignmentRequest(
94*14675a02SAndroid Build Coastguard Worker        population_name=POPULATION_NAME, session_id='session-id'
95*14675a02SAndroid Build Coastguard Worker    )
96*14675a02SAndroid Build Coastguard Worker    operation = service.start_task_assignment(request)
97*14675a02SAndroid Build Coastguard Worker    self.assertTrue(operation.done)
98*14675a02SAndroid Build Coastguard Worker
99*14675a02SAndroid Build Coastguard Worker    response = task_assignments_pb2.StartTaskAssignmentResponse()
100*14675a02SAndroid Build Coastguard Worker    operation.response.Unpack(response)
101*14675a02SAndroid Build Coastguard Worker    self.assertEqual(
102*14675a02SAndroid Build Coastguard Worker        response,
103*14675a02SAndroid Build Coastguard Worker        task_assignments_pb2.StartTaskAssignmentResponse(
104*14675a02SAndroid Build Coastguard Worker            rejection_info=common_pb2.RejectionInfo()))
105*14675a02SAndroid Build Coastguard Worker
106*14675a02SAndroid Build Coastguard Worker  @mock.patch.object(uuid, 'uuid4', return_value=uuid.uuid4(), autospec=True)
107*14675a02SAndroid Build Coastguard Worker  def test_start_task_assignment_with_one_task(self, mock_uuid):
108*14675a02SAndroid Build Coastguard Worker    service = task_assignments.Service(POPULATION_NAME, lambda: FORWARDING_INFO,
109*14675a02SAndroid Build Coastguard Worker                                       self.mock_aggregations_service)
110*14675a02SAndroid Build Coastguard Worker
111*14675a02SAndroid Build Coastguard Worker    self.mock_aggregations_service.pre_authorize_clients.return_value = [
112*14675a02SAndroid Build Coastguard Worker        'token'
113*14675a02SAndroid Build Coastguard Worker    ]
114*14675a02SAndroid Build Coastguard Worker
115*14675a02SAndroid Build Coastguard Worker    task_plan = common_pb2.Resource(uri='https://task.example/plan')
116*14675a02SAndroid Build Coastguard Worker    task_checkpoint = common_pb2.Resource(uri='https://task.example/checkpoint')
117*14675a02SAndroid Build Coastguard Worker    task_federated_select_uri_template = 'https://task.example/{key_base10}'
118*14675a02SAndroid Build Coastguard Worker    service.add_task(
119*14675a02SAndroid Build Coastguard Worker        'task',
120*14675a02SAndroid Build Coastguard Worker        _TaskAssignmentMode.TASK_ASSIGNMENT_MODE_SINGLE,
121*14675a02SAndroid Build Coastguard Worker        'aggregation-session',
122*14675a02SAndroid Build Coastguard Worker        task_plan,
123*14675a02SAndroid Build Coastguard Worker        task_checkpoint,
124*14675a02SAndroid Build Coastguard Worker        task_federated_select_uri_template,
125*14675a02SAndroid Build Coastguard Worker    )
126*14675a02SAndroid Build Coastguard Worker
127*14675a02SAndroid Build Coastguard Worker    request = task_assignments_pb2.StartTaskAssignmentRequest(
128*14675a02SAndroid Build Coastguard Worker        population_name=POPULATION_NAME, session_id='session-id')
129*14675a02SAndroid Build Coastguard Worker    operation = service.start_task_assignment(request)
130*14675a02SAndroid Build Coastguard Worker    self.assertEqual(operation.name, f'operations/{mock_uuid.return_value}')
131*14675a02SAndroid Build Coastguard Worker    self.assertTrue(operation.done)
132*14675a02SAndroid Build Coastguard Worker
133*14675a02SAndroid Build Coastguard Worker    metadata = task_assignments_pb2.StartTaskAssignmentMetadata()
134*14675a02SAndroid Build Coastguard Worker    operation.metadata.Unpack(metadata)
135*14675a02SAndroid Build Coastguard Worker    self.assertEqual(metadata,
136*14675a02SAndroid Build Coastguard Worker                     task_assignments_pb2.StartTaskAssignmentMetadata())
137*14675a02SAndroid Build Coastguard Worker
138*14675a02SAndroid Build Coastguard Worker    response = task_assignments_pb2.StartTaskAssignmentResponse()
139*14675a02SAndroid Build Coastguard Worker    operation.response.Unpack(response)
140*14675a02SAndroid Build Coastguard Worker    self.assertEqual(
141*14675a02SAndroid Build Coastguard Worker        response,
142*14675a02SAndroid Build Coastguard Worker        task_assignments_pb2.StartTaskAssignmentResponse(
143*14675a02SAndroid Build Coastguard Worker            task_assignment=task_assignments_pb2.TaskAssignment(
144*14675a02SAndroid Build Coastguard Worker                aggregation_data_forwarding_info=FORWARDING_INFO,
145*14675a02SAndroid Build Coastguard Worker                aggregation_info=(
146*14675a02SAndroid Build Coastguard Worker                    task_assignments_pb2.TaskAssignment.AggregationInfo()
147*14675a02SAndroid Build Coastguard Worker                ),
148*14675a02SAndroid Build Coastguard Worker                session_id=request.session_id,
149*14675a02SAndroid Build Coastguard Worker                aggregation_id='aggregation-session',
150*14675a02SAndroid Build Coastguard Worker                authorization_token='token',
151*14675a02SAndroid Build Coastguard Worker                task_name='task',
152*14675a02SAndroid Build Coastguard Worker                plan=task_plan,
153*14675a02SAndroid Build Coastguard Worker                init_checkpoint=task_checkpoint,
154*14675a02SAndroid Build Coastguard Worker                federated_select_uri_info=(
155*14675a02SAndroid Build Coastguard Worker                    task_assignments_pb2.FederatedSelectUriInfo(
156*14675a02SAndroid Build Coastguard Worker                        uri_template=task_federated_select_uri_template
157*14675a02SAndroid Build Coastguard Worker                    )
158*14675a02SAndroid Build Coastguard Worker                ),
159*14675a02SAndroid Build Coastguard Worker            )
160*14675a02SAndroid Build Coastguard Worker        ),
161*14675a02SAndroid Build Coastguard Worker    )
162*14675a02SAndroid Build Coastguard Worker
163*14675a02SAndroid Build Coastguard Worker    self.mock_aggregations_service.pre_authorize_clients.assert_called_once_with(
164*14675a02SAndroid Build Coastguard Worker        'aggregation-session', num_tokens=1)
165*14675a02SAndroid Build Coastguard Worker
166*14675a02SAndroid Build Coastguard Worker  def test_start_task_assignment_with_multiple_tasks(self):
167*14675a02SAndroid Build Coastguard Worker    service = task_assignments.Service(POPULATION_NAME, lambda: FORWARDING_INFO,
168*14675a02SAndroid Build Coastguard Worker                                       self.mock_aggregations_service)
169*14675a02SAndroid Build Coastguard Worker
170*14675a02SAndroid Build Coastguard Worker    self.mock_aggregations_service.pre_authorize_clients.return_value = [
171*14675a02SAndroid Build Coastguard Worker        'token'
172*14675a02SAndroid Build Coastguard Worker    ]
173*14675a02SAndroid Build Coastguard Worker
174*14675a02SAndroid Build Coastguard Worker    task1_plan = common_pb2.Resource(uri='https://task1.example/plan')
175*14675a02SAndroid Build Coastguard Worker    task1_checkpoint = common_pb2.Resource(
176*14675a02SAndroid Build Coastguard Worker        uri='https://task1.example/checkpoint')
177*14675a02SAndroid Build Coastguard Worker    task1_federated_select_uri_template = 'https://task1.example/{key_base10}'
178*14675a02SAndroid Build Coastguard Worker    service.add_task(
179*14675a02SAndroid Build Coastguard Worker        'task1',
180*14675a02SAndroid Build Coastguard Worker        _TaskAssignmentMode.TASK_ASSIGNMENT_MODE_SINGLE,
181*14675a02SAndroid Build Coastguard Worker        'aggregation-session1',
182*14675a02SAndroid Build Coastguard Worker        task1_plan,
183*14675a02SAndroid Build Coastguard Worker        task1_checkpoint,
184*14675a02SAndroid Build Coastguard Worker        task1_federated_select_uri_template,
185*14675a02SAndroid Build Coastguard Worker    )
186*14675a02SAndroid Build Coastguard Worker    task2_plan = common_pb2.Resource(uri='https://task2.example/plan')
187*14675a02SAndroid Build Coastguard Worker    task2_checkpoint = common_pb2.Resource(
188*14675a02SAndroid Build Coastguard Worker        uri='https://task2.example/checkpoint')
189*14675a02SAndroid Build Coastguard Worker    task2_federated_select_uri_template = 'https://task2.example/{key_base10}'
190*14675a02SAndroid Build Coastguard Worker    service.add_task(
191*14675a02SAndroid Build Coastguard Worker        'task2',
192*14675a02SAndroid Build Coastguard Worker        _TaskAssignmentMode.TASK_ASSIGNMENT_MODE_SINGLE,
193*14675a02SAndroid Build Coastguard Worker        'aggregation-session2',
194*14675a02SAndroid Build Coastguard Worker        task2_plan,
195*14675a02SAndroid Build Coastguard Worker        task2_checkpoint,
196*14675a02SAndroid Build Coastguard Worker        task2_federated_select_uri_template,
197*14675a02SAndroid Build Coastguard Worker    )
198*14675a02SAndroid Build Coastguard Worker
199*14675a02SAndroid Build Coastguard Worker    request = task_assignments_pb2.StartTaskAssignmentRequest(
200*14675a02SAndroid Build Coastguard Worker        population_name=POPULATION_NAME, session_id='session-id')
201*14675a02SAndroid Build Coastguard Worker
202*14675a02SAndroid Build Coastguard Worker    # Initially, task1 should be used.
203*14675a02SAndroid Build Coastguard Worker    operation = service.start_task_assignment(request)
204*14675a02SAndroid Build Coastguard Worker    response = task_assignments_pb2.StartTaskAssignmentResponse()
205*14675a02SAndroid Build Coastguard Worker    operation.response.Unpack(response)
206*14675a02SAndroid Build Coastguard Worker    self.assertEqual(
207*14675a02SAndroid Build Coastguard Worker        response,
208*14675a02SAndroid Build Coastguard Worker        task_assignments_pb2.StartTaskAssignmentResponse(
209*14675a02SAndroid Build Coastguard Worker            task_assignment=task_assignments_pb2.TaskAssignment(
210*14675a02SAndroid Build Coastguard Worker                aggregation_data_forwarding_info=FORWARDING_INFO,
211*14675a02SAndroid Build Coastguard Worker                aggregation_info=(
212*14675a02SAndroid Build Coastguard Worker                    task_assignments_pb2.TaskAssignment.AggregationInfo()
213*14675a02SAndroid Build Coastguard Worker                ),
214*14675a02SAndroid Build Coastguard Worker                session_id=request.session_id,
215*14675a02SAndroid Build Coastguard Worker                aggregation_id='aggregation-session1',
216*14675a02SAndroid Build Coastguard Worker                authorization_token='token',
217*14675a02SAndroid Build Coastguard Worker                task_name='task1',
218*14675a02SAndroid Build Coastguard Worker                plan=task1_plan,
219*14675a02SAndroid Build Coastguard Worker                init_checkpoint=task1_checkpoint,
220*14675a02SAndroid Build Coastguard Worker                federated_select_uri_info=(
221*14675a02SAndroid Build Coastguard Worker                    task_assignments_pb2.FederatedSelectUriInfo(
222*14675a02SAndroid Build Coastguard Worker                        uri_template=task1_federated_select_uri_template
223*14675a02SAndroid Build Coastguard Worker                    )
224*14675a02SAndroid Build Coastguard Worker                ),
225*14675a02SAndroid Build Coastguard Worker            )
226*14675a02SAndroid Build Coastguard Worker        ),
227*14675a02SAndroid Build Coastguard Worker    )
228*14675a02SAndroid Build Coastguard Worker    self.mock_aggregations_service.pre_authorize_clients.assert_called_with(
229*14675a02SAndroid Build Coastguard Worker        'aggregation-session1', num_tokens=1)
230*14675a02SAndroid Build Coastguard Worker
231*14675a02SAndroid Build Coastguard Worker    # After task1 is removed, task2 should be used.
232*14675a02SAndroid Build Coastguard Worker    service.remove_task('aggregation-session1')
233*14675a02SAndroid Build Coastguard Worker    operation = service.start_task_assignment(request)
234*14675a02SAndroid Build Coastguard Worker    response = task_assignments_pb2.StartTaskAssignmentResponse()
235*14675a02SAndroid Build Coastguard Worker    operation.response.Unpack(response)
236*14675a02SAndroid Build Coastguard Worker    self.assertEqual(
237*14675a02SAndroid Build Coastguard Worker        response,
238*14675a02SAndroid Build Coastguard Worker        task_assignments_pb2.StartTaskAssignmentResponse(
239*14675a02SAndroid Build Coastguard Worker            task_assignment=task_assignments_pb2.TaskAssignment(
240*14675a02SAndroid Build Coastguard Worker                aggregation_data_forwarding_info=FORWARDING_INFO,
241*14675a02SAndroid Build Coastguard Worker                aggregation_info=(
242*14675a02SAndroid Build Coastguard Worker                    task_assignments_pb2.TaskAssignment.AggregationInfo()
243*14675a02SAndroid Build Coastguard Worker                ),
244*14675a02SAndroid Build Coastguard Worker                session_id=request.session_id,
245*14675a02SAndroid Build Coastguard Worker                aggregation_id='aggregation-session2',
246*14675a02SAndroid Build Coastguard Worker                authorization_token='token',
247*14675a02SAndroid Build Coastguard Worker                task_name='task2',
248*14675a02SAndroid Build Coastguard Worker                plan=task2_plan,
249*14675a02SAndroid Build Coastguard Worker                init_checkpoint=task2_checkpoint,
250*14675a02SAndroid Build Coastguard Worker                federated_select_uri_info=(
251*14675a02SAndroid Build Coastguard Worker                    task_assignments_pb2.FederatedSelectUriInfo(
252*14675a02SAndroid Build Coastguard Worker                        uri_template=task2_federated_select_uri_template
253*14675a02SAndroid Build Coastguard Worker                    )
254*14675a02SAndroid Build Coastguard Worker                ),
255*14675a02SAndroid Build Coastguard Worker            )
256*14675a02SAndroid Build Coastguard Worker        ),
257*14675a02SAndroid Build Coastguard Worker    )
258*14675a02SAndroid Build Coastguard Worker    self.mock_aggregations_service.pre_authorize_clients.assert_called_with(
259*14675a02SAndroid Build Coastguard Worker        'aggregation-session2', num_tokens=1)
260*14675a02SAndroid Build Coastguard Worker
261*14675a02SAndroid Build Coastguard Worker    # After task2 is removed, the client should be rejected.
262*14675a02SAndroid Build Coastguard Worker    service.remove_task('aggregation-session2')
263*14675a02SAndroid Build Coastguard Worker    operation = service.start_task_assignment(request)
264*14675a02SAndroid Build Coastguard Worker    response = task_assignments_pb2.StartTaskAssignmentResponse()
265*14675a02SAndroid Build Coastguard Worker    operation.response.Unpack(response)
266*14675a02SAndroid Build Coastguard Worker    self.assertEqual(
267*14675a02SAndroid Build Coastguard Worker        response,
268*14675a02SAndroid Build Coastguard Worker        task_assignments_pb2.StartTaskAssignmentResponse(
269*14675a02SAndroid Build Coastguard Worker            rejection_info=common_pb2.RejectionInfo()))
270*14675a02SAndroid Build Coastguard Worker
271*14675a02SAndroid Build Coastguard Worker  def test_perform_multiple_task_assignments_with_wrong_population(self):
272*14675a02SAndroid Build Coastguard Worker    service = task_assignments.Service(POPULATION_NAME, lambda: FORWARDING_INFO,
273*14675a02SAndroid Build Coastguard Worker                                       self.mock_aggregations_service)
274*14675a02SAndroid Build Coastguard Worker    request = task_assignments_pb2.PerformMultipleTaskAssignmentsRequest(
275*14675a02SAndroid Build Coastguard Worker        population_name='other/population',
276*14675a02SAndroid Build Coastguard Worker        session_id='session-id',
277*14675a02SAndroid Build Coastguard Worker        task_names=['task1', 'task2', 'task3'])
278*14675a02SAndroid Build Coastguard Worker    with self.assertRaises(http_actions.HttpError) as cm:
279*14675a02SAndroid Build Coastguard Worker      service.perform_multiple_task_assignments(request)
280*14675a02SAndroid Build Coastguard Worker    self.assertEqual(cm.exception.code, http.HTTPStatus.NOT_FOUND)
281*14675a02SAndroid Build Coastguard Worker
282*14675a02SAndroid Build Coastguard Worker  def test_perform_multiple_task_assignments_without_tasks(self):
283*14675a02SAndroid Build Coastguard Worker    service = task_assignments.Service(POPULATION_NAME, lambda: FORWARDING_INFO,
284*14675a02SAndroid Build Coastguard Worker                                       self.mock_aggregations_service)
285*14675a02SAndroid Build Coastguard Worker
286*14675a02SAndroid Build Coastguard Worker    request = task_assignments_pb2.PerformMultipleTaskAssignmentsRequest(
287*14675a02SAndroid Build Coastguard Worker        population_name=POPULATION_NAME,
288*14675a02SAndroid Build Coastguard Worker        session_id='session-id',
289*14675a02SAndroid Build Coastguard Worker        task_names=['task1', 'task2', 'task3'])
290*14675a02SAndroid Build Coastguard Worker    self.assertEqual(
291*14675a02SAndroid Build Coastguard Worker        service.perform_multiple_task_assignments(request),
292*14675a02SAndroid Build Coastguard Worker        task_assignments_pb2.PerformMultipleTaskAssignmentsResponse())
293*14675a02SAndroid Build Coastguard Worker
294*14675a02SAndroid Build Coastguard Worker  def test_perform_multiple_task_assignments_with_multiple_tasks(self):
295*14675a02SAndroid Build Coastguard Worker    self.mock_aggregations_service.pre_authorize_clients.side_effect = (
296*14675a02SAndroid Build Coastguard Worker        lambda session_id, num_tokens=1: [f'token-for-{session_id}'])
297*14675a02SAndroid Build Coastguard Worker    service = task_assignments.Service(POPULATION_NAME, lambda: FORWARDING_INFO,
298*14675a02SAndroid Build Coastguard Worker                                       self.mock_aggregations_service)
299*14675a02SAndroid Build Coastguard Worker
300*14675a02SAndroid Build Coastguard Worker    task1_plan = common_pb2.Resource(uri='https://task1.example/plan')
301*14675a02SAndroid Build Coastguard Worker    task1_checkpoint = common_pb2.Resource(
302*14675a02SAndroid Build Coastguard Worker        uri='https://task1.example/checkpoint')
303*14675a02SAndroid Build Coastguard Worker    task1_federated_select_uri_template = 'https://task1.example/{key_base10}'
304*14675a02SAndroid Build Coastguard Worker    service.add_task(
305*14675a02SAndroid Build Coastguard Worker        'task1',
306*14675a02SAndroid Build Coastguard Worker        _TaskAssignmentMode.TASK_ASSIGNMENT_MODE_MULTIPLE,
307*14675a02SAndroid Build Coastguard Worker        'aggregation-session1',
308*14675a02SAndroid Build Coastguard Worker        task1_plan,
309*14675a02SAndroid Build Coastguard Worker        task1_checkpoint,
310*14675a02SAndroid Build Coastguard Worker        task1_federated_select_uri_template,
311*14675a02SAndroid Build Coastguard Worker    )
312*14675a02SAndroid Build Coastguard Worker    task2_plan = common_pb2.Resource(uri='https://task2.example/plan')
313*14675a02SAndroid Build Coastguard Worker    task2_checkpoint = common_pb2.Resource(
314*14675a02SAndroid Build Coastguard Worker        uri='https://task2.example/checkpoint')
315*14675a02SAndroid Build Coastguard Worker    task2_federated_select_uri_template = 'https://task2.example/{key_base10}'
316*14675a02SAndroid Build Coastguard Worker    service.add_task(
317*14675a02SAndroid Build Coastguard Worker        'task2',
318*14675a02SAndroid Build Coastguard Worker        _TaskAssignmentMode.TASK_ASSIGNMENT_MODE_MULTIPLE,
319*14675a02SAndroid Build Coastguard Worker        'aggregation-session2',
320*14675a02SAndroid Build Coastguard Worker        task2_plan,
321*14675a02SAndroid Build Coastguard Worker        task2_checkpoint,
322*14675a02SAndroid Build Coastguard Worker        task2_federated_select_uri_template,
323*14675a02SAndroid Build Coastguard Worker    )
324*14675a02SAndroid Build Coastguard Worker    # Tasks using other TaskAssignmentModes should be skipped.
325*14675a02SAndroid Build Coastguard Worker    task3_plan = common_pb2.Resource(uri='https://task3.example/plan')
326*14675a02SAndroid Build Coastguard Worker    task3_checkpoint = common_pb2.Resource(
327*14675a02SAndroid Build Coastguard Worker        uri='https://task3.example/checkpoint'
328*14675a02SAndroid Build Coastguard Worker    )
329*14675a02SAndroid Build Coastguard Worker    task3_federated_select_uri_template = 'https://task3.example/{key_base10}'
330*14675a02SAndroid Build Coastguard Worker    service.add_task(
331*14675a02SAndroid Build Coastguard Worker        'task3',
332*14675a02SAndroid Build Coastguard Worker        _TaskAssignmentMode.TASK_ASSIGNMENT_MODE_SINGLE,
333*14675a02SAndroid Build Coastguard Worker        'aggregation-session3',
334*14675a02SAndroid Build Coastguard Worker        task3_plan,
335*14675a02SAndroid Build Coastguard Worker        task3_checkpoint,
336*14675a02SAndroid Build Coastguard Worker        task3_federated_select_uri_template,
337*14675a02SAndroid Build Coastguard Worker    )
338*14675a02SAndroid Build Coastguard Worker
339*14675a02SAndroid Build Coastguard Worker    request = task_assignments_pb2.PerformMultipleTaskAssignmentsRequest(
340*14675a02SAndroid Build Coastguard Worker        population_name=POPULATION_NAME,
341*14675a02SAndroid Build Coastguard Worker        session_id='session-id',
342*14675a02SAndroid Build Coastguard Worker        task_names=['task1', 'task2', 'task3'])
343*14675a02SAndroid Build Coastguard Worker    self.assertCountEqual(
344*14675a02SAndroid Build Coastguard Worker        service.perform_multiple_task_assignments(request).task_assignments,
345*14675a02SAndroid Build Coastguard Worker        [
346*14675a02SAndroid Build Coastguard Worker            task_assignments_pb2.TaskAssignment(
347*14675a02SAndroid Build Coastguard Worker                aggregation_data_forwarding_info=FORWARDING_INFO,
348*14675a02SAndroid Build Coastguard Worker                aggregation_info=(
349*14675a02SAndroid Build Coastguard Worker                    task_assignments_pb2.TaskAssignment.AggregationInfo()
350*14675a02SAndroid Build Coastguard Worker                ),
351*14675a02SAndroid Build Coastguard Worker                session_id=request.session_id,
352*14675a02SAndroid Build Coastguard Worker                aggregation_id='aggregation-session1',
353*14675a02SAndroid Build Coastguard Worker                authorization_token='token-for-aggregation-session1',
354*14675a02SAndroid Build Coastguard Worker                task_name='task1',
355*14675a02SAndroid Build Coastguard Worker                plan=task1_plan,
356*14675a02SAndroid Build Coastguard Worker                init_checkpoint=task1_checkpoint,
357*14675a02SAndroid Build Coastguard Worker                federated_select_uri_info=(
358*14675a02SAndroid Build Coastguard Worker                    task_assignments_pb2.FederatedSelectUriInfo(
359*14675a02SAndroid Build Coastguard Worker                        uri_template=task1_federated_select_uri_template
360*14675a02SAndroid Build Coastguard Worker                    )
361*14675a02SAndroid Build Coastguard Worker                ),
362*14675a02SAndroid Build Coastguard Worker            ),
363*14675a02SAndroid Build Coastguard Worker            task_assignments_pb2.TaskAssignment(
364*14675a02SAndroid Build Coastguard Worker                aggregation_data_forwarding_info=FORWARDING_INFO,
365*14675a02SAndroid Build Coastguard Worker                aggregation_info=(
366*14675a02SAndroid Build Coastguard Worker                    task_assignments_pb2.TaskAssignment.AggregationInfo()
367*14675a02SAndroid Build Coastguard Worker                ),
368*14675a02SAndroid Build Coastguard Worker                session_id=request.session_id,
369*14675a02SAndroid Build Coastguard Worker                aggregation_id='aggregation-session2',
370*14675a02SAndroid Build Coastguard Worker                authorization_token='token-for-aggregation-session2',
371*14675a02SAndroid Build Coastguard Worker                task_name='task2',
372*14675a02SAndroid Build Coastguard Worker                plan=task2_plan,
373*14675a02SAndroid Build Coastguard Worker                init_checkpoint=task2_checkpoint,
374*14675a02SAndroid Build Coastguard Worker                federated_select_uri_info=(
375*14675a02SAndroid Build Coastguard Worker                    task_assignments_pb2.FederatedSelectUriInfo(
376*14675a02SAndroid Build Coastguard Worker                        uri_template=task2_federated_select_uri_template
377*14675a02SAndroid Build Coastguard Worker                    )
378*14675a02SAndroid Build Coastguard Worker                ),
379*14675a02SAndroid Build Coastguard Worker            ),
380*14675a02SAndroid Build Coastguard Worker            # 'task3' should be omitted since there isn't a corresponding task.
381*14675a02SAndroid Build Coastguard Worker        ],
382*14675a02SAndroid Build Coastguard Worker    )
383*14675a02SAndroid Build Coastguard Worker
384*14675a02SAndroid Build Coastguard Worker  def test_add_task_with_invalid_task_assignment_mode(self):
385*14675a02SAndroid Build Coastguard Worker    service = task_assignments.Service(
386*14675a02SAndroid Build Coastguard Worker        POPULATION_NAME, lambda: FORWARDING_INFO, self.mock_aggregations_service
387*14675a02SAndroid Build Coastguard Worker    )
388*14675a02SAndroid Build Coastguard Worker    with self.assertRaises(ValueError):
389*14675a02SAndroid Build Coastguard Worker      service.add_task(
390*14675a02SAndroid Build Coastguard Worker          'task',
391*14675a02SAndroid Build Coastguard Worker          _TaskAssignmentMode.TASK_ASSIGNMENT_MODE_UNSPECIFIED,
392*14675a02SAndroid Build Coastguard Worker          'aggregation-session',
393*14675a02SAndroid Build Coastguard Worker          common_pb2.Resource(uri='https://task.example/plan'),
394*14675a02SAndroid Build Coastguard Worker          common_pb2.Resource(uri='https://task.example/checkpoint'),
395*14675a02SAndroid Build Coastguard Worker          'https://task.example/{key_base10}',
396*14675a02SAndroid Build Coastguard Worker      )
397*14675a02SAndroid Build Coastguard Worker
398*14675a02SAndroid Build Coastguard Worker  def test_remove_multiple_assignment_task(self):
399*14675a02SAndroid Build Coastguard Worker    service = task_assignments.Service(
400*14675a02SAndroid Build Coastguard Worker        POPULATION_NAME, lambda: FORWARDING_INFO, self.mock_aggregations_service
401*14675a02SAndroid Build Coastguard Worker    )
402*14675a02SAndroid Build Coastguard Worker    service.add_task(
403*14675a02SAndroid Build Coastguard Worker        'task',
404*14675a02SAndroid Build Coastguard Worker        _TaskAssignmentMode.TASK_ASSIGNMENT_MODE_MULTIPLE,
405*14675a02SAndroid Build Coastguard Worker        'aggregation-session',
406*14675a02SAndroid Build Coastguard Worker        common_pb2.Resource(uri='https://task.example/plan'),
407*14675a02SAndroid Build Coastguard Worker        common_pb2.Resource(uri='https://task.example/checkpoint'),
408*14675a02SAndroid Build Coastguard Worker        'https://task.example/{key_base10}',
409*14675a02SAndroid Build Coastguard Worker    )
410*14675a02SAndroid Build Coastguard Worker    service.remove_task('aggregation-session')
411*14675a02SAndroid Build Coastguard Worker
412*14675a02SAndroid Build Coastguard Worker    request = task_assignments_pb2.PerformMultipleTaskAssignmentsRequest(
413*14675a02SAndroid Build Coastguard Worker        population_name=POPULATION_NAME,
414*14675a02SAndroid Build Coastguard Worker        session_id='session-id',
415*14675a02SAndroid Build Coastguard Worker        task_names=['task'],
416*14675a02SAndroid Build Coastguard Worker    )
417*14675a02SAndroid Build Coastguard Worker    self.assertEmpty(
418*14675a02SAndroid Build Coastguard Worker        service.perform_multiple_task_assignments(request).task_assignments
419*14675a02SAndroid Build Coastguard Worker    )
420*14675a02SAndroid Build Coastguard Worker
421*14675a02SAndroid Build Coastguard Worker  def test_remove_missing_task(self):
422*14675a02SAndroid Build Coastguard Worker    service = task_assignments.Service(POPULATION_NAME, lambda: FORWARDING_INFO,
423*14675a02SAndroid Build Coastguard Worker                                       self.mock_aggregations_service)
424*14675a02SAndroid Build Coastguard Worker    with self.assertRaises(KeyError):
425*14675a02SAndroid Build Coastguard Worker      service.remove_task('does-not-exist')
426*14675a02SAndroid Build Coastguard Worker
427*14675a02SAndroid Build Coastguard Worker  def test_report_task_result(self):
428*14675a02SAndroid Build Coastguard Worker    service = task_assignments.Service(POPULATION_NAME, lambda: FORWARDING_INFO,
429*14675a02SAndroid Build Coastguard Worker                                       self.mock_aggregations_service)
430*14675a02SAndroid Build Coastguard Worker    request = task_assignments_pb2.ReportTaskResultRequest(
431*14675a02SAndroid Build Coastguard Worker        population_name=POPULATION_NAME,
432*14675a02SAndroid Build Coastguard Worker        session_id='session-id',
433*14675a02SAndroid Build Coastguard Worker        aggregation_id='aggregation-id',
434*14675a02SAndroid Build Coastguard Worker        computation_status_code=code_pb2.ABORTED)
435*14675a02SAndroid Build Coastguard Worker    self.assertEqual(
436*14675a02SAndroid Build Coastguard Worker        service.report_task_result(request),
437*14675a02SAndroid Build Coastguard Worker        task_assignments_pb2.ReportTaskResultResponse())
438*14675a02SAndroid Build Coastguard Worker
439*14675a02SAndroid Build Coastguard Worker  def test_report_task_result_with_wrong_population(self):
440*14675a02SAndroid Build Coastguard Worker    service = task_assignments.Service(POPULATION_NAME, lambda: FORWARDING_INFO,
441*14675a02SAndroid Build Coastguard Worker                                       self.mock_aggregations_service)
442*14675a02SAndroid Build Coastguard Worker    request = task_assignments_pb2.ReportTaskResultRequest(
443*14675a02SAndroid Build Coastguard Worker        population_name='other/population',
444*14675a02SAndroid Build Coastguard Worker        session_id='session-id',
445*14675a02SAndroid Build Coastguard Worker        aggregation_id='aggregation-id',
446*14675a02SAndroid Build Coastguard Worker        computation_status_code=code_pb2.ABORTED)
447*14675a02SAndroid Build Coastguard Worker    with self.assertRaises(http_actions.HttpError) as cm:
448*14675a02SAndroid Build Coastguard Worker      service.report_task_result(request)
449*14675a02SAndroid Build Coastguard Worker    self.assertEqual(cm.exception.code, http.HTTPStatus.NOT_FOUND)
450*14675a02SAndroid Build Coastguard Worker
451*14675a02SAndroid Build Coastguard Worker
452*14675a02SAndroid Build Coastguard Workerif __name__ == '__main__':
453*14675a02SAndroid Build Coastguard Worker  absltest.main()
454