xref: /aosp_15_r20/external/pytorch/binaries/at_launch_benchmark.cc (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker #include "ATen/Parallel.h"
2*da0073e9SAndroid Build Coastguard Worker 
3*da0073e9SAndroid Build Coastguard Worker #include "c10/util/Flags.h"
4*da0073e9SAndroid Build Coastguard Worker #include "caffe2/core/init.h"
5*da0073e9SAndroid Build Coastguard Worker 
6*da0073e9SAndroid Build Coastguard Worker #include <atomic>
7*da0073e9SAndroid Build Coastguard Worker #include <chrono>
8*da0073e9SAndroid Build Coastguard Worker #include <condition_variable>
9*da0073e9SAndroid Build Coastguard Worker #include <iostream>
10*da0073e9SAndroid Build Coastguard Worker #include <mutex>
11*da0073e9SAndroid Build Coastguard Worker #include <ctime>
12*da0073e9SAndroid Build Coastguard Worker 
13*da0073e9SAndroid Build Coastguard Worker C10_DEFINE_int(iter, 10e4, "Number of at::launch iterations (tasks)");
14*da0073e9SAndroid Build Coastguard Worker C10_DEFINE_int(warmup_iter, 10, "Number of warmup iterations")
15*da0073e9SAndroid Build Coastguard Worker C10_DEFINE_int(inter_op_threads, 0, "Number of inter-op threads");
16*da0073e9SAndroid Build Coastguard Worker C10_DEFINE_int(benchmark_iter, 3, "Number of times to run benchmark")
17*da0073e9SAndroid Build Coastguard Worker 
18*da0073e9SAndroid Build Coastguard Worker namespace {
19*da0073e9SAndroid Build Coastguard Worker int iter = 0;
20*da0073e9SAndroid Build Coastguard Worker std::atomic<int> counter{0};
21*da0073e9SAndroid Build Coastguard Worker std::condition_variable cv;
22*da0073e9SAndroid Build Coastguard Worker std::mutex mutex;
23*da0073e9SAndroid Build Coastguard Worker }
24*da0073e9SAndroid Build Coastguard Worker 
launch_tasks()25*da0073e9SAndroid Build Coastguard Worker  void launch_tasks() {
26*da0073e9SAndroid Build Coastguard Worker   at::launch([]() {
27*da0073e9SAndroid Build Coastguard Worker     at::launch([](){
28*da0073e9SAndroid Build Coastguard Worker       at::launch([]() {
29*da0073e9SAndroid Build Coastguard Worker         auto cur_ctr = ++counter;
30*da0073e9SAndroid Build Coastguard Worker         if (cur_ctr == iter) {
31*da0073e9SAndroid Build Coastguard Worker           std::unique_lock<std::mutex> lk(mutex);
32*da0073e9SAndroid Build Coastguard Worker           cv.notify_one();
33*da0073e9SAndroid Build Coastguard Worker         }
34*da0073e9SAndroid Build Coastguard Worker       });
35*da0073e9SAndroid Build Coastguard Worker     });
36*da0073e9SAndroid Build Coastguard Worker   });
37*da0073e9SAndroid Build Coastguard Worker }
38*da0073e9SAndroid Build Coastguard Worker 
launch_tasks_and_wait(int tasks_num)39*da0073e9SAndroid Build Coastguard Worker void launch_tasks_and_wait(int tasks_num) {
40*da0073e9SAndroid Build Coastguard Worker   iter = tasks_num;
41*da0073e9SAndroid Build Coastguard Worker   counter = 0;
42*da0073e9SAndroid Build Coastguard Worker   for (auto idx = 0; idx < iter; ++idx) {
43*da0073e9SAndroid Build Coastguard Worker     launch_tasks();
44*da0073e9SAndroid Build Coastguard Worker   }
45*da0073e9SAndroid Build Coastguard Worker   {
46*da0073e9SAndroid Build Coastguard Worker     std::unique_lock<std::mutex> lk(mutex);
47*da0073e9SAndroid Build Coastguard Worker     while (counter < iter) {
48*da0073e9SAndroid Build Coastguard Worker       cv.wait(lk);
49*da0073e9SAndroid Build Coastguard Worker     }
50*da0073e9SAndroid Build Coastguard Worker   }
51*da0073e9SAndroid Build Coastguard Worker }
52*da0073e9SAndroid Build Coastguard Worker 
main(int argc,char ** argv)53*da0073e9SAndroid Build Coastguard Worker int main(int argc, char** argv) {
54*da0073e9SAndroid Build Coastguard Worker   if (!c10::ParseCommandLineFlags(&argc, &argv)) {
55*da0073e9SAndroid Build Coastguard Worker     std::cout << "Failed to parse command line flags" << std::endl;
56*da0073e9SAndroid Build Coastguard Worker     return -1;
57*da0073e9SAndroid Build Coastguard Worker   }
58*da0073e9SAndroid Build Coastguard Worker   caffe2::unsafeRunCaffe2InitFunction("registerThreadPools");
59*da0073e9SAndroid Build Coastguard Worker   at::init_num_threads();
60*da0073e9SAndroid Build Coastguard Worker 
61*da0073e9SAndroid Build Coastguard Worker   if (FLAGS_inter_op_threads > 0) {
62*da0073e9SAndroid Build Coastguard Worker     at::set_num_interop_threads(FLAGS_inter_op_threads);
63*da0073e9SAndroid Build Coastguard Worker   }
64*da0073e9SAndroid Build Coastguard Worker 
65*da0073e9SAndroid Build Coastguard Worker   typedef std::chrono::high_resolution_clock clock;
66*da0073e9SAndroid Build Coastguard Worker   typedef std::chrono::milliseconds ms;
67*da0073e9SAndroid Build Coastguard Worker 
68*da0073e9SAndroid Build Coastguard Worker   std::cout << "Launching " << FLAGS_warmup_iter << " warmup tasks using "
69*da0073e9SAndroid Build Coastguard Worker             << at::get_num_interop_threads() << " threads "
70*da0073e9SAndroid Build Coastguard Worker             << std::endl;
71*da0073e9SAndroid Build Coastguard Worker 
72*da0073e9SAndroid Build Coastguard Worker   std::chrono::time_point<clock> start_time = clock::now();
73*da0073e9SAndroid Build Coastguard Worker   launch_tasks_and_wait(FLAGS_warmup_iter);
74*da0073e9SAndroid Build Coastguard Worker   auto duration = static_cast<float>(
75*da0073e9SAndroid Build Coastguard Worker       std::chrono::duration_cast<ms>(clock::now() - start_time).count());
76*da0073e9SAndroid Build Coastguard Worker 
77*da0073e9SAndroid Build Coastguard Worker   std::cout << "Warmup time: " << duration << " ms." << std::endl;
78*da0073e9SAndroid Build Coastguard Worker 
79*da0073e9SAndroid Build Coastguard Worker   std::cout << "Launching " << FLAGS_iter << " tasks using "
80*da0073e9SAndroid Build Coastguard Worker             << at::get_num_interop_threads() << " threads "
81*da0073e9SAndroid Build Coastguard Worker             << std::endl;
82*da0073e9SAndroid Build Coastguard Worker 
83*da0073e9SAndroid Build Coastguard Worker   for (auto bench_iter = 0; bench_iter < FLAGS_benchmark_iter; ++bench_iter) {
84*da0073e9SAndroid Build Coastguard Worker     start_time = clock::now();
85*da0073e9SAndroid Build Coastguard Worker     launch_tasks_and_wait(FLAGS_iter);
86*da0073e9SAndroid Build Coastguard Worker     duration = static_cast<float>(
87*da0073e9SAndroid Build Coastguard Worker         std::chrono::duration_cast<ms>(clock::now() - start_time).count());
88*da0073e9SAndroid Build Coastguard Worker 
89*da0073e9SAndroid Build Coastguard Worker     std::cout << "Time to run " << iter << " iterations "
90*da0073e9SAndroid Build Coastguard Worker               << (duration/1000.0) << " s." << std::endl;
91*da0073e9SAndroid Build Coastguard Worker   }
92*da0073e9SAndroid Build Coastguard Worker 
93*da0073e9SAndroid Build Coastguard Worker   return 0;
94*da0073e9SAndroid Build Coastguard Worker }
95