xref: /aosp_15_r20/external/pytorch/c10/util/WaitCounter.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <c10/util/WaitCounter.h>
2 
3 #include <c10/util/Synchronized.h>
4 
5 #include <chrono>
6 #include <memory>
7 #include <string_view>
8 #include <unordered_map>
9 #include <vector>
10 
11 namespace c10::monitor {
12 
13 namespace detail {
14 namespace {
15 using WaitCounterBackendFactories =
16     std::vector<std::shared_ptr<WaitCounterBackendFactoryIf>>;
17 
waitCounterBackendFactories()18 Synchronized<WaitCounterBackendFactories>& waitCounterBackendFactories() {
19   static auto instance = new Synchronized<WaitCounterBackendFactories>();
20   return *instance;
21 }
22 } // namespace
23 
24 class WaitCounterImpl {
25  public:
getInstance(std::string_view key)26   static WaitCounterImpl& getInstance(std::string_view key) {
27     static auto& implMapSynchronized = *new Synchronized<
28         std::unordered_map<std::string, std::unique_ptr<WaitCounterImpl>>>();
29 
30     return *implMapSynchronized.withLock([&](auto& implMap) {
31       if (auto implIt = implMap.find(std::string(key));
32           implIt != implMap.end()) {
33         return implIt->second.get();
34       }
35 
36       auto [implIt, emplaceSuccess] = implMap.emplace(
37           std::string{key},
38           std::unique_ptr<WaitCounterImpl>(new WaitCounterImpl(key)));
39 
40       assert(emplaceSuccess);
41 
42       return implIt->second.get();
43     });
44   }
45 
start()46   SmallVector<intptr_t> start() noexcept {
47     auto now = std::chrono::steady_clock::now();
48     SmallVector<intptr_t> ctxs;
49     ctxs.reserve(backends_.size());
50     for (const auto& backend : backends_) {
51       ctxs.push_back(backend->start(now));
52     }
53     return ctxs;
54   }
55 
stop(SmallVector<intptr_t> && ctxs)56   void stop(SmallVector<intptr_t>&& ctxs) noexcept {
57     auto now = std::chrono::steady_clock::now();
58     assert(ctxs.size() == backends_.size());
59     for (size_t i = 0; i < ctxs.size(); ++i) {
60       backends_[i]->stop(now, ctxs[i]);
61     }
62   }
63 
64  private:
WaitCounterImpl(std::string_view key)65   explicit WaitCounterImpl(std::string_view key) {
66     auto factoriesCopy = waitCounterBackendFactories().withLock(
67         [](auto& factories) { return factories; });
68     for (const auto& factory : factoriesCopy) {
69       if (auto backend = factory->create(key)) {
70         backends_.push_back(std::move(backend));
71       }
72     }
73   }
74 
75   SmallVector<std::unique_ptr<WaitCounterBackendIf>> backends_;
76 };
77 
registerWaitCounterBackend(std::unique_ptr<WaitCounterBackendFactoryIf> factory)78 void registerWaitCounterBackend(
79     std::unique_ptr<WaitCounterBackendFactoryIf> factory) {
80   waitCounterBackendFactories().withLock(
81       [&](auto& factories) { factories.push_back(std::move(factory)); });
82 }
83 } // namespace detail
84 
WaitCounterHandle(std::string_view key)85 WaitCounterHandle::WaitCounterHandle(std::string_view key)
86     : impl_(detail::WaitCounterImpl::getInstance(key)) {}
87 
start()88 WaitCounterHandle::WaitGuard WaitCounterHandle::start() {
89   return WaitCounterHandle::WaitGuard(*this, impl_.start());
90 }
91 
stop(SmallVector<intptr_t> && ctxs)92 void WaitCounterHandle::stop(SmallVector<intptr_t>&& ctxs) {
93   return impl_.stop(std::move(ctxs));
94 }
95 } // namespace c10::monitor
96