1*da0073e9SAndroid Build Coastguard Worker #include "ATen/Parallel.h"
2*da0073e9SAndroid Build Coastguard Worker
3*da0073e9SAndroid Build Coastguard Worker #include "c10/util/Flags.h"
4*da0073e9SAndroid Build Coastguard Worker #include "caffe2/core/init.h"
5*da0073e9SAndroid Build Coastguard Worker
6*da0073e9SAndroid Build Coastguard Worker #include <atomic>
7*da0073e9SAndroid Build Coastguard Worker #include <chrono>
8*da0073e9SAndroid Build Coastguard Worker #include <condition_variable>
9*da0073e9SAndroid Build Coastguard Worker #include <iostream>
10*da0073e9SAndroid Build Coastguard Worker #include <mutex>
11*da0073e9SAndroid Build Coastguard Worker #include <ctime>
12*da0073e9SAndroid Build Coastguard Worker
13*da0073e9SAndroid Build Coastguard Worker C10_DEFINE_int(iter, 10e4, "Number of at::launch iterations (tasks)");
14*da0073e9SAndroid Build Coastguard Worker C10_DEFINE_int(warmup_iter, 10, "Number of warmup iterations")
15*da0073e9SAndroid Build Coastguard Worker C10_DEFINE_int(inter_op_threads, 0, "Number of inter-op threads");
16*da0073e9SAndroid Build Coastguard Worker C10_DEFINE_int(benchmark_iter, 3, "Number of times to run benchmark")
17*da0073e9SAndroid Build Coastguard Worker
18*da0073e9SAndroid Build Coastguard Worker namespace {
19*da0073e9SAndroid Build Coastguard Worker int iter = 0;
20*da0073e9SAndroid Build Coastguard Worker std::atomic<int> counter{0};
21*da0073e9SAndroid Build Coastguard Worker std::condition_variable cv;
22*da0073e9SAndroid Build Coastguard Worker std::mutex mutex;
23*da0073e9SAndroid Build Coastguard Worker }
24*da0073e9SAndroid Build Coastguard Worker
launch_tasks()25*da0073e9SAndroid Build Coastguard Worker void launch_tasks() {
26*da0073e9SAndroid Build Coastguard Worker at::launch([]() {
27*da0073e9SAndroid Build Coastguard Worker at::launch([](){
28*da0073e9SAndroid Build Coastguard Worker at::launch([]() {
29*da0073e9SAndroid Build Coastguard Worker auto cur_ctr = ++counter;
30*da0073e9SAndroid Build Coastguard Worker if (cur_ctr == iter) {
31*da0073e9SAndroid Build Coastguard Worker std::unique_lock<std::mutex> lk(mutex);
32*da0073e9SAndroid Build Coastguard Worker cv.notify_one();
33*da0073e9SAndroid Build Coastguard Worker }
34*da0073e9SAndroid Build Coastguard Worker });
35*da0073e9SAndroid Build Coastguard Worker });
36*da0073e9SAndroid Build Coastguard Worker });
37*da0073e9SAndroid Build Coastguard Worker }
38*da0073e9SAndroid Build Coastguard Worker
launch_tasks_and_wait(int tasks_num)39*da0073e9SAndroid Build Coastguard Worker void launch_tasks_and_wait(int tasks_num) {
40*da0073e9SAndroid Build Coastguard Worker iter = tasks_num;
41*da0073e9SAndroid Build Coastguard Worker counter = 0;
42*da0073e9SAndroid Build Coastguard Worker for (auto idx = 0; idx < iter; ++idx) {
43*da0073e9SAndroid Build Coastguard Worker launch_tasks();
44*da0073e9SAndroid Build Coastguard Worker }
45*da0073e9SAndroid Build Coastguard Worker {
46*da0073e9SAndroid Build Coastguard Worker std::unique_lock<std::mutex> lk(mutex);
47*da0073e9SAndroid Build Coastguard Worker while (counter < iter) {
48*da0073e9SAndroid Build Coastguard Worker cv.wait(lk);
49*da0073e9SAndroid Build Coastguard Worker }
50*da0073e9SAndroid Build Coastguard Worker }
51*da0073e9SAndroid Build Coastguard Worker }
52*da0073e9SAndroid Build Coastguard Worker
main(int argc,char ** argv)53*da0073e9SAndroid Build Coastguard Worker int main(int argc, char** argv) {
54*da0073e9SAndroid Build Coastguard Worker if (!c10::ParseCommandLineFlags(&argc, &argv)) {
55*da0073e9SAndroid Build Coastguard Worker std::cout << "Failed to parse command line flags" << std::endl;
56*da0073e9SAndroid Build Coastguard Worker return -1;
57*da0073e9SAndroid Build Coastguard Worker }
58*da0073e9SAndroid Build Coastguard Worker caffe2::unsafeRunCaffe2InitFunction("registerThreadPools");
59*da0073e9SAndroid Build Coastguard Worker at::init_num_threads();
60*da0073e9SAndroid Build Coastguard Worker
61*da0073e9SAndroid Build Coastguard Worker if (FLAGS_inter_op_threads > 0) {
62*da0073e9SAndroid Build Coastguard Worker at::set_num_interop_threads(FLAGS_inter_op_threads);
63*da0073e9SAndroid Build Coastguard Worker }
64*da0073e9SAndroid Build Coastguard Worker
65*da0073e9SAndroid Build Coastguard Worker typedef std::chrono::high_resolution_clock clock;
66*da0073e9SAndroid Build Coastguard Worker typedef std::chrono::milliseconds ms;
67*da0073e9SAndroid Build Coastguard Worker
68*da0073e9SAndroid Build Coastguard Worker std::cout << "Launching " << FLAGS_warmup_iter << " warmup tasks using "
69*da0073e9SAndroid Build Coastguard Worker << at::get_num_interop_threads() << " threads "
70*da0073e9SAndroid Build Coastguard Worker << std::endl;
71*da0073e9SAndroid Build Coastguard Worker
72*da0073e9SAndroid Build Coastguard Worker std::chrono::time_point<clock> start_time = clock::now();
73*da0073e9SAndroid Build Coastguard Worker launch_tasks_and_wait(FLAGS_warmup_iter);
74*da0073e9SAndroid Build Coastguard Worker auto duration = static_cast<float>(
75*da0073e9SAndroid Build Coastguard Worker std::chrono::duration_cast<ms>(clock::now() - start_time).count());
76*da0073e9SAndroid Build Coastguard Worker
77*da0073e9SAndroid Build Coastguard Worker std::cout << "Warmup time: " << duration << " ms." << std::endl;
78*da0073e9SAndroid Build Coastguard Worker
79*da0073e9SAndroid Build Coastguard Worker std::cout << "Launching " << FLAGS_iter << " tasks using "
80*da0073e9SAndroid Build Coastguard Worker << at::get_num_interop_threads() << " threads "
81*da0073e9SAndroid Build Coastguard Worker << std::endl;
82*da0073e9SAndroid Build Coastguard Worker
83*da0073e9SAndroid Build Coastguard Worker for (auto bench_iter = 0; bench_iter < FLAGS_benchmark_iter; ++bench_iter) {
84*da0073e9SAndroid Build Coastguard Worker start_time = clock::now();
85*da0073e9SAndroid Build Coastguard Worker launch_tasks_and_wait(FLAGS_iter);
86*da0073e9SAndroid Build Coastguard Worker duration = static_cast<float>(
87*da0073e9SAndroid Build Coastguard Worker std::chrono::duration_cast<ms>(clock::now() - start_time).count());
88*da0073e9SAndroid Build Coastguard Worker
89*da0073e9SAndroid Build Coastguard Worker std::cout << "Time to run " << iter << " iterations "
90*da0073e9SAndroid Build Coastguard Worker << (duration/1000.0) << " s." << std::endl;
91*da0073e9SAndroid Build Coastguard Worker }
92*da0073e9SAndroid Build Coastguard Worker
93*da0073e9SAndroid Build Coastguard Worker return 0;
94*da0073e9SAndroid Build Coastguard Worker }
95