1 /*
2 * Copyright 2020 Google LLC
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "fcp/client/interruptible_runner.h"
17
18 #include <functional>
19
20 #include "gmock/gmock.h"
21 #include "gtest/gtest.h"
22 #include "absl/status/status.h"
23 #include "absl/synchronization/blocking_counter.h"
24 #include "absl/time/time.h"
25 #include "fcp/client/diag_codes.pb.h"
26 #include "fcp/client/test_helpers.h"
27 #include "fcp/testing/testing.h"
28
29 namespace fcp {
30 namespace client {
31 namespace {
32
33 using ::fcp::client::ProdDiagCode::BACKGROUND_TRAINING_INTERRUPT_TF_EXECUTION;
34 using ::fcp::client::ProdDiagCode::
35 BACKGROUND_TRAINING_INTERRUPT_TF_EXECUTION_TIMED_OUT;
36 using ::fcp::client::ProdDiagCode::
37 BACKGROUND_TRAINING_INTERRUPT_TF_EXTENDED_EXECUTION_COMPLETED;
38 using ::fcp::client::ProdDiagCode::
39 BACKGROUND_TRAINING_INTERRUPT_TF_EXTENDED_EXECUTION_TIMED_OUT;
40 using ::testing::StrictMock;
41
getDiagnosticsConfig()42 static InterruptibleRunner::DiagnosticsConfig getDiagnosticsConfig() {
43 return InterruptibleRunner::DiagnosticsConfig{
44 .interrupted = BACKGROUND_TRAINING_INTERRUPT_TF_EXECUTION,
45 .interrupt_timeout = BACKGROUND_TRAINING_INTERRUPT_TF_EXECUTION_TIMED_OUT,
46 .interrupted_extended =
47 BACKGROUND_TRAINING_INTERRUPT_TF_EXTENDED_EXECUTION_COMPLETED,
48 .interrupt_timeout_extended =
49 BACKGROUND_TRAINING_INTERRUPT_TF_EXTENDED_EXECUTION_TIMED_OUT};
50 }
51
52 // Tests the case where runnable finishes before the future times out (and we'd
53 // call should_abort).
TEST(InterruptibleRunnerTest,TestNormalNoAbortCheck)54 TEST(InterruptibleRunnerTest, TestNormalNoAbortCheck) {
55 int should_abort_calls = 0;
56 int abort_function_calls = 0;
57 std::function<bool()> should_abort = [&should_abort_calls]() {
58 should_abort_calls++;
59 return false;
60 };
61 std::function<void()> abort_function = [&abort_function_calls]() {
62 abort_function_calls++;
63 };
64
65 InterruptibleRunner interruptibleRunner(
66 /*log_manager=*/nullptr, should_abort,
67 InterruptibleRunner::TimingConfig{
68 .polling_period = absl::InfiniteDuration(),
69 .graceful_shutdown_period = absl::InfiniteDuration(),
70 .extended_shutdown_period = absl::InfiniteDuration()},
71 getDiagnosticsConfig());
72 absl::Status status = interruptibleRunner.Run(
73 []() { return absl::OkStatus(); }, abort_function);
74 EXPECT_THAT(status, IsCode(OK));
75 EXPECT_EQ(should_abort_calls, 1);
76 EXPECT_EQ(abort_function_calls, 0);
77
78 // Test that the Status returned by the runnable is returned as is.
79 status = interruptibleRunner.Run([]() { return absl::DataLossError(""); },
80 abort_function);
81 EXPECT_THAT(status, IsCode(DATA_LOSS));
82 }
83
84 // Tests the case where should_abort prevents us from even kicking off the run.
TEST(InterruptibleRunnerTest,TestNormalAbortBeforeRun)85 TEST(InterruptibleRunnerTest, TestNormalAbortBeforeRun) {
86 int should_abort_calls = 0;
87 int abort_function_calls = 0;
88 int runnable_calls = 0;
89 std::function<bool()> should_abort = [&should_abort_calls]() {
90 should_abort_calls++;
91 return true;
92 };
93 std::function<void()> abort_function = [&abort_function_calls]() {
94 abort_function_calls++;
95 };
96
97 InterruptibleRunner interruptibleRunner(
98 /*log_manager=*/nullptr, should_abort,
99 InterruptibleRunner::TimingConfig{
100 .polling_period = absl::InfiniteDuration(),
101 .graceful_shutdown_period = absl::InfiniteDuration(),
102 .extended_shutdown_period = absl::InfiniteDuration()},
103 getDiagnosticsConfig());
104 absl::Status status = interruptibleRunner.Run(
105 [&runnable_calls]() {
106 runnable_calls++;
107 return absl::OkStatus();
108 },
109 abort_function);
110 EXPECT_THAT(status, IsCode(CANCELLED));
111 EXPECT_EQ(abort_function_calls, 0);
112 EXPECT_EQ(runnable_calls, 0);
113 }
114
115 // Tests the case where the future wait times out once, we call should_abort,
116 // which says to continue, and then the future returns.
TEST(InterruptibleRunnerTest,TestNormalWithAbortCheckButNoAbort)117 TEST(InterruptibleRunnerTest, TestNormalWithAbortCheckButNoAbort) {
118 int should_abort_calls = 0;
119 int abort_function_calls = 0;
120 absl::BlockingCounter counter_should_abort(1);
121 absl::BlockingCounter counter_did_abort(1);
122 std::function<bool()> should_abort =
123 [&should_abort_calls, &counter_should_abort, &counter_did_abort]() {
124 should_abort_calls++;
125 if (should_abort_calls == 2) {
126 counter_should_abort.DecrementCount();
127 counter_did_abort.Wait();
128 }
129 return false;
130 };
131 std::function<void()> abort_function = [&abort_function_calls]() {
132 abort_function_calls++;
133 };
134
135 InterruptibleRunner interruptibleRunner(
136 nullptr, should_abort,
137 InterruptibleRunner::TimingConfig{
138 .polling_period = absl::ZeroDuration(),
139 .graceful_shutdown_period = absl::InfiniteDuration(),
140 .extended_shutdown_period = absl::InfiniteDuration()},
141 getDiagnosticsConfig());
142 absl::Status status = interruptibleRunner.Run(
143 [&counter_should_abort, &counter_did_abort]() {
144 // Block until should_abort has been called.
145 counter_should_abort.Wait();
146 // Tell should_abort to return false.
147 counter_did_abort.DecrementCount();
148 return absl::OkStatus();
149 },
150 abort_function);
151 EXPECT_THAT(status, IsCode(OK));
152 EXPECT_GE(should_abort_calls, 2);
153 EXPECT_EQ(abort_function_calls, 0);
154
155 status = interruptibleRunner.Run([]() { return absl::DataLossError(""); },
156 abort_function);
157 EXPECT_THAT(status, IsCode(DATA_LOSS));
158 }
159
160 // Tests the case where the runnable gets aborted and behaves nicely (aborts
161 // within the grace period).
TEST(InterruptibleRunnerTest,TestAbortInGracePeriod)162 TEST(InterruptibleRunnerTest, TestAbortInGracePeriod) {
163 StrictMock<MockLogManager> log_manager;
164 int should_abort_calls = 0;
165 int abort_function_calls = 0;
166 absl::BlockingCounter counter_should_abort(1);
167 absl::BlockingCounter counter_did_abort(1);
168
169 std::function<bool()> should_abort = [&should_abort_calls]() {
170 should_abort_calls++;
171 return should_abort_calls >= 2;
172 };
173 std::function<void()> abort_function =
174 [&abort_function_calls, &counter_should_abort, &counter_did_abort]() {
175 abort_function_calls++;
176 // Signal runnable to abort.
177 counter_should_abort.DecrementCount();
178 // Wait for runnable to have aborted.
179 counter_did_abort.Wait();
180 };
181
182 InterruptibleRunner interruptibleRunner(
183 &log_manager, should_abort,
184 InterruptibleRunner::TimingConfig{
185 .polling_period = absl::ZeroDuration(),
186 .graceful_shutdown_period = absl::InfiniteDuration(),
187 .extended_shutdown_period = absl::InfiniteDuration()},
188 getDiagnosticsConfig());
189 // Tests that abort works.
190 EXPECT_CALL(log_manager, LogDiag(BACKGROUND_TRAINING_INTERRUPT_TF_EXECUTION))
191 .Times(testing::Exactly(1));
192 absl::Status status = interruptibleRunner.Run(
193 [&counter_should_abort, &counter_did_abort]() {
194 counter_should_abort.Wait();
195 counter_did_abort.DecrementCount();
196 return absl::OkStatus();
197 },
198 abort_function);
199 EXPECT_THAT(status, IsCode(CANCELLED));
200 EXPECT_EQ(should_abort_calls, 2);
201 EXPECT_EQ(abort_function_calls, 1);
202 }
203
204 // Tests the case where abort does not happen within the grace period.
205 // This is achieved by only letting the runnable finish once the grace period
206 // wait fails and a timeout diag code is logged, by taking an action on the
207 // LogManager mock.
TEST(InterruptibleRunnerTest,TestAbortInExtendedGracePeriod)208 TEST(InterruptibleRunnerTest, TestAbortInExtendedGracePeriod) {
209 StrictMock<MockLogManager> log_manager;
210 int should_abort_calls = 0;
211 int abort_function_calls = 0;
212
213 absl::BlockingCounter counter_should_abort(1);
214 absl::BlockingCounter counter_did_abort(1);
215
216 std::function<bool()> should_abort = [&should_abort_calls]() {
217 should_abort_calls++;
218 return should_abort_calls >= 2;
219 };
220 std::function<void()> abort_function = [&abort_function_calls]() {
221 abort_function_calls++;
222 };
223
224 InterruptibleRunner interruptibleRunner(
225 &log_manager, should_abort,
226 InterruptibleRunner::TimingConfig{
227 .polling_period = absl::ZeroDuration(),
228 .graceful_shutdown_period = absl::ZeroDuration(),
229 .extended_shutdown_period = absl::InfiniteDuration()},
230 getDiagnosticsConfig());
231 EXPECT_CALL(log_manager,
232 LogDiag(BACKGROUND_TRAINING_INTERRUPT_TF_EXECUTION_TIMED_OUT))
233 .WillOnce(
234 [&counter_should_abort, &counter_did_abort](ProdDiagCode ignored) {
235 counter_should_abort.DecrementCount();
236 counter_did_abort.Wait();
237 return absl::OkStatus();
238 });
239 EXPECT_CALL(
240 log_manager,
241 LogDiag(BACKGROUND_TRAINING_INTERRUPT_TF_EXTENDED_EXECUTION_COMPLETED))
242 .Times(testing::Exactly(1));
243 absl::Status status = interruptibleRunner.Run(
244 [&counter_should_abort, &counter_did_abort]() {
245 counter_should_abort.Wait();
246 counter_did_abort.DecrementCount();
247 return absl::OkStatus();
248 },
249 abort_function);
250
251 EXPECT_THAT(status, IsCode(CANCELLED));
252 EXPECT_EQ(should_abort_calls, 2);
253 EXPECT_EQ(abort_function_calls, 1);
254 }
255
256 } // namespace
257 } // namespace client
258 } // namespace fcp
259