xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/slow_operation_alarm.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/slow_operation_alarm.h"
17 
18 #include <functional>
19 #include <iterator>
20 #include <list>
21 #include <memory>
22 #include <string>
23 #include <utility>
24 
25 #include "absl/algorithm/container.h"
26 #include "absl/base/call_once.h"
27 #include "absl/base/thread_annotations.h"
28 #include "absl/numeric/bits.h"
29 #include "absl/strings/string_view.h"
30 #include "absl/synchronization/mutex.h"
31 #include "absl/time/time.h"
32 #include "tensorflow/core/platform/env.h"
33 
34 namespace xla {
35 namespace {
36 
37 absl::Mutex mu(absl::kConstInit);
38 absl::CondVar* ready;
39 absl::once_flag init_flag;
40 std::list<SlowOperationAlarm*>* outstanding_alarms ABSL_PT_GUARDED_BY(mu) =
41     nullptr;
42 
43 }  // namespace
44 
AlarmLoop()45 void SlowOperationAlarm::AlarmLoop() {
46   while (true) {
47     absl::MutexLock lock(&mu);
48 
49     // Fire any alarms which are ready.
50     absl::Time now = absl::Now();
51     for (auto it = outstanding_alarms->begin();
52          it != outstanding_alarms->end();) {
53       auto next = std::next(it);
54       auto* alarm = *it;
55       // Fire the alarm if applicable.
56       if (alarm->deadline() <= now) {
57         outstanding_alarms->erase(it);
58         const int64_t count =
59             alarm->counter() == nullptr ? 0 : alarm->counter()->fetch_add(1);
60         // If the alarm has a counter, only fire if the count is a power of 2.
61         if (count == 0 || absl::has_single_bit<uint64_t>(count) == 0) {
62           alarm->fired_.store(true);
63           // We fire alarms with LOG(ERROR) because otherwise it might not show
64           // up without --logtostderr.
65           LOG(ERROR) << alarm->msg();
66         }
67       }
68       it = next;
69     }
70 
71     auto next_alarm = absl::c_min_element(
72         *outstanding_alarms,
73         [](const SlowOperationAlarm* a, const SlowOperationAlarm* b) {
74           return a->deadline() < b->deadline();
75         });
76     const absl::Time deadline = next_alarm != outstanding_alarms->end()
77                                     ? (*next_alarm)->deadline()
78                                     : absl::InfiniteFuture();
79 
80     ready->WaitWithDeadline(&mu, deadline);
81   }
82 }
83 
ScheduleAlarm(SlowOperationAlarm * alarm)84 void SlowOperationAlarm::ScheduleAlarm(SlowOperationAlarm* alarm) {
85   absl::call_once(init_flag, [] {
86     ready = new absl::CondVar();
87     outstanding_alarms = new std::list<SlowOperationAlarm*>();
88     (void)tensorflow::Env::Default()->StartThread(
89         tensorflow::ThreadOptions(), "SlowOperationAlarm", [] { AlarmLoop(); });
90   });
91 
92   absl::MutexLock lock(&mu);
93   outstanding_alarms->push_back(alarm);
94   ready->Signal();
95 }
96 
UnscheduleAlarm(const SlowOperationAlarm * alarm)97 void SlowOperationAlarm::UnscheduleAlarm(const SlowOperationAlarm* alarm) {
98   absl::MutexLock lock(&mu);
99   CHECK(outstanding_alarms != nullptr);
100   auto it = absl::c_find(*outstanding_alarms, alarm);
101   if (it != outstanding_alarms->end()) {
102     outstanding_alarms->erase(it);
103   }
104 }
SlowOperationAlarm(absl::Duration timeout,std::string msg,std::atomic<int64_t> * counter,absl::string_view context)105 SlowOperationAlarm::SlowOperationAlarm(
106     absl::Duration timeout, std::string msg,
107     std::atomic<int64_t>* counter /*=nullptr*/,
108     absl::string_view context /*=""*/)
109     : SlowOperationAlarm(
110           timeout,                                 //
111           [msg = std::move(msg)] { return msg; },  //
112           counter, std::move(context)) {}
113 
SlowOperationAlarm(absl::Duration timeout,std::function<std::string ()> msg_fn,std::atomic<int64_t> * counter,absl::string_view context)114 SlowOperationAlarm::SlowOperationAlarm(
115     absl::Duration timeout, std::function<std::string()> msg_fn,
116     std::atomic<int64_t>* counter /*=nullptr*/,
117     absl::string_view context /*=""*/)
118     : start_(absl::Now()),
119       deadline_(start_ + timeout),
120       context_(std::move(context)),
121       msg_fn_(std::move(msg_fn)),
122       counter_(counter) {
123   ScheduleAlarm(this);
124 }
125 
~SlowOperationAlarm()126 SlowOperationAlarm::~SlowOperationAlarm() {
127   UnscheduleAlarm(this);
128 
129   absl::Time now = absl::Now();
130   if (deadline() <= now) {
131     absl::Duration duration = now - start_;
132     if (context_.empty()) {
133       LOG(ERROR) << "The operation took " << absl::FormatDuration(duration)
134                  << "\n"
135                  << msg_fn_();
136     } else {
137       LOG(ERROR) << "[" << context_ << "] The operation took "
138                  << absl::FormatDuration(duration) << "\n"
139                  << msg_fn_();
140     }
141   }
142 }
143 
SlowCompilationAlarm(absl::string_view context)144 std::unique_ptr<SlowOperationAlarm> SlowCompilationAlarm(
145     absl::string_view context) {
146   // Pass a counter to these alarms so they only log once every power-of-two
147   // occurrences.
148   static auto* counter = new std::atomic<int64_t>(0);
149 
150   const char* separator = "\n********************************";
151 
152   std::string context_msg;
153   if (!context.empty()) {
154     context_msg = absl::StrCat("[", context, "] ");
155   }
156 
157 #if NDEBUG
158   return std::make_unique<SlowOperationAlarm>(
159       absl::Duration(absl::Minutes(2)),
160       absl::StrCat(
161           separator, "\n", context_msg,
162           "Very slow compile?  If you want to file a bug, run with envvar "
163           "XLA_FLAGS=--xla_dump_to=/tmp/foo and attach the results.",
164           separator),
165       counter);
166 #else
167   return std::make_unique<SlowOperationAlarm>(
168       absl::Duration(absl::Seconds(10)),
169       absl::StrCat(
170           separator, "\n", context_msg,
171           "Slow compile?  XLA was built without compiler optimizations, "
172           "which can be slow.  Try rebuilding with -c opt.",
173           separator),
174       counter);
175 #endif
176 }
177 
178 }  // namespace xla
179