xref: /aosp_15_r20/external/federated-compute/fcp/client/interruptible_runner_test.cc (revision 14675a029014e728ec732f129a32e299b2da0601)
1*14675a02SAndroid Build Coastguard Worker /*
2*14675a02SAndroid Build Coastguard Worker  * Copyright 2020 Google LLC
3*14675a02SAndroid Build Coastguard Worker  *
4*14675a02SAndroid Build Coastguard Worker  * Licensed under the Apache License, Version 2.0 (the "License");
5*14675a02SAndroid Build Coastguard Worker  * you may not use this file except in compliance with the License.
6*14675a02SAndroid Build Coastguard Worker  * You may obtain a copy of the License at
7*14675a02SAndroid Build Coastguard Worker  *
8*14675a02SAndroid Build Coastguard Worker  *      http://www.apache.org/licenses/LICENSE-2.0
9*14675a02SAndroid Build Coastguard Worker  *
10*14675a02SAndroid Build Coastguard Worker  * Unless required by applicable law or agreed to in writing, software
11*14675a02SAndroid Build Coastguard Worker  * distributed under the License is distributed on an "AS IS" BASIS,
12*14675a02SAndroid Build Coastguard Worker  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13*14675a02SAndroid Build Coastguard Worker  * See the License for the specific language governing permissions and
14*14675a02SAndroid Build Coastguard Worker  * limitations under the License.
15*14675a02SAndroid Build Coastguard Worker  */
16*14675a02SAndroid Build Coastguard Worker #include "fcp/client/interruptible_runner.h"
17*14675a02SAndroid Build Coastguard Worker 
18*14675a02SAndroid Build Coastguard Worker #include <functional>
19*14675a02SAndroid Build Coastguard Worker 
20*14675a02SAndroid Build Coastguard Worker #include "gmock/gmock.h"
21*14675a02SAndroid Build Coastguard Worker #include "gtest/gtest.h"
22*14675a02SAndroid Build Coastguard Worker #include "absl/status/status.h"
23*14675a02SAndroid Build Coastguard Worker #include "absl/synchronization/blocking_counter.h"
24*14675a02SAndroid Build Coastguard Worker #include "absl/time/time.h"
25*14675a02SAndroid Build Coastguard Worker #include "fcp/client/diag_codes.pb.h"
26*14675a02SAndroid Build Coastguard Worker #include "fcp/client/test_helpers.h"
27*14675a02SAndroid Build Coastguard Worker #include "fcp/testing/testing.h"
28*14675a02SAndroid Build Coastguard Worker 
29*14675a02SAndroid Build Coastguard Worker namespace fcp {
30*14675a02SAndroid Build Coastguard Worker namespace client {
31*14675a02SAndroid Build Coastguard Worker namespace {
32*14675a02SAndroid Build Coastguard Worker 
33*14675a02SAndroid Build Coastguard Worker using ::fcp::client::ProdDiagCode::BACKGROUND_TRAINING_INTERRUPT_TF_EXECUTION;
34*14675a02SAndroid Build Coastguard Worker using ::fcp::client::ProdDiagCode::
35*14675a02SAndroid Build Coastguard Worker     BACKGROUND_TRAINING_INTERRUPT_TF_EXECUTION_TIMED_OUT;
36*14675a02SAndroid Build Coastguard Worker using ::fcp::client::ProdDiagCode::
37*14675a02SAndroid Build Coastguard Worker     BACKGROUND_TRAINING_INTERRUPT_TF_EXTENDED_EXECUTION_COMPLETED;
38*14675a02SAndroid Build Coastguard Worker using ::fcp::client::ProdDiagCode::
39*14675a02SAndroid Build Coastguard Worker     BACKGROUND_TRAINING_INTERRUPT_TF_EXTENDED_EXECUTION_TIMED_OUT;
40*14675a02SAndroid Build Coastguard Worker using ::testing::StrictMock;
41*14675a02SAndroid Build Coastguard Worker 
getDiagnosticsConfig()42*14675a02SAndroid Build Coastguard Worker static InterruptibleRunner::DiagnosticsConfig getDiagnosticsConfig() {
43*14675a02SAndroid Build Coastguard Worker   return InterruptibleRunner::DiagnosticsConfig{
44*14675a02SAndroid Build Coastguard Worker       .interrupted = BACKGROUND_TRAINING_INTERRUPT_TF_EXECUTION,
45*14675a02SAndroid Build Coastguard Worker       .interrupt_timeout = BACKGROUND_TRAINING_INTERRUPT_TF_EXECUTION_TIMED_OUT,
46*14675a02SAndroid Build Coastguard Worker       .interrupted_extended =
47*14675a02SAndroid Build Coastguard Worker           BACKGROUND_TRAINING_INTERRUPT_TF_EXTENDED_EXECUTION_COMPLETED,
48*14675a02SAndroid Build Coastguard Worker       .interrupt_timeout_extended =
49*14675a02SAndroid Build Coastguard Worker           BACKGROUND_TRAINING_INTERRUPT_TF_EXTENDED_EXECUTION_TIMED_OUT};
50*14675a02SAndroid Build Coastguard Worker }
51*14675a02SAndroid Build Coastguard Worker 
52*14675a02SAndroid Build Coastguard Worker // Tests the case where runnable finishes before the future times out (and we'd
53*14675a02SAndroid Build Coastguard Worker // call should_abort).
TEST(InterruptibleRunnerTest,TestNormalNoAbortCheck)54*14675a02SAndroid Build Coastguard Worker TEST(InterruptibleRunnerTest, TestNormalNoAbortCheck) {
55*14675a02SAndroid Build Coastguard Worker   int should_abort_calls = 0;
56*14675a02SAndroid Build Coastguard Worker   int abort_function_calls = 0;
57*14675a02SAndroid Build Coastguard Worker   std::function<bool()> should_abort = [&should_abort_calls]() {
58*14675a02SAndroid Build Coastguard Worker     should_abort_calls++;
59*14675a02SAndroid Build Coastguard Worker     return false;
60*14675a02SAndroid Build Coastguard Worker   };
61*14675a02SAndroid Build Coastguard Worker   std::function<void()> abort_function = [&abort_function_calls]() {
62*14675a02SAndroid Build Coastguard Worker     abort_function_calls++;
63*14675a02SAndroid Build Coastguard Worker   };
64*14675a02SAndroid Build Coastguard Worker 
65*14675a02SAndroid Build Coastguard Worker   InterruptibleRunner interruptibleRunner(
66*14675a02SAndroid Build Coastguard Worker       /*log_manager=*/nullptr, should_abort,
67*14675a02SAndroid Build Coastguard Worker       InterruptibleRunner::TimingConfig{
68*14675a02SAndroid Build Coastguard Worker           .polling_period = absl::InfiniteDuration(),
69*14675a02SAndroid Build Coastguard Worker           .graceful_shutdown_period = absl::InfiniteDuration(),
70*14675a02SAndroid Build Coastguard Worker           .extended_shutdown_period = absl::InfiniteDuration()},
71*14675a02SAndroid Build Coastguard Worker       getDiagnosticsConfig());
72*14675a02SAndroid Build Coastguard Worker   absl::Status status = interruptibleRunner.Run(
73*14675a02SAndroid Build Coastguard Worker       []() { return absl::OkStatus(); }, abort_function);
74*14675a02SAndroid Build Coastguard Worker   EXPECT_THAT(status, IsCode(OK));
75*14675a02SAndroid Build Coastguard Worker   EXPECT_EQ(should_abort_calls, 1);
76*14675a02SAndroid Build Coastguard Worker   EXPECT_EQ(abort_function_calls, 0);
77*14675a02SAndroid Build Coastguard Worker 
78*14675a02SAndroid Build Coastguard Worker   // Test that the Status returned by the runnable is returned as is.
79*14675a02SAndroid Build Coastguard Worker   status = interruptibleRunner.Run([]() { return absl::DataLossError(""); },
80*14675a02SAndroid Build Coastguard Worker                                    abort_function);
81*14675a02SAndroid Build Coastguard Worker   EXPECT_THAT(status, IsCode(DATA_LOSS));
82*14675a02SAndroid Build Coastguard Worker }
83*14675a02SAndroid Build Coastguard Worker 
84*14675a02SAndroid Build Coastguard Worker // Tests the case where should_abort prevents us from even kicking off the run.
TEST(InterruptibleRunnerTest,TestNormalAbortBeforeRun)85*14675a02SAndroid Build Coastguard Worker TEST(InterruptibleRunnerTest, TestNormalAbortBeforeRun) {
86*14675a02SAndroid Build Coastguard Worker   int should_abort_calls = 0;
87*14675a02SAndroid Build Coastguard Worker   int abort_function_calls = 0;
88*14675a02SAndroid Build Coastguard Worker   int runnable_calls = 0;
89*14675a02SAndroid Build Coastguard Worker   std::function<bool()> should_abort = [&should_abort_calls]() {
90*14675a02SAndroid Build Coastguard Worker     should_abort_calls++;
91*14675a02SAndroid Build Coastguard Worker     return true;
92*14675a02SAndroid Build Coastguard Worker   };
93*14675a02SAndroid Build Coastguard Worker   std::function<void()> abort_function = [&abort_function_calls]() {
94*14675a02SAndroid Build Coastguard Worker     abort_function_calls++;
95*14675a02SAndroid Build Coastguard Worker   };
96*14675a02SAndroid Build Coastguard Worker 
97*14675a02SAndroid Build Coastguard Worker   InterruptibleRunner interruptibleRunner(
98*14675a02SAndroid Build Coastguard Worker       /*log_manager=*/nullptr, should_abort,
99*14675a02SAndroid Build Coastguard Worker       InterruptibleRunner::TimingConfig{
100*14675a02SAndroid Build Coastguard Worker           .polling_period = absl::InfiniteDuration(),
101*14675a02SAndroid Build Coastguard Worker           .graceful_shutdown_period = absl::InfiniteDuration(),
102*14675a02SAndroid Build Coastguard Worker           .extended_shutdown_period = absl::InfiniteDuration()},
103*14675a02SAndroid Build Coastguard Worker       getDiagnosticsConfig());
104*14675a02SAndroid Build Coastguard Worker   absl::Status status = interruptibleRunner.Run(
105*14675a02SAndroid Build Coastguard Worker       [&runnable_calls]() {
106*14675a02SAndroid Build Coastguard Worker         runnable_calls++;
107*14675a02SAndroid Build Coastguard Worker         return absl::OkStatus();
108*14675a02SAndroid Build Coastguard Worker       },
109*14675a02SAndroid Build Coastguard Worker       abort_function);
110*14675a02SAndroid Build Coastguard Worker   EXPECT_THAT(status, IsCode(CANCELLED));
111*14675a02SAndroid Build Coastguard Worker   EXPECT_EQ(abort_function_calls, 0);
112*14675a02SAndroid Build Coastguard Worker   EXPECT_EQ(runnable_calls, 0);
113*14675a02SAndroid Build Coastguard Worker }
114*14675a02SAndroid Build Coastguard Worker 
115*14675a02SAndroid Build Coastguard Worker // Tests the case where the future wait times out once, we call should_abort,
116*14675a02SAndroid Build Coastguard Worker // which says to continue, and then the future returns.
TEST(InterruptibleRunnerTest,TestNormalWithAbortCheckButNoAbort)117*14675a02SAndroid Build Coastguard Worker TEST(InterruptibleRunnerTest, TestNormalWithAbortCheckButNoAbort) {
118*14675a02SAndroid Build Coastguard Worker   int should_abort_calls = 0;
119*14675a02SAndroid Build Coastguard Worker   int abort_function_calls = 0;
120*14675a02SAndroid Build Coastguard Worker   absl::BlockingCounter counter_should_abort(1);
121*14675a02SAndroid Build Coastguard Worker   absl::BlockingCounter counter_did_abort(1);
122*14675a02SAndroid Build Coastguard Worker   std::function<bool()> should_abort =
123*14675a02SAndroid Build Coastguard Worker       [&should_abort_calls, &counter_should_abort, &counter_did_abort]() {
124*14675a02SAndroid Build Coastguard Worker         should_abort_calls++;
125*14675a02SAndroid Build Coastguard Worker         if (should_abort_calls == 2) {
126*14675a02SAndroid Build Coastguard Worker           counter_should_abort.DecrementCount();
127*14675a02SAndroid Build Coastguard Worker           counter_did_abort.Wait();
128*14675a02SAndroid Build Coastguard Worker         }
129*14675a02SAndroid Build Coastguard Worker         return false;
130*14675a02SAndroid Build Coastguard Worker       };
131*14675a02SAndroid Build Coastguard Worker   std::function<void()> abort_function = [&abort_function_calls]() {
132*14675a02SAndroid Build Coastguard Worker     abort_function_calls++;
133*14675a02SAndroid Build Coastguard Worker   };
134*14675a02SAndroid Build Coastguard Worker 
135*14675a02SAndroid Build Coastguard Worker   InterruptibleRunner interruptibleRunner(
136*14675a02SAndroid Build Coastguard Worker       nullptr, should_abort,
137*14675a02SAndroid Build Coastguard Worker       InterruptibleRunner::TimingConfig{
138*14675a02SAndroid Build Coastguard Worker           .polling_period = absl::ZeroDuration(),
139*14675a02SAndroid Build Coastguard Worker           .graceful_shutdown_period = absl::InfiniteDuration(),
140*14675a02SAndroid Build Coastguard Worker           .extended_shutdown_period = absl::InfiniteDuration()},
141*14675a02SAndroid Build Coastguard Worker       getDiagnosticsConfig());
142*14675a02SAndroid Build Coastguard Worker   absl::Status status = interruptibleRunner.Run(
143*14675a02SAndroid Build Coastguard Worker       [&counter_should_abort, &counter_did_abort]() {
144*14675a02SAndroid Build Coastguard Worker         // Block until should_abort has been called.
145*14675a02SAndroid Build Coastguard Worker         counter_should_abort.Wait();
146*14675a02SAndroid Build Coastguard Worker         // Tell should_abort to return false.
147*14675a02SAndroid Build Coastguard Worker         counter_did_abort.DecrementCount();
148*14675a02SAndroid Build Coastguard Worker         return absl::OkStatus();
149*14675a02SAndroid Build Coastguard Worker       },
150*14675a02SAndroid Build Coastguard Worker       abort_function);
151*14675a02SAndroid Build Coastguard Worker   EXPECT_THAT(status, IsCode(OK));
152*14675a02SAndroid Build Coastguard Worker   EXPECT_GE(should_abort_calls, 2);
153*14675a02SAndroid Build Coastguard Worker   EXPECT_EQ(abort_function_calls, 0);
154*14675a02SAndroid Build Coastguard Worker 
155*14675a02SAndroid Build Coastguard Worker   status = interruptibleRunner.Run([]() { return absl::DataLossError(""); },
156*14675a02SAndroid Build Coastguard Worker                                    abort_function);
157*14675a02SAndroid Build Coastguard Worker   EXPECT_THAT(status, IsCode(DATA_LOSS));
158*14675a02SAndroid Build Coastguard Worker }
159*14675a02SAndroid Build Coastguard Worker 
160*14675a02SAndroid Build Coastguard Worker // Tests the case where the runnable gets aborted and behaves nicely (aborts
161*14675a02SAndroid Build Coastguard Worker // within the grace period).
TEST(InterruptibleRunnerTest,TestAbortInGracePeriod)162*14675a02SAndroid Build Coastguard Worker TEST(InterruptibleRunnerTest, TestAbortInGracePeriod) {
163*14675a02SAndroid Build Coastguard Worker   StrictMock<MockLogManager> log_manager;
164*14675a02SAndroid Build Coastguard Worker   int should_abort_calls = 0;
165*14675a02SAndroid Build Coastguard Worker   int abort_function_calls = 0;
166*14675a02SAndroid Build Coastguard Worker   absl::BlockingCounter counter_should_abort(1);
167*14675a02SAndroid Build Coastguard Worker   absl::BlockingCounter counter_did_abort(1);
168*14675a02SAndroid Build Coastguard Worker 
169*14675a02SAndroid Build Coastguard Worker   std::function<bool()> should_abort = [&should_abort_calls]() {
170*14675a02SAndroid Build Coastguard Worker     should_abort_calls++;
171*14675a02SAndroid Build Coastguard Worker     return should_abort_calls >= 2;
172*14675a02SAndroid Build Coastguard Worker   };
173*14675a02SAndroid Build Coastguard Worker   std::function<void()> abort_function =
174*14675a02SAndroid Build Coastguard Worker       [&abort_function_calls, &counter_should_abort, &counter_did_abort]() {
175*14675a02SAndroid Build Coastguard Worker         abort_function_calls++;
176*14675a02SAndroid Build Coastguard Worker         // Signal runnable to abort.
177*14675a02SAndroid Build Coastguard Worker         counter_should_abort.DecrementCount();
178*14675a02SAndroid Build Coastguard Worker         // Wait for runnable to have aborted.
179*14675a02SAndroid Build Coastguard Worker         counter_did_abort.Wait();
180*14675a02SAndroid Build Coastguard Worker       };
181*14675a02SAndroid Build Coastguard Worker 
182*14675a02SAndroid Build Coastguard Worker   InterruptibleRunner interruptibleRunner(
183*14675a02SAndroid Build Coastguard Worker       &log_manager, should_abort,
184*14675a02SAndroid Build Coastguard Worker       InterruptibleRunner::TimingConfig{
185*14675a02SAndroid Build Coastguard Worker           .polling_period = absl::ZeroDuration(),
186*14675a02SAndroid Build Coastguard Worker           .graceful_shutdown_period = absl::InfiniteDuration(),
187*14675a02SAndroid Build Coastguard Worker           .extended_shutdown_period = absl::InfiniteDuration()},
188*14675a02SAndroid Build Coastguard Worker       getDiagnosticsConfig());
189*14675a02SAndroid Build Coastguard Worker   // Tests that abort works.
190*14675a02SAndroid Build Coastguard Worker   EXPECT_CALL(log_manager, LogDiag(BACKGROUND_TRAINING_INTERRUPT_TF_EXECUTION))
191*14675a02SAndroid Build Coastguard Worker       .Times(testing::Exactly(1));
192*14675a02SAndroid Build Coastguard Worker   absl::Status status = interruptibleRunner.Run(
193*14675a02SAndroid Build Coastguard Worker       [&counter_should_abort, &counter_did_abort]() {
194*14675a02SAndroid Build Coastguard Worker         counter_should_abort.Wait();
195*14675a02SAndroid Build Coastguard Worker         counter_did_abort.DecrementCount();
196*14675a02SAndroid Build Coastguard Worker         return absl::OkStatus();
197*14675a02SAndroid Build Coastguard Worker       },
198*14675a02SAndroid Build Coastguard Worker       abort_function);
199*14675a02SAndroid Build Coastguard Worker   EXPECT_THAT(status, IsCode(CANCELLED));
200*14675a02SAndroid Build Coastguard Worker   EXPECT_EQ(should_abort_calls, 2);
201*14675a02SAndroid Build Coastguard Worker   EXPECT_EQ(abort_function_calls, 1);
202*14675a02SAndroid Build Coastguard Worker }
203*14675a02SAndroid Build Coastguard Worker 
204*14675a02SAndroid Build Coastguard Worker // Tests the case where abort does not happen within the grace period.
205*14675a02SAndroid Build Coastguard Worker // This is achieved by only letting the runnable finish once the grace period
206*14675a02SAndroid Build Coastguard Worker // wait fails and a timeout diag code is logged, by taking an action on the
207*14675a02SAndroid Build Coastguard Worker // LogManager mock.
TEST(InterruptibleRunnerTest,TestAbortInExtendedGracePeriod)208*14675a02SAndroid Build Coastguard Worker TEST(InterruptibleRunnerTest, TestAbortInExtendedGracePeriod) {
209*14675a02SAndroid Build Coastguard Worker   StrictMock<MockLogManager> log_manager;
210*14675a02SAndroid Build Coastguard Worker   int should_abort_calls = 0;
211*14675a02SAndroid Build Coastguard Worker   int abort_function_calls = 0;
212*14675a02SAndroid Build Coastguard Worker 
213*14675a02SAndroid Build Coastguard Worker   absl::BlockingCounter counter_should_abort(1);
214*14675a02SAndroid Build Coastguard Worker   absl::BlockingCounter counter_did_abort(1);
215*14675a02SAndroid Build Coastguard Worker 
216*14675a02SAndroid Build Coastguard Worker   std::function<bool()> should_abort = [&should_abort_calls]() {
217*14675a02SAndroid Build Coastguard Worker     should_abort_calls++;
218*14675a02SAndroid Build Coastguard Worker     return should_abort_calls >= 2;
219*14675a02SAndroid Build Coastguard Worker   };
220*14675a02SAndroid Build Coastguard Worker   std::function<void()> abort_function = [&abort_function_calls]() {
221*14675a02SAndroid Build Coastguard Worker     abort_function_calls++;
222*14675a02SAndroid Build Coastguard Worker   };
223*14675a02SAndroid Build Coastguard Worker 
224*14675a02SAndroid Build Coastguard Worker   InterruptibleRunner interruptibleRunner(
225*14675a02SAndroid Build Coastguard Worker       &log_manager, should_abort,
226*14675a02SAndroid Build Coastguard Worker       InterruptibleRunner::TimingConfig{
227*14675a02SAndroid Build Coastguard Worker           .polling_period = absl::ZeroDuration(),
228*14675a02SAndroid Build Coastguard Worker           .graceful_shutdown_period = absl::ZeroDuration(),
229*14675a02SAndroid Build Coastguard Worker           .extended_shutdown_period = absl::InfiniteDuration()},
230*14675a02SAndroid Build Coastguard Worker       getDiagnosticsConfig());
231*14675a02SAndroid Build Coastguard Worker   EXPECT_CALL(log_manager,
232*14675a02SAndroid Build Coastguard Worker               LogDiag(BACKGROUND_TRAINING_INTERRUPT_TF_EXECUTION_TIMED_OUT))
233*14675a02SAndroid Build Coastguard Worker       .WillOnce(
234*14675a02SAndroid Build Coastguard Worker           [&counter_should_abort, &counter_did_abort](ProdDiagCode ignored) {
235*14675a02SAndroid Build Coastguard Worker             counter_should_abort.DecrementCount();
236*14675a02SAndroid Build Coastguard Worker             counter_did_abort.Wait();
237*14675a02SAndroid Build Coastguard Worker             return absl::OkStatus();
238*14675a02SAndroid Build Coastguard Worker           });
239*14675a02SAndroid Build Coastguard Worker   EXPECT_CALL(
240*14675a02SAndroid Build Coastguard Worker       log_manager,
241*14675a02SAndroid Build Coastguard Worker       LogDiag(BACKGROUND_TRAINING_INTERRUPT_TF_EXTENDED_EXECUTION_COMPLETED))
242*14675a02SAndroid Build Coastguard Worker       .Times(testing::Exactly(1));
243*14675a02SAndroid Build Coastguard Worker   absl::Status status = interruptibleRunner.Run(
244*14675a02SAndroid Build Coastguard Worker       [&counter_should_abort, &counter_did_abort]() {
245*14675a02SAndroid Build Coastguard Worker         counter_should_abort.Wait();
246*14675a02SAndroid Build Coastguard Worker         counter_did_abort.DecrementCount();
247*14675a02SAndroid Build Coastguard Worker         return absl::OkStatus();
248*14675a02SAndroid Build Coastguard Worker       },
249*14675a02SAndroid Build Coastguard Worker       abort_function);
250*14675a02SAndroid Build Coastguard Worker 
251*14675a02SAndroid Build Coastguard Worker   EXPECT_THAT(status, IsCode(CANCELLED));
252*14675a02SAndroid Build Coastguard Worker   EXPECT_EQ(should_abort_calls, 2);
253*14675a02SAndroid Build Coastguard Worker   EXPECT_EQ(abort_function_calls, 1);
254*14675a02SAndroid Build Coastguard Worker }
255*14675a02SAndroid Build Coastguard Worker 
256*14675a02SAndroid Build Coastguard Worker }  // namespace
257*14675a02SAndroid Build Coastguard Worker }  // namespace client
258*14675a02SAndroid Build Coastguard Worker }  // namespace fcp
259