1
2 #include <torch/csrc/autograd/profiler_kineto.h>
3 #include <torch/torch.h>
4 #include <string>
5
6 using namespace torch::autograd::profiler;
7
blueprint(const std::string & text)8 void blueprint(const std::string& text) {
9 printf("\33[94m%s\33[0m\n", text.c_str());
10 }
11
12 /**
13 * We're emulating a C++ training engine calling into Python to allow Python
14 * code controlling how profiling should be done.
15 */
16 class ProfilerEventHandler
17 : public std::enable_shared_from_this<ProfilerEventHandler> {
18 public:
19 static std::shared_ptr<ProfilerEventHandler> Handler;
Register(const std::shared_ptr<ProfilerEventHandler> & handler)20 static void Register(const std::shared_ptr<ProfilerEventHandler>& handler) {
21 Handler = handler;
22 }
23
24 public:
~ProfilerEventHandler()25 virtual ~ProfilerEventHandler() {}
onIterationStart(int)26 virtual void onIterationStart(int) {}
emulateTraining(int,int)27 virtual void emulateTraining(int, int) {}
28 };
29 std::shared_ptr<ProfilerEventHandler> ProfilerEventHandler::Handler;
30
31 class ProfilerEventHandlerTrampoline : public ProfilerEventHandler {
32 public:
onIterationStart(int iteration)33 virtual void onIterationStart(int iteration) override {
34 PYBIND11_OVERRIDE(void, ProfilerEventHandler, onIterationStart, iteration);
35 }
emulateTraining(int iteration,int thread_id)36 virtual void emulateTraining(int iteration, int thread_id) override {
37 PYBIND11_OVERRIDE(
38 void, ProfilerEventHandler, emulateTraining, iteration, thread_id);
39 }
40 };
41
42 /**
43 * This is the entry point for the C++ training engine.
44 */
start_threads(int thread_count,int iteration_count,bool attach)45 void start_threads(int thread_count, int iteration_count, bool attach) {
46 blueprint("start_cpp_threads called");
47
48 static std::atomic<int> barrier = 0;
49 barrier = 0;
50 thread_local bool enabled_in_main_thread = false;
51
52 std::vector<std::thread> threads;
53 for (int id = 0; id < thread_count; id++) {
54 blueprint("starting thread " + std::to_string(id));
55 threads.emplace_back([thread_count, iteration_count, id, attach]() {
56 for (int iteration = 0; iteration < iteration_count; iteration++) {
57 if (id == 0) {
58 ProfilerEventHandler::Handler->onIterationStart(iteration);
59 }
60
61 // this barrier makes sure all child threads will be turned on
62 // with profiling when main thread is enabled
63 ++barrier;
64 while (barrier % thread_count) {
65 std::this_thread::yield();
66 }
67
68 if (id > 0 && attach) {
69 bool enabled = isProfilerEnabledInMainThread();
70 if (enabled != enabled_in_main_thread) {
71 if (enabled) {
72 enableProfilerInChildThread();
73 } else {
74 disableProfilerInChildThread();
75 }
76 enabled_in_main_thread = enabled;
77 }
78 }
79
80 ProfilerEventHandler::Handler->emulateTraining(iteration, id);
81 }
82 });
83 }
84 for (auto& t : threads) {
85 t.join();
86 }
87 }
88
PYBIND11_MODULE(profiler_test_cpp_thread_lib,m)89 PYBIND11_MODULE(profiler_test_cpp_thread_lib, m) {
90 py::class_<
91 ProfilerEventHandler,
92 ProfilerEventHandlerTrampoline,
93 std::shared_ptr<ProfilerEventHandler>>(m, "ProfilerEventHandler")
94 .def(py::init<>())
95 .def_static("Register", &ProfilerEventHandler::Register)
96 .def(
97 "onIterationStart",
98 &ProfilerEventHandler::onIterationStart,
99 py::call_guard<py::gil_scoped_release>())
100 .def(
101 "emulateTraining",
102 &ProfilerEventHandler::emulateTraining,
103 py::call_guard<py::gil_scoped_release>());
104
105 m.def(
106 "start_threads",
107 &start_threads,
108 py::call_guard<py::gil_scoped_release>());
109 };
110