xref: /aosp_15_r20/external/pytorch/torch/csrc/monitor/python_init.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <utility>
2 
3 #include <c10/util/WaitCounter.h>
4 
5 #include <torch/csrc/utils/pybind.h>
6 #include <torch/csrc/utils/python_arg_parser.h>
7 #include <torch/csrc/utils/python_numbers.h>
8 #include <torch/csrc/utils/python_strings.h>
9 
10 #include <pybind11/chrono.h>
11 #include <pybind11/functional.h>
12 #include <pybind11/operators.h>
13 #include <pybind11/stl.h>
14 
15 #include <torch/csrc/monitor/counters.h>
16 #include <torch/csrc/monitor/events.h>
17 
18 namespace pybind11 {
19 namespace detail {
20 template <>
21 struct type_caster<torch::monitor::data_value_t> {
22  public:
23   PYBIND11_TYPE_CASTER(torch::monitor::data_value_t, _("data_value_t"));
24 
25   // Python -> C++
loadpybind11::detail::type_caster26   bool load(handle src, bool) {
27     PyObject* source = src.ptr();
28     if (THPUtils_checkLong(source)) {
29       this->value = THPUtils_unpackLong(source);
30     } else if (THPUtils_checkDouble(source)) {
31       this->value = THPUtils_unpackDouble(source);
32     } else if (THPUtils_checkString(source)) {
33       this->value = THPUtils_unpackString(source);
34     } else if (PyBool_Check(source)) {
35       this->value = THPUtils_unpackBool(source);
36     } else {
37       return false;
38     }
39     return !PyErr_Occurred();
40   }
41 
42   // C++ -> Python
castpybind11::detail::type_caster43   static handle cast(
44       torch::monitor::data_value_t src,
45       return_value_policy /* policy */,
46       handle /* parent */) {
47     if (std::holds_alternative<double>(src)) {
48       return PyFloat_FromDouble(std::get<double>(src));
49     } else if (std::holds_alternative<int64_t>(src)) {
50       return THPUtils_packInt64(std::get<int64_t>(src));
51     } else if (std::holds_alternative<bool>(src)) {
52       if (std::get<bool>(src)) {
53         Py_RETURN_TRUE;
54       } else {
55         Py_RETURN_FALSE;
56       }
57     } else if (std::holds_alternative<std::string>(src)) {
58       std::string str = std::get<std::string>(src);
59       return THPUtils_packString(str);
60     }
61     throw std::runtime_error("unknown data_value_t type");
62   }
63 };
64 } // namespace detail
65 } // namespace pybind11
66 
67 namespace torch {
68 namespace monitor {
69 
70 namespace {
71 class PythonEventHandler : public EventHandler {
72  public:
PythonEventHandler(std::function<void (const Event &)> handler)73   explicit PythonEventHandler(std::function<void(const Event&)> handler)
74       : handler_(std::move(handler)) {}
75 
handle(const Event & e)76   void handle(const Event& e) override {
77     handler_(e);
78   }
79 
80  private:
81   std::function<void(const Event&)> handler_;
82 };
83 } // namespace
84 
initMonitorBindings(PyObject * module)85 void initMonitorBindings(PyObject* module) {
86   auto rootModule = py::handle(module).cast<py::module>();
87 
88   auto m = rootModule.def_submodule("_monitor");
89 
90   py::enum_<Aggregation>(
91       m,
92       "Aggregation",
93       R"DOC(
94         These are types of aggregations that can be used to accumulate stats.
95       )DOC")
96       .value(
97           "VALUE",
98           Aggregation::NONE,
99           R"DOC(
100             VALUE returns the last value to be added.
101           )DOC")
102       .value(
103           "MEAN",
104           Aggregation::MEAN,
105           R"DOC(
106             MEAN computes the arithmetic mean of all the added values.
107           )DOC")
108       .value(
109           "COUNT",
110           Aggregation::COUNT,
111           R"DOC(
112             COUNT returns the total number of added values.
113           )DOC")
114       .value(
115           "SUM",
116           Aggregation::SUM,
117           R"DOC(
118             SUM returns the sum of the added values.
119           )DOC")
120       .value(
121           "MAX",
122           Aggregation::MAX,
123           R"DOC(
124             MAX returns the max of the added values.
125           )DOC")
126       .value(
127           "MIN",
128           Aggregation::MIN,
129           R"DOC(
130             MIN returns the min of the added values.
131           )DOC")
132       .export_values();
133 
134   py::class_<Stat<double>>(
135       m,
136       "Stat",
137       R"DOC(
138         Stat is used to compute summary statistics in a performant way over
139         fixed intervals. Stat logs the statistics as an Event once every
140         ``window_size`` duration. When the window closes the stats are logged
141         via the event handlers as a ``torch.monitor.Stat`` event.
142 
143         ``window_size`` should be set to something relatively high to avoid a
144         huge number of events being logged. Ex: 60s. Stat uses millisecond
145         precision.
146 
147         If ``max_samples`` is set, the stat will cap the number of samples per
148         window by discarding `add` calls once ``max_samples`` adds have
149         occurred. If it's not set, all ``add`` calls during the window will be
150         included. This is an optional field to make aggregations more directly
151         comparable across windows when the number of samples might vary.
152 
153         When the Stat is destructed it will log any remaining data even if the
154         window hasn't elapsed.
155       )DOC")
156       .def(
157           py::init<
158               std::string,
159               std::vector<Aggregation>,
160               std::chrono::milliseconds,
161               int64_t>(),
162           py::arg("name"),
163           py::arg("aggregations"),
164           py::arg("window_size"),
165           py::arg("max_samples") = std::numeric_limits<int64_t>::max(),
166           R"DOC(
167            Constructs the ``Stat``.
168           )DOC")
169       .def(
170           "add",
171           &Stat<double>::add,
172           py::arg("v"),
173           R"DOC(
174             Adds a value to the stat to be aggregated according to the
175             configured stat type and aggregations.
176           )DOC")
177       .def(
178           "get",
179           &Stat<double>::get,
180           R"DOC(
181             Returns the current value of the stat, primarily for testing
182             purposes. If the stat has logged and no additional values have been
183             added this will be zero.
184           )DOC")
185       .def_property_readonly(
186           "name",
187           &Stat<double>::name,
188           R"DOC(
189             The name of the stat that was set during creation.
190           )DOC")
191       .def_property_readonly(
192           "count",
193           &Stat<double>::count,
194           R"DOC(
195             Number of data points that have currently been collected. Resets
196             once the event has been logged.
197           )DOC");
198 
199   py::class_<Event>(
200       m,
201       "Event",
202       R"DOC(
203         Event represents a specific typed event to be logged. This can represent
204         high-level data points such as loss or accuracy per epoch or more
205         low-level aggregations such as through the Stats provided through this
206         library.
207 
208         All Events of the same type should have the same name so downstream
209         handlers can correctly process them.
210       )DOC")
211       .def(
212           py::init([](const std::string& name,
213                       std::chrono::system_clock::time_point timestamp,
214                       std::unordered_map<std::string, data_value_t> data) {
215             Event e;
216             e.name = name;
217             e.timestamp = timestamp;
218             e.data = std::move(data);
219             return e;
220           }),
221           py::arg("name"),
222           py::arg("timestamp"),
223           py::arg("data"),
224           R"DOC(
225            Constructs the ``Event``.
226           )DOC")
227       .def_readwrite(
228           "name",
229           &Event::name,
230           R"DOC(
231             The name of the ``Event``.
232           )DOC")
233       .def_readwrite(
234           "timestamp",
235           &Event::timestamp,
236           R"DOC(
237             The timestamp when the ``Event`` happened.
238           )DOC")
239       .def_readwrite(
240           "data",
241           &Event::data,
242           R"DOC(
243             The structured data contained within the ``Event``.
244           )DOC");
245 
246   m.def(
247       "log_event",
248       &logEvent,
249       py::arg("event"),
250       R"DOC(
251         log_event logs the specified event to all of the registered event
252         handlers. It's up to the event handlers to log the event out to the
253         corresponding event sink.
254 
255         If there are no event handlers registered this method is a no-op.
256       )DOC");
257 
258   py::class_<data_value_t> dataClass(
259       m,
260       "data_value_t",
261       R"DOC(
262         data_value_t is one of ``str``, ``float``, ``int``, ``bool``.
263       )DOC");
264 
265   py::implicitly_convertible<std::string, data_value_t>();
266   py::implicitly_convertible<double, data_value_t>();
267   py::implicitly_convertible<int64_t, data_value_t>();
268   py::implicitly_convertible<bool, data_value_t>();
269 
270   py::class_<PythonEventHandler, std::shared_ptr<PythonEventHandler>>
271       eventHandlerClass(m, "EventHandlerHandle", R"DOC(
272         EventHandlerHandle is a wrapper type returned by
273         ``register_event_handler`` used to unregister the handler via
274         ``unregister_event_handler``. This cannot be directly initialized.
275       )DOC");
276   m.def(
277       "register_event_handler",
278       [](std::function<void(const Event&)> f) {
279         auto handler = std::make_shared<PythonEventHandler>(std::move(f));
280         registerEventHandler(handler);
281         return handler;
282       },
283       py::arg("callback"),
284       R"DOC(
285         register_event_handler registers a callback to be called whenever an
286         event is logged via ``log_event``. These handlers should avoid blocking
287         the main thread since that may interfere with training as they run
288         during the ``log_event`` call.
289       )DOC");
290   m.def(
291       "unregister_event_handler",
292       [](const std::shared_ptr<PythonEventHandler>& handler) {
293         unregisterEventHandler(handler);
294       },
295       py::arg("handler"),
296       R"DOC(
297         unregister_event_handler unregisters the ``EventHandlerHandle`` returned
298         after calling ``register_event_handler``. After this returns the event
299         handler will no longer receive events.
300       )DOC");
301 
302   struct WaitCounterTracker {
303     explicit WaitCounterTracker(const c10::monitor::WaitCounterHandle& h)
304         : handle{h} {}
305     c10::monitor::WaitCounterHandle handle;
306     std::optional<c10::monitor::WaitCounterHandle::WaitGuard> guard;
307   };
308   py::class_<WaitCounterTracker, std::shared_ptr<WaitCounterTracker>>(
309       m, "_WaitCounterTracker")
310       .def(
311           "__enter__",
312           [](const std::shared_ptr<WaitCounterTracker>& self) {
313             self->guard.emplace(self->handle.start());
314           })
315       .def(
316           "__exit__",
317           [](const std::shared_ptr<WaitCounterTracker>& self,
318              const pybind11::args&) { self->guard.reset(); });
319 
320   py::class_<c10::monitor::WaitCounterHandle>(
321       m,
322       "_WaitCounter",
323       R"DOC(
324         WaitCounter represents a named duration counter.
325         Multiple units of work can be tracked by the same WaitCounter. Depending
326         on the backend, the WaitCounter may track the number of units of work,
327         their duration etc.
328       )DOC")
329       .def(
330           py::init([](const std::string& key) {
331             return std::make_unique<c10::monitor::WaitCounterHandle>(key);
332           }),
333           py::arg("key"))
334       .def(
335           "guard",
336           [](const c10::monitor::WaitCounterHandle* self) {
337             return std::make_shared<WaitCounterTracker>(*self);
338           },
339           R"DOC(
340             Creates a guard that manages a single unit of work.
341           )DOC");
342 }
343 
344 } // namespace monitor
345 } // namespace torch
346