xref: /aosp_15_r20/external/pytorch/test/cpp/api/parallel_benchmark.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/torch.h>
2 #include <chrono>
3 #include <condition_variable>
4 #include <mutex>
5 
6 class Baton {
7  public:
post()8   void post() {
9     std::unique_lock<std::mutex> l(lock_);
10     done_ = true;
11     cv_.notify_all();
12   }
wait()13   void wait() {
14     std::unique_lock<std::mutex> l(lock_);
15     while (!done_) {
16       cv_.wait(l);
17     }
18   }
19 
20  private:
21   std::mutex lock_;
22   std::condition_variable cv_;
23   bool done_{false};
24 };
25 
AtLaunch_Base(int32_t numIters)26 void AtLaunch_Base(int32_t numIters) {
27   struct Helper {
28     explicit Helper(int32_t lim) : limit_(lim) {}
29     void operator()() {
30       if (++val_ == limit_) {
31         done.post();
32       } else {
33         at::launch([this]() { (*this)(); });
34       }
35     }
36     int val_{0};
37     int limit_;
38     Baton done;
39   };
40   Helper h(numIters);
41   auto start = std::chrono::system_clock::now();
42   h();
43   h.done.wait();
44   std::cout << "NoData "
45             << static_cast<double>(
46                    std::chrono::duration_cast<std::chrono::microseconds>(
47                        std::chrono::system_clock::now() - start)
48                        .count()) /
49           static_cast<double>(numIters)
50             << " usec/each\n";
51 }
52 
AtLaunch_WithData(int32_t numIters,int32_t vecSize)53 void AtLaunch_WithData(int32_t numIters, int32_t vecSize) {
54   struct Helper {
55     explicit Helper(int32_t lim) : limit_(lim) {}
56     void operator()(std::vector<int32_t> v) {
57       if (++val_ == limit_) {
58         done.post();
59       } else {
60         at::launch([this, v = std::move(v)]() { (*this)(v); });
61       }
62     }
63     int val_{0};
64     int limit_;
65     Baton done;
66   };
67   Helper h(numIters);
68   std::vector<int32_t> v(vecSize, 0);
69   auto start = std::chrono::system_clock::now();
70   h(v);
71   h.done.wait();
72   std::cout << "WithData(" << vecSize << "): "
73             << static_cast<double>(
74                    std::chrono::duration_cast<std::chrono::microseconds>(
75                        std::chrono::system_clock::now() - start)
76                        .count()) /
77           static_cast<double>(numIters)
78             << " usec/each\n";
79 }
80 
main(int argc,char ** argv)81 int main(int argc, char** argv) {
82   int32_t N = 1000000;
83   AtLaunch_Base(N);
84   AtLaunch_WithData(N, 0);
85   AtLaunch_WithData(N, 4);
86   AtLaunch_WithData(N, 256);
87   return 0;
88 }
89