xref: /aosp_15_r20/external/federated-compute/fcp/client/http/http_federated_protocol_test.cc (revision 14675a029014e728ec732f129a32e299b2da0601)
1 /*
2  * Copyright 2022 Google LLC
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "fcp/client/http/http_federated_protocol.h"
17 
18 #include <cstdint>
19 #include <memory>
20 #include <optional>
21 #include <string>
22 #include <utility>
23 #include <vector>
24 
25 #include "google/longrunning/operations.pb.h"
26 #include "google/protobuf/any.pb.h"
27 #include "google/protobuf/duration.pb.h"
28 #include "google/rpc/code.pb.h"
29 #include "gmock/gmock.h"
30 #include "gtest/gtest.h"
31 #include "absl/memory/memory.h"
32 #include "absl/random/random.h"
33 #include "absl/status/status.h"
34 #include "absl/status/statusor.h"
35 #include "absl/strings/str_cat.h"
36 #include "absl/synchronization/blocking_counter.h"
37 #include "absl/synchronization/notification.h"
38 #include "absl/time/time.h"
39 #include "fcp/base/clock.h"
40 #include "fcp/base/monitoring.h"
41 #include "fcp/base/platform.h"
42 #include "fcp/base/time_util.h"
43 #include "fcp/base/wall_clock_stopwatch.h"
44 #include "fcp/client/cache/test_helpers.h"
45 #include "fcp/client/diag_codes.pb.h"
46 #include "fcp/client/engine/engine.pb.h"
47 #include "fcp/client/federated_protocol.h"
48 #include "fcp/client/federated_protocol_util.h"
49 #include "fcp/client/http/http_client.h"
50 #include "fcp/client/http/http_client_util.h"
51 #include "fcp/client/http/in_memory_request_response.h"
52 #include "fcp/client/http/testing/test_helpers.h"
53 #include "fcp/client/interruptible_runner.h"
54 #include "fcp/client/stats.h"
55 #include "fcp/client/test_helpers.h"
56 #include "fcp/protos/federated_api.pb.h"
57 #include "fcp/protos/federatedcompute/aggregations.pb.h"
58 #include "fcp/protos/federatedcompute/common.pb.h"
59 #include "fcp/protos/federatedcompute/eligibility_eval_tasks.pb.h"
60 #include "fcp/protos/federatedcompute/secure_aggregations.pb.h"
61 #include "fcp/protos/federatedcompute/task_assignments.pb.h"
62 #include "fcp/protos/plan.pb.h"
63 #include "fcp/secagg/shared/secagg_messages.pb.h"
64 #include "fcp/testing/testing.h"
65 
66 namespace fcp::client::http {
67 namespace {
68 
69 using ::fcp::EqualsProto;
70 using ::fcp::IsCode;
71 using ::fcp::client::http::FakeHttpResponse;
72 using ::fcp::client::http::MockableHttpClient;
73 using ::fcp::client::http::MockHttpClient;
74 using ::fcp::client::http::SimpleHttpRequestMatcher;
75 using ::google::internal::federatedcompute::v1::ByteStreamResource;
76 using ::google::internal::federatedcompute::v1::ClientStats;
77 using ::google::internal::federatedcompute::v1::EligibilityEvalTask;
78 using ::google::internal::federatedcompute::v1::EligibilityEvalTaskRequest;
79 using ::google::internal::federatedcompute::v1::EligibilityEvalTaskResponse;
80 using ::google::internal::federatedcompute::v1::ForwardingInfo;
81 using ::google::internal::federatedcompute::v1::PopulationEligibilitySpec;
82 using ::google::internal::federatedcompute::v1::
83     ReportEligibilityEvalTaskResultRequest;
84 using ::google::internal::federatedcompute::v1::ReportTaskResultRequest;
85 using ::google::internal::federatedcompute::v1::ReportTaskResultResponse;
86 using ::google::internal::federatedcompute::v1::Resource;
87 using ::google::internal::federatedcompute::v1::ResourceCompressionFormat;
88 using ::google::internal::federatedcompute::v1::RetryWindow;
89 using ::google::internal::federatedcompute::v1::SecureAggregandExecutionInfo;
90 using ::google::internal::federatedcompute::v1::
91     StartAggregationDataUploadRequest;
92 using ::google::internal::federatedcompute::v1::
93     StartAggregationDataUploadResponse;
94 using ::google::internal::federatedcompute::v1::StartSecureAggregationRequest;
95 using ::google::internal::federatedcompute::v1::StartSecureAggregationResponse;
96 using ::google::internal::federatedcompute::v1::StartTaskAssignmentRequest;
97 using ::google::internal::federatedcompute::v1::StartTaskAssignmentResponse;
98 using ::google::internal::federatedcompute::v1::SubmitAggregationResultRequest;
99 using ::google::internal::federatedcompute::v1::TaskAssignment;
100 using ::google::internal::federatedml::v2::TaskEligibilityInfo;
101 using ::google::internal::federatedml::v2::TaskWeight;
102 using ::google::longrunning::GetOperationRequest;
103 using ::google::longrunning::Operation;
104 using ::testing::_;
105 using ::testing::AllOf;
106 using ::testing::ByMove;
107 using ::testing::DescribeMatcher;
108 using ::testing::DoubleEq;
109 using ::testing::DoubleNear;
110 using ::testing::Eq;
111 using ::testing::ExplainMatchResult;
112 using ::testing::Field;
113 using ::testing::FieldsAre;
114 using ::testing::Ge;
115 using ::testing::Gt;
116 using ::testing::HasSubstr;
117 using ::testing::InSequence;
118 using ::testing::IsEmpty;
119 using ::testing::Lt;
120 using ::testing::MockFunction;
121 using ::testing::NiceMock;
122 using ::testing::Not;
123 using ::testing::Optional;
124 using ::testing::Return;
125 using ::testing::StrEq;
126 using ::testing::StrictMock;
127 using ::testing::UnorderedElementsAre;
128 using ::testing::VariantWith;
129 using ::testing::WithArg;
130 
131 constexpr char kEntryPointUri[] = "https://initial.uri/";
132 constexpr char kTaskAssignmentTargetUri[] = "https://taskassignment.uri/";
133 constexpr char kAggregationTargetUri[] = "https://aggregation.uri/";
134 constexpr char kSecondStageAggregationTargetUri[] =
135     "https://aggregation.second.uri/";
136 constexpr char kByteStreamTargetUri[] = "https://bytestream.uri/";
137 constexpr char kApiKey[] = "TEST_APIKEY";
138 // Note that we include a '/' character in the population name, which allows us
139 // to verify that it is correctly URL-encoded into "%2F".
140 constexpr char kPopulationName[] = "TEST/POPULATION";
141 constexpr char kEligibilityEvalExecutionId[] = "ELIGIBILITY_EXECUTION_ID";
142 // Note that we include a '/' and '#' characters in the population name, which
143 // allows us to verify that it is correctly URL-encoded into "%2F" and "%23".
144 constexpr char kEligibilityEvalSessionId[] = "ELIGIBILITY/SESSION#ID";
145 constexpr char kPlan[] = "CLIENT_ONLY_PLAN";
146 constexpr char kInitCheckpoint[] = "INIT_CHECKPOINT";
147 constexpr char kRetryToken[] = "OLD_RETRY_TOKEN";
148 constexpr char kClientVersion[] = "CLIENT_VERSION";
149 constexpr char kAttestationMeasurement[] = "ATTESTATION_MEASUREMENT";
150 constexpr char kClientSessionId[] = "CLIENT_SESSION_ID";
151 constexpr char kAggregationSessionId[] = "AGGREGATION_SESSION_ID";
152 constexpr char kAuthorizationToken[] = "AUTHORIZATION_TOKEN";
153 constexpr char kTaskName[] = "TASK_NAME";
154 constexpr char kClientToken[] = "CLIENT_TOKEN";
155 constexpr char kResourceName[] = "CHECKPOINT_RESOURCE";
156 constexpr char kFederatedSelectUriTemplate[] = "https://federated.select";
157 constexpr char kOperationName[] = "my_operation";
158 
159 const int32_t kCancellationWaitingPeriodSec = 1;
160 const int32_t kMinimumClientsInServerVisibleAggregate = 2;
161 
162 MATCHER_P(EligibilityEvalTaskRequestMatcher, matcher,
163           absl::StrCat(negation ? "doesn't parse" : "parses",
164                        " as an EligibilityEvalTaskRequest, and that ",
165                        DescribeMatcher<EligibilityEvalTaskRequest>(matcher,
166                                                                    negation))) {
167   EligibilityEvalTaskRequest request;
168   if (!request.ParseFromString(arg)) {
169     return false;
170   }
171   return ExplainMatchResult(matcher, request, result_listener);
172 }
173 
174 MATCHER_P(
175     ReportEligibilityEvalTaskResultRequestMatcher, matcher,
176     absl::StrCat(negation ? "doesn't parse" : "parses",
177                  " as a ReportEligibilityEvalTaskResultRequest, and that ",
178                  DescribeMatcher<ReportEligibilityEvalTaskResultRequest>(
179                      matcher, negation))) {
180   ReportEligibilityEvalTaskResultRequest request;
181   if (!request.ParseFromString(arg)) {
182     return false;
183   }
184   return ExplainMatchResult(matcher, request, result_listener);
185 }
186 
187 MATCHER_P(StartTaskAssignmentRequestMatcher, matcher,
188           absl::StrCat(negation ? "doesn't parse" : "parses",
189                        " as a StartTaskAssignmentRequest, and that ",
190                        DescribeMatcher<StartTaskAssignmentRequest>(matcher,
191                                                                    negation))) {
192   StartTaskAssignmentRequest request;
193   if (!request.ParseFromString(arg)) {
194     return false;
195   }
196   return ExplainMatchResult(matcher, request, result_listener);
197 }
198 
199 MATCHER_P(GetOperationRequestMatcher, matcher,
200           absl::StrCat(negation ? "doesn't parse" : "parses",
201                        " as a GetOperationRequest, and that ",
202                        DescribeMatcher<GetOperationRequest>(matcher,
203                                                             negation))) {
204   GetOperationRequest request;
205   if (!request.ParseFromString(arg)) {
206     return false;
207   }
208   return ExplainMatchResult(matcher, request, result_listener);
209 }
210 
211 MATCHER_P(ReportTaskResultRequestMatcher, matcher,
212           absl::StrCat(negation ? "doesn't parse" : "parses",
213                        " as a ReportTaskResultRequest, and that ",
214                        DescribeMatcher<ReportTaskResultRequest>(matcher,
215                                                                 negation))) {
216   ReportTaskResultRequest request;
217   if (!request.ParseFromString(arg)) {
218     return false;
219   }
220   return ExplainMatchResult(matcher, request, result_listener);
221 }
222 
223 constexpr int kTransientErrorsRetryPeriodSecs = 10;
224 constexpr double kTransientErrorsRetryDelayJitterPercent = 0.1;
225 constexpr double kExpectedTransientErrorsRetryPeriodSecsMin = 9.0;
226 constexpr double kExpectedTransientErrorsRetryPeriodSecsMax = 11.0;
227 constexpr int kPermanentErrorsRetryPeriodSecs = 100;
228 constexpr double kPermanentErrorsRetryDelayJitterPercent = 0.2;
229 constexpr double kExpectedPermanentErrorsRetryPeriodSecsMin = 80.0;
230 constexpr double kExpectedPermanentErrorsRetryPeriodSecsMax = 120.0;
231 
ExpectTransientErrorRetryWindow(const::google::internal::federatedml::v2::RetryWindow & retry_window)232 void ExpectTransientErrorRetryWindow(
233     const ::google::internal::federatedml::v2::RetryWindow& retry_window) {
234   // The calculated retry delay must lie within the expected transient errors
235   // retry delay range.
236   EXPECT_THAT(retry_window.delay_min().seconds() +
237                   retry_window.delay_min().nanos() / 1000000000,
238               AllOf(Ge(kExpectedTransientErrorsRetryPeriodSecsMin),
239                     Lt(kExpectedTransientErrorsRetryPeriodSecsMax)));
240   EXPECT_THAT(retry_window.delay_max(), EqualsProto(retry_window.delay_min()));
241 }
242 
ExpectPermanentErrorRetryWindow(const::google::internal::federatedml::v2::RetryWindow & retry_window)243 void ExpectPermanentErrorRetryWindow(
244     const ::google::internal::federatedml::v2::RetryWindow& retry_window) {
245   // The calculated retry delay must lie within the expected permanent errors
246   // retry delay range.
247   EXPECT_THAT(retry_window.delay_min().seconds() +
248                   retry_window.delay_min().nanos() / 1000000000,
249               AllOf(Ge(kExpectedPermanentErrorsRetryPeriodSecsMin),
250                     Lt(kExpectedPermanentErrorsRetryPeriodSecsMax)));
251   EXPECT_THAT(retry_window.delay_max(), EqualsProto(retry_window.delay_min()));
252 }
253 
GetAcceptedRetryWindow()254 RetryWindow GetAcceptedRetryWindow() {
255   // Must not overlap with kTransientErrorsRetryPeriodSecs or
256   // kPermanentErrorsRetryPeriodSecs.
257   RetryWindow retry_window;
258   retry_window.mutable_delay_min()->set_seconds(200L);
259   retry_window.mutable_delay_max()->set_seconds(299L);
260   return retry_window;
261 }
262 
ExpectAcceptedRetryWindow(const::google::internal::federatedml::v2::RetryWindow & retry_window)263 void ExpectAcceptedRetryWindow(
264     const ::google::internal::federatedml::v2::RetryWindow& retry_window) {
265   // The calculated retry delay must lie within the expected 'rejected' retry
266   // delay range.
267   EXPECT_THAT(retry_window.delay_min().seconds() +
268                   retry_window.delay_min().nanos() / 1000000000,
269               AllOf(Ge(200L), Lt(299L)));
270   EXPECT_THAT(retry_window.delay_max(), EqualsProto(retry_window.delay_min()));
271 }
272 
GetRejectedRetryWindow()273 RetryWindow GetRejectedRetryWindow() {
274   // Must not overlap with kTransientErrorsRetryPeriodSecs or
275   // kPermanentErrorsRetryPeriodSecs.
276   RetryWindow retry_window;
277   retry_window.mutable_delay_min()->set_seconds(300L);
278   retry_window.mutable_delay_max()->set_seconds(399L);
279   return retry_window;
280 }
281 
ExpectRejectedRetryWindow(const::google::internal::federatedml::v2::RetryWindow & retry_window)282 void ExpectRejectedRetryWindow(
283     const ::google::internal::federatedml::v2::RetryWindow& retry_window) {
284   // The calculated retry delay must lie within the expected 'rejected' retry
285   // delay range.
286   EXPECT_THAT(retry_window.delay_min().seconds() +
287                   retry_window.delay_min().nanos() / 1000000000,
288               AllOf(Ge(300L), Lt(399L)));
289   EXPECT_THAT(retry_window.delay_max(), EqualsProto(retry_window.delay_min()));
290 }
291 
GetExpectedEligibilityEvalTaskRequest(bool supports_multiple_task_assignments=false)292 EligibilityEvalTaskRequest GetExpectedEligibilityEvalTaskRequest(
293     bool supports_multiple_task_assignments = false) {
294   EligibilityEvalTaskRequest request;
295   // Note: we don't expect population_name to be set, since it should be set in
296   // the URI instead.
297   request.mutable_client_version()->set_version_code(kClientVersion);
298   request.mutable_attestation_measurement()->set_value(kAttestationMeasurement);
299   request.mutable_resource_capabilities()
300       ->mutable_supported_compression_formats()
301       ->Add(ResourceCompressionFormat::RESOURCE_COMPRESSION_FORMAT_GZIP);
302   request.mutable_eligibility_eval_task_capabilities()
303       ->set_supports_multiple_task_assignment(
304           supports_multiple_task_assignments);
305   return request;
306 }
307 
GetFakeEnabledEligibilityEvalTaskResponse(const Resource & plan,const Resource & checkpoint,const std::string & execution_id,std::optional<Resource> population_eligibility_spec=std::nullopt,const RetryWindow & accepted_retry_window=GetAcceptedRetryWindow (),const RetryWindow & rejected_retry_window=GetRejectedRetryWindow ())308 EligibilityEvalTaskResponse GetFakeEnabledEligibilityEvalTaskResponse(
309     const Resource& plan, const Resource& checkpoint,
310     const std::string& execution_id,
311     std::optional<Resource> population_eligibility_spec = std::nullopt,
312     const RetryWindow& accepted_retry_window = GetAcceptedRetryWindow(),
313     const RetryWindow& rejected_retry_window = GetRejectedRetryWindow()) {
314   EligibilityEvalTaskResponse response;
315   response.set_session_id(kEligibilityEvalSessionId);
316   EligibilityEvalTask* eval_task = response.mutable_eligibility_eval_task();
317   *eval_task->mutable_plan() = plan;
318   *eval_task->mutable_init_checkpoint() = checkpoint;
319   if (population_eligibility_spec.has_value()) {
320     *eval_task->mutable_population_eligibility_spec() =
321         population_eligibility_spec.value();
322   }
323   eval_task->set_execution_id(execution_id);
324   ForwardingInfo* forwarding_info =
325       response.mutable_task_assignment_forwarding_info();
326   forwarding_info->set_target_uri_prefix(kTaskAssignmentTargetUri);
327   *response.mutable_retry_window_if_accepted() = accepted_retry_window;
328   *response.mutable_retry_window_if_rejected() = rejected_retry_window;
329   return response;
330 }
331 
GetFakeDisabledEligibilityEvalTaskResponse()332 EligibilityEvalTaskResponse GetFakeDisabledEligibilityEvalTaskResponse() {
333   EligibilityEvalTaskResponse response;
334   response.set_session_id(kEligibilityEvalSessionId);
335   response.mutable_no_eligibility_eval_configured();
336   ForwardingInfo* forwarding_info =
337       response.mutable_task_assignment_forwarding_info();
338   forwarding_info->set_target_uri_prefix(kTaskAssignmentTargetUri);
339   *response.mutable_retry_window_if_accepted() = GetAcceptedRetryWindow();
340   *response.mutable_retry_window_if_rejected() = GetRejectedRetryWindow();
341   return response;
342 }
343 
GetFakeRejectedEligibilityEvalTaskResponse()344 EligibilityEvalTaskResponse GetFakeRejectedEligibilityEvalTaskResponse() {
345   EligibilityEvalTaskResponse response;
346   response.mutable_rejection_info();
347   *response.mutable_retry_window_if_accepted() = GetAcceptedRetryWindow();
348   *response.mutable_retry_window_if_rejected() = GetRejectedRetryWindow();
349   return response;
350 }
351 
GetFakeTaskEligibilityInfo()352 TaskEligibilityInfo GetFakeTaskEligibilityInfo() {
353   TaskEligibilityInfo eligibility_info;
354   TaskWeight* task_weight = eligibility_info.mutable_task_weights()->Add();
355   task_weight->set_task_name("foo");
356   task_weight->set_weight(567.8);
357   return eligibility_info;
358 }
359 
GetExpectedStartTaskAssignmentRequest(const std::optional<TaskEligibilityInfo> & task_eligibility_info)360 StartTaskAssignmentRequest GetExpectedStartTaskAssignmentRequest(
361     const std::optional<TaskEligibilityInfo>& task_eligibility_info) {
362   // Note: we don't expect population_name or session_id to be set, since they
363   // should be set in the URI instead.
364   StartTaskAssignmentRequest request;
365   request.mutable_client_version()->set_version_code(kClientVersion);
366   if (task_eligibility_info.has_value()) {
367     *request.mutable_task_eligibility_info() = *task_eligibility_info;
368   }
369   request.mutable_resource_capabilities()
370       ->mutable_supported_compression_formats()
371       ->Add(ResourceCompressionFormat::RESOURCE_COMPRESSION_FORMAT_GZIP);
372   return request;
373 }
374 
GetFakeRejectedTaskAssignmentResponse()375 StartTaskAssignmentResponse GetFakeRejectedTaskAssignmentResponse() {
376   StartTaskAssignmentResponse response;
377   response.mutable_rejection_info();
378   return response;
379 }
380 
GetFakeTaskAssignmentResponse(const Resource & plan,const Resource & checkpoint,const std::string & federated_select_uri_template,const std::string & aggregation_session_id,int32_t minimum_clients_in_server_visible_aggregate)381 StartTaskAssignmentResponse GetFakeTaskAssignmentResponse(
382     const Resource& plan, const Resource& checkpoint,
383     const std::string& federated_select_uri_template,
384     const std::string& aggregation_session_id,
385     int32_t minimum_clients_in_server_visible_aggregate) {
386   StartTaskAssignmentResponse response;
387   TaskAssignment* task_assignment = response.mutable_task_assignment();
388   ForwardingInfo* forwarding_info =
389       task_assignment->mutable_aggregation_data_forwarding_info();
390   forwarding_info->set_target_uri_prefix(kAggregationTargetUri);
391   task_assignment->set_session_id(kClientSessionId);
392   task_assignment->set_aggregation_id(aggregation_session_id);
393   task_assignment->set_authorization_token(kAuthorizationToken);
394   task_assignment->set_task_name(kTaskName);
395   *task_assignment->mutable_plan() = plan;
396   *task_assignment->mutable_init_checkpoint() = checkpoint;
397   task_assignment->mutable_federated_select_uri_info()->set_uri_template(
398       federated_select_uri_template);
399   if (minimum_clients_in_server_visible_aggregate > 0) {
400     task_assignment->mutable_secure_aggregation_info()
401         ->set_minimum_clients_in_server_visible_aggregate(
402             minimum_clients_in_server_visible_aggregate);
403   } else {
404     task_assignment->mutable_aggregation_info();
405   }
406   return response;
407 }
408 
GetExpectedReportTaskResultRequest(absl::string_view aggregation_id,absl::string_view task_name,::google::rpc::Code code,absl::Duration train_duration)409 ReportTaskResultRequest GetExpectedReportTaskResultRequest(
410     absl::string_view aggregation_id, absl::string_view task_name,
411     ::google::rpc::Code code, absl::Duration train_duration) {
412   ReportTaskResultRequest request;
413   request.set_aggregation_id(std::string(aggregation_id));
414   request.set_task_name(std::string(task_name));
415   request.set_computation_status_code(code);
416   ClientStats client_stats;
417   *client_stats.mutable_computation_execution_duration() =
418       TimeUtil::ConvertAbslToProtoDuration(train_duration);
419   *request.mutable_client_stats() = client_stats;
420   return request;
421 }
422 
GetFakeStartAggregationDataUploadResponse(absl::string_view aggregation_resource_name,absl::string_view byte_stream_uri_prefix,absl::string_view second_stage_aggregation_uri_prefix)423 StartAggregationDataUploadResponse GetFakeStartAggregationDataUploadResponse(
424     absl::string_view aggregation_resource_name,
425     absl::string_view byte_stream_uri_prefix,
426     absl::string_view second_stage_aggregation_uri_prefix) {
427   StartAggregationDataUploadResponse response;
428   ByteStreamResource* resource = response.mutable_resource();
429   *resource->mutable_resource_name() = aggregation_resource_name;
430   ForwardingInfo* data_upload_forwarding_info =
431       resource->mutable_data_upload_forwarding_info();
432   *data_upload_forwarding_info->mutable_target_uri_prefix() =
433       byte_stream_uri_prefix;
434   ForwardingInfo* aggregation_protocol_forwarding_info =
435       response.mutable_aggregation_protocol_forwarding_info();
436   *aggregation_protocol_forwarding_info->mutable_target_uri_prefix() =
437       second_stage_aggregation_uri_prefix;
438   response.set_client_token(kClientToken);
439   return response;
440 }
441 
CreateEmptySuccessHttpResponse()442 FakeHttpResponse CreateEmptySuccessHttpResponse() {
443   return FakeHttpResponse(200, HeaderList(), "");
444 }
445 
446 class HttpFederatedProtocolTest : public ::testing::Test {
447  protected:
SetUp()448   void SetUp() override {
449     EXPECT_CALL(mock_flags_,
450                 federated_training_transient_errors_retry_delay_secs)
451         .WillRepeatedly(Return(kTransientErrorsRetryPeriodSecs));
452     EXPECT_CALL(mock_flags_,
453                 federated_training_transient_errors_retry_delay_jitter_percent)
454         .WillRepeatedly(Return(kTransientErrorsRetryDelayJitterPercent));
455     EXPECT_CALL(mock_flags_,
456                 federated_training_permanent_errors_retry_delay_secs)
457         .WillRepeatedly(Return(kPermanentErrorsRetryPeriodSecs));
458     EXPECT_CALL(mock_flags_,
459                 federated_training_permanent_errors_retry_delay_jitter_percent)
460         .WillRepeatedly(Return(kPermanentErrorsRetryDelayJitterPercent));
461     EXPECT_CALL(mock_flags_, federated_training_permanent_error_codes)
462         .WillRepeatedly(Return(std::vector<int32_t>{
463             static_cast<int32_t>(absl::StatusCode::kNotFound),
464             static_cast<int32_t>(absl::StatusCode::kInvalidArgument),
465             static_cast<int32_t>(absl::StatusCode::kUnimplemented)}));
466     // Note that we disable compression in test to make it easier to verify the
467     // request body. The compression logic is tested in
468     // in_memory_request_response_test.cc.
469     EXPECT_CALL(mock_flags_, disable_http_request_body_compression)
470         .WillRepeatedly(Return(true));
471     EXPECT_CALL(mock_flags_, waiting_period_sec_for_cancellation)
472         .WillRepeatedly(Return(kCancellationWaitingPeriodSec));
473 
474     EXPECT_CALL(mock_flags_, http_protocol_supports_multiple_task_assignments)
475         .WillRepeatedly(Return(false));
476 
477     // We only initialize federated_protocol_ in this SetUp method, rather than
478     // in the test's constructor, to ensure that we can set mock flag values
479     // before the HttpFederatedProtocol constructor is called. Using
480     // std::unique_ptr conveniently allows us to assign the field a new value
481     // after construction (which we could not do if the field's type was
482     // HttpFederatedProtocol, since it doesn't have copy or move constructors).
483     federated_protocol_ = std::make_unique<HttpFederatedProtocol>(
484         clock_, &mock_log_manager_, &mock_flags_, &mock_http_client_,
485         absl::WrapUnique(mock_secagg_runner_factory_),
486         &mock_secagg_event_publisher_, kEntryPointUri, kApiKey, kPopulationName,
487         kRetryToken, kClientVersion, kAttestationMeasurement,
488         mock_should_abort_.AsStdFunction(), absl::BitGen(),
489         InterruptibleRunner::TimingConfig{
490             .polling_period = absl::ZeroDuration(),
491             .graceful_shutdown_period = absl::InfiniteDuration(),
492             .extended_shutdown_period = absl::InfiniteDuration()},
493         &mock_resource_cache_);
494   }
495 
TearDown()496   void TearDown() override {
497     // Regardless of the outcome of the test (or the protocol interaction being
498     // tested), network usage must always be reflected in the network stats
499     // methods.
500     HttpRequestHandle::SentReceivedBytes sent_received_bytes =
501         mock_http_client_.TotalSentReceivedBytes();
502 
503     NetworkStats network_stats = federated_protocol_->GetNetworkStats();
504     EXPECT_EQ(network_stats.bytes_downloaded,
505               sent_received_bytes.received_bytes);
506     EXPECT_EQ(network_stats.bytes_uploaded, sent_received_bytes.sent_bytes);
507     // If any network traffic occurred, we expect to see some time reflected in
508     // the duration.
509     if (network_stats.bytes_uploaded > 0) {
510       EXPECT_THAT(network_stats.network_duration, Gt(absl::ZeroDuration()));
511     }
512   }
513 
514   // This function runs a successful EligibilityEvalCheckin() that results in an
515   // eligibility eval payload being returned by the server (if
516   // `eligibility_eval_enabled` is true), or results in a 'no eligibility eval
517   // configured' response (if `eligibility_eval_enabled` is false). This is a
518   // utility function used by Checkin*() tests that depend on a prior,
519   // successful execution of EligibilityEvalCheckin(). It returns a
520   // absl::Status, which the caller should verify is OK using ASSERT_OK.
RunSuccessfulEligibilityEvalCheckin(bool eligibility_eval_enabled=true)521   absl::Status RunSuccessfulEligibilityEvalCheckin(
522       bool eligibility_eval_enabled = true) {
523     EligibilityEvalTaskResponse eval_task_response;
524     if (eligibility_eval_enabled) {
525       // We return a fake response which returns the plan/initial checkpoint
526       // data inline, to keep things simple.
527       std::string expected_plan = kPlan;
528       Resource plan_resource;
529       plan_resource.mutable_inline_resource()->set_data(kPlan);
530       std::string expected_checkpoint = kInitCheckpoint;
531       Resource checkpoint_resource;
532       checkpoint_resource.mutable_inline_resource()->set_data(
533           expected_checkpoint);
534       eval_task_response = GetFakeEnabledEligibilityEvalTaskResponse(
535           plan_resource, checkpoint_resource, kEligibilityEvalExecutionId);
536     } else {
537       eval_task_response = GetFakeDisabledEligibilityEvalTaskResponse();
538     }
539     std::string request_uri =
540         "https://initial.uri/v1/eligibilityevaltasks/"
541         "TEST%2FPOPULATION:request?%24alt=proto";
542     EXPECT_CALL(mock_http_client_,
543                 PerformSingleRequest(SimpleHttpRequestMatcher(
544                     request_uri, HttpRequest::Method::kPost, _,
545                     EligibilityEvalTaskRequestMatcher(
546                         EqualsProto(GetExpectedEligibilityEvalTaskRequest())))))
547         .WillOnce(Return(FakeHttpResponse(
548             200, HeaderList(), eval_task_response.SerializeAsString())));
549 
550     // The 'EET received' callback should be called, even if the task resource
551     // data was available inline.
552     if (eligibility_eval_enabled) {
553       EXPECT_CALL(mock_eet_received_callback_,
554                   Call(FieldsAre(FieldsAre("", ""), kEligibilityEvalExecutionId,
555                                  Eq(std::nullopt))));
556     }
557 
558     return federated_protocol_
559         ->EligibilityEvalCheckin(mock_eet_received_callback_.AsStdFunction())
560         .status();
561   }
562 
563   // This function runs a successful Checkin() that results in a
564   // task assignment payload being returned by the server. This is a
565   // utility function used by Report*() tests that depend on a prior,
566   // successful execution of Checkin(). It returns a
567   // absl::Status, which the caller should verify is OK using ASSERT_OK.
RunSuccessfulCheckin(bool eligibility_eval_enabled=true)568   absl::Status RunSuccessfulCheckin(bool eligibility_eval_enabled = true) {
569     // We return a fake response which returns the plan/initial checkpoint
570     // data inline, to keep things simple.
571     std::string expected_plan = kPlan;
572     std::string plan_uri = "https://fake.uri/plan";
573     Resource plan_resource;
574     plan_resource.set_uri(plan_uri);
575     std::string expected_checkpoint = kInitCheckpoint;
576     Resource checkpoint_resource;
577     checkpoint_resource.mutable_inline_resource()->set_data(
578         expected_checkpoint);
579     std::string expected_aggregation_session_id = kAggregationSessionId;
580     StartTaskAssignmentResponse task_assignment_response =
581         GetFakeTaskAssignmentResponse(plan_resource, checkpoint_resource,
582                                       kFederatedSelectUriTemplate,
583                                       expected_aggregation_session_id, 0);
584 
585     std::string request_uri =
586         "https://taskassignment.uri/v1/populations/TEST%2FPOPULATION/"
587         "taskassignments/ELIGIBILITY%2FSESSION%23ID:start?%24alt=proto";
588     TaskEligibilityInfo expected_eligibility_info =
589         GetFakeTaskEligibilityInfo();
590     EXPECT_CALL(mock_http_client_,
591                 PerformSingleRequest(SimpleHttpRequestMatcher(
592                     request_uri, HttpRequest::Method::kPost, _,
593                     StartTaskAssignmentRequestMatcher(
594                         EqualsProto(GetExpectedStartTaskAssignmentRequest(
595                             expected_eligibility_info))))))
596         .WillOnce(Return(FakeHttpResponse(
597             200, HeaderList(),
598             CreateDoneOperation(kOperationName, task_assignment_response)
599                 .SerializeAsString())));
600 
601     EXPECT_CALL(mock_http_client_,
602                 PerformSingleRequest(SimpleHttpRequestMatcher(
603                     plan_uri, HttpRequest::Method::kGet, _, "")))
604         .WillOnce(Return(FakeHttpResponse(200, HeaderList(), expected_plan)));
605 
606     if (eligibility_eval_enabled) {
607       std::string report_eet_request_uri =
608           "https://initial.uri/v1/populations/TEST%2FPOPULATION/"
609           "eligibilityevaltasks/"
610           "ELIGIBILITY%2FSESSION%23ID:reportresult?%24alt=proto";
611       ExpectSuccessfulReportEligibilityEvalTaskResultRequest(
612           report_eet_request_uri, absl::OkStatus());
613     }
614 
615     return federated_protocol_
616         ->Checkin(expected_eligibility_info,
617                   mock_task_received_callback_.AsStdFunction())
618         .status();
619   }
620 
ExpectSuccessfulReportEligibilityEvalTaskResultRequest(absl::string_view expected_request_uri,absl::Status eet_status)621   void ExpectSuccessfulReportEligibilityEvalTaskResultRequest(
622       absl::string_view expected_request_uri, absl::Status eet_status) {
623     ReportEligibilityEvalTaskResultRequest report_eet_request;
624     report_eet_request.set_status_code(
625         static_cast<google::rpc::Code>(eet_status.code()));
626     EXPECT_CALL(
627         mock_http_client_,
628         PerformSingleRequest(SimpleHttpRequestMatcher(
629             std::string(expected_request_uri), HttpRequest::Method::kPost, _,
630             ReportEligibilityEvalTaskResultRequestMatcher(
631                 EqualsProto(report_eet_request)))))
632         .WillOnce(Return(FakeHttpResponse(200, HeaderList(), "")));
633   }
634 
ExpectSuccessfulReportTaskResultRequest(absl::string_view expected_report_result_uri,absl::string_view aggregation_session_id,absl::string_view task_name,absl::Duration plan_duration)635   void ExpectSuccessfulReportTaskResultRequest(
636       absl::string_view expected_report_result_uri,
637       absl::string_view aggregation_session_id, absl::string_view task_name,
638       absl::Duration plan_duration) {
639     ReportTaskResultResponse report_task_result_response;
640     EXPECT_CALL(mock_http_client_,
641                 PerformSingleRequest(SimpleHttpRequestMatcher(
642                     std::string(expected_report_result_uri),
643                     HttpRequest::Method::kPost, _,
644                     ReportTaskResultRequestMatcher(
645                         EqualsProto(GetExpectedReportTaskResultRequest(
646                             aggregation_session_id, task_name,
647                             google::rpc::Code::OK, plan_duration))))))
648         .WillOnce(Return(CreateEmptySuccessHttpResponse()));
649   }
650 
ExpectSuccessfulStartAggregationDataUploadRequest(absl::string_view expected_start_data_upload_uri,absl::string_view aggregation_resource_name,absl::string_view byte_stream_uri_prefix,absl::string_view second_stage_aggregation_uri_prefix)651   void ExpectSuccessfulStartAggregationDataUploadRequest(
652       absl::string_view expected_start_data_upload_uri,
653       absl::string_view aggregation_resource_name,
654       absl::string_view byte_stream_uri_prefix,
655       absl::string_view second_stage_aggregation_uri_prefix) {
656     Operation pending_operation_response =
657         CreatePendingOperation("operations/foo#bar");
658     EXPECT_CALL(mock_http_client_,
659                 PerformSingleRequest(SimpleHttpRequestMatcher(
660                     std::string(expected_start_data_upload_uri),
661                     HttpRequest::Method::kPost, _,
662                     StartAggregationDataUploadRequest().SerializeAsString())))
663         .WillOnce(Return(
664             FakeHttpResponse(200, HeaderList(),
665                              pending_operation_response.SerializeAsString())));
666     EXPECT_CALL(
667         mock_http_client_,
668         PerformSingleRequest(SimpleHttpRequestMatcher(
669             // Note that the '#' character is encoded as "%23".
670             "https://aggregation.uri/v1/operations/foo%23bar?%24alt=proto",
671             HttpRequest::Method::kGet, _,
672             GetOperationRequestMatcher(EqualsProto(GetOperationRequest())))))
673         .WillOnce(Return(FakeHttpResponse(
674             200, HeaderList(),
675             CreateDoneOperation(
676                 kOperationName,
677                 GetFakeStartAggregationDataUploadResponse(
678                     aggregation_resource_name, byte_stream_uri_prefix,
679                     second_stage_aggregation_uri_prefix))
680                 .SerializeAsString())));
681   }
682 
ExpectSuccessfulByteStreamUploadRequest(absl::string_view byte_stream_upload_uri,absl::string_view checkpoint_str)683   void ExpectSuccessfulByteStreamUploadRequest(
684       absl::string_view byte_stream_upload_uri,
685       absl::string_view checkpoint_str) {
686     EXPECT_CALL(
687         mock_http_client_,
688         PerformSingleRequest(SimpleHttpRequestMatcher(
689             std::string(byte_stream_upload_uri), HttpRequest::Method::kPost, _,
690             std::string(checkpoint_str))))
691         .WillOnce(Return(CreateEmptySuccessHttpResponse()));
692   }
693 
ExpectSuccessfulSubmitAggregationResultRequest(absl::string_view expected_submit_aggregation_result_uri)694   void ExpectSuccessfulSubmitAggregationResultRequest(
695       absl::string_view expected_submit_aggregation_result_uri) {
696     SubmitAggregationResultRequest submit_aggregation_result_request;
697     submit_aggregation_result_request.set_resource_name(kResourceName);
698     EXPECT_CALL(mock_http_client_,
699                 PerformSingleRequest(SimpleHttpRequestMatcher(
700                     std::string(expected_submit_aggregation_result_uri),
701                     HttpRequest::Method::kPost, _,
702                     submit_aggregation_result_request.SerializeAsString())))
703         .WillOnce(Return(CreateEmptySuccessHttpResponse()));
704   }
705 
ExpectSuccessfulAbortAggregationRequest(absl::string_view base_uri)706   void ExpectSuccessfulAbortAggregationRequest(absl::string_view base_uri) {
707     EXPECT_CALL(mock_http_client_,
708                 PerformSingleRequest(SimpleHttpRequestMatcher(
709                     absl::StrCat(base_uri, "/v1/aggregations/",
710                                  "AGGREGATION_SESSION_ID/clients/"
711                                  "CLIENT_TOKEN:abort?%24alt=proto"),
712                     HttpRequest::Method::kPost, _, _)))
713         .WillOnce(Return(CreateEmptySuccessHttpResponse()));
714   }
715 
716   StrictMock<MockHttpClient> mock_http_client_;
717   StrictMock<MockSecAggRunnerFactory>* mock_secagg_runner_factory_ =
718       new StrictMock<MockSecAggRunnerFactory>();
719   StrictMock<MockSecAggEventPublisher> mock_secagg_event_publisher_;
720   StrictMock<MockLogManager> mock_log_manager_;
721   NiceMock<MockFlags> mock_flags_;
722   NiceMock<MockFunction<bool()>> mock_should_abort_;
723   StrictMock<cache::MockResourceCache> mock_resource_cache_;
724   Clock* clock_ = Clock::RealClock();
725   NiceMock<MockFunction<void(
726       const ::fcp::client::FederatedProtocol::EligibilityEvalTask&)>>
727       mock_eet_received_callback_;
728   NiceMock<MockFunction<void(
729       const ::fcp::client::FederatedProtocol::TaskAssignment&)>>
730       mock_task_received_callback_;
731 
732   // The class under test.
733   std::unique_ptr<HttpFederatedProtocol> federated_protocol_;
734 };
735 
736 using HttpFederatedProtocolDeathTest = HttpFederatedProtocolTest;
737 
TEST_F(HttpFederatedProtocolTest,TestTransientErrorRetryWindowDifferentAcrossDifferentInstances)738 TEST_F(HttpFederatedProtocolTest,
739        TestTransientErrorRetryWindowDifferentAcrossDifferentInstances) {
740   const ::google::internal::federatedml::v2::RetryWindow& retry_window1 =
741       federated_protocol_->GetLatestRetryWindow();
742   ExpectTransientErrorRetryWindow(retry_window1);
743   federated_protocol_.reset(nullptr);
744   mock_secagg_runner_factory_ = new StrictMock<MockSecAggRunnerFactory>();
745 
746   // Create a new HttpFederatedProtocol instance. It should not produce the same
747   // retry window value as the one we just got. This is a simple correctness
748   // check to ensure that the value is at least randomly generated (and that we
749   // don't accidentally use the random number generator incorrectly).
750   federated_protocol_ = std::make_unique<HttpFederatedProtocol>(
751       clock_, &mock_log_manager_, &mock_flags_, &mock_http_client_,
752       absl::WrapUnique(mock_secagg_runner_factory_),
753       &mock_secagg_event_publisher_, kEntryPointUri, kApiKey, kPopulationName,
754       kRetryToken, kClientVersion, kAttestationMeasurement,
755       mock_should_abort_.AsStdFunction(), absl::BitGen(),
756       InterruptibleRunner::TimingConfig{
757           .polling_period = absl::ZeroDuration(),
758           .graceful_shutdown_period = absl::InfiniteDuration(),
759           .extended_shutdown_period = absl::InfiniteDuration()},
760       &mock_resource_cache_);
761 
762   const ::google::internal::federatedml::v2::RetryWindow& retry_window2 =
763       federated_protocol_->GetLatestRetryWindow();
764   ExpectTransientErrorRetryWindow(retry_window2);
765 
766   EXPECT_THAT(retry_window1, Not(EqualsProto(retry_window2)));
767 }
768 
TEST_F(HttpFederatedProtocolTest,TestEligibilityEvalCheckinRequestFailsTransientError)769 TEST_F(HttpFederatedProtocolTest,
770        TestEligibilityEvalCheckinRequestFailsTransientError) {
771   // Make the HTTP client return a 503 Service Unavailable error when the
772   // EligibilityEvalCheckin(...) code issues the control protocol's HTTP
773   // request. This should result in the error being returned as the result.
774   EXPECT_CALL(mock_http_client_,
775               PerformSingleRequest(SimpleHttpRequestMatcher(
776                   "https://initial.uri/v1/eligibilityevaltasks/"
777                   "TEST%2FPOPULATION:request?%24alt=proto",
778                   HttpRequest::Method::kPost, _, _)))
779       .WillOnce(Return(FakeHttpResponse(503, HeaderList(), "")));
780 
781   auto eligibility_checkin_result = federated_protocol_->EligibilityEvalCheckin(
782       mock_eet_received_callback_.AsStdFunction());
783 
784   EXPECT_THAT(eligibility_checkin_result.status(), IsCode(UNAVAILABLE));
785   EXPECT_THAT(eligibility_checkin_result.status().message(),
786               HasSubstr("protocol request failed"));
787   // The original 503 HTTP response code should be included in the message as
788   // well.
789   EXPECT_THAT(eligibility_checkin_result.status().message(), HasSubstr("503"));
790   // No RetryWindows were received from the server, so we expect to get a
791   // RetryWindow generated based on the transient errors retry delay flag.
792   ExpectTransientErrorRetryWindow(federated_protocol_->GetLatestRetryWindow());
793 }
794 
TEST_F(HttpFederatedProtocolTest,TestEligibilityEvalCheckinRequestFailsPermanentError)795 TEST_F(HttpFederatedProtocolTest,
796        TestEligibilityEvalCheckinRequestFailsPermanentError) {
797   // Make the HTTP client return a 404 Not Found error when the
798   // EligibilityEvalCheckin(...) code issues the control protocol's HTTP
799   // request. This should result in the error being returned as the result.
800   EXPECT_CALL(mock_http_client_,
801               PerformSingleRequest(SimpleHttpRequestMatcher(
802                   "https://initial.uri/v1/eligibilityevaltasks/"
803                   "TEST%2FPOPULATION:request?%24alt=proto",
804                   HttpRequest::Method::kPost, _, _)))
805       .WillOnce(Return(FakeHttpResponse(404, HeaderList(), "")));
806 
807   auto eligibility_checkin_result = federated_protocol_->EligibilityEvalCheckin(
808       mock_eet_received_callback_.AsStdFunction());
809 
810   EXPECT_THAT(eligibility_checkin_result.status(), IsCode(NOT_FOUND));
811   EXPECT_THAT(eligibility_checkin_result.status().message(),
812               HasSubstr("protocol request failed"));
813   // The original 404 HTTP response code should be included in the message as
814   // well.
815   EXPECT_THAT(eligibility_checkin_result.status().message(), HasSubstr("404"));
816   // No RetryWindows were received from the server, so we expect to get a
817   // RetryWindow generated based on the *permanent* errors retry delay flag,
818   // since NOT_FOUND is marked as a permanent error in the flags.
819   ExpectPermanentErrorRetryWindow(federated_protocol_->GetLatestRetryWindow());
820 }
821 
822 // Tests the case where we get interrupted while waiting for a response to the
823 // protocol request in EligibilityEvalCheckin.
TEST_F(HttpFederatedProtocolTest,TestEligibilityEvalCheckinRequestInterrupted)824 TEST_F(HttpFederatedProtocolTest,
825        TestEligibilityEvalCheckinRequestInterrupted) {
826   absl::Notification request_issued;
827   absl::Notification request_cancelled;
828 
829   // Make HttpClient::PerformRequests() block until the counter is decremented.
830   EXPECT_CALL(mock_http_client_,
831               PerformSingleRequest(SimpleHttpRequestMatcher(
832                   "https://initial.uri/v1/eligibilityevaltasks/"
833                   "TEST%2FPOPULATION:request?%24alt=proto",
834                   HttpRequest::Method::kPost, _, _)))
835       .WillOnce([&request_issued, &request_cancelled](
836                     MockableHttpClient::SimpleHttpRequest ignored) {
837         request_issued.Notify();
838         request_cancelled.WaitForNotification();
839         return FakeHttpResponse(503, HeaderList(), "");
840       });
841 
842   // Make should_abort return false until we know that the request was issued
843   // (i.e. once InterruptibleRunner has actually started running the code it
844   // was given), and then make it return true, triggering an abort sequence and
845   // unblocking the PerformRequests()() call we caused to block above.
846   EXPECT_CALL(mock_should_abort_, Call()).WillRepeatedly([&request_issued] {
847     return request_issued.HasBeenNotified();
848   });
849 
850   // When the HttpClient receives a HttpRequestHandle::Cancel call, we let the
851   // request complete.
852   mock_http_client_.SetCancellationListener(
853       [&request_cancelled]() { request_cancelled.Notify(); });
854 
855   EXPECT_CALL(mock_log_manager_,
856               LogDiag(ProdDiagCode::BACKGROUND_TRAINING_INTERRUPT_HTTP));
857 
858   auto eligibility_checkin_result = federated_protocol_->EligibilityEvalCheckin(
859       mock_eet_received_callback_.AsStdFunction());
860 
861   EXPECT_THAT(eligibility_checkin_result.status(), IsCode(CANCELLED));
862   // No RetryWindows were received from the server, so we expect to get a
863   // RetryWindow generated based on the transient errors retry delay flag.
864   ExpectTransientErrorRetryWindow(federated_protocol_->GetLatestRetryWindow());
865 }
866 
TEST_F(HttpFederatedProtocolTest,TestEligibilityEvalCheckinRejection)867 TEST_F(HttpFederatedProtocolTest, TestEligibilityEvalCheckinRejection) {
868   EXPECT_CALL(mock_http_client_,
869               PerformSingleRequest(SimpleHttpRequestMatcher(
870                   "https://initial.uri/v1/eligibilityevaltasks/"
871                   "TEST%2FPOPULATION:request?%24alt=proto",
872                   HttpRequest::Method::kPost, _,
873                   EligibilityEvalTaskRequestMatcher(
874                       EqualsProto(GetExpectedEligibilityEvalTaskRequest())))))
875       .WillOnce(Return(FakeHttpResponse(
876           200, HeaderList(),
877           GetFakeRejectedEligibilityEvalTaskResponse().SerializeAsString())));
878 
879   // The 'eet received' callback should not be invoked since no EET was given to
880   // the client.
881   EXPECT_CALL(mock_eet_received_callback_, Call(_)).Times(0);
882 
883   auto eligibility_checkin_result = federated_protocol_->EligibilityEvalCheckin(
884       mock_eet_received_callback_.AsStdFunction());
885 
886   ASSERT_OK(eligibility_checkin_result);
887   EXPECT_THAT(*eligibility_checkin_result,
888               VariantWith<FederatedProtocol::Rejection>(_));
889   ExpectRejectedRetryWindow(federated_protocol_->GetLatestRetryWindow());
890 }
891 
TEST_F(HttpFederatedProtocolTest,TestEligibilityEvalCheckinDisabled)892 TEST_F(HttpFederatedProtocolTest, TestEligibilityEvalCheckinDisabled) {
893   EXPECT_CALL(mock_http_client_,
894               PerformSingleRequest(SimpleHttpRequestMatcher(
895                   "https://initial.uri/v1/eligibilityevaltasks/"
896                   "TEST%2FPOPULATION:request?%24alt=proto",
897                   HttpRequest::Method::kPost, _,
898                   EligibilityEvalTaskRequestMatcher(
899                       EqualsProto(GetExpectedEligibilityEvalTaskRequest())))))
900       .WillOnce(Return(FakeHttpResponse(
901           200, HeaderList(),
902           GetFakeDisabledEligibilityEvalTaskResponse().SerializeAsString())));
903 
904   // The 'eet received' callback should not be invoked since no EET was given to
905   // the client.
906   EXPECT_CALL(mock_eet_received_callback_, Call(_)).Times(0);
907 
908   auto eligibility_checkin_result = federated_protocol_->EligibilityEvalCheckin(
909       mock_eet_received_callback_.AsStdFunction());
910 
911   ASSERT_OK(eligibility_checkin_result);
912   EXPECT_THAT(*eligibility_checkin_result,
913               VariantWith<FederatedProtocol::EligibilityEvalDisabled>(_));
914   ExpectRejectedRetryWindow(federated_protocol_->GetLatestRetryWindow());
915 }
916 
TEST_F(HttpFederatedProtocolTest,TestEligibilityEvalCheckinEnabled)917 TEST_F(HttpFederatedProtocolTest, TestEligibilityEvalCheckinEnabled) {
918   // We return a fake response which requires fetching the plan via HTTP, but
919   // which has the checkpoint data available inline.
920   std::string expected_plan = kPlan;
921   std::string plan_uri = "https://fake.uri/plan";
922   Resource plan_resource;
923   plan_resource.set_uri(plan_uri);
924   std::string expected_checkpoint = kInitCheckpoint;
925   Resource checkpoint_resource;
926   checkpoint_resource.mutable_inline_resource()->set_data(expected_checkpoint);
927   std::string expected_execution_id = kEligibilityEvalExecutionId;
928   EligibilityEvalTaskResponse eval_task_response =
929       GetFakeEnabledEligibilityEvalTaskResponse(
930           plan_resource, checkpoint_resource, expected_execution_id);
931 
932   InSequence seq;
933   EXPECT_CALL(mock_http_client_,
934               PerformSingleRequest(SimpleHttpRequestMatcher(
935                   "https://initial.uri/v1/eligibilityevaltasks/"
936                   "TEST%2FPOPULATION:request?%24alt=proto",
937                   HttpRequest::Method::kPost, _,
938                   EligibilityEvalTaskRequestMatcher(
939                       EqualsProto(GetExpectedEligibilityEvalTaskRequest())))))
940       .WillOnce(Return(FakeHttpResponse(
941           200, HeaderList(), eval_task_response.SerializeAsString())));
942 
943   // The 'EET received' callback should be called *before* the actual task
944   // resources are fetched.
945   EXPECT_CALL(mock_eet_received_callback_,
946               Call(FieldsAre(FieldsAre("", ""), expected_execution_id,
947                              Eq(std::nullopt))));
948 
949   EXPECT_CALL(mock_http_client_,
950               PerformSingleRequest(SimpleHttpRequestMatcher(
951                   plan_uri, HttpRequest::Method::kGet, _, "")))
952       .WillOnce(Return(FakeHttpResponse(200, HeaderList(), expected_plan)));
953 
954   auto eligibility_checkin_result = federated_protocol_->EligibilityEvalCheckin(
955       mock_eet_received_callback_.AsStdFunction());
956 
957   ASSERT_OK(eligibility_checkin_result);
958   EXPECT_THAT(
959       *eligibility_checkin_result,
960       VariantWith<FederatedProtocol::EligibilityEvalTask>(FieldsAre(
961           AllOf(Field(&FederatedProtocol::PlanAndCheckpointPayloads::plan,
962                       absl::Cord(expected_plan)),
963                 Field(&FederatedProtocol::PlanAndCheckpointPayloads::checkpoint,
964                       absl::Cord(expected_checkpoint))),
965           expected_execution_id, Eq(std::nullopt))));
966   ExpectRejectedRetryWindow(federated_protocol_->GetLatestRetryWindow());
967 }
968 
TEST_F(HttpFederatedProtocolTest,TestEligibilityEvalCheckinWithPopulationEligibilitySpec)969 TEST_F(HttpFederatedProtocolTest,
970        TestEligibilityEvalCheckinWithPopulationEligibilitySpec) {
971   EXPECT_CALL(mock_flags_, http_protocol_supports_multiple_task_assignments)
972       .WillRepeatedly(Return(true));
973   // We return a fake response which requires fetching the plan via HTTP,
974   // but which has the checkpoint data available inline.
975   std::string expected_plan = kPlan;
976   std::string plan_uri = "https://fake.uri/plan";
977   Resource plan_resource;
978   plan_resource.set_uri(plan_uri);
979   std::string expected_checkpoint = kInitCheckpoint;
980   Resource checkpoint_resource;
981   checkpoint_resource.mutable_inline_resource()->set_data(expected_checkpoint);
982 
983   PopulationEligibilitySpec expected_population_eligibility_spec;
984   auto task_info = expected_population_eligibility_spec.add_task_info();
985   task_info->set_task_name("task_1");
986   task_info->set_task_assignment_mode(
987       PopulationEligibilitySpec::TaskInfo::TASK_ASSIGNMENT_MODE_MULTIPLE);
988   std::string population_eligibility_spec_uri =
989       "https://fake.uri/population_eligibility_spec";
990   Resource population_eligibility_spec;
991   population_eligibility_spec.set_uri(population_eligibility_spec_uri);
992   std::string expected_execution_id = kEligibilityEvalExecutionId;
993   EligibilityEvalTaskResponse eval_task_response =
994       GetFakeEnabledEligibilityEvalTaskResponse(
995           plan_resource, checkpoint_resource, expected_execution_id,
996           population_eligibility_spec);
997 
998   InSequence seq;
999   EXPECT_CALL(mock_http_client_,
1000               PerformSingleRequest(SimpleHttpRequestMatcher(
1001                   "https://initial.uri/v1/eligibilityevaltasks/"
1002                   "TEST%2FPOPULATION:request?%24alt=proto",
1003                   HttpRequest::Method::kPost, _,
1004                   EligibilityEvalTaskRequestMatcher(
1005                       EqualsProto(GetExpectedEligibilityEvalTaskRequest(
1006                           /* supports_multiple_task_assignments= */ true))))))
1007       .WillOnce(Return(FakeHttpResponse(
1008           200, HeaderList(), eval_task_response.SerializeAsString())));
1009 
1010   // The 'EET received' callback should be called *before* the actual task
1011   // resources are fetched.
1012   EXPECT_CALL(mock_eet_received_callback_,
1013               Call(FieldsAre(FieldsAre("", ""), expected_execution_id,
1014                              Eq(std::nullopt))));
1015 
1016   EXPECT_CALL(mock_http_client_,
1017               PerformSingleRequest(SimpleHttpRequestMatcher(
1018                   plan_uri, HttpRequest::Method::kGet, _, "")))
1019       .WillOnce(Return(FakeHttpResponse(200, HeaderList(), expected_plan)));
1020   EXPECT_CALL(mock_http_client_, PerformSingleRequest(SimpleHttpRequestMatcher(
1021                                      population_eligibility_spec_uri,
1022                                      HttpRequest::Method::kGet, _, "")))
1023       .WillOnce(Return(FakeHttpResponse(
1024           200, HeaderList(),
1025           expected_population_eligibility_spec.SerializeAsString())));
1026 
1027   auto eligibility_checkin_result = federated_protocol_->EligibilityEvalCheckin(
1028       mock_eet_received_callback_.AsStdFunction());
1029 
1030   ASSERT_OK(eligibility_checkin_result);
1031   EXPECT_THAT(
1032       *eligibility_checkin_result,
1033       VariantWith<FederatedProtocol::EligibilityEvalTask>(FieldsAre(
1034           AllOf(Field(&FederatedProtocol::PlanAndCheckpointPayloads::plan,
1035                       absl::Cord(expected_plan)),
1036                 Field(&FederatedProtocol::PlanAndCheckpointPayloads::checkpoint,
1037                       absl::Cord(expected_checkpoint))),
1038           expected_execution_id,
1039           Optional(EqualsProto(expected_population_eligibility_spec)))));
1040   ExpectRejectedRetryWindow(federated_protocol_->GetLatestRetryWindow());
1041 }
1042 
TEST_F(HttpFederatedProtocolTest,TestEligibilityEvalCheckinWithPopulationEligibilitySpecInvalidFormat)1043 TEST_F(HttpFederatedProtocolTest,
1044        TestEligibilityEvalCheckinWithPopulationEligibilitySpecInvalidFormat) {
1045   EXPECT_CALL(mock_flags_, http_protocol_supports_multiple_task_assignments)
1046       .WillRepeatedly(Return(true));
1047   // We return a fake response which requires fetching the plan via HTTP,
1048   // but which has the checkpoint data available inline.
1049   std::string expected_plan = kPlan;
1050   std::string plan_uri = "https://fake.uri/plan";
1051   Resource plan_resource;
1052   plan_resource.set_uri(plan_uri);
1053   std::string expected_checkpoint = kInitCheckpoint;
1054   Resource checkpoint_resource;
1055   checkpoint_resource.mutable_inline_resource()->set_data(expected_checkpoint);
1056 
1057   Resource population_eligibility_spec;
1058   population_eligibility_spec.mutable_inline_resource()->set_data(
1059       "Invalid_spec");
1060   std::string expected_execution_id = kEligibilityEvalExecutionId;
1061   EligibilityEvalTaskResponse eval_task_response =
1062       GetFakeEnabledEligibilityEvalTaskResponse(
1063           plan_resource, checkpoint_resource, expected_execution_id,
1064           population_eligibility_spec);
1065 
1066   InSequence seq;
1067   EXPECT_CALL(mock_http_client_,
1068               PerformSingleRequest(SimpleHttpRequestMatcher(
1069                   "https://initial.uri/v1/eligibilityevaltasks/"
1070                   "TEST%2FPOPULATION:request?%24alt=proto",
1071                   HttpRequest::Method::kPost, _,
1072                   EligibilityEvalTaskRequestMatcher(
1073                       EqualsProto(GetExpectedEligibilityEvalTaskRequest(
1074                           /* supports_multiple_task_assignments= */ true))))))
1075       .WillOnce(Return(FakeHttpResponse(
1076           200, HeaderList(), eval_task_response.SerializeAsString())));
1077 
1078   // The 'EET received' callback should be called *before* the actual task
1079   // resources are fetched.
1080   EXPECT_CALL(mock_eet_received_callback_,
1081               Call(FieldsAre(FieldsAre("", ""), expected_execution_id,
1082                              Eq(std::nullopt))));
1083 
1084   EXPECT_CALL(mock_http_client_,
1085               PerformSingleRequest(SimpleHttpRequestMatcher(
1086                   plan_uri, HttpRequest::Method::kGet, _, "")))
1087       .WillOnce(Return(FakeHttpResponse(200, HeaderList(), expected_plan)));
1088 
1089   auto eligibility_checkin_result = federated_protocol_->EligibilityEvalCheckin(
1090       mock_eet_received_callback_.AsStdFunction());
1091 
1092   ASSERT_THAT(eligibility_checkin_result, IsCode(INVALID_ARGUMENT));
1093   ExpectPermanentErrorRetryWindow(federated_protocol_->GetLatestRetryWindow());
1094 }
1095 
TEST_F(HttpFederatedProtocolTest,TestEligibilityEvalCheckinEnabledWithCompression)1096 TEST_F(HttpFederatedProtocolTest,
1097        TestEligibilityEvalCheckinEnabledWithCompression) {
1098   std::string expected_plan = kPlan;
1099   absl::StatusOr<std::string> compressed_plan =
1100       internal::CompressWithGzip(expected_plan);
1101   ASSERT_OK(compressed_plan);
1102   Resource plan_resource;
1103   plan_resource.mutable_inline_resource()->set_data(*compressed_plan);
1104   plan_resource.mutable_inline_resource()->set_compression_format(
1105       ResourceCompressionFormat::RESOURCE_COMPRESSION_FORMAT_GZIP);
1106   std::string expected_checkpoint = kInitCheckpoint;
1107   absl::StatusOr<std::string> compressed_checkpoint =
1108       internal::CompressWithGzip(expected_checkpoint);
1109   Resource checkpoint_resource;
1110   checkpoint_resource.mutable_inline_resource()->set_data(
1111       *compressed_checkpoint);
1112   checkpoint_resource.mutable_inline_resource()->set_compression_format(
1113       ResourceCompressionFormat::RESOURCE_COMPRESSION_FORMAT_GZIP);
1114   std::string expected_execution_id = kEligibilityEvalExecutionId;
1115   EligibilityEvalTaskResponse eval_task_response =
1116       GetFakeEnabledEligibilityEvalTaskResponse(
1117           plan_resource, checkpoint_resource, expected_execution_id);
1118   EXPECT_CALL(mock_http_client_,
1119               PerformSingleRequest(SimpleHttpRequestMatcher(
1120                   "https://initial.uri/v1/eligibilityevaltasks/"
1121                   "TEST%2FPOPULATION:request?%24alt=proto",
1122                   HttpRequest::Method::kPost, _, _)))
1123       .WillOnce(Return(FakeHttpResponse(
1124           200, HeaderList(), eval_task_response.SerializeAsString())));
1125 
1126   auto eligibility_checkin_result = federated_protocol_->EligibilityEvalCheckin(
1127       mock_eet_received_callback_.AsStdFunction());
1128 
1129   ASSERT_OK(eligibility_checkin_result);
1130   EXPECT_THAT(
1131       *eligibility_checkin_result,
1132       VariantWith<FederatedProtocol::EligibilityEvalTask>(FieldsAre(
1133           AllOf(Field(&FederatedProtocol::PlanAndCheckpointPayloads::plan,
1134                       absl::Cord(expected_plan)),
1135                 Field(&FederatedProtocol::PlanAndCheckpointPayloads::checkpoint,
1136                       absl::Cord(expected_checkpoint))),
1137           expected_execution_id, Eq(std::nullopt))));
1138   ExpectRejectedRetryWindow(federated_protocol_->GetLatestRetryWindow());
1139 }
1140 
1141 // Ensures that if the plan resource fails to be downloaded, the error is
1142 // correctly returned from the EligibilityEvalCheckin(...) method.
TEST_F(HttpFederatedProtocolTest,TestEligibilityEvalCheckinEnabledPlanDataFetchFailed)1143 TEST_F(HttpFederatedProtocolTest,
1144        TestEligibilityEvalCheckinEnabledPlanDataFetchFailed) {
1145   std::string plan_uri = "https://fake.uri/plan";
1146   Resource plan_resource;
1147   plan_resource.set_uri(plan_uri);
1148   std::string checkpoint_uri = "https://fake.uri/checkpoint";
1149   Resource checkpoint_resource;
1150   checkpoint_resource.set_uri(checkpoint_uri);
1151   std::string expected_execution_id = kEligibilityEvalExecutionId;
1152   EligibilityEvalTaskResponse eval_task_response =
1153       GetFakeEnabledEligibilityEvalTaskResponse(
1154           plan_resource, checkpoint_resource, expected_execution_id);
1155   EXPECT_CALL(mock_http_client_,
1156               PerformSingleRequest(SimpleHttpRequestMatcher(
1157                   "https://initial.uri/v1/eligibilityevaltasks/"
1158                   "TEST%2FPOPULATION:request?%24alt=proto",
1159                   HttpRequest::Method::kPost, _, _)))
1160       .WillOnce(Return(FakeHttpResponse(
1161           200, HeaderList(), eval_task_response.SerializeAsString())));
1162 
1163   EXPECT_CALL(mock_http_client_,
1164               PerformSingleRequest(SimpleHttpRequestMatcher(
1165                   checkpoint_uri, HttpRequest::Method::kGet, _, "")))
1166       .WillOnce(Return(FakeHttpResponse(200, HeaderList(), "")));
1167 
1168   // Mock a failed plan fetch.
1169   EXPECT_CALL(mock_http_client_,
1170               PerformSingleRequest(SimpleHttpRequestMatcher(
1171                   plan_uri, HttpRequest::Method::kGet, _, "")))
1172       .WillOnce(Return(FakeHttpResponse(404, HeaderList(), "")));
1173 
1174   auto eligibility_checkin_result = federated_protocol_->EligibilityEvalCheckin(
1175       mock_eet_received_callback_.AsStdFunction());
1176 
1177   // The 404 error for the resource request should be reflected in the return
1178   // value.
1179   EXPECT_THAT(eligibility_checkin_result.status(), IsCode(NOT_FOUND));
1180   EXPECT_THAT(eligibility_checkin_result.status().message(),
1181               HasSubstr("plan fetch failed"));
1182   // The original 404 HTTP response code should be included in the message as
1183   // well.
1184   EXPECT_THAT(eligibility_checkin_result.status().message(), HasSubstr("404"));
1185   // Since the error type is considered a permanent error, we should get a
1186   // permanent error retry window.
1187   ExpectPermanentErrorRetryWindow(federated_protocol_->GetLatestRetryWindow());
1188 }
1189 
1190 // Ensures that if the checkpoint resource fails to be downloaded, the error is
1191 // correctly returned from the EligibilityEvalCheckin(...) method.
TEST_F(HttpFederatedProtocolTest,TestEligibilityEvalCheckinEnabledCheckpointDataFetchFailed)1192 TEST_F(HttpFederatedProtocolTest,
1193        TestEligibilityEvalCheckinEnabledCheckpointDataFetchFailed) {
1194   std::string plan_uri = "https://fake.uri/plan";
1195   Resource plan_resource;
1196   plan_resource.set_uri(plan_uri);
1197   std::string checkpoint_uri = "https://fake.uri/checkpoint";
1198   Resource checkpoint_resource;
1199   checkpoint_resource.set_uri(checkpoint_uri);
1200   std::string expected_execution_id = kEligibilityEvalExecutionId;
1201   EligibilityEvalTaskResponse eval_task_response =
1202       GetFakeEnabledEligibilityEvalTaskResponse(
1203           plan_resource, checkpoint_resource, expected_execution_id);
1204   EXPECT_CALL(mock_http_client_,
1205               PerformSingleRequest(SimpleHttpRequestMatcher(
1206                   "https://initial.uri/v1/eligibilityevaltasks/"
1207                   "TEST%2FPOPULATION:request?%24alt=proto",
1208                   HttpRequest::Method::kPost, _, _)))
1209       .WillOnce(Return(FakeHttpResponse(
1210           200, HeaderList(), eval_task_response.SerializeAsString())));
1211 
1212   // Mock a failed checkpoint fetch.
1213   EXPECT_CALL(mock_http_client_,
1214               PerformSingleRequest(SimpleHttpRequestMatcher(
1215                   checkpoint_uri, HttpRequest::Method::kGet, _, "")))
1216       .WillOnce(Return(FakeHttpResponse(503, HeaderList(), "")));
1217 
1218   EXPECT_CALL(mock_http_client_,
1219               PerformSingleRequest(SimpleHttpRequestMatcher(
1220                   plan_uri, HttpRequest::Method::kGet, _, "")))
1221       .WillOnce(Return(FakeHttpResponse(200, HeaderList(), "")));
1222 
1223   auto eligibility_checkin_result = federated_protocol_->EligibilityEvalCheckin(
1224       mock_eet_received_callback_.AsStdFunction());
1225 
1226   // The 503 error for the resource request should be reflected in the return
1227   // value.
1228   EXPECT_THAT(eligibility_checkin_result.status(), IsCode(UNAVAILABLE));
1229   EXPECT_THAT(eligibility_checkin_result.status().message(),
1230               HasSubstr("checkpoint fetch failed"));
1231   // The original 503 HTTP response code should be included in the message as
1232   // well.
1233   EXPECT_THAT(eligibility_checkin_result.status().message(), HasSubstr("503"));
1234   // RetryWindows were received from the server before the error was received,
1235   // and the error is considered 'transient', so we expect to get a rejected
1236   // RetryWindow.
1237   ExpectRejectedRetryWindow(federated_protocol_->GetLatestRetryWindow());
1238 }
1239 
TEST_F(HttpFederatedProtocolTest,TestReportEligibilityEvalTaskResult)1240 TEST_F(HttpFederatedProtocolTest, TestReportEligibilityEvalTaskResult) {
1241   ASSERT_OK(RunSuccessfulEligibilityEvalCheckin());
1242   std::string report_eet_request_uri =
1243       "https://initial.uri/v1/populations/TEST%2FPOPULATION/"
1244       "eligibilityevaltasks/"
1245       "ELIGIBILITY%2FSESSION%23ID:reportresult?%24alt=proto";
1246   ReportEligibilityEvalTaskResultRequest report_eet_request;
1247   report_eet_request.set_status_code(
1248       static_cast<google::rpc::Code>(absl::StatusCode::kCancelled));
1249   EXPECT_CALL(mock_http_client_,
1250               PerformSingleRequest(SimpleHttpRequestMatcher(
1251                   report_eet_request_uri, HttpRequest::Method::kPost, _,
1252                   ReportEligibilityEvalTaskResultRequestMatcher(
1253                       EqualsProto(report_eet_request)))))
1254       .WillOnce(Return(FakeHttpResponse(200, HeaderList(), "")));
1255 
1256   federated_protocol_->ReportEligibilityEvalError(absl::CancelledError());
1257 }
1258 
1259 // Tests that the protocol correctly sanitizes any invalid values it may have
1260 // received from the server.
TEST_F(HttpFederatedProtocolTest,TestNegativeMinMaxRetryDelayValueSanitization)1261 TEST_F(HttpFederatedProtocolTest,
1262        TestNegativeMinMaxRetryDelayValueSanitization) {
1263   RetryWindow retry_window;
1264   retry_window.mutable_delay_min()->set_seconds(-1);
1265   retry_window.mutable_delay_max()->set_seconds(-2);
1266 
1267   // The above retry window's negative min/max values should be clamped to 0.
1268   RetryWindow expected_retry_window;
1269   expected_retry_window.mutable_delay_min()->set_seconds(0);
1270   expected_retry_window.mutable_delay_max()->set_seconds(0);
1271 
1272   EligibilityEvalTaskResponse eval_task_response =
1273       GetFakeEnabledEligibilityEvalTaskResponse(
1274           Resource(), Resource(), kEligibilityEvalExecutionId, std::nullopt,
1275           retry_window, retry_window);
1276   EXPECT_CALL(mock_http_client_,
1277               PerformSingleRequest(SimpleHttpRequestMatcher(
1278                   "https://initial.uri/v1/eligibilityevaltasks/"
1279                   "TEST%2FPOPULATION:request?%24alt=proto",
1280                   HttpRequest::Method::kPost, _, _)))
1281       .WillOnce(Return(FakeHttpResponse(
1282           200, HeaderList(), eval_task_response.SerializeAsString())));
1283 
1284   ASSERT_OK(federated_protocol_->EligibilityEvalCheckin(
1285       mock_eet_received_callback_.AsStdFunction()));
1286 
1287   const google::internal::federatedml::v2::RetryWindow& actual_retry_window =
1288       federated_protocol_->GetLatestRetryWindow();
1289   // The above retry window's invalid max value should be clamped to the min
1290   // value (minus some errors introduced by the inaccuracy of double
1291   // multiplication).
1292   EXPECT_THAT(actual_retry_window.delay_min().seconds() +
1293                   actual_retry_window.delay_min().nanos() / 1000000000.0,
1294               DoubleEq(0));
1295   EXPECT_THAT(actual_retry_window.delay_max().seconds() +
1296                   actual_retry_window.delay_max().nanos() / 1000000000.0,
1297               DoubleEq(0));
1298 }
1299 
1300 // Tests that the protocol correctly sanitizes any invalid values it may have
1301 // received from the server.
TEST_F(HttpFederatedProtocolTest,TestInvalidMaxRetryDelayValueSanitization)1302 TEST_F(HttpFederatedProtocolTest, TestInvalidMaxRetryDelayValueSanitization) {
1303   RetryWindow retry_window;
1304   retry_window.mutable_delay_min()->set_seconds(1234);
1305   retry_window.mutable_delay_max()->set_seconds(1233);  // less than delay_min
1306 
1307   EligibilityEvalTaskResponse eval_task_response =
1308       GetFakeEnabledEligibilityEvalTaskResponse(
1309           Resource(), Resource(), kEligibilityEvalExecutionId, std::nullopt,
1310           retry_window, retry_window);
1311   EXPECT_CALL(mock_http_client_,
1312               PerformSingleRequest(SimpleHttpRequestMatcher(
1313                   "https://initial.uri/v1/eligibilityevaltasks/"
1314                   "TEST%2FPOPULATION:request?%24alt=proto",
1315                   HttpRequest::Method::kPost, _, _)))
1316       .WillOnce(Return(FakeHttpResponse(
1317           200, HeaderList(), eval_task_response.SerializeAsString())));
1318 
1319   ASSERT_OK(federated_protocol_->EligibilityEvalCheckin(
1320       mock_eet_received_callback_.AsStdFunction()));
1321 
1322   const google::internal::federatedml::v2::RetryWindow& actual_retry_window =
1323       federated_protocol_->GetLatestRetryWindow();
1324   // The above retry window's invalid max value should be clamped to the min
1325   // value (minus some errors introduced by the inaccuracy of double
1326   // multiplication). Note that DoubleEq enforces too precise of bounds, so we
1327   // use DoubleNear instead.
1328   EXPECT_THAT(actual_retry_window.delay_min().seconds() +
1329                   actual_retry_window.delay_min().nanos() / 1000000000.0,
1330               DoubleNear(1234.0, 0.015));
1331   EXPECT_THAT(actual_retry_window.delay_max().seconds() +
1332                   actual_retry_window.delay_max().nanos() / 1000000000.0,
1333               DoubleNear(1234.0, 0.015));
1334 }
1335 
TEST_F(HttpFederatedProtocolDeathTest,TestCheckinAfterFailedEligibilityEvalCheckin)1336 TEST_F(HttpFederatedProtocolDeathTest,
1337        TestCheckinAfterFailedEligibilityEvalCheckin) {
1338   // Make the HTTP client return a 503 Service Unavailable error when the
1339   // EligibilityEvalCheckin(...) code issues the protocol HTTP request.
1340   // This should result in the error being returned as the result.
1341   EXPECT_CALL(mock_http_client_,
1342               PerformSingleRequest(SimpleHttpRequestMatcher(
1343                   "https://initial.uri/v1/eligibilityevaltasks/"
1344                   "TEST%2FPOPULATION:request?%24alt=proto",
1345                   HttpRequest::Method::kPost, _, _)))
1346       .WillOnce(Return(FakeHttpResponse(503, HeaderList(), "")));
1347 
1348   auto eligibility_checkin_result = federated_protocol_->EligibilityEvalCheckin(
1349       mock_eet_received_callback_.AsStdFunction());
1350 
1351   EXPECT_THAT(eligibility_checkin_result.status(), IsCode(UNAVAILABLE));
1352 
1353   // A Checkin(...) request should now fail, because Checkin(...) should only
1354   // be a called after a successful EligibilityEvalCheckin(...) request.
1355   ASSERT_DEATH(
1356       {
1357         auto unused = federated_protocol_->Checkin(
1358             std::nullopt, mock_task_received_callback_.AsStdFunction());
1359       },
1360       _);
1361 }
1362 
TEST_F(HttpFederatedProtocolDeathTest,TestCheckinAfterEligibilityEvalCheckinRejection)1363 TEST_F(HttpFederatedProtocolDeathTest,
1364        TestCheckinAfterEligibilityEvalCheckinRejection) {
1365   EXPECT_CALL(mock_http_client_,
1366               PerformSingleRequest(SimpleHttpRequestMatcher(
1367                   "https://initial.uri/v1/eligibilityevaltasks/"
1368                   "TEST%2FPOPULATION:request?%24alt=proto",
1369                   HttpRequest::Method::kPost, _,
1370                   EligibilityEvalTaskRequestMatcher(
1371                       EqualsProto(GetExpectedEligibilityEvalTaskRequest())))))
1372       .WillOnce(Return(FakeHttpResponse(
1373           200, HeaderList(),
1374           GetFakeRejectedEligibilityEvalTaskResponse().SerializeAsString())));
1375 
1376   ASSERT_OK(federated_protocol_->EligibilityEvalCheckin(
1377       mock_eet_received_callback_.AsStdFunction()));
1378 
1379   // A Checkin(...) request should now fail, because Checkin(...) should only
1380   // be a called after a successful EligibilityEvalCheckin(...) request, with a
1381   // non-rejection response.
1382   ASSERT_DEATH(
1383       {
1384         auto unused = federated_protocol_->Checkin(
1385             std::nullopt, mock_task_received_callback_.AsStdFunction());
1386       },
1387       _);
1388 }
1389 
TEST_F(HttpFederatedProtocolDeathTest,TestCheckinWithEligibilityInfoAfterEligibilityEvalCheckinDisabled)1390 TEST_F(HttpFederatedProtocolDeathTest,
1391        TestCheckinWithEligibilityInfoAfterEligibilityEvalCheckinDisabled) {
1392   ASSERT_OK(
1393       RunSuccessfulEligibilityEvalCheckin(/*eligibility_eval_enabled=*/false));
1394 
1395   // A Checkin(...) request with a TaskEligibilityInfo argument should now fail,
1396   // because such info should only be passed a successful
1397   // EligibilityEvalCheckin(...) request with an eligibility eval task in the
1398   // response.
1399   ASSERT_DEATH(
1400       {
1401         auto unused = federated_protocol_->Checkin(
1402             TaskEligibilityInfo(),
1403             mock_task_received_callback_.AsStdFunction());
1404       },
1405       _);
1406 }
1407 
TEST_F(HttpFederatedProtocolDeathTest,TestCheckinWithMissingEligibilityInfo)1408 TEST_F(HttpFederatedProtocolDeathTest, TestCheckinWithMissingEligibilityInfo) {
1409   ASSERT_OK(
1410       RunSuccessfulEligibilityEvalCheckin(/*eligibility_eval_enabled=*/true));
1411 
1412   // A Checkin(...) request with a missing TaskEligibilityInfo should now fail,
1413   // as the protocol requires us to provide one based on the plan includes in
1414   // the eligibility eval checkin response payload..
1415   ASSERT_DEATH(
1416       {
1417         auto unused = federated_protocol_->Checkin(
1418             std::nullopt, mock_task_received_callback_.AsStdFunction());
1419       },
1420       _);
1421 }
1422 
TEST_F(HttpFederatedProtocolDeathTest,TestCheckinAfterEligibilityEvalResourceDataFetchFailed)1423 TEST_F(HttpFederatedProtocolDeathTest,
1424        TestCheckinAfterEligibilityEvalResourceDataFetchFailed) {
1425   Resource plan_resource;
1426   plan_resource.set_uri("https://fake.uri/plan");
1427   Resource checkpoint_resource;
1428   checkpoint_resource.set_uri("https://fake.uri/checkpoint");
1429   EligibilityEvalTaskResponse eval_task_response =
1430       GetFakeEnabledEligibilityEvalTaskResponse(
1431           plan_resource, checkpoint_resource, kEligibilityEvalExecutionId);
1432   EXPECT_CALL(mock_http_client_,
1433               PerformSingleRequest(SimpleHttpRequestMatcher(
1434                   "https://initial.uri/v1/eligibilityevaltasks/"
1435                   "TEST%2FPOPULATION:request?%24alt=proto",
1436                   HttpRequest::Method::kPost, _, _)))
1437       .WillOnce(Return(FakeHttpResponse(
1438           200, HeaderList(), eval_task_response.SerializeAsString())));
1439 
1440   // Mock a failed plan/resource fetch.
1441   EXPECT_CALL(mock_http_client_, PerformSingleRequest(SimpleHttpRequestMatcher(
1442                                      _, HttpRequest::Method::kGet, _, "")))
1443       .WillRepeatedly(Return(FakeHttpResponse(503, HeaderList(), "")));
1444 
1445   auto eligibility_checkin_result = federated_protocol_->EligibilityEvalCheckin(
1446       mock_eet_received_callback_.AsStdFunction());
1447 
1448   // A Checkin(...) request should now fail, because Checkin(...) should only
1449   // be a called after a successful EligibilityEvalCheckin(...) request, with a
1450   // non-rejection response.
1451   ASSERT_DEATH(
1452       {
1453         auto unused = federated_protocol_->Checkin(
1454             TaskEligibilityInfo(),
1455             mock_task_received_callback_.AsStdFunction());
1456       },
1457       _);
1458 }
1459 
1460 // Ensures that if the HTTP layer returns an error code that maps to a transient
1461 // error, it is handled correctly
TEST_F(HttpFederatedProtocolTest,TestCheckinFailsTransientError)1462 TEST_F(HttpFederatedProtocolTest, TestCheckinFailsTransientError) {
1463   // Issue an eligibility eval checkin first.
1464   ASSERT_OK(RunSuccessfulEligibilityEvalCheckin());
1465   std::string report_eet_request_uri =
1466       "https://initial.uri/v1/populations/TEST%2FPOPULATION/"
1467       "eligibilityevaltasks/"
1468       "ELIGIBILITY%2FSESSION%23ID:reportresult?%24alt=proto";
1469   ExpectSuccessfulReportEligibilityEvalTaskResultRequest(report_eet_request_uri,
1470                                                          absl::OkStatus());
1471 
1472   // Make the HTTP request return an 503 Service Unavailable error when the
1473   // Checkin(...) code tries to send its first request. This should result in
1474   // the error being returned as the result.
1475   EXPECT_CALL(
1476       mock_http_client_,
1477       PerformSingleRequest(SimpleHttpRequestMatcher(
1478           "https://taskassignment.uri/v1/populations/TEST%2FPOPULATION/"
1479           "taskassignments/ELIGIBILITY%2FSESSION%23ID:start?%24alt=proto",
1480           HttpRequest::Method::kPost, _, _)))
1481       .WillOnce(Return(FakeHttpResponse(503, HeaderList(), "")));
1482 
1483   auto checkin_result = federated_protocol_->Checkin(
1484       GetFakeTaskEligibilityInfo(),
1485       mock_task_received_callback_.AsStdFunction());
1486 
1487   EXPECT_THAT(checkin_result.status(), IsCode(UNAVAILABLE));
1488   // The original 503 HTTP response code should be included in the message as
1489   // well.
1490   EXPECT_THAT(checkin_result.status().message(), HasSubstr("503"));
1491   // RetryWindows were already received from the server during the eligibility
1492   // eval checkin, so we expect to get a 'rejected' retry window.
1493   ExpectRejectedRetryWindow(federated_protocol_->GetLatestRetryWindow());
1494 }
1495 
1496 // Ensures that if the HTTP layer returns an error code that maps to a permanent
1497 // error, it is handled correctly.
TEST_F(HttpFederatedProtocolTest,TestCheckinFailsPermanentErrorFromHttp)1498 TEST_F(HttpFederatedProtocolTest, TestCheckinFailsPermanentErrorFromHttp) {
1499   // Issue an eligibility eval checkin first.
1500   ASSERT_OK(RunSuccessfulEligibilityEvalCheckin());
1501   std::string report_eet_request_uri =
1502       "https://initial.uri/v1/populations/TEST%2FPOPULATION/"
1503       "eligibilityevaltasks/"
1504       "ELIGIBILITY%2FSESSION%23ID:reportresult?%24alt=proto";
1505   ExpectSuccessfulReportEligibilityEvalTaskResultRequest(report_eet_request_uri,
1506                                                          absl::OkStatus());
1507 
1508   // Make the HTTP request return an 404 Not Found error when the Checkin(...)
1509   // code tries to send its first request. This should result in the error being
1510   // returned as the result.
1511   EXPECT_CALL(
1512       mock_http_client_,
1513       PerformSingleRequest(SimpleHttpRequestMatcher(
1514           "https://taskassignment.uri/v1/populations/TEST%2FPOPULATION/"
1515           "taskassignments/ELIGIBILITY%2FSESSION%23ID:start?%24alt=proto",
1516           HttpRequest::Method::kPost, _, _)))
1517       .WillOnce(Return(FakeHttpResponse(404, HeaderList(), "")));
1518 
1519   auto checkin_result = federated_protocol_->Checkin(
1520       GetFakeTaskEligibilityInfo(),
1521       mock_task_received_callback_.AsStdFunction());
1522 
1523   EXPECT_THAT(checkin_result.status(), IsCode(NOT_FOUND));
1524   // The original 503 HTTP response code should be included in the message as
1525   // well.
1526   EXPECT_THAT(checkin_result.status().message(), HasSubstr("404"));
1527   // Even though RetryWindows were already received from the server during the
1528   // eligibility eval checkin, we expect a RetryWindow generated based on the
1529   // *permanent* errors retry delay flag, since NOT_FOUND is marked as a
1530   // permanent error in the flags, and permanent errors should always result in
1531   // permanent error windows (regardless of whether retry windows were already
1532   // received).
1533   ExpectPermanentErrorRetryWindow(federated_protocol_->GetLatestRetryWindow());
1534 }
1535 
1536 // Ensures that if the HTTP layer returns a successful response, but it contains
1537 // an Operation proto with a permanent error, that it is handled correctly.
TEST_F(HttpFederatedProtocolTest,TestCheckinFailsPermanentErrorFromOperation)1538 TEST_F(HttpFederatedProtocolTest, TestCheckinFailsPermanentErrorFromOperation) {
1539   // Issue an eligibility eval checkin first.
1540   ASSERT_OK(RunSuccessfulEligibilityEvalCheckin());
1541   std::string report_eet_request_uri =
1542       "https://initial.uri/v1/populations/TEST%2FPOPULATION/"
1543       "eligibilityevaltasks/"
1544       "ELIGIBILITY%2FSESSION%23ID:reportresult?%24alt=proto";
1545   ExpectSuccessfulReportEligibilityEvalTaskResultRequest(report_eet_request_uri,
1546                                                          absl::OkStatus());
1547 
1548   // Make the HTTP request return successfully, but make it contain an Operation
1549   // proto that itself contains a permanent error. This should result in the
1550   // error being returned as the result.
1551   EXPECT_CALL(
1552       mock_http_client_,
1553       PerformSingleRequest(SimpleHttpRequestMatcher(
1554           "https://taskassignment.uri/v1/populations/TEST%2FPOPULATION/"
1555           "taskassignments/ELIGIBILITY%2FSESSION%23ID:start?%24alt=proto",
1556           HttpRequest::Method::kPost, _, _)))
1557       .WillOnce(Return(FakeHttpResponse(
1558           200, HeaderList(),
1559           CreateErrorOperation(kOperationName, absl::StatusCode::kNotFound,
1560                                "foo")
1561               .SerializeAsString())));
1562 
1563   auto checkin_result = federated_protocol_->Checkin(
1564       GetFakeTaskEligibilityInfo(),
1565       mock_task_received_callback_.AsStdFunction());
1566 
1567   EXPECT_THAT(checkin_result.status(), IsCode(NOT_FOUND));
1568   EXPECT_THAT(checkin_result.status().message(),
1569               HasSubstr("Operation my_operation contained error"));
1570   // The original error message should be included in the message as well.
1571   EXPECT_THAT(checkin_result.status().message(), HasSubstr("foo"));
1572   // Even though RetryWindows were already received from the server during the
1573   // eligibility eval checkin, we expect a RetryWindow generated based on the
1574   // *permanent* errors retry delay flag, since NOT_FOUND is marked as a
1575   // permanent error in the flags, and permanent errors should always result in
1576   // permanent error windows (regardless of whether retry windows were already
1577   // received).
1578   ExpectPermanentErrorRetryWindow(federated_protocol_->GetLatestRetryWindow());
1579 }
1580 
1581 // Tests the case where we get interrupted while waiting for a response to the
1582 // protocol request in Checkin.
TEST_F(HttpFederatedProtocolTest,TestCheckinInterrupted)1583 TEST_F(HttpFederatedProtocolTest, TestCheckinInterrupted) {
1584   // Issue an eligibility eval checkin first.
1585   ASSERT_OK(RunSuccessfulEligibilityEvalCheckin());
1586   std::string report_eet_request_uri =
1587       "https://initial.uri/v1/populations/TEST%2FPOPULATION/"
1588       "eligibilityevaltasks/"
1589       "ELIGIBILITY%2FSESSION%23ID:reportresult?%24alt=proto";
1590   ExpectSuccessfulReportEligibilityEvalTaskResultRequest(report_eet_request_uri,
1591                                                          absl::OkStatus());
1592 
1593   absl::Notification request_issued;
1594   absl::Notification request_cancelled;
1595 
1596   // Make HttpClient::PerformRequests() block until the counter is decremented.
1597   EXPECT_CALL(
1598       mock_http_client_,
1599       PerformSingleRequest(SimpleHttpRequestMatcher(
1600           "https://taskassignment.uri/v1/populations/TEST%2FPOPULATION/"
1601           "taskassignments/ELIGIBILITY%2FSESSION%23ID:start?%24alt=proto",
1602           HttpRequest::Method::kPost, _, _)))
1603       .WillOnce([&request_issued, &request_cancelled](
1604                     MockableHttpClient::SimpleHttpRequest ignored) {
1605         request_issued.Notify();
1606         request_cancelled.WaitForNotification();
1607         return FakeHttpResponse(503, HeaderList(), "");
1608       });
1609 
1610   // Make should_abort return false until we know that the request was issued
1611   // (i.e. once InterruptibleRunner has actually started running the code it
1612   // was given), and then make it return true, triggering an abort sequence and
1613   // unblocking the PerformRequests()() call we caused to block above.
1614   EXPECT_CALL(mock_should_abort_, Call()).WillRepeatedly([&request_issued] {
1615     return request_issued.HasBeenNotified();
1616   });
1617 
1618   // When the HttpClient receives a HttpRequestHandle::Cancel call, we let the
1619   // request complete.
1620   mock_http_client_.SetCancellationListener([&request_cancelled]() {
1621     if (!request_cancelled.HasBeenNotified()) {
1622       request_cancelled.Notify();
1623     }
1624   });
1625 
1626   EXPECT_CALL(mock_log_manager_,
1627               LogDiag(ProdDiagCode::BACKGROUND_TRAINING_INTERRUPT_HTTP));
1628 
1629   auto checkin_result = federated_protocol_->Checkin(
1630       GetFakeTaskEligibilityInfo(),
1631       mock_task_received_callback_.AsStdFunction());
1632   EXPECT_THAT(checkin_result.status(), IsCode(CANCELLED));
1633   // RetryWindows were already received from the server during the eligibility
1634   // eval checkin, so we expect to get a 'rejected' retry window.
1635   ExpectRejectedRetryWindow(federated_protocol_->GetLatestRetryWindow());
1636 }
1637 
1638 // Tests the case where we get interrupted during polling of the long running
1639 // operation.
TEST_F(HttpFederatedProtocolTest,TestCheckinInterruptedDuringLongRunningOperation)1640 TEST_F(HttpFederatedProtocolTest,
1641        TestCheckinInterruptedDuringLongRunningOperation) {
1642   // Issue an eligibility eval checkin first.
1643   ASSERT_OK(RunSuccessfulEligibilityEvalCheckin());
1644   std::string report_eet_request_uri =
1645       "https://initial.uri/v1/populations/TEST%2FPOPULATION/"
1646       "eligibilityevaltasks/"
1647       "ELIGIBILITY%2FSESSION%23ID:reportresult?%24alt=proto";
1648   ExpectSuccessfulReportEligibilityEvalTaskResultRequest(report_eet_request_uri,
1649                                                          absl::OkStatus());
1650 
1651   absl::Notification request_issued;
1652   absl::Notification request_cancelled;
1653 
1654   Operation pending_operation = CreatePendingOperation("operations/foo#bar");
1655   // Make HttpClient::PerformRequests() block until the counter is decremented.
1656   EXPECT_CALL(
1657       mock_http_client_,
1658       PerformSingleRequest(SimpleHttpRequestMatcher(
1659           "https://taskassignment.uri/v1/populations/TEST%2FPOPULATION/"
1660           "taskassignments/ELIGIBILITY%2FSESSION%23ID:start?%24alt=proto",
1661           HttpRequest::Method::kPost, _, _)))
1662       .WillOnce(Return(FakeHttpResponse(
1663           200, HeaderList(), pending_operation.SerializeAsString())));
1664 
1665   // Make should_abort return false until we know that the request was issued
1666   // (i.e. once InterruptibleRunner has actually started running the code it
1667   // was given), and then make it return true, triggering an abort sequence and
1668   // unblocking the PerformRequests()() call we caused to block above.
1669   EXPECT_CALL(mock_should_abort_, Call()).WillRepeatedly([&request_issued] {
1670     return request_issued.HasBeenNotified();
1671   });
1672   EXPECT_CALL(
1673       mock_http_client_,
1674       PerformSingleRequest(SimpleHttpRequestMatcher(
1675           // Note that the '#' character is encoded as "%23".
1676           "https://taskassignment.uri/v1/operations/foo%23bar?%24alt=proto",
1677           HttpRequest::Method::kGet, _,
1678           GetOperationRequestMatcher(EqualsProto(GetOperationRequest())))))
1679       .WillRepeatedly([&request_issued, &request_cancelled, pending_operation](
1680                           MockableHttpClient::SimpleHttpRequest ignored) {
1681         if (!request_issued.HasBeenNotified()) {
1682           request_issued.Notify();
1683         }
1684         request_cancelled.WaitForNotification();
1685         return FakeHttpResponse(200, HeaderList(),
1686                                 pending_operation.SerializeAsString());
1687       });
1688 
1689   // Once the client is cancelled, a CancelOperationRequest should still be sent
1690   // out before returning to the caller."
1691   EXPECT_CALL(
1692       mock_http_client_,
1693       PerformSingleRequest(SimpleHttpRequestMatcher(
1694           // Note that the '#' character is encoded as "%23".
1695           "https://taskassignment.uri/v1/operations/"
1696           "foo%23bar:cancel?%24alt=proto",
1697           HttpRequest::Method::kGet, _,
1698           GetOperationRequestMatcher(EqualsProto(GetOperationRequest())))))
1699       .WillOnce(Return(FakeHttpResponse(200, HeaderList(), "")));
1700 
1701   // When the HttpClient receives a HttpRequestHandle::Cancel call, we let the
1702   // request complete.
1703   mock_http_client_.SetCancellationListener(
1704       [&request_cancelled]() { request_cancelled.Notify(); });
1705 
1706   EXPECT_CALL(mock_log_manager_,
1707               LogDiag(ProdDiagCode::BACKGROUND_TRAINING_INTERRUPT_HTTP));
1708 
1709   auto checkin_result = federated_protocol_->Checkin(
1710       GetFakeTaskEligibilityInfo(),
1711       mock_task_received_callback_.AsStdFunction());
1712   EXPECT_THAT(checkin_result.status(), IsCode(CANCELLED));
1713   // RetryWindows were already received from the server during the eligibility
1714   // eval checkin, so we expect to get a 'rejected' retry window.
1715   ExpectRejectedRetryWindow(federated_protocol_->GetLatestRetryWindow());
1716 }
1717 
1718 // Tests the case where we get interrupted during polling of the long-running
1719 // operation, and the issued cancellation request timed out.
TEST_F(HttpFederatedProtocolTest,TestCheckinInterruptedCancellationTimeout)1720 TEST_F(HttpFederatedProtocolTest, TestCheckinInterruptedCancellationTimeout) {
1721   // Issue an eligibility eval checkin first.
1722   ASSERT_OK(RunSuccessfulEligibilityEvalCheckin());
1723   std::string report_eet_request_uri =
1724       "https://initial.uri/v1/populations/TEST%2FPOPULATION/"
1725       "eligibilityevaltasks/"
1726       "ELIGIBILITY%2FSESSION%23ID:reportresult?%24alt=proto";
1727   ExpectSuccessfulReportEligibilityEvalTaskResultRequest(report_eet_request_uri,
1728                                                          absl::OkStatus());
1729 
1730   absl::Notification request_issued;
1731   absl::Notification request_cancelled;
1732 
1733   Operation pending_operation = CreatePendingOperation("operations/foo#bar");
1734   // Make HttpClient::PerformRequests() block until the counter is decremented.
1735   EXPECT_CALL(
1736       mock_http_client_,
1737       PerformSingleRequest(SimpleHttpRequestMatcher(
1738           "https://taskassignment.uri/v1/populations/TEST%2FPOPULATION/"
1739           "taskassignments/ELIGIBILITY%2FSESSION%23ID:start?%24alt=proto",
1740           HttpRequest::Method::kPost, _, _)))
1741       .WillOnce(Return(FakeHttpResponse(
1742           200, HeaderList(), pending_operation.SerializeAsString())));
1743 
1744   // Make should_abort return false until we know that the request was issued
1745   // (i.e. once InterruptibleRunner has actually started running the code it
1746   // was given), and then make it return true, triggering an abort sequence and
1747   // unblocking the PerformRequests()() call we caused to block above.
1748   EXPECT_CALL(mock_should_abort_, Call()).WillRepeatedly([&request_issued] {
1749     return request_issued.HasBeenNotified();
1750   });
1751   EXPECT_CALL(
1752       mock_http_client_,
1753       PerformSingleRequest(SimpleHttpRequestMatcher(
1754           // Note that the '#' character is encoded as "%23".
1755           "https://taskassignment.uri/v1/operations/foo%23bar?%24alt=proto",
1756           HttpRequest::Method::kGet, _,
1757           GetOperationRequestMatcher(EqualsProto(GetOperationRequest())))))
1758       .WillRepeatedly([&request_issued, &request_cancelled, pending_operation](
1759                           MockableHttpClient::SimpleHttpRequest ignored) {
1760         if (!request_issued.HasBeenNotified()) {
1761           request_issued.Notify();
1762         }
1763         request_cancelled.WaitForNotification();
1764         return FakeHttpResponse(200, HeaderList(),
1765                                 pending_operation.SerializeAsString());
1766       });
1767 
1768   // Once the client is cancelled, a CancelOperationRequest should still be sent
1769   // out before returning to the caller."
1770   EXPECT_CALL(
1771       mock_http_client_,
1772       PerformSingleRequest(SimpleHttpRequestMatcher(
1773           // Note that the '#' character is encoded as "%23".
1774           "https://taskassignment.uri/v1/operations/"
1775           "foo%23bar:cancel?%24alt=proto",
1776           HttpRequest::Method::kGet, _,
1777           GetOperationRequestMatcher(EqualsProto(GetOperationRequest())))))
1778       .WillOnce([](MockableHttpClient::SimpleHttpRequest ignored) {
1779         // Sleep for 2 seconds before returning the response.
1780         absl::SleepFor(absl::Seconds(2));
1781         return FakeHttpResponse(200, HeaderList(), "");
1782       });
1783 
1784   // When the HttpClient receives a HttpRequestHandle::Cancel call, we let the
1785   // request complete.
1786   mock_http_client_.SetCancellationListener([&request_cancelled]() {
1787     if (!request_cancelled.HasBeenNotified()) {
1788       request_cancelled.Notify();
1789     }
1790   });
1791 
1792   // The Interruption log will be logged twice, one for Get operation, the other
1793   // for Cancel operation.
1794   EXPECT_CALL(mock_log_manager_,
1795               LogDiag(ProdDiagCode::BACKGROUND_TRAINING_INTERRUPT_HTTP))
1796       .Times(2);
1797   EXPECT_CALL(mock_log_manager_,
1798               LogDiag(ProdDiagCode::HTTP_CANCELLATION_OR_ABORT_REQUEST_FAILED));
1799 
1800   auto checkin_result = federated_protocol_->Checkin(
1801       GetFakeTaskEligibilityInfo(),
1802       mock_task_received_callback_.AsStdFunction());
1803   EXPECT_THAT(checkin_result.status(), IsCode(CANCELLED));
1804   // RetryWindows were already received from the server during the eligibility
1805   // eval checkin, so we expect to get a 'rejected' retry window.
1806   ExpectRejectedRetryWindow(federated_protocol_->GetLatestRetryWindow());
1807 }
1808 
1809 // Tests whether 'rejection' responses to the main Checkin(...) request are
1810 // handled correctly.
TEST_F(HttpFederatedProtocolTest,TestCheckinRejectionWithTaskEligibilityInfo)1811 TEST_F(HttpFederatedProtocolTest, TestCheckinRejectionWithTaskEligibilityInfo) {
1812   // Issue an eligibility eval checkin first.
1813   ASSERT_OK(RunSuccessfulEligibilityEvalCheckin());
1814   std::string report_eet_request_uri =
1815       "https://initial.uri/v1/populations/TEST%2FPOPULATION/"
1816       "eligibilityevaltasks/"
1817       "ELIGIBILITY%2FSESSION%23ID:reportresult?%24alt=proto";
1818   ExpectSuccessfulReportEligibilityEvalTaskResultRequest(report_eet_request_uri,
1819                                                          absl::OkStatus());
1820 
1821   TaskEligibilityInfo expected_eligibility_info = GetFakeTaskEligibilityInfo();
1822   EXPECT_CALL(
1823       mock_http_client_,
1824       PerformSingleRequest(SimpleHttpRequestMatcher(
1825           "https://taskassignment.uri/v1/populations/TEST%2FPOPULATION/"
1826           "taskassignments/ELIGIBILITY%2FSESSION%23ID:start?%24alt=proto",
1827           HttpRequest::Method::kPost, _,
1828           StartTaskAssignmentRequestMatcher(
1829               EqualsProto(GetExpectedStartTaskAssignmentRequest(
1830                   expected_eligibility_info))))))
1831       .WillOnce(Return(FakeHttpResponse(
1832           200, HeaderList(),
1833           CreateDoneOperation(kOperationName,
1834                               GetFakeRejectedTaskAssignmentResponse())
1835               .SerializeAsString())));
1836 
1837   // The 'task received' callback should not be invoked since no task was given
1838   // to the client.
1839   EXPECT_CALL(mock_task_received_callback_, Call(_)).Times(0);
1840 
1841   // Issue the regular checkin.
1842   auto checkin_result = federated_protocol_->Checkin(
1843       expected_eligibility_info, mock_task_received_callback_.AsStdFunction());
1844 
1845   ASSERT_OK(checkin_result.status());
1846   EXPECT_THAT(*checkin_result, VariantWith<FederatedProtocol::Rejection>(_));
1847   ExpectRejectedRetryWindow(federated_protocol_->GetLatestRetryWindow());
1848 }
1849 
1850 // Tests whether we can issue a Checkin() request correctly without passing a
1851 // TaskEligibilityInfo, in the case that the eligibility eval checkin didn't
1852 // return any eligibility eval task to run.
TEST_F(HttpFederatedProtocolTest,TestCheckinRejectionWithoutTaskEligibilityInfo)1853 TEST_F(HttpFederatedProtocolTest,
1854        TestCheckinRejectionWithoutTaskEligibilityInfo) {
1855   // Issue an eligibility eval checkin first.
1856   ASSERT_OK(
1857       RunSuccessfulEligibilityEvalCheckin(/*eligibility_eval_enabled=*/false));
1858 
1859   EXPECT_CALL(
1860       mock_http_client_,
1861       PerformSingleRequest(SimpleHttpRequestMatcher(
1862           "https://taskassignment.uri/v1/populations/TEST%2FPOPULATION/"
1863           "taskassignments/ELIGIBILITY%2FSESSION%23ID:start?%24alt=proto",
1864           HttpRequest::Method::kPost, _,
1865           StartTaskAssignmentRequestMatcher(EqualsProto(
1866               GetExpectedStartTaskAssignmentRequest(std::nullopt))))))
1867       .WillOnce(Return(FakeHttpResponse(
1868           200, HeaderList(),
1869           CreateDoneOperation(kOperationName,
1870                               GetFakeRejectedTaskAssignmentResponse())
1871               .SerializeAsString())));
1872 
1873   // The 'task received' callback should not be invoked since no task was given
1874   // to the client.
1875   EXPECT_CALL(mock_task_received_callback_, Call(_)).Times(0);
1876 
1877   // Issue the regular checkin, without a TaskEligibilityInfo (since we didn't
1878   // receive an eligibility eval task to run during eligibility eval checkin).
1879   auto checkin_result = federated_protocol_->Checkin(
1880       std::nullopt, mock_task_received_callback_.AsStdFunction());
1881 
1882   ASSERT_OK(checkin_result.status());
1883   EXPECT_THAT(*checkin_result, VariantWith<FederatedProtocol::Rejection>(_));
1884   ExpectRejectedRetryWindow(federated_protocol_->GetLatestRetryWindow());
1885 }
1886 
1887 // Tests whether a successful task assignment response is handled correctly.
TEST_F(HttpFederatedProtocolTest,TestCheckinTaskAssigned)1888 TEST_F(HttpFederatedProtocolTest, TestCheckinTaskAssigned) {
1889   // Issue an eligibility eval checkin first.
1890   ASSERT_OK(RunSuccessfulEligibilityEvalCheckin());
1891   std::string report_eet_request_uri =
1892       "https://initial.uri/v1/populations/TEST%2FPOPULATION/"
1893       "eligibilityevaltasks/"
1894       "ELIGIBILITY%2FSESSION%23ID:reportresult?%24alt=proto";
1895   ExpectSuccessfulReportEligibilityEvalTaskResultRequest(report_eet_request_uri,
1896                                                          absl::OkStatus());
1897 
1898   TaskEligibilityInfo expected_eligibility_info = GetFakeTaskEligibilityInfo();
1899   // We return a fake response which requires fetching the plan via HTTP, but
1900   // which has the checkpoint data available inline.
1901   std::string expected_plan = kPlan;
1902   std::string plan_uri = "https://fake.uri/plan";
1903   Resource plan_resource;
1904   plan_resource.set_uri(plan_uri);
1905   std::string expected_checkpoint = kInitCheckpoint;
1906   Resource checkpoint_resource;
1907   checkpoint_resource.mutable_inline_resource()->set_data(expected_checkpoint);
1908   std::string expected_federated_select_uri_template =
1909       kFederatedSelectUriTemplate;
1910   std::string expected_aggregation_session_id = kAggregationSessionId;
1911 
1912   InSequence seq;
1913   // Note that in this particular test we check that the CheckinRequest is as
1914   // expected (in all prior tests we just use the '_' matcher, because the
1915   // request isn't really relevant to the test).
1916   EXPECT_CALL(
1917       mock_http_client_,
1918       PerformSingleRequest(SimpleHttpRequestMatcher(
1919           "https://taskassignment.uri/v1/populations/TEST%2FPOPULATION/"
1920           "taskassignments/ELIGIBILITY%2FSESSION%23ID:start?%24alt=proto",
1921           HttpRequest::Method::kPost, _,
1922           StartTaskAssignmentRequestMatcher(
1923               EqualsProto(GetExpectedStartTaskAssignmentRequest(
1924                   expected_eligibility_info))))))
1925       .WillOnce(Return(FakeHttpResponse(
1926           200, HeaderList(),
1927           CreateDoneOperation(kOperationName,
1928                               GetFakeTaskAssignmentResponse(
1929                                   plan_resource, checkpoint_resource,
1930                                   expected_federated_select_uri_template,
1931                                   expected_aggregation_session_id,
1932                                   kMinimumClientsInServerVisibleAggregate))
1933               .SerializeAsString())));
1934 
1935   // The 'task received' callback should be called *before* the actual task
1936   // resources are fetched.
1937   EXPECT_CALL(
1938       mock_task_received_callback_,
1939       Call(FieldsAre(FieldsAre("", ""), expected_federated_select_uri_template,
1940                      expected_aggregation_session_id,
1941                      Optional(FieldsAre(
1942                          _, Eq(kMinimumClientsInServerVisibleAggregate))))));
1943 
1944   EXPECT_CALL(mock_http_client_,
1945               PerformSingleRequest(SimpleHttpRequestMatcher(
1946                   plan_uri, HttpRequest::Method::kGet, _, "")))
1947       .WillOnce(Return(FakeHttpResponse(200, HeaderList(), expected_plan)));
1948 
1949   // Issue the regular checkin.
1950   auto checkin_result = federated_protocol_->Checkin(
1951       expected_eligibility_info, mock_task_received_callback_.AsStdFunction());
1952 
1953   ASSERT_OK(checkin_result.status());
1954   EXPECT_THAT(
1955       *checkin_result,
1956       VariantWith<FederatedProtocol::TaskAssignment>(FieldsAre(
1957           FieldsAre(absl::Cord(expected_plan), absl::Cord(expected_checkpoint)),
1958           expected_federated_select_uri_template,
1959           expected_aggregation_session_id,
1960           Optional(
1961               FieldsAre(_, Eq(kMinimumClientsInServerVisibleAggregate))))));
1962   // The Checkin call is expected to return the accepted retry window from the
1963   // response to the first eligibility eval request.
1964   ExpectAcceptedRetryWindow(federated_protocol_->GetLatestRetryWindow());
1965 }
1966 
1967 // Ensures that polling the Operation returned by a StartTaskAssignmentRequest
1968 // works as expected. This serves mostly as a high-level check. Further
1969 // polling-specific behavior is tested in more detail in
1970 // ProtocolRequestHelperTest.
TEST_F(HttpFederatedProtocolTest,TestCheckinTaskAssignedAfterOperationPolling)1971 TEST_F(HttpFederatedProtocolTest,
1972        TestCheckinTaskAssignedAfterOperationPolling) {
1973   // Issue an eligibility eval checkin first.
1974   ASSERT_OK(RunSuccessfulEligibilityEvalCheckin());
1975   std::string report_eet_request_uri =
1976       "https://initial.uri/v1/populations/TEST%2FPOPULATION/"
1977       "eligibilityevaltasks/"
1978       "ELIGIBILITY%2FSESSION%23ID:reportresult?%24alt=proto";
1979   ExpectSuccessfulReportEligibilityEvalTaskResultRequest(report_eet_request_uri,
1980                                                          absl::OkStatus());
1981 
1982   // Make the initial StartTaskAssignmentRequest return a pending Operation
1983   // result. Note that we use a '#' character in the operation name to allow us
1984   // to verify that it is properly URL-encoded.
1985   Operation pending_operation_response =
1986       CreatePendingOperation("operations/foo#bar");
1987   EXPECT_CALL(
1988       mock_http_client_,
1989       PerformSingleRequest(SimpleHttpRequestMatcher(
1990           "https://taskassignment.uri/v1/populations/TEST%2FPOPULATION/"
1991           "taskassignments/ELIGIBILITY%2FSESSION%23ID:start?%24alt=proto",
1992           HttpRequest::Method::kPost, _, _)))
1993       .WillOnce(Return(FakeHttpResponse(
1994           200, HeaderList(), pending_operation_response.SerializeAsString())));
1995 
1996   // Then, after letting the operation get polled twice more, eventually return
1997   // a fake response.
1998   std::string expected_plan = kPlan;
1999   Resource plan_resource;
2000   plan_resource.mutable_inline_resource()->set_data(expected_plan);
2001   std::string expected_checkpoint = kInitCheckpoint;
2002   Resource checkpoint_resource;
2003   checkpoint_resource.mutable_inline_resource()->set_data(expected_checkpoint);
2004   std::string expected_federated_select_uri_template =
2005       kFederatedSelectUriTemplate;
2006   std::string expected_aggregation_session_id = kAggregationSessionId;
2007 
2008   EXPECT_CALL(
2009       mock_http_client_,
2010       PerformSingleRequest(SimpleHttpRequestMatcher(
2011           // Note that the '#' character is encoded as "%23".
2012           "https://taskassignment.uri/v1/operations/foo%23bar?%24alt=proto",
2013           HttpRequest::Method::kGet, _,
2014           GetOperationRequestMatcher(EqualsProto(GetOperationRequest())))))
2015       .WillOnce(Return(FakeHttpResponse(
2016           200, HeaderList(), pending_operation_response.SerializeAsString())))
2017       .WillOnce(Return(FakeHttpResponse(
2018           200, HeaderList(), pending_operation_response.SerializeAsString())))
2019       .WillOnce(Return(FakeHttpResponse(
2020           200, HeaderList(),
2021           CreateDoneOperation(kOperationName,
2022                               GetFakeTaskAssignmentResponse(
2023                                   plan_resource, checkpoint_resource,
2024                                   expected_federated_select_uri_template,
2025                                   expected_aggregation_session_id, 0))
2026               .SerializeAsString())));
2027 
2028   // The 'task received' callback should be called, even if the task resource
2029   // data was available inline.
2030   EXPECT_CALL(
2031       mock_task_received_callback_,
2032       Call(FieldsAre(FieldsAre("", ""), expected_federated_select_uri_template,
2033                      expected_aggregation_session_id, Eq(std::nullopt))));
2034 
2035   // Issue the regular checkin.
2036   auto checkin_result = federated_protocol_->Checkin(
2037       GetFakeTaskEligibilityInfo(),
2038       mock_task_received_callback_.AsStdFunction());
2039 
2040   ASSERT_OK(checkin_result.status());
2041   EXPECT_THAT(
2042       *checkin_result,
2043       VariantWith<FederatedProtocol::TaskAssignment>(FieldsAre(
2044           FieldsAre(absl::Cord(expected_plan), absl::Cord(expected_checkpoint)),
2045           expected_federated_select_uri_template,
2046           expected_aggregation_session_id, Eq(std::nullopt))));
2047   // The Checkin call is expected to return the accepted retry window from the
2048   // response to the first eligibility eval request.
2049   ExpectAcceptedRetryWindow(federated_protocol_->GetLatestRetryWindow());
2050 }
2051 
2052 // Ensures that if the plan resource fails to be downloaded, the error is
2053 // correctly returned from the Checkin(...) method.
TEST_F(HttpFederatedProtocolTest,TestCheckinTaskAssignedPlanDataFetchFailed)2054 TEST_F(HttpFederatedProtocolTest, TestCheckinTaskAssignedPlanDataFetchFailed) {
2055   // Issue an eligibility eval checkin first.
2056   ASSERT_OK(RunSuccessfulEligibilityEvalCheckin());
2057   std::string report_eet_request_uri =
2058       "https://initial.uri/v1/populations/TEST%2FPOPULATION/"
2059       "eligibilityevaltasks/"
2060       "ELIGIBILITY%2FSESSION%23ID:reportresult?%24alt=proto";
2061   ExpectSuccessfulReportEligibilityEvalTaskResultRequest(report_eet_request_uri,
2062                                                          absl::OkStatus());
2063 
2064   std::string plan_uri = "https://fake.uri/plan";
2065   Resource plan_resource;
2066   plan_resource.set_uri(plan_uri);
2067   std::string checkpoint_uri = "https://fake.uri/checkpoint";
2068   Resource checkpoint_resource;
2069   checkpoint_resource.set_uri(checkpoint_uri);
2070   EXPECT_CALL(
2071       mock_http_client_,
2072       PerformSingleRequest(SimpleHttpRequestMatcher(
2073           "https://taskassignment.uri/v1/populations/TEST%2FPOPULATION/"
2074           "taskassignments/ELIGIBILITY%2FSESSION%23ID:start?%24alt=proto",
2075           HttpRequest::Method::kPost, _, _)))
2076       .WillOnce(Return(FakeHttpResponse(
2077           200, HeaderList(),
2078           CreateDoneOperation(
2079               kOperationName,
2080               GetFakeTaskAssignmentResponse(plan_resource, checkpoint_resource,
2081                                             kFederatedSelectUriTemplate,
2082                                             kAggregationSessionId, 0))
2083               .SerializeAsString())));
2084 
2085   // Mock a failed plan fetch.
2086   EXPECT_CALL(mock_http_client_,
2087               PerformSingleRequest(SimpleHttpRequestMatcher(
2088                   plan_uri, HttpRequest::Method::kGet, _, "")))
2089       .WillOnce(Return(FakeHttpResponse(404, HeaderList(), "")));
2090 
2091   EXPECT_CALL(mock_http_client_,
2092               PerformSingleRequest(SimpleHttpRequestMatcher(
2093                   checkpoint_uri, HttpRequest::Method::kGet, _, "")))
2094       .WillOnce(Return(FakeHttpResponse(200, HeaderList(), "")));
2095 
2096   // Issue the regular checkin.
2097   auto checkin_result = federated_protocol_->Checkin(
2098       GetFakeTaskEligibilityInfo(),
2099       mock_task_received_callback_.AsStdFunction());
2100 
2101   // The 404 error for the resource request should be reflected in the return
2102   // value.
2103   EXPECT_THAT(checkin_result.status(), IsCode(NOT_FOUND));
2104   EXPECT_THAT(checkin_result.status().message(),
2105               HasSubstr("plan fetch failed"));
2106   EXPECT_THAT(checkin_result.status().message(), HasSubstr("404"));
2107   // The Checkin call is expected to return the permanent error retry window,
2108   // since 404 maps to a permanent error.
2109   ExpectPermanentErrorRetryWindow(federated_protocol_->GetLatestRetryWindow());
2110 }
2111 
2112 // Ensures that if the checkpoint resource fails to be downloaded, the error is
2113 // correctly returned from the Checkin(...) method.
TEST_F(HttpFederatedProtocolTest,TestCheckinTaskAssignedCheckpointDataFetchFailed)2114 TEST_F(HttpFederatedProtocolTest,
2115        TestCheckinTaskAssignedCheckpointDataFetchFailed) {
2116   // Issue an eligibility eval checkin first.
2117   ASSERT_OK(RunSuccessfulEligibilityEvalCheckin());
2118   std::string report_eet_request_uri =
2119       "https://initial.uri/v1/populations/TEST%2FPOPULATION/"
2120       "eligibilityevaltasks/"
2121       "ELIGIBILITY%2FSESSION%23ID:reportresult?%24alt=proto";
2122   ExpectSuccessfulReportEligibilityEvalTaskResultRequest(report_eet_request_uri,
2123                                                          absl::OkStatus());
2124 
2125   std::string plan_uri = "https://fake.uri/plan";
2126   Resource plan_resource;
2127   plan_resource.set_uri(plan_uri);
2128   std::string checkpoint_uri = "https://fake.uri/checkpoint";
2129   Resource checkpoint_resource;
2130   checkpoint_resource.set_uri(checkpoint_uri);
2131 
2132   EXPECT_CALL(
2133       mock_http_client_,
2134       PerformSingleRequest(SimpleHttpRequestMatcher(
2135           "https://taskassignment.uri/v1/populations/TEST%2FPOPULATION/"
2136           "taskassignments/ELIGIBILITY%2FSESSION%23ID:start?%24alt=proto",
2137           HttpRequest::Method::kPost, _, _)))
2138       .WillOnce(Return(FakeHttpResponse(
2139           200, HeaderList(),
2140           CreateDoneOperation(
2141               kOperationName,
2142               GetFakeTaskAssignmentResponse(plan_resource, checkpoint_resource,
2143                                             kFederatedSelectUriTemplate,
2144                                             kAggregationSessionId, 0))
2145               .SerializeAsString())));
2146 
2147   // Mock a failed checkpoint fetch.
2148   EXPECT_CALL(mock_http_client_,
2149               PerformSingleRequest(SimpleHttpRequestMatcher(
2150                   checkpoint_uri, HttpRequest::Method::kGet, _, "")))
2151       .WillOnce(Return(FakeHttpResponse(503, HeaderList(), "")));
2152 
2153   EXPECT_CALL(mock_http_client_,
2154               PerformSingleRequest(SimpleHttpRequestMatcher(
2155                   plan_uri, HttpRequest::Method::kGet, _, "")))
2156       .WillOnce(Return(FakeHttpResponse(200, HeaderList(), "")));
2157 
2158   // Issue the regular checkin.
2159   auto checkin_result = federated_protocol_->Checkin(
2160       GetFakeTaskEligibilityInfo(),
2161       mock_task_received_callback_.AsStdFunction());
2162 
2163   // The 503 error for the resource request should be reflected in the return
2164   // value.
2165   EXPECT_THAT(checkin_result.status(), IsCode(UNAVAILABLE));
2166   EXPECT_THAT(checkin_result.status().message(),
2167               HasSubstr("checkpoint fetch failed"));
2168   EXPECT_THAT(checkin_result.status().message(), HasSubstr("503"));
2169   // The Checkin call is expected to return the rejected retry window from the
2170   // response to the first eligibility eval request.
2171   ExpectRejectedRetryWindow(federated_protocol_->GetLatestRetryWindow());
2172 }
2173 
TEST_F(HttpFederatedProtocolTest,TestReportCompletedViaSimpleAggSuccess)2174 TEST_F(HttpFederatedProtocolTest, TestReportCompletedViaSimpleAggSuccess) {
2175   // Issue an eligibility eval checkin first.
2176   ASSERT_OK(RunSuccessfulEligibilityEvalCheckin());
2177   // Issue a regular checkin
2178   ASSERT_OK(RunSuccessfulCheckin());
2179 
2180   // Create a fake checkpoint with 32 'X'.
2181   std::string checkpoint_str(32, 'X');
2182   ComputationResults results;
2183   results.emplace("tensorflow_checkpoint", checkpoint_str);
2184   absl::Duration plan_duration = absl::Minutes(5);
2185 
2186   ExpectSuccessfulReportTaskResultRequest(
2187       "https://taskassignment.uri/v1/populations/TEST%2FPOPULATION/"
2188       "taskassignments/CLIENT_SESSION_ID:reportresult?%24alt=proto",
2189       kAggregationSessionId, kTaskName, plan_duration);
2190   ExpectSuccessfulStartAggregationDataUploadRequest(
2191       "https://aggregation.uri/v1/aggregations/AGGREGATION_SESSION_ID/"
2192       "clients/AUTHORIZATION_TOKEN:startdataupload?%24alt=proto",
2193       kResourceName, kByteStreamTargetUri, kSecondStageAggregationTargetUri);
2194   ExpectSuccessfulByteStreamUploadRequest(
2195       "https://bytestream.uri/upload/v1/media/"
2196       "CHECKPOINT_RESOURCE?upload_protocol=raw",
2197       checkpoint_str);
2198   ExpectSuccessfulSubmitAggregationResultRequest(
2199       "https://aggregation.second.uri/v1/aggregations/"
2200       "AGGREGATION_SESSION_ID/clients/CLIENT_TOKEN:submit?%24alt=proto");
2201 
2202   EXPECT_OK(federated_protocol_->ReportCompleted(std::move(results),
2203                                                  plan_duration, std::nullopt));
2204 }
2205 
2206 // TODO(team): Remove this test once client_token is always populated in
2207 // StartAggregationDataUploadResponse.
TEST_F(HttpFederatedProtocolTest,TestReportCompletedViaSimpleAggWithoutClientToken)2208 TEST_F(HttpFederatedProtocolTest,
2209        TestReportCompletedViaSimpleAggWithoutClientToken) {
2210   // Issue an eligibility eval checkin first.
2211   ASSERT_OK(RunSuccessfulEligibilityEvalCheckin());
2212   // Issue a regular checkin
2213   ASSERT_OK(RunSuccessfulCheckin());
2214 
2215   // Create a fake checkpoint with 32 'X'.
2216   std::string checkpoint_str(32, 'X');
2217   ComputationResults results;
2218   results.emplace("tensorflow_checkpoint", checkpoint_str);
2219   absl::Duration plan_duration = absl::Minutes(5);
2220 
2221   ExpectSuccessfulReportTaskResultRequest(
2222       "https://taskassignment.uri/v1/populations/TEST%2FPOPULATION/"
2223       "taskassignments/CLIENT_SESSION_ID:reportresult?%24alt=proto",
2224       kAggregationSessionId, kTaskName, plan_duration);
2225 
2226   StartAggregationDataUploadResponse start_aggregation_data_upload_response =
2227       GetFakeStartAggregationDataUploadResponse(
2228           kResourceName, kByteStreamTargetUri,
2229           kSecondStageAggregationTargetUri);
2230   // Omit the client token from the response.
2231   start_aggregation_data_upload_response.clear_client_token();
2232   EXPECT_CALL(
2233       mock_http_client_,
2234       PerformSingleRequest(SimpleHttpRequestMatcher(
2235           "https://aggregation.uri/v1/aggregations/AGGREGATION_SESSION_ID/"
2236           "clients/AUTHORIZATION_TOKEN:startdataupload?%24alt=proto",
2237           HttpRequest::Method::kPost, _, _)))
2238       .WillOnce(Return(FakeHttpResponse(
2239           200, HeaderList(),
2240           CreateDoneOperation(kOperationName,
2241                               start_aggregation_data_upload_response)
2242               .SerializeAsString())));
2243 
2244   ExpectSuccessfulByteStreamUploadRequest(
2245       "https://bytestream.uri/upload/v1/media/"
2246       "CHECKPOINT_RESOURCE?upload_protocol=raw",
2247       checkpoint_str);
2248   // SubmitAggregationResult should reuse the authorization token.
2249   ExpectSuccessfulSubmitAggregationResultRequest(
2250       "https://aggregation.second.uri/v1/aggregations/"
2251       "AGGREGATION_SESSION_ID/clients/AUTHORIZATION_TOKEN:submit?%24alt=proto");
2252 
2253   EXPECT_OK(federated_protocol_->ReportCompleted(std::move(results),
2254                                                  plan_duration, std::nullopt));
2255 }
2256 
TEST_F(HttpFederatedProtocolTest,TestReportCompletedViaSecureAgg)2257 TEST_F(HttpFederatedProtocolTest, TestReportCompletedViaSecureAgg) {
2258   absl::Duration plan_duration = absl::Minutes(5);
2259   // Issue an eligibility eval checkin first.
2260   ASSERT_OK(RunSuccessfulEligibilityEvalCheckin());
2261   // Issue a regular checkin
2262   ASSERT_OK(RunSuccessfulCheckin());
2263 
2264   StartSecureAggregationResponse start_secure_aggregation_response;
2265   start_secure_aggregation_response.set_client_token(kClientToken);
2266   auto masked_result_resource =
2267       start_secure_aggregation_response.mutable_masked_result_resource();
2268   masked_result_resource->set_resource_name("masked_resource");
2269   masked_result_resource->mutable_data_upload_forwarding_info()
2270       ->set_target_uri_prefix("https://bytestream.uri/");
2271 
2272   auto nonmasked_result_resource =
2273       start_secure_aggregation_response.mutable_nonmasked_result_resource();
2274   nonmasked_result_resource->set_resource_name("nonmasked_resource");
2275   nonmasked_result_resource->mutable_data_upload_forwarding_info()
2276       ->set_target_uri_prefix("https://bytestream.uri/");
2277 
2278   start_secure_aggregation_response.mutable_secagg_protocol_forwarding_info()
2279       ->set_target_uri_prefix("https://secure.aggregations.uri/");
2280   auto protocol_execution_info =
2281       start_secure_aggregation_response.mutable_protocol_execution_info();
2282   protocol_execution_info->set_minimum_surviving_clients_for_reconstruction(
2283       450);
2284   protocol_execution_info->set_expected_number_of_clients(500);
2285 
2286   auto secure_aggregands =
2287       start_secure_aggregation_response.mutable_secure_aggregands();
2288   SecureAggregandExecutionInfo secure_aggregand_execution_info;
2289   secure_aggregand_execution_info.set_modulus(9999);
2290   (*secure_aggregands)["secagg_tensor"] = secure_aggregand_execution_info;
2291 
2292   ExpectSuccessfulReportTaskResultRequest(
2293       "https://taskassignment.uri/v1/populations/TEST%2FPOPULATION/"
2294       "taskassignments/CLIENT_SESSION_ID:reportresult?%24alt=proto",
2295       kAggregationSessionId, kTaskName, plan_duration);
2296   EXPECT_CALL(mock_http_client_,
2297               PerformSingleRequest(SimpleHttpRequestMatcher(
2298                   "https://aggregation.uri/v1/secureaggregations/"
2299                   "AGGREGATION_SESSION_ID/clients/"
2300                   "AUTHORIZATION_TOKEN:start?%24alt=proto",
2301                   HttpRequest::Method::kPost, _,
2302                   StartSecureAggregationRequest().SerializeAsString())))
2303       .WillOnce(Return(FakeHttpResponse(
2304           200, HeaderList(),
2305           CreatePendingOperation("operations/foo#bar").SerializeAsString())));
2306   EXPECT_CALL(
2307       mock_http_client_,
2308       PerformSingleRequest(SimpleHttpRequestMatcher(
2309           "https://aggregation.uri/v1/operations/foo%23bar?%24alt=proto",
2310           HttpRequest::Method::kGet, _, "")))
2311       .WillOnce(Return(FakeHttpResponse(
2312           200, HeaderList(),
2313           CreateDoneOperation(kOperationName, start_secure_aggregation_response)
2314               .SerializeAsString())));
2315 
2316   // Create a fake checkpoint with 32 'X'.
2317   std::string checkpoint_str(32, 'X');
2318   ComputationResults results;
2319   results.emplace("tensorflow_checkpoint", checkpoint_str);
2320   results.emplace("secagg_tensor", QuantizedTensor());
2321 
2322   EXPECT_CALL(*mock_secagg_runner_factory_,
2323               CreateSecAggRunner(_, _, _, _, _, 500, 450))
2324       .WillOnce(WithArg<0>([&](auto send_to_server_impl) {
2325         auto mock_secagg_runner =
2326             std::make_unique<StrictMock<MockSecAggRunner>>();
2327         EXPECT_CALL(*mock_secagg_runner,
2328                     Run(UnorderedElementsAre(Pair(
2329                         "secagg_tensor", VariantWith<QuantizedTensor>(FieldsAre(
2330                                              IsEmpty(), 0, IsEmpty()))))))
2331             .WillOnce([=,
2332                        send_to_server_impl = std::move(send_to_server_impl)] {
2333               // SecAggSendToServerBase::Send should use the client token. This
2334               // needs to be tested here since `send_to_server_impl` should not
2335               // be used outside of Run.
2336               EXPECT_CALL(
2337                   mock_http_client_,
2338                   PerformSingleRequest(SimpleHttpRequestMatcher(
2339                       "https://secure.aggregations.uri/v1/secureaggregations/"
2340                       "AGGREGATION_SESSION_ID/clients/"
2341                       "CLIENT_TOKEN:abort?%24alt=proto",
2342                       _, _, _)))
2343                   .WillOnce(Return(CreateEmptySuccessHttpResponse()));
2344               secagg::ClientToServerWrapperMessage abort_message;
2345               abort_message.mutable_abort();
2346               send_to_server_impl->Send(&abort_message);
2347 
2348               return absl::OkStatus();
2349             });
2350         return mock_secagg_runner;
2351       }));
2352 
2353   EXPECT_OK(federated_protocol_->ReportCompleted(std::move(results),
2354                                                  plan_duration, std::nullopt));
2355 }
2356 
2357 // TODO(team): Remove this test once client_token is always populated in
2358 // StartSecureAggregationResponse.
TEST_F(HttpFederatedProtocolTest,TestReportCompletedViaSecureAggWithoutClientToken)2359 TEST_F(HttpFederatedProtocolTest,
2360        TestReportCompletedViaSecureAggWithoutClientToken) {
2361   absl::Duration plan_duration = absl::Minutes(5);
2362   // Issue an eligibility eval checkin first.
2363   ASSERT_OK(RunSuccessfulEligibilityEvalCheckin());
2364   // Issue a regular checkin
2365   ASSERT_OK(RunSuccessfulCheckin());
2366 
2367   StartSecureAggregationResponse start_secure_aggregation_response;
2368   // Don't set client_token.
2369   auto masked_result_resource =
2370       start_secure_aggregation_response.mutable_masked_result_resource();
2371   masked_result_resource->set_resource_name("masked_resource");
2372   masked_result_resource->mutable_data_upload_forwarding_info()
2373       ->set_target_uri_prefix("https://bytestream.uri/");
2374 
2375   auto nonmasked_result_resource =
2376       start_secure_aggregation_response.mutable_nonmasked_result_resource();
2377   nonmasked_result_resource->set_resource_name("nonmasked_resource");
2378   nonmasked_result_resource->mutable_data_upload_forwarding_info()
2379       ->set_target_uri_prefix("https://bytestream.uri/");
2380 
2381   start_secure_aggregation_response.mutable_secagg_protocol_forwarding_info()
2382       ->set_target_uri_prefix("https://secure.aggregations.uri/");
2383   auto protocol_execution_info =
2384       start_secure_aggregation_response.mutable_protocol_execution_info();
2385   protocol_execution_info->set_minimum_surviving_clients_for_reconstruction(
2386       450);
2387   protocol_execution_info->set_expected_number_of_clients(500);
2388 
2389   auto secure_aggregands =
2390       start_secure_aggregation_response.mutable_secure_aggregands();
2391   SecureAggregandExecutionInfo secure_aggregand_execution_info;
2392   secure_aggregand_execution_info.set_modulus(9999);
2393   (*secure_aggregands)["secagg_tensor"] = secure_aggregand_execution_info;
2394 
2395   ExpectSuccessfulReportTaskResultRequest(
2396       "https://taskassignment.uri/v1/populations/TEST%2FPOPULATION/"
2397       "taskassignments/CLIENT_SESSION_ID:reportresult?%24alt=proto",
2398       kAggregationSessionId, kTaskName, plan_duration);
2399   EXPECT_CALL(mock_http_client_,
2400               PerformSingleRequest(SimpleHttpRequestMatcher(
2401                   "https://aggregation.uri/v1/secureaggregations/"
2402                   "AGGREGATION_SESSION_ID/clients/"
2403                   "AUTHORIZATION_TOKEN:start?%24alt=proto",
2404                   HttpRequest::Method::kPost, _,
2405                   StartSecureAggregationRequest().SerializeAsString())))
2406       .WillOnce(Return(FakeHttpResponse(
2407           200, HeaderList(),
2408           CreateDoneOperation(kOperationName, start_secure_aggregation_response)
2409               .SerializeAsString())));
2410 
2411   // Create a fake checkpoint with 32 'X'.
2412   std::string checkpoint_str(32, 'X');
2413   ComputationResults results;
2414   results.emplace("tensorflow_checkpoint", checkpoint_str);
2415   results.emplace("secagg_tensor", QuantizedTensor());
2416 
2417   EXPECT_CALL(*mock_secagg_runner_factory_,
2418               CreateSecAggRunner(_, _, _, _, _, _, _))
2419       .WillOnce(WithArg<0>([&](auto send_to_server_impl) {
2420         auto mock_secagg_runner =
2421             std::make_unique<StrictMock<MockSecAggRunner>>();
2422         EXPECT_CALL(*mock_secagg_runner, Run(_))
2423             .WillOnce([=,
2424                        send_to_server_impl = std::move(send_to_server_impl)] {
2425               // SecAggSendToServerBase::Send should reuse the authorization
2426               // token. This needs to be tested here since `send_to_server_impl`
2427               // should not be used outside of Run.
2428               EXPECT_CALL(
2429                   mock_http_client_,
2430                   PerformSingleRequest(SimpleHttpRequestMatcher(
2431                       "https://secure.aggregations.uri/v1/secureaggregations/"
2432                       "AGGREGATION_SESSION_ID/clients/"
2433                       "AUTHORIZATION_TOKEN:abort?%24alt=proto",
2434                       _, _, _)))
2435                   .WillOnce(Return(CreateEmptySuccessHttpResponse()));
2436               secagg::ClientToServerWrapperMessage abort_message;
2437               abort_message.mutable_abort();
2438               send_to_server_impl->Send(&abort_message);
2439 
2440               return absl::OkStatus();
2441             });
2442         return mock_secagg_runner;
2443       }));
2444 
2445   EXPECT_OK(federated_protocol_->ReportCompleted(std::move(results),
2446                                                  plan_duration, std::nullopt));
2447 }
2448 
TEST_F(HttpFederatedProtocolTest,TestReportCompletedViaSecureAggReportTaskResultFailed)2449 TEST_F(HttpFederatedProtocolTest,
2450        TestReportCompletedViaSecureAggReportTaskResultFailed) {
2451   absl::Duration plan_duration = absl::Minutes(5);
2452   // Issue an eligibility eval checkin first.
2453   ASSERT_OK(RunSuccessfulEligibilityEvalCheckin());
2454   // Issue a regular checkin
2455   ASSERT_OK(RunSuccessfulCheckin());
2456 
2457   StartSecureAggregationResponse start_secure_aggregation_response;
2458   start_secure_aggregation_response.set_client_token(kClientToken);
2459   auto masked_result_resource =
2460       start_secure_aggregation_response.mutable_masked_result_resource();
2461   masked_result_resource->set_resource_name("masked_resource");
2462   masked_result_resource->mutable_data_upload_forwarding_info()
2463       ->set_target_uri_prefix("https://bytestream.uri/");
2464 
2465   auto nonmasked_result_resource =
2466       start_secure_aggregation_response.mutable_nonmasked_result_resource();
2467   nonmasked_result_resource->set_resource_name("nonmasked_resource");
2468   nonmasked_result_resource->mutable_data_upload_forwarding_info()
2469       ->set_target_uri_prefix("https://bytestream.uri/");
2470 
2471   start_secure_aggregation_response.mutable_secagg_protocol_forwarding_info()
2472       ->set_target_uri_prefix("https://secure.aggregations.uri/");
2473   auto protocol_execution_info =
2474       start_secure_aggregation_response.mutable_protocol_execution_info();
2475   protocol_execution_info->set_minimum_surviving_clients_for_reconstruction(
2476       450);
2477   protocol_execution_info->set_expected_number_of_clients(500);
2478 
2479   auto secure_aggregands =
2480       start_secure_aggregation_response.mutable_secure_aggregands();
2481   SecureAggregandExecutionInfo secure_aggregand_execution_info;
2482   secure_aggregand_execution_info.set_modulus(9999);
2483   (*secure_aggregands)["secagg_tensor"] = secure_aggregand_execution_info;
2484 
2485   // Mock a failed ReportTaskResult request.
2486   EXPECT_CALL(mock_http_client_,
2487               PerformSingleRequest(SimpleHttpRequestMatcher(
2488                   "https://taskassignment.uri/v1/populations/TEST%2FPOPULATION/"
2489                   "taskassignments/CLIENT_SESSION_ID:reportresult?%24alt=proto",
2490                   HttpRequest::Method::kPost, _,
2491                   ReportTaskResultRequestMatcher(
2492                       EqualsProto(GetExpectedReportTaskResultRequest(
2493                           kAggregationSessionId, kTaskName,
2494                           google::rpc::Code::OK, plan_duration))))))
2495       .WillOnce(Return(FakeHttpResponse(503, HeaderList())));
2496   EXPECT_CALL(mock_log_manager_,
2497               LogDiag(ProdDiagCode::HTTP_REPORT_TASK_RESULT_REQUEST_FAILED));
2498   EXPECT_CALL(mock_http_client_,
2499               PerformSingleRequest(SimpleHttpRequestMatcher(
2500                   "https://aggregation.uri/v1/secureaggregations/"
2501                   "AGGREGATION_SESSION_ID/clients/"
2502                   "AUTHORIZATION_TOKEN:start?%24alt=proto",
2503                   HttpRequest::Method::kPost, _,
2504                   StartSecureAggregationRequest().SerializeAsString())))
2505       .WillOnce(Return(FakeHttpResponse(
2506           200, HeaderList(),
2507           CreatePendingOperation("operations/foo#bar").SerializeAsString())));
2508   EXPECT_CALL(
2509       mock_http_client_,
2510       PerformSingleRequest(SimpleHttpRequestMatcher(
2511           "https://aggregation.uri/v1/operations/foo%23bar?%24alt=proto",
2512           HttpRequest::Method::kGet, _, "")))
2513       .WillOnce(Return(FakeHttpResponse(
2514           200, HeaderList(),
2515           CreateDoneOperation(kOperationName, start_secure_aggregation_response)
2516               .SerializeAsString())));
2517 
2518   // Create a fake checkpoint with 32 'X'.
2519   std::string checkpoint_str(32, 'X');
2520   ComputationResults results;
2521   results.emplace("tensorflow_checkpoint", checkpoint_str);
2522   results.emplace("secagg_tensor", QuantizedTensor());
2523 
2524   MockSecAggRunner* mock_secagg_runner = new StrictMock<MockSecAggRunner>();
2525   EXPECT_CALL(*mock_secagg_runner_factory_,
2526               CreateSecAggRunner(_, _, _, _, _, 500, 450))
2527       .WillOnce(Return(ByMove(absl::WrapUnique(mock_secagg_runner))));
2528   EXPECT_CALL(*mock_secagg_runner,
2529               Run(UnorderedElementsAre(
2530                   Pair("secagg_tensor", VariantWith<QuantizedTensor>(FieldsAre(
2531                                             IsEmpty(), 0, IsEmpty()))))))
2532       .WillOnce(Return(absl::OkStatus()));
2533 
2534   EXPECT_OK(federated_protocol_->ReportCompleted(std::move(results),
2535                                                  plan_duration, std::nullopt));
2536 }
2537 
TEST_F(HttpFederatedProtocolTest,TestReportCompletedStartSecAggFailed)2538 TEST_F(HttpFederatedProtocolTest, TestReportCompletedStartSecAggFailed) {
2539   absl::Duration plan_duration = absl::Minutes(5);
2540   // Issue an eligibility eval checkin first.
2541   ASSERT_OK(RunSuccessfulEligibilityEvalCheckin());
2542   // Issue a regular checkin.
2543   ASSERT_OK(RunSuccessfulCheckin());
2544   ExpectSuccessfulReportTaskResultRequest(
2545       "https://taskassignment.uri/v1/populations/TEST%2FPOPULATION/"
2546       "taskassignments/CLIENT_SESSION_ID:reportresult?%24alt=proto",
2547       kAggregationSessionId, kTaskName, plan_duration);
2548   EXPECT_CALL(mock_http_client_,
2549               PerformSingleRequest(SimpleHttpRequestMatcher(
2550                   "https://aggregation.uri/v1/secureaggregations/"
2551                   "AGGREGATION_SESSION_ID/clients/"
2552                   "AUTHORIZATION_TOKEN:start?%24alt=proto",
2553                   HttpRequest::Method::kPost, _,
2554                   StartSecureAggregationRequest().SerializeAsString())))
2555       .WillOnce(Return(FakeHttpResponse(
2556           200, HeaderList(),
2557           CreateErrorOperation(kOperationName, absl::StatusCode::kInternal,
2558                                "Request failed.")
2559               .SerializeAsString())));
2560 
2561   // Create a fake checkpoint with 32 'X'.
2562   std::string checkpoint_str(32, 'X');
2563   ComputationResults results;
2564   results.emplace("tensorflow_checkpoint", checkpoint_str);
2565   results.emplace("secagg_tensor", QuantizedTensor());
2566 
2567   EXPECT_THAT(federated_protocol_->ReportCompleted(std::move(results),
2568                                                    plan_duration, std::nullopt),
2569               IsCode(absl::StatusCode::kInternal));
2570 }
2571 
TEST_F(HttpFederatedProtocolTest,TestReportCompletedStartSecAggFailedImmediately)2572 TEST_F(HttpFederatedProtocolTest,
2573        TestReportCompletedStartSecAggFailedImmediately) {
2574   absl::Duration plan_duration = absl::Minutes(5);
2575   // Issue an eligibility eval checkin first.
2576   ASSERT_OK(RunSuccessfulEligibilityEvalCheckin());
2577   // Issue a regular checkin.
2578   ASSERT_OK(RunSuccessfulCheckin());
2579   ExpectSuccessfulReportTaskResultRequest(
2580       "https://taskassignment.uri/v1/populations/TEST%2FPOPULATION/"
2581       "taskassignments/CLIENT_SESSION_ID:reportresult?%24alt=proto",
2582       kAggregationSessionId, kTaskName, plan_duration);
2583   EXPECT_CALL(mock_http_client_,
2584               PerformSingleRequest(SimpleHttpRequestMatcher(
2585                   "https://aggregation.uri/v1/secureaggregations/"
2586                   "AGGREGATION_SESSION_ID/clients/"
2587                   "AUTHORIZATION_TOKEN:start?%24alt=proto",
2588                   HttpRequest::Method::kPost, _,
2589                   StartSecureAggregationRequest().SerializeAsString())))
2590       .WillOnce(Return(FakeHttpResponse(403, HeaderList(), "")));
2591 
2592   // Create a fake checkpoint with 32 'X'.
2593   std::string checkpoint_str(32, 'X');
2594   ComputationResults results;
2595   results.emplace("tensorflow_checkpoint", checkpoint_str);
2596   results.emplace("secagg_tensor", QuantizedTensor());
2597 
2598   EXPECT_THAT(federated_protocol_->ReportCompleted(std::move(results),
2599                                                    plan_duration, std::nullopt),
2600               IsCode(absl::StatusCode::kPermissionDenied));
2601 }
2602 
TEST_F(HttpFederatedProtocolTest,TestReportCompletedReportTaskResultFailed)2603 TEST_F(HttpFederatedProtocolTest, TestReportCompletedReportTaskResultFailed) {
2604   // Issue an eligibility eval checkin first.
2605   ASSERT_OK(RunSuccessfulEligibilityEvalCheckin());
2606   // Issue a regular checkin.
2607   ASSERT_OK(RunSuccessfulCheckin());
2608 
2609   // Create a fake checkpoint with 32 'X'.
2610   std::string checkpoint_str(32, 'X');
2611   ComputationResults results;
2612   results.emplace("tensorflow_checkpoint", checkpoint_str);
2613   absl::Duration plan_duration = absl::Minutes(5);
2614 
2615   // Mock a failed ReportTaskResult request.
2616   ReportTaskResultResponse report_task_result_response;
2617   EXPECT_CALL(mock_http_client_,
2618               PerformSingleRequest(SimpleHttpRequestMatcher(
2619                   "https://taskassignment.uri/v1/populations/TEST%2FPOPULATION/"
2620                   "taskassignments/CLIENT_SESSION_ID:reportresult?%24alt=proto",
2621                   HttpRequest::Method::kPost, _,
2622                   ReportTaskResultRequestMatcher(
2623                       EqualsProto(GetExpectedReportTaskResultRequest(
2624                           kAggregationSessionId, kTaskName,
2625                           google::rpc::Code::OK, plan_duration))))))
2626       .WillOnce(Return(FakeHttpResponse(503, HeaderList())));
2627   EXPECT_CALL(mock_log_manager_,
2628               LogDiag(ProdDiagCode::HTTP_REPORT_TASK_RESULT_REQUEST_FAILED));
2629 
2630   ExpectSuccessfulStartAggregationDataUploadRequest(
2631       "https://aggregation.uri/v1/aggregations/AGGREGATION_SESSION_ID/"
2632       "clients/AUTHORIZATION_TOKEN:startdataupload?%24alt=proto",
2633       kResourceName, kByteStreamTargetUri, kSecondStageAggregationTargetUri);
2634   ExpectSuccessfulByteStreamUploadRequest(
2635       "https://bytestream.uri/upload/v1/media/"
2636       "CHECKPOINT_RESOURCE?upload_protocol=raw",
2637       checkpoint_str);
2638   ExpectSuccessfulSubmitAggregationResultRequest(
2639       "https://aggregation.second.uri/v1/aggregations/"
2640       "AGGREGATION_SESSION_ID/clients/CLIENT_TOKEN:submit?%24alt=proto");
2641 
2642   // Despite the ReportTaskResult request failed, we still consider the overall
2643   // ReportCompleted succeeded because the rest of the steps succeeds, and the
2644   // ReportTaskResult is a just a metric reporting on a best effort basis.
2645   EXPECT_OK(federated_protocol_->ReportCompleted(std::move(results),
2646                                                  plan_duration, std::nullopt));
2647 }
2648 
TEST_F(HttpFederatedProtocolTest,TestReportCompletedStartAggregationFailedImmediately)2649 TEST_F(HttpFederatedProtocolTest,
2650        TestReportCompletedStartAggregationFailedImmediately) {
2651   // Issue an eligibility eval checkin first.
2652   ASSERT_OK(RunSuccessfulEligibilityEvalCheckin());
2653   // Issue a regular checkin.
2654   ASSERT_OK(RunSuccessfulCheckin());
2655 
2656   std::string checkpoint_str;
2657   const size_t kTFCheckpointSize = 32;
2658   checkpoint_str.resize(kTFCheckpointSize, 'X');
2659   ComputationResults results;
2660   results.emplace("tensorflow_checkpoint", checkpoint_str);
2661   absl::Duration plan_duration = absl::Minutes(5);
2662 
2663   ExpectSuccessfulReportTaskResultRequest(
2664       "https://taskassignment.uri/v1/populations/TEST%2FPOPULATION/"
2665       "taskassignments/CLIENT_SESSION_ID:reportresult?%24alt=proto",
2666       kAggregationSessionId, kTaskName, plan_duration);
2667   EXPECT_CALL(
2668       mock_http_client_,
2669       PerformSingleRequest(SimpleHttpRequestMatcher(
2670           "https://aggregation.uri/v1/aggregations/AGGREGATION_SESSION_ID/"
2671           "clients/AUTHORIZATION_TOKEN:startdataupload?%24alt=proto",
2672           HttpRequest::Method::kPost, _,
2673           StartAggregationDataUploadRequest().SerializeAsString())))
2674       .WillOnce(Return(FakeHttpResponse(503, HeaderList())));
2675   absl::Status report_result = federated_protocol_->ReportCompleted(
2676       std::move(results), plan_duration, std::nullopt);
2677   ASSERT_THAT(report_result, IsCode(absl::StatusCode::kUnavailable));
2678   EXPECT_THAT(report_result.message(),
2679               HasSubstr("StartAggregationDataUpload request failed"));
2680   EXPECT_THAT(report_result.message(), HasSubstr("503"));
2681 }
2682 
TEST_F(HttpFederatedProtocolTest,TestReportCompletedStartAggregationFailedDuringPolling)2683 TEST_F(HttpFederatedProtocolTest,
2684        TestReportCompletedStartAggregationFailedDuringPolling) {
2685   // Issue an eligibility eval checkin first.
2686   ASSERT_OK(RunSuccessfulEligibilityEvalCheckin());
2687   // Issue a regular checkin.
2688   ASSERT_OK(RunSuccessfulCheckin());
2689 
2690   std::string checkpoint_str;
2691   const size_t kTFCheckpointSize = 32;
2692   checkpoint_str.resize(kTFCheckpointSize, 'X');
2693   ComputationResults results;
2694   results.emplace("tensorflow_checkpoint", checkpoint_str);
2695   absl::Duration plan_duration = absl::Minutes(5);
2696 
2697   ExpectSuccessfulReportTaskResultRequest(
2698       "https://taskassignment.uri/v1/populations/TEST%2FPOPULATION/"
2699       "taskassignments/CLIENT_SESSION_ID:reportresult?%24alt=proto",
2700       kAggregationSessionId, kTaskName, plan_duration);
2701   Operation pending_operation_response =
2702       CreatePendingOperation("operations/foo#bar");
2703   EXPECT_CALL(
2704       mock_http_client_,
2705       PerformSingleRequest(SimpleHttpRequestMatcher(
2706           "https://aggregation.uri/v1/aggregations/AGGREGATION_SESSION_ID/"
2707           "clients/AUTHORIZATION_TOKEN:startdataupload?%24alt=proto",
2708           HttpRequest::Method::kPost, _,
2709           StartAggregationDataUploadRequest().SerializeAsString())))
2710       .WillOnce(Return(FakeHttpResponse(
2711           200, HeaderList(), pending_operation_response.SerializeAsString())));
2712   EXPECT_CALL(
2713       mock_http_client_,
2714       PerformSingleRequest(SimpleHttpRequestMatcher(
2715           // Note that the '#' character is encoded as "%23".
2716           "https://aggregation.uri/v1/operations/foo%23bar?%24alt=proto",
2717           HttpRequest::Method::kGet, _,
2718           GetOperationRequestMatcher(EqualsProto(GetOperationRequest())))))
2719       .WillOnce(Return(FakeHttpResponse(401, HeaderList())));
2720   absl::Status report_result = federated_protocol_->ReportCompleted(
2721       std::move(results), plan_duration, std::nullopt);
2722   ASSERT_THAT(report_result, IsCode(absl::StatusCode::kUnauthenticated));
2723   EXPECT_THAT(report_result.message(),
2724               HasSubstr("StartAggregationDataUpload request failed"));
2725   EXPECT_THAT(report_result.message(), HasSubstr("401"));
2726 }
2727 
TEST_F(HttpFederatedProtocolTest,TestReportCompletedUploadFailed)2728 TEST_F(HttpFederatedProtocolTest, TestReportCompletedUploadFailed) {
2729   // Issue an eligibility eval checkin first.
2730   ASSERT_OK(RunSuccessfulEligibilityEvalCheckin());
2731   // Issue a regular checkin.
2732   ASSERT_OK(RunSuccessfulCheckin());
2733 
2734   std::string checkpoint_str;
2735   const size_t kTFCheckpointSize = 32;
2736   checkpoint_str.resize(kTFCheckpointSize, 'X');
2737   ComputationResults results;
2738   results.emplace("tensorflow_checkpoint", checkpoint_str);
2739   absl::Duration plan_duration = absl::Minutes(5);
2740 
2741   ExpectSuccessfulReportTaskResultRequest(
2742       "https://taskassignment.uri/v1/populations/TEST%2FPOPULATION/"
2743       "taskassignments/CLIENT_SESSION_ID:reportresult?%24alt=proto",
2744       kAggregationSessionId, kTaskName, plan_duration);
2745   ExpectSuccessfulStartAggregationDataUploadRequest(
2746       "https://aggregation.uri/v1/aggregations/AGGREGATION_SESSION_ID/"
2747       "clients/AUTHORIZATION_TOKEN:startdataupload?%24alt=proto",
2748       kResourceName, kByteStreamTargetUri, kSecondStageAggregationTargetUri);
2749   EXPECT_CALL(mock_http_client_,
2750               PerformSingleRequest(SimpleHttpRequestMatcher(
2751                   StrEq("https://bytestream.uri/upload/v1/media/"
2752                         "CHECKPOINT_RESOURCE?upload_protocol=raw"),
2753                   HttpRequest::Method::kPost, _, std::string(checkpoint_str))))
2754       .WillOnce(Return(FakeHttpResponse(501, HeaderList())));
2755   ExpectSuccessfulAbortAggregationRequest("https://aggregation.second.uri");
2756   absl::Status report_result = federated_protocol_->ReportCompleted(
2757       std::move(results), plan_duration, std::nullopt);
2758   ASSERT_THAT(report_result, IsCode(absl::StatusCode::kUnimplemented));
2759   EXPECT_THAT(report_result.message(), HasSubstr("Data upload failed"));
2760   EXPECT_THAT(report_result.message(), HasSubstr("501"));
2761 }
2762 
TEST_F(HttpFederatedProtocolTest,TestReportCompletedUploadAbortedByServer)2763 TEST_F(HttpFederatedProtocolTest, TestReportCompletedUploadAbortedByServer) {
2764   // Issue an eligibility eval checkin first.
2765   ASSERT_OK(RunSuccessfulEligibilityEvalCheckin());
2766   // Issue a regular checkin.
2767   ASSERT_OK(RunSuccessfulCheckin());
2768 
2769   std::string checkpoint_str;
2770   const size_t kTFCheckpointSize = 32;
2771   checkpoint_str.resize(kTFCheckpointSize, 'X');
2772   ComputationResults results;
2773   results.emplace("tensorflow_checkpoint", checkpoint_str);
2774   absl::Duration plan_duration = absl::Minutes(5);
2775 
2776   ExpectSuccessfulReportTaskResultRequest(
2777       "https://taskassignment.uri/v1/populations/TEST%2FPOPULATION/"
2778       "taskassignments/CLIENT_SESSION_ID:reportresult?%24alt=proto",
2779       kAggregationSessionId, kTaskName, plan_duration);
2780   ExpectSuccessfulStartAggregationDataUploadRequest(
2781       "https://aggregation.uri/v1/aggregations/AGGREGATION_SESSION_ID/"
2782       "clients/AUTHORIZATION_TOKEN:startdataupload?%24alt=proto",
2783       kResourceName, kByteStreamTargetUri, kSecondStageAggregationTargetUri);
2784   EXPECT_CALL(mock_http_client_,
2785               PerformSingleRequest(SimpleHttpRequestMatcher(
2786                   StrEq("https://bytestream.uri/upload/v1/media/"
2787                         "CHECKPOINT_RESOURCE?upload_protocol=raw"),
2788                   HttpRequest::Method::kPost, _, std::string(checkpoint_str))))
2789       .WillOnce(Return(FakeHttpResponse(
2790           409, HeaderList(),
2791           CreateErrorOperation(kOperationName, absl::StatusCode::kAborted,
2792                                "The client update is no longer needed.")
2793               .SerializeAsString())));
2794   absl::Status report_result = federated_protocol_->ReportCompleted(
2795       std::move(results), plan_duration, std::nullopt);
2796   ASSERT_THAT(report_result, IsCode(absl::StatusCode::kAborted));
2797   EXPECT_THAT(report_result.message(), HasSubstr("Data upload failed"));
2798   EXPECT_THAT(report_result.message(), HasSubstr("409"));
2799 }
2800 
TEST_F(HttpFederatedProtocolTest,TestReportCompletedUploadInterrupted)2801 TEST_F(HttpFederatedProtocolTest, TestReportCompletedUploadInterrupted) {
2802   // Issue an eligibility eval checkin first.
2803   ASSERT_OK(RunSuccessfulEligibilityEvalCheckin());
2804   // Issue a regular checkin.
2805   ASSERT_OK(RunSuccessfulCheckin());
2806 
2807   std::string checkpoint_str;
2808   const size_t kTFCheckpointSize = 32;
2809   checkpoint_str.resize(kTFCheckpointSize, 'X');
2810   ComputationResults results;
2811   results.emplace("tensorflow_checkpoint", checkpoint_str);
2812   absl::Duration plan_duration = absl::Minutes(5);
2813 
2814   ExpectSuccessfulReportTaskResultRequest(
2815       "https://taskassignment.uri/v1/populations/TEST%2FPOPULATION/"
2816       "taskassignments/CLIENT_SESSION_ID:reportresult?%24alt=proto",
2817       kAggregationSessionId, kTaskName, plan_duration);
2818   ExpectSuccessfulStartAggregationDataUploadRequest(
2819       "https://aggregation.uri/v1/aggregations/AGGREGATION_SESSION_ID/"
2820       "clients/AUTHORIZATION_TOKEN:startdataupload?%24alt=proto",
2821       kResourceName, kByteStreamTargetUri, kSecondStageAggregationTargetUri);
2822   absl::Notification request_issued;
2823   absl::Notification request_cancelled;
2824 
2825   // Make HttpClient::PerformRequests() block until the counter is decremented.
2826   EXPECT_CALL(mock_http_client_,
2827               PerformSingleRequest(SimpleHttpRequestMatcher(
2828                   StrEq("https://bytestream.uri/upload/v1/media/"
2829                         "CHECKPOINT_RESOURCE?upload_protocol=raw"),
2830                   HttpRequest::Method::kPost, _, std::string(checkpoint_str))))
2831       .WillOnce([&request_issued, &request_cancelled](
2832                     MockableHttpClient::SimpleHttpRequest ignored) {
2833         request_issued.Notify();
2834         request_cancelled.WaitForNotification();
2835         return FakeHttpResponse(503, HeaderList(), "");
2836       });
2837   // Make should_abort return false until we know that the request was issued
2838   // (i.e. once InterruptibleRunner has actually started running the code it
2839   // was given), and then make it return true, triggering an abort sequence and
2840   // unblocking the PerformRequests()() call we caused to block above.
2841   EXPECT_CALL(mock_should_abort_, Call()).WillRepeatedly([&request_issued] {
2842     return request_issued.HasBeenNotified();
2843   });
2844 
2845   // When the HttpClient receives a HttpRequestHandle::Cancel call, we let the
2846   // request complete.
2847   mock_http_client_.SetCancellationListener([&request_cancelled]() {
2848     if (!request_cancelled.HasBeenNotified()) {
2849       request_cancelled.Notify();
2850     }
2851   });
2852 
2853   EXPECT_CALL(mock_log_manager_,
2854               LogDiag(ProdDiagCode::BACKGROUND_TRAINING_INTERRUPT_HTTP));
2855   ExpectSuccessfulAbortAggregationRequest("https://aggregation.second.uri");
2856   absl::Status report_result = federated_protocol_->ReportCompleted(
2857       std::move(results), plan_duration, std::nullopt);
2858   ASSERT_THAT(report_result, IsCode(absl::StatusCode::kCancelled));
2859   EXPECT_THAT(report_result.message(), HasSubstr("Data upload failed"));
2860 }
2861 
TEST_F(HttpFederatedProtocolTest,TestReportCompletedSubmitAggregationResultFailed)2862 TEST_F(HttpFederatedProtocolTest,
2863        TestReportCompletedSubmitAggregationResultFailed) {
2864   // Issue an eligibility eval checkin first.
2865   ASSERT_OK(RunSuccessfulEligibilityEvalCheckin());
2866   // Issue a regular checkin.
2867   ASSERT_OK(RunSuccessfulCheckin());
2868 
2869   std::string checkpoint_str;
2870   const size_t kTFCheckpointSize = 32;
2871   checkpoint_str.resize(kTFCheckpointSize, 'X');
2872   ComputationResults results;
2873   results.emplace("tensorflow_checkpoint", checkpoint_str);
2874   absl::Duration plan_duration = absl::Minutes(5);
2875 
2876   ExpectSuccessfulReportTaskResultRequest(
2877       "https://taskassignment.uri/v1/populations/TEST%2FPOPULATION/"
2878       "taskassignments/CLIENT_SESSION_ID:reportresult?%24alt=proto",
2879       kAggregationSessionId, kTaskName, plan_duration);
2880   ExpectSuccessfulStartAggregationDataUploadRequest(
2881       "https://aggregation.uri/v1/aggregations/AGGREGATION_SESSION_ID/"
2882       "clients/AUTHORIZATION_TOKEN:startdataupload?%24alt=proto",
2883       kResourceName, kByteStreamTargetUri, kSecondStageAggregationTargetUri);
2884   ExpectSuccessfulByteStreamUploadRequest(
2885       "https://bytestream.uri/upload/v1/media/"
2886       "CHECKPOINT_RESOURCE?upload_protocol=raw",
2887       checkpoint_str);
2888 
2889   SubmitAggregationResultRequest submit_aggregation_result_request;
2890   submit_aggregation_result_request.set_resource_name(kResourceName);
2891   EXPECT_CALL(
2892       mock_http_client_,
2893       PerformSingleRequest(SimpleHttpRequestMatcher(
2894           "https://aggregation.second.uri/v1/aggregations/"
2895           "AGGREGATION_SESSION_ID/clients/CLIENT_TOKEN:submit?%24alt=proto",
2896           HttpRequest::Method::kPost, _,
2897           submit_aggregation_result_request.SerializeAsString())))
2898       .WillOnce(Return(FakeHttpResponse(409, HeaderList())));
2899   absl::Status report_result = federated_protocol_->ReportCompleted(
2900       std::move(results), plan_duration, std::nullopt);
2901 
2902   ASSERT_THAT(report_result, IsCode(absl::StatusCode::kAborted));
2903   EXPECT_THAT(report_result.message(),
2904               HasSubstr("SubmitAggregationResult failed"));
2905   EXPECT_THAT(report_result.message(), HasSubstr("409"));
2906 }
2907 
TEST_F(HttpFederatedProtocolTest,TestReportNotCompletedSuccess)2908 TEST_F(HttpFederatedProtocolTest, TestReportNotCompletedSuccess) {
2909   // Issue an eligibility eval checkin first.
2910   ASSERT_OK(RunSuccessfulEligibilityEvalCheckin());
2911   // Issue a regular checkin.
2912   ASSERT_OK(RunSuccessfulCheckin());
2913   absl::Duration plan_duration = absl::Minutes(5);
2914   ReportTaskResultResponse response;
2915   EXPECT_CALL(mock_http_client_,
2916               PerformSingleRequest(SimpleHttpRequestMatcher(
2917                   "https://taskassignment.uri/v1/populations/TEST%2FPOPULATION/"
2918                   "taskassignments/CLIENT_SESSION_ID:reportresult?%24alt=proto",
2919                   HttpRequest::Method::kPost, _,
2920                   ReportTaskResultRequestMatcher(
2921                       EqualsProto(GetExpectedReportTaskResultRequest(
2922                           kAggregationSessionId, kTaskName,
2923                           ::google::rpc::Code::INTERNAL, plan_duration))))))
2924       .WillOnce(Return(
2925           FakeHttpResponse(200, HeaderList(), response.SerializeAsString())));
2926 
2927   ASSERT_OK(federated_protocol_->ReportNotCompleted(
2928       engine::PhaseOutcome::ERROR, plan_duration, std::nullopt));
2929 
2930   ExpectAcceptedRetryWindow(federated_protocol_->GetLatestRetryWindow());
2931 }
2932 
TEST_F(HttpFederatedProtocolTest,TestReportNotCompletedError)2933 TEST_F(HttpFederatedProtocolTest, TestReportNotCompletedError) {
2934   // Issue an eligibility eval checkin first.
2935   ASSERT_OK(RunSuccessfulEligibilityEvalCheckin());
2936   // Issue a regular checkin.
2937   ASSERT_OK(RunSuccessfulCheckin());
2938   ReportTaskResultResponse response;
2939   EXPECT_CALL(mock_http_client_,
2940               PerformSingleRequest(SimpleHttpRequestMatcher(
2941                   "https://taskassignment.uri/v1/populations/TEST%2FPOPULATION/"
2942                   "taskassignments/CLIENT_SESSION_ID:reportresult?%24alt=proto",
2943                   HttpRequest::Method::kPost, _, _)))
2944       .WillOnce(Return(FakeHttpResponse(503, HeaderList())));
2945 
2946   absl::Status status = federated_protocol_->ReportNotCompleted(
2947       engine::PhaseOutcome::ERROR, absl::Minutes(5), std::nullopt);
2948   EXPECT_THAT(status, IsCode(UNAVAILABLE));
2949   EXPECT_THAT(
2950       status.message(),
2951       AllOf(HasSubstr("ReportTaskResult request failed:"), HasSubstr("503")));
2952   ExpectAcceptedRetryWindow(federated_protocol_->GetLatestRetryWindow());
2953 }
2954 
TEST_F(HttpFederatedProtocolTest,TestReportNotCompletedPermanentError)2955 TEST_F(HttpFederatedProtocolTest, TestReportNotCompletedPermanentError) {
2956   // Issue an eligibility eval checkin first.
2957   ASSERT_OK(RunSuccessfulEligibilityEvalCheckin());
2958   // Issue a regular checkin.
2959   ASSERT_OK(RunSuccessfulCheckin());
2960   ReportTaskResultResponse response;
2961   EXPECT_CALL(mock_http_client_,
2962               PerformSingleRequest(SimpleHttpRequestMatcher(
2963                   "https://taskassignment.uri/v1/populations/TEST%2FPOPULATION/"
2964                   "taskassignments/CLIENT_SESSION_ID:reportresult?%24alt=proto",
2965                   HttpRequest::Method::kPost, _, _)))
2966       .WillOnce(Return(FakeHttpResponse(404, HeaderList())));
2967 
2968   absl::Status status = federated_protocol_->ReportNotCompleted(
2969       engine::PhaseOutcome::ERROR, absl::Minutes(5), std::nullopt);
2970   EXPECT_THAT(status, IsCode(NOT_FOUND));
2971   EXPECT_THAT(
2972       status.message(),
2973       AllOf(HasSubstr("ReportTaskResult request failed:"), HasSubstr("404")));
2974   ExpectAcceptedRetryWindow(federated_protocol_->GetLatestRetryWindow());
2975 }
2976 
TEST_F(HttpFederatedProtocolTest,TestClientDecodedResourcesEnabledDeclaresSupport)2977 TEST_F(HttpFederatedProtocolTest,
2978        TestClientDecodedResourcesEnabledDeclaresSupport) {
2979   EligibilityEvalTaskRequest expected_eligibility_request;
2980   expected_eligibility_request.mutable_client_version()->set_version_code(
2981       kClientVersion);
2982   expected_eligibility_request.mutable_attestation_measurement()->set_value(
2983       kAttestationMeasurement);
2984   // Make sure gzip support is declared in the eligibility eval checkin request.
2985   expected_eligibility_request.mutable_resource_capabilities()
2986       ->add_supported_compression_formats(
2987           ResourceCompressionFormat::RESOURCE_COMPRESSION_FORMAT_GZIP);
2988   expected_eligibility_request.mutable_eligibility_eval_task_capabilities()
2989       ->set_supports_multiple_task_assignment(false);
2990 
2991   // Issue an eligibility eval checkin so we can validate the field is set.
2992   Resource eligibility_plan_resource;
2993   eligibility_plan_resource.mutable_inline_resource()->set_data(kPlan);
2994   Resource checkpoint_resource;
2995   checkpoint_resource.mutable_inline_resource()->set_data(kInitCheckpoint);
2996 
2997   EligibilityEvalTaskResponse eval_task_response =
2998       GetFakeEnabledEligibilityEvalTaskResponse(eligibility_plan_resource,
2999                                                 checkpoint_resource,
3000                                                 kEligibilityEvalExecutionId);
3001   const std::string eligibility_request_uri =
3002       "https://initial.uri/v1/eligibilityevaltasks/"
3003       "TEST%2FPOPULATION:request?%24alt=proto";
3004   EXPECT_CALL(mock_http_client_,
3005               PerformSingleRequest(SimpleHttpRequestMatcher(
3006                   eligibility_request_uri, HttpRequest::Method::kPost, _,
3007                   EligibilityEvalTaskRequestMatcher(
3008                       EqualsProto(expected_eligibility_request)))))
3009       .WillOnce(Return(FakeHttpResponse(
3010           200, HeaderList(), eval_task_response.SerializeAsString())));
3011 
3012   ASSERT_OK(federated_protocol_->EligibilityEvalCheckin(
3013       mock_eet_received_callback_.AsStdFunction()));
3014 
3015   // Now issue a regular checkin and make sure the field is set there too.
3016   const std::string plan_uri = "https://fake.uri/plan";
3017   Resource plan_resource;
3018   plan_resource.set_uri(plan_uri);
3019   StartTaskAssignmentResponse task_assignment_response =
3020       GetFakeTaskAssignmentResponse(plan_resource, checkpoint_resource,
3021                                     kFederatedSelectUriTemplate,
3022                                     kAggregationSessionId, 0);
3023   const std::string request_uri =
3024       "https://taskassignment.uri/v1/populations/TEST%2FPOPULATION/"
3025       "taskassignments/ELIGIBILITY%2FSESSION%23ID:start?%24alt=proto";
3026   TaskEligibilityInfo expected_eligibility_info = GetFakeTaskEligibilityInfo();
3027   StartTaskAssignmentRequest expected_request;
3028   expected_request.mutable_client_version()->set_version_code(kClientVersion);
3029   *expected_request.mutable_task_eligibility_info() = expected_eligibility_info;
3030   // Make sure gzip support is declared in the regular checkin request.
3031   expected_request.mutable_resource_capabilities()
3032       ->add_supported_compression_formats(
3033           ResourceCompressionFormat::RESOURCE_COMPRESSION_FORMAT_GZIP);
3034 
3035   EXPECT_CALL(
3036       mock_http_client_,
3037       PerformSingleRequest(SimpleHttpRequestMatcher(
3038           request_uri, HttpRequest::Method::kPost, _,
3039           StartTaskAssignmentRequestMatcher(EqualsProto(expected_request)))))
3040       .WillOnce(Return(FakeHttpResponse(
3041           200, HeaderList(),
3042           CreateDoneOperation(kOperationName, task_assignment_response)
3043               .SerializeAsString())));
3044 
3045   EXPECT_CALL(mock_http_client_,
3046               PerformSingleRequest(SimpleHttpRequestMatcher(
3047                   plan_uri, HttpRequest::Method::kGet, _, "")))
3048       .WillOnce(Return(FakeHttpResponse(200, HeaderList(), kPlan)));
3049 
3050   std::string report_eet_request_uri =
3051       "https://initial.uri/v1/populations/TEST%2FPOPULATION/"
3052       "eligibilityevaltasks/"
3053       "ELIGIBILITY%2FSESSION%23ID:reportresult?%24alt=proto";
3054   ExpectSuccessfulReportEligibilityEvalTaskResultRequest(report_eet_request_uri,
3055                                                          absl::OkStatus());
3056 
3057   ASSERT_OK(federated_protocol_->Checkin(
3058       expected_eligibility_info, mock_task_received_callback_.AsStdFunction()));
3059 }
3060 
3061 }  // anonymous namespace
3062 }  // namespace fcp::client::http
3063