xref: /aosp_15_r20/external/pytorch/torch/csrc/lazy/core/metrics.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 /**
2  * This file is adapted from PyTorch/XLA
3  * https://github.com/pytorch/xla/blob/master/third_party/xla_client/metrics.h
4  */
5 
6 #pragma once
7 
8 #include <atomic>
9 #include <functional>
10 #include <map>
11 #include <memory>
12 #include <mutex>
13 #include <string>
14 #include <vector>
15 
16 #include <c10/macros/Export.h>
17 
18 namespace torch {
19 namespace lazy {
20 
21 struct TORCH_API Sample {
22   Sample() = default;
SampleSample23   Sample(int64_t timestamp_ns, double value)
24       : timestamp_ns(timestamp_ns), value(value) {}
25 
26   int64_t timestamp_ns = 0;
27   double value = 0;
28 };
29 
30 using MetricReprFn = std::function<std::string(double)>;
31 
32 // Class used to collect time-stamped numeric samples. The samples are stored in
33 // a circular buffer whose size can be configured at constructor time.
34 class TORCH_API MetricData {
35  public:
36   // Creates a new MetricData object with the internal circular buffer storing
37   // max_samples samples. The repr_fn argument allow to specify a function which
38   // pretty-prints a sample value.
39   MetricData(MetricReprFn repr_fn, size_t max_samples);
40 
41   // Returns the total values of all the samples being posted to this metric.
42   double Accumulator() const;
43 
44   size_t TotalSamples() const;
45 
46   void AddSample(int64_t timestamp_ns, double value);
47 
48   // Returns a vector with all the current samples, from the oldest to the
49   // newer. If accumulator is not nullptr, it will receive the current value of
50   // the metrics' accumulator (the sum of all posted values). If total_samples
51   // is not nullptr, it will receive the count of the posted values.
52   std::vector<Sample> Samples(double* accumulator, size_t* total_samples) const;
53 
Repr(double value)54   std::string Repr(double value) const {
55     return repr_fn_(value);
56   }
57 
58   void Reset();
59 
IsValid()60   bool IsValid() const {
61     return TotalSamples() > 0;
62   }
63 
64  private:
65   mutable std::mutex lock_;
66   MetricReprFn repr_fn_;
67   size_t count_ = 0;
68   std::vector<Sample> samples_;
69   double accumulator_ = 0.0;
70 };
71 
72 // Counters are a very lightweight form of metrics which do not need to track
73 // sample time.
74 class TORCH_API CounterData {
75  public:
CounterData()76   CounterData() : value_(0) {}
77 
AddValue(int64_t value)78   void AddValue(int64_t value) {
79     value_ += value;
80   }
81 
Value()82   int64_t Value() const {
83     return value_;
84   }
85 
Reset()86   void Reset() {
87     value_ = 0;
88   }
89 
IsValid()90   bool IsValid() const {
91     return value_ > 0;
92   }
93 
94  private:
95   std::atomic<int64_t> value_;
96 };
97 
98 class TORCH_API MetricsArena {
99  public:
100   static MetricsArena* Get();
101 
102   void ResetCounters();
103   void ResetMetrics();
104 
105   // Registers a new metric in the global arena.
106   void RegisterMetric(
107       const std::string& name,
108       MetricReprFn repr_fn,
109       size_t max_samples,
110       std::shared_ptr<MetricData>* data);
111 
112   void RegisterCounter(
113       const std::string& name,
114       std::shared_ptr<CounterData>* data);
115 
116   void ForEachMetric(
117       const std::function<void(const std::string&, MetricData*)>& metric_func);
118 
119   void ForEachCounter(
120       const std::function<void(const std::string&, CounterData*)>&
121           counter_func);
122 
123   std::vector<std::string> GetMetricNames();
124 
125   MetricData* GetMetric(const std::string& name);
126 
127   std::vector<std::string> GetCounterNames();
128 
129   CounterData* GetCounter(const std::string& name);
130 
131  private:
132   std::mutex lock_;
133   std::map<std::string, std::shared_ptr<MetricData>> metrics_;
134   std::map<std::string, std::shared_ptr<CounterData>> counters_;
135 };
136 
137 // Emits the value in a to_string() conversion.
138 TORCH_API std::string MetricFnValue(double value);
139 // Emits the value in a humanized bytes representation.
140 TORCH_API std::string MetricFnBytes(double value);
141 // Emits the value in a humanized time representation. The value is expressed in
142 // nanoseconds EPOCH time.
143 TORCH_API std::string MetricFnTime(double value);
144 
145 // The typical use of a Metric is one in which it gets created either in a
146 // global scope context:
147 //   static Metric* metric = new Metric("RpcCount");
148 // Or within a function scope:
149 //   void MyFunction(...) {
150 //     static Metric* metric = new Metric("RpcCount");
151 //     ...
152 //     metric->AddSample(ts_nanos, some_value);
153 //   }
154 class TORCH_API Metric {
155  public:
156   explicit Metric(
157       std::string name,
158       MetricReprFn repr_fn = MetricFnValue,
159       size_t max_samples = 0);
160 
Name()161   const std::string& Name() const {
162     return name_;
163   }
164 
165   double Accumulator() const;
166 
167   void AddSample(int64_t timestamp_ns, double value);
168 
169   void AddSample(double value);
170 
171   std::vector<Sample> Samples(double* accumulator, size_t* total_samples) const;
172 
173   std::string Repr(double value) const;
174 
175  private:
176   MetricData* GetData() const;
177 
178   std::string name_;
179   MetricReprFn repr_fn_;
180   size_t max_samples_;
181   mutable std::shared_ptr<MetricData> data_ptr_;
182   mutable std::atomic<MetricData*> data_;
183 };
184 
185 // A Counter is a lightweight form of metric which tracks an integer value which
186 // can increase or decrease.
187 // A typical use is as:
188 //   static Counter* counter = new Counter("MyCounter");
189 //   ...
190 //   counter->AddValue(+1);
191 class TORCH_API Counter {
192  public:
193   explicit Counter(std::string name);
194 
AddValue(int64_t value)195   void AddValue(int64_t value) {
196     GetData()->AddValue(value);
197   }
198 
Value()199   int64_t Value() const {
200     return GetData()->Value();
201   }
202 
203  private:
204   CounterData* GetData() const;
205 
206   std::string name_;
207   mutable std::shared_ptr<CounterData> data_ptr_;
208   mutable std::atomic<CounterData*> data_;
209 };
210 
211 #define TORCH_LAZY_COUNTER(name, value)        \
212   do {                                         \
213     static ::torch::lazy::Counter* __counter = \
214         new ::torch::lazy::Counter(name);      \
215     __counter->AddValue(value);                \
216   } while (0)
217 
218 #define TORCH_LAZY_FN_COUNTER(ns) TORCH_LAZY_COUNTER(c10::str(ns, __func__), 1)
219 
220 #define TORCH_LAZY_VALUE_METRIC(name, value)                         \
221   do {                                                               \
222     static ::torch::lazy::Metric* __metric =                         \
223         new ::torch::lazy::Metric(name, torch::lazy::MetricFnValue); \
224     __metric->AddSample(value);                                      \
225   } while (0)
226 
227 // Creates a report with the current metrics statistics.
228 TORCH_API std::string CreateMetricReport();
229 
230 // Creates a report with the selected metrics statistics.
231 TORCH_API std::string CreateMetricReport(
232     const std::vector<std::string>& counter_names,
233     const std::vector<std::string>& metric_names);
234 
235 // Returns the currently registered metric names. Note that the list can grow
236 // since metrics are usually function intialized (they are static function
237 // variables).
238 TORCH_API std::vector<std::string> GetMetricNames();
239 
240 // Retrieves the metric data of a given metric, or nullptr if such metric does
241 // not exist.
242 TORCH_API MetricData* GetMetric(const std::string& name);
243 
244 // Returns the currently registered counter names. Note that the list can grow
245 // since counters are usually function intialized (they are static function
246 // variables).
247 TORCH_API std::vector<std::string> GetCounterNames();
248 
249 // Retrieves the counter data of a given counter, or nullptr if such counter
250 // does not exist.
251 TORCH_API CounterData* GetCounter(const std::string& name);
252 
253 // Retrieves the current EPOCH time in nanoseconds.
254 TORCH_API int64_t NowNs();
255 
256 // Scope based utility class TORCH_API to measure the time the code takes within
257 // a given C++ scope.
258 class TORCH_API TimedSection {
259  public:
TimedSection(Metric * metric)260   explicit TimedSection(Metric* metric) : metric_(metric), start_(NowNs()) {}
261 
~TimedSection()262   ~TimedSection() {
263     int64_t now = NowNs();
264     metric_->AddSample(now, now - start_);
265   }
266 
Elapsed()267   double Elapsed() const {
268     return 1e-9 * static_cast<double>(NowNs() - start_);
269   }
270 
271  private:
272   Metric* metric_;
273   int64_t start_;
274 };
275 
276 #define TORCH_LAZY_TIMED(name)                                  \
277   static torch::lazy::Metric* timed_metric =                    \
278       new torch::lazy::Metric(name, torch::lazy::MetricFnTime); \
279   torch::lazy::TimedSection timed_section(timed_metric)
280 
281 #define TORCH_LAZY_FN_COUNTER_TIMED_TRACING(ns) \
282   TORCH_LAZY_FN_COUNTER(ns);                    \
283   TORCH_LAZY_TIMED("LazyTracing")
284 
285 } // namespace lazy
286 } // namespace torch
287