1 #pragma once
2
3 #include <chrono>
4 #include <mutex>
5 #include <string>
6 #include <unordered_map>
7 #include <vector>
8
9 #include <torch/csrc/Export.h>
10
11 namespace torch::jit::logging {
12
13 class LoggerBase {
14 public:
15 TORCH_API virtual void addStatValue(
16 const std::string& stat_name,
17 int64_t val) = 0;
18 virtual ~LoggerBase() = default;
19 };
20
21 TORCH_API LoggerBase* getLogger();
22 TORCH_API LoggerBase* setLogger(LoggerBase* logger);
23
24 // No-op logger. This is the default and is meant to incur almost no runtime
25 // overhead.
26
27 class NoopLogger : public LoggerBase {
28 public:
addStatValue(const std::string & stat_name,int64_t val)29 void addStatValue(
30 const std::string& stat_name [[maybe_unused]],
31 int64_t val [[maybe_unused]]) override {}
32 ~NoopLogger() override = default;
33 };
34
35 // Trivial locking logger. Pass in an instance of this to setLogger() to use it.
36 // This keeps track of the sum of all statistics.
37 //
38 // NOTE: this is not written in a scalable way and should probably only be used
39 // in the single-threaded case or for testing.
40 class TORCH_API LockingLogger : public LoggerBase {
41 public:
42 void addStatValue(const std::string& stat_name, int64_t val) override;
43 virtual int64_t getCounterValue(const std::string& name) const;
44 enum class AggregationType { SUM = 0, AVG = 1 };
45 void setAggregationType(const std::string& stat_name, AggregationType type);
46 ~LockingLogger() override = default;
47
48 private:
49 mutable std::mutex m;
50 struct RawCounter {
RawCounterRawCounter51 RawCounter() : sum(0), count(0) {}
52 int64_t sum;
53 size_t count;
54 };
55 std::unordered_map<std::string, RawCounter> raw_counters;
56 std::unordered_map<std::string, AggregationType> agg_types;
57 };
58
59 // Make this struct so the timer internals are opaque to the user.
60 struct JITTimePoint {
61 std::chrono::time_point<std::chrono::high_resolution_clock> point;
62 };
63
64 TORCH_API JITTimePoint timePoint();
65 TORCH_API void recordDurationSince(
66 const std::string& name,
67 const JITTimePoint& tp);
68
69 namespace runtime_counters {
70 constexpr const char* GRAPH_EXECUTORS_CONSTRUCTED =
71 "pytorch_runtime.graph_executors_constructed";
72 constexpr const char* GRAPH_EXECUTOR_INVOCATIONS =
73 "pytorch_runtime.graph_executor_invocations";
74 constexpr const char* EXECUTION_PLAN_CACHE_HIT =
75 "pytorch_runtime.execution_plan_cache_hit";
76 constexpr const char* EXECUTION_PLAN_CACHE_MISS =
77 "pytorch_runtime.execution_plan_cache_miss";
78
allRuntimeCounters()79 inline std::vector<const char*> allRuntimeCounters() {
80 return {
81 GRAPH_EXECUTORS_CONSTRUCTED,
82 GRAPH_EXECUTOR_INVOCATIONS,
83 EXECUTION_PLAN_CACHE_HIT,
84 EXECUTION_PLAN_CACHE_MISS};
85 }
86
87 } // namespace runtime_counters
88
89 } // namespace torch::jit::logging
90