1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/distributed_runtime/preemption/preemption_notifier.h"
16 
17 #include <atomic>
18 #include <csignal>
19 #include <functional>
20 #include <memory>
21 #include <utility>
22 
23 #include "absl/synchronization/notification.h"
24 #include "absl/time/clock.h"
25 #include "absl/time/time.h"
26 #include "tensorflow/core/platform/env.h"
27 #include "tensorflow/core/platform/errors.h"
28 #include "tensorflow/core/platform/mutex.h"
29 #include "tensorflow/core/platform/platform.h"
30 #include "tensorflow/core/platform/statusor.h"
31 #if defined(PLATFORM_GOOGLE)
32 #include "thread/executor.h"
33 #include "thread/signal.h"
34 #endif
35 
36 namespace tensorflow {
37 
38 namespace {
39 constexpr absl::Duration kListenInterval = absl::Seconds(1);
40 constexpr absl::Time kUnsetDeathTime = absl::InfinitePast();
41 static std::atomic_bool sigterm_received(false);
42 
43 class SigtermNotifier : public PreemptionNotifier {
44  public:
45   explicit SigtermNotifier(Env* env);
~SigtermNotifier()46   ~SigtermNotifier() override {
47     // Trigger shutdown logic in listener thread.
48     shutdown_notification_.Notify();
49   }
50 
51  private:
52   void StartListenerThread();
53   absl::Notification shutdown_notification_;
54   std::unique_ptr<Thread> preempt_listener_thread_;
55 };
56 
SigtermNotifier(Env * env)57 SigtermNotifier::SigtermNotifier(Env* env) : PreemptionNotifier(env) {
58   sigterm_received.store(false);
59   StartListenerThread();
60 #if defined(PLATFORM_GOOGLE)
61   thread::signal::Token unused_token;
62 
63   thread::signal::AddHandler(
64       SIGTERM, thread::Executor::DefaultExecutor(),
65       []() { sigterm_received.store(true); },
66       /*flags=*/0,  // Don't override existing signal handlers.
67       &unused_token);
68 #else
69   std::signal(SIGTERM, [](int signal) { sigterm_received.store(true); });
70 #endif
71 }
72 
StartListenerThread()73 void SigtermNotifier::StartListenerThread() {
74   preempt_listener_thread_.reset(
75       GetEnv()->StartThread({}, "PreemptionNotifier_Listen", [this]() {
76         // Poll for SIGTERM receipt every kListenInterval.
77         while (!sigterm_received.load()) {
78           if (shutdown_notification_.WaitForNotificationWithTimeout(
79                   kListenInterval)) {
80             // Shutdown:
81             // 1) Cancel any pending callbacks and blocking WillBePreemptedAt()
82             // calls.
83             NotifyRegisteredListeners(
84                 errors::Cancelled("Preemption notifier is being deleted."));
85             // 2) Exit listener thread.
86             return;
87           }
88         }
89         const absl::Time death_time = absl::Now();
90         LOG(WARNING) << "SIGTERM caught at " << death_time;
91         // Notify registered listeners.
92         NotifyRegisteredListeners(death_time);
93       }));
94 }
95 
96 }  // namespace
97 
WillBePreemptedAt()98 StatusOr<absl::Time> PreemptionNotifier::WillBePreemptedAt() {
99   absl::Notification n;
100   StatusOr<absl::Time> result;
101   WillBePreemptedAtAsync([&n, &result](StatusOr<absl::Time> async_result) {
102     result = async_result;
103     n.Notify();
104   });
105   n.WaitForNotification();
106   return result;
107 }
108 
WillBePreemptedAtAsync(PreemptTimeCallback callback)109 void PreemptionNotifier::WillBePreemptedAtAsync(PreemptTimeCallback callback) {
110   mutex_lock l(mu_);
111   if (death_time_ == kUnsetDeathTime) {
112     // Did not receive preemption notice yet.
113     callbacks_.push_back(std::move(callback));
114   } else {
115     // Already received preemption notice, respond immediately.
116     callback(death_time_);
117   }
118 }
119 
NotifyRegisteredListeners(StatusOr<absl::Time> death_time)120 void PreemptionNotifier::NotifyRegisteredListeners(
121     StatusOr<absl::Time> death_time) {
122   mutex_lock l(mu_);
123   if (death_time.ok()) {
124     death_time_ = death_time.value();
125   }
126   for (const auto& callback : callbacks_) {
127     callback(death_time);
128   }
129   callbacks_.clear();
130 }
131 
132 REGISTER_PREEMPTION_NOTIFIER(
__anonefced4000602(Env* env) 133     "sigterm", [](Env* env) -> std::unique_ptr<PreemptionNotifier> {
134       return std::make_unique<SigtermNotifier>(env);
135     });
136 }  // namespace tensorflow
137