1 /** 2 * This file is adapted from PyTorch/XLA 3 * https://github.com/pytorch/xla/blob/master/third_party/xla_client/metrics.h 4 */ 5 6 #pragma once 7 8 #include <atomic> 9 #include <functional> 10 #include <map> 11 #include <memory> 12 #include <mutex> 13 #include <string> 14 #include <vector> 15 16 #include <c10/macros/Export.h> 17 18 namespace torch { 19 namespace lazy { 20 21 struct TORCH_API Sample { 22 Sample() = default; SampleSample23 Sample(int64_t timestamp_ns, double value) 24 : timestamp_ns(timestamp_ns), value(value) {} 25 26 int64_t timestamp_ns = 0; 27 double value = 0; 28 }; 29 30 using MetricReprFn = std::function<std::string(double)>; 31 32 // Class used to collect time-stamped numeric samples. The samples are stored in 33 // a circular buffer whose size can be configured at constructor time. 34 class TORCH_API MetricData { 35 public: 36 // Creates a new MetricData object with the internal circular buffer storing 37 // max_samples samples. The repr_fn argument allow to specify a function which 38 // pretty-prints a sample value. 39 MetricData(MetricReprFn repr_fn, size_t max_samples); 40 41 // Returns the total values of all the samples being posted to this metric. 42 double Accumulator() const; 43 44 size_t TotalSamples() const; 45 46 void AddSample(int64_t timestamp_ns, double value); 47 48 // Returns a vector with all the current samples, from the oldest to the 49 // newer. If accumulator is not nullptr, it will receive the current value of 50 // the metrics' accumulator (the sum of all posted values). If total_samples 51 // is not nullptr, it will receive the count of the posted values. 52 std::vector<Sample> Samples(double* accumulator, size_t* total_samples) const; 53 Repr(double value)54 std::string Repr(double value) const { 55 return repr_fn_(value); 56 } 57 58 void Reset(); 59 IsValid()60 bool IsValid() const { 61 return TotalSamples() > 0; 62 } 63 64 private: 65 mutable std::mutex lock_; 66 MetricReprFn repr_fn_; 67 size_t count_ = 0; 68 std::vector<Sample> samples_; 69 double accumulator_ = 0.0; 70 }; 71 72 // Counters are a very lightweight form of metrics which do not need to track 73 // sample time. 74 class TORCH_API CounterData { 75 public: CounterData()76 CounterData() : value_(0) {} 77 AddValue(int64_t value)78 void AddValue(int64_t value) { 79 value_ += value; 80 } 81 Value()82 int64_t Value() const { 83 return value_; 84 } 85 Reset()86 void Reset() { 87 value_ = 0; 88 } 89 IsValid()90 bool IsValid() const { 91 return value_ > 0; 92 } 93 94 private: 95 std::atomic<int64_t> value_; 96 }; 97 98 class TORCH_API MetricsArena { 99 public: 100 static MetricsArena* Get(); 101 102 void ResetCounters(); 103 void ResetMetrics(); 104 105 // Registers a new metric in the global arena. 106 void RegisterMetric( 107 const std::string& name, 108 MetricReprFn repr_fn, 109 size_t max_samples, 110 std::shared_ptr<MetricData>* data); 111 112 void RegisterCounter( 113 const std::string& name, 114 std::shared_ptr<CounterData>* data); 115 116 void ForEachMetric( 117 const std::function<void(const std::string&, MetricData*)>& metric_func); 118 119 void ForEachCounter( 120 const std::function<void(const std::string&, CounterData*)>& 121 counter_func); 122 123 std::vector<std::string> GetMetricNames(); 124 125 MetricData* GetMetric(const std::string& name); 126 127 std::vector<std::string> GetCounterNames(); 128 129 CounterData* GetCounter(const std::string& name); 130 131 private: 132 std::mutex lock_; 133 std::map<std::string, std::shared_ptr<MetricData>> metrics_; 134 std::map<std::string, std::shared_ptr<CounterData>> counters_; 135 }; 136 137 // Emits the value in a to_string() conversion. 138 TORCH_API std::string MetricFnValue(double value); 139 // Emits the value in a humanized bytes representation. 140 TORCH_API std::string MetricFnBytes(double value); 141 // Emits the value in a humanized time representation. The value is expressed in 142 // nanoseconds EPOCH time. 143 TORCH_API std::string MetricFnTime(double value); 144 145 // The typical use of a Metric is one in which it gets created either in a 146 // global scope context: 147 // static Metric* metric = new Metric("RpcCount"); 148 // Or within a function scope: 149 // void MyFunction(...) { 150 // static Metric* metric = new Metric("RpcCount"); 151 // ... 152 // metric->AddSample(ts_nanos, some_value); 153 // } 154 class TORCH_API Metric { 155 public: 156 explicit Metric( 157 std::string name, 158 MetricReprFn repr_fn = MetricFnValue, 159 size_t max_samples = 0); 160 Name()161 const std::string& Name() const { 162 return name_; 163 } 164 165 double Accumulator() const; 166 167 void AddSample(int64_t timestamp_ns, double value); 168 169 void AddSample(double value); 170 171 std::vector<Sample> Samples(double* accumulator, size_t* total_samples) const; 172 173 std::string Repr(double value) const; 174 175 private: 176 MetricData* GetData() const; 177 178 std::string name_; 179 MetricReprFn repr_fn_; 180 size_t max_samples_; 181 mutable std::shared_ptr<MetricData> data_ptr_; 182 mutable std::atomic<MetricData*> data_; 183 }; 184 185 // A Counter is a lightweight form of metric which tracks an integer value which 186 // can increase or decrease. 187 // A typical use is as: 188 // static Counter* counter = new Counter("MyCounter"); 189 // ... 190 // counter->AddValue(+1); 191 class TORCH_API Counter { 192 public: 193 explicit Counter(std::string name); 194 AddValue(int64_t value)195 void AddValue(int64_t value) { 196 GetData()->AddValue(value); 197 } 198 Value()199 int64_t Value() const { 200 return GetData()->Value(); 201 } 202 203 private: 204 CounterData* GetData() const; 205 206 std::string name_; 207 mutable std::shared_ptr<CounterData> data_ptr_; 208 mutable std::atomic<CounterData*> data_; 209 }; 210 211 #define TORCH_LAZY_COUNTER(name, value) \ 212 do { \ 213 static ::torch::lazy::Counter* __counter = \ 214 new ::torch::lazy::Counter(name); \ 215 __counter->AddValue(value); \ 216 } while (0) 217 218 #define TORCH_LAZY_FN_COUNTER(ns) TORCH_LAZY_COUNTER(c10::str(ns, __func__), 1) 219 220 #define TORCH_LAZY_VALUE_METRIC(name, value) \ 221 do { \ 222 static ::torch::lazy::Metric* __metric = \ 223 new ::torch::lazy::Metric(name, torch::lazy::MetricFnValue); \ 224 __metric->AddSample(value); \ 225 } while (0) 226 227 // Creates a report with the current metrics statistics. 228 TORCH_API std::string CreateMetricReport(); 229 230 // Creates a report with the selected metrics statistics. 231 TORCH_API std::string CreateMetricReport( 232 const std::vector<std::string>& counter_names, 233 const std::vector<std::string>& metric_names); 234 235 // Returns the currently registered metric names. Note that the list can grow 236 // since metrics are usually function intialized (they are static function 237 // variables). 238 TORCH_API std::vector<std::string> GetMetricNames(); 239 240 // Retrieves the metric data of a given metric, or nullptr if such metric does 241 // not exist. 242 TORCH_API MetricData* GetMetric(const std::string& name); 243 244 // Returns the currently registered counter names. Note that the list can grow 245 // since counters are usually function intialized (they are static function 246 // variables). 247 TORCH_API std::vector<std::string> GetCounterNames(); 248 249 // Retrieves the counter data of a given counter, or nullptr if such counter 250 // does not exist. 251 TORCH_API CounterData* GetCounter(const std::string& name); 252 253 // Retrieves the current EPOCH time in nanoseconds. 254 TORCH_API int64_t NowNs(); 255 256 // Scope based utility class TORCH_API to measure the time the code takes within 257 // a given C++ scope. 258 class TORCH_API TimedSection { 259 public: TimedSection(Metric * metric)260 explicit TimedSection(Metric* metric) : metric_(metric), start_(NowNs()) {} 261 ~TimedSection()262 ~TimedSection() { 263 int64_t now = NowNs(); 264 metric_->AddSample(now, now - start_); 265 } 266 Elapsed()267 double Elapsed() const { 268 return 1e-9 * static_cast<double>(NowNs() - start_); 269 } 270 271 private: 272 Metric* metric_; 273 int64_t start_; 274 }; 275 276 #define TORCH_LAZY_TIMED(name) \ 277 static torch::lazy::Metric* timed_metric = \ 278 new torch::lazy::Metric(name, torch::lazy::MetricFnTime); \ 279 torch::lazy::TimedSection timed_section(timed_metric) 280 281 #define TORCH_LAZY_FN_COUNTER_TIMED_TRACING(ns) \ 282 TORCH_LAZY_FN_COUNTER(ns); \ 283 TORCH_LAZY_TIMED("LazyTracing") 284 285 } // namespace lazy 286 } // namespace torch 287