1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/distributed_runtime/preemption/preemption_notifier.h"
16
17 #include <atomic>
18 #include <csignal>
19 #include <functional>
20 #include <memory>
21 #include <utility>
22
23 #include "absl/synchronization/notification.h"
24 #include "absl/time/clock.h"
25 #include "absl/time/time.h"
26 #include "tensorflow/core/platform/env.h"
27 #include "tensorflow/core/platform/errors.h"
28 #include "tensorflow/core/platform/mutex.h"
29 #include "tensorflow/core/platform/platform.h"
30 #include "tensorflow/core/platform/statusor.h"
31 #if defined(PLATFORM_GOOGLE)
32 #include "thread/executor.h"
33 #include "thread/signal.h"
34 #endif
35
36 namespace tensorflow {
37
38 namespace {
39 constexpr absl::Duration kListenInterval = absl::Seconds(1);
40 constexpr absl::Time kUnsetDeathTime = absl::InfinitePast();
41 static std::atomic_bool sigterm_received(false);
42
43 class SigtermNotifier : public PreemptionNotifier {
44 public:
45 explicit SigtermNotifier(Env* env);
~SigtermNotifier()46 ~SigtermNotifier() override {
47 // Trigger shutdown logic in listener thread.
48 shutdown_notification_.Notify();
49 }
50
51 private:
52 void StartListenerThread();
53 absl::Notification shutdown_notification_;
54 std::unique_ptr<Thread> preempt_listener_thread_;
55 };
56
SigtermNotifier(Env * env)57 SigtermNotifier::SigtermNotifier(Env* env) : PreemptionNotifier(env) {
58 sigterm_received.store(false);
59 StartListenerThread();
60 #if defined(PLATFORM_GOOGLE)
61 thread::signal::Token unused_token;
62
63 thread::signal::AddHandler(
64 SIGTERM, thread::Executor::DefaultExecutor(),
65 []() { sigterm_received.store(true); },
66 /*flags=*/0, // Don't override existing signal handlers.
67 &unused_token);
68 #else
69 std::signal(SIGTERM, [](int signal) { sigterm_received.store(true); });
70 #endif
71 }
72
StartListenerThread()73 void SigtermNotifier::StartListenerThread() {
74 preempt_listener_thread_.reset(
75 GetEnv()->StartThread({}, "PreemptionNotifier_Listen", [this]() {
76 // Poll for SIGTERM receipt every kListenInterval.
77 while (!sigterm_received.load()) {
78 if (shutdown_notification_.WaitForNotificationWithTimeout(
79 kListenInterval)) {
80 // Shutdown:
81 // 1) Cancel any pending callbacks and blocking WillBePreemptedAt()
82 // calls.
83 NotifyRegisteredListeners(
84 errors::Cancelled("Preemption notifier is being deleted."));
85 // 2) Exit listener thread.
86 return;
87 }
88 }
89 const absl::Time death_time = absl::Now();
90 LOG(WARNING) << "SIGTERM caught at " << death_time;
91 // Notify registered listeners.
92 NotifyRegisteredListeners(death_time);
93 }));
94 }
95
96 } // namespace
97
WillBePreemptedAt()98 StatusOr<absl::Time> PreemptionNotifier::WillBePreemptedAt() {
99 absl::Notification n;
100 StatusOr<absl::Time> result;
101 WillBePreemptedAtAsync([&n, &result](StatusOr<absl::Time> async_result) {
102 result = async_result;
103 n.Notify();
104 });
105 n.WaitForNotification();
106 return result;
107 }
108
WillBePreemptedAtAsync(PreemptTimeCallback callback)109 void PreemptionNotifier::WillBePreemptedAtAsync(PreemptTimeCallback callback) {
110 mutex_lock l(mu_);
111 if (death_time_ == kUnsetDeathTime) {
112 // Did not receive preemption notice yet.
113 callbacks_.push_back(std::move(callback));
114 } else {
115 // Already received preemption notice, respond immediately.
116 callback(death_time_);
117 }
118 }
119
NotifyRegisteredListeners(StatusOr<absl::Time> death_time)120 void PreemptionNotifier::NotifyRegisteredListeners(
121 StatusOr<absl::Time> death_time) {
122 mutex_lock l(mu_);
123 if (death_time.ok()) {
124 death_time_ = death_time.value();
125 }
126 for (const auto& callback : callbacks_) {
127 callback(death_time);
128 }
129 callbacks_.clear();
130 }
131
132 REGISTER_PREEMPTION_NOTIFIER(
__anonefced4000602(Env* env) 133 "sigterm", [](Env* env) -> std::unique_ptr<PreemptionNotifier> {
134 return std::make_unique<SigtermNotifier>(env);
135 });
136 } // namespace tensorflow
137