1 /*
2 * Copyright 2022 Google LLC
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "fcp/client/http/http_federated_protocol.h"
17
18 #include <cstdint>
19 #include <memory>
20 #include <optional>
21 #include <string>
22 #include <utility>
23 #include <vector>
24
25 #include "google/longrunning/operations.pb.h"
26 #include "google/protobuf/any.pb.h"
27 #include "google/protobuf/duration.pb.h"
28 #include "google/rpc/code.pb.h"
29 #include "gmock/gmock.h"
30 #include "gtest/gtest.h"
31 #include "absl/memory/memory.h"
32 #include "absl/random/random.h"
33 #include "absl/status/status.h"
34 #include "absl/status/statusor.h"
35 #include "absl/strings/str_cat.h"
36 #include "absl/synchronization/blocking_counter.h"
37 #include "absl/synchronization/notification.h"
38 #include "absl/time/time.h"
39 #include "fcp/base/clock.h"
40 #include "fcp/base/monitoring.h"
41 #include "fcp/base/platform.h"
42 #include "fcp/base/time_util.h"
43 #include "fcp/base/wall_clock_stopwatch.h"
44 #include "fcp/client/cache/test_helpers.h"
45 #include "fcp/client/diag_codes.pb.h"
46 #include "fcp/client/engine/engine.pb.h"
47 #include "fcp/client/federated_protocol.h"
48 #include "fcp/client/federated_protocol_util.h"
49 #include "fcp/client/http/http_client.h"
50 #include "fcp/client/http/http_client_util.h"
51 #include "fcp/client/http/in_memory_request_response.h"
52 #include "fcp/client/http/testing/test_helpers.h"
53 #include "fcp/client/interruptible_runner.h"
54 #include "fcp/client/stats.h"
55 #include "fcp/client/test_helpers.h"
56 #include "fcp/protos/federated_api.pb.h"
57 #include "fcp/protos/federatedcompute/aggregations.pb.h"
58 #include "fcp/protos/federatedcompute/common.pb.h"
59 #include "fcp/protos/federatedcompute/eligibility_eval_tasks.pb.h"
60 #include "fcp/protos/federatedcompute/secure_aggregations.pb.h"
61 #include "fcp/protos/federatedcompute/task_assignments.pb.h"
62 #include "fcp/protos/plan.pb.h"
63 #include "fcp/secagg/shared/secagg_messages.pb.h"
64 #include "fcp/testing/testing.h"
65
66 namespace fcp::client::http {
67 namespace {
68
69 using ::fcp::EqualsProto;
70 using ::fcp::IsCode;
71 using ::fcp::client::http::FakeHttpResponse;
72 using ::fcp::client::http::MockableHttpClient;
73 using ::fcp::client::http::MockHttpClient;
74 using ::fcp::client::http::SimpleHttpRequestMatcher;
75 using ::google::internal::federatedcompute::v1::ByteStreamResource;
76 using ::google::internal::federatedcompute::v1::ClientStats;
77 using ::google::internal::federatedcompute::v1::EligibilityEvalTask;
78 using ::google::internal::federatedcompute::v1::EligibilityEvalTaskRequest;
79 using ::google::internal::federatedcompute::v1::EligibilityEvalTaskResponse;
80 using ::google::internal::federatedcompute::v1::ForwardingInfo;
81 using ::google::internal::federatedcompute::v1::PopulationEligibilitySpec;
82 using ::google::internal::federatedcompute::v1::
83 ReportEligibilityEvalTaskResultRequest;
84 using ::google::internal::federatedcompute::v1::ReportTaskResultRequest;
85 using ::google::internal::federatedcompute::v1::ReportTaskResultResponse;
86 using ::google::internal::federatedcompute::v1::Resource;
87 using ::google::internal::federatedcompute::v1::ResourceCompressionFormat;
88 using ::google::internal::federatedcompute::v1::RetryWindow;
89 using ::google::internal::federatedcompute::v1::SecureAggregandExecutionInfo;
90 using ::google::internal::federatedcompute::v1::
91 StartAggregationDataUploadRequest;
92 using ::google::internal::federatedcompute::v1::
93 StartAggregationDataUploadResponse;
94 using ::google::internal::federatedcompute::v1::StartSecureAggregationRequest;
95 using ::google::internal::federatedcompute::v1::StartSecureAggregationResponse;
96 using ::google::internal::federatedcompute::v1::StartTaskAssignmentRequest;
97 using ::google::internal::federatedcompute::v1::StartTaskAssignmentResponse;
98 using ::google::internal::federatedcompute::v1::SubmitAggregationResultRequest;
99 using ::google::internal::federatedcompute::v1::TaskAssignment;
100 using ::google::internal::federatedml::v2::TaskEligibilityInfo;
101 using ::google::internal::federatedml::v2::TaskWeight;
102 using ::google::longrunning::GetOperationRequest;
103 using ::google::longrunning::Operation;
104 using ::testing::_;
105 using ::testing::AllOf;
106 using ::testing::ByMove;
107 using ::testing::DescribeMatcher;
108 using ::testing::DoubleEq;
109 using ::testing::DoubleNear;
110 using ::testing::Eq;
111 using ::testing::ExplainMatchResult;
112 using ::testing::Field;
113 using ::testing::FieldsAre;
114 using ::testing::Ge;
115 using ::testing::Gt;
116 using ::testing::HasSubstr;
117 using ::testing::InSequence;
118 using ::testing::IsEmpty;
119 using ::testing::Lt;
120 using ::testing::MockFunction;
121 using ::testing::NiceMock;
122 using ::testing::Not;
123 using ::testing::Optional;
124 using ::testing::Return;
125 using ::testing::StrEq;
126 using ::testing::StrictMock;
127 using ::testing::UnorderedElementsAre;
128 using ::testing::VariantWith;
129 using ::testing::WithArg;
130
131 constexpr char kEntryPointUri[] = "https://initial.uri/";
132 constexpr char kTaskAssignmentTargetUri[] = "https://taskassignment.uri/";
133 constexpr char kAggregationTargetUri[] = "https://aggregation.uri/";
134 constexpr char kSecondStageAggregationTargetUri[] =
135 "https://aggregation.second.uri/";
136 constexpr char kByteStreamTargetUri[] = "https://bytestream.uri/";
137 constexpr char kApiKey[] = "TEST_APIKEY";
138 // Note that we include a '/' character in the population name, which allows us
139 // to verify that it is correctly URL-encoded into "%2F".
140 constexpr char kPopulationName[] = "TEST/POPULATION";
141 constexpr char kEligibilityEvalExecutionId[] = "ELIGIBILITY_EXECUTION_ID";
142 // Note that we include a '/' and '#' characters in the population name, which
143 // allows us to verify that it is correctly URL-encoded into "%2F" and "%23".
144 constexpr char kEligibilityEvalSessionId[] = "ELIGIBILITY/SESSION#ID";
145 constexpr char kPlan[] = "CLIENT_ONLY_PLAN";
146 constexpr char kInitCheckpoint[] = "INIT_CHECKPOINT";
147 constexpr char kRetryToken[] = "OLD_RETRY_TOKEN";
148 constexpr char kClientVersion[] = "CLIENT_VERSION";
149 constexpr char kAttestationMeasurement[] = "ATTESTATION_MEASUREMENT";
150 constexpr char kClientSessionId[] = "CLIENT_SESSION_ID";
151 constexpr char kAggregationSessionId[] = "AGGREGATION_SESSION_ID";
152 constexpr char kAuthorizationToken[] = "AUTHORIZATION_TOKEN";
153 constexpr char kTaskName[] = "TASK_NAME";
154 constexpr char kClientToken[] = "CLIENT_TOKEN";
155 constexpr char kResourceName[] = "CHECKPOINT_RESOURCE";
156 constexpr char kFederatedSelectUriTemplate[] = "https://federated.select";
157 constexpr char kOperationName[] = "my_operation";
158
159 const int32_t kCancellationWaitingPeriodSec = 1;
160 const int32_t kMinimumClientsInServerVisibleAggregate = 2;
161
162 MATCHER_P(EligibilityEvalTaskRequestMatcher, matcher,
163 absl::StrCat(negation ? "doesn't parse" : "parses",
164 " as an EligibilityEvalTaskRequest, and that ",
165 DescribeMatcher<EligibilityEvalTaskRequest>(matcher,
166 negation))) {
167 EligibilityEvalTaskRequest request;
168 if (!request.ParseFromString(arg)) {
169 return false;
170 }
171 return ExplainMatchResult(matcher, request, result_listener);
172 }
173
174 MATCHER_P(
175 ReportEligibilityEvalTaskResultRequestMatcher, matcher,
176 absl::StrCat(negation ? "doesn't parse" : "parses",
177 " as a ReportEligibilityEvalTaskResultRequest, and that ",
178 DescribeMatcher<ReportEligibilityEvalTaskResultRequest>(
179 matcher, negation))) {
180 ReportEligibilityEvalTaskResultRequest request;
181 if (!request.ParseFromString(arg)) {
182 return false;
183 }
184 return ExplainMatchResult(matcher, request, result_listener);
185 }
186
187 MATCHER_P(StartTaskAssignmentRequestMatcher, matcher,
188 absl::StrCat(negation ? "doesn't parse" : "parses",
189 " as a StartTaskAssignmentRequest, and that ",
190 DescribeMatcher<StartTaskAssignmentRequest>(matcher,
191 negation))) {
192 StartTaskAssignmentRequest request;
193 if (!request.ParseFromString(arg)) {
194 return false;
195 }
196 return ExplainMatchResult(matcher, request, result_listener);
197 }
198
199 MATCHER_P(GetOperationRequestMatcher, matcher,
200 absl::StrCat(negation ? "doesn't parse" : "parses",
201 " as a GetOperationRequest, and that ",
202 DescribeMatcher<GetOperationRequest>(matcher,
203 negation))) {
204 GetOperationRequest request;
205 if (!request.ParseFromString(arg)) {
206 return false;
207 }
208 return ExplainMatchResult(matcher, request, result_listener);
209 }
210
211 MATCHER_P(ReportTaskResultRequestMatcher, matcher,
212 absl::StrCat(negation ? "doesn't parse" : "parses",
213 " as a ReportTaskResultRequest, and that ",
214 DescribeMatcher<ReportTaskResultRequest>(matcher,
215 negation))) {
216 ReportTaskResultRequest request;
217 if (!request.ParseFromString(arg)) {
218 return false;
219 }
220 return ExplainMatchResult(matcher, request, result_listener);
221 }
222
223 constexpr int kTransientErrorsRetryPeriodSecs = 10;
224 constexpr double kTransientErrorsRetryDelayJitterPercent = 0.1;
225 constexpr double kExpectedTransientErrorsRetryPeriodSecsMin = 9.0;
226 constexpr double kExpectedTransientErrorsRetryPeriodSecsMax = 11.0;
227 constexpr int kPermanentErrorsRetryPeriodSecs = 100;
228 constexpr double kPermanentErrorsRetryDelayJitterPercent = 0.2;
229 constexpr double kExpectedPermanentErrorsRetryPeriodSecsMin = 80.0;
230 constexpr double kExpectedPermanentErrorsRetryPeriodSecsMax = 120.0;
231
ExpectTransientErrorRetryWindow(const::google::internal::federatedml::v2::RetryWindow & retry_window)232 void ExpectTransientErrorRetryWindow(
233 const ::google::internal::federatedml::v2::RetryWindow& retry_window) {
234 // The calculated retry delay must lie within the expected transient errors
235 // retry delay range.
236 EXPECT_THAT(retry_window.delay_min().seconds() +
237 retry_window.delay_min().nanos() / 1000000000,
238 AllOf(Ge(kExpectedTransientErrorsRetryPeriodSecsMin),
239 Lt(kExpectedTransientErrorsRetryPeriodSecsMax)));
240 EXPECT_THAT(retry_window.delay_max(), EqualsProto(retry_window.delay_min()));
241 }
242
ExpectPermanentErrorRetryWindow(const::google::internal::federatedml::v2::RetryWindow & retry_window)243 void ExpectPermanentErrorRetryWindow(
244 const ::google::internal::federatedml::v2::RetryWindow& retry_window) {
245 // The calculated retry delay must lie within the expected permanent errors
246 // retry delay range.
247 EXPECT_THAT(retry_window.delay_min().seconds() +
248 retry_window.delay_min().nanos() / 1000000000,
249 AllOf(Ge(kExpectedPermanentErrorsRetryPeriodSecsMin),
250 Lt(kExpectedPermanentErrorsRetryPeriodSecsMax)));
251 EXPECT_THAT(retry_window.delay_max(), EqualsProto(retry_window.delay_min()));
252 }
253
GetAcceptedRetryWindow()254 RetryWindow GetAcceptedRetryWindow() {
255 // Must not overlap with kTransientErrorsRetryPeriodSecs or
256 // kPermanentErrorsRetryPeriodSecs.
257 RetryWindow retry_window;
258 retry_window.mutable_delay_min()->set_seconds(200L);
259 retry_window.mutable_delay_max()->set_seconds(299L);
260 return retry_window;
261 }
262
ExpectAcceptedRetryWindow(const::google::internal::federatedml::v2::RetryWindow & retry_window)263 void ExpectAcceptedRetryWindow(
264 const ::google::internal::federatedml::v2::RetryWindow& retry_window) {
265 // The calculated retry delay must lie within the expected 'rejected' retry
266 // delay range.
267 EXPECT_THAT(retry_window.delay_min().seconds() +
268 retry_window.delay_min().nanos() / 1000000000,
269 AllOf(Ge(200L), Lt(299L)));
270 EXPECT_THAT(retry_window.delay_max(), EqualsProto(retry_window.delay_min()));
271 }
272
GetRejectedRetryWindow()273 RetryWindow GetRejectedRetryWindow() {
274 // Must not overlap with kTransientErrorsRetryPeriodSecs or
275 // kPermanentErrorsRetryPeriodSecs.
276 RetryWindow retry_window;
277 retry_window.mutable_delay_min()->set_seconds(300L);
278 retry_window.mutable_delay_max()->set_seconds(399L);
279 return retry_window;
280 }
281
ExpectRejectedRetryWindow(const::google::internal::federatedml::v2::RetryWindow & retry_window)282 void ExpectRejectedRetryWindow(
283 const ::google::internal::federatedml::v2::RetryWindow& retry_window) {
284 // The calculated retry delay must lie within the expected 'rejected' retry
285 // delay range.
286 EXPECT_THAT(retry_window.delay_min().seconds() +
287 retry_window.delay_min().nanos() / 1000000000,
288 AllOf(Ge(300L), Lt(399L)));
289 EXPECT_THAT(retry_window.delay_max(), EqualsProto(retry_window.delay_min()));
290 }
291
GetExpectedEligibilityEvalTaskRequest(bool supports_multiple_task_assignments=false)292 EligibilityEvalTaskRequest GetExpectedEligibilityEvalTaskRequest(
293 bool supports_multiple_task_assignments = false) {
294 EligibilityEvalTaskRequest request;
295 // Note: we don't expect population_name to be set, since it should be set in
296 // the URI instead.
297 request.mutable_client_version()->set_version_code(kClientVersion);
298 request.mutable_attestation_measurement()->set_value(kAttestationMeasurement);
299 request.mutable_resource_capabilities()
300 ->mutable_supported_compression_formats()
301 ->Add(ResourceCompressionFormat::RESOURCE_COMPRESSION_FORMAT_GZIP);
302 request.mutable_eligibility_eval_task_capabilities()
303 ->set_supports_multiple_task_assignment(
304 supports_multiple_task_assignments);
305 return request;
306 }
307
GetFakeEnabledEligibilityEvalTaskResponse(const Resource & plan,const Resource & checkpoint,const std::string & execution_id,std::optional<Resource> population_eligibility_spec=std::nullopt,const RetryWindow & accepted_retry_window=GetAcceptedRetryWindow (),const RetryWindow & rejected_retry_window=GetRejectedRetryWindow ())308 EligibilityEvalTaskResponse GetFakeEnabledEligibilityEvalTaskResponse(
309 const Resource& plan, const Resource& checkpoint,
310 const std::string& execution_id,
311 std::optional<Resource> population_eligibility_spec = std::nullopt,
312 const RetryWindow& accepted_retry_window = GetAcceptedRetryWindow(),
313 const RetryWindow& rejected_retry_window = GetRejectedRetryWindow()) {
314 EligibilityEvalTaskResponse response;
315 response.set_session_id(kEligibilityEvalSessionId);
316 EligibilityEvalTask* eval_task = response.mutable_eligibility_eval_task();
317 *eval_task->mutable_plan() = plan;
318 *eval_task->mutable_init_checkpoint() = checkpoint;
319 if (population_eligibility_spec.has_value()) {
320 *eval_task->mutable_population_eligibility_spec() =
321 population_eligibility_spec.value();
322 }
323 eval_task->set_execution_id(execution_id);
324 ForwardingInfo* forwarding_info =
325 response.mutable_task_assignment_forwarding_info();
326 forwarding_info->set_target_uri_prefix(kTaskAssignmentTargetUri);
327 *response.mutable_retry_window_if_accepted() = accepted_retry_window;
328 *response.mutable_retry_window_if_rejected() = rejected_retry_window;
329 return response;
330 }
331
GetFakeDisabledEligibilityEvalTaskResponse()332 EligibilityEvalTaskResponse GetFakeDisabledEligibilityEvalTaskResponse() {
333 EligibilityEvalTaskResponse response;
334 response.set_session_id(kEligibilityEvalSessionId);
335 response.mutable_no_eligibility_eval_configured();
336 ForwardingInfo* forwarding_info =
337 response.mutable_task_assignment_forwarding_info();
338 forwarding_info->set_target_uri_prefix(kTaskAssignmentTargetUri);
339 *response.mutable_retry_window_if_accepted() = GetAcceptedRetryWindow();
340 *response.mutable_retry_window_if_rejected() = GetRejectedRetryWindow();
341 return response;
342 }
343
GetFakeRejectedEligibilityEvalTaskResponse()344 EligibilityEvalTaskResponse GetFakeRejectedEligibilityEvalTaskResponse() {
345 EligibilityEvalTaskResponse response;
346 response.mutable_rejection_info();
347 *response.mutable_retry_window_if_accepted() = GetAcceptedRetryWindow();
348 *response.mutable_retry_window_if_rejected() = GetRejectedRetryWindow();
349 return response;
350 }
351
GetFakeTaskEligibilityInfo()352 TaskEligibilityInfo GetFakeTaskEligibilityInfo() {
353 TaskEligibilityInfo eligibility_info;
354 TaskWeight* task_weight = eligibility_info.mutable_task_weights()->Add();
355 task_weight->set_task_name("foo");
356 task_weight->set_weight(567.8);
357 return eligibility_info;
358 }
359
GetExpectedStartTaskAssignmentRequest(const std::optional<TaskEligibilityInfo> & task_eligibility_info)360 StartTaskAssignmentRequest GetExpectedStartTaskAssignmentRequest(
361 const std::optional<TaskEligibilityInfo>& task_eligibility_info) {
362 // Note: we don't expect population_name or session_id to be set, since they
363 // should be set in the URI instead.
364 StartTaskAssignmentRequest request;
365 request.mutable_client_version()->set_version_code(kClientVersion);
366 if (task_eligibility_info.has_value()) {
367 *request.mutable_task_eligibility_info() = *task_eligibility_info;
368 }
369 request.mutable_resource_capabilities()
370 ->mutable_supported_compression_formats()
371 ->Add(ResourceCompressionFormat::RESOURCE_COMPRESSION_FORMAT_GZIP);
372 return request;
373 }
374
GetFakeRejectedTaskAssignmentResponse()375 StartTaskAssignmentResponse GetFakeRejectedTaskAssignmentResponse() {
376 StartTaskAssignmentResponse response;
377 response.mutable_rejection_info();
378 return response;
379 }
380
GetFakeTaskAssignmentResponse(const Resource & plan,const Resource & checkpoint,const std::string & federated_select_uri_template,const std::string & aggregation_session_id,int32_t minimum_clients_in_server_visible_aggregate)381 StartTaskAssignmentResponse GetFakeTaskAssignmentResponse(
382 const Resource& plan, const Resource& checkpoint,
383 const std::string& federated_select_uri_template,
384 const std::string& aggregation_session_id,
385 int32_t minimum_clients_in_server_visible_aggregate) {
386 StartTaskAssignmentResponse response;
387 TaskAssignment* task_assignment = response.mutable_task_assignment();
388 ForwardingInfo* forwarding_info =
389 task_assignment->mutable_aggregation_data_forwarding_info();
390 forwarding_info->set_target_uri_prefix(kAggregationTargetUri);
391 task_assignment->set_session_id(kClientSessionId);
392 task_assignment->set_aggregation_id(aggregation_session_id);
393 task_assignment->set_authorization_token(kAuthorizationToken);
394 task_assignment->set_task_name(kTaskName);
395 *task_assignment->mutable_plan() = plan;
396 *task_assignment->mutable_init_checkpoint() = checkpoint;
397 task_assignment->mutable_federated_select_uri_info()->set_uri_template(
398 federated_select_uri_template);
399 if (minimum_clients_in_server_visible_aggregate > 0) {
400 task_assignment->mutable_secure_aggregation_info()
401 ->set_minimum_clients_in_server_visible_aggregate(
402 minimum_clients_in_server_visible_aggregate);
403 } else {
404 task_assignment->mutable_aggregation_info();
405 }
406 return response;
407 }
408
GetExpectedReportTaskResultRequest(absl::string_view aggregation_id,absl::string_view task_name,::google::rpc::Code code,absl::Duration train_duration)409 ReportTaskResultRequest GetExpectedReportTaskResultRequest(
410 absl::string_view aggregation_id, absl::string_view task_name,
411 ::google::rpc::Code code, absl::Duration train_duration) {
412 ReportTaskResultRequest request;
413 request.set_aggregation_id(std::string(aggregation_id));
414 request.set_task_name(std::string(task_name));
415 request.set_computation_status_code(code);
416 ClientStats client_stats;
417 *client_stats.mutable_computation_execution_duration() =
418 TimeUtil::ConvertAbslToProtoDuration(train_duration);
419 *request.mutable_client_stats() = client_stats;
420 return request;
421 }
422
GetFakeStartAggregationDataUploadResponse(absl::string_view aggregation_resource_name,absl::string_view byte_stream_uri_prefix,absl::string_view second_stage_aggregation_uri_prefix)423 StartAggregationDataUploadResponse GetFakeStartAggregationDataUploadResponse(
424 absl::string_view aggregation_resource_name,
425 absl::string_view byte_stream_uri_prefix,
426 absl::string_view second_stage_aggregation_uri_prefix) {
427 StartAggregationDataUploadResponse response;
428 ByteStreamResource* resource = response.mutable_resource();
429 *resource->mutable_resource_name() = aggregation_resource_name;
430 ForwardingInfo* data_upload_forwarding_info =
431 resource->mutable_data_upload_forwarding_info();
432 *data_upload_forwarding_info->mutable_target_uri_prefix() =
433 byte_stream_uri_prefix;
434 ForwardingInfo* aggregation_protocol_forwarding_info =
435 response.mutable_aggregation_protocol_forwarding_info();
436 *aggregation_protocol_forwarding_info->mutable_target_uri_prefix() =
437 second_stage_aggregation_uri_prefix;
438 response.set_client_token(kClientToken);
439 return response;
440 }
441
CreateEmptySuccessHttpResponse()442 FakeHttpResponse CreateEmptySuccessHttpResponse() {
443 return FakeHttpResponse(200, HeaderList(), "");
444 }
445
446 class HttpFederatedProtocolTest : public ::testing::Test {
447 protected:
SetUp()448 void SetUp() override {
449 EXPECT_CALL(mock_flags_,
450 federated_training_transient_errors_retry_delay_secs)
451 .WillRepeatedly(Return(kTransientErrorsRetryPeriodSecs));
452 EXPECT_CALL(mock_flags_,
453 federated_training_transient_errors_retry_delay_jitter_percent)
454 .WillRepeatedly(Return(kTransientErrorsRetryDelayJitterPercent));
455 EXPECT_CALL(mock_flags_,
456 federated_training_permanent_errors_retry_delay_secs)
457 .WillRepeatedly(Return(kPermanentErrorsRetryPeriodSecs));
458 EXPECT_CALL(mock_flags_,
459 federated_training_permanent_errors_retry_delay_jitter_percent)
460 .WillRepeatedly(Return(kPermanentErrorsRetryDelayJitterPercent));
461 EXPECT_CALL(mock_flags_, federated_training_permanent_error_codes)
462 .WillRepeatedly(Return(std::vector<int32_t>{
463 static_cast<int32_t>(absl::StatusCode::kNotFound),
464 static_cast<int32_t>(absl::StatusCode::kInvalidArgument),
465 static_cast<int32_t>(absl::StatusCode::kUnimplemented)}));
466 // Note that we disable compression in test to make it easier to verify the
467 // request body. The compression logic is tested in
468 // in_memory_request_response_test.cc.
469 EXPECT_CALL(mock_flags_, disable_http_request_body_compression)
470 .WillRepeatedly(Return(true));
471 EXPECT_CALL(mock_flags_, waiting_period_sec_for_cancellation)
472 .WillRepeatedly(Return(kCancellationWaitingPeriodSec));
473
474 EXPECT_CALL(mock_flags_, http_protocol_supports_multiple_task_assignments)
475 .WillRepeatedly(Return(false));
476
477 // We only initialize federated_protocol_ in this SetUp method, rather than
478 // in the test's constructor, to ensure that we can set mock flag values
479 // before the HttpFederatedProtocol constructor is called. Using
480 // std::unique_ptr conveniently allows us to assign the field a new value
481 // after construction (which we could not do if the field's type was
482 // HttpFederatedProtocol, since it doesn't have copy or move constructors).
483 federated_protocol_ = std::make_unique<HttpFederatedProtocol>(
484 clock_, &mock_log_manager_, &mock_flags_, &mock_http_client_,
485 absl::WrapUnique(mock_secagg_runner_factory_),
486 &mock_secagg_event_publisher_, kEntryPointUri, kApiKey, kPopulationName,
487 kRetryToken, kClientVersion, kAttestationMeasurement,
488 mock_should_abort_.AsStdFunction(), absl::BitGen(),
489 InterruptibleRunner::TimingConfig{
490 .polling_period = absl::ZeroDuration(),
491 .graceful_shutdown_period = absl::InfiniteDuration(),
492 .extended_shutdown_period = absl::InfiniteDuration()},
493 &mock_resource_cache_);
494 }
495
TearDown()496 void TearDown() override {
497 // Regardless of the outcome of the test (or the protocol interaction being
498 // tested), network usage must always be reflected in the network stats
499 // methods.
500 HttpRequestHandle::SentReceivedBytes sent_received_bytes =
501 mock_http_client_.TotalSentReceivedBytes();
502
503 NetworkStats network_stats = federated_protocol_->GetNetworkStats();
504 EXPECT_EQ(network_stats.bytes_downloaded,
505 sent_received_bytes.received_bytes);
506 EXPECT_EQ(network_stats.bytes_uploaded, sent_received_bytes.sent_bytes);
507 // If any network traffic occurred, we expect to see some time reflected in
508 // the duration.
509 if (network_stats.bytes_uploaded > 0) {
510 EXPECT_THAT(network_stats.network_duration, Gt(absl::ZeroDuration()));
511 }
512 }
513
514 // This function runs a successful EligibilityEvalCheckin() that results in an
515 // eligibility eval payload being returned by the server (if
516 // `eligibility_eval_enabled` is true), or results in a 'no eligibility eval
517 // configured' response (if `eligibility_eval_enabled` is false). This is a
518 // utility function used by Checkin*() tests that depend on a prior,
519 // successful execution of EligibilityEvalCheckin(). It returns a
520 // absl::Status, which the caller should verify is OK using ASSERT_OK.
RunSuccessfulEligibilityEvalCheckin(bool eligibility_eval_enabled=true)521 absl::Status RunSuccessfulEligibilityEvalCheckin(
522 bool eligibility_eval_enabled = true) {
523 EligibilityEvalTaskResponse eval_task_response;
524 if (eligibility_eval_enabled) {
525 // We return a fake response which returns the plan/initial checkpoint
526 // data inline, to keep things simple.
527 std::string expected_plan = kPlan;
528 Resource plan_resource;
529 plan_resource.mutable_inline_resource()->set_data(kPlan);
530 std::string expected_checkpoint = kInitCheckpoint;
531 Resource checkpoint_resource;
532 checkpoint_resource.mutable_inline_resource()->set_data(
533 expected_checkpoint);
534 eval_task_response = GetFakeEnabledEligibilityEvalTaskResponse(
535 plan_resource, checkpoint_resource, kEligibilityEvalExecutionId);
536 } else {
537 eval_task_response = GetFakeDisabledEligibilityEvalTaskResponse();
538 }
539 std::string request_uri =
540 "https://initial.uri/v1/eligibilityevaltasks/"
541 "TEST%2FPOPULATION:request?%24alt=proto";
542 EXPECT_CALL(mock_http_client_,
543 PerformSingleRequest(SimpleHttpRequestMatcher(
544 request_uri, HttpRequest::Method::kPost, _,
545 EligibilityEvalTaskRequestMatcher(
546 EqualsProto(GetExpectedEligibilityEvalTaskRequest())))))
547 .WillOnce(Return(FakeHttpResponse(
548 200, HeaderList(), eval_task_response.SerializeAsString())));
549
550 // The 'EET received' callback should be called, even if the task resource
551 // data was available inline.
552 if (eligibility_eval_enabled) {
553 EXPECT_CALL(mock_eet_received_callback_,
554 Call(FieldsAre(FieldsAre("", ""), kEligibilityEvalExecutionId,
555 Eq(std::nullopt))));
556 }
557
558 return federated_protocol_
559 ->EligibilityEvalCheckin(mock_eet_received_callback_.AsStdFunction())
560 .status();
561 }
562
563 // This function runs a successful Checkin() that results in a
564 // task assignment payload being returned by the server. This is a
565 // utility function used by Report*() tests that depend on a prior,
566 // successful execution of Checkin(). It returns a
567 // absl::Status, which the caller should verify is OK using ASSERT_OK.
RunSuccessfulCheckin(bool eligibility_eval_enabled=true)568 absl::Status RunSuccessfulCheckin(bool eligibility_eval_enabled = true) {
569 // We return a fake response which returns the plan/initial checkpoint
570 // data inline, to keep things simple.
571 std::string expected_plan = kPlan;
572 std::string plan_uri = "https://fake.uri/plan";
573 Resource plan_resource;
574 plan_resource.set_uri(plan_uri);
575 std::string expected_checkpoint = kInitCheckpoint;
576 Resource checkpoint_resource;
577 checkpoint_resource.mutable_inline_resource()->set_data(
578 expected_checkpoint);
579 std::string expected_aggregation_session_id = kAggregationSessionId;
580 StartTaskAssignmentResponse task_assignment_response =
581 GetFakeTaskAssignmentResponse(plan_resource, checkpoint_resource,
582 kFederatedSelectUriTemplate,
583 expected_aggregation_session_id, 0);
584
585 std::string request_uri =
586 "https://taskassignment.uri/v1/populations/TEST%2FPOPULATION/"
587 "taskassignments/ELIGIBILITY%2FSESSION%23ID:start?%24alt=proto";
588 TaskEligibilityInfo expected_eligibility_info =
589 GetFakeTaskEligibilityInfo();
590 EXPECT_CALL(mock_http_client_,
591 PerformSingleRequest(SimpleHttpRequestMatcher(
592 request_uri, HttpRequest::Method::kPost, _,
593 StartTaskAssignmentRequestMatcher(
594 EqualsProto(GetExpectedStartTaskAssignmentRequest(
595 expected_eligibility_info))))))
596 .WillOnce(Return(FakeHttpResponse(
597 200, HeaderList(),
598 CreateDoneOperation(kOperationName, task_assignment_response)
599 .SerializeAsString())));
600
601 EXPECT_CALL(mock_http_client_,
602 PerformSingleRequest(SimpleHttpRequestMatcher(
603 plan_uri, HttpRequest::Method::kGet, _, "")))
604 .WillOnce(Return(FakeHttpResponse(200, HeaderList(), expected_plan)));
605
606 if (eligibility_eval_enabled) {
607 std::string report_eet_request_uri =
608 "https://initial.uri/v1/populations/TEST%2FPOPULATION/"
609 "eligibilityevaltasks/"
610 "ELIGIBILITY%2FSESSION%23ID:reportresult?%24alt=proto";
611 ExpectSuccessfulReportEligibilityEvalTaskResultRequest(
612 report_eet_request_uri, absl::OkStatus());
613 }
614
615 return federated_protocol_
616 ->Checkin(expected_eligibility_info,
617 mock_task_received_callback_.AsStdFunction())
618 .status();
619 }
620
ExpectSuccessfulReportEligibilityEvalTaskResultRequest(absl::string_view expected_request_uri,absl::Status eet_status)621 void ExpectSuccessfulReportEligibilityEvalTaskResultRequest(
622 absl::string_view expected_request_uri, absl::Status eet_status) {
623 ReportEligibilityEvalTaskResultRequest report_eet_request;
624 report_eet_request.set_status_code(
625 static_cast<google::rpc::Code>(eet_status.code()));
626 EXPECT_CALL(
627 mock_http_client_,
628 PerformSingleRequest(SimpleHttpRequestMatcher(
629 std::string(expected_request_uri), HttpRequest::Method::kPost, _,
630 ReportEligibilityEvalTaskResultRequestMatcher(
631 EqualsProto(report_eet_request)))))
632 .WillOnce(Return(FakeHttpResponse(200, HeaderList(), "")));
633 }
634
ExpectSuccessfulReportTaskResultRequest(absl::string_view expected_report_result_uri,absl::string_view aggregation_session_id,absl::string_view task_name,absl::Duration plan_duration)635 void ExpectSuccessfulReportTaskResultRequest(
636 absl::string_view expected_report_result_uri,
637 absl::string_view aggregation_session_id, absl::string_view task_name,
638 absl::Duration plan_duration) {
639 ReportTaskResultResponse report_task_result_response;
640 EXPECT_CALL(mock_http_client_,
641 PerformSingleRequest(SimpleHttpRequestMatcher(
642 std::string(expected_report_result_uri),
643 HttpRequest::Method::kPost, _,
644 ReportTaskResultRequestMatcher(
645 EqualsProto(GetExpectedReportTaskResultRequest(
646 aggregation_session_id, task_name,
647 google::rpc::Code::OK, plan_duration))))))
648 .WillOnce(Return(CreateEmptySuccessHttpResponse()));
649 }
650
ExpectSuccessfulStartAggregationDataUploadRequest(absl::string_view expected_start_data_upload_uri,absl::string_view aggregation_resource_name,absl::string_view byte_stream_uri_prefix,absl::string_view second_stage_aggregation_uri_prefix)651 void ExpectSuccessfulStartAggregationDataUploadRequest(
652 absl::string_view expected_start_data_upload_uri,
653 absl::string_view aggregation_resource_name,
654 absl::string_view byte_stream_uri_prefix,
655 absl::string_view second_stage_aggregation_uri_prefix) {
656 Operation pending_operation_response =
657 CreatePendingOperation("operations/foo#bar");
658 EXPECT_CALL(mock_http_client_,
659 PerformSingleRequest(SimpleHttpRequestMatcher(
660 std::string(expected_start_data_upload_uri),
661 HttpRequest::Method::kPost, _,
662 StartAggregationDataUploadRequest().SerializeAsString())))
663 .WillOnce(Return(
664 FakeHttpResponse(200, HeaderList(),
665 pending_operation_response.SerializeAsString())));
666 EXPECT_CALL(
667 mock_http_client_,
668 PerformSingleRequest(SimpleHttpRequestMatcher(
669 // Note that the '#' character is encoded as "%23".
670 "https://aggregation.uri/v1/operations/foo%23bar?%24alt=proto",
671 HttpRequest::Method::kGet, _,
672 GetOperationRequestMatcher(EqualsProto(GetOperationRequest())))))
673 .WillOnce(Return(FakeHttpResponse(
674 200, HeaderList(),
675 CreateDoneOperation(
676 kOperationName,
677 GetFakeStartAggregationDataUploadResponse(
678 aggregation_resource_name, byte_stream_uri_prefix,
679 second_stage_aggregation_uri_prefix))
680 .SerializeAsString())));
681 }
682
ExpectSuccessfulByteStreamUploadRequest(absl::string_view byte_stream_upload_uri,absl::string_view checkpoint_str)683 void ExpectSuccessfulByteStreamUploadRequest(
684 absl::string_view byte_stream_upload_uri,
685 absl::string_view checkpoint_str) {
686 EXPECT_CALL(
687 mock_http_client_,
688 PerformSingleRequest(SimpleHttpRequestMatcher(
689 std::string(byte_stream_upload_uri), HttpRequest::Method::kPost, _,
690 std::string(checkpoint_str))))
691 .WillOnce(Return(CreateEmptySuccessHttpResponse()));
692 }
693
ExpectSuccessfulSubmitAggregationResultRequest(absl::string_view expected_submit_aggregation_result_uri)694 void ExpectSuccessfulSubmitAggregationResultRequest(
695 absl::string_view expected_submit_aggregation_result_uri) {
696 SubmitAggregationResultRequest submit_aggregation_result_request;
697 submit_aggregation_result_request.set_resource_name(kResourceName);
698 EXPECT_CALL(mock_http_client_,
699 PerformSingleRequest(SimpleHttpRequestMatcher(
700 std::string(expected_submit_aggregation_result_uri),
701 HttpRequest::Method::kPost, _,
702 submit_aggregation_result_request.SerializeAsString())))
703 .WillOnce(Return(CreateEmptySuccessHttpResponse()));
704 }
705
ExpectSuccessfulAbortAggregationRequest(absl::string_view base_uri)706 void ExpectSuccessfulAbortAggregationRequest(absl::string_view base_uri) {
707 EXPECT_CALL(mock_http_client_,
708 PerformSingleRequest(SimpleHttpRequestMatcher(
709 absl::StrCat(base_uri, "/v1/aggregations/",
710 "AGGREGATION_SESSION_ID/clients/"
711 "CLIENT_TOKEN:abort?%24alt=proto"),
712 HttpRequest::Method::kPost, _, _)))
713 .WillOnce(Return(CreateEmptySuccessHttpResponse()));
714 }
715
716 StrictMock<MockHttpClient> mock_http_client_;
717 StrictMock<MockSecAggRunnerFactory>* mock_secagg_runner_factory_ =
718 new StrictMock<MockSecAggRunnerFactory>();
719 StrictMock<MockSecAggEventPublisher> mock_secagg_event_publisher_;
720 StrictMock<MockLogManager> mock_log_manager_;
721 NiceMock<MockFlags> mock_flags_;
722 NiceMock<MockFunction<bool()>> mock_should_abort_;
723 StrictMock<cache::MockResourceCache> mock_resource_cache_;
724 Clock* clock_ = Clock::RealClock();
725 NiceMock<MockFunction<void(
726 const ::fcp::client::FederatedProtocol::EligibilityEvalTask&)>>
727 mock_eet_received_callback_;
728 NiceMock<MockFunction<void(
729 const ::fcp::client::FederatedProtocol::TaskAssignment&)>>
730 mock_task_received_callback_;
731
732 // The class under test.
733 std::unique_ptr<HttpFederatedProtocol> federated_protocol_;
734 };
735
736 using HttpFederatedProtocolDeathTest = HttpFederatedProtocolTest;
737
TEST_F(HttpFederatedProtocolTest,TestTransientErrorRetryWindowDifferentAcrossDifferentInstances)738 TEST_F(HttpFederatedProtocolTest,
739 TestTransientErrorRetryWindowDifferentAcrossDifferentInstances) {
740 const ::google::internal::federatedml::v2::RetryWindow& retry_window1 =
741 federated_protocol_->GetLatestRetryWindow();
742 ExpectTransientErrorRetryWindow(retry_window1);
743 federated_protocol_.reset(nullptr);
744 mock_secagg_runner_factory_ = new StrictMock<MockSecAggRunnerFactory>();
745
746 // Create a new HttpFederatedProtocol instance. It should not produce the same
747 // retry window value as the one we just got. This is a simple correctness
748 // check to ensure that the value is at least randomly generated (and that we
749 // don't accidentally use the random number generator incorrectly).
750 federated_protocol_ = std::make_unique<HttpFederatedProtocol>(
751 clock_, &mock_log_manager_, &mock_flags_, &mock_http_client_,
752 absl::WrapUnique(mock_secagg_runner_factory_),
753 &mock_secagg_event_publisher_, kEntryPointUri, kApiKey, kPopulationName,
754 kRetryToken, kClientVersion, kAttestationMeasurement,
755 mock_should_abort_.AsStdFunction(), absl::BitGen(),
756 InterruptibleRunner::TimingConfig{
757 .polling_period = absl::ZeroDuration(),
758 .graceful_shutdown_period = absl::InfiniteDuration(),
759 .extended_shutdown_period = absl::InfiniteDuration()},
760 &mock_resource_cache_);
761
762 const ::google::internal::federatedml::v2::RetryWindow& retry_window2 =
763 federated_protocol_->GetLatestRetryWindow();
764 ExpectTransientErrorRetryWindow(retry_window2);
765
766 EXPECT_THAT(retry_window1, Not(EqualsProto(retry_window2)));
767 }
768
TEST_F(HttpFederatedProtocolTest,TestEligibilityEvalCheckinRequestFailsTransientError)769 TEST_F(HttpFederatedProtocolTest,
770 TestEligibilityEvalCheckinRequestFailsTransientError) {
771 // Make the HTTP client return a 503 Service Unavailable error when the
772 // EligibilityEvalCheckin(...) code issues the control protocol's HTTP
773 // request. This should result in the error being returned as the result.
774 EXPECT_CALL(mock_http_client_,
775 PerformSingleRequest(SimpleHttpRequestMatcher(
776 "https://initial.uri/v1/eligibilityevaltasks/"
777 "TEST%2FPOPULATION:request?%24alt=proto",
778 HttpRequest::Method::kPost, _, _)))
779 .WillOnce(Return(FakeHttpResponse(503, HeaderList(), "")));
780
781 auto eligibility_checkin_result = federated_protocol_->EligibilityEvalCheckin(
782 mock_eet_received_callback_.AsStdFunction());
783
784 EXPECT_THAT(eligibility_checkin_result.status(), IsCode(UNAVAILABLE));
785 EXPECT_THAT(eligibility_checkin_result.status().message(),
786 HasSubstr("protocol request failed"));
787 // The original 503 HTTP response code should be included in the message as
788 // well.
789 EXPECT_THAT(eligibility_checkin_result.status().message(), HasSubstr("503"));
790 // No RetryWindows were received from the server, so we expect to get a
791 // RetryWindow generated based on the transient errors retry delay flag.
792 ExpectTransientErrorRetryWindow(federated_protocol_->GetLatestRetryWindow());
793 }
794
TEST_F(HttpFederatedProtocolTest,TestEligibilityEvalCheckinRequestFailsPermanentError)795 TEST_F(HttpFederatedProtocolTest,
796 TestEligibilityEvalCheckinRequestFailsPermanentError) {
797 // Make the HTTP client return a 404 Not Found error when the
798 // EligibilityEvalCheckin(...) code issues the control protocol's HTTP
799 // request. This should result in the error being returned as the result.
800 EXPECT_CALL(mock_http_client_,
801 PerformSingleRequest(SimpleHttpRequestMatcher(
802 "https://initial.uri/v1/eligibilityevaltasks/"
803 "TEST%2FPOPULATION:request?%24alt=proto",
804 HttpRequest::Method::kPost, _, _)))
805 .WillOnce(Return(FakeHttpResponse(404, HeaderList(), "")));
806
807 auto eligibility_checkin_result = federated_protocol_->EligibilityEvalCheckin(
808 mock_eet_received_callback_.AsStdFunction());
809
810 EXPECT_THAT(eligibility_checkin_result.status(), IsCode(NOT_FOUND));
811 EXPECT_THAT(eligibility_checkin_result.status().message(),
812 HasSubstr("protocol request failed"));
813 // The original 404 HTTP response code should be included in the message as
814 // well.
815 EXPECT_THAT(eligibility_checkin_result.status().message(), HasSubstr("404"));
816 // No RetryWindows were received from the server, so we expect to get a
817 // RetryWindow generated based on the *permanent* errors retry delay flag,
818 // since NOT_FOUND is marked as a permanent error in the flags.
819 ExpectPermanentErrorRetryWindow(federated_protocol_->GetLatestRetryWindow());
820 }
821
822 // Tests the case where we get interrupted while waiting for a response to the
823 // protocol request in EligibilityEvalCheckin.
TEST_F(HttpFederatedProtocolTest,TestEligibilityEvalCheckinRequestInterrupted)824 TEST_F(HttpFederatedProtocolTest,
825 TestEligibilityEvalCheckinRequestInterrupted) {
826 absl::Notification request_issued;
827 absl::Notification request_cancelled;
828
829 // Make HttpClient::PerformRequests() block until the counter is decremented.
830 EXPECT_CALL(mock_http_client_,
831 PerformSingleRequest(SimpleHttpRequestMatcher(
832 "https://initial.uri/v1/eligibilityevaltasks/"
833 "TEST%2FPOPULATION:request?%24alt=proto",
834 HttpRequest::Method::kPost, _, _)))
835 .WillOnce([&request_issued, &request_cancelled](
836 MockableHttpClient::SimpleHttpRequest ignored) {
837 request_issued.Notify();
838 request_cancelled.WaitForNotification();
839 return FakeHttpResponse(503, HeaderList(), "");
840 });
841
842 // Make should_abort return false until we know that the request was issued
843 // (i.e. once InterruptibleRunner has actually started running the code it
844 // was given), and then make it return true, triggering an abort sequence and
845 // unblocking the PerformRequests()() call we caused to block above.
846 EXPECT_CALL(mock_should_abort_, Call()).WillRepeatedly([&request_issued] {
847 return request_issued.HasBeenNotified();
848 });
849
850 // When the HttpClient receives a HttpRequestHandle::Cancel call, we let the
851 // request complete.
852 mock_http_client_.SetCancellationListener(
853 [&request_cancelled]() { request_cancelled.Notify(); });
854
855 EXPECT_CALL(mock_log_manager_,
856 LogDiag(ProdDiagCode::BACKGROUND_TRAINING_INTERRUPT_HTTP));
857
858 auto eligibility_checkin_result = federated_protocol_->EligibilityEvalCheckin(
859 mock_eet_received_callback_.AsStdFunction());
860
861 EXPECT_THAT(eligibility_checkin_result.status(), IsCode(CANCELLED));
862 // No RetryWindows were received from the server, so we expect to get a
863 // RetryWindow generated based on the transient errors retry delay flag.
864 ExpectTransientErrorRetryWindow(federated_protocol_->GetLatestRetryWindow());
865 }
866
TEST_F(HttpFederatedProtocolTest,TestEligibilityEvalCheckinRejection)867 TEST_F(HttpFederatedProtocolTest, TestEligibilityEvalCheckinRejection) {
868 EXPECT_CALL(mock_http_client_,
869 PerformSingleRequest(SimpleHttpRequestMatcher(
870 "https://initial.uri/v1/eligibilityevaltasks/"
871 "TEST%2FPOPULATION:request?%24alt=proto",
872 HttpRequest::Method::kPost, _,
873 EligibilityEvalTaskRequestMatcher(
874 EqualsProto(GetExpectedEligibilityEvalTaskRequest())))))
875 .WillOnce(Return(FakeHttpResponse(
876 200, HeaderList(),
877 GetFakeRejectedEligibilityEvalTaskResponse().SerializeAsString())));
878
879 // The 'eet received' callback should not be invoked since no EET was given to
880 // the client.
881 EXPECT_CALL(mock_eet_received_callback_, Call(_)).Times(0);
882
883 auto eligibility_checkin_result = federated_protocol_->EligibilityEvalCheckin(
884 mock_eet_received_callback_.AsStdFunction());
885
886 ASSERT_OK(eligibility_checkin_result);
887 EXPECT_THAT(*eligibility_checkin_result,
888 VariantWith<FederatedProtocol::Rejection>(_));
889 ExpectRejectedRetryWindow(federated_protocol_->GetLatestRetryWindow());
890 }
891
TEST_F(HttpFederatedProtocolTest,TestEligibilityEvalCheckinDisabled)892 TEST_F(HttpFederatedProtocolTest, TestEligibilityEvalCheckinDisabled) {
893 EXPECT_CALL(mock_http_client_,
894 PerformSingleRequest(SimpleHttpRequestMatcher(
895 "https://initial.uri/v1/eligibilityevaltasks/"
896 "TEST%2FPOPULATION:request?%24alt=proto",
897 HttpRequest::Method::kPost, _,
898 EligibilityEvalTaskRequestMatcher(
899 EqualsProto(GetExpectedEligibilityEvalTaskRequest())))))
900 .WillOnce(Return(FakeHttpResponse(
901 200, HeaderList(),
902 GetFakeDisabledEligibilityEvalTaskResponse().SerializeAsString())));
903
904 // The 'eet received' callback should not be invoked since no EET was given to
905 // the client.
906 EXPECT_CALL(mock_eet_received_callback_, Call(_)).Times(0);
907
908 auto eligibility_checkin_result = federated_protocol_->EligibilityEvalCheckin(
909 mock_eet_received_callback_.AsStdFunction());
910
911 ASSERT_OK(eligibility_checkin_result);
912 EXPECT_THAT(*eligibility_checkin_result,
913 VariantWith<FederatedProtocol::EligibilityEvalDisabled>(_));
914 ExpectRejectedRetryWindow(federated_protocol_->GetLatestRetryWindow());
915 }
916
TEST_F(HttpFederatedProtocolTest,TestEligibilityEvalCheckinEnabled)917 TEST_F(HttpFederatedProtocolTest, TestEligibilityEvalCheckinEnabled) {
918 // We return a fake response which requires fetching the plan via HTTP, but
919 // which has the checkpoint data available inline.
920 std::string expected_plan = kPlan;
921 std::string plan_uri = "https://fake.uri/plan";
922 Resource plan_resource;
923 plan_resource.set_uri(plan_uri);
924 std::string expected_checkpoint = kInitCheckpoint;
925 Resource checkpoint_resource;
926 checkpoint_resource.mutable_inline_resource()->set_data(expected_checkpoint);
927 std::string expected_execution_id = kEligibilityEvalExecutionId;
928 EligibilityEvalTaskResponse eval_task_response =
929 GetFakeEnabledEligibilityEvalTaskResponse(
930 plan_resource, checkpoint_resource, expected_execution_id);
931
932 InSequence seq;
933 EXPECT_CALL(mock_http_client_,
934 PerformSingleRequest(SimpleHttpRequestMatcher(
935 "https://initial.uri/v1/eligibilityevaltasks/"
936 "TEST%2FPOPULATION:request?%24alt=proto",
937 HttpRequest::Method::kPost, _,
938 EligibilityEvalTaskRequestMatcher(
939 EqualsProto(GetExpectedEligibilityEvalTaskRequest())))))
940 .WillOnce(Return(FakeHttpResponse(
941 200, HeaderList(), eval_task_response.SerializeAsString())));
942
943 // The 'EET received' callback should be called *before* the actual task
944 // resources are fetched.
945 EXPECT_CALL(mock_eet_received_callback_,
946 Call(FieldsAre(FieldsAre("", ""), expected_execution_id,
947 Eq(std::nullopt))));
948
949 EXPECT_CALL(mock_http_client_,
950 PerformSingleRequest(SimpleHttpRequestMatcher(
951 plan_uri, HttpRequest::Method::kGet, _, "")))
952 .WillOnce(Return(FakeHttpResponse(200, HeaderList(), expected_plan)));
953
954 auto eligibility_checkin_result = federated_protocol_->EligibilityEvalCheckin(
955 mock_eet_received_callback_.AsStdFunction());
956
957 ASSERT_OK(eligibility_checkin_result);
958 EXPECT_THAT(
959 *eligibility_checkin_result,
960 VariantWith<FederatedProtocol::EligibilityEvalTask>(FieldsAre(
961 AllOf(Field(&FederatedProtocol::PlanAndCheckpointPayloads::plan,
962 absl::Cord(expected_plan)),
963 Field(&FederatedProtocol::PlanAndCheckpointPayloads::checkpoint,
964 absl::Cord(expected_checkpoint))),
965 expected_execution_id, Eq(std::nullopt))));
966 ExpectRejectedRetryWindow(federated_protocol_->GetLatestRetryWindow());
967 }
968
TEST_F(HttpFederatedProtocolTest,TestEligibilityEvalCheckinWithPopulationEligibilitySpec)969 TEST_F(HttpFederatedProtocolTest,
970 TestEligibilityEvalCheckinWithPopulationEligibilitySpec) {
971 EXPECT_CALL(mock_flags_, http_protocol_supports_multiple_task_assignments)
972 .WillRepeatedly(Return(true));
973 // We return a fake response which requires fetching the plan via HTTP,
974 // but which has the checkpoint data available inline.
975 std::string expected_plan = kPlan;
976 std::string plan_uri = "https://fake.uri/plan";
977 Resource plan_resource;
978 plan_resource.set_uri(plan_uri);
979 std::string expected_checkpoint = kInitCheckpoint;
980 Resource checkpoint_resource;
981 checkpoint_resource.mutable_inline_resource()->set_data(expected_checkpoint);
982
983 PopulationEligibilitySpec expected_population_eligibility_spec;
984 auto task_info = expected_population_eligibility_spec.add_task_info();
985 task_info->set_task_name("task_1");
986 task_info->set_task_assignment_mode(
987 PopulationEligibilitySpec::TaskInfo::TASK_ASSIGNMENT_MODE_MULTIPLE);
988 std::string population_eligibility_spec_uri =
989 "https://fake.uri/population_eligibility_spec";
990 Resource population_eligibility_spec;
991 population_eligibility_spec.set_uri(population_eligibility_spec_uri);
992 std::string expected_execution_id = kEligibilityEvalExecutionId;
993 EligibilityEvalTaskResponse eval_task_response =
994 GetFakeEnabledEligibilityEvalTaskResponse(
995 plan_resource, checkpoint_resource, expected_execution_id,
996 population_eligibility_spec);
997
998 InSequence seq;
999 EXPECT_CALL(mock_http_client_,
1000 PerformSingleRequest(SimpleHttpRequestMatcher(
1001 "https://initial.uri/v1/eligibilityevaltasks/"
1002 "TEST%2FPOPULATION:request?%24alt=proto",
1003 HttpRequest::Method::kPost, _,
1004 EligibilityEvalTaskRequestMatcher(
1005 EqualsProto(GetExpectedEligibilityEvalTaskRequest(
1006 /* supports_multiple_task_assignments= */ true))))))
1007 .WillOnce(Return(FakeHttpResponse(
1008 200, HeaderList(), eval_task_response.SerializeAsString())));
1009
1010 // The 'EET received' callback should be called *before* the actual task
1011 // resources are fetched.
1012 EXPECT_CALL(mock_eet_received_callback_,
1013 Call(FieldsAre(FieldsAre("", ""), expected_execution_id,
1014 Eq(std::nullopt))));
1015
1016 EXPECT_CALL(mock_http_client_,
1017 PerformSingleRequest(SimpleHttpRequestMatcher(
1018 plan_uri, HttpRequest::Method::kGet, _, "")))
1019 .WillOnce(Return(FakeHttpResponse(200, HeaderList(), expected_plan)));
1020 EXPECT_CALL(mock_http_client_, PerformSingleRequest(SimpleHttpRequestMatcher(
1021 population_eligibility_spec_uri,
1022 HttpRequest::Method::kGet, _, "")))
1023 .WillOnce(Return(FakeHttpResponse(
1024 200, HeaderList(),
1025 expected_population_eligibility_spec.SerializeAsString())));
1026
1027 auto eligibility_checkin_result = federated_protocol_->EligibilityEvalCheckin(
1028 mock_eet_received_callback_.AsStdFunction());
1029
1030 ASSERT_OK(eligibility_checkin_result);
1031 EXPECT_THAT(
1032 *eligibility_checkin_result,
1033 VariantWith<FederatedProtocol::EligibilityEvalTask>(FieldsAre(
1034 AllOf(Field(&FederatedProtocol::PlanAndCheckpointPayloads::plan,
1035 absl::Cord(expected_plan)),
1036 Field(&FederatedProtocol::PlanAndCheckpointPayloads::checkpoint,
1037 absl::Cord(expected_checkpoint))),
1038 expected_execution_id,
1039 Optional(EqualsProto(expected_population_eligibility_spec)))));
1040 ExpectRejectedRetryWindow(federated_protocol_->GetLatestRetryWindow());
1041 }
1042
TEST_F(HttpFederatedProtocolTest,TestEligibilityEvalCheckinWithPopulationEligibilitySpecInvalidFormat)1043 TEST_F(HttpFederatedProtocolTest,
1044 TestEligibilityEvalCheckinWithPopulationEligibilitySpecInvalidFormat) {
1045 EXPECT_CALL(mock_flags_, http_protocol_supports_multiple_task_assignments)
1046 .WillRepeatedly(Return(true));
1047 // We return a fake response which requires fetching the plan via HTTP,
1048 // but which has the checkpoint data available inline.
1049 std::string expected_plan = kPlan;
1050 std::string plan_uri = "https://fake.uri/plan";
1051 Resource plan_resource;
1052 plan_resource.set_uri(plan_uri);
1053 std::string expected_checkpoint = kInitCheckpoint;
1054 Resource checkpoint_resource;
1055 checkpoint_resource.mutable_inline_resource()->set_data(expected_checkpoint);
1056
1057 Resource population_eligibility_spec;
1058 population_eligibility_spec.mutable_inline_resource()->set_data(
1059 "Invalid_spec");
1060 std::string expected_execution_id = kEligibilityEvalExecutionId;
1061 EligibilityEvalTaskResponse eval_task_response =
1062 GetFakeEnabledEligibilityEvalTaskResponse(
1063 plan_resource, checkpoint_resource, expected_execution_id,
1064 population_eligibility_spec);
1065
1066 InSequence seq;
1067 EXPECT_CALL(mock_http_client_,
1068 PerformSingleRequest(SimpleHttpRequestMatcher(
1069 "https://initial.uri/v1/eligibilityevaltasks/"
1070 "TEST%2FPOPULATION:request?%24alt=proto",
1071 HttpRequest::Method::kPost, _,
1072 EligibilityEvalTaskRequestMatcher(
1073 EqualsProto(GetExpectedEligibilityEvalTaskRequest(
1074 /* supports_multiple_task_assignments= */ true))))))
1075 .WillOnce(Return(FakeHttpResponse(
1076 200, HeaderList(), eval_task_response.SerializeAsString())));
1077
1078 // The 'EET received' callback should be called *before* the actual task
1079 // resources are fetched.
1080 EXPECT_CALL(mock_eet_received_callback_,
1081 Call(FieldsAre(FieldsAre("", ""), expected_execution_id,
1082 Eq(std::nullopt))));
1083
1084 EXPECT_CALL(mock_http_client_,
1085 PerformSingleRequest(SimpleHttpRequestMatcher(
1086 plan_uri, HttpRequest::Method::kGet, _, "")))
1087 .WillOnce(Return(FakeHttpResponse(200, HeaderList(), expected_plan)));
1088
1089 auto eligibility_checkin_result = federated_protocol_->EligibilityEvalCheckin(
1090 mock_eet_received_callback_.AsStdFunction());
1091
1092 ASSERT_THAT(eligibility_checkin_result, IsCode(INVALID_ARGUMENT));
1093 ExpectPermanentErrorRetryWindow(federated_protocol_->GetLatestRetryWindow());
1094 }
1095
TEST_F(HttpFederatedProtocolTest,TestEligibilityEvalCheckinEnabledWithCompression)1096 TEST_F(HttpFederatedProtocolTest,
1097 TestEligibilityEvalCheckinEnabledWithCompression) {
1098 std::string expected_plan = kPlan;
1099 absl::StatusOr<std::string> compressed_plan =
1100 internal::CompressWithGzip(expected_plan);
1101 ASSERT_OK(compressed_plan);
1102 Resource plan_resource;
1103 plan_resource.mutable_inline_resource()->set_data(*compressed_plan);
1104 plan_resource.mutable_inline_resource()->set_compression_format(
1105 ResourceCompressionFormat::RESOURCE_COMPRESSION_FORMAT_GZIP);
1106 std::string expected_checkpoint = kInitCheckpoint;
1107 absl::StatusOr<std::string> compressed_checkpoint =
1108 internal::CompressWithGzip(expected_checkpoint);
1109 Resource checkpoint_resource;
1110 checkpoint_resource.mutable_inline_resource()->set_data(
1111 *compressed_checkpoint);
1112 checkpoint_resource.mutable_inline_resource()->set_compression_format(
1113 ResourceCompressionFormat::RESOURCE_COMPRESSION_FORMAT_GZIP);
1114 std::string expected_execution_id = kEligibilityEvalExecutionId;
1115 EligibilityEvalTaskResponse eval_task_response =
1116 GetFakeEnabledEligibilityEvalTaskResponse(
1117 plan_resource, checkpoint_resource, expected_execution_id);
1118 EXPECT_CALL(mock_http_client_,
1119 PerformSingleRequest(SimpleHttpRequestMatcher(
1120 "https://initial.uri/v1/eligibilityevaltasks/"
1121 "TEST%2FPOPULATION:request?%24alt=proto",
1122 HttpRequest::Method::kPost, _, _)))
1123 .WillOnce(Return(FakeHttpResponse(
1124 200, HeaderList(), eval_task_response.SerializeAsString())));
1125
1126 auto eligibility_checkin_result = federated_protocol_->EligibilityEvalCheckin(
1127 mock_eet_received_callback_.AsStdFunction());
1128
1129 ASSERT_OK(eligibility_checkin_result);
1130 EXPECT_THAT(
1131 *eligibility_checkin_result,
1132 VariantWith<FederatedProtocol::EligibilityEvalTask>(FieldsAre(
1133 AllOf(Field(&FederatedProtocol::PlanAndCheckpointPayloads::plan,
1134 absl::Cord(expected_plan)),
1135 Field(&FederatedProtocol::PlanAndCheckpointPayloads::checkpoint,
1136 absl::Cord(expected_checkpoint))),
1137 expected_execution_id, Eq(std::nullopt))));
1138 ExpectRejectedRetryWindow(federated_protocol_->GetLatestRetryWindow());
1139 }
1140
1141 // Ensures that if the plan resource fails to be downloaded, the error is
1142 // correctly returned from the EligibilityEvalCheckin(...) method.
TEST_F(HttpFederatedProtocolTest,TestEligibilityEvalCheckinEnabledPlanDataFetchFailed)1143 TEST_F(HttpFederatedProtocolTest,
1144 TestEligibilityEvalCheckinEnabledPlanDataFetchFailed) {
1145 std::string plan_uri = "https://fake.uri/plan";
1146 Resource plan_resource;
1147 plan_resource.set_uri(plan_uri);
1148 std::string checkpoint_uri = "https://fake.uri/checkpoint";
1149 Resource checkpoint_resource;
1150 checkpoint_resource.set_uri(checkpoint_uri);
1151 std::string expected_execution_id = kEligibilityEvalExecutionId;
1152 EligibilityEvalTaskResponse eval_task_response =
1153 GetFakeEnabledEligibilityEvalTaskResponse(
1154 plan_resource, checkpoint_resource, expected_execution_id);
1155 EXPECT_CALL(mock_http_client_,
1156 PerformSingleRequest(SimpleHttpRequestMatcher(
1157 "https://initial.uri/v1/eligibilityevaltasks/"
1158 "TEST%2FPOPULATION:request?%24alt=proto",
1159 HttpRequest::Method::kPost, _, _)))
1160 .WillOnce(Return(FakeHttpResponse(
1161 200, HeaderList(), eval_task_response.SerializeAsString())));
1162
1163 EXPECT_CALL(mock_http_client_,
1164 PerformSingleRequest(SimpleHttpRequestMatcher(
1165 checkpoint_uri, HttpRequest::Method::kGet, _, "")))
1166 .WillOnce(Return(FakeHttpResponse(200, HeaderList(), "")));
1167
1168 // Mock a failed plan fetch.
1169 EXPECT_CALL(mock_http_client_,
1170 PerformSingleRequest(SimpleHttpRequestMatcher(
1171 plan_uri, HttpRequest::Method::kGet, _, "")))
1172 .WillOnce(Return(FakeHttpResponse(404, HeaderList(), "")));
1173
1174 auto eligibility_checkin_result = federated_protocol_->EligibilityEvalCheckin(
1175 mock_eet_received_callback_.AsStdFunction());
1176
1177 // The 404 error for the resource request should be reflected in the return
1178 // value.
1179 EXPECT_THAT(eligibility_checkin_result.status(), IsCode(NOT_FOUND));
1180 EXPECT_THAT(eligibility_checkin_result.status().message(),
1181 HasSubstr("plan fetch failed"));
1182 // The original 404 HTTP response code should be included in the message as
1183 // well.
1184 EXPECT_THAT(eligibility_checkin_result.status().message(), HasSubstr("404"));
1185 // Since the error type is considered a permanent error, we should get a
1186 // permanent error retry window.
1187 ExpectPermanentErrorRetryWindow(federated_protocol_->GetLatestRetryWindow());
1188 }
1189
1190 // Ensures that if the checkpoint resource fails to be downloaded, the error is
1191 // correctly returned from the EligibilityEvalCheckin(...) method.
TEST_F(HttpFederatedProtocolTest,TestEligibilityEvalCheckinEnabledCheckpointDataFetchFailed)1192 TEST_F(HttpFederatedProtocolTest,
1193 TestEligibilityEvalCheckinEnabledCheckpointDataFetchFailed) {
1194 std::string plan_uri = "https://fake.uri/plan";
1195 Resource plan_resource;
1196 plan_resource.set_uri(plan_uri);
1197 std::string checkpoint_uri = "https://fake.uri/checkpoint";
1198 Resource checkpoint_resource;
1199 checkpoint_resource.set_uri(checkpoint_uri);
1200 std::string expected_execution_id = kEligibilityEvalExecutionId;
1201 EligibilityEvalTaskResponse eval_task_response =
1202 GetFakeEnabledEligibilityEvalTaskResponse(
1203 plan_resource, checkpoint_resource, expected_execution_id);
1204 EXPECT_CALL(mock_http_client_,
1205 PerformSingleRequest(SimpleHttpRequestMatcher(
1206 "https://initial.uri/v1/eligibilityevaltasks/"
1207 "TEST%2FPOPULATION:request?%24alt=proto",
1208 HttpRequest::Method::kPost, _, _)))
1209 .WillOnce(Return(FakeHttpResponse(
1210 200, HeaderList(), eval_task_response.SerializeAsString())));
1211
1212 // Mock a failed checkpoint fetch.
1213 EXPECT_CALL(mock_http_client_,
1214 PerformSingleRequest(SimpleHttpRequestMatcher(
1215 checkpoint_uri, HttpRequest::Method::kGet, _, "")))
1216 .WillOnce(Return(FakeHttpResponse(503, HeaderList(), "")));
1217
1218 EXPECT_CALL(mock_http_client_,
1219 PerformSingleRequest(SimpleHttpRequestMatcher(
1220 plan_uri, HttpRequest::Method::kGet, _, "")))
1221 .WillOnce(Return(FakeHttpResponse(200, HeaderList(), "")));
1222
1223 auto eligibility_checkin_result = federated_protocol_->EligibilityEvalCheckin(
1224 mock_eet_received_callback_.AsStdFunction());
1225
1226 // The 503 error for the resource request should be reflected in the return
1227 // value.
1228 EXPECT_THAT(eligibility_checkin_result.status(), IsCode(UNAVAILABLE));
1229 EXPECT_THAT(eligibility_checkin_result.status().message(),
1230 HasSubstr("checkpoint fetch failed"));
1231 // The original 503 HTTP response code should be included in the message as
1232 // well.
1233 EXPECT_THAT(eligibility_checkin_result.status().message(), HasSubstr("503"));
1234 // RetryWindows were received from the server before the error was received,
1235 // and the error is considered 'transient', so we expect to get a rejected
1236 // RetryWindow.
1237 ExpectRejectedRetryWindow(federated_protocol_->GetLatestRetryWindow());
1238 }
1239
TEST_F(HttpFederatedProtocolTest,TestReportEligibilityEvalTaskResult)1240 TEST_F(HttpFederatedProtocolTest, TestReportEligibilityEvalTaskResult) {
1241 ASSERT_OK(RunSuccessfulEligibilityEvalCheckin());
1242 std::string report_eet_request_uri =
1243 "https://initial.uri/v1/populations/TEST%2FPOPULATION/"
1244 "eligibilityevaltasks/"
1245 "ELIGIBILITY%2FSESSION%23ID:reportresult?%24alt=proto";
1246 ReportEligibilityEvalTaskResultRequest report_eet_request;
1247 report_eet_request.set_status_code(
1248 static_cast<google::rpc::Code>(absl::StatusCode::kCancelled));
1249 EXPECT_CALL(mock_http_client_,
1250 PerformSingleRequest(SimpleHttpRequestMatcher(
1251 report_eet_request_uri, HttpRequest::Method::kPost, _,
1252 ReportEligibilityEvalTaskResultRequestMatcher(
1253 EqualsProto(report_eet_request)))))
1254 .WillOnce(Return(FakeHttpResponse(200, HeaderList(), "")));
1255
1256 federated_protocol_->ReportEligibilityEvalError(absl::CancelledError());
1257 }
1258
1259 // Tests that the protocol correctly sanitizes any invalid values it may have
1260 // received from the server.
TEST_F(HttpFederatedProtocolTest,TestNegativeMinMaxRetryDelayValueSanitization)1261 TEST_F(HttpFederatedProtocolTest,
1262 TestNegativeMinMaxRetryDelayValueSanitization) {
1263 RetryWindow retry_window;
1264 retry_window.mutable_delay_min()->set_seconds(-1);
1265 retry_window.mutable_delay_max()->set_seconds(-2);
1266
1267 // The above retry window's negative min/max values should be clamped to 0.
1268 RetryWindow expected_retry_window;
1269 expected_retry_window.mutable_delay_min()->set_seconds(0);
1270 expected_retry_window.mutable_delay_max()->set_seconds(0);
1271
1272 EligibilityEvalTaskResponse eval_task_response =
1273 GetFakeEnabledEligibilityEvalTaskResponse(
1274 Resource(), Resource(), kEligibilityEvalExecutionId, std::nullopt,
1275 retry_window, retry_window);
1276 EXPECT_CALL(mock_http_client_,
1277 PerformSingleRequest(SimpleHttpRequestMatcher(
1278 "https://initial.uri/v1/eligibilityevaltasks/"
1279 "TEST%2FPOPULATION:request?%24alt=proto",
1280 HttpRequest::Method::kPost, _, _)))
1281 .WillOnce(Return(FakeHttpResponse(
1282 200, HeaderList(), eval_task_response.SerializeAsString())));
1283
1284 ASSERT_OK(federated_protocol_->EligibilityEvalCheckin(
1285 mock_eet_received_callback_.AsStdFunction()));
1286
1287 const google::internal::federatedml::v2::RetryWindow& actual_retry_window =
1288 federated_protocol_->GetLatestRetryWindow();
1289 // The above retry window's invalid max value should be clamped to the min
1290 // value (minus some errors introduced by the inaccuracy of double
1291 // multiplication).
1292 EXPECT_THAT(actual_retry_window.delay_min().seconds() +
1293 actual_retry_window.delay_min().nanos() / 1000000000.0,
1294 DoubleEq(0));
1295 EXPECT_THAT(actual_retry_window.delay_max().seconds() +
1296 actual_retry_window.delay_max().nanos() / 1000000000.0,
1297 DoubleEq(0));
1298 }
1299
1300 // Tests that the protocol correctly sanitizes any invalid values it may have
1301 // received from the server.
TEST_F(HttpFederatedProtocolTest,TestInvalidMaxRetryDelayValueSanitization)1302 TEST_F(HttpFederatedProtocolTest, TestInvalidMaxRetryDelayValueSanitization) {
1303 RetryWindow retry_window;
1304 retry_window.mutable_delay_min()->set_seconds(1234);
1305 retry_window.mutable_delay_max()->set_seconds(1233); // less than delay_min
1306
1307 EligibilityEvalTaskResponse eval_task_response =
1308 GetFakeEnabledEligibilityEvalTaskResponse(
1309 Resource(), Resource(), kEligibilityEvalExecutionId, std::nullopt,
1310 retry_window, retry_window);
1311 EXPECT_CALL(mock_http_client_,
1312 PerformSingleRequest(SimpleHttpRequestMatcher(
1313 "https://initial.uri/v1/eligibilityevaltasks/"
1314 "TEST%2FPOPULATION:request?%24alt=proto",
1315 HttpRequest::Method::kPost, _, _)))
1316 .WillOnce(Return(FakeHttpResponse(
1317 200, HeaderList(), eval_task_response.SerializeAsString())));
1318
1319 ASSERT_OK(federated_protocol_->EligibilityEvalCheckin(
1320 mock_eet_received_callback_.AsStdFunction()));
1321
1322 const google::internal::federatedml::v2::RetryWindow& actual_retry_window =
1323 federated_protocol_->GetLatestRetryWindow();
1324 // The above retry window's invalid max value should be clamped to the min
1325 // value (minus some errors introduced by the inaccuracy of double
1326 // multiplication). Note that DoubleEq enforces too precise of bounds, so we
1327 // use DoubleNear instead.
1328 EXPECT_THAT(actual_retry_window.delay_min().seconds() +
1329 actual_retry_window.delay_min().nanos() / 1000000000.0,
1330 DoubleNear(1234.0, 0.015));
1331 EXPECT_THAT(actual_retry_window.delay_max().seconds() +
1332 actual_retry_window.delay_max().nanos() / 1000000000.0,
1333 DoubleNear(1234.0, 0.015));
1334 }
1335
TEST_F(HttpFederatedProtocolDeathTest,TestCheckinAfterFailedEligibilityEvalCheckin)1336 TEST_F(HttpFederatedProtocolDeathTest,
1337 TestCheckinAfterFailedEligibilityEvalCheckin) {
1338 // Make the HTTP client return a 503 Service Unavailable error when the
1339 // EligibilityEvalCheckin(...) code issues the protocol HTTP request.
1340 // This should result in the error being returned as the result.
1341 EXPECT_CALL(mock_http_client_,
1342 PerformSingleRequest(SimpleHttpRequestMatcher(
1343 "https://initial.uri/v1/eligibilityevaltasks/"
1344 "TEST%2FPOPULATION:request?%24alt=proto",
1345 HttpRequest::Method::kPost, _, _)))
1346 .WillOnce(Return(FakeHttpResponse(503, HeaderList(), "")));
1347
1348 auto eligibility_checkin_result = federated_protocol_->EligibilityEvalCheckin(
1349 mock_eet_received_callback_.AsStdFunction());
1350
1351 EXPECT_THAT(eligibility_checkin_result.status(), IsCode(UNAVAILABLE));
1352
1353 // A Checkin(...) request should now fail, because Checkin(...) should only
1354 // be a called after a successful EligibilityEvalCheckin(...) request.
1355 ASSERT_DEATH(
1356 {
1357 auto unused = federated_protocol_->Checkin(
1358 std::nullopt, mock_task_received_callback_.AsStdFunction());
1359 },
1360 _);
1361 }
1362
TEST_F(HttpFederatedProtocolDeathTest,TestCheckinAfterEligibilityEvalCheckinRejection)1363 TEST_F(HttpFederatedProtocolDeathTest,
1364 TestCheckinAfterEligibilityEvalCheckinRejection) {
1365 EXPECT_CALL(mock_http_client_,
1366 PerformSingleRequest(SimpleHttpRequestMatcher(
1367 "https://initial.uri/v1/eligibilityevaltasks/"
1368 "TEST%2FPOPULATION:request?%24alt=proto",
1369 HttpRequest::Method::kPost, _,
1370 EligibilityEvalTaskRequestMatcher(
1371 EqualsProto(GetExpectedEligibilityEvalTaskRequest())))))
1372 .WillOnce(Return(FakeHttpResponse(
1373 200, HeaderList(),
1374 GetFakeRejectedEligibilityEvalTaskResponse().SerializeAsString())));
1375
1376 ASSERT_OK(federated_protocol_->EligibilityEvalCheckin(
1377 mock_eet_received_callback_.AsStdFunction()));
1378
1379 // A Checkin(...) request should now fail, because Checkin(...) should only
1380 // be a called after a successful EligibilityEvalCheckin(...) request, with a
1381 // non-rejection response.
1382 ASSERT_DEATH(
1383 {
1384 auto unused = federated_protocol_->Checkin(
1385 std::nullopt, mock_task_received_callback_.AsStdFunction());
1386 },
1387 _);
1388 }
1389
TEST_F(HttpFederatedProtocolDeathTest,TestCheckinWithEligibilityInfoAfterEligibilityEvalCheckinDisabled)1390 TEST_F(HttpFederatedProtocolDeathTest,
1391 TestCheckinWithEligibilityInfoAfterEligibilityEvalCheckinDisabled) {
1392 ASSERT_OK(
1393 RunSuccessfulEligibilityEvalCheckin(/*eligibility_eval_enabled=*/false));
1394
1395 // A Checkin(...) request with a TaskEligibilityInfo argument should now fail,
1396 // because such info should only be passed a successful
1397 // EligibilityEvalCheckin(...) request with an eligibility eval task in the
1398 // response.
1399 ASSERT_DEATH(
1400 {
1401 auto unused = federated_protocol_->Checkin(
1402 TaskEligibilityInfo(),
1403 mock_task_received_callback_.AsStdFunction());
1404 },
1405 _);
1406 }
1407
TEST_F(HttpFederatedProtocolDeathTest,TestCheckinWithMissingEligibilityInfo)1408 TEST_F(HttpFederatedProtocolDeathTest, TestCheckinWithMissingEligibilityInfo) {
1409 ASSERT_OK(
1410 RunSuccessfulEligibilityEvalCheckin(/*eligibility_eval_enabled=*/true));
1411
1412 // A Checkin(...) request with a missing TaskEligibilityInfo should now fail,
1413 // as the protocol requires us to provide one based on the plan includes in
1414 // the eligibility eval checkin response payload..
1415 ASSERT_DEATH(
1416 {
1417 auto unused = federated_protocol_->Checkin(
1418 std::nullopt, mock_task_received_callback_.AsStdFunction());
1419 },
1420 _);
1421 }
1422
TEST_F(HttpFederatedProtocolDeathTest,TestCheckinAfterEligibilityEvalResourceDataFetchFailed)1423 TEST_F(HttpFederatedProtocolDeathTest,
1424 TestCheckinAfterEligibilityEvalResourceDataFetchFailed) {
1425 Resource plan_resource;
1426 plan_resource.set_uri("https://fake.uri/plan");
1427 Resource checkpoint_resource;
1428 checkpoint_resource.set_uri("https://fake.uri/checkpoint");
1429 EligibilityEvalTaskResponse eval_task_response =
1430 GetFakeEnabledEligibilityEvalTaskResponse(
1431 plan_resource, checkpoint_resource, kEligibilityEvalExecutionId);
1432 EXPECT_CALL(mock_http_client_,
1433 PerformSingleRequest(SimpleHttpRequestMatcher(
1434 "https://initial.uri/v1/eligibilityevaltasks/"
1435 "TEST%2FPOPULATION:request?%24alt=proto",
1436 HttpRequest::Method::kPost, _, _)))
1437 .WillOnce(Return(FakeHttpResponse(
1438 200, HeaderList(), eval_task_response.SerializeAsString())));
1439
1440 // Mock a failed plan/resource fetch.
1441 EXPECT_CALL(mock_http_client_, PerformSingleRequest(SimpleHttpRequestMatcher(
1442 _, HttpRequest::Method::kGet, _, "")))
1443 .WillRepeatedly(Return(FakeHttpResponse(503, HeaderList(), "")));
1444
1445 auto eligibility_checkin_result = federated_protocol_->EligibilityEvalCheckin(
1446 mock_eet_received_callback_.AsStdFunction());
1447
1448 // A Checkin(...) request should now fail, because Checkin(...) should only
1449 // be a called after a successful EligibilityEvalCheckin(...) request, with a
1450 // non-rejection response.
1451 ASSERT_DEATH(
1452 {
1453 auto unused = federated_protocol_->Checkin(
1454 TaskEligibilityInfo(),
1455 mock_task_received_callback_.AsStdFunction());
1456 },
1457 _);
1458 }
1459
1460 // Ensures that if the HTTP layer returns an error code that maps to a transient
1461 // error, it is handled correctly
TEST_F(HttpFederatedProtocolTest,TestCheckinFailsTransientError)1462 TEST_F(HttpFederatedProtocolTest, TestCheckinFailsTransientError) {
1463 // Issue an eligibility eval checkin first.
1464 ASSERT_OK(RunSuccessfulEligibilityEvalCheckin());
1465 std::string report_eet_request_uri =
1466 "https://initial.uri/v1/populations/TEST%2FPOPULATION/"
1467 "eligibilityevaltasks/"
1468 "ELIGIBILITY%2FSESSION%23ID:reportresult?%24alt=proto";
1469 ExpectSuccessfulReportEligibilityEvalTaskResultRequest(report_eet_request_uri,
1470 absl::OkStatus());
1471
1472 // Make the HTTP request return an 503 Service Unavailable error when the
1473 // Checkin(...) code tries to send its first request. This should result in
1474 // the error being returned as the result.
1475 EXPECT_CALL(
1476 mock_http_client_,
1477 PerformSingleRequest(SimpleHttpRequestMatcher(
1478 "https://taskassignment.uri/v1/populations/TEST%2FPOPULATION/"
1479 "taskassignments/ELIGIBILITY%2FSESSION%23ID:start?%24alt=proto",
1480 HttpRequest::Method::kPost, _, _)))
1481 .WillOnce(Return(FakeHttpResponse(503, HeaderList(), "")));
1482
1483 auto checkin_result = federated_protocol_->Checkin(
1484 GetFakeTaskEligibilityInfo(),
1485 mock_task_received_callback_.AsStdFunction());
1486
1487 EXPECT_THAT(checkin_result.status(), IsCode(UNAVAILABLE));
1488 // The original 503 HTTP response code should be included in the message as
1489 // well.
1490 EXPECT_THAT(checkin_result.status().message(), HasSubstr("503"));
1491 // RetryWindows were already received from the server during the eligibility
1492 // eval checkin, so we expect to get a 'rejected' retry window.
1493 ExpectRejectedRetryWindow(federated_protocol_->GetLatestRetryWindow());
1494 }
1495
1496 // Ensures that if the HTTP layer returns an error code that maps to a permanent
1497 // error, it is handled correctly.
TEST_F(HttpFederatedProtocolTest,TestCheckinFailsPermanentErrorFromHttp)1498 TEST_F(HttpFederatedProtocolTest, TestCheckinFailsPermanentErrorFromHttp) {
1499 // Issue an eligibility eval checkin first.
1500 ASSERT_OK(RunSuccessfulEligibilityEvalCheckin());
1501 std::string report_eet_request_uri =
1502 "https://initial.uri/v1/populations/TEST%2FPOPULATION/"
1503 "eligibilityevaltasks/"
1504 "ELIGIBILITY%2FSESSION%23ID:reportresult?%24alt=proto";
1505 ExpectSuccessfulReportEligibilityEvalTaskResultRequest(report_eet_request_uri,
1506 absl::OkStatus());
1507
1508 // Make the HTTP request return an 404 Not Found error when the Checkin(...)
1509 // code tries to send its first request. This should result in the error being
1510 // returned as the result.
1511 EXPECT_CALL(
1512 mock_http_client_,
1513 PerformSingleRequest(SimpleHttpRequestMatcher(
1514 "https://taskassignment.uri/v1/populations/TEST%2FPOPULATION/"
1515 "taskassignments/ELIGIBILITY%2FSESSION%23ID:start?%24alt=proto",
1516 HttpRequest::Method::kPost, _, _)))
1517 .WillOnce(Return(FakeHttpResponse(404, HeaderList(), "")));
1518
1519 auto checkin_result = federated_protocol_->Checkin(
1520 GetFakeTaskEligibilityInfo(),
1521 mock_task_received_callback_.AsStdFunction());
1522
1523 EXPECT_THAT(checkin_result.status(), IsCode(NOT_FOUND));
1524 // The original 503 HTTP response code should be included in the message as
1525 // well.
1526 EXPECT_THAT(checkin_result.status().message(), HasSubstr("404"));
1527 // Even though RetryWindows were already received from the server during the
1528 // eligibility eval checkin, we expect a RetryWindow generated based on the
1529 // *permanent* errors retry delay flag, since NOT_FOUND is marked as a
1530 // permanent error in the flags, and permanent errors should always result in
1531 // permanent error windows (regardless of whether retry windows were already
1532 // received).
1533 ExpectPermanentErrorRetryWindow(federated_protocol_->GetLatestRetryWindow());
1534 }
1535
1536 // Ensures that if the HTTP layer returns a successful response, but it contains
1537 // an Operation proto with a permanent error, that it is handled correctly.
TEST_F(HttpFederatedProtocolTest,TestCheckinFailsPermanentErrorFromOperation)1538 TEST_F(HttpFederatedProtocolTest, TestCheckinFailsPermanentErrorFromOperation) {
1539 // Issue an eligibility eval checkin first.
1540 ASSERT_OK(RunSuccessfulEligibilityEvalCheckin());
1541 std::string report_eet_request_uri =
1542 "https://initial.uri/v1/populations/TEST%2FPOPULATION/"
1543 "eligibilityevaltasks/"
1544 "ELIGIBILITY%2FSESSION%23ID:reportresult?%24alt=proto";
1545 ExpectSuccessfulReportEligibilityEvalTaskResultRequest(report_eet_request_uri,
1546 absl::OkStatus());
1547
1548 // Make the HTTP request return successfully, but make it contain an Operation
1549 // proto that itself contains a permanent error. This should result in the
1550 // error being returned as the result.
1551 EXPECT_CALL(
1552 mock_http_client_,
1553 PerformSingleRequest(SimpleHttpRequestMatcher(
1554 "https://taskassignment.uri/v1/populations/TEST%2FPOPULATION/"
1555 "taskassignments/ELIGIBILITY%2FSESSION%23ID:start?%24alt=proto",
1556 HttpRequest::Method::kPost, _, _)))
1557 .WillOnce(Return(FakeHttpResponse(
1558 200, HeaderList(),
1559 CreateErrorOperation(kOperationName, absl::StatusCode::kNotFound,
1560 "foo")
1561 .SerializeAsString())));
1562
1563 auto checkin_result = federated_protocol_->Checkin(
1564 GetFakeTaskEligibilityInfo(),
1565 mock_task_received_callback_.AsStdFunction());
1566
1567 EXPECT_THAT(checkin_result.status(), IsCode(NOT_FOUND));
1568 EXPECT_THAT(checkin_result.status().message(),
1569 HasSubstr("Operation my_operation contained error"));
1570 // The original error message should be included in the message as well.
1571 EXPECT_THAT(checkin_result.status().message(), HasSubstr("foo"));
1572 // Even though RetryWindows were already received from the server during the
1573 // eligibility eval checkin, we expect a RetryWindow generated based on the
1574 // *permanent* errors retry delay flag, since NOT_FOUND is marked as a
1575 // permanent error in the flags, and permanent errors should always result in
1576 // permanent error windows (regardless of whether retry windows were already
1577 // received).
1578 ExpectPermanentErrorRetryWindow(federated_protocol_->GetLatestRetryWindow());
1579 }
1580
1581 // Tests the case where we get interrupted while waiting for a response to the
1582 // protocol request in Checkin.
TEST_F(HttpFederatedProtocolTest,TestCheckinInterrupted)1583 TEST_F(HttpFederatedProtocolTest, TestCheckinInterrupted) {
1584 // Issue an eligibility eval checkin first.
1585 ASSERT_OK(RunSuccessfulEligibilityEvalCheckin());
1586 std::string report_eet_request_uri =
1587 "https://initial.uri/v1/populations/TEST%2FPOPULATION/"
1588 "eligibilityevaltasks/"
1589 "ELIGIBILITY%2FSESSION%23ID:reportresult?%24alt=proto";
1590 ExpectSuccessfulReportEligibilityEvalTaskResultRequest(report_eet_request_uri,
1591 absl::OkStatus());
1592
1593 absl::Notification request_issued;
1594 absl::Notification request_cancelled;
1595
1596 // Make HttpClient::PerformRequests() block until the counter is decremented.
1597 EXPECT_CALL(
1598 mock_http_client_,
1599 PerformSingleRequest(SimpleHttpRequestMatcher(
1600 "https://taskassignment.uri/v1/populations/TEST%2FPOPULATION/"
1601 "taskassignments/ELIGIBILITY%2FSESSION%23ID:start?%24alt=proto",
1602 HttpRequest::Method::kPost, _, _)))
1603 .WillOnce([&request_issued, &request_cancelled](
1604 MockableHttpClient::SimpleHttpRequest ignored) {
1605 request_issued.Notify();
1606 request_cancelled.WaitForNotification();
1607 return FakeHttpResponse(503, HeaderList(), "");
1608 });
1609
1610 // Make should_abort return false until we know that the request was issued
1611 // (i.e. once InterruptibleRunner has actually started running the code it
1612 // was given), and then make it return true, triggering an abort sequence and
1613 // unblocking the PerformRequests()() call we caused to block above.
1614 EXPECT_CALL(mock_should_abort_, Call()).WillRepeatedly([&request_issued] {
1615 return request_issued.HasBeenNotified();
1616 });
1617
1618 // When the HttpClient receives a HttpRequestHandle::Cancel call, we let the
1619 // request complete.
1620 mock_http_client_.SetCancellationListener([&request_cancelled]() {
1621 if (!request_cancelled.HasBeenNotified()) {
1622 request_cancelled.Notify();
1623 }
1624 });
1625
1626 EXPECT_CALL(mock_log_manager_,
1627 LogDiag(ProdDiagCode::BACKGROUND_TRAINING_INTERRUPT_HTTP));
1628
1629 auto checkin_result = federated_protocol_->Checkin(
1630 GetFakeTaskEligibilityInfo(),
1631 mock_task_received_callback_.AsStdFunction());
1632 EXPECT_THAT(checkin_result.status(), IsCode(CANCELLED));
1633 // RetryWindows were already received from the server during the eligibility
1634 // eval checkin, so we expect to get a 'rejected' retry window.
1635 ExpectRejectedRetryWindow(federated_protocol_->GetLatestRetryWindow());
1636 }
1637
1638 // Tests the case where we get interrupted during polling of the long running
1639 // operation.
TEST_F(HttpFederatedProtocolTest,TestCheckinInterruptedDuringLongRunningOperation)1640 TEST_F(HttpFederatedProtocolTest,
1641 TestCheckinInterruptedDuringLongRunningOperation) {
1642 // Issue an eligibility eval checkin first.
1643 ASSERT_OK(RunSuccessfulEligibilityEvalCheckin());
1644 std::string report_eet_request_uri =
1645 "https://initial.uri/v1/populations/TEST%2FPOPULATION/"
1646 "eligibilityevaltasks/"
1647 "ELIGIBILITY%2FSESSION%23ID:reportresult?%24alt=proto";
1648 ExpectSuccessfulReportEligibilityEvalTaskResultRequest(report_eet_request_uri,
1649 absl::OkStatus());
1650
1651 absl::Notification request_issued;
1652 absl::Notification request_cancelled;
1653
1654 Operation pending_operation = CreatePendingOperation("operations/foo#bar");
1655 // Make HttpClient::PerformRequests() block until the counter is decremented.
1656 EXPECT_CALL(
1657 mock_http_client_,
1658 PerformSingleRequest(SimpleHttpRequestMatcher(
1659 "https://taskassignment.uri/v1/populations/TEST%2FPOPULATION/"
1660 "taskassignments/ELIGIBILITY%2FSESSION%23ID:start?%24alt=proto",
1661 HttpRequest::Method::kPost, _, _)))
1662 .WillOnce(Return(FakeHttpResponse(
1663 200, HeaderList(), pending_operation.SerializeAsString())));
1664
1665 // Make should_abort return false until we know that the request was issued
1666 // (i.e. once InterruptibleRunner has actually started running the code it
1667 // was given), and then make it return true, triggering an abort sequence and
1668 // unblocking the PerformRequests()() call we caused to block above.
1669 EXPECT_CALL(mock_should_abort_, Call()).WillRepeatedly([&request_issued] {
1670 return request_issued.HasBeenNotified();
1671 });
1672 EXPECT_CALL(
1673 mock_http_client_,
1674 PerformSingleRequest(SimpleHttpRequestMatcher(
1675 // Note that the '#' character is encoded as "%23".
1676 "https://taskassignment.uri/v1/operations/foo%23bar?%24alt=proto",
1677 HttpRequest::Method::kGet, _,
1678 GetOperationRequestMatcher(EqualsProto(GetOperationRequest())))))
1679 .WillRepeatedly([&request_issued, &request_cancelled, pending_operation](
1680 MockableHttpClient::SimpleHttpRequest ignored) {
1681 if (!request_issued.HasBeenNotified()) {
1682 request_issued.Notify();
1683 }
1684 request_cancelled.WaitForNotification();
1685 return FakeHttpResponse(200, HeaderList(),
1686 pending_operation.SerializeAsString());
1687 });
1688
1689 // Once the client is cancelled, a CancelOperationRequest should still be sent
1690 // out before returning to the caller."
1691 EXPECT_CALL(
1692 mock_http_client_,
1693 PerformSingleRequest(SimpleHttpRequestMatcher(
1694 // Note that the '#' character is encoded as "%23".
1695 "https://taskassignment.uri/v1/operations/"
1696 "foo%23bar:cancel?%24alt=proto",
1697 HttpRequest::Method::kGet, _,
1698 GetOperationRequestMatcher(EqualsProto(GetOperationRequest())))))
1699 .WillOnce(Return(FakeHttpResponse(200, HeaderList(), "")));
1700
1701 // When the HttpClient receives a HttpRequestHandle::Cancel call, we let the
1702 // request complete.
1703 mock_http_client_.SetCancellationListener(
1704 [&request_cancelled]() { request_cancelled.Notify(); });
1705
1706 EXPECT_CALL(mock_log_manager_,
1707 LogDiag(ProdDiagCode::BACKGROUND_TRAINING_INTERRUPT_HTTP));
1708
1709 auto checkin_result = federated_protocol_->Checkin(
1710 GetFakeTaskEligibilityInfo(),
1711 mock_task_received_callback_.AsStdFunction());
1712 EXPECT_THAT(checkin_result.status(), IsCode(CANCELLED));
1713 // RetryWindows were already received from the server during the eligibility
1714 // eval checkin, so we expect to get a 'rejected' retry window.
1715 ExpectRejectedRetryWindow(federated_protocol_->GetLatestRetryWindow());
1716 }
1717
1718 // Tests the case where we get interrupted during polling of the long-running
1719 // operation, and the issued cancellation request timed out.
TEST_F(HttpFederatedProtocolTest,TestCheckinInterruptedCancellationTimeout)1720 TEST_F(HttpFederatedProtocolTest, TestCheckinInterruptedCancellationTimeout) {
1721 // Issue an eligibility eval checkin first.
1722 ASSERT_OK(RunSuccessfulEligibilityEvalCheckin());
1723 std::string report_eet_request_uri =
1724 "https://initial.uri/v1/populations/TEST%2FPOPULATION/"
1725 "eligibilityevaltasks/"
1726 "ELIGIBILITY%2FSESSION%23ID:reportresult?%24alt=proto";
1727 ExpectSuccessfulReportEligibilityEvalTaskResultRequest(report_eet_request_uri,
1728 absl::OkStatus());
1729
1730 absl::Notification request_issued;
1731 absl::Notification request_cancelled;
1732
1733 Operation pending_operation = CreatePendingOperation("operations/foo#bar");
1734 // Make HttpClient::PerformRequests() block until the counter is decremented.
1735 EXPECT_CALL(
1736 mock_http_client_,
1737 PerformSingleRequest(SimpleHttpRequestMatcher(
1738 "https://taskassignment.uri/v1/populations/TEST%2FPOPULATION/"
1739 "taskassignments/ELIGIBILITY%2FSESSION%23ID:start?%24alt=proto",
1740 HttpRequest::Method::kPost, _, _)))
1741 .WillOnce(Return(FakeHttpResponse(
1742 200, HeaderList(), pending_operation.SerializeAsString())));
1743
1744 // Make should_abort return false until we know that the request was issued
1745 // (i.e. once InterruptibleRunner has actually started running the code it
1746 // was given), and then make it return true, triggering an abort sequence and
1747 // unblocking the PerformRequests()() call we caused to block above.
1748 EXPECT_CALL(mock_should_abort_, Call()).WillRepeatedly([&request_issued] {
1749 return request_issued.HasBeenNotified();
1750 });
1751 EXPECT_CALL(
1752 mock_http_client_,
1753 PerformSingleRequest(SimpleHttpRequestMatcher(
1754 // Note that the '#' character is encoded as "%23".
1755 "https://taskassignment.uri/v1/operations/foo%23bar?%24alt=proto",
1756 HttpRequest::Method::kGet, _,
1757 GetOperationRequestMatcher(EqualsProto(GetOperationRequest())))))
1758 .WillRepeatedly([&request_issued, &request_cancelled, pending_operation](
1759 MockableHttpClient::SimpleHttpRequest ignored) {
1760 if (!request_issued.HasBeenNotified()) {
1761 request_issued.Notify();
1762 }
1763 request_cancelled.WaitForNotification();
1764 return FakeHttpResponse(200, HeaderList(),
1765 pending_operation.SerializeAsString());
1766 });
1767
1768 // Once the client is cancelled, a CancelOperationRequest should still be sent
1769 // out before returning to the caller."
1770 EXPECT_CALL(
1771 mock_http_client_,
1772 PerformSingleRequest(SimpleHttpRequestMatcher(
1773 // Note that the '#' character is encoded as "%23".
1774 "https://taskassignment.uri/v1/operations/"
1775 "foo%23bar:cancel?%24alt=proto",
1776 HttpRequest::Method::kGet, _,
1777 GetOperationRequestMatcher(EqualsProto(GetOperationRequest())))))
1778 .WillOnce([](MockableHttpClient::SimpleHttpRequest ignored) {
1779 // Sleep for 2 seconds before returning the response.
1780 absl::SleepFor(absl::Seconds(2));
1781 return FakeHttpResponse(200, HeaderList(), "");
1782 });
1783
1784 // When the HttpClient receives a HttpRequestHandle::Cancel call, we let the
1785 // request complete.
1786 mock_http_client_.SetCancellationListener([&request_cancelled]() {
1787 if (!request_cancelled.HasBeenNotified()) {
1788 request_cancelled.Notify();
1789 }
1790 });
1791
1792 // The Interruption log will be logged twice, one for Get operation, the other
1793 // for Cancel operation.
1794 EXPECT_CALL(mock_log_manager_,
1795 LogDiag(ProdDiagCode::BACKGROUND_TRAINING_INTERRUPT_HTTP))
1796 .Times(2);
1797 EXPECT_CALL(mock_log_manager_,
1798 LogDiag(ProdDiagCode::HTTP_CANCELLATION_OR_ABORT_REQUEST_FAILED));
1799
1800 auto checkin_result = federated_protocol_->Checkin(
1801 GetFakeTaskEligibilityInfo(),
1802 mock_task_received_callback_.AsStdFunction());
1803 EXPECT_THAT(checkin_result.status(), IsCode(CANCELLED));
1804 // RetryWindows were already received from the server during the eligibility
1805 // eval checkin, so we expect to get a 'rejected' retry window.
1806 ExpectRejectedRetryWindow(federated_protocol_->GetLatestRetryWindow());
1807 }
1808
1809 // Tests whether 'rejection' responses to the main Checkin(...) request are
1810 // handled correctly.
TEST_F(HttpFederatedProtocolTest,TestCheckinRejectionWithTaskEligibilityInfo)1811 TEST_F(HttpFederatedProtocolTest, TestCheckinRejectionWithTaskEligibilityInfo) {
1812 // Issue an eligibility eval checkin first.
1813 ASSERT_OK(RunSuccessfulEligibilityEvalCheckin());
1814 std::string report_eet_request_uri =
1815 "https://initial.uri/v1/populations/TEST%2FPOPULATION/"
1816 "eligibilityevaltasks/"
1817 "ELIGIBILITY%2FSESSION%23ID:reportresult?%24alt=proto";
1818 ExpectSuccessfulReportEligibilityEvalTaskResultRequest(report_eet_request_uri,
1819 absl::OkStatus());
1820
1821 TaskEligibilityInfo expected_eligibility_info = GetFakeTaskEligibilityInfo();
1822 EXPECT_CALL(
1823 mock_http_client_,
1824 PerformSingleRequest(SimpleHttpRequestMatcher(
1825 "https://taskassignment.uri/v1/populations/TEST%2FPOPULATION/"
1826 "taskassignments/ELIGIBILITY%2FSESSION%23ID:start?%24alt=proto",
1827 HttpRequest::Method::kPost, _,
1828 StartTaskAssignmentRequestMatcher(
1829 EqualsProto(GetExpectedStartTaskAssignmentRequest(
1830 expected_eligibility_info))))))
1831 .WillOnce(Return(FakeHttpResponse(
1832 200, HeaderList(),
1833 CreateDoneOperation(kOperationName,
1834 GetFakeRejectedTaskAssignmentResponse())
1835 .SerializeAsString())));
1836
1837 // The 'task received' callback should not be invoked since no task was given
1838 // to the client.
1839 EXPECT_CALL(mock_task_received_callback_, Call(_)).Times(0);
1840
1841 // Issue the regular checkin.
1842 auto checkin_result = federated_protocol_->Checkin(
1843 expected_eligibility_info, mock_task_received_callback_.AsStdFunction());
1844
1845 ASSERT_OK(checkin_result.status());
1846 EXPECT_THAT(*checkin_result, VariantWith<FederatedProtocol::Rejection>(_));
1847 ExpectRejectedRetryWindow(federated_protocol_->GetLatestRetryWindow());
1848 }
1849
1850 // Tests whether we can issue a Checkin() request correctly without passing a
1851 // TaskEligibilityInfo, in the case that the eligibility eval checkin didn't
1852 // return any eligibility eval task to run.
TEST_F(HttpFederatedProtocolTest,TestCheckinRejectionWithoutTaskEligibilityInfo)1853 TEST_F(HttpFederatedProtocolTest,
1854 TestCheckinRejectionWithoutTaskEligibilityInfo) {
1855 // Issue an eligibility eval checkin first.
1856 ASSERT_OK(
1857 RunSuccessfulEligibilityEvalCheckin(/*eligibility_eval_enabled=*/false));
1858
1859 EXPECT_CALL(
1860 mock_http_client_,
1861 PerformSingleRequest(SimpleHttpRequestMatcher(
1862 "https://taskassignment.uri/v1/populations/TEST%2FPOPULATION/"
1863 "taskassignments/ELIGIBILITY%2FSESSION%23ID:start?%24alt=proto",
1864 HttpRequest::Method::kPost, _,
1865 StartTaskAssignmentRequestMatcher(EqualsProto(
1866 GetExpectedStartTaskAssignmentRequest(std::nullopt))))))
1867 .WillOnce(Return(FakeHttpResponse(
1868 200, HeaderList(),
1869 CreateDoneOperation(kOperationName,
1870 GetFakeRejectedTaskAssignmentResponse())
1871 .SerializeAsString())));
1872
1873 // The 'task received' callback should not be invoked since no task was given
1874 // to the client.
1875 EXPECT_CALL(mock_task_received_callback_, Call(_)).Times(0);
1876
1877 // Issue the regular checkin, without a TaskEligibilityInfo (since we didn't
1878 // receive an eligibility eval task to run during eligibility eval checkin).
1879 auto checkin_result = federated_protocol_->Checkin(
1880 std::nullopt, mock_task_received_callback_.AsStdFunction());
1881
1882 ASSERT_OK(checkin_result.status());
1883 EXPECT_THAT(*checkin_result, VariantWith<FederatedProtocol::Rejection>(_));
1884 ExpectRejectedRetryWindow(federated_protocol_->GetLatestRetryWindow());
1885 }
1886
1887 // Tests whether a successful task assignment response is handled correctly.
TEST_F(HttpFederatedProtocolTest,TestCheckinTaskAssigned)1888 TEST_F(HttpFederatedProtocolTest, TestCheckinTaskAssigned) {
1889 // Issue an eligibility eval checkin first.
1890 ASSERT_OK(RunSuccessfulEligibilityEvalCheckin());
1891 std::string report_eet_request_uri =
1892 "https://initial.uri/v1/populations/TEST%2FPOPULATION/"
1893 "eligibilityevaltasks/"
1894 "ELIGIBILITY%2FSESSION%23ID:reportresult?%24alt=proto";
1895 ExpectSuccessfulReportEligibilityEvalTaskResultRequest(report_eet_request_uri,
1896 absl::OkStatus());
1897
1898 TaskEligibilityInfo expected_eligibility_info = GetFakeTaskEligibilityInfo();
1899 // We return a fake response which requires fetching the plan via HTTP, but
1900 // which has the checkpoint data available inline.
1901 std::string expected_plan = kPlan;
1902 std::string plan_uri = "https://fake.uri/plan";
1903 Resource plan_resource;
1904 plan_resource.set_uri(plan_uri);
1905 std::string expected_checkpoint = kInitCheckpoint;
1906 Resource checkpoint_resource;
1907 checkpoint_resource.mutable_inline_resource()->set_data(expected_checkpoint);
1908 std::string expected_federated_select_uri_template =
1909 kFederatedSelectUriTemplate;
1910 std::string expected_aggregation_session_id = kAggregationSessionId;
1911
1912 InSequence seq;
1913 // Note that in this particular test we check that the CheckinRequest is as
1914 // expected (in all prior tests we just use the '_' matcher, because the
1915 // request isn't really relevant to the test).
1916 EXPECT_CALL(
1917 mock_http_client_,
1918 PerformSingleRequest(SimpleHttpRequestMatcher(
1919 "https://taskassignment.uri/v1/populations/TEST%2FPOPULATION/"
1920 "taskassignments/ELIGIBILITY%2FSESSION%23ID:start?%24alt=proto",
1921 HttpRequest::Method::kPost, _,
1922 StartTaskAssignmentRequestMatcher(
1923 EqualsProto(GetExpectedStartTaskAssignmentRequest(
1924 expected_eligibility_info))))))
1925 .WillOnce(Return(FakeHttpResponse(
1926 200, HeaderList(),
1927 CreateDoneOperation(kOperationName,
1928 GetFakeTaskAssignmentResponse(
1929 plan_resource, checkpoint_resource,
1930 expected_federated_select_uri_template,
1931 expected_aggregation_session_id,
1932 kMinimumClientsInServerVisibleAggregate))
1933 .SerializeAsString())));
1934
1935 // The 'task received' callback should be called *before* the actual task
1936 // resources are fetched.
1937 EXPECT_CALL(
1938 mock_task_received_callback_,
1939 Call(FieldsAre(FieldsAre("", ""), expected_federated_select_uri_template,
1940 expected_aggregation_session_id,
1941 Optional(FieldsAre(
1942 _, Eq(kMinimumClientsInServerVisibleAggregate))))));
1943
1944 EXPECT_CALL(mock_http_client_,
1945 PerformSingleRequest(SimpleHttpRequestMatcher(
1946 plan_uri, HttpRequest::Method::kGet, _, "")))
1947 .WillOnce(Return(FakeHttpResponse(200, HeaderList(), expected_plan)));
1948
1949 // Issue the regular checkin.
1950 auto checkin_result = federated_protocol_->Checkin(
1951 expected_eligibility_info, mock_task_received_callback_.AsStdFunction());
1952
1953 ASSERT_OK(checkin_result.status());
1954 EXPECT_THAT(
1955 *checkin_result,
1956 VariantWith<FederatedProtocol::TaskAssignment>(FieldsAre(
1957 FieldsAre(absl::Cord(expected_plan), absl::Cord(expected_checkpoint)),
1958 expected_federated_select_uri_template,
1959 expected_aggregation_session_id,
1960 Optional(
1961 FieldsAre(_, Eq(kMinimumClientsInServerVisibleAggregate))))));
1962 // The Checkin call is expected to return the accepted retry window from the
1963 // response to the first eligibility eval request.
1964 ExpectAcceptedRetryWindow(federated_protocol_->GetLatestRetryWindow());
1965 }
1966
1967 // Ensures that polling the Operation returned by a StartTaskAssignmentRequest
1968 // works as expected. This serves mostly as a high-level check. Further
1969 // polling-specific behavior is tested in more detail in
1970 // ProtocolRequestHelperTest.
TEST_F(HttpFederatedProtocolTest,TestCheckinTaskAssignedAfterOperationPolling)1971 TEST_F(HttpFederatedProtocolTest,
1972 TestCheckinTaskAssignedAfterOperationPolling) {
1973 // Issue an eligibility eval checkin first.
1974 ASSERT_OK(RunSuccessfulEligibilityEvalCheckin());
1975 std::string report_eet_request_uri =
1976 "https://initial.uri/v1/populations/TEST%2FPOPULATION/"
1977 "eligibilityevaltasks/"
1978 "ELIGIBILITY%2FSESSION%23ID:reportresult?%24alt=proto";
1979 ExpectSuccessfulReportEligibilityEvalTaskResultRequest(report_eet_request_uri,
1980 absl::OkStatus());
1981
1982 // Make the initial StartTaskAssignmentRequest return a pending Operation
1983 // result. Note that we use a '#' character in the operation name to allow us
1984 // to verify that it is properly URL-encoded.
1985 Operation pending_operation_response =
1986 CreatePendingOperation("operations/foo#bar");
1987 EXPECT_CALL(
1988 mock_http_client_,
1989 PerformSingleRequest(SimpleHttpRequestMatcher(
1990 "https://taskassignment.uri/v1/populations/TEST%2FPOPULATION/"
1991 "taskassignments/ELIGIBILITY%2FSESSION%23ID:start?%24alt=proto",
1992 HttpRequest::Method::kPost, _, _)))
1993 .WillOnce(Return(FakeHttpResponse(
1994 200, HeaderList(), pending_operation_response.SerializeAsString())));
1995
1996 // Then, after letting the operation get polled twice more, eventually return
1997 // a fake response.
1998 std::string expected_plan = kPlan;
1999 Resource plan_resource;
2000 plan_resource.mutable_inline_resource()->set_data(expected_plan);
2001 std::string expected_checkpoint = kInitCheckpoint;
2002 Resource checkpoint_resource;
2003 checkpoint_resource.mutable_inline_resource()->set_data(expected_checkpoint);
2004 std::string expected_federated_select_uri_template =
2005 kFederatedSelectUriTemplate;
2006 std::string expected_aggregation_session_id = kAggregationSessionId;
2007
2008 EXPECT_CALL(
2009 mock_http_client_,
2010 PerformSingleRequest(SimpleHttpRequestMatcher(
2011 // Note that the '#' character is encoded as "%23".
2012 "https://taskassignment.uri/v1/operations/foo%23bar?%24alt=proto",
2013 HttpRequest::Method::kGet, _,
2014 GetOperationRequestMatcher(EqualsProto(GetOperationRequest())))))
2015 .WillOnce(Return(FakeHttpResponse(
2016 200, HeaderList(), pending_operation_response.SerializeAsString())))
2017 .WillOnce(Return(FakeHttpResponse(
2018 200, HeaderList(), pending_operation_response.SerializeAsString())))
2019 .WillOnce(Return(FakeHttpResponse(
2020 200, HeaderList(),
2021 CreateDoneOperation(kOperationName,
2022 GetFakeTaskAssignmentResponse(
2023 plan_resource, checkpoint_resource,
2024 expected_federated_select_uri_template,
2025 expected_aggregation_session_id, 0))
2026 .SerializeAsString())));
2027
2028 // The 'task received' callback should be called, even if the task resource
2029 // data was available inline.
2030 EXPECT_CALL(
2031 mock_task_received_callback_,
2032 Call(FieldsAre(FieldsAre("", ""), expected_federated_select_uri_template,
2033 expected_aggregation_session_id, Eq(std::nullopt))));
2034
2035 // Issue the regular checkin.
2036 auto checkin_result = federated_protocol_->Checkin(
2037 GetFakeTaskEligibilityInfo(),
2038 mock_task_received_callback_.AsStdFunction());
2039
2040 ASSERT_OK(checkin_result.status());
2041 EXPECT_THAT(
2042 *checkin_result,
2043 VariantWith<FederatedProtocol::TaskAssignment>(FieldsAre(
2044 FieldsAre(absl::Cord(expected_plan), absl::Cord(expected_checkpoint)),
2045 expected_federated_select_uri_template,
2046 expected_aggregation_session_id, Eq(std::nullopt))));
2047 // The Checkin call is expected to return the accepted retry window from the
2048 // response to the first eligibility eval request.
2049 ExpectAcceptedRetryWindow(federated_protocol_->GetLatestRetryWindow());
2050 }
2051
2052 // Ensures that if the plan resource fails to be downloaded, the error is
2053 // correctly returned from the Checkin(...) method.
TEST_F(HttpFederatedProtocolTest,TestCheckinTaskAssignedPlanDataFetchFailed)2054 TEST_F(HttpFederatedProtocolTest, TestCheckinTaskAssignedPlanDataFetchFailed) {
2055 // Issue an eligibility eval checkin first.
2056 ASSERT_OK(RunSuccessfulEligibilityEvalCheckin());
2057 std::string report_eet_request_uri =
2058 "https://initial.uri/v1/populations/TEST%2FPOPULATION/"
2059 "eligibilityevaltasks/"
2060 "ELIGIBILITY%2FSESSION%23ID:reportresult?%24alt=proto";
2061 ExpectSuccessfulReportEligibilityEvalTaskResultRequest(report_eet_request_uri,
2062 absl::OkStatus());
2063
2064 std::string plan_uri = "https://fake.uri/plan";
2065 Resource plan_resource;
2066 plan_resource.set_uri(plan_uri);
2067 std::string checkpoint_uri = "https://fake.uri/checkpoint";
2068 Resource checkpoint_resource;
2069 checkpoint_resource.set_uri(checkpoint_uri);
2070 EXPECT_CALL(
2071 mock_http_client_,
2072 PerformSingleRequest(SimpleHttpRequestMatcher(
2073 "https://taskassignment.uri/v1/populations/TEST%2FPOPULATION/"
2074 "taskassignments/ELIGIBILITY%2FSESSION%23ID:start?%24alt=proto",
2075 HttpRequest::Method::kPost, _, _)))
2076 .WillOnce(Return(FakeHttpResponse(
2077 200, HeaderList(),
2078 CreateDoneOperation(
2079 kOperationName,
2080 GetFakeTaskAssignmentResponse(plan_resource, checkpoint_resource,
2081 kFederatedSelectUriTemplate,
2082 kAggregationSessionId, 0))
2083 .SerializeAsString())));
2084
2085 // Mock a failed plan fetch.
2086 EXPECT_CALL(mock_http_client_,
2087 PerformSingleRequest(SimpleHttpRequestMatcher(
2088 plan_uri, HttpRequest::Method::kGet, _, "")))
2089 .WillOnce(Return(FakeHttpResponse(404, HeaderList(), "")));
2090
2091 EXPECT_CALL(mock_http_client_,
2092 PerformSingleRequest(SimpleHttpRequestMatcher(
2093 checkpoint_uri, HttpRequest::Method::kGet, _, "")))
2094 .WillOnce(Return(FakeHttpResponse(200, HeaderList(), "")));
2095
2096 // Issue the regular checkin.
2097 auto checkin_result = federated_protocol_->Checkin(
2098 GetFakeTaskEligibilityInfo(),
2099 mock_task_received_callback_.AsStdFunction());
2100
2101 // The 404 error for the resource request should be reflected in the return
2102 // value.
2103 EXPECT_THAT(checkin_result.status(), IsCode(NOT_FOUND));
2104 EXPECT_THAT(checkin_result.status().message(),
2105 HasSubstr("plan fetch failed"));
2106 EXPECT_THAT(checkin_result.status().message(), HasSubstr("404"));
2107 // The Checkin call is expected to return the permanent error retry window,
2108 // since 404 maps to a permanent error.
2109 ExpectPermanentErrorRetryWindow(federated_protocol_->GetLatestRetryWindow());
2110 }
2111
2112 // Ensures that if the checkpoint resource fails to be downloaded, the error is
2113 // correctly returned from the Checkin(...) method.
TEST_F(HttpFederatedProtocolTest,TestCheckinTaskAssignedCheckpointDataFetchFailed)2114 TEST_F(HttpFederatedProtocolTest,
2115 TestCheckinTaskAssignedCheckpointDataFetchFailed) {
2116 // Issue an eligibility eval checkin first.
2117 ASSERT_OK(RunSuccessfulEligibilityEvalCheckin());
2118 std::string report_eet_request_uri =
2119 "https://initial.uri/v1/populations/TEST%2FPOPULATION/"
2120 "eligibilityevaltasks/"
2121 "ELIGIBILITY%2FSESSION%23ID:reportresult?%24alt=proto";
2122 ExpectSuccessfulReportEligibilityEvalTaskResultRequest(report_eet_request_uri,
2123 absl::OkStatus());
2124
2125 std::string plan_uri = "https://fake.uri/plan";
2126 Resource plan_resource;
2127 plan_resource.set_uri(plan_uri);
2128 std::string checkpoint_uri = "https://fake.uri/checkpoint";
2129 Resource checkpoint_resource;
2130 checkpoint_resource.set_uri(checkpoint_uri);
2131
2132 EXPECT_CALL(
2133 mock_http_client_,
2134 PerformSingleRequest(SimpleHttpRequestMatcher(
2135 "https://taskassignment.uri/v1/populations/TEST%2FPOPULATION/"
2136 "taskassignments/ELIGIBILITY%2FSESSION%23ID:start?%24alt=proto",
2137 HttpRequest::Method::kPost, _, _)))
2138 .WillOnce(Return(FakeHttpResponse(
2139 200, HeaderList(),
2140 CreateDoneOperation(
2141 kOperationName,
2142 GetFakeTaskAssignmentResponse(plan_resource, checkpoint_resource,
2143 kFederatedSelectUriTemplate,
2144 kAggregationSessionId, 0))
2145 .SerializeAsString())));
2146
2147 // Mock a failed checkpoint fetch.
2148 EXPECT_CALL(mock_http_client_,
2149 PerformSingleRequest(SimpleHttpRequestMatcher(
2150 checkpoint_uri, HttpRequest::Method::kGet, _, "")))
2151 .WillOnce(Return(FakeHttpResponse(503, HeaderList(), "")));
2152
2153 EXPECT_CALL(mock_http_client_,
2154 PerformSingleRequest(SimpleHttpRequestMatcher(
2155 plan_uri, HttpRequest::Method::kGet, _, "")))
2156 .WillOnce(Return(FakeHttpResponse(200, HeaderList(), "")));
2157
2158 // Issue the regular checkin.
2159 auto checkin_result = federated_protocol_->Checkin(
2160 GetFakeTaskEligibilityInfo(),
2161 mock_task_received_callback_.AsStdFunction());
2162
2163 // The 503 error for the resource request should be reflected in the return
2164 // value.
2165 EXPECT_THAT(checkin_result.status(), IsCode(UNAVAILABLE));
2166 EXPECT_THAT(checkin_result.status().message(),
2167 HasSubstr("checkpoint fetch failed"));
2168 EXPECT_THAT(checkin_result.status().message(), HasSubstr("503"));
2169 // The Checkin call is expected to return the rejected retry window from the
2170 // response to the first eligibility eval request.
2171 ExpectRejectedRetryWindow(federated_protocol_->GetLatestRetryWindow());
2172 }
2173
TEST_F(HttpFederatedProtocolTest,TestReportCompletedViaSimpleAggSuccess)2174 TEST_F(HttpFederatedProtocolTest, TestReportCompletedViaSimpleAggSuccess) {
2175 // Issue an eligibility eval checkin first.
2176 ASSERT_OK(RunSuccessfulEligibilityEvalCheckin());
2177 // Issue a regular checkin
2178 ASSERT_OK(RunSuccessfulCheckin());
2179
2180 // Create a fake checkpoint with 32 'X'.
2181 std::string checkpoint_str(32, 'X');
2182 ComputationResults results;
2183 results.emplace("tensorflow_checkpoint", checkpoint_str);
2184 absl::Duration plan_duration = absl::Minutes(5);
2185
2186 ExpectSuccessfulReportTaskResultRequest(
2187 "https://taskassignment.uri/v1/populations/TEST%2FPOPULATION/"
2188 "taskassignments/CLIENT_SESSION_ID:reportresult?%24alt=proto",
2189 kAggregationSessionId, kTaskName, plan_duration);
2190 ExpectSuccessfulStartAggregationDataUploadRequest(
2191 "https://aggregation.uri/v1/aggregations/AGGREGATION_SESSION_ID/"
2192 "clients/AUTHORIZATION_TOKEN:startdataupload?%24alt=proto",
2193 kResourceName, kByteStreamTargetUri, kSecondStageAggregationTargetUri);
2194 ExpectSuccessfulByteStreamUploadRequest(
2195 "https://bytestream.uri/upload/v1/media/"
2196 "CHECKPOINT_RESOURCE?upload_protocol=raw",
2197 checkpoint_str);
2198 ExpectSuccessfulSubmitAggregationResultRequest(
2199 "https://aggregation.second.uri/v1/aggregations/"
2200 "AGGREGATION_SESSION_ID/clients/CLIENT_TOKEN:submit?%24alt=proto");
2201
2202 EXPECT_OK(federated_protocol_->ReportCompleted(std::move(results),
2203 plan_duration, std::nullopt));
2204 }
2205
2206 // TODO(team): Remove this test once client_token is always populated in
2207 // StartAggregationDataUploadResponse.
TEST_F(HttpFederatedProtocolTest,TestReportCompletedViaSimpleAggWithoutClientToken)2208 TEST_F(HttpFederatedProtocolTest,
2209 TestReportCompletedViaSimpleAggWithoutClientToken) {
2210 // Issue an eligibility eval checkin first.
2211 ASSERT_OK(RunSuccessfulEligibilityEvalCheckin());
2212 // Issue a regular checkin
2213 ASSERT_OK(RunSuccessfulCheckin());
2214
2215 // Create a fake checkpoint with 32 'X'.
2216 std::string checkpoint_str(32, 'X');
2217 ComputationResults results;
2218 results.emplace("tensorflow_checkpoint", checkpoint_str);
2219 absl::Duration plan_duration = absl::Minutes(5);
2220
2221 ExpectSuccessfulReportTaskResultRequest(
2222 "https://taskassignment.uri/v1/populations/TEST%2FPOPULATION/"
2223 "taskassignments/CLIENT_SESSION_ID:reportresult?%24alt=proto",
2224 kAggregationSessionId, kTaskName, plan_duration);
2225
2226 StartAggregationDataUploadResponse start_aggregation_data_upload_response =
2227 GetFakeStartAggregationDataUploadResponse(
2228 kResourceName, kByteStreamTargetUri,
2229 kSecondStageAggregationTargetUri);
2230 // Omit the client token from the response.
2231 start_aggregation_data_upload_response.clear_client_token();
2232 EXPECT_CALL(
2233 mock_http_client_,
2234 PerformSingleRequest(SimpleHttpRequestMatcher(
2235 "https://aggregation.uri/v1/aggregations/AGGREGATION_SESSION_ID/"
2236 "clients/AUTHORIZATION_TOKEN:startdataupload?%24alt=proto",
2237 HttpRequest::Method::kPost, _, _)))
2238 .WillOnce(Return(FakeHttpResponse(
2239 200, HeaderList(),
2240 CreateDoneOperation(kOperationName,
2241 start_aggregation_data_upload_response)
2242 .SerializeAsString())));
2243
2244 ExpectSuccessfulByteStreamUploadRequest(
2245 "https://bytestream.uri/upload/v1/media/"
2246 "CHECKPOINT_RESOURCE?upload_protocol=raw",
2247 checkpoint_str);
2248 // SubmitAggregationResult should reuse the authorization token.
2249 ExpectSuccessfulSubmitAggregationResultRequest(
2250 "https://aggregation.second.uri/v1/aggregations/"
2251 "AGGREGATION_SESSION_ID/clients/AUTHORIZATION_TOKEN:submit?%24alt=proto");
2252
2253 EXPECT_OK(federated_protocol_->ReportCompleted(std::move(results),
2254 plan_duration, std::nullopt));
2255 }
2256
TEST_F(HttpFederatedProtocolTest,TestReportCompletedViaSecureAgg)2257 TEST_F(HttpFederatedProtocolTest, TestReportCompletedViaSecureAgg) {
2258 absl::Duration plan_duration = absl::Minutes(5);
2259 // Issue an eligibility eval checkin first.
2260 ASSERT_OK(RunSuccessfulEligibilityEvalCheckin());
2261 // Issue a regular checkin
2262 ASSERT_OK(RunSuccessfulCheckin());
2263
2264 StartSecureAggregationResponse start_secure_aggregation_response;
2265 start_secure_aggregation_response.set_client_token(kClientToken);
2266 auto masked_result_resource =
2267 start_secure_aggregation_response.mutable_masked_result_resource();
2268 masked_result_resource->set_resource_name("masked_resource");
2269 masked_result_resource->mutable_data_upload_forwarding_info()
2270 ->set_target_uri_prefix("https://bytestream.uri/");
2271
2272 auto nonmasked_result_resource =
2273 start_secure_aggregation_response.mutable_nonmasked_result_resource();
2274 nonmasked_result_resource->set_resource_name("nonmasked_resource");
2275 nonmasked_result_resource->mutable_data_upload_forwarding_info()
2276 ->set_target_uri_prefix("https://bytestream.uri/");
2277
2278 start_secure_aggregation_response.mutable_secagg_protocol_forwarding_info()
2279 ->set_target_uri_prefix("https://secure.aggregations.uri/");
2280 auto protocol_execution_info =
2281 start_secure_aggregation_response.mutable_protocol_execution_info();
2282 protocol_execution_info->set_minimum_surviving_clients_for_reconstruction(
2283 450);
2284 protocol_execution_info->set_expected_number_of_clients(500);
2285
2286 auto secure_aggregands =
2287 start_secure_aggregation_response.mutable_secure_aggregands();
2288 SecureAggregandExecutionInfo secure_aggregand_execution_info;
2289 secure_aggregand_execution_info.set_modulus(9999);
2290 (*secure_aggregands)["secagg_tensor"] = secure_aggregand_execution_info;
2291
2292 ExpectSuccessfulReportTaskResultRequest(
2293 "https://taskassignment.uri/v1/populations/TEST%2FPOPULATION/"
2294 "taskassignments/CLIENT_SESSION_ID:reportresult?%24alt=proto",
2295 kAggregationSessionId, kTaskName, plan_duration);
2296 EXPECT_CALL(mock_http_client_,
2297 PerformSingleRequest(SimpleHttpRequestMatcher(
2298 "https://aggregation.uri/v1/secureaggregations/"
2299 "AGGREGATION_SESSION_ID/clients/"
2300 "AUTHORIZATION_TOKEN:start?%24alt=proto",
2301 HttpRequest::Method::kPost, _,
2302 StartSecureAggregationRequest().SerializeAsString())))
2303 .WillOnce(Return(FakeHttpResponse(
2304 200, HeaderList(),
2305 CreatePendingOperation("operations/foo#bar").SerializeAsString())));
2306 EXPECT_CALL(
2307 mock_http_client_,
2308 PerformSingleRequest(SimpleHttpRequestMatcher(
2309 "https://aggregation.uri/v1/operations/foo%23bar?%24alt=proto",
2310 HttpRequest::Method::kGet, _, "")))
2311 .WillOnce(Return(FakeHttpResponse(
2312 200, HeaderList(),
2313 CreateDoneOperation(kOperationName, start_secure_aggregation_response)
2314 .SerializeAsString())));
2315
2316 // Create a fake checkpoint with 32 'X'.
2317 std::string checkpoint_str(32, 'X');
2318 ComputationResults results;
2319 results.emplace("tensorflow_checkpoint", checkpoint_str);
2320 results.emplace("secagg_tensor", QuantizedTensor());
2321
2322 EXPECT_CALL(*mock_secagg_runner_factory_,
2323 CreateSecAggRunner(_, _, _, _, _, 500, 450))
2324 .WillOnce(WithArg<0>([&](auto send_to_server_impl) {
2325 auto mock_secagg_runner =
2326 std::make_unique<StrictMock<MockSecAggRunner>>();
2327 EXPECT_CALL(*mock_secagg_runner,
2328 Run(UnorderedElementsAre(Pair(
2329 "secagg_tensor", VariantWith<QuantizedTensor>(FieldsAre(
2330 IsEmpty(), 0, IsEmpty()))))))
2331 .WillOnce([=,
2332 send_to_server_impl = std::move(send_to_server_impl)] {
2333 // SecAggSendToServerBase::Send should use the client token. This
2334 // needs to be tested here since `send_to_server_impl` should not
2335 // be used outside of Run.
2336 EXPECT_CALL(
2337 mock_http_client_,
2338 PerformSingleRequest(SimpleHttpRequestMatcher(
2339 "https://secure.aggregations.uri/v1/secureaggregations/"
2340 "AGGREGATION_SESSION_ID/clients/"
2341 "CLIENT_TOKEN:abort?%24alt=proto",
2342 _, _, _)))
2343 .WillOnce(Return(CreateEmptySuccessHttpResponse()));
2344 secagg::ClientToServerWrapperMessage abort_message;
2345 abort_message.mutable_abort();
2346 send_to_server_impl->Send(&abort_message);
2347
2348 return absl::OkStatus();
2349 });
2350 return mock_secagg_runner;
2351 }));
2352
2353 EXPECT_OK(federated_protocol_->ReportCompleted(std::move(results),
2354 plan_duration, std::nullopt));
2355 }
2356
2357 // TODO(team): Remove this test once client_token is always populated in
2358 // StartSecureAggregationResponse.
TEST_F(HttpFederatedProtocolTest,TestReportCompletedViaSecureAggWithoutClientToken)2359 TEST_F(HttpFederatedProtocolTest,
2360 TestReportCompletedViaSecureAggWithoutClientToken) {
2361 absl::Duration plan_duration = absl::Minutes(5);
2362 // Issue an eligibility eval checkin first.
2363 ASSERT_OK(RunSuccessfulEligibilityEvalCheckin());
2364 // Issue a regular checkin
2365 ASSERT_OK(RunSuccessfulCheckin());
2366
2367 StartSecureAggregationResponse start_secure_aggregation_response;
2368 // Don't set client_token.
2369 auto masked_result_resource =
2370 start_secure_aggregation_response.mutable_masked_result_resource();
2371 masked_result_resource->set_resource_name("masked_resource");
2372 masked_result_resource->mutable_data_upload_forwarding_info()
2373 ->set_target_uri_prefix("https://bytestream.uri/");
2374
2375 auto nonmasked_result_resource =
2376 start_secure_aggregation_response.mutable_nonmasked_result_resource();
2377 nonmasked_result_resource->set_resource_name("nonmasked_resource");
2378 nonmasked_result_resource->mutable_data_upload_forwarding_info()
2379 ->set_target_uri_prefix("https://bytestream.uri/");
2380
2381 start_secure_aggregation_response.mutable_secagg_protocol_forwarding_info()
2382 ->set_target_uri_prefix("https://secure.aggregations.uri/");
2383 auto protocol_execution_info =
2384 start_secure_aggregation_response.mutable_protocol_execution_info();
2385 protocol_execution_info->set_minimum_surviving_clients_for_reconstruction(
2386 450);
2387 protocol_execution_info->set_expected_number_of_clients(500);
2388
2389 auto secure_aggregands =
2390 start_secure_aggregation_response.mutable_secure_aggregands();
2391 SecureAggregandExecutionInfo secure_aggregand_execution_info;
2392 secure_aggregand_execution_info.set_modulus(9999);
2393 (*secure_aggregands)["secagg_tensor"] = secure_aggregand_execution_info;
2394
2395 ExpectSuccessfulReportTaskResultRequest(
2396 "https://taskassignment.uri/v1/populations/TEST%2FPOPULATION/"
2397 "taskassignments/CLIENT_SESSION_ID:reportresult?%24alt=proto",
2398 kAggregationSessionId, kTaskName, plan_duration);
2399 EXPECT_CALL(mock_http_client_,
2400 PerformSingleRequest(SimpleHttpRequestMatcher(
2401 "https://aggregation.uri/v1/secureaggregations/"
2402 "AGGREGATION_SESSION_ID/clients/"
2403 "AUTHORIZATION_TOKEN:start?%24alt=proto",
2404 HttpRequest::Method::kPost, _,
2405 StartSecureAggregationRequest().SerializeAsString())))
2406 .WillOnce(Return(FakeHttpResponse(
2407 200, HeaderList(),
2408 CreateDoneOperation(kOperationName, start_secure_aggregation_response)
2409 .SerializeAsString())));
2410
2411 // Create a fake checkpoint with 32 'X'.
2412 std::string checkpoint_str(32, 'X');
2413 ComputationResults results;
2414 results.emplace("tensorflow_checkpoint", checkpoint_str);
2415 results.emplace("secagg_tensor", QuantizedTensor());
2416
2417 EXPECT_CALL(*mock_secagg_runner_factory_,
2418 CreateSecAggRunner(_, _, _, _, _, _, _))
2419 .WillOnce(WithArg<0>([&](auto send_to_server_impl) {
2420 auto mock_secagg_runner =
2421 std::make_unique<StrictMock<MockSecAggRunner>>();
2422 EXPECT_CALL(*mock_secagg_runner, Run(_))
2423 .WillOnce([=,
2424 send_to_server_impl = std::move(send_to_server_impl)] {
2425 // SecAggSendToServerBase::Send should reuse the authorization
2426 // token. This needs to be tested here since `send_to_server_impl`
2427 // should not be used outside of Run.
2428 EXPECT_CALL(
2429 mock_http_client_,
2430 PerformSingleRequest(SimpleHttpRequestMatcher(
2431 "https://secure.aggregations.uri/v1/secureaggregations/"
2432 "AGGREGATION_SESSION_ID/clients/"
2433 "AUTHORIZATION_TOKEN:abort?%24alt=proto",
2434 _, _, _)))
2435 .WillOnce(Return(CreateEmptySuccessHttpResponse()));
2436 secagg::ClientToServerWrapperMessage abort_message;
2437 abort_message.mutable_abort();
2438 send_to_server_impl->Send(&abort_message);
2439
2440 return absl::OkStatus();
2441 });
2442 return mock_secagg_runner;
2443 }));
2444
2445 EXPECT_OK(federated_protocol_->ReportCompleted(std::move(results),
2446 plan_duration, std::nullopt));
2447 }
2448
TEST_F(HttpFederatedProtocolTest,TestReportCompletedViaSecureAggReportTaskResultFailed)2449 TEST_F(HttpFederatedProtocolTest,
2450 TestReportCompletedViaSecureAggReportTaskResultFailed) {
2451 absl::Duration plan_duration = absl::Minutes(5);
2452 // Issue an eligibility eval checkin first.
2453 ASSERT_OK(RunSuccessfulEligibilityEvalCheckin());
2454 // Issue a regular checkin
2455 ASSERT_OK(RunSuccessfulCheckin());
2456
2457 StartSecureAggregationResponse start_secure_aggregation_response;
2458 start_secure_aggregation_response.set_client_token(kClientToken);
2459 auto masked_result_resource =
2460 start_secure_aggregation_response.mutable_masked_result_resource();
2461 masked_result_resource->set_resource_name("masked_resource");
2462 masked_result_resource->mutable_data_upload_forwarding_info()
2463 ->set_target_uri_prefix("https://bytestream.uri/");
2464
2465 auto nonmasked_result_resource =
2466 start_secure_aggregation_response.mutable_nonmasked_result_resource();
2467 nonmasked_result_resource->set_resource_name("nonmasked_resource");
2468 nonmasked_result_resource->mutable_data_upload_forwarding_info()
2469 ->set_target_uri_prefix("https://bytestream.uri/");
2470
2471 start_secure_aggregation_response.mutable_secagg_protocol_forwarding_info()
2472 ->set_target_uri_prefix("https://secure.aggregations.uri/");
2473 auto protocol_execution_info =
2474 start_secure_aggregation_response.mutable_protocol_execution_info();
2475 protocol_execution_info->set_minimum_surviving_clients_for_reconstruction(
2476 450);
2477 protocol_execution_info->set_expected_number_of_clients(500);
2478
2479 auto secure_aggregands =
2480 start_secure_aggregation_response.mutable_secure_aggregands();
2481 SecureAggregandExecutionInfo secure_aggregand_execution_info;
2482 secure_aggregand_execution_info.set_modulus(9999);
2483 (*secure_aggregands)["secagg_tensor"] = secure_aggregand_execution_info;
2484
2485 // Mock a failed ReportTaskResult request.
2486 EXPECT_CALL(mock_http_client_,
2487 PerformSingleRequest(SimpleHttpRequestMatcher(
2488 "https://taskassignment.uri/v1/populations/TEST%2FPOPULATION/"
2489 "taskassignments/CLIENT_SESSION_ID:reportresult?%24alt=proto",
2490 HttpRequest::Method::kPost, _,
2491 ReportTaskResultRequestMatcher(
2492 EqualsProto(GetExpectedReportTaskResultRequest(
2493 kAggregationSessionId, kTaskName,
2494 google::rpc::Code::OK, plan_duration))))))
2495 .WillOnce(Return(FakeHttpResponse(503, HeaderList())));
2496 EXPECT_CALL(mock_log_manager_,
2497 LogDiag(ProdDiagCode::HTTP_REPORT_TASK_RESULT_REQUEST_FAILED));
2498 EXPECT_CALL(mock_http_client_,
2499 PerformSingleRequest(SimpleHttpRequestMatcher(
2500 "https://aggregation.uri/v1/secureaggregations/"
2501 "AGGREGATION_SESSION_ID/clients/"
2502 "AUTHORIZATION_TOKEN:start?%24alt=proto",
2503 HttpRequest::Method::kPost, _,
2504 StartSecureAggregationRequest().SerializeAsString())))
2505 .WillOnce(Return(FakeHttpResponse(
2506 200, HeaderList(),
2507 CreatePendingOperation("operations/foo#bar").SerializeAsString())));
2508 EXPECT_CALL(
2509 mock_http_client_,
2510 PerformSingleRequest(SimpleHttpRequestMatcher(
2511 "https://aggregation.uri/v1/operations/foo%23bar?%24alt=proto",
2512 HttpRequest::Method::kGet, _, "")))
2513 .WillOnce(Return(FakeHttpResponse(
2514 200, HeaderList(),
2515 CreateDoneOperation(kOperationName, start_secure_aggregation_response)
2516 .SerializeAsString())));
2517
2518 // Create a fake checkpoint with 32 'X'.
2519 std::string checkpoint_str(32, 'X');
2520 ComputationResults results;
2521 results.emplace("tensorflow_checkpoint", checkpoint_str);
2522 results.emplace("secagg_tensor", QuantizedTensor());
2523
2524 MockSecAggRunner* mock_secagg_runner = new StrictMock<MockSecAggRunner>();
2525 EXPECT_CALL(*mock_secagg_runner_factory_,
2526 CreateSecAggRunner(_, _, _, _, _, 500, 450))
2527 .WillOnce(Return(ByMove(absl::WrapUnique(mock_secagg_runner))));
2528 EXPECT_CALL(*mock_secagg_runner,
2529 Run(UnorderedElementsAre(
2530 Pair("secagg_tensor", VariantWith<QuantizedTensor>(FieldsAre(
2531 IsEmpty(), 0, IsEmpty()))))))
2532 .WillOnce(Return(absl::OkStatus()));
2533
2534 EXPECT_OK(federated_protocol_->ReportCompleted(std::move(results),
2535 plan_duration, std::nullopt));
2536 }
2537
TEST_F(HttpFederatedProtocolTest,TestReportCompletedStartSecAggFailed)2538 TEST_F(HttpFederatedProtocolTest, TestReportCompletedStartSecAggFailed) {
2539 absl::Duration plan_duration = absl::Minutes(5);
2540 // Issue an eligibility eval checkin first.
2541 ASSERT_OK(RunSuccessfulEligibilityEvalCheckin());
2542 // Issue a regular checkin.
2543 ASSERT_OK(RunSuccessfulCheckin());
2544 ExpectSuccessfulReportTaskResultRequest(
2545 "https://taskassignment.uri/v1/populations/TEST%2FPOPULATION/"
2546 "taskassignments/CLIENT_SESSION_ID:reportresult?%24alt=proto",
2547 kAggregationSessionId, kTaskName, plan_duration);
2548 EXPECT_CALL(mock_http_client_,
2549 PerformSingleRequest(SimpleHttpRequestMatcher(
2550 "https://aggregation.uri/v1/secureaggregations/"
2551 "AGGREGATION_SESSION_ID/clients/"
2552 "AUTHORIZATION_TOKEN:start?%24alt=proto",
2553 HttpRequest::Method::kPost, _,
2554 StartSecureAggregationRequest().SerializeAsString())))
2555 .WillOnce(Return(FakeHttpResponse(
2556 200, HeaderList(),
2557 CreateErrorOperation(kOperationName, absl::StatusCode::kInternal,
2558 "Request failed.")
2559 .SerializeAsString())));
2560
2561 // Create a fake checkpoint with 32 'X'.
2562 std::string checkpoint_str(32, 'X');
2563 ComputationResults results;
2564 results.emplace("tensorflow_checkpoint", checkpoint_str);
2565 results.emplace("secagg_tensor", QuantizedTensor());
2566
2567 EXPECT_THAT(federated_protocol_->ReportCompleted(std::move(results),
2568 plan_duration, std::nullopt),
2569 IsCode(absl::StatusCode::kInternal));
2570 }
2571
TEST_F(HttpFederatedProtocolTest,TestReportCompletedStartSecAggFailedImmediately)2572 TEST_F(HttpFederatedProtocolTest,
2573 TestReportCompletedStartSecAggFailedImmediately) {
2574 absl::Duration plan_duration = absl::Minutes(5);
2575 // Issue an eligibility eval checkin first.
2576 ASSERT_OK(RunSuccessfulEligibilityEvalCheckin());
2577 // Issue a regular checkin.
2578 ASSERT_OK(RunSuccessfulCheckin());
2579 ExpectSuccessfulReportTaskResultRequest(
2580 "https://taskassignment.uri/v1/populations/TEST%2FPOPULATION/"
2581 "taskassignments/CLIENT_SESSION_ID:reportresult?%24alt=proto",
2582 kAggregationSessionId, kTaskName, plan_duration);
2583 EXPECT_CALL(mock_http_client_,
2584 PerformSingleRequest(SimpleHttpRequestMatcher(
2585 "https://aggregation.uri/v1/secureaggregations/"
2586 "AGGREGATION_SESSION_ID/clients/"
2587 "AUTHORIZATION_TOKEN:start?%24alt=proto",
2588 HttpRequest::Method::kPost, _,
2589 StartSecureAggregationRequest().SerializeAsString())))
2590 .WillOnce(Return(FakeHttpResponse(403, HeaderList(), "")));
2591
2592 // Create a fake checkpoint with 32 'X'.
2593 std::string checkpoint_str(32, 'X');
2594 ComputationResults results;
2595 results.emplace("tensorflow_checkpoint", checkpoint_str);
2596 results.emplace("secagg_tensor", QuantizedTensor());
2597
2598 EXPECT_THAT(federated_protocol_->ReportCompleted(std::move(results),
2599 plan_duration, std::nullopt),
2600 IsCode(absl::StatusCode::kPermissionDenied));
2601 }
2602
TEST_F(HttpFederatedProtocolTest,TestReportCompletedReportTaskResultFailed)2603 TEST_F(HttpFederatedProtocolTest, TestReportCompletedReportTaskResultFailed) {
2604 // Issue an eligibility eval checkin first.
2605 ASSERT_OK(RunSuccessfulEligibilityEvalCheckin());
2606 // Issue a regular checkin.
2607 ASSERT_OK(RunSuccessfulCheckin());
2608
2609 // Create a fake checkpoint with 32 'X'.
2610 std::string checkpoint_str(32, 'X');
2611 ComputationResults results;
2612 results.emplace("tensorflow_checkpoint", checkpoint_str);
2613 absl::Duration plan_duration = absl::Minutes(5);
2614
2615 // Mock a failed ReportTaskResult request.
2616 ReportTaskResultResponse report_task_result_response;
2617 EXPECT_CALL(mock_http_client_,
2618 PerformSingleRequest(SimpleHttpRequestMatcher(
2619 "https://taskassignment.uri/v1/populations/TEST%2FPOPULATION/"
2620 "taskassignments/CLIENT_SESSION_ID:reportresult?%24alt=proto",
2621 HttpRequest::Method::kPost, _,
2622 ReportTaskResultRequestMatcher(
2623 EqualsProto(GetExpectedReportTaskResultRequest(
2624 kAggregationSessionId, kTaskName,
2625 google::rpc::Code::OK, plan_duration))))))
2626 .WillOnce(Return(FakeHttpResponse(503, HeaderList())));
2627 EXPECT_CALL(mock_log_manager_,
2628 LogDiag(ProdDiagCode::HTTP_REPORT_TASK_RESULT_REQUEST_FAILED));
2629
2630 ExpectSuccessfulStartAggregationDataUploadRequest(
2631 "https://aggregation.uri/v1/aggregations/AGGREGATION_SESSION_ID/"
2632 "clients/AUTHORIZATION_TOKEN:startdataupload?%24alt=proto",
2633 kResourceName, kByteStreamTargetUri, kSecondStageAggregationTargetUri);
2634 ExpectSuccessfulByteStreamUploadRequest(
2635 "https://bytestream.uri/upload/v1/media/"
2636 "CHECKPOINT_RESOURCE?upload_protocol=raw",
2637 checkpoint_str);
2638 ExpectSuccessfulSubmitAggregationResultRequest(
2639 "https://aggregation.second.uri/v1/aggregations/"
2640 "AGGREGATION_SESSION_ID/clients/CLIENT_TOKEN:submit?%24alt=proto");
2641
2642 // Despite the ReportTaskResult request failed, we still consider the overall
2643 // ReportCompleted succeeded because the rest of the steps succeeds, and the
2644 // ReportTaskResult is a just a metric reporting on a best effort basis.
2645 EXPECT_OK(federated_protocol_->ReportCompleted(std::move(results),
2646 plan_duration, std::nullopt));
2647 }
2648
TEST_F(HttpFederatedProtocolTest,TestReportCompletedStartAggregationFailedImmediately)2649 TEST_F(HttpFederatedProtocolTest,
2650 TestReportCompletedStartAggregationFailedImmediately) {
2651 // Issue an eligibility eval checkin first.
2652 ASSERT_OK(RunSuccessfulEligibilityEvalCheckin());
2653 // Issue a regular checkin.
2654 ASSERT_OK(RunSuccessfulCheckin());
2655
2656 std::string checkpoint_str;
2657 const size_t kTFCheckpointSize = 32;
2658 checkpoint_str.resize(kTFCheckpointSize, 'X');
2659 ComputationResults results;
2660 results.emplace("tensorflow_checkpoint", checkpoint_str);
2661 absl::Duration plan_duration = absl::Minutes(5);
2662
2663 ExpectSuccessfulReportTaskResultRequest(
2664 "https://taskassignment.uri/v1/populations/TEST%2FPOPULATION/"
2665 "taskassignments/CLIENT_SESSION_ID:reportresult?%24alt=proto",
2666 kAggregationSessionId, kTaskName, plan_duration);
2667 EXPECT_CALL(
2668 mock_http_client_,
2669 PerformSingleRequest(SimpleHttpRequestMatcher(
2670 "https://aggregation.uri/v1/aggregations/AGGREGATION_SESSION_ID/"
2671 "clients/AUTHORIZATION_TOKEN:startdataupload?%24alt=proto",
2672 HttpRequest::Method::kPost, _,
2673 StartAggregationDataUploadRequest().SerializeAsString())))
2674 .WillOnce(Return(FakeHttpResponse(503, HeaderList())));
2675 absl::Status report_result = federated_protocol_->ReportCompleted(
2676 std::move(results), plan_duration, std::nullopt);
2677 ASSERT_THAT(report_result, IsCode(absl::StatusCode::kUnavailable));
2678 EXPECT_THAT(report_result.message(),
2679 HasSubstr("StartAggregationDataUpload request failed"));
2680 EXPECT_THAT(report_result.message(), HasSubstr("503"));
2681 }
2682
TEST_F(HttpFederatedProtocolTest,TestReportCompletedStartAggregationFailedDuringPolling)2683 TEST_F(HttpFederatedProtocolTest,
2684 TestReportCompletedStartAggregationFailedDuringPolling) {
2685 // Issue an eligibility eval checkin first.
2686 ASSERT_OK(RunSuccessfulEligibilityEvalCheckin());
2687 // Issue a regular checkin.
2688 ASSERT_OK(RunSuccessfulCheckin());
2689
2690 std::string checkpoint_str;
2691 const size_t kTFCheckpointSize = 32;
2692 checkpoint_str.resize(kTFCheckpointSize, 'X');
2693 ComputationResults results;
2694 results.emplace("tensorflow_checkpoint", checkpoint_str);
2695 absl::Duration plan_duration = absl::Minutes(5);
2696
2697 ExpectSuccessfulReportTaskResultRequest(
2698 "https://taskassignment.uri/v1/populations/TEST%2FPOPULATION/"
2699 "taskassignments/CLIENT_SESSION_ID:reportresult?%24alt=proto",
2700 kAggregationSessionId, kTaskName, plan_duration);
2701 Operation pending_operation_response =
2702 CreatePendingOperation("operations/foo#bar");
2703 EXPECT_CALL(
2704 mock_http_client_,
2705 PerformSingleRequest(SimpleHttpRequestMatcher(
2706 "https://aggregation.uri/v1/aggregations/AGGREGATION_SESSION_ID/"
2707 "clients/AUTHORIZATION_TOKEN:startdataupload?%24alt=proto",
2708 HttpRequest::Method::kPost, _,
2709 StartAggregationDataUploadRequest().SerializeAsString())))
2710 .WillOnce(Return(FakeHttpResponse(
2711 200, HeaderList(), pending_operation_response.SerializeAsString())));
2712 EXPECT_CALL(
2713 mock_http_client_,
2714 PerformSingleRequest(SimpleHttpRequestMatcher(
2715 // Note that the '#' character is encoded as "%23".
2716 "https://aggregation.uri/v1/operations/foo%23bar?%24alt=proto",
2717 HttpRequest::Method::kGet, _,
2718 GetOperationRequestMatcher(EqualsProto(GetOperationRequest())))))
2719 .WillOnce(Return(FakeHttpResponse(401, HeaderList())));
2720 absl::Status report_result = federated_protocol_->ReportCompleted(
2721 std::move(results), plan_duration, std::nullopt);
2722 ASSERT_THAT(report_result, IsCode(absl::StatusCode::kUnauthenticated));
2723 EXPECT_THAT(report_result.message(),
2724 HasSubstr("StartAggregationDataUpload request failed"));
2725 EXPECT_THAT(report_result.message(), HasSubstr("401"));
2726 }
2727
TEST_F(HttpFederatedProtocolTest,TestReportCompletedUploadFailed)2728 TEST_F(HttpFederatedProtocolTest, TestReportCompletedUploadFailed) {
2729 // Issue an eligibility eval checkin first.
2730 ASSERT_OK(RunSuccessfulEligibilityEvalCheckin());
2731 // Issue a regular checkin.
2732 ASSERT_OK(RunSuccessfulCheckin());
2733
2734 std::string checkpoint_str;
2735 const size_t kTFCheckpointSize = 32;
2736 checkpoint_str.resize(kTFCheckpointSize, 'X');
2737 ComputationResults results;
2738 results.emplace("tensorflow_checkpoint", checkpoint_str);
2739 absl::Duration plan_duration = absl::Minutes(5);
2740
2741 ExpectSuccessfulReportTaskResultRequest(
2742 "https://taskassignment.uri/v1/populations/TEST%2FPOPULATION/"
2743 "taskassignments/CLIENT_SESSION_ID:reportresult?%24alt=proto",
2744 kAggregationSessionId, kTaskName, plan_duration);
2745 ExpectSuccessfulStartAggregationDataUploadRequest(
2746 "https://aggregation.uri/v1/aggregations/AGGREGATION_SESSION_ID/"
2747 "clients/AUTHORIZATION_TOKEN:startdataupload?%24alt=proto",
2748 kResourceName, kByteStreamTargetUri, kSecondStageAggregationTargetUri);
2749 EXPECT_CALL(mock_http_client_,
2750 PerformSingleRequest(SimpleHttpRequestMatcher(
2751 StrEq("https://bytestream.uri/upload/v1/media/"
2752 "CHECKPOINT_RESOURCE?upload_protocol=raw"),
2753 HttpRequest::Method::kPost, _, std::string(checkpoint_str))))
2754 .WillOnce(Return(FakeHttpResponse(501, HeaderList())));
2755 ExpectSuccessfulAbortAggregationRequest("https://aggregation.second.uri");
2756 absl::Status report_result = federated_protocol_->ReportCompleted(
2757 std::move(results), plan_duration, std::nullopt);
2758 ASSERT_THAT(report_result, IsCode(absl::StatusCode::kUnimplemented));
2759 EXPECT_THAT(report_result.message(), HasSubstr("Data upload failed"));
2760 EXPECT_THAT(report_result.message(), HasSubstr("501"));
2761 }
2762
TEST_F(HttpFederatedProtocolTest,TestReportCompletedUploadAbortedByServer)2763 TEST_F(HttpFederatedProtocolTest, TestReportCompletedUploadAbortedByServer) {
2764 // Issue an eligibility eval checkin first.
2765 ASSERT_OK(RunSuccessfulEligibilityEvalCheckin());
2766 // Issue a regular checkin.
2767 ASSERT_OK(RunSuccessfulCheckin());
2768
2769 std::string checkpoint_str;
2770 const size_t kTFCheckpointSize = 32;
2771 checkpoint_str.resize(kTFCheckpointSize, 'X');
2772 ComputationResults results;
2773 results.emplace("tensorflow_checkpoint", checkpoint_str);
2774 absl::Duration plan_duration = absl::Minutes(5);
2775
2776 ExpectSuccessfulReportTaskResultRequest(
2777 "https://taskassignment.uri/v1/populations/TEST%2FPOPULATION/"
2778 "taskassignments/CLIENT_SESSION_ID:reportresult?%24alt=proto",
2779 kAggregationSessionId, kTaskName, plan_duration);
2780 ExpectSuccessfulStartAggregationDataUploadRequest(
2781 "https://aggregation.uri/v1/aggregations/AGGREGATION_SESSION_ID/"
2782 "clients/AUTHORIZATION_TOKEN:startdataupload?%24alt=proto",
2783 kResourceName, kByteStreamTargetUri, kSecondStageAggregationTargetUri);
2784 EXPECT_CALL(mock_http_client_,
2785 PerformSingleRequest(SimpleHttpRequestMatcher(
2786 StrEq("https://bytestream.uri/upload/v1/media/"
2787 "CHECKPOINT_RESOURCE?upload_protocol=raw"),
2788 HttpRequest::Method::kPost, _, std::string(checkpoint_str))))
2789 .WillOnce(Return(FakeHttpResponse(
2790 409, HeaderList(),
2791 CreateErrorOperation(kOperationName, absl::StatusCode::kAborted,
2792 "The client update is no longer needed.")
2793 .SerializeAsString())));
2794 absl::Status report_result = federated_protocol_->ReportCompleted(
2795 std::move(results), plan_duration, std::nullopt);
2796 ASSERT_THAT(report_result, IsCode(absl::StatusCode::kAborted));
2797 EXPECT_THAT(report_result.message(), HasSubstr("Data upload failed"));
2798 EXPECT_THAT(report_result.message(), HasSubstr("409"));
2799 }
2800
TEST_F(HttpFederatedProtocolTest,TestReportCompletedUploadInterrupted)2801 TEST_F(HttpFederatedProtocolTest, TestReportCompletedUploadInterrupted) {
2802 // Issue an eligibility eval checkin first.
2803 ASSERT_OK(RunSuccessfulEligibilityEvalCheckin());
2804 // Issue a regular checkin.
2805 ASSERT_OK(RunSuccessfulCheckin());
2806
2807 std::string checkpoint_str;
2808 const size_t kTFCheckpointSize = 32;
2809 checkpoint_str.resize(kTFCheckpointSize, 'X');
2810 ComputationResults results;
2811 results.emplace("tensorflow_checkpoint", checkpoint_str);
2812 absl::Duration plan_duration = absl::Minutes(5);
2813
2814 ExpectSuccessfulReportTaskResultRequest(
2815 "https://taskassignment.uri/v1/populations/TEST%2FPOPULATION/"
2816 "taskassignments/CLIENT_SESSION_ID:reportresult?%24alt=proto",
2817 kAggregationSessionId, kTaskName, plan_duration);
2818 ExpectSuccessfulStartAggregationDataUploadRequest(
2819 "https://aggregation.uri/v1/aggregations/AGGREGATION_SESSION_ID/"
2820 "clients/AUTHORIZATION_TOKEN:startdataupload?%24alt=proto",
2821 kResourceName, kByteStreamTargetUri, kSecondStageAggregationTargetUri);
2822 absl::Notification request_issued;
2823 absl::Notification request_cancelled;
2824
2825 // Make HttpClient::PerformRequests() block until the counter is decremented.
2826 EXPECT_CALL(mock_http_client_,
2827 PerformSingleRequest(SimpleHttpRequestMatcher(
2828 StrEq("https://bytestream.uri/upload/v1/media/"
2829 "CHECKPOINT_RESOURCE?upload_protocol=raw"),
2830 HttpRequest::Method::kPost, _, std::string(checkpoint_str))))
2831 .WillOnce([&request_issued, &request_cancelled](
2832 MockableHttpClient::SimpleHttpRequest ignored) {
2833 request_issued.Notify();
2834 request_cancelled.WaitForNotification();
2835 return FakeHttpResponse(503, HeaderList(), "");
2836 });
2837 // Make should_abort return false until we know that the request was issued
2838 // (i.e. once InterruptibleRunner has actually started running the code it
2839 // was given), and then make it return true, triggering an abort sequence and
2840 // unblocking the PerformRequests()() call we caused to block above.
2841 EXPECT_CALL(mock_should_abort_, Call()).WillRepeatedly([&request_issued] {
2842 return request_issued.HasBeenNotified();
2843 });
2844
2845 // When the HttpClient receives a HttpRequestHandle::Cancel call, we let the
2846 // request complete.
2847 mock_http_client_.SetCancellationListener([&request_cancelled]() {
2848 if (!request_cancelled.HasBeenNotified()) {
2849 request_cancelled.Notify();
2850 }
2851 });
2852
2853 EXPECT_CALL(mock_log_manager_,
2854 LogDiag(ProdDiagCode::BACKGROUND_TRAINING_INTERRUPT_HTTP));
2855 ExpectSuccessfulAbortAggregationRequest("https://aggregation.second.uri");
2856 absl::Status report_result = federated_protocol_->ReportCompleted(
2857 std::move(results), plan_duration, std::nullopt);
2858 ASSERT_THAT(report_result, IsCode(absl::StatusCode::kCancelled));
2859 EXPECT_THAT(report_result.message(), HasSubstr("Data upload failed"));
2860 }
2861
TEST_F(HttpFederatedProtocolTest,TestReportCompletedSubmitAggregationResultFailed)2862 TEST_F(HttpFederatedProtocolTest,
2863 TestReportCompletedSubmitAggregationResultFailed) {
2864 // Issue an eligibility eval checkin first.
2865 ASSERT_OK(RunSuccessfulEligibilityEvalCheckin());
2866 // Issue a regular checkin.
2867 ASSERT_OK(RunSuccessfulCheckin());
2868
2869 std::string checkpoint_str;
2870 const size_t kTFCheckpointSize = 32;
2871 checkpoint_str.resize(kTFCheckpointSize, 'X');
2872 ComputationResults results;
2873 results.emplace("tensorflow_checkpoint", checkpoint_str);
2874 absl::Duration plan_duration = absl::Minutes(5);
2875
2876 ExpectSuccessfulReportTaskResultRequest(
2877 "https://taskassignment.uri/v1/populations/TEST%2FPOPULATION/"
2878 "taskassignments/CLIENT_SESSION_ID:reportresult?%24alt=proto",
2879 kAggregationSessionId, kTaskName, plan_duration);
2880 ExpectSuccessfulStartAggregationDataUploadRequest(
2881 "https://aggregation.uri/v1/aggregations/AGGREGATION_SESSION_ID/"
2882 "clients/AUTHORIZATION_TOKEN:startdataupload?%24alt=proto",
2883 kResourceName, kByteStreamTargetUri, kSecondStageAggregationTargetUri);
2884 ExpectSuccessfulByteStreamUploadRequest(
2885 "https://bytestream.uri/upload/v1/media/"
2886 "CHECKPOINT_RESOURCE?upload_protocol=raw",
2887 checkpoint_str);
2888
2889 SubmitAggregationResultRequest submit_aggregation_result_request;
2890 submit_aggregation_result_request.set_resource_name(kResourceName);
2891 EXPECT_CALL(
2892 mock_http_client_,
2893 PerformSingleRequest(SimpleHttpRequestMatcher(
2894 "https://aggregation.second.uri/v1/aggregations/"
2895 "AGGREGATION_SESSION_ID/clients/CLIENT_TOKEN:submit?%24alt=proto",
2896 HttpRequest::Method::kPost, _,
2897 submit_aggregation_result_request.SerializeAsString())))
2898 .WillOnce(Return(FakeHttpResponse(409, HeaderList())));
2899 absl::Status report_result = federated_protocol_->ReportCompleted(
2900 std::move(results), plan_duration, std::nullopt);
2901
2902 ASSERT_THAT(report_result, IsCode(absl::StatusCode::kAborted));
2903 EXPECT_THAT(report_result.message(),
2904 HasSubstr("SubmitAggregationResult failed"));
2905 EXPECT_THAT(report_result.message(), HasSubstr("409"));
2906 }
2907
TEST_F(HttpFederatedProtocolTest,TestReportNotCompletedSuccess)2908 TEST_F(HttpFederatedProtocolTest, TestReportNotCompletedSuccess) {
2909 // Issue an eligibility eval checkin first.
2910 ASSERT_OK(RunSuccessfulEligibilityEvalCheckin());
2911 // Issue a regular checkin.
2912 ASSERT_OK(RunSuccessfulCheckin());
2913 absl::Duration plan_duration = absl::Minutes(5);
2914 ReportTaskResultResponse response;
2915 EXPECT_CALL(mock_http_client_,
2916 PerformSingleRequest(SimpleHttpRequestMatcher(
2917 "https://taskassignment.uri/v1/populations/TEST%2FPOPULATION/"
2918 "taskassignments/CLIENT_SESSION_ID:reportresult?%24alt=proto",
2919 HttpRequest::Method::kPost, _,
2920 ReportTaskResultRequestMatcher(
2921 EqualsProto(GetExpectedReportTaskResultRequest(
2922 kAggregationSessionId, kTaskName,
2923 ::google::rpc::Code::INTERNAL, plan_duration))))))
2924 .WillOnce(Return(
2925 FakeHttpResponse(200, HeaderList(), response.SerializeAsString())));
2926
2927 ASSERT_OK(federated_protocol_->ReportNotCompleted(
2928 engine::PhaseOutcome::ERROR, plan_duration, std::nullopt));
2929
2930 ExpectAcceptedRetryWindow(federated_protocol_->GetLatestRetryWindow());
2931 }
2932
TEST_F(HttpFederatedProtocolTest,TestReportNotCompletedError)2933 TEST_F(HttpFederatedProtocolTest, TestReportNotCompletedError) {
2934 // Issue an eligibility eval checkin first.
2935 ASSERT_OK(RunSuccessfulEligibilityEvalCheckin());
2936 // Issue a regular checkin.
2937 ASSERT_OK(RunSuccessfulCheckin());
2938 ReportTaskResultResponse response;
2939 EXPECT_CALL(mock_http_client_,
2940 PerformSingleRequest(SimpleHttpRequestMatcher(
2941 "https://taskassignment.uri/v1/populations/TEST%2FPOPULATION/"
2942 "taskassignments/CLIENT_SESSION_ID:reportresult?%24alt=proto",
2943 HttpRequest::Method::kPost, _, _)))
2944 .WillOnce(Return(FakeHttpResponse(503, HeaderList())));
2945
2946 absl::Status status = federated_protocol_->ReportNotCompleted(
2947 engine::PhaseOutcome::ERROR, absl::Minutes(5), std::nullopt);
2948 EXPECT_THAT(status, IsCode(UNAVAILABLE));
2949 EXPECT_THAT(
2950 status.message(),
2951 AllOf(HasSubstr("ReportTaskResult request failed:"), HasSubstr("503")));
2952 ExpectAcceptedRetryWindow(federated_protocol_->GetLatestRetryWindow());
2953 }
2954
TEST_F(HttpFederatedProtocolTest,TestReportNotCompletedPermanentError)2955 TEST_F(HttpFederatedProtocolTest, TestReportNotCompletedPermanentError) {
2956 // Issue an eligibility eval checkin first.
2957 ASSERT_OK(RunSuccessfulEligibilityEvalCheckin());
2958 // Issue a regular checkin.
2959 ASSERT_OK(RunSuccessfulCheckin());
2960 ReportTaskResultResponse response;
2961 EXPECT_CALL(mock_http_client_,
2962 PerformSingleRequest(SimpleHttpRequestMatcher(
2963 "https://taskassignment.uri/v1/populations/TEST%2FPOPULATION/"
2964 "taskassignments/CLIENT_SESSION_ID:reportresult?%24alt=proto",
2965 HttpRequest::Method::kPost, _, _)))
2966 .WillOnce(Return(FakeHttpResponse(404, HeaderList())));
2967
2968 absl::Status status = federated_protocol_->ReportNotCompleted(
2969 engine::PhaseOutcome::ERROR, absl::Minutes(5), std::nullopt);
2970 EXPECT_THAT(status, IsCode(NOT_FOUND));
2971 EXPECT_THAT(
2972 status.message(),
2973 AllOf(HasSubstr("ReportTaskResult request failed:"), HasSubstr("404")));
2974 ExpectAcceptedRetryWindow(federated_protocol_->GetLatestRetryWindow());
2975 }
2976
TEST_F(HttpFederatedProtocolTest,TestClientDecodedResourcesEnabledDeclaresSupport)2977 TEST_F(HttpFederatedProtocolTest,
2978 TestClientDecodedResourcesEnabledDeclaresSupport) {
2979 EligibilityEvalTaskRequest expected_eligibility_request;
2980 expected_eligibility_request.mutable_client_version()->set_version_code(
2981 kClientVersion);
2982 expected_eligibility_request.mutable_attestation_measurement()->set_value(
2983 kAttestationMeasurement);
2984 // Make sure gzip support is declared in the eligibility eval checkin request.
2985 expected_eligibility_request.mutable_resource_capabilities()
2986 ->add_supported_compression_formats(
2987 ResourceCompressionFormat::RESOURCE_COMPRESSION_FORMAT_GZIP);
2988 expected_eligibility_request.mutable_eligibility_eval_task_capabilities()
2989 ->set_supports_multiple_task_assignment(false);
2990
2991 // Issue an eligibility eval checkin so we can validate the field is set.
2992 Resource eligibility_plan_resource;
2993 eligibility_plan_resource.mutable_inline_resource()->set_data(kPlan);
2994 Resource checkpoint_resource;
2995 checkpoint_resource.mutable_inline_resource()->set_data(kInitCheckpoint);
2996
2997 EligibilityEvalTaskResponse eval_task_response =
2998 GetFakeEnabledEligibilityEvalTaskResponse(eligibility_plan_resource,
2999 checkpoint_resource,
3000 kEligibilityEvalExecutionId);
3001 const std::string eligibility_request_uri =
3002 "https://initial.uri/v1/eligibilityevaltasks/"
3003 "TEST%2FPOPULATION:request?%24alt=proto";
3004 EXPECT_CALL(mock_http_client_,
3005 PerformSingleRequest(SimpleHttpRequestMatcher(
3006 eligibility_request_uri, HttpRequest::Method::kPost, _,
3007 EligibilityEvalTaskRequestMatcher(
3008 EqualsProto(expected_eligibility_request)))))
3009 .WillOnce(Return(FakeHttpResponse(
3010 200, HeaderList(), eval_task_response.SerializeAsString())));
3011
3012 ASSERT_OK(federated_protocol_->EligibilityEvalCheckin(
3013 mock_eet_received_callback_.AsStdFunction()));
3014
3015 // Now issue a regular checkin and make sure the field is set there too.
3016 const std::string plan_uri = "https://fake.uri/plan";
3017 Resource plan_resource;
3018 plan_resource.set_uri(plan_uri);
3019 StartTaskAssignmentResponse task_assignment_response =
3020 GetFakeTaskAssignmentResponse(plan_resource, checkpoint_resource,
3021 kFederatedSelectUriTemplate,
3022 kAggregationSessionId, 0);
3023 const std::string request_uri =
3024 "https://taskassignment.uri/v1/populations/TEST%2FPOPULATION/"
3025 "taskassignments/ELIGIBILITY%2FSESSION%23ID:start?%24alt=proto";
3026 TaskEligibilityInfo expected_eligibility_info = GetFakeTaskEligibilityInfo();
3027 StartTaskAssignmentRequest expected_request;
3028 expected_request.mutable_client_version()->set_version_code(kClientVersion);
3029 *expected_request.mutable_task_eligibility_info() = expected_eligibility_info;
3030 // Make sure gzip support is declared in the regular checkin request.
3031 expected_request.mutable_resource_capabilities()
3032 ->add_supported_compression_formats(
3033 ResourceCompressionFormat::RESOURCE_COMPRESSION_FORMAT_GZIP);
3034
3035 EXPECT_CALL(
3036 mock_http_client_,
3037 PerformSingleRequest(SimpleHttpRequestMatcher(
3038 request_uri, HttpRequest::Method::kPost, _,
3039 StartTaskAssignmentRequestMatcher(EqualsProto(expected_request)))))
3040 .WillOnce(Return(FakeHttpResponse(
3041 200, HeaderList(),
3042 CreateDoneOperation(kOperationName, task_assignment_response)
3043 .SerializeAsString())));
3044
3045 EXPECT_CALL(mock_http_client_,
3046 PerformSingleRequest(SimpleHttpRequestMatcher(
3047 plan_uri, HttpRequest::Method::kGet, _, "")))
3048 .WillOnce(Return(FakeHttpResponse(200, HeaderList(), kPlan)));
3049
3050 std::string report_eet_request_uri =
3051 "https://initial.uri/v1/populations/TEST%2FPOPULATION/"
3052 "eligibilityevaltasks/"
3053 "ELIGIBILITY%2FSESSION%23ID:reportresult?%24alt=proto";
3054 ExpectSuccessfulReportEligibilityEvalTaskResultRequest(report_eet_request_uri,
3055 absl::OkStatus());
3056
3057 ASSERT_OK(federated_protocol_->Checkin(
3058 expected_eligibility_info, mock_task_received_callback_.AsStdFunction()));
3059 }
3060
3061 } // anonymous namespace
3062 } // namespace fcp::client::http
3063