1 #include <utility>
2
3 #include <c10/util/WaitCounter.h>
4
5 #include <torch/csrc/utils/pybind.h>
6 #include <torch/csrc/utils/python_arg_parser.h>
7 #include <torch/csrc/utils/python_numbers.h>
8 #include <torch/csrc/utils/python_strings.h>
9
10 #include <pybind11/chrono.h>
11 #include <pybind11/functional.h>
12 #include <pybind11/operators.h>
13 #include <pybind11/stl.h>
14
15 #include <torch/csrc/monitor/counters.h>
16 #include <torch/csrc/monitor/events.h>
17
18 namespace pybind11 {
19 namespace detail {
20 template <>
21 struct type_caster<torch::monitor::data_value_t> {
22 public:
23 PYBIND11_TYPE_CASTER(torch::monitor::data_value_t, _("data_value_t"));
24
25 // Python -> C++
loadpybind11::detail::type_caster26 bool load(handle src, bool) {
27 PyObject* source = src.ptr();
28 if (THPUtils_checkLong(source)) {
29 this->value = THPUtils_unpackLong(source);
30 } else if (THPUtils_checkDouble(source)) {
31 this->value = THPUtils_unpackDouble(source);
32 } else if (THPUtils_checkString(source)) {
33 this->value = THPUtils_unpackString(source);
34 } else if (PyBool_Check(source)) {
35 this->value = THPUtils_unpackBool(source);
36 } else {
37 return false;
38 }
39 return !PyErr_Occurred();
40 }
41
42 // C++ -> Python
castpybind11::detail::type_caster43 static handle cast(
44 torch::monitor::data_value_t src,
45 return_value_policy /* policy */,
46 handle /* parent */) {
47 if (std::holds_alternative<double>(src)) {
48 return PyFloat_FromDouble(std::get<double>(src));
49 } else if (std::holds_alternative<int64_t>(src)) {
50 return THPUtils_packInt64(std::get<int64_t>(src));
51 } else if (std::holds_alternative<bool>(src)) {
52 if (std::get<bool>(src)) {
53 Py_RETURN_TRUE;
54 } else {
55 Py_RETURN_FALSE;
56 }
57 } else if (std::holds_alternative<std::string>(src)) {
58 std::string str = std::get<std::string>(src);
59 return THPUtils_packString(str);
60 }
61 throw std::runtime_error("unknown data_value_t type");
62 }
63 };
64 } // namespace detail
65 } // namespace pybind11
66
67 namespace torch {
68 namespace monitor {
69
70 namespace {
71 class PythonEventHandler : public EventHandler {
72 public:
PythonEventHandler(std::function<void (const Event &)> handler)73 explicit PythonEventHandler(std::function<void(const Event&)> handler)
74 : handler_(std::move(handler)) {}
75
handle(const Event & e)76 void handle(const Event& e) override {
77 handler_(e);
78 }
79
80 private:
81 std::function<void(const Event&)> handler_;
82 };
83 } // namespace
84
initMonitorBindings(PyObject * module)85 void initMonitorBindings(PyObject* module) {
86 auto rootModule = py::handle(module).cast<py::module>();
87
88 auto m = rootModule.def_submodule("_monitor");
89
90 py::enum_<Aggregation>(
91 m,
92 "Aggregation",
93 R"DOC(
94 These are types of aggregations that can be used to accumulate stats.
95 )DOC")
96 .value(
97 "VALUE",
98 Aggregation::NONE,
99 R"DOC(
100 VALUE returns the last value to be added.
101 )DOC")
102 .value(
103 "MEAN",
104 Aggregation::MEAN,
105 R"DOC(
106 MEAN computes the arithmetic mean of all the added values.
107 )DOC")
108 .value(
109 "COUNT",
110 Aggregation::COUNT,
111 R"DOC(
112 COUNT returns the total number of added values.
113 )DOC")
114 .value(
115 "SUM",
116 Aggregation::SUM,
117 R"DOC(
118 SUM returns the sum of the added values.
119 )DOC")
120 .value(
121 "MAX",
122 Aggregation::MAX,
123 R"DOC(
124 MAX returns the max of the added values.
125 )DOC")
126 .value(
127 "MIN",
128 Aggregation::MIN,
129 R"DOC(
130 MIN returns the min of the added values.
131 )DOC")
132 .export_values();
133
134 py::class_<Stat<double>>(
135 m,
136 "Stat",
137 R"DOC(
138 Stat is used to compute summary statistics in a performant way over
139 fixed intervals. Stat logs the statistics as an Event once every
140 ``window_size`` duration. When the window closes the stats are logged
141 via the event handlers as a ``torch.monitor.Stat`` event.
142
143 ``window_size`` should be set to something relatively high to avoid a
144 huge number of events being logged. Ex: 60s. Stat uses millisecond
145 precision.
146
147 If ``max_samples`` is set, the stat will cap the number of samples per
148 window by discarding `add` calls once ``max_samples`` adds have
149 occurred. If it's not set, all ``add`` calls during the window will be
150 included. This is an optional field to make aggregations more directly
151 comparable across windows when the number of samples might vary.
152
153 When the Stat is destructed it will log any remaining data even if the
154 window hasn't elapsed.
155 )DOC")
156 .def(
157 py::init<
158 std::string,
159 std::vector<Aggregation>,
160 std::chrono::milliseconds,
161 int64_t>(),
162 py::arg("name"),
163 py::arg("aggregations"),
164 py::arg("window_size"),
165 py::arg("max_samples") = std::numeric_limits<int64_t>::max(),
166 R"DOC(
167 Constructs the ``Stat``.
168 )DOC")
169 .def(
170 "add",
171 &Stat<double>::add,
172 py::arg("v"),
173 R"DOC(
174 Adds a value to the stat to be aggregated according to the
175 configured stat type and aggregations.
176 )DOC")
177 .def(
178 "get",
179 &Stat<double>::get,
180 R"DOC(
181 Returns the current value of the stat, primarily for testing
182 purposes. If the stat has logged and no additional values have been
183 added this will be zero.
184 )DOC")
185 .def_property_readonly(
186 "name",
187 &Stat<double>::name,
188 R"DOC(
189 The name of the stat that was set during creation.
190 )DOC")
191 .def_property_readonly(
192 "count",
193 &Stat<double>::count,
194 R"DOC(
195 Number of data points that have currently been collected. Resets
196 once the event has been logged.
197 )DOC");
198
199 py::class_<Event>(
200 m,
201 "Event",
202 R"DOC(
203 Event represents a specific typed event to be logged. This can represent
204 high-level data points such as loss or accuracy per epoch or more
205 low-level aggregations such as through the Stats provided through this
206 library.
207
208 All Events of the same type should have the same name so downstream
209 handlers can correctly process them.
210 )DOC")
211 .def(
212 py::init([](const std::string& name,
213 std::chrono::system_clock::time_point timestamp,
214 std::unordered_map<std::string, data_value_t> data) {
215 Event e;
216 e.name = name;
217 e.timestamp = timestamp;
218 e.data = std::move(data);
219 return e;
220 }),
221 py::arg("name"),
222 py::arg("timestamp"),
223 py::arg("data"),
224 R"DOC(
225 Constructs the ``Event``.
226 )DOC")
227 .def_readwrite(
228 "name",
229 &Event::name,
230 R"DOC(
231 The name of the ``Event``.
232 )DOC")
233 .def_readwrite(
234 "timestamp",
235 &Event::timestamp,
236 R"DOC(
237 The timestamp when the ``Event`` happened.
238 )DOC")
239 .def_readwrite(
240 "data",
241 &Event::data,
242 R"DOC(
243 The structured data contained within the ``Event``.
244 )DOC");
245
246 m.def(
247 "log_event",
248 &logEvent,
249 py::arg("event"),
250 R"DOC(
251 log_event logs the specified event to all of the registered event
252 handlers. It's up to the event handlers to log the event out to the
253 corresponding event sink.
254
255 If there are no event handlers registered this method is a no-op.
256 )DOC");
257
258 py::class_<data_value_t> dataClass(
259 m,
260 "data_value_t",
261 R"DOC(
262 data_value_t is one of ``str``, ``float``, ``int``, ``bool``.
263 )DOC");
264
265 py::implicitly_convertible<std::string, data_value_t>();
266 py::implicitly_convertible<double, data_value_t>();
267 py::implicitly_convertible<int64_t, data_value_t>();
268 py::implicitly_convertible<bool, data_value_t>();
269
270 py::class_<PythonEventHandler, std::shared_ptr<PythonEventHandler>>
271 eventHandlerClass(m, "EventHandlerHandle", R"DOC(
272 EventHandlerHandle is a wrapper type returned by
273 ``register_event_handler`` used to unregister the handler via
274 ``unregister_event_handler``. This cannot be directly initialized.
275 )DOC");
276 m.def(
277 "register_event_handler",
278 [](std::function<void(const Event&)> f) {
279 auto handler = std::make_shared<PythonEventHandler>(std::move(f));
280 registerEventHandler(handler);
281 return handler;
282 },
283 py::arg("callback"),
284 R"DOC(
285 register_event_handler registers a callback to be called whenever an
286 event is logged via ``log_event``. These handlers should avoid blocking
287 the main thread since that may interfere with training as they run
288 during the ``log_event`` call.
289 )DOC");
290 m.def(
291 "unregister_event_handler",
292 [](const std::shared_ptr<PythonEventHandler>& handler) {
293 unregisterEventHandler(handler);
294 },
295 py::arg("handler"),
296 R"DOC(
297 unregister_event_handler unregisters the ``EventHandlerHandle`` returned
298 after calling ``register_event_handler``. After this returns the event
299 handler will no longer receive events.
300 )DOC");
301
302 struct WaitCounterTracker {
303 explicit WaitCounterTracker(const c10::monitor::WaitCounterHandle& h)
304 : handle{h} {}
305 c10::monitor::WaitCounterHandle handle;
306 std::optional<c10::monitor::WaitCounterHandle::WaitGuard> guard;
307 };
308 py::class_<WaitCounterTracker, std::shared_ptr<WaitCounterTracker>>(
309 m, "_WaitCounterTracker")
310 .def(
311 "__enter__",
312 [](const std::shared_ptr<WaitCounterTracker>& self) {
313 self->guard.emplace(self->handle.start());
314 })
315 .def(
316 "__exit__",
317 [](const std::shared_ptr<WaitCounterTracker>& self,
318 const pybind11::args&) { self->guard.reset(); });
319
320 py::class_<c10::monitor::WaitCounterHandle>(
321 m,
322 "_WaitCounter",
323 R"DOC(
324 WaitCounter represents a named duration counter.
325 Multiple units of work can be tracked by the same WaitCounter. Depending
326 on the backend, the WaitCounter may track the number of units of work,
327 their duration etc.
328 )DOC")
329 .def(
330 py::init([](const std::string& key) {
331 return std::make_unique<c10::monitor::WaitCounterHandle>(key);
332 }),
333 py::arg("key"))
334 .def(
335 "guard",
336 [](const c10::monitor::WaitCounterHandle* self) {
337 return std::make_shared<WaitCounterTracker>(*self);
338 },
339 R"DOC(
340 Creates a guard that manages a single unit of work.
341 )DOC");
342 }
343
344 } // namespace monitor
345 } // namespace torch
346