1 /** 2 * This file is adapted from PyTorch/XLA 3 * https://github.com/pytorch/xla/blob/master/third_party/xla_client/multi_wait.h 4 */ 5 6 #pragma once 7 8 #include <condition_variable> 9 #include <exception> 10 #include <functional> 11 #include <memory> 12 #include <mutex> 13 14 #include <c10/macros/Export.h> 15 16 namespace torch { 17 namespace lazy { 18 19 // Support waiting for a number of tasks to complete. 20 class TORCH_API MultiWait { 21 public: MultiWait(size_t count)22 explicit MultiWait(size_t count) : count_(count) {} 23 24 // Signal the completion of a single task. 25 void Done(); 26 27 // Waits until at least count (passed as constructor value) completions 28 // happened. 29 void Wait(); 30 31 // Same as above, but waits up to wait_seconds. 32 void Wait(double wait_seconds); 33 34 // Resets the threshold counter for the MultiWait object. The completed count 35 // is also reset to zero. 36 void Reset(size_t count); 37 38 // Creates a completer functor which signals the mult wait object once func 39 // has completed. Handles exceptions by signaling the multi wait with the 40 // proper status value. This API returns a function which captures a MultiWait 41 // reference, so care must be taken such that the reference remains valid for 42 // the whole lifetime of the returned function. 43 std::function<void()> Completer(std::function<void()> func); 44 45 // Similar as the above API, but with explicit capture of the MultiWait shared 46 // pointer. 47 static std::function<void()> Completer( 48 std::shared_ptr<MultiWait> mwait, 49 std::function<void()> func); 50 51 private: 52 void Complete(const std::function<void()>& func); 53 54 std::mutex mutex_; 55 std::condition_variable cv_; 56 size_t count_ = 0; 57 size_t completed_count_ = 0; 58 std::exception_ptr exptr_; 59 }; 60 61 } // namespace lazy 62 } // namespace torch 63