xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/runtime/logging.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <chrono>
4 #include <mutex>
5 #include <string>
6 #include <unordered_map>
7 #include <vector>
8 
9 #include <torch/csrc/Export.h>
10 
11 namespace torch::jit::logging {
12 
13 class LoggerBase {
14  public:
15   TORCH_API virtual void addStatValue(
16       const std::string& stat_name,
17       int64_t val) = 0;
18   virtual ~LoggerBase() = default;
19 };
20 
21 TORCH_API LoggerBase* getLogger();
22 TORCH_API LoggerBase* setLogger(LoggerBase* logger);
23 
24 // No-op logger. This is the default and is meant to incur almost no runtime
25 // overhead.
26 
27 class NoopLogger : public LoggerBase {
28  public:
addStatValue(const std::string & stat_name,int64_t val)29   void addStatValue(
30       const std::string& stat_name [[maybe_unused]],
31       int64_t val [[maybe_unused]]) override {}
32   ~NoopLogger() override = default;
33 };
34 
35 // Trivial locking logger. Pass in an instance of this to setLogger() to use it.
36 // This keeps track of the sum of all statistics.
37 //
38 // NOTE: this is not written in a scalable way and should probably only be used
39 // in the single-threaded case or for testing.
40 class TORCH_API LockingLogger : public LoggerBase {
41  public:
42   void addStatValue(const std::string& stat_name, int64_t val) override;
43   virtual int64_t getCounterValue(const std::string& name) const;
44   enum class AggregationType { SUM = 0, AVG = 1 };
45   void setAggregationType(const std::string& stat_name, AggregationType type);
46   ~LockingLogger() override = default;
47 
48  private:
49   mutable std::mutex m;
50   struct RawCounter {
RawCounterRawCounter51     RawCounter() : sum(0), count(0) {}
52     int64_t sum;
53     size_t count;
54   };
55   std::unordered_map<std::string, RawCounter> raw_counters;
56   std::unordered_map<std::string, AggregationType> agg_types;
57 };
58 
59 // Make this struct so the timer internals are opaque to the user.
60 struct JITTimePoint {
61   std::chrono::time_point<std::chrono::high_resolution_clock> point;
62 };
63 
64 TORCH_API JITTimePoint timePoint();
65 TORCH_API void recordDurationSince(
66     const std::string& name,
67     const JITTimePoint& tp);
68 
69 namespace runtime_counters {
70 constexpr const char* GRAPH_EXECUTORS_CONSTRUCTED =
71     "pytorch_runtime.graph_executors_constructed";
72 constexpr const char* GRAPH_EXECUTOR_INVOCATIONS =
73     "pytorch_runtime.graph_executor_invocations";
74 constexpr const char* EXECUTION_PLAN_CACHE_HIT =
75     "pytorch_runtime.execution_plan_cache_hit";
76 constexpr const char* EXECUTION_PLAN_CACHE_MISS =
77     "pytorch_runtime.execution_plan_cache_miss";
78 
allRuntimeCounters()79 inline std::vector<const char*> allRuntimeCounters() {
80   return {
81       GRAPH_EXECUTORS_CONSTRUCTED,
82       GRAPH_EXECUTOR_INVOCATIONS,
83       EXECUTION_PLAN_CACHE_HIT,
84       EXECUTION_PLAN_CACHE_MISS};
85 }
86 
87 } // namespace runtime_counters
88 
89 } // namespace torch::jit::logging
90