xref: /aosp_15_r20/external/pytorch/torch/csrc/monitor/counters.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <bitset>
4 #include <mutex>
5 #include <sstream>
6 #include <unordered_map>
7 #include <vector>
8 
9 #include <c10/macros/Macros.h>
10 
11 #include <torch/csrc/monitor/events.h>
12 
13 namespace torch {
14 namespace monitor {
15 
16 constexpr int NUM_AGGREGATIONS = 7;
17 
18 // Aggregation is the list of possible aggregations for Stats.
19 // These use bitwise flags so they can be efficiently stored.
20 enum class C10_API_ENUM Aggregation {
21   // NONE means no aggregations are set.
22   NONE = 0,
23   // VALUE exports the most recently set value.
24   VALUE = 1,
25   // MEAN computes the mean of the set values within the window. Zero if no
26   // values.
27   MEAN = 2,
28   // COUNT tracks the number of times a value is set within the window.
29   COUNT = 3,
30   // SUM computes the sum of the values set within the window.
31   SUM = 4,
32   // MIN computes the minimum of the values set within the window. Zero if no
33   // values.
34   MAX = 5,
35   // MAX computes the maximum of the values set within the window. Zero if no
36   // values.
37   MIN = 6,
38 };
39 
40 struct TORCH_API AggregationHash {
41   template <typename T>
operatorAggregationHash42   std::size_t operator()(T t) const {
43     return static_cast<std::size_t>(t);
44   }
45 };
46 
47 // aggregationName returns the human readable name corresponding to the
48 // aggregation.
49 TORCH_API const char* aggregationName(Aggregation agg);
50 
51 template <typename T>
52 class Stat;
53 
54 namespace {
55 template <typename T>
merge(T & list)56 inline std::bitset<NUM_AGGREGATIONS> merge(T& list) {
57   std::bitset<NUM_AGGREGATIONS> a;
58   for (Aggregation b : list) {
59     a.set(static_cast<int>(b));
60   }
61   return a;
62 }
63 } // namespace
64 
65 namespace detail {
66 void TORCH_API registerStat(Stat<double>* stat);
67 void TORCH_API registerStat(Stat<int64_t>* stat);
68 void TORCH_API unregisterStat(Stat<double>* stat);
69 void TORCH_API unregisterStat(Stat<int64_t>* stat);
70 } // namespace detail
71 
72 // Stat is used to compute summary statistics in a performant way over fixed
73 // intervals. Stat logs the statistics as an Event once every `windowSize`
74 // duration. When the window closes the stats are logged via the event handlers
75 // as a `torch.monitor.Stat` event.
76 //
77 // `windowSize` should be set to something relatively high to avoid a huge
78 // number of events being logged. Ex: 60s. Stat uses millisecond precision.
79 //
80 // If maxSamples is set, the stat will cap the number of samples per window by
81 // discarding `add` calls once `maxSamples` adds have occurred. If it's not set,
82 // all `add` calls during the window will be included.
83 // This is an optional field to make aggregations more directly comparable
84 // across windows when the number of samples might vary.
85 //
86 // Stats support double and int64_t data types depending on what needs to be
87 // logged and needs to be templatized with one of them.
88 //
89 // When the Stat is destructed it will log any remaining data even if the window
90 // hasn't elapsed.
91 template <typename T>
92 class Stat {
93  private:
94   struct Values {
95     T value{0};
96     T sum{0};
97     T min{0};
98     T max{0};
99     int64_t count{0};
100   };
101 
102  public:
103   Stat(
104       std::string name,
105       std::initializer_list<Aggregation> aggregations,
106       std::chrono::milliseconds windowSize,
107       int64_t maxSamples = std::numeric_limits<int64_t>::max())
name_(std::move (name))108       : name_(std::move(name)),
109         aggregations_(merge(aggregations)),
110         windowSize_(windowSize),
111         maxSamples_(maxSamples) {
112     detail::registerStat(this);
113   }
114 
115   Stat(
116       std::string name,
117       std::vector<Aggregation> aggregations,
118       std::chrono::milliseconds windowSize,
119       int64_t maxSamples = std::numeric_limits<int64_t>::max())
name_(std::move (name))120       : name_(std::move(name)),
121         aggregations_(merge(aggregations)),
122         windowSize_(windowSize),
123         maxSamples_(maxSamples) {
124     detail::registerStat(this);
125   }
126 
~Stat()127   virtual ~Stat() {
128     {
129       // on destruction log if there's unlogged data
130       std::lock_guard<std::mutex> guard(mu_);
131       logLocked();
132     }
133     detail::unregisterStat(this);
134   }
135 
136   // add adds the value v to the current window.
add(T v)137   void add(T v) {
138     std::lock_guard<std::mutex> guard(mu_);
139     maybeLogLocked();
140 
141     if (alreadyLogged()) {
142       return;
143     }
144 
145     if (aggregations_.test(static_cast<int>(Aggregation::VALUE))) {
146       current_.value = v;
147     }
148     if (aggregations_.test(static_cast<int>(Aggregation::MEAN)) ||
149         aggregations_.test(static_cast<int>(Aggregation::SUM))) {
150       current_.sum += v;
151     }
152 
153     if (aggregations_.test(static_cast<int>(Aggregation::MAX))) {
154       if (current_.max < v || current_.count == 0) {
155         current_.max = v;
156       }
157     }
158     if (aggregations_.test(static_cast<int>(Aggregation::MIN))) {
159       if (current_.min > v || current_.count == 0) {
160         current_.min = v;
161       }
162     }
163 
164     current_.count += 1;
165     maybeLogLocked();
166   }
167 
name()168   const std::string& name() const noexcept {
169     return name_;
170   }
171 
172   // count returns the number of items in the current open window.
count()173   int64_t count() noexcept {
174     std::lock_guard<std::mutex> guard(mu_);
175 
176     return current_.count;
177   }
178 
get()179   std::unordered_map<Aggregation, T, AggregationHash> get() noexcept {
180     std::lock_guard<std::mutex> guard(mu_);
181     return getLocked();
182   }
183 
184  protected:
currentWindowId()185   virtual uint64_t currentWindowId() const {
186     std::chrono::milliseconds now =
187         std::chrono::duration_cast<std::chrono::milliseconds>(
188             std::chrono::steady_clock::now().time_since_epoch());
189 
190     // always returns a currentWindowId of at least 1 to avoid 0 window issues
191     return (now / windowSize_) + 1;
192   }
193 
194  private:
alreadyLogged()195   bool alreadyLogged() {
196     return lastLoggedWindowId_ == currentWindowId();
197   }
198 
maybeLogLocked()199   void maybeLogLocked() {
200     auto windowId = currentWindowId();
201     bool shouldLog = windowId_ != windowId || current_.count >= maxSamples_;
202     if (shouldLog && !alreadyLogged()) {
203       logLocked();
204       lastLoggedWindowId_ = windowId_;
205       windowId_ = windowId;
206     }
207   }
208 
logLocked()209   void logLocked() {
210     prev_ = current_;
211     current_ = Values();
212 
213     // don't log event if there's no data
214     if (prev_.count == 0) {
215       return;
216     }
217 
218     Event e;
219     e.name = "torch.monitor.Stat";
220     e.timestamp = std::chrono::system_clock::now();
221 
222     auto stats = getLocked();
223     e.data.reserve(stats.size());
224     for (auto& kv : stats) {
225       std::stringstream key;
226       key << name_;
227       key << ".";
228       key << aggregationName(kv.first);
229       e.data[key.str()] = kv.second;
230     }
231 
232     logEvent(e);
233   }
234 
getLocked()235   std::unordered_map<Aggregation, T, AggregationHash> getLocked()
236       const noexcept {
237     std::unordered_map<Aggregation, T, AggregationHash> out;
238     out.reserve(aggregations_.count());
239 
240     if (aggregations_.test(static_cast<int>(Aggregation::VALUE))) {
241       out.emplace(Aggregation::VALUE, prev_.value);
242     }
243     if (aggregations_.test(static_cast<int>(Aggregation::MEAN))) {
244       if (prev_.count == 0) {
245         out.emplace(Aggregation::MEAN, 0);
246       } else {
247         out.emplace(Aggregation::MEAN, prev_.sum / prev_.count);
248       }
249     }
250     if (aggregations_.test(static_cast<int>(Aggregation::COUNT))) {
251       out.emplace(Aggregation::COUNT, prev_.count);
252     }
253     if (aggregations_.test(static_cast<int>(Aggregation::SUM))) {
254       out.emplace(Aggregation::SUM, prev_.sum);
255     }
256     if (aggregations_.test(static_cast<int>(Aggregation::MAX))) {
257       out.emplace(Aggregation::MAX, prev_.max);
258     }
259     if (aggregations_.test(static_cast<int>(Aggregation::MIN))) {
260       out.emplace(Aggregation::MIN, prev_.min);
261     }
262 
263     return out;
264   }
265 
266   const std::string name_;
267   const std::bitset<NUM_AGGREGATIONS> aggregations_;
268 
269   std::mutex mu_;
270   Values current_;
271   Values prev_;
272 
273   uint64_t windowId_{0};
274   uint64_t lastLoggedWindowId_{0};
275   const std::chrono::milliseconds windowSize_;
276   const int64_t maxSamples_;
277 };
278 } // namespace monitor
279 } // namespace torch
280