xref: /aosp_15_r20/external/pytorch/test/profiler/test_cpp_thread.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 
2 #include <torch/csrc/autograd/profiler_kineto.h>
3 #include <torch/torch.h>
4 #include <string>
5 
6 using namespace torch::autograd::profiler;
7 
blueprint(const std::string & text)8 void blueprint(const std::string& text) {
9   printf("\33[94m%s\33[0m\n", text.c_str());
10 }
11 
12 /**
13  * We're emulating a C++ training engine calling into Python to allow Python
14  * code controlling how profiling should be done.
15  */
16 class ProfilerEventHandler
17     : public std::enable_shared_from_this<ProfilerEventHandler> {
18  public:
19   static std::shared_ptr<ProfilerEventHandler> Handler;
Register(const std::shared_ptr<ProfilerEventHandler> & handler)20   static void Register(const std::shared_ptr<ProfilerEventHandler>& handler) {
21     Handler = handler;
22   }
23 
24  public:
~ProfilerEventHandler()25   virtual ~ProfilerEventHandler() {}
onIterationStart(int)26   virtual void onIterationStart(int) {}
emulateTraining(int,int)27   virtual void emulateTraining(int, int) {}
28 };
29 std::shared_ptr<ProfilerEventHandler> ProfilerEventHandler::Handler;
30 
31 class ProfilerEventHandlerTrampoline : public ProfilerEventHandler {
32  public:
onIterationStart(int iteration)33   virtual void onIterationStart(int iteration) override {
34     PYBIND11_OVERRIDE(void, ProfilerEventHandler, onIterationStart, iteration);
35   }
emulateTraining(int iteration,int thread_id)36   virtual void emulateTraining(int iteration, int thread_id) override {
37     PYBIND11_OVERRIDE(
38         void, ProfilerEventHandler, emulateTraining, iteration, thread_id);
39   }
40 };
41 
42 /**
43  * This is the entry point for the C++ training engine.
44  */
start_threads(int thread_count,int iteration_count,bool attach)45 void start_threads(int thread_count, int iteration_count, bool attach) {
46   blueprint("start_cpp_threads called");
47 
48   static std::atomic<int> barrier = 0;
49   barrier = 0;
50   thread_local bool enabled_in_main_thread = false;
51 
52   std::vector<std::thread> threads;
53   for (int id = 0; id < thread_count; id++) {
54     blueprint("starting thread " + std::to_string(id));
55     threads.emplace_back([thread_count, iteration_count, id, attach]() {
56       for (int iteration = 0; iteration < iteration_count; iteration++) {
57         if (id == 0) {
58           ProfilerEventHandler::Handler->onIterationStart(iteration);
59         }
60 
61         // this barrier makes sure all child threads will be turned on
62         // with profiling when main thread is enabled
63         ++barrier;
64         while (barrier % thread_count) {
65           std::this_thread::yield();
66         }
67 
68         if (id > 0 && attach) {
69           bool enabled = isProfilerEnabledInMainThread();
70           if (enabled != enabled_in_main_thread) {
71             if (enabled) {
72               enableProfilerInChildThread();
73             } else {
74               disableProfilerInChildThread();
75             }
76             enabled_in_main_thread = enabled;
77           }
78         }
79 
80         ProfilerEventHandler::Handler->emulateTraining(iteration, id);
81       }
82     });
83   }
84   for (auto& t : threads) {
85     t.join();
86   }
87 }
88 
PYBIND11_MODULE(profiler_test_cpp_thread_lib,m)89 PYBIND11_MODULE(profiler_test_cpp_thread_lib, m) {
90   py::class_<
91       ProfilerEventHandler,
92       ProfilerEventHandlerTrampoline,
93       std::shared_ptr<ProfilerEventHandler>>(m, "ProfilerEventHandler")
94       .def(py::init<>())
95       .def_static("Register", &ProfilerEventHandler::Register)
96       .def(
97           "onIterationStart",
98           &ProfilerEventHandler::onIterationStart,
99           py::call_guard<py::gil_scoped_release>())
100       .def(
101           "emulateTraining",
102           &ProfilerEventHandler::emulateTraining,
103           py::call_guard<py::gil_scoped_release>());
104 
105   m.def(
106       "start_threads",
107       &start_threads,
108       py::call_guard<py::gil_scoped_release>());
109 };
110