xref: /aosp_15_r20/external/federated-compute/fcp/client/interruptible_runner_test.cc (revision 14675a029014e728ec732f129a32e299b2da0601)
1 /*
2  * Copyright 2020 Google LLC
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "fcp/client/interruptible_runner.h"
17 
18 #include <functional>
19 
20 #include "gmock/gmock.h"
21 #include "gtest/gtest.h"
22 #include "absl/status/status.h"
23 #include "absl/synchronization/blocking_counter.h"
24 #include "absl/time/time.h"
25 #include "fcp/client/diag_codes.pb.h"
26 #include "fcp/client/test_helpers.h"
27 #include "fcp/testing/testing.h"
28 
29 namespace fcp {
30 namespace client {
31 namespace {
32 
33 using ::fcp::client::ProdDiagCode::BACKGROUND_TRAINING_INTERRUPT_TF_EXECUTION;
34 using ::fcp::client::ProdDiagCode::
35     BACKGROUND_TRAINING_INTERRUPT_TF_EXECUTION_TIMED_OUT;
36 using ::fcp::client::ProdDiagCode::
37     BACKGROUND_TRAINING_INTERRUPT_TF_EXTENDED_EXECUTION_COMPLETED;
38 using ::fcp::client::ProdDiagCode::
39     BACKGROUND_TRAINING_INTERRUPT_TF_EXTENDED_EXECUTION_TIMED_OUT;
40 using ::testing::StrictMock;
41 
getDiagnosticsConfig()42 static InterruptibleRunner::DiagnosticsConfig getDiagnosticsConfig() {
43   return InterruptibleRunner::DiagnosticsConfig{
44       .interrupted = BACKGROUND_TRAINING_INTERRUPT_TF_EXECUTION,
45       .interrupt_timeout = BACKGROUND_TRAINING_INTERRUPT_TF_EXECUTION_TIMED_OUT,
46       .interrupted_extended =
47           BACKGROUND_TRAINING_INTERRUPT_TF_EXTENDED_EXECUTION_COMPLETED,
48       .interrupt_timeout_extended =
49           BACKGROUND_TRAINING_INTERRUPT_TF_EXTENDED_EXECUTION_TIMED_OUT};
50 }
51 
52 // Tests the case where runnable finishes before the future times out (and we'd
53 // call should_abort).
TEST(InterruptibleRunnerTest,TestNormalNoAbortCheck)54 TEST(InterruptibleRunnerTest, TestNormalNoAbortCheck) {
55   int should_abort_calls = 0;
56   int abort_function_calls = 0;
57   std::function<bool()> should_abort = [&should_abort_calls]() {
58     should_abort_calls++;
59     return false;
60   };
61   std::function<void()> abort_function = [&abort_function_calls]() {
62     abort_function_calls++;
63   };
64 
65   InterruptibleRunner interruptibleRunner(
66       /*log_manager=*/nullptr, should_abort,
67       InterruptibleRunner::TimingConfig{
68           .polling_period = absl::InfiniteDuration(),
69           .graceful_shutdown_period = absl::InfiniteDuration(),
70           .extended_shutdown_period = absl::InfiniteDuration()},
71       getDiagnosticsConfig());
72   absl::Status status = interruptibleRunner.Run(
73       []() { return absl::OkStatus(); }, abort_function);
74   EXPECT_THAT(status, IsCode(OK));
75   EXPECT_EQ(should_abort_calls, 1);
76   EXPECT_EQ(abort_function_calls, 0);
77 
78   // Test that the Status returned by the runnable is returned as is.
79   status = interruptibleRunner.Run([]() { return absl::DataLossError(""); },
80                                    abort_function);
81   EXPECT_THAT(status, IsCode(DATA_LOSS));
82 }
83 
84 // Tests the case where should_abort prevents us from even kicking off the run.
TEST(InterruptibleRunnerTest,TestNormalAbortBeforeRun)85 TEST(InterruptibleRunnerTest, TestNormalAbortBeforeRun) {
86   int should_abort_calls = 0;
87   int abort_function_calls = 0;
88   int runnable_calls = 0;
89   std::function<bool()> should_abort = [&should_abort_calls]() {
90     should_abort_calls++;
91     return true;
92   };
93   std::function<void()> abort_function = [&abort_function_calls]() {
94     abort_function_calls++;
95   };
96 
97   InterruptibleRunner interruptibleRunner(
98       /*log_manager=*/nullptr, should_abort,
99       InterruptibleRunner::TimingConfig{
100           .polling_period = absl::InfiniteDuration(),
101           .graceful_shutdown_period = absl::InfiniteDuration(),
102           .extended_shutdown_period = absl::InfiniteDuration()},
103       getDiagnosticsConfig());
104   absl::Status status = interruptibleRunner.Run(
105       [&runnable_calls]() {
106         runnable_calls++;
107         return absl::OkStatus();
108       },
109       abort_function);
110   EXPECT_THAT(status, IsCode(CANCELLED));
111   EXPECT_EQ(abort_function_calls, 0);
112   EXPECT_EQ(runnable_calls, 0);
113 }
114 
115 // Tests the case where the future wait times out once, we call should_abort,
116 // which says to continue, and then the future returns.
TEST(InterruptibleRunnerTest,TestNormalWithAbortCheckButNoAbort)117 TEST(InterruptibleRunnerTest, TestNormalWithAbortCheckButNoAbort) {
118   int should_abort_calls = 0;
119   int abort_function_calls = 0;
120   absl::BlockingCounter counter_should_abort(1);
121   absl::BlockingCounter counter_did_abort(1);
122   std::function<bool()> should_abort =
123       [&should_abort_calls, &counter_should_abort, &counter_did_abort]() {
124         should_abort_calls++;
125         if (should_abort_calls == 2) {
126           counter_should_abort.DecrementCount();
127           counter_did_abort.Wait();
128         }
129         return false;
130       };
131   std::function<void()> abort_function = [&abort_function_calls]() {
132     abort_function_calls++;
133   };
134 
135   InterruptibleRunner interruptibleRunner(
136       nullptr, should_abort,
137       InterruptibleRunner::TimingConfig{
138           .polling_period = absl::ZeroDuration(),
139           .graceful_shutdown_period = absl::InfiniteDuration(),
140           .extended_shutdown_period = absl::InfiniteDuration()},
141       getDiagnosticsConfig());
142   absl::Status status = interruptibleRunner.Run(
143       [&counter_should_abort, &counter_did_abort]() {
144         // Block until should_abort has been called.
145         counter_should_abort.Wait();
146         // Tell should_abort to return false.
147         counter_did_abort.DecrementCount();
148         return absl::OkStatus();
149       },
150       abort_function);
151   EXPECT_THAT(status, IsCode(OK));
152   EXPECT_GE(should_abort_calls, 2);
153   EXPECT_EQ(abort_function_calls, 0);
154 
155   status = interruptibleRunner.Run([]() { return absl::DataLossError(""); },
156                                    abort_function);
157   EXPECT_THAT(status, IsCode(DATA_LOSS));
158 }
159 
160 // Tests the case where the runnable gets aborted and behaves nicely (aborts
161 // within the grace period).
TEST(InterruptibleRunnerTest,TestAbortInGracePeriod)162 TEST(InterruptibleRunnerTest, TestAbortInGracePeriod) {
163   StrictMock<MockLogManager> log_manager;
164   int should_abort_calls = 0;
165   int abort_function_calls = 0;
166   absl::BlockingCounter counter_should_abort(1);
167   absl::BlockingCounter counter_did_abort(1);
168 
169   std::function<bool()> should_abort = [&should_abort_calls]() {
170     should_abort_calls++;
171     return should_abort_calls >= 2;
172   };
173   std::function<void()> abort_function =
174       [&abort_function_calls, &counter_should_abort, &counter_did_abort]() {
175         abort_function_calls++;
176         // Signal runnable to abort.
177         counter_should_abort.DecrementCount();
178         // Wait for runnable to have aborted.
179         counter_did_abort.Wait();
180       };
181 
182   InterruptibleRunner interruptibleRunner(
183       &log_manager, should_abort,
184       InterruptibleRunner::TimingConfig{
185           .polling_period = absl::ZeroDuration(),
186           .graceful_shutdown_period = absl::InfiniteDuration(),
187           .extended_shutdown_period = absl::InfiniteDuration()},
188       getDiagnosticsConfig());
189   // Tests that abort works.
190   EXPECT_CALL(log_manager, LogDiag(BACKGROUND_TRAINING_INTERRUPT_TF_EXECUTION))
191       .Times(testing::Exactly(1));
192   absl::Status status = interruptibleRunner.Run(
193       [&counter_should_abort, &counter_did_abort]() {
194         counter_should_abort.Wait();
195         counter_did_abort.DecrementCount();
196         return absl::OkStatus();
197       },
198       abort_function);
199   EXPECT_THAT(status, IsCode(CANCELLED));
200   EXPECT_EQ(should_abort_calls, 2);
201   EXPECT_EQ(abort_function_calls, 1);
202 }
203 
204 // Tests the case where abort does not happen within the grace period.
205 // This is achieved by only letting the runnable finish once the grace period
206 // wait fails and a timeout diag code is logged, by taking an action on the
207 // LogManager mock.
TEST(InterruptibleRunnerTest,TestAbortInExtendedGracePeriod)208 TEST(InterruptibleRunnerTest, TestAbortInExtendedGracePeriod) {
209   StrictMock<MockLogManager> log_manager;
210   int should_abort_calls = 0;
211   int abort_function_calls = 0;
212 
213   absl::BlockingCounter counter_should_abort(1);
214   absl::BlockingCounter counter_did_abort(1);
215 
216   std::function<bool()> should_abort = [&should_abort_calls]() {
217     should_abort_calls++;
218     return should_abort_calls >= 2;
219   };
220   std::function<void()> abort_function = [&abort_function_calls]() {
221     abort_function_calls++;
222   };
223 
224   InterruptibleRunner interruptibleRunner(
225       &log_manager, should_abort,
226       InterruptibleRunner::TimingConfig{
227           .polling_period = absl::ZeroDuration(),
228           .graceful_shutdown_period = absl::ZeroDuration(),
229           .extended_shutdown_period = absl::InfiniteDuration()},
230       getDiagnosticsConfig());
231   EXPECT_CALL(log_manager,
232               LogDiag(BACKGROUND_TRAINING_INTERRUPT_TF_EXECUTION_TIMED_OUT))
233       .WillOnce(
234           [&counter_should_abort, &counter_did_abort](ProdDiagCode ignored) {
235             counter_should_abort.DecrementCount();
236             counter_did_abort.Wait();
237             return absl::OkStatus();
238           });
239   EXPECT_CALL(
240       log_manager,
241       LogDiag(BACKGROUND_TRAINING_INTERRUPT_TF_EXTENDED_EXECUTION_COMPLETED))
242       .Times(testing::Exactly(1));
243   absl::Status status = interruptibleRunner.Run(
244       [&counter_should_abort, &counter_did_abort]() {
245         counter_should_abort.Wait();
246         counter_did_abort.DecrementCount();
247         return absl::OkStatus();
248       },
249       abort_function);
250 
251   EXPECT_THAT(status, IsCode(CANCELLED));
252   EXPECT_EQ(should_abort_calls, 2);
253   EXPECT_EQ(abort_function_calls, 1);
254 }
255 
256 }  // namespace
257 }  // namespace client
258 }  // namespace fcp
259