xref: /aosp_15_r20/external/federated-compute/fcp/demo/aggregations_test.py (revision 14675a029014e728ec732f129a32e299b2da0601)
1*14675a02SAndroid Build Coastguard Worker# Copyright 2022 Google LLC
2*14675a02SAndroid Build Coastguard Worker#
3*14675a02SAndroid Build Coastguard Worker# Licensed under the Apache License, Version 2.0 (the "License");
4*14675a02SAndroid Build Coastguard Worker# you may not use this file except in compliance with the License.
5*14675a02SAndroid Build Coastguard Worker# You may obtain a copy of the License at
6*14675a02SAndroid Build Coastguard Worker#
7*14675a02SAndroid Build Coastguard Worker#      http://www.apache.org/licenses/LICENSE-2.0
8*14675a02SAndroid Build Coastguard Worker#
9*14675a02SAndroid Build Coastguard Worker# Unless required by applicable law or agreed to in writing, software
10*14675a02SAndroid Build Coastguard Worker# distributed under the License is distributed on an "AS IS" BASIS,
11*14675a02SAndroid Build Coastguard Worker# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12*14675a02SAndroid Build Coastguard Worker# See the License for the specific language governing permissions and
13*14675a02SAndroid Build Coastguard Worker# limitations under the License.
14*14675a02SAndroid Build Coastguard Worker"""Tests for aggregations."""
15*14675a02SAndroid Build Coastguard Worker
16*14675a02SAndroid Build Coastguard Workerimport asyncio
17*14675a02SAndroid Build Coastguard Workerimport http
18*14675a02SAndroid Build Coastguard Workerimport unittest
19*14675a02SAndroid Build Coastguard Workerfrom unittest import mock
20*14675a02SAndroid Build Coastguard Worker
21*14675a02SAndroid Build Coastguard Workerfrom absl.testing import absltest
22*14675a02SAndroid Build Coastguard Workerimport tensorflow as tf
23*14675a02SAndroid Build Coastguard Worker
24*14675a02SAndroid Build Coastguard Workerfrom fcp.aggregation.protocol import aggregation_protocol_messages_pb2 as apm_pb2
25*14675a02SAndroid Build Coastguard Workerfrom fcp.aggregation.protocol.python import aggregation_protocol
26*14675a02SAndroid Build Coastguard Workerfrom fcp.aggregation.tensorflow.python import aggregation_protocols
27*14675a02SAndroid Build Coastguard Workerfrom fcp.demo import aggregations
28*14675a02SAndroid Build Coastguard Workerfrom fcp.demo import http_actions
29*14675a02SAndroid Build Coastguard Workerfrom fcp.demo import media
30*14675a02SAndroid Build Coastguard Workerfrom fcp.demo import test_utils
31*14675a02SAndroid Build Coastguard Workerfrom fcp.protos import plan_pb2
32*14675a02SAndroid Build Coastguard Workerfrom fcp.protos.federatedcompute import aggregations_pb2
33*14675a02SAndroid Build Coastguard Workerfrom fcp.protos.federatedcompute import common_pb2
34*14675a02SAndroid Build Coastguard Workerfrom pybind11_abseil import status as absl_status
35*14675a02SAndroid Build Coastguard Worker
36*14675a02SAndroid Build Coastguard WorkerINPUT_TENSOR = 'in'
37*14675a02SAndroid Build Coastguard WorkerOUTPUT_TENSOR = 'out'
38*14675a02SAndroid Build Coastguard WorkerAGGREGATION_REQUIREMENTS = aggregations.AggregationRequirements(
39*14675a02SAndroid Build Coastguard Worker    minimum_clients_in_server_published_aggregate=3,
40*14675a02SAndroid Build Coastguard Worker    plan=plan_pb2.Plan(phase=[
41*14675a02SAndroid Build Coastguard Worker        plan_pb2.Plan.Phase(
42*14675a02SAndroid Build Coastguard Worker            server_phase_v2=plan_pb2.ServerPhaseV2(aggregations=[
43*14675a02SAndroid Build Coastguard Worker                plan_pb2.ServerAggregationConfig(
44*14675a02SAndroid Build Coastguard Worker                    intrinsic_uri='federated_sum',
45*14675a02SAndroid Build Coastguard Worker                    intrinsic_args=[
46*14675a02SAndroid Build Coastguard Worker                        plan_pb2.ServerAggregationConfig.IntrinsicArg(
47*14675a02SAndroid Build Coastguard Worker                            input_tensor=tf.TensorSpec((
48*14675a02SAndroid Build Coastguard Worker                            ), tf.int32, INPUT_TENSOR).experimental_as_proto())
49*14675a02SAndroid Build Coastguard Worker                    ],
50*14675a02SAndroid Build Coastguard Worker                    output_tensors=[
51*14675a02SAndroid Build Coastguard Worker                        tf.TensorSpec((
52*14675a02SAndroid Build Coastguard Worker                        ), tf.int32, OUTPUT_TENSOR).experimental_as_proto(),
53*14675a02SAndroid Build Coastguard Worker                    ]),
54*14675a02SAndroid Build Coastguard Worker            ])),
55*14675a02SAndroid Build Coastguard Worker    ]))
56*14675a02SAndroid Build Coastguard WorkerFORWARDING_INFO = common_pb2.ForwardingInfo(
57*14675a02SAndroid Build Coastguard Worker    target_uri_prefix='https://forwarding.example/')
58*14675a02SAndroid Build Coastguard Worker
59*14675a02SAndroid Build Coastguard Worker
60*14675a02SAndroid Build Coastguard Workerclass NotOkStatus:
61*14675a02SAndroid Build Coastguard Worker  """Matcher for a not-ok Status."""
62*14675a02SAndroid Build Coastguard Worker
63*14675a02SAndroid Build Coastguard Worker  def __eq__(self, other) -> bool:
64*14675a02SAndroid Build Coastguard Worker    return isinstance(other, absl_status.Status) and not other.ok()
65*14675a02SAndroid Build Coastguard Worker
66*14675a02SAndroid Build Coastguard Worker
67*14675a02SAndroid Build Coastguard Workerclass AggregationsTest(absltest.TestCase, unittest.IsolatedAsyncioTestCase):
68*14675a02SAndroid Build Coastguard Worker
69*14675a02SAndroid Build Coastguard Worker  def setUp(self):
70*14675a02SAndroid Build Coastguard Worker    super().setUp()
71*14675a02SAndroid Build Coastguard Worker    self.mock_media_service = self.enter_context(
72*14675a02SAndroid Build Coastguard Worker        mock.patch.object(media, 'Service', autospec=True))
73*14675a02SAndroid Build Coastguard Worker    self.mock_media_service.register_upload.return_value = 'upload-id'
74*14675a02SAndroid Build Coastguard Worker    self.mock_media_service.finalize_upload.return_value = (
75*14675a02SAndroid Build Coastguard Worker        test_utils.create_checkpoint({INPUT_TENSOR: 0}))
76*14675a02SAndroid Build Coastguard Worker
77*14675a02SAndroid Build Coastguard Worker  def test_pre_authorize_clients(self):
78*14675a02SAndroid Build Coastguard Worker    service = aggregations.Service(lambda: FORWARDING_INFO,
79*14675a02SAndroid Build Coastguard Worker                                   self.mock_media_service)
80*14675a02SAndroid Build Coastguard Worker    session_id = service.create_session(AGGREGATION_REQUIREMENTS)
81*14675a02SAndroid Build Coastguard Worker    tokens = service.pre_authorize_clients(session_id, 3)
82*14675a02SAndroid Build Coastguard Worker    self.assertLen(tokens, 3)
83*14675a02SAndroid Build Coastguard Worker    # The tokens should all be unique.
84*14675a02SAndroid Build Coastguard Worker    self.assertLen(set(tokens), 3)
85*14675a02SAndroid Build Coastguard Worker
86*14675a02SAndroid Build Coastguard Worker  def test_pre_authorize_clients_with_missing_session_id(self):
87*14675a02SAndroid Build Coastguard Worker    service = aggregations.Service(lambda: FORWARDING_INFO,
88*14675a02SAndroid Build Coastguard Worker                                   self.mock_media_service)
89*14675a02SAndroid Build Coastguard Worker    with self.assertRaises(KeyError):
90*14675a02SAndroid Build Coastguard Worker      service.pre_authorize_clients('does-not-exist', 1)
91*14675a02SAndroid Build Coastguard Worker
92*14675a02SAndroid Build Coastguard Worker  def test_pre_authorize_clients_with_bad_count(self):
93*14675a02SAndroid Build Coastguard Worker    service = aggregations.Service(lambda: FORWARDING_INFO,
94*14675a02SAndroid Build Coastguard Worker                                   self.mock_media_service)
95*14675a02SAndroid Build Coastguard Worker    session_id = service.create_session(AGGREGATION_REQUIREMENTS)
96*14675a02SAndroid Build Coastguard Worker    self.assertEmpty(service.pre_authorize_clients(session_id, 0))
97*14675a02SAndroid Build Coastguard Worker    self.assertEmpty(service.pre_authorize_clients(session_id, -2))
98*14675a02SAndroid Build Coastguard Worker
99*14675a02SAndroid Build Coastguard Worker  def test_create_session(self):
100*14675a02SAndroid Build Coastguard Worker    service = aggregations.Service(lambda: FORWARDING_INFO,
101*14675a02SAndroid Build Coastguard Worker                                   self.mock_media_service)
102*14675a02SAndroid Build Coastguard Worker    session_id = service.create_session(AGGREGATION_REQUIREMENTS)
103*14675a02SAndroid Build Coastguard Worker    self.assertEqual(
104*14675a02SAndroid Build Coastguard Worker        service.get_session_status(session_id),
105*14675a02SAndroid Build Coastguard Worker        aggregations.SessionStatus(
106*14675a02SAndroid Build Coastguard Worker            status=aggregations.AggregationStatus.PENDING))
107*14675a02SAndroid Build Coastguard Worker
108*14675a02SAndroid Build Coastguard Worker  def test_complete_session(self):
109*14675a02SAndroid Build Coastguard Worker    service = aggregations.Service(lambda: FORWARDING_INFO,
110*14675a02SAndroid Build Coastguard Worker                                   self.mock_media_service)
111*14675a02SAndroid Build Coastguard Worker    session_id = service.create_session(AGGREGATION_REQUIREMENTS)
112*14675a02SAndroid Build Coastguard Worker
113*14675a02SAndroid Build Coastguard Worker    # Upload results from the client.
114*14675a02SAndroid Build Coastguard Worker    num_clients = (
115*14675a02SAndroid Build Coastguard Worker        AGGREGATION_REQUIREMENTS.minimum_clients_in_server_published_aggregate)
116*14675a02SAndroid Build Coastguard Worker    for i in range(num_clients):
117*14675a02SAndroid Build Coastguard Worker      tokens = service.pre_authorize_clients(session_id, 1)
118*14675a02SAndroid Build Coastguard Worker
119*14675a02SAndroid Build Coastguard Worker      self.mock_media_service.register_upload.return_value = f'upload-{i}'
120*14675a02SAndroid Build Coastguard Worker      operation = service.start_aggregation_data_upload(
121*14675a02SAndroid Build Coastguard Worker          aggregations_pb2.StartAggregationDataUploadRequest(
122*14675a02SAndroid Build Coastguard Worker              aggregation_id=session_id, authorization_token=tokens[0]))
123*14675a02SAndroid Build Coastguard Worker      self.assertTrue(operation.done)
124*14675a02SAndroid Build Coastguard Worker      start_upload_response = (
125*14675a02SAndroid Build Coastguard Worker          aggregations_pb2.StartAggregationDataUploadResponse())
126*14675a02SAndroid Build Coastguard Worker      operation.response.Unpack(start_upload_response)
127*14675a02SAndroid Build Coastguard Worker
128*14675a02SAndroid Build Coastguard Worker      self.mock_media_service.finalize_upload.return_value = (
129*14675a02SAndroid Build Coastguard Worker          test_utils.create_checkpoint({INPUT_TENSOR: i}))
130*14675a02SAndroid Build Coastguard Worker      service.submit_aggregation_result(
131*14675a02SAndroid Build Coastguard Worker          aggregations_pb2.SubmitAggregationResultRequest(
132*14675a02SAndroid Build Coastguard Worker              aggregation_id=session_id,
133*14675a02SAndroid Build Coastguard Worker              client_token=start_upload_response.client_token,
134*14675a02SAndroid Build Coastguard Worker              resource_name=start_upload_response.resource.resource_name))
135*14675a02SAndroid Build Coastguard Worker
136*14675a02SAndroid Build Coastguard Worker    # Now that all clients have contributed, the aggregation session can be
137*14675a02SAndroid Build Coastguard Worker    # completed.
138*14675a02SAndroid Build Coastguard Worker    status, aggregate = service.complete_session(session_id)
139*14675a02SAndroid Build Coastguard Worker    self.assertEqual(
140*14675a02SAndroid Build Coastguard Worker        status,
141*14675a02SAndroid Build Coastguard Worker        aggregations.SessionStatus(
142*14675a02SAndroid Build Coastguard Worker            status=aggregations.AggregationStatus.COMPLETED,
143*14675a02SAndroid Build Coastguard Worker            num_clients_completed=num_clients,
144*14675a02SAndroid Build Coastguard Worker            num_inputs_aggregated_and_included=num_clients))
145*14675a02SAndroid Build Coastguard Worker    self.assertEqual(
146*14675a02SAndroid Build Coastguard Worker        test_utils.read_tensor_from_checkpoint(aggregate,
147*14675a02SAndroid Build Coastguard Worker                                               OUTPUT_TENSOR, tf.int32),
148*14675a02SAndroid Build Coastguard Worker        sum(range(num_clients)))
149*14675a02SAndroid Build Coastguard Worker
150*14675a02SAndroid Build Coastguard Worker    # get_session_status should no longer return results.
151*14675a02SAndroid Build Coastguard Worker    with self.assertRaises(KeyError):
152*14675a02SAndroid Build Coastguard Worker      service.get_session_status(session_id)
153*14675a02SAndroid Build Coastguard Worker
154*14675a02SAndroid Build Coastguard Worker  @mock.patch.object(
155*14675a02SAndroid Build Coastguard Worker      aggregation_protocols,
156*14675a02SAndroid Build Coastguard Worker      'create_simple_aggregation_protocol',
157*14675a02SAndroid Build Coastguard Worker      autospec=True)
158*14675a02SAndroid Build Coastguard Worker  def test_complete_session_fails(self, mock_create_simple_agg_protocol):
159*14675a02SAndroid Build Coastguard Worker    # Use a mock since it's not easy to cause
160*14675a02SAndroid Build Coastguard Worker    # SimpleAggregationProtocol::Complete to fail.
161*14675a02SAndroid Build Coastguard Worker    mock_agg_protocol = mock.create_autospec(
162*14675a02SAndroid Build Coastguard Worker        aggregation_protocol.AggregationProtocol, instance=True)
163*14675a02SAndroid Build Coastguard Worker    mock_create_simple_agg_protocol.return_value = mock_agg_protocol
164*14675a02SAndroid Build Coastguard Worker
165*14675a02SAndroid Build Coastguard Worker    service = aggregations.Service(lambda: FORWARDING_INFO,
166*14675a02SAndroid Build Coastguard Worker                                   self.mock_media_service)
167*14675a02SAndroid Build Coastguard Worker    session_id = service.create_session(AGGREGATION_REQUIREMENTS)
168*14675a02SAndroid Build Coastguard Worker
169*14675a02SAndroid Build Coastguard Worker    required_clients = (
170*14675a02SAndroid Build Coastguard Worker        AGGREGATION_REQUIREMENTS.minimum_clients_in_server_published_aggregate)
171*14675a02SAndroid Build Coastguard Worker    agg_status = apm_pb2.StatusMessage(
172*14675a02SAndroid Build Coastguard Worker        num_inputs_aggregated_and_included=required_clients)
173*14675a02SAndroid Build Coastguard Worker    mock_agg_protocol.GetStatus.side_effect = lambda: agg_status
174*14675a02SAndroid Build Coastguard Worker
175*14675a02SAndroid Build Coastguard Worker    def on_complete():
176*14675a02SAndroid Build Coastguard Worker      agg_status.num_inputs_discarded = (
177*14675a02SAndroid Build Coastguard Worker          agg_status.num_inputs_aggregated_and_included)
178*14675a02SAndroid Build Coastguard Worker      agg_status.num_inputs_aggregated_and_included = 0
179*14675a02SAndroid Build Coastguard Worker      raise absl_status.StatusNotOk(absl_status.unknown_error('message'))
180*14675a02SAndroid Build Coastguard Worker
181*14675a02SAndroid Build Coastguard Worker    mock_agg_protocol.Complete.side_effect = on_complete
182*14675a02SAndroid Build Coastguard Worker
183*14675a02SAndroid Build Coastguard Worker    status, aggregate = service.complete_session(session_id)
184*14675a02SAndroid Build Coastguard Worker    self.assertEqual(
185*14675a02SAndroid Build Coastguard Worker        status,
186*14675a02SAndroid Build Coastguard Worker        aggregations.SessionStatus(
187*14675a02SAndroid Build Coastguard Worker            status=aggregations.AggregationStatus.FAILED,
188*14675a02SAndroid Build Coastguard Worker            num_inputs_discarded=required_clients))
189*14675a02SAndroid Build Coastguard Worker    self.assertIsNone(aggregate)
190*14675a02SAndroid Build Coastguard Worker    mock_agg_protocol.Complete.assert_called_once()
191*14675a02SAndroid Build Coastguard Worker    mock_agg_protocol.Abort.assert_not_called()
192*14675a02SAndroid Build Coastguard Worker
193*14675a02SAndroid Build Coastguard Worker  @mock.patch.object(
194*14675a02SAndroid Build Coastguard Worker      aggregation_protocols,
195*14675a02SAndroid Build Coastguard Worker      'create_simple_aggregation_protocol',
196*14675a02SAndroid Build Coastguard Worker      autospec=True)
197*14675a02SAndroid Build Coastguard Worker  def test_complete_session_aborts(self, mock_create_simple_agg_protocol):
198*14675a02SAndroid Build Coastguard Worker    # Use a mock since it's not easy to cause
199*14675a02SAndroid Build Coastguard Worker    # SimpleAggregationProtocol::Complete to trigger a protocol abort.
200*14675a02SAndroid Build Coastguard Worker    mock_agg_protocol = mock.create_autospec(
201*14675a02SAndroid Build Coastguard Worker        aggregation_protocol.AggregationProtocol, instance=True)
202*14675a02SAndroid Build Coastguard Worker    mock_create_simple_agg_protocol.return_value = mock_agg_protocol
203*14675a02SAndroid Build Coastguard Worker
204*14675a02SAndroid Build Coastguard Worker    service = aggregations.Service(lambda: FORWARDING_INFO,
205*14675a02SAndroid Build Coastguard Worker                                   self.mock_media_service)
206*14675a02SAndroid Build Coastguard Worker    session_id = service.create_session(AGGREGATION_REQUIREMENTS)
207*14675a02SAndroid Build Coastguard Worker
208*14675a02SAndroid Build Coastguard Worker    required_clients = (
209*14675a02SAndroid Build Coastguard Worker        AGGREGATION_REQUIREMENTS.minimum_clients_in_server_published_aggregate)
210*14675a02SAndroid Build Coastguard Worker    agg_status = apm_pb2.StatusMessage(
211*14675a02SAndroid Build Coastguard Worker        num_inputs_aggregated_and_included=required_clients)
212*14675a02SAndroid Build Coastguard Worker    mock_agg_protocol.GetStatus.side_effect = lambda: agg_status
213*14675a02SAndroid Build Coastguard Worker
214*14675a02SAndroid Build Coastguard Worker    def on_complete():
215*14675a02SAndroid Build Coastguard Worker      agg_status.num_inputs_discarded = (
216*14675a02SAndroid Build Coastguard Worker          agg_status.num_inputs_aggregated_and_included)
217*14675a02SAndroid Build Coastguard Worker      agg_status.num_inputs_aggregated_and_included = 0
218*14675a02SAndroid Build Coastguard Worker      callback = mock_create_simple_agg_protocol.call_args.args[1]
219*14675a02SAndroid Build Coastguard Worker      callback.OnAbort(absl_status.unknown_error('message'))
220*14675a02SAndroid Build Coastguard Worker
221*14675a02SAndroid Build Coastguard Worker    mock_agg_protocol.Complete.side_effect = on_complete
222*14675a02SAndroid Build Coastguard Worker
223*14675a02SAndroid Build Coastguard Worker    status, aggregate = service.complete_session(session_id)
224*14675a02SAndroid Build Coastguard Worker    self.assertEqual(
225*14675a02SAndroid Build Coastguard Worker        status,
226*14675a02SAndroid Build Coastguard Worker        aggregations.SessionStatus(
227*14675a02SAndroid Build Coastguard Worker            status=aggregations.AggregationStatus.FAILED,
228*14675a02SAndroid Build Coastguard Worker            num_inputs_discarded=required_clients))
229*14675a02SAndroid Build Coastguard Worker    self.assertIsNone(aggregate)
230*14675a02SAndroid Build Coastguard Worker    mock_agg_protocol.Complete.assert_called_once()
231*14675a02SAndroid Build Coastguard Worker    mock_agg_protocol.Abort.assert_not_called()
232*14675a02SAndroid Build Coastguard Worker
233*14675a02SAndroid Build Coastguard Worker  def test_complete_session_without_enough_inputs(self):
234*14675a02SAndroid Build Coastguard Worker    service = aggregations.Service(lambda: FORWARDING_INFO,
235*14675a02SAndroid Build Coastguard Worker                                   self.mock_media_service)
236*14675a02SAndroid Build Coastguard Worker    session_id = service.create_session(
237*14675a02SAndroid Build Coastguard Worker        aggregations.AggregationRequirements(
238*14675a02SAndroid Build Coastguard Worker            minimum_clients_in_server_published_aggregate=3,
239*14675a02SAndroid Build Coastguard Worker            plan=AGGREGATION_REQUIREMENTS.plan))
240*14675a02SAndroid Build Coastguard Worker    tokens = service.pre_authorize_clients(session_id, 2)
241*14675a02SAndroid Build Coastguard Worker
242*14675a02SAndroid Build Coastguard Worker    # Upload results for one client.
243*14675a02SAndroid Build Coastguard Worker    operation = service.start_aggregation_data_upload(
244*14675a02SAndroid Build Coastguard Worker        aggregations_pb2.StartAggregationDataUploadRequest(
245*14675a02SAndroid Build Coastguard Worker            aggregation_id=session_id, authorization_token=tokens[0]))
246*14675a02SAndroid Build Coastguard Worker    self.assertTrue(operation.done)
247*14675a02SAndroid Build Coastguard Worker    start_upload_response = (
248*14675a02SAndroid Build Coastguard Worker        aggregations_pb2.StartAggregationDataUploadResponse())
249*14675a02SAndroid Build Coastguard Worker    operation.response.Unpack(start_upload_response)
250*14675a02SAndroid Build Coastguard Worker    service.submit_aggregation_result(
251*14675a02SAndroid Build Coastguard Worker        aggregations_pb2.SubmitAggregationResultRequest(
252*14675a02SAndroid Build Coastguard Worker            aggregation_id=session_id,
253*14675a02SAndroid Build Coastguard Worker            client_token=start_upload_response.client_token,
254*14675a02SAndroid Build Coastguard Worker            resource_name=start_upload_response.resource.resource_name))
255*14675a02SAndroid Build Coastguard Worker
256*14675a02SAndroid Build Coastguard Worker    # Complete the session before there are 2 completed clients.
257*14675a02SAndroid Build Coastguard Worker    status, aggregate = service.complete_session(session_id)
258*14675a02SAndroid Build Coastguard Worker    self.assertEqual(
259*14675a02SAndroid Build Coastguard Worker        status,
260*14675a02SAndroid Build Coastguard Worker        aggregations.SessionStatus(
261*14675a02SAndroid Build Coastguard Worker            status=aggregations.AggregationStatus.FAILED,
262*14675a02SAndroid Build Coastguard Worker            num_clients_completed=1,
263*14675a02SAndroid Build Coastguard Worker            num_inputs_discarded=1))
264*14675a02SAndroid Build Coastguard Worker    self.assertIsNone(aggregate)
265*14675a02SAndroid Build Coastguard Worker
266*14675a02SAndroid Build Coastguard Worker  def test_complete_session_with_missing_session_id(self):
267*14675a02SAndroid Build Coastguard Worker    service = aggregations.Service(lambda: FORWARDING_INFO,
268*14675a02SAndroid Build Coastguard Worker                                   self.mock_media_service)
269*14675a02SAndroid Build Coastguard Worker    with self.assertRaises(KeyError):
270*14675a02SAndroid Build Coastguard Worker      service.complete_session('does-not-exist')
271*14675a02SAndroid Build Coastguard Worker
272*14675a02SAndroid Build Coastguard Worker  def test_abort_session_with_no_uploads(self):
273*14675a02SAndroid Build Coastguard Worker    service = aggregations.Service(lambda: FORWARDING_INFO,
274*14675a02SAndroid Build Coastguard Worker                                   self.mock_media_service)
275*14675a02SAndroid Build Coastguard Worker    session_id = service.create_session(AGGREGATION_REQUIREMENTS)
276*14675a02SAndroid Build Coastguard Worker    self.assertEqual(
277*14675a02SAndroid Build Coastguard Worker        service.abort_session(session_id),
278*14675a02SAndroid Build Coastguard Worker        aggregations.SessionStatus(
279*14675a02SAndroid Build Coastguard Worker            status=aggregations.AggregationStatus.ABORTED))
280*14675a02SAndroid Build Coastguard Worker
281*14675a02SAndroid Build Coastguard Worker  def test_abort_session_with_uploads(self):
282*14675a02SAndroid Build Coastguard Worker    service = aggregations.Service(lambda: FORWARDING_INFO,
283*14675a02SAndroid Build Coastguard Worker                                   self.mock_media_service)
284*14675a02SAndroid Build Coastguard Worker    session_id = service.create_session(AGGREGATION_REQUIREMENTS)
285*14675a02SAndroid Build Coastguard Worker    tokens = service.pre_authorize_clients(session_id, 3)
286*14675a02SAndroid Build Coastguard Worker
287*14675a02SAndroid Build Coastguard Worker    # Upload results for one client.
288*14675a02SAndroid Build Coastguard Worker    self.mock_media_service.register_upload.return_value = 'upload1'
289*14675a02SAndroid Build Coastguard Worker    operation = service.start_aggregation_data_upload(
290*14675a02SAndroid Build Coastguard Worker        aggregations_pb2.StartAggregationDataUploadRequest(
291*14675a02SAndroid Build Coastguard Worker            aggregation_id=session_id, authorization_token=tokens[0]))
292*14675a02SAndroid Build Coastguard Worker    self.assertTrue(operation.done)
293*14675a02SAndroid Build Coastguard Worker    start_upload_response = (
294*14675a02SAndroid Build Coastguard Worker        aggregations_pb2.StartAggregationDataUploadResponse())
295*14675a02SAndroid Build Coastguard Worker    operation.response.Unpack(start_upload_response)
296*14675a02SAndroid Build Coastguard Worker    service.submit_aggregation_result(
297*14675a02SAndroid Build Coastguard Worker        aggregations_pb2.SubmitAggregationResultRequest(
298*14675a02SAndroid Build Coastguard Worker            aggregation_id=session_id,
299*14675a02SAndroid Build Coastguard Worker            client_token=start_upload_response.client_token,
300*14675a02SAndroid Build Coastguard Worker            resource_name=start_upload_response.resource.resource_name))
301*14675a02SAndroid Build Coastguard Worker
302*14675a02SAndroid Build Coastguard Worker    # Start a partial upload from a second client.
303*14675a02SAndroid Build Coastguard Worker    self.mock_media_service.register_upload.return_value = 'upload2'
304*14675a02SAndroid Build Coastguard Worker    service.start_aggregation_data_upload(
305*14675a02SAndroid Build Coastguard Worker        aggregations_pb2.StartAggregationDataUploadRequest(
306*14675a02SAndroid Build Coastguard Worker            aggregation_id=session_id, authorization_token=tokens[1]))
307*14675a02SAndroid Build Coastguard Worker
308*14675a02SAndroid Build Coastguard Worker    # Abort the session. The pending client should be treated as failed.
309*14675a02SAndroid Build Coastguard Worker    self.assertEqual(
310*14675a02SAndroid Build Coastguard Worker        service.abort_session(session_id),
311*14675a02SAndroid Build Coastguard Worker        aggregations.SessionStatus(
312*14675a02SAndroid Build Coastguard Worker            status=aggregations.AggregationStatus.ABORTED,
313*14675a02SAndroid Build Coastguard Worker            num_clients_completed=1,
314*14675a02SAndroid Build Coastguard Worker            num_clients_aborted=1,
315*14675a02SAndroid Build Coastguard Worker            num_inputs_discarded=1))
316*14675a02SAndroid Build Coastguard Worker    # The registered upload for the second client should have been finalized.
317*14675a02SAndroid Build Coastguard Worker    self.mock_media_service.finalize_upload.assert_called_with('upload2')
318*14675a02SAndroid Build Coastguard Worker    # get_session_status should no longer return results.
319*14675a02SAndroid Build Coastguard Worker    with self.assertRaises(KeyError):
320*14675a02SAndroid Build Coastguard Worker      service.get_session_status(session_id)
321*14675a02SAndroid Build Coastguard Worker
322*14675a02SAndroid Build Coastguard Worker  def test_abort_session_with_missing_session_id(self):
323*14675a02SAndroid Build Coastguard Worker    service = aggregations.Service(lambda: FORWARDING_INFO,
324*14675a02SAndroid Build Coastguard Worker                                   self.mock_media_service)
325*14675a02SAndroid Build Coastguard Worker    with self.assertRaises(KeyError):
326*14675a02SAndroid Build Coastguard Worker      service.abort_session('does-not-exist')
327*14675a02SAndroid Build Coastguard Worker
328*14675a02SAndroid Build Coastguard Worker  def test_get_session_status_with_missing_session_id(self):
329*14675a02SAndroid Build Coastguard Worker    service = aggregations.Service(lambda: FORWARDING_INFO,
330*14675a02SAndroid Build Coastguard Worker                                   self.mock_media_service)
331*14675a02SAndroid Build Coastguard Worker    with self.assertRaises(KeyError):
332*14675a02SAndroid Build Coastguard Worker      service.get_session_status('does-not-exist')
333*14675a02SAndroid Build Coastguard Worker
334*14675a02SAndroid Build Coastguard Worker  async def test_wait(self):
335*14675a02SAndroid Build Coastguard Worker    service = aggregations.Service(lambda: FORWARDING_INFO,
336*14675a02SAndroid Build Coastguard Worker                                   self.mock_media_service)
337*14675a02SAndroid Build Coastguard Worker    session_id = service.create_session(AGGREGATION_REQUIREMENTS)
338*14675a02SAndroid Build Coastguard Worker    task = asyncio.create_task(
339*14675a02SAndroid Build Coastguard Worker        service.wait(session_id, num_inputs_aggregated_and_included=1))
340*14675a02SAndroid Build Coastguard Worker    # The awaitable should not be done yet.
341*14675a02SAndroid Build Coastguard Worker    await asyncio.wait([task], timeout=0.1)
342*14675a02SAndroid Build Coastguard Worker    self.assertFalse(task.done())
343*14675a02SAndroid Build Coastguard Worker
344*14675a02SAndroid Build Coastguard Worker    # Upload results for one client.
345*14675a02SAndroid Build Coastguard Worker    tokens = service.pre_authorize_clients(session_id, 1)
346*14675a02SAndroid Build Coastguard Worker    self.mock_media_service.register_upload.return_value = 'upload'
347*14675a02SAndroid Build Coastguard Worker    operation = service.start_aggregation_data_upload(
348*14675a02SAndroid Build Coastguard Worker        aggregations_pb2.StartAggregationDataUploadRequest(
349*14675a02SAndroid Build Coastguard Worker            aggregation_id=session_id, authorization_token=tokens[0]))
350*14675a02SAndroid Build Coastguard Worker    self.assertTrue(operation.done)
351*14675a02SAndroid Build Coastguard Worker    start_upload_response = (
352*14675a02SAndroid Build Coastguard Worker        aggregations_pb2.StartAggregationDataUploadResponse())
353*14675a02SAndroid Build Coastguard Worker    operation.response.Unpack(start_upload_response)
354*14675a02SAndroid Build Coastguard Worker    service.submit_aggregation_result(
355*14675a02SAndroid Build Coastguard Worker        aggregations_pb2.SubmitAggregationResultRequest(
356*14675a02SAndroid Build Coastguard Worker            aggregation_id=session_id,
357*14675a02SAndroid Build Coastguard Worker            client_token=start_upload_response.client_token,
358*14675a02SAndroid Build Coastguard Worker            resource_name=start_upload_response.resource.resource_name))
359*14675a02SAndroid Build Coastguard Worker
360*14675a02SAndroid Build Coastguard Worker    # The awaitable should now return.
361*14675a02SAndroid Build Coastguard Worker    await asyncio.wait([task], timeout=1)
362*14675a02SAndroid Build Coastguard Worker    self.assertTrue(task.done())
363*14675a02SAndroid Build Coastguard Worker    self.assertEqual(
364*14675a02SAndroid Build Coastguard Worker        task.result(),
365*14675a02SAndroid Build Coastguard Worker        aggregations.SessionStatus(
366*14675a02SAndroid Build Coastguard Worker            status=aggregations.AggregationStatus.PENDING,
367*14675a02SAndroid Build Coastguard Worker            num_clients_completed=1,
368*14675a02SAndroid Build Coastguard Worker            num_inputs_aggregated_and_included=1))
369*14675a02SAndroid Build Coastguard Worker
370*14675a02SAndroid Build Coastguard Worker  async def test_wait_already_satisfied(self):
371*14675a02SAndroid Build Coastguard Worker    service = aggregations.Service(lambda: FORWARDING_INFO,
372*14675a02SAndroid Build Coastguard Worker                                   self.mock_media_service)
373*14675a02SAndroid Build Coastguard Worker    session_id = service.create_session(AGGREGATION_REQUIREMENTS)
374*14675a02SAndroid Build Coastguard Worker
375*14675a02SAndroid Build Coastguard Worker    # Upload results for one client.
376*14675a02SAndroid Build Coastguard Worker    tokens = service.pre_authorize_clients(session_id, 1)
377*14675a02SAndroid Build Coastguard Worker    self.mock_media_service.register_upload.return_value = 'upload'
378*14675a02SAndroid Build Coastguard Worker    operation = service.start_aggregation_data_upload(
379*14675a02SAndroid Build Coastguard Worker        aggregations_pb2.StartAggregationDataUploadRequest(
380*14675a02SAndroid Build Coastguard Worker            aggregation_id=session_id, authorization_token=tokens[0]))
381*14675a02SAndroid Build Coastguard Worker    self.assertTrue(operation.done)
382*14675a02SAndroid Build Coastguard Worker    start_upload_response = (
383*14675a02SAndroid Build Coastguard Worker        aggregations_pb2.StartAggregationDataUploadResponse())
384*14675a02SAndroid Build Coastguard Worker    operation.response.Unpack(start_upload_response)
385*14675a02SAndroid Build Coastguard Worker    service.submit_aggregation_result(
386*14675a02SAndroid Build Coastguard Worker        aggregations_pb2.SubmitAggregationResultRequest(
387*14675a02SAndroid Build Coastguard Worker            aggregation_id=session_id,
388*14675a02SAndroid Build Coastguard Worker            client_token=start_upload_response.client_token,
389*14675a02SAndroid Build Coastguard Worker            resource_name=start_upload_response.resource.resource_name))
390*14675a02SAndroid Build Coastguard Worker
391*14675a02SAndroid Build Coastguard Worker    # Since a client has already reported, the condition should already be
392*14675a02SAndroid Build Coastguard Worker    # satisfied.
393*14675a02SAndroid Build Coastguard Worker    task = asyncio.create_task(
394*14675a02SAndroid Build Coastguard Worker        service.wait(session_id, num_inputs_aggregated_and_included=1))
395*14675a02SAndroid Build Coastguard Worker    await asyncio.wait([task], timeout=1)
396*14675a02SAndroid Build Coastguard Worker    self.assertTrue(task.done())
397*14675a02SAndroid Build Coastguard Worker    self.assertEqual(
398*14675a02SAndroid Build Coastguard Worker        task.result(),
399*14675a02SAndroid Build Coastguard Worker        aggregations.SessionStatus(
400*14675a02SAndroid Build Coastguard Worker            status=aggregations.AggregationStatus.PENDING,
401*14675a02SAndroid Build Coastguard Worker            num_clients_completed=1,
402*14675a02SAndroid Build Coastguard Worker            num_inputs_aggregated_and_included=1))
403*14675a02SAndroid Build Coastguard Worker
404*14675a02SAndroid Build Coastguard Worker  async def test_wait_with_abort(self):
405*14675a02SAndroid Build Coastguard Worker    service = aggregations.Service(lambda: FORWARDING_INFO,
406*14675a02SAndroid Build Coastguard Worker                                   self.mock_media_service)
407*14675a02SAndroid Build Coastguard Worker    session_id = service.create_session(AGGREGATION_REQUIREMENTS)
408*14675a02SAndroid Build Coastguard Worker    task = asyncio.create_task(
409*14675a02SAndroid Build Coastguard Worker        service.wait(session_id, num_inputs_aggregated_and_included=1))
410*14675a02SAndroid Build Coastguard Worker    # The awaitable should not be done yet.
411*14675a02SAndroid Build Coastguard Worker    await asyncio.wait([task], timeout=0.1)
412*14675a02SAndroid Build Coastguard Worker    self.assertFalse(task.done())
413*14675a02SAndroid Build Coastguard Worker
414*14675a02SAndroid Build Coastguard Worker    # The awaitable should return once the session is aborted.
415*14675a02SAndroid Build Coastguard Worker    status = service.abort_session(session_id)
416*14675a02SAndroid Build Coastguard Worker    await asyncio.wait([task], timeout=1)
417*14675a02SAndroid Build Coastguard Worker    self.assertTrue(task.done())
418*14675a02SAndroid Build Coastguard Worker    self.assertEqual(task.result(), status)
419*14675a02SAndroid Build Coastguard Worker
420*14675a02SAndroid Build Coastguard Worker  @mock.patch.object(
421*14675a02SAndroid Build Coastguard Worker      aggregation_protocols,
422*14675a02SAndroid Build Coastguard Worker      'create_simple_aggregation_protocol',
423*14675a02SAndroid Build Coastguard Worker      autospec=True)
424*14675a02SAndroid Build Coastguard Worker  async def test_wait_with_protocol_abort(self,
425*14675a02SAndroid Build Coastguard Worker                                          mock_create_simple_agg_protocol):
426*14675a02SAndroid Build Coastguard Worker    # Use a mock since it's not easy to cause the AggregationProtocol to abort.
427*14675a02SAndroid Build Coastguard Worker    mock_agg_protocol = mock.create_autospec(
428*14675a02SAndroid Build Coastguard Worker        aggregation_protocol.AggregationProtocol, instance=True)
429*14675a02SAndroid Build Coastguard Worker    mock_create_simple_agg_protocol.return_value = mock_agg_protocol
430*14675a02SAndroid Build Coastguard Worker    mock_agg_protocol.GetStatus.return_value = apm_pb2.StatusMessage(
431*14675a02SAndroid Build Coastguard Worker        num_clients_aborted=1234)
432*14675a02SAndroid Build Coastguard Worker
433*14675a02SAndroid Build Coastguard Worker    service = aggregations.Service(lambda: FORWARDING_INFO,
434*14675a02SAndroid Build Coastguard Worker                                   self.mock_media_service)
435*14675a02SAndroid Build Coastguard Worker    session_id = service.create_session(AGGREGATION_REQUIREMENTS)
436*14675a02SAndroid Build Coastguard Worker    task = asyncio.create_task(
437*14675a02SAndroid Build Coastguard Worker        service.wait(session_id, num_inputs_aggregated_and_included=1))
438*14675a02SAndroid Build Coastguard Worker    # The awaitable should not be done yet.
439*14675a02SAndroid Build Coastguard Worker    await asyncio.wait([task], timeout=0.1)
440*14675a02SAndroid Build Coastguard Worker    self.assertFalse(task.done())
441*14675a02SAndroid Build Coastguard Worker
442*14675a02SAndroid Build Coastguard Worker    # The awaitable should return once the AggregationProtocol aborts.
443*14675a02SAndroid Build Coastguard Worker    callback = mock_create_simple_agg_protocol.call_args.args[1]
444*14675a02SAndroid Build Coastguard Worker    callback.OnAbort(absl_status.unknown_error('message'))
445*14675a02SAndroid Build Coastguard Worker    await asyncio.wait([task], timeout=1)
446*14675a02SAndroid Build Coastguard Worker    self.assertTrue(task.done())
447*14675a02SAndroid Build Coastguard Worker    self.assertEqual(
448*14675a02SAndroid Build Coastguard Worker        task.result(),
449*14675a02SAndroid Build Coastguard Worker        aggregations.SessionStatus(
450*14675a02SAndroid Build Coastguard Worker            status=aggregations.AggregationStatus.FAILED,
451*14675a02SAndroid Build Coastguard Worker            num_clients_aborted=1234))
452*14675a02SAndroid Build Coastguard Worker
453*14675a02SAndroid Build Coastguard Worker  async def test_wait_with_complete(self):
454*14675a02SAndroid Build Coastguard Worker    service = aggregations.Service(lambda: FORWARDING_INFO,
455*14675a02SAndroid Build Coastguard Worker                                   self.mock_media_service)
456*14675a02SAndroid Build Coastguard Worker    session_id = service.create_session(
457*14675a02SAndroid Build Coastguard Worker        aggregations.AggregationRequirements(
458*14675a02SAndroid Build Coastguard Worker            minimum_clients_in_server_published_aggregate=0,
459*14675a02SAndroid Build Coastguard Worker            plan=AGGREGATION_REQUIREMENTS.plan))
460*14675a02SAndroid Build Coastguard Worker    task = asyncio.create_task(
461*14675a02SAndroid Build Coastguard Worker        service.wait(session_id, num_inputs_aggregated_and_included=1))
462*14675a02SAndroid Build Coastguard Worker    # The awaitable should not be done yet.
463*14675a02SAndroid Build Coastguard Worker    await asyncio.wait([task], timeout=0.1)
464*14675a02SAndroid Build Coastguard Worker    self.assertFalse(task.done())
465*14675a02SAndroid Build Coastguard Worker
466*14675a02SAndroid Build Coastguard Worker    # The awaitable should return once the session is completed.
467*14675a02SAndroid Build Coastguard Worker    status, _ = service.complete_session(session_id)
468*14675a02SAndroid Build Coastguard Worker    await asyncio.wait([task], timeout=1)
469*14675a02SAndroid Build Coastguard Worker    self.assertTrue(task.done())
470*14675a02SAndroid Build Coastguard Worker    self.assertEqual(task.result(), status)
471*14675a02SAndroid Build Coastguard Worker
472*14675a02SAndroid Build Coastguard Worker  async def test_wait_without_condition(self):
473*14675a02SAndroid Build Coastguard Worker    service = aggregations.Service(lambda: FORWARDING_INFO,
474*14675a02SAndroid Build Coastguard Worker                                   self.mock_media_service)
475*14675a02SAndroid Build Coastguard Worker    session_id = service.create_session(AGGREGATION_REQUIREMENTS)
476*14675a02SAndroid Build Coastguard Worker    task = asyncio.create_task(service.wait(session_id))
477*14675a02SAndroid Build Coastguard Worker    # If there are no conditions, the wait should be trivially satisfied.
478*14675a02SAndroid Build Coastguard Worker    await asyncio.wait([task], timeout=1)
479*14675a02SAndroid Build Coastguard Worker    self.assertTrue(task.done())
480*14675a02SAndroid Build Coastguard Worker    self.assertEqual(task.result(), service.get_session_status(session_id))
481*14675a02SAndroid Build Coastguard Worker
482*14675a02SAndroid Build Coastguard Worker  async def test_wait_with_missing_session_id(self):
483*14675a02SAndroid Build Coastguard Worker    service = aggregations.Service(lambda: FORWARDING_INFO,
484*14675a02SAndroid Build Coastguard Worker                                   self.mock_media_service)
485*14675a02SAndroid Build Coastguard Worker    task = asyncio.create_task(service.wait('does-not-exist'))
486*14675a02SAndroid Build Coastguard Worker    await asyncio.wait([task], timeout=1)
487*14675a02SAndroid Build Coastguard Worker    self.assertTrue(task.done())
488*14675a02SAndroid Build Coastguard Worker    self.assertIsInstance(task.exception(), KeyError)
489*14675a02SAndroid Build Coastguard Worker
490*14675a02SAndroid Build Coastguard Worker  def test_start_aggregation_data_upload(self):
491*14675a02SAndroid Build Coastguard Worker    service = aggregations.Service(lambda: FORWARDING_INFO,
492*14675a02SAndroid Build Coastguard Worker                                   self.mock_media_service)
493*14675a02SAndroid Build Coastguard Worker    session_id = service.create_session(AGGREGATION_REQUIREMENTS)
494*14675a02SAndroid Build Coastguard Worker    tokens = service.pre_authorize_clients(session_id, 1)
495*14675a02SAndroid Build Coastguard Worker    self.mock_media_service.register_upload.return_value = 'upload'
496*14675a02SAndroid Build Coastguard Worker    operation = service.start_aggregation_data_upload(
497*14675a02SAndroid Build Coastguard Worker        aggregations_pb2.StartAggregationDataUploadRequest(
498*14675a02SAndroid Build Coastguard Worker            aggregation_id=session_id, authorization_token=tokens[0]))
499*14675a02SAndroid Build Coastguard Worker    self.assertNotEmpty(operation.name)
500*14675a02SAndroid Build Coastguard Worker    self.assertTrue(operation.done)
501*14675a02SAndroid Build Coastguard Worker
502*14675a02SAndroid Build Coastguard Worker    metadata = aggregations_pb2.StartAggregationDataUploadMetadata()
503*14675a02SAndroid Build Coastguard Worker    operation.metadata.Unpack(metadata)
504*14675a02SAndroid Build Coastguard Worker    self.assertEqual(metadata,
505*14675a02SAndroid Build Coastguard Worker                     aggregations_pb2.StartAggregationDataUploadMetadata())
506*14675a02SAndroid Build Coastguard Worker
507*14675a02SAndroid Build Coastguard Worker    response = aggregations_pb2.StartAggregationDataUploadResponse()
508*14675a02SAndroid Build Coastguard Worker    operation.response.Unpack(response)
509*14675a02SAndroid Build Coastguard Worker    # The client token should be set and different from the authorization token.
510*14675a02SAndroid Build Coastguard Worker    self.assertNotEmpty(response.client_token)
511*14675a02SAndroid Build Coastguard Worker    self.assertNotEqual(response.client_token, tokens[0])
512*14675a02SAndroid Build Coastguard Worker    self.assertEqual(
513*14675a02SAndroid Build Coastguard Worker        response,
514*14675a02SAndroid Build Coastguard Worker        aggregations_pb2.StartAggregationDataUploadResponse(
515*14675a02SAndroid Build Coastguard Worker            aggregation_protocol_forwarding_info=FORWARDING_INFO,
516*14675a02SAndroid Build Coastguard Worker            resource=common_pb2.ByteStreamResource(
517*14675a02SAndroid Build Coastguard Worker                data_upload_forwarding_info=FORWARDING_INFO,
518*14675a02SAndroid Build Coastguard Worker                resource_name='upload'),
519*14675a02SAndroid Build Coastguard Worker            client_token=response.client_token))
520*14675a02SAndroid Build Coastguard Worker    self.assertEqual(
521*14675a02SAndroid Build Coastguard Worker        service.get_session_status(session_id),
522*14675a02SAndroid Build Coastguard Worker        aggregations.SessionStatus(
523*14675a02SAndroid Build Coastguard Worker            status=aggregations.AggregationStatus.PENDING,
524*14675a02SAndroid Build Coastguard Worker            num_clients_pending=1))
525*14675a02SAndroid Build Coastguard Worker
526*14675a02SAndroid Build Coastguard Worker  def test_start_aggregagation_data_upload_with_missing_session_id(self):
527*14675a02SAndroid Build Coastguard Worker    service = aggregations.Service(lambda: FORWARDING_INFO,
528*14675a02SAndroid Build Coastguard Worker                                   self.mock_media_service)
529*14675a02SAndroid Build Coastguard Worker    session_id = service.create_session(AGGREGATION_REQUIREMENTS)
530*14675a02SAndroid Build Coastguard Worker    tokens = service.pre_authorize_clients(session_id, 1)
531*14675a02SAndroid Build Coastguard Worker    with self.assertRaises(http_actions.HttpError) as cm:
532*14675a02SAndroid Build Coastguard Worker      service.start_aggregation_data_upload(
533*14675a02SAndroid Build Coastguard Worker          aggregations_pb2.StartAggregationDataUploadRequest(
534*14675a02SAndroid Build Coastguard Worker              aggregation_id='does-not-exist', authorization_token=tokens[0]))
535*14675a02SAndroid Build Coastguard Worker    self.assertEqual(cm.exception.code, http.HTTPStatus.NOT_FOUND)
536*14675a02SAndroid Build Coastguard Worker    self.assertEqual(
537*14675a02SAndroid Build Coastguard Worker        service.get_session_status(session_id),
538*14675a02SAndroid Build Coastguard Worker        aggregations.SessionStatus(
539*14675a02SAndroid Build Coastguard Worker            status=aggregations.AggregationStatus.PENDING))
540*14675a02SAndroid Build Coastguard Worker
541*14675a02SAndroid Build Coastguard Worker  def test_start_aggregagation_data_upload_with_invalid_token(self):
542*14675a02SAndroid Build Coastguard Worker    service = aggregations.Service(lambda: FORWARDING_INFO,
543*14675a02SAndroid Build Coastguard Worker                                   self.mock_media_service)
544*14675a02SAndroid Build Coastguard Worker    session_id = service.create_session(AGGREGATION_REQUIREMENTS)
545*14675a02SAndroid Build Coastguard Worker    with self.assertRaises(http_actions.HttpError) as cm:
546*14675a02SAndroid Build Coastguard Worker      service.start_aggregation_data_upload(
547*14675a02SAndroid Build Coastguard Worker          aggregations_pb2.StartAggregationDataUploadRequest(
548*14675a02SAndroid Build Coastguard Worker              aggregation_id=session_id, authorization_token='does-not-exist'))
549*14675a02SAndroid Build Coastguard Worker    self.assertEqual(cm.exception.code, http.HTTPStatus.UNAUTHORIZED)
550*14675a02SAndroid Build Coastguard Worker    self.assertEqual(
551*14675a02SAndroid Build Coastguard Worker        service.get_session_status(session_id),
552*14675a02SAndroid Build Coastguard Worker        aggregations.SessionStatus(
553*14675a02SAndroid Build Coastguard Worker            status=aggregations.AggregationStatus.PENDING))
554*14675a02SAndroid Build Coastguard Worker
555*14675a02SAndroid Build Coastguard Worker  def test_start_aggregagation_data_upload_twice(self):
556*14675a02SAndroid Build Coastguard Worker    service = aggregations.Service(lambda: FORWARDING_INFO,
557*14675a02SAndroid Build Coastguard Worker                                   self.mock_media_service)
558*14675a02SAndroid Build Coastguard Worker    session_id = service.create_session(AGGREGATION_REQUIREMENTS)
559*14675a02SAndroid Build Coastguard Worker    tokens = service.pre_authorize_clients(session_id, 1)
560*14675a02SAndroid Build Coastguard Worker    service.start_aggregation_data_upload(
561*14675a02SAndroid Build Coastguard Worker        aggregations_pb2.StartAggregationDataUploadRequest(
562*14675a02SAndroid Build Coastguard Worker            aggregation_id=session_id, authorization_token=tokens[0]))
563*14675a02SAndroid Build Coastguard Worker    with self.assertRaises(http_actions.HttpError) as cm:
564*14675a02SAndroid Build Coastguard Worker      service.start_aggregation_data_upload(
565*14675a02SAndroid Build Coastguard Worker          aggregations_pb2.StartAggregationDataUploadRequest(
566*14675a02SAndroid Build Coastguard Worker              aggregation_id=session_id, authorization_token=tokens[0]))
567*14675a02SAndroid Build Coastguard Worker    self.assertEqual(cm.exception.code, http.HTTPStatus.UNAUTHORIZED)
568*14675a02SAndroid Build Coastguard Worker
569*14675a02SAndroid Build Coastguard Worker  def test_submit_aggregation_result(self):
570*14675a02SAndroid Build Coastguard Worker    service = aggregations.Service(lambda: FORWARDING_INFO,
571*14675a02SAndroid Build Coastguard Worker                                   self.mock_media_service)
572*14675a02SAndroid Build Coastguard Worker    session_id = service.create_session(AGGREGATION_REQUIREMENTS)
573*14675a02SAndroid Build Coastguard Worker
574*14675a02SAndroid Build Coastguard Worker    # Upload results from the client.
575*14675a02SAndroid Build Coastguard Worker    tokens = service.pre_authorize_clients(session_id, 1)
576*14675a02SAndroid Build Coastguard Worker
577*14675a02SAndroid Build Coastguard Worker    operation = service.start_aggregation_data_upload(
578*14675a02SAndroid Build Coastguard Worker        aggregations_pb2.StartAggregationDataUploadRequest(
579*14675a02SAndroid Build Coastguard Worker            aggregation_id=session_id, authorization_token=tokens[0]))
580*14675a02SAndroid Build Coastguard Worker    self.assertTrue(operation.done)
581*14675a02SAndroid Build Coastguard Worker    start_upload_response = (
582*14675a02SAndroid Build Coastguard Worker        aggregations_pb2.StartAggregationDataUploadResponse())
583*14675a02SAndroid Build Coastguard Worker    operation.response.Unpack(start_upload_response)
584*14675a02SAndroid Build Coastguard Worker
585*14675a02SAndroid Build Coastguard Worker    submit_response = service.submit_aggregation_result(
586*14675a02SAndroid Build Coastguard Worker        aggregations_pb2.SubmitAggregationResultRequest(
587*14675a02SAndroid Build Coastguard Worker            aggregation_id=session_id,
588*14675a02SAndroid Build Coastguard Worker            client_token=start_upload_response.client_token,
589*14675a02SAndroid Build Coastguard Worker            resource_name=start_upload_response.resource.resource_name))
590*14675a02SAndroid Build Coastguard Worker    self.assertEqual(submit_response,
591*14675a02SAndroid Build Coastguard Worker                     aggregations_pb2.SubmitAggregationResultResponse())
592*14675a02SAndroid Build Coastguard Worker    self.mock_media_service.finalize_upload.assert_called_with(
593*14675a02SAndroid Build Coastguard Worker        start_upload_response.resource.resource_name)
594*14675a02SAndroid Build Coastguard Worker    self.assertEqual(
595*14675a02SAndroid Build Coastguard Worker        service.get_session_status(session_id),
596*14675a02SAndroid Build Coastguard Worker        aggregations.SessionStatus(
597*14675a02SAndroid Build Coastguard Worker            status=aggregations.AggregationStatus.PENDING,
598*14675a02SAndroid Build Coastguard Worker            num_clients_completed=1,
599*14675a02SAndroid Build Coastguard Worker            num_inputs_aggregated_and_included=1))
600*14675a02SAndroid Build Coastguard Worker
601*14675a02SAndroid Build Coastguard Worker  def test_submit_aggregation_result_with_invalid_client_input(self):
602*14675a02SAndroid Build Coastguard Worker    service = aggregations.Service(lambda: FORWARDING_INFO,
603*14675a02SAndroid Build Coastguard Worker                                   self.mock_media_service)
604*14675a02SAndroid Build Coastguard Worker    session_id = service.create_session(AGGREGATION_REQUIREMENTS)
605*14675a02SAndroid Build Coastguard Worker
606*14675a02SAndroid Build Coastguard Worker    tokens = service.pre_authorize_clients(session_id, 1)
607*14675a02SAndroid Build Coastguard Worker    operation = service.start_aggregation_data_upload(
608*14675a02SAndroid Build Coastguard Worker        aggregations_pb2.StartAggregationDataUploadRequest(
609*14675a02SAndroid Build Coastguard Worker            aggregation_id=session_id, authorization_token=tokens[0]))
610*14675a02SAndroid Build Coastguard Worker    self.assertTrue(operation.done)
611*14675a02SAndroid Build Coastguard Worker    start_upload_response = (
612*14675a02SAndroid Build Coastguard Worker        aggregations_pb2.StartAggregationDataUploadResponse())
613*14675a02SAndroid Build Coastguard Worker    operation.response.Unpack(start_upload_response)
614*14675a02SAndroid Build Coastguard Worker
615*14675a02SAndroid Build Coastguard Worker    self.mock_media_service.finalize_upload.return_value = b'invalid'
616*14675a02SAndroid Build Coastguard Worker    with self.assertRaises(http_actions.HttpError):
617*14675a02SAndroid Build Coastguard Worker      service.submit_aggregation_result(
618*14675a02SAndroid Build Coastguard Worker          aggregations_pb2.SubmitAggregationResultRequest(
619*14675a02SAndroid Build Coastguard Worker              aggregation_id=session_id,
620*14675a02SAndroid Build Coastguard Worker              client_token=start_upload_response.client_token,
621*14675a02SAndroid Build Coastguard Worker              resource_name=start_upload_response.resource.resource_name))
622*14675a02SAndroid Build Coastguard Worker    self.assertEqual(
623*14675a02SAndroid Build Coastguard Worker        service.get_session_status(session_id),
624*14675a02SAndroid Build Coastguard Worker        aggregations.SessionStatus(
625*14675a02SAndroid Build Coastguard Worker            status=aggregations.AggregationStatus.PENDING,
626*14675a02SAndroid Build Coastguard Worker            num_clients_failed=1))
627*14675a02SAndroid Build Coastguard Worker
628*14675a02SAndroid Build Coastguard Worker  def test_submit_aggregation_result_with_missing_session_id(self):
629*14675a02SAndroid Build Coastguard Worker    service = aggregations.Service(lambda: FORWARDING_INFO,
630*14675a02SAndroid Build Coastguard Worker                                   self.mock_media_service)
631*14675a02SAndroid Build Coastguard Worker    session_id = service.create_session(AGGREGATION_REQUIREMENTS)
632*14675a02SAndroid Build Coastguard Worker    with self.assertRaises(http_actions.HttpError) as cm:
633*14675a02SAndroid Build Coastguard Worker      service.submit_aggregation_result(
634*14675a02SAndroid Build Coastguard Worker          aggregations_pb2.SubmitAggregationResultRequest(
635*14675a02SAndroid Build Coastguard Worker              aggregation_id='does-not-exist',
636*14675a02SAndroid Build Coastguard Worker              client_token='client-token',
637*14675a02SAndroid Build Coastguard Worker              resource_name='upload-id'))
638*14675a02SAndroid Build Coastguard Worker    self.assertEqual(cm.exception.code, http.HTTPStatus.NOT_FOUND)
639*14675a02SAndroid Build Coastguard Worker    self.assertEqual(
640*14675a02SAndroid Build Coastguard Worker        service.get_session_status(session_id),
641*14675a02SAndroid Build Coastguard Worker        aggregations.SessionStatus(
642*14675a02SAndroid Build Coastguard Worker            status=aggregations.AggregationStatus.PENDING))
643*14675a02SAndroid Build Coastguard Worker
644*14675a02SAndroid Build Coastguard Worker  def test_submit_aggregation_result_with_invalid_token(self):
645*14675a02SAndroid Build Coastguard Worker    service = aggregations.Service(lambda: FORWARDING_INFO,
646*14675a02SAndroid Build Coastguard Worker                                   self.mock_media_service)
647*14675a02SAndroid Build Coastguard Worker    session_id = service.create_session(AGGREGATION_REQUIREMENTS)
648*14675a02SAndroid Build Coastguard Worker    with self.assertRaises(http_actions.HttpError) as cm:
649*14675a02SAndroid Build Coastguard Worker      service.submit_aggregation_result(
650*14675a02SAndroid Build Coastguard Worker          aggregations_pb2.SubmitAggregationResultRequest(
651*14675a02SAndroid Build Coastguard Worker              aggregation_id=session_id,
652*14675a02SAndroid Build Coastguard Worker              client_token='does-not-exist',
653*14675a02SAndroid Build Coastguard Worker              resource_name='upload-id'))
654*14675a02SAndroid Build Coastguard Worker    self.assertEqual(cm.exception.code, http.HTTPStatus.UNAUTHORIZED)
655*14675a02SAndroid Build Coastguard Worker    self.assertEqual(
656*14675a02SAndroid Build Coastguard Worker        service.get_session_status(session_id),
657*14675a02SAndroid Build Coastguard Worker        aggregations.SessionStatus(
658*14675a02SAndroid Build Coastguard Worker            status=aggregations.AggregationStatus.PENDING))
659*14675a02SAndroid Build Coastguard Worker
660*14675a02SAndroid Build Coastguard Worker  def test_submit_aggregation_result_with_finalize_upload_error(self):
661*14675a02SAndroid Build Coastguard Worker    service = aggregations.Service(lambda: FORWARDING_INFO,
662*14675a02SAndroid Build Coastguard Worker                                   self.mock_media_service)
663*14675a02SAndroid Build Coastguard Worker    session_id = service.create_session(AGGREGATION_REQUIREMENTS)
664*14675a02SAndroid Build Coastguard Worker    tokens = service.pre_authorize_clients(session_id, 1)
665*14675a02SAndroid Build Coastguard Worker    operation = service.start_aggregation_data_upload(
666*14675a02SAndroid Build Coastguard Worker        aggregations_pb2.StartAggregationDataUploadRequest(
667*14675a02SAndroid Build Coastguard Worker            aggregation_id=session_id, authorization_token=tokens[0]))
668*14675a02SAndroid Build Coastguard Worker    self.assertTrue(operation.done)
669*14675a02SAndroid Build Coastguard Worker    start_upload_response = (
670*14675a02SAndroid Build Coastguard Worker        aggregations_pb2.StartAggregationDataUploadResponse())
671*14675a02SAndroid Build Coastguard Worker    operation.response.Unpack(start_upload_response)
672*14675a02SAndroid Build Coastguard Worker
673*14675a02SAndroid Build Coastguard Worker    # If the resource_name doesn't correspond to a registered upload,
674*14675a02SAndroid Build Coastguard Worker    # finalize_upload will raise a KeyError.
675*14675a02SAndroid Build Coastguard Worker    self.mock_media_service.finalize_upload.side_effect = KeyError()
676*14675a02SAndroid Build Coastguard Worker    with self.assertRaises(http_actions.HttpError) as cm:
677*14675a02SAndroid Build Coastguard Worker      service.submit_aggregation_result(
678*14675a02SAndroid Build Coastguard Worker          aggregations_pb2.SubmitAggregationResultRequest(
679*14675a02SAndroid Build Coastguard Worker              aggregation_id=session_id,
680*14675a02SAndroid Build Coastguard Worker              client_token=start_upload_response.client_token,
681*14675a02SAndroid Build Coastguard Worker              resource_name=start_upload_response.resource.resource_name))
682*14675a02SAndroid Build Coastguard Worker    self.assertEqual(cm.exception.code, http.HTTPStatus.INTERNAL_SERVER_ERROR)
683*14675a02SAndroid Build Coastguard Worker    self.assertEqual(
684*14675a02SAndroid Build Coastguard Worker        service.get_session_status(session_id),
685*14675a02SAndroid Build Coastguard Worker        aggregations.SessionStatus(
686*14675a02SAndroid Build Coastguard Worker            status=aggregations.AggregationStatus.PENDING,
687*14675a02SAndroid Build Coastguard Worker            num_clients_failed=1))
688*14675a02SAndroid Build Coastguard Worker
689*14675a02SAndroid Build Coastguard Worker  def test_submit_aggregation_result_with_unuploaded_resource(self):
690*14675a02SAndroid Build Coastguard Worker    service = aggregations.Service(lambda: FORWARDING_INFO,
691*14675a02SAndroid Build Coastguard Worker                                   self.mock_media_service)
692*14675a02SAndroid Build Coastguard Worker    session_id = service.create_session(AGGREGATION_REQUIREMENTS)
693*14675a02SAndroid Build Coastguard Worker    tokens = service.pre_authorize_clients(session_id, 1)
694*14675a02SAndroid Build Coastguard Worker    operation = service.start_aggregation_data_upload(
695*14675a02SAndroid Build Coastguard Worker        aggregations_pb2.StartAggregationDataUploadRequest(
696*14675a02SAndroid Build Coastguard Worker            aggregation_id=session_id, authorization_token=tokens[0]))
697*14675a02SAndroid Build Coastguard Worker    self.assertTrue(operation.done)
698*14675a02SAndroid Build Coastguard Worker    start_upload_response = (
699*14675a02SAndroid Build Coastguard Worker        aggregations_pb2.StartAggregationDataUploadResponse())
700*14675a02SAndroid Build Coastguard Worker    operation.response.Unpack(start_upload_response)
701*14675a02SAndroid Build Coastguard Worker
702*14675a02SAndroid Build Coastguard Worker    # If the resource_name is valid but no resource was uploaded,
703*14675a02SAndroid Build Coastguard Worker    # finalize_resource will return None.
704*14675a02SAndroid Build Coastguard Worker    self.mock_media_service.finalize_upload.return_value = None
705*14675a02SAndroid Build Coastguard Worker    with self.assertRaises(http_actions.HttpError) as cm:
706*14675a02SAndroid Build Coastguard Worker      service.submit_aggregation_result(
707*14675a02SAndroid Build Coastguard Worker          aggregations_pb2.SubmitAggregationResultRequest(
708*14675a02SAndroid Build Coastguard Worker              aggregation_id=session_id,
709*14675a02SAndroid Build Coastguard Worker              client_token=start_upload_response.client_token,
710*14675a02SAndroid Build Coastguard Worker              resource_name=start_upload_response.resource.resource_name))
711*14675a02SAndroid Build Coastguard Worker    self.assertEqual(cm.exception.code, http.HTTPStatus.BAD_REQUEST)
712*14675a02SAndroid Build Coastguard Worker    self.assertEqual(
713*14675a02SAndroid Build Coastguard Worker        service.get_session_status(session_id),
714*14675a02SAndroid Build Coastguard Worker        aggregations.SessionStatus(
715*14675a02SAndroid Build Coastguard Worker            status=aggregations.AggregationStatus.PENDING,
716*14675a02SAndroid Build Coastguard Worker            num_clients_failed=1))
717*14675a02SAndroid Build Coastguard Worker
718*14675a02SAndroid Build Coastguard Worker  def test_abort_aggregation(self):
719*14675a02SAndroid Build Coastguard Worker    service = aggregations.Service(lambda: FORWARDING_INFO,
720*14675a02SAndroid Build Coastguard Worker                                   self.mock_media_service)
721*14675a02SAndroid Build Coastguard Worker    session_id = service.create_session(AGGREGATION_REQUIREMENTS)
722*14675a02SAndroid Build Coastguard Worker    tokens = service.pre_authorize_clients(session_id, 1)
723*14675a02SAndroid Build Coastguard Worker    operation = service.start_aggregation_data_upload(
724*14675a02SAndroid Build Coastguard Worker        aggregations_pb2.StartAggregationDataUploadRequest(
725*14675a02SAndroid Build Coastguard Worker            aggregation_id=session_id, authorization_token=tokens[0]))
726*14675a02SAndroid Build Coastguard Worker    self.assertTrue(operation.done)
727*14675a02SAndroid Build Coastguard Worker    start_upload_response = (
728*14675a02SAndroid Build Coastguard Worker        aggregations_pb2.StartAggregationDataUploadResponse())
729*14675a02SAndroid Build Coastguard Worker    operation.response.Unpack(start_upload_response)
730*14675a02SAndroid Build Coastguard Worker    self.assertEqual(
731*14675a02SAndroid Build Coastguard Worker        service.abort_aggregation(
732*14675a02SAndroid Build Coastguard Worker            aggregations_pb2.AbortAggregationRequest(
733*14675a02SAndroid Build Coastguard Worker                aggregation_id=session_id,
734*14675a02SAndroid Build Coastguard Worker                client_token=start_upload_response.client_token)),
735*14675a02SAndroid Build Coastguard Worker        aggregations_pb2.AbortAggregationResponse())
736*14675a02SAndroid Build Coastguard Worker    self.assertEqual(
737*14675a02SAndroid Build Coastguard Worker        service.get_session_status(session_id),
738*14675a02SAndroid Build Coastguard Worker        aggregations.SessionStatus(
739*14675a02SAndroid Build Coastguard Worker            status=aggregations.AggregationStatus.PENDING,
740*14675a02SAndroid Build Coastguard Worker            num_clients_completed=1,
741*14675a02SAndroid Build Coastguard Worker            num_inputs_discarded=1))
742*14675a02SAndroid Build Coastguard Worker
743*14675a02SAndroid Build Coastguard Worker  def test_abort_aggregation_with_missing_session_id(self):
744*14675a02SAndroid Build Coastguard Worker    service = aggregations.Service(lambda: FORWARDING_INFO,
745*14675a02SAndroid Build Coastguard Worker                                   self.mock_media_service)
746*14675a02SAndroid Build Coastguard Worker    session_id = service.create_session(AGGREGATION_REQUIREMENTS)
747*14675a02SAndroid Build Coastguard Worker    tokens = service.pre_authorize_clients(session_id, 1)
748*14675a02SAndroid Build Coastguard Worker    operation = service.start_aggregation_data_upload(
749*14675a02SAndroid Build Coastguard Worker        aggregations_pb2.StartAggregationDataUploadRequest(
750*14675a02SAndroid Build Coastguard Worker            aggregation_id=session_id, authorization_token=tokens[0]))
751*14675a02SAndroid Build Coastguard Worker    self.assertTrue(operation.done)
752*14675a02SAndroid Build Coastguard Worker    start_upload_response = (
753*14675a02SAndroid Build Coastguard Worker        aggregations_pb2.StartAggregationDataUploadResponse())
754*14675a02SAndroid Build Coastguard Worker    operation.response.Unpack(start_upload_response)
755*14675a02SAndroid Build Coastguard Worker    with self.assertRaises(http_actions.HttpError) as cm:
756*14675a02SAndroid Build Coastguard Worker      service.abort_aggregation(
757*14675a02SAndroid Build Coastguard Worker          aggregations_pb2.AbortAggregationRequest(
758*14675a02SAndroid Build Coastguard Worker              aggregation_id='does-not-exist',
759*14675a02SAndroid Build Coastguard Worker              client_token=start_upload_response.client_token))
760*14675a02SAndroid Build Coastguard Worker    self.assertEqual(cm.exception.code, http.HTTPStatus.NOT_FOUND)
761*14675a02SAndroid Build Coastguard Worker    self.assertEqual(
762*14675a02SAndroid Build Coastguard Worker        service.get_session_status(session_id),
763*14675a02SAndroid Build Coastguard Worker        aggregations.SessionStatus(
764*14675a02SAndroid Build Coastguard Worker            status=aggregations.AggregationStatus.PENDING,
765*14675a02SAndroid Build Coastguard Worker            num_clients_pending=1))
766*14675a02SAndroid Build Coastguard Worker
767*14675a02SAndroid Build Coastguard Worker  def test_abort_aggregation_with_invalid_token(self):
768*14675a02SAndroid Build Coastguard Worker    service = aggregations.Service(lambda: FORWARDING_INFO,
769*14675a02SAndroid Build Coastguard Worker                                   self.mock_media_service)
770*14675a02SAndroid Build Coastguard Worker    session_id = service.create_session(AGGREGATION_REQUIREMENTS)
771*14675a02SAndroid Build Coastguard Worker    with self.assertRaises(http_actions.HttpError) as cm:
772*14675a02SAndroid Build Coastguard Worker      service.abort_aggregation(
773*14675a02SAndroid Build Coastguard Worker          aggregations_pb2.AbortAggregationRequest(
774*14675a02SAndroid Build Coastguard Worker              aggregation_id=session_id, client_token='does-not-exist'))
775*14675a02SAndroid Build Coastguard Worker    self.assertEqual(cm.exception.code, http.HTTPStatus.UNAUTHORIZED)
776*14675a02SAndroid Build Coastguard Worker    self.assertEqual(
777*14675a02SAndroid Build Coastguard Worker        service.get_session_status(session_id),
778*14675a02SAndroid Build Coastguard Worker        aggregations.SessionStatus(
779*14675a02SAndroid Build Coastguard Worker            status=aggregations.AggregationStatus.PENDING))
780*14675a02SAndroid Build Coastguard Worker
781*14675a02SAndroid Build Coastguard Worker
782*14675a02SAndroid Build Coastguard Workerif __name__ == '__main__':
783*14675a02SAndroid Build Coastguard Worker  absltest.main()
784