1 #include <torch/torch.h> 2 #include <chrono> 3 #include <condition_variable> 4 #include <mutex> 5 6 class Baton { 7 public: post()8 void post() { 9 std::unique_lock<std::mutex> l(lock_); 10 done_ = true; 11 cv_.notify_all(); 12 } wait()13 void wait() { 14 std::unique_lock<std::mutex> l(lock_); 15 while (!done_) { 16 cv_.wait(l); 17 } 18 } 19 20 private: 21 std::mutex lock_; 22 std::condition_variable cv_; 23 bool done_{false}; 24 }; 25 AtLaunch_Base(int32_t numIters)26void AtLaunch_Base(int32_t numIters) { 27 struct Helper { 28 explicit Helper(int32_t lim) : limit_(lim) {} 29 void operator()() { 30 if (++val_ == limit_) { 31 done.post(); 32 } else { 33 at::launch([this]() { (*this)(); }); 34 } 35 } 36 int val_{0}; 37 int limit_; 38 Baton done; 39 }; 40 Helper h(numIters); 41 auto start = std::chrono::system_clock::now(); 42 h(); 43 h.done.wait(); 44 std::cout << "NoData " 45 << static_cast<double>( 46 std::chrono::duration_cast<std::chrono::microseconds>( 47 std::chrono::system_clock::now() - start) 48 .count()) / 49 static_cast<double>(numIters) 50 << " usec/each\n"; 51 } 52 AtLaunch_WithData(int32_t numIters,int32_t vecSize)53void AtLaunch_WithData(int32_t numIters, int32_t vecSize) { 54 struct Helper { 55 explicit Helper(int32_t lim) : limit_(lim) {} 56 void operator()(std::vector<int32_t> v) { 57 if (++val_ == limit_) { 58 done.post(); 59 } else { 60 at::launch([this, v = std::move(v)]() { (*this)(v); }); 61 } 62 } 63 int val_{0}; 64 int limit_; 65 Baton done; 66 }; 67 Helper h(numIters); 68 std::vector<int32_t> v(vecSize, 0); 69 auto start = std::chrono::system_clock::now(); 70 h(v); 71 h.done.wait(); 72 std::cout << "WithData(" << vecSize << "): " 73 << static_cast<double>( 74 std::chrono::duration_cast<std::chrono::microseconds>( 75 std::chrono::system_clock::now() - start) 76 .count()) / 77 static_cast<double>(numIters) 78 << " usec/each\n"; 79 } 80 main(int argc,char ** argv)81int main(int argc, char** argv) { 82 int32_t N = 1000000; 83 AtLaunch_Base(N); 84 AtLaunch_WithData(N, 0); 85 AtLaunch_WithData(N, 4); 86 AtLaunch_WithData(N, 256); 87 return 0; 88 } 89