1*14675a02SAndroid Build Coastguard Worker /*
2*14675a02SAndroid Build Coastguard Worker * Copyright 2020 Google LLC
3*14675a02SAndroid Build Coastguard Worker *
4*14675a02SAndroid Build Coastguard Worker * Licensed under the Apache License, Version 2.0 (the "License");
5*14675a02SAndroid Build Coastguard Worker * you may not use this file except in compliance with the License.
6*14675a02SAndroid Build Coastguard Worker * You may obtain a copy of the License at
7*14675a02SAndroid Build Coastguard Worker *
8*14675a02SAndroid Build Coastguard Worker * http://www.apache.org/licenses/LICENSE-2.0
9*14675a02SAndroid Build Coastguard Worker *
10*14675a02SAndroid Build Coastguard Worker * Unless required by applicable law or agreed to in writing, software
11*14675a02SAndroid Build Coastguard Worker * distributed under the License is distributed on an "AS IS" BASIS,
12*14675a02SAndroid Build Coastguard Worker * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13*14675a02SAndroid Build Coastguard Worker * See the License for the specific language governing permissions and
14*14675a02SAndroid Build Coastguard Worker * limitations under the License.
15*14675a02SAndroid Build Coastguard Worker */
16*14675a02SAndroid Build Coastguard Worker #include "fcp/client/interruptible_runner.h"
17*14675a02SAndroid Build Coastguard Worker
18*14675a02SAndroid Build Coastguard Worker #include <functional>
19*14675a02SAndroid Build Coastguard Worker
20*14675a02SAndroid Build Coastguard Worker #include "gmock/gmock.h"
21*14675a02SAndroid Build Coastguard Worker #include "gtest/gtest.h"
22*14675a02SAndroid Build Coastguard Worker #include "absl/status/status.h"
23*14675a02SAndroid Build Coastguard Worker #include "absl/synchronization/blocking_counter.h"
24*14675a02SAndroid Build Coastguard Worker #include "absl/time/time.h"
25*14675a02SAndroid Build Coastguard Worker #include "fcp/client/diag_codes.pb.h"
26*14675a02SAndroid Build Coastguard Worker #include "fcp/client/test_helpers.h"
27*14675a02SAndroid Build Coastguard Worker #include "fcp/testing/testing.h"
28*14675a02SAndroid Build Coastguard Worker
29*14675a02SAndroid Build Coastguard Worker namespace fcp {
30*14675a02SAndroid Build Coastguard Worker namespace client {
31*14675a02SAndroid Build Coastguard Worker namespace {
32*14675a02SAndroid Build Coastguard Worker
33*14675a02SAndroid Build Coastguard Worker using ::fcp::client::ProdDiagCode::BACKGROUND_TRAINING_INTERRUPT_TF_EXECUTION;
34*14675a02SAndroid Build Coastguard Worker using ::fcp::client::ProdDiagCode::
35*14675a02SAndroid Build Coastguard Worker BACKGROUND_TRAINING_INTERRUPT_TF_EXECUTION_TIMED_OUT;
36*14675a02SAndroid Build Coastguard Worker using ::fcp::client::ProdDiagCode::
37*14675a02SAndroid Build Coastguard Worker BACKGROUND_TRAINING_INTERRUPT_TF_EXTENDED_EXECUTION_COMPLETED;
38*14675a02SAndroid Build Coastguard Worker using ::fcp::client::ProdDiagCode::
39*14675a02SAndroid Build Coastguard Worker BACKGROUND_TRAINING_INTERRUPT_TF_EXTENDED_EXECUTION_TIMED_OUT;
40*14675a02SAndroid Build Coastguard Worker using ::testing::StrictMock;
41*14675a02SAndroid Build Coastguard Worker
getDiagnosticsConfig()42*14675a02SAndroid Build Coastguard Worker static InterruptibleRunner::DiagnosticsConfig getDiagnosticsConfig() {
43*14675a02SAndroid Build Coastguard Worker return InterruptibleRunner::DiagnosticsConfig{
44*14675a02SAndroid Build Coastguard Worker .interrupted = BACKGROUND_TRAINING_INTERRUPT_TF_EXECUTION,
45*14675a02SAndroid Build Coastguard Worker .interrupt_timeout = BACKGROUND_TRAINING_INTERRUPT_TF_EXECUTION_TIMED_OUT,
46*14675a02SAndroid Build Coastguard Worker .interrupted_extended =
47*14675a02SAndroid Build Coastguard Worker BACKGROUND_TRAINING_INTERRUPT_TF_EXTENDED_EXECUTION_COMPLETED,
48*14675a02SAndroid Build Coastguard Worker .interrupt_timeout_extended =
49*14675a02SAndroid Build Coastguard Worker BACKGROUND_TRAINING_INTERRUPT_TF_EXTENDED_EXECUTION_TIMED_OUT};
50*14675a02SAndroid Build Coastguard Worker }
51*14675a02SAndroid Build Coastguard Worker
52*14675a02SAndroid Build Coastguard Worker // Tests the case where runnable finishes before the future times out (and we'd
53*14675a02SAndroid Build Coastguard Worker // call should_abort).
TEST(InterruptibleRunnerTest,TestNormalNoAbortCheck)54*14675a02SAndroid Build Coastguard Worker TEST(InterruptibleRunnerTest, TestNormalNoAbortCheck) {
55*14675a02SAndroid Build Coastguard Worker int should_abort_calls = 0;
56*14675a02SAndroid Build Coastguard Worker int abort_function_calls = 0;
57*14675a02SAndroid Build Coastguard Worker std::function<bool()> should_abort = [&should_abort_calls]() {
58*14675a02SAndroid Build Coastguard Worker should_abort_calls++;
59*14675a02SAndroid Build Coastguard Worker return false;
60*14675a02SAndroid Build Coastguard Worker };
61*14675a02SAndroid Build Coastguard Worker std::function<void()> abort_function = [&abort_function_calls]() {
62*14675a02SAndroid Build Coastguard Worker abort_function_calls++;
63*14675a02SAndroid Build Coastguard Worker };
64*14675a02SAndroid Build Coastguard Worker
65*14675a02SAndroid Build Coastguard Worker InterruptibleRunner interruptibleRunner(
66*14675a02SAndroid Build Coastguard Worker /*log_manager=*/nullptr, should_abort,
67*14675a02SAndroid Build Coastguard Worker InterruptibleRunner::TimingConfig{
68*14675a02SAndroid Build Coastguard Worker .polling_period = absl::InfiniteDuration(),
69*14675a02SAndroid Build Coastguard Worker .graceful_shutdown_period = absl::InfiniteDuration(),
70*14675a02SAndroid Build Coastguard Worker .extended_shutdown_period = absl::InfiniteDuration()},
71*14675a02SAndroid Build Coastguard Worker getDiagnosticsConfig());
72*14675a02SAndroid Build Coastguard Worker absl::Status status = interruptibleRunner.Run(
73*14675a02SAndroid Build Coastguard Worker []() { return absl::OkStatus(); }, abort_function);
74*14675a02SAndroid Build Coastguard Worker EXPECT_THAT(status, IsCode(OK));
75*14675a02SAndroid Build Coastguard Worker EXPECT_EQ(should_abort_calls, 1);
76*14675a02SAndroid Build Coastguard Worker EXPECT_EQ(abort_function_calls, 0);
77*14675a02SAndroid Build Coastguard Worker
78*14675a02SAndroid Build Coastguard Worker // Test that the Status returned by the runnable is returned as is.
79*14675a02SAndroid Build Coastguard Worker status = interruptibleRunner.Run([]() { return absl::DataLossError(""); },
80*14675a02SAndroid Build Coastguard Worker abort_function);
81*14675a02SAndroid Build Coastguard Worker EXPECT_THAT(status, IsCode(DATA_LOSS));
82*14675a02SAndroid Build Coastguard Worker }
83*14675a02SAndroid Build Coastguard Worker
84*14675a02SAndroid Build Coastguard Worker // Tests the case where should_abort prevents us from even kicking off the run.
TEST(InterruptibleRunnerTest,TestNormalAbortBeforeRun)85*14675a02SAndroid Build Coastguard Worker TEST(InterruptibleRunnerTest, TestNormalAbortBeforeRun) {
86*14675a02SAndroid Build Coastguard Worker int should_abort_calls = 0;
87*14675a02SAndroid Build Coastguard Worker int abort_function_calls = 0;
88*14675a02SAndroid Build Coastguard Worker int runnable_calls = 0;
89*14675a02SAndroid Build Coastguard Worker std::function<bool()> should_abort = [&should_abort_calls]() {
90*14675a02SAndroid Build Coastguard Worker should_abort_calls++;
91*14675a02SAndroid Build Coastguard Worker return true;
92*14675a02SAndroid Build Coastguard Worker };
93*14675a02SAndroid Build Coastguard Worker std::function<void()> abort_function = [&abort_function_calls]() {
94*14675a02SAndroid Build Coastguard Worker abort_function_calls++;
95*14675a02SAndroid Build Coastguard Worker };
96*14675a02SAndroid Build Coastguard Worker
97*14675a02SAndroid Build Coastguard Worker InterruptibleRunner interruptibleRunner(
98*14675a02SAndroid Build Coastguard Worker /*log_manager=*/nullptr, should_abort,
99*14675a02SAndroid Build Coastguard Worker InterruptibleRunner::TimingConfig{
100*14675a02SAndroid Build Coastguard Worker .polling_period = absl::InfiniteDuration(),
101*14675a02SAndroid Build Coastguard Worker .graceful_shutdown_period = absl::InfiniteDuration(),
102*14675a02SAndroid Build Coastguard Worker .extended_shutdown_period = absl::InfiniteDuration()},
103*14675a02SAndroid Build Coastguard Worker getDiagnosticsConfig());
104*14675a02SAndroid Build Coastguard Worker absl::Status status = interruptibleRunner.Run(
105*14675a02SAndroid Build Coastguard Worker [&runnable_calls]() {
106*14675a02SAndroid Build Coastguard Worker runnable_calls++;
107*14675a02SAndroid Build Coastguard Worker return absl::OkStatus();
108*14675a02SAndroid Build Coastguard Worker },
109*14675a02SAndroid Build Coastguard Worker abort_function);
110*14675a02SAndroid Build Coastguard Worker EXPECT_THAT(status, IsCode(CANCELLED));
111*14675a02SAndroid Build Coastguard Worker EXPECT_EQ(abort_function_calls, 0);
112*14675a02SAndroid Build Coastguard Worker EXPECT_EQ(runnable_calls, 0);
113*14675a02SAndroid Build Coastguard Worker }
114*14675a02SAndroid Build Coastguard Worker
115*14675a02SAndroid Build Coastguard Worker // Tests the case where the future wait times out once, we call should_abort,
116*14675a02SAndroid Build Coastguard Worker // which says to continue, and then the future returns.
TEST(InterruptibleRunnerTest,TestNormalWithAbortCheckButNoAbort)117*14675a02SAndroid Build Coastguard Worker TEST(InterruptibleRunnerTest, TestNormalWithAbortCheckButNoAbort) {
118*14675a02SAndroid Build Coastguard Worker int should_abort_calls = 0;
119*14675a02SAndroid Build Coastguard Worker int abort_function_calls = 0;
120*14675a02SAndroid Build Coastguard Worker absl::BlockingCounter counter_should_abort(1);
121*14675a02SAndroid Build Coastguard Worker absl::BlockingCounter counter_did_abort(1);
122*14675a02SAndroid Build Coastguard Worker std::function<bool()> should_abort =
123*14675a02SAndroid Build Coastguard Worker [&should_abort_calls, &counter_should_abort, &counter_did_abort]() {
124*14675a02SAndroid Build Coastguard Worker should_abort_calls++;
125*14675a02SAndroid Build Coastguard Worker if (should_abort_calls == 2) {
126*14675a02SAndroid Build Coastguard Worker counter_should_abort.DecrementCount();
127*14675a02SAndroid Build Coastguard Worker counter_did_abort.Wait();
128*14675a02SAndroid Build Coastguard Worker }
129*14675a02SAndroid Build Coastguard Worker return false;
130*14675a02SAndroid Build Coastguard Worker };
131*14675a02SAndroid Build Coastguard Worker std::function<void()> abort_function = [&abort_function_calls]() {
132*14675a02SAndroid Build Coastguard Worker abort_function_calls++;
133*14675a02SAndroid Build Coastguard Worker };
134*14675a02SAndroid Build Coastguard Worker
135*14675a02SAndroid Build Coastguard Worker InterruptibleRunner interruptibleRunner(
136*14675a02SAndroid Build Coastguard Worker nullptr, should_abort,
137*14675a02SAndroid Build Coastguard Worker InterruptibleRunner::TimingConfig{
138*14675a02SAndroid Build Coastguard Worker .polling_period = absl::ZeroDuration(),
139*14675a02SAndroid Build Coastguard Worker .graceful_shutdown_period = absl::InfiniteDuration(),
140*14675a02SAndroid Build Coastguard Worker .extended_shutdown_period = absl::InfiniteDuration()},
141*14675a02SAndroid Build Coastguard Worker getDiagnosticsConfig());
142*14675a02SAndroid Build Coastguard Worker absl::Status status = interruptibleRunner.Run(
143*14675a02SAndroid Build Coastguard Worker [&counter_should_abort, &counter_did_abort]() {
144*14675a02SAndroid Build Coastguard Worker // Block until should_abort has been called.
145*14675a02SAndroid Build Coastguard Worker counter_should_abort.Wait();
146*14675a02SAndroid Build Coastguard Worker // Tell should_abort to return false.
147*14675a02SAndroid Build Coastguard Worker counter_did_abort.DecrementCount();
148*14675a02SAndroid Build Coastguard Worker return absl::OkStatus();
149*14675a02SAndroid Build Coastguard Worker },
150*14675a02SAndroid Build Coastguard Worker abort_function);
151*14675a02SAndroid Build Coastguard Worker EXPECT_THAT(status, IsCode(OK));
152*14675a02SAndroid Build Coastguard Worker EXPECT_GE(should_abort_calls, 2);
153*14675a02SAndroid Build Coastguard Worker EXPECT_EQ(abort_function_calls, 0);
154*14675a02SAndroid Build Coastguard Worker
155*14675a02SAndroid Build Coastguard Worker status = interruptibleRunner.Run([]() { return absl::DataLossError(""); },
156*14675a02SAndroid Build Coastguard Worker abort_function);
157*14675a02SAndroid Build Coastguard Worker EXPECT_THAT(status, IsCode(DATA_LOSS));
158*14675a02SAndroid Build Coastguard Worker }
159*14675a02SAndroid Build Coastguard Worker
160*14675a02SAndroid Build Coastguard Worker // Tests the case where the runnable gets aborted and behaves nicely (aborts
161*14675a02SAndroid Build Coastguard Worker // within the grace period).
TEST(InterruptibleRunnerTest,TestAbortInGracePeriod)162*14675a02SAndroid Build Coastguard Worker TEST(InterruptibleRunnerTest, TestAbortInGracePeriod) {
163*14675a02SAndroid Build Coastguard Worker StrictMock<MockLogManager> log_manager;
164*14675a02SAndroid Build Coastguard Worker int should_abort_calls = 0;
165*14675a02SAndroid Build Coastguard Worker int abort_function_calls = 0;
166*14675a02SAndroid Build Coastguard Worker absl::BlockingCounter counter_should_abort(1);
167*14675a02SAndroid Build Coastguard Worker absl::BlockingCounter counter_did_abort(1);
168*14675a02SAndroid Build Coastguard Worker
169*14675a02SAndroid Build Coastguard Worker std::function<bool()> should_abort = [&should_abort_calls]() {
170*14675a02SAndroid Build Coastguard Worker should_abort_calls++;
171*14675a02SAndroid Build Coastguard Worker return should_abort_calls >= 2;
172*14675a02SAndroid Build Coastguard Worker };
173*14675a02SAndroid Build Coastguard Worker std::function<void()> abort_function =
174*14675a02SAndroid Build Coastguard Worker [&abort_function_calls, &counter_should_abort, &counter_did_abort]() {
175*14675a02SAndroid Build Coastguard Worker abort_function_calls++;
176*14675a02SAndroid Build Coastguard Worker // Signal runnable to abort.
177*14675a02SAndroid Build Coastguard Worker counter_should_abort.DecrementCount();
178*14675a02SAndroid Build Coastguard Worker // Wait for runnable to have aborted.
179*14675a02SAndroid Build Coastguard Worker counter_did_abort.Wait();
180*14675a02SAndroid Build Coastguard Worker };
181*14675a02SAndroid Build Coastguard Worker
182*14675a02SAndroid Build Coastguard Worker InterruptibleRunner interruptibleRunner(
183*14675a02SAndroid Build Coastguard Worker &log_manager, should_abort,
184*14675a02SAndroid Build Coastguard Worker InterruptibleRunner::TimingConfig{
185*14675a02SAndroid Build Coastguard Worker .polling_period = absl::ZeroDuration(),
186*14675a02SAndroid Build Coastguard Worker .graceful_shutdown_period = absl::InfiniteDuration(),
187*14675a02SAndroid Build Coastguard Worker .extended_shutdown_period = absl::InfiniteDuration()},
188*14675a02SAndroid Build Coastguard Worker getDiagnosticsConfig());
189*14675a02SAndroid Build Coastguard Worker // Tests that abort works.
190*14675a02SAndroid Build Coastguard Worker EXPECT_CALL(log_manager, LogDiag(BACKGROUND_TRAINING_INTERRUPT_TF_EXECUTION))
191*14675a02SAndroid Build Coastguard Worker .Times(testing::Exactly(1));
192*14675a02SAndroid Build Coastguard Worker absl::Status status = interruptibleRunner.Run(
193*14675a02SAndroid Build Coastguard Worker [&counter_should_abort, &counter_did_abort]() {
194*14675a02SAndroid Build Coastguard Worker counter_should_abort.Wait();
195*14675a02SAndroid Build Coastguard Worker counter_did_abort.DecrementCount();
196*14675a02SAndroid Build Coastguard Worker return absl::OkStatus();
197*14675a02SAndroid Build Coastguard Worker },
198*14675a02SAndroid Build Coastguard Worker abort_function);
199*14675a02SAndroid Build Coastguard Worker EXPECT_THAT(status, IsCode(CANCELLED));
200*14675a02SAndroid Build Coastguard Worker EXPECT_EQ(should_abort_calls, 2);
201*14675a02SAndroid Build Coastguard Worker EXPECT_EQ(abort_function_calls, 1);
202*14675a02SAndroid Build Coastguard Worker }
203*14675a02SAndroid Build Coastguard Worker
204*14675a02SAndroid Build Coastguard Worker // Tests the case where abort does not happen within the grace period.
205*14675a02SAndroid Build Coastguard Worker // This is achieved by only letting the runnable finish once the grace period
206*14675a02SAndroid Build Coastguard Worker // wait fails and a timeout diag code is logged, by taking an action on the
207*14675a02SAndroid Build Coastguard Worker // LogManager mock.
TEST(InterruptibleRunnerTest,TestAbortInExtendedGracePeriod)208*14675a02SAndroid Build Coastguard Worker TEST(InterruptibleRunnerTest, TestAbortInExtendedGracePeriod) {
209*14675a02SAndroid Build Coastguard Worker StrictMock<MockLogManager> log_manager;
210*14675a02SAndroid Build Coastguard Worker int should_abort_calls = 0;
211*14675a02SAndroid Build Coastguard Worker int abort_function_calls = 0;
212*14675a02SAndroid Build Coastguard Worker
213*14675a02SAndroid Build Coastguard Worker absl::BlockingCounter counter_should_abort(1);
214*14675a02SAndroid Build Coastguard Worker absl::BlockingCounter counter_did_abort(1);
215*14675a02SAndroid Build Coastguard Worker
216*14675a02SAndroid Build Coastguard Worker std::function<bool()> should_abort = [&should_abort_calls]() {
217*14675a02SAndroid Build Coastguard Worker should_abort_calls++;
218*14675a02SAndroid Build Coastguard Worker return should_abort_calls >= 2;
219*14675a02SAndroid Build Coastguard Worker };
220*14675a02SAndroid Build Coastguard Worker std::function<void()> abort_function = [&abort_function_calls]() {
221*14675a02SAndroid Build Coastguard Worker abort_function_calls++;
222*14675a02SAndroid Build Coastguard Worker };
223*14675a02SAndroid Build Coastguard Worker
224*14675a02SAndroid Build Coastguard Worker InterruptibleRunner interruptibleRunner(
225*14675a02SAndroid Build Coastguard Worker &log_manager, should_abort,
226*14675a02SAndroid Build Coastguard Worker InterruptibleRunner::TimingConfig{
227*14675a02SAndroid Build Coastguard Worker .polling_period = absl::ZeroDuration(),
228*14675a02SAndroid Build Coastguard Worker .graceful_shutdown_period = absl::ZeroDuration(),
229*14675a02SAndroid Build Coastguard Worker .extended_shutdown_period = absl::InfiniteDuration()},
230*14675a02SAndroid Build Coastguard Worker getDiagnosticsConfig());
231*14675a02SAndroid Build Coastguard Worker EXPECT_CALL(log_manager,
232*14675a02SAndroid Build Coastguard Worker LogDiag(BACKGROUND_TRAINING_INTERRUPT_TF_EXECUTION_TIMED_OUT))
233*14675a02SAndroid Build Coastguard Worker .WillOnce(
234*14675a02SAndroid Build Coastguard Worker [&counter_should_abort, &counter_did_abort](ProdDiagCode ignored) {
235*14675a02SAndroid Build Coastguard Worker counter_should_abort.DecrementCount();
236*14675a02SAndroid Build Coastguard Worker counter_did_abort.Wait();
237*14675a02SAndroid Build Coastguard Worker return absl::OkStatus();
238*14675a02SAndroid Build Coastguard Worker });
239*14675a02SAndroid Build Coastguard Worker EXPECT_CALL(
240*14675a02SAndroid Build Coastguard Worker log_manager,
241*14675a02SAndroid Build Coastguard Worker LogDiag(BACKGROUND_TRAINING_INTERRUPT_TF_EXTENDED_EXECUTION_COMPLETED))
242*14675a02SAndroid Build Coastguard Worker .Times(testing::Exactly(1));
243*14675a02SAndroid Build Coastguard Worker absl::Status status = interruptibleRunner.Run(
244*14675a02SAndroid Build Coastguard Worker [&counter_should_abort, &counter_did_abort]() {
245*14675a02SAndroid Build Coastguard Worker counter_should_abort.Wait();
246*14675a02SAndroid Build Coastguard Worker counter_did_abort.DecrementCount();
247*14675a02SAndroid Build Coastguard Worker return absl::OkStatus();
248*14675a02SAndroid Build Coastguard Worker },
249*14675a02SAndroid Build Coastguard Worker abort_function);
250*14675a02SAndroid Build Coastguard Worker
251*14675a02SAndroid Build Coastguard Worker EXPECT_THAT(status, IsCode(CANCELLED));
252*14675a02SAndroid Build Coastguard Worker EXPECT_EQ(should_abort_calls, 2);
253*14675a02SAndroid Build Coastguard Worker EXPECT_EQ(abort_function_calls, 1);
254*14675a02SAndroid Build Coastguard Worker }
255*14675a02SAndroid Build Coastguard Worker
256*14675a02SAndroid Build Coastguard Worker } // namespace
257*14675a02SAndroid Build Coastguard Worker } // namespace client
258*14675a02SAndroid Build Coastguard Worker } // namespace fcp
259