xref: /aosp_15_r20/external/pytorch/c10/util/Gauge.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <c10/util/Gauge.h>
2 
3 #include <c10/util/Synchronized.h>
4 
5 #include <memory>
6 #include <string>
7 #include <string_view>
8 #include <unordered_map>
9 #include <vector>
10 
11 namespace c10::monitor {
12 
13 namespace detail {
14 namespace {
15 using GaugeBackendFactories =
16     std::vector<std::shared_ptr<GaugeBackendFactoryIf>>;
17 
gaugeBackendFactories()18 Synchronized<GaugeBackendFactories>& gaugeBackendFactories() {
19   static auto instance = new Synchronized<GaugeBackendFactories>();
20   return *instance;
21 }
22 } // namespace
23 
24 class GaugeImpl {
25  public:
getInstance(std::string_view key)26   static GaugeImpl& getInstance(std::string_view key) {
27     static auto& implMapSynchronized = *new Synchronized<
28         std::unordered_map<std::string, std::unique_ptr<GaugeImpl>>>();
29 
30     return *implMapSynchronized.withLock([&](auto& implMap) {
31       if (auto implIt = implMap.find(std::string(key));
32           implIt != implMap.end()) {
33         return implIt->second.get();
34       }
35 
36       auto [implIt, emplaceSuccess] = implMap.emplace(
37           std::string{key}, std::unique_ptr<GaugeImpl>(new GaugeImpl(key)));
38 
39       assert(emplaceSuccess);
40 
41       return implIt->second.get();
42     });
43   }
44 
record(int64_t value)45   void record(int64_t value) {
46     for (auto& backend : backends_) {
47       backend->record(value);
48     }
49   }
50 
51  private:
GaugeImpl(std::string_view key)52   explicit GaugeImpl(std::string_view key) {
53     auto factoriesCopy = gaugeBackendFactories().withLock(
54         [](auto& factories) { return factories; });
55     for (const auto& factory : factoriesCopy) {
56       if (auto backend = factory->create(key)) {
57         backends_.push_back(std::move(backend));
58       }
59     }
60   }
61 
62   SmallVector<std::unique_ptr<GaugeBackendIf>> backends_;
63 };
64 
registerGaugeBackend(std::unique_ptr<GaugeBackendFactoryIf> backend)65 void registerGaugeBackend(std::unique_ptr<GaugeBackendFactoryIf> backend) {
66   gaugeBackendFactories().withLock(
67       [&](auto& backends) { backends.push_back(std::move(backend)); });
68 }
69 
70 } // namespace detail
71 
GaugeHandle(std::string_view key)72 GaugeHandle::GaugeHandle(std::string_view key)
73     : impl_(detail::GaugeImpl::getInstance(key)) {}
74 
record(int64_t value)75 void GaugeHandle::record(int64_t value) {
76   impl_.record(value);
77 }
78 
79 } // namespace c10::monitor
80