xref: /aosp_15_r20/external/pytorch/torch/csrc/lazy/core/multi_wait.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 /**
2  * This file is adapted from PyTorch/XLA
3  * https://github.com/pytorch/xla/blob/master/third_party/xla_client/multi_wait.h
4  */
5 
6 #pragma once
7 
8 #include <condition_variable>
9 #include <exception>
10 #include <functional>
11 #include <memory>
12 #include <mutex>
13 
14 #include <c10/macros/Export.h>
15 
16 namespace torch {
17 namespace lazy {
18 
19 // Support waiting for a number of tasks to complete.
20 class TORCH_API MultiWait {
21  public:
MultiWait(size_t count)22   explicit MultiWait(size_t count) : count_(count) {}
23 
24   // Signal the completion of a single task.
25   void Done();
26 
27   // Waits until at least count (passed as constructor value) completions
28   // happened.
29   void Wait();
30 
31   // Same as above, but waits up to wait_seconds.
32   void Wait(double wait_seconds);
33 
34   // Resets the threshold counter for the MultiWait object. The completed count
35   // is also reset to zero.
36   void Reset(size_t count);
37 
38   // Creates a completer functor which signals the mult wait object once func
39   // has completed. Handles exceptions by signaling the multi wait with the
40   // proper status value. This API returns a function which captures a MultiWait
41   // reference, so care must be taken such that the reference remains valid for
42   // the whole lifetime of the returned function.
43   std::function<void()> Completer(std::function<void()> func);
44 
45   // Similar as the above API, but with explicit capture of the MultiWait shared
46   // pointer.
47   static std::function<void()> Completer(
48       std::shared_ptr<MultiWait> mwait,
49       std::function<void()> func);
50 
51  private:
52   void Complete(const std::function<void()>& func);
53 
54   std::mutex mutex_;
55   std::condition_variable cv_;
56   size_t count_ = 0;
57   size_t completed_count_ = 0;
58   std::exception_ptr exptr_;
59 };
60 
61 } // namespace lazy
62 } // namespace torch
63