1 #include <torch/csrc/utils/pybind.h>
2 #include <optional>
3 #include <tuple>
4
5 namespace torch::impl {
6
7 template <typename GuardT, typename... Args>
8 struct RAIIContextManager {
RAIIContextManagerRAIIContextManager9 explicit RAIIContextManager(Args&&... args)
10 : args_(std::forward<Args>(args)...) {}
11
enterRAIIContextManager12 void enter() {
13 auto emplace = [&](Args... args) {
14 guard_.emplace(std::forward<Args>(args)...);
15 };
16 std::apply(std::move(emplace), args_);
17 }
18
exitRAIIContextManager19 void exit() {
20 guard_ = std::nullopt;
21 }
22
23 private:
24 std::optional<GuardT> guard_;
25 std::tuple<Args...> args_;
26 };
27
28 // Turns a C++ RAII guard into a Python context manager.
29 // See _ExcludeDispatchKeyGuard in python_dispatch.cpp for example.
30 template <typename GuardT, typename... GuardArgs>
py_context_manager(const py::module & m,const char * name)31 void py_context_manager(const py::module& m, const char* name) {
32 using ContextManagerT = RAIIContextManager<GuardT, GuardArgs...>;
33 py::class_<ContextManagerT>(m, name)
34 .def(py::init<GuardArgs...>())
35 .def("__enter__", [](ContextManagerT& guard) { guard.enter(); })
36 .def(
37 "__exit__",
38 [](ContextManagerT& guard,
39 const py::object& exc_type,
40 const py::object& exc_value,
41 const py::object& traceback) { guard.exit(); });
42 }
43
44 template <typename GuardT, typename... Args>
45 struct DeprecatedRAIIContextManager {
DeprecatedRAIIContextManagerDeprecatedRAIIContextManager46 explicit DeprecatedRAIIContextManager(Args&&... args) {
47 guard_.emplace(std::forward<Args>(args)...);
48 }
49
enterDeprecatedRAIIContextManager50 void enter() {}
51
exitDeprecatedRAIIContextManager52 void exit() {
53 guard_ = std::nullopt;
54 }
55
56 private:
57 std::optional<GuardT> guard_;
58 std::tuple<Args...> args_;
59 };
60
61 // Definition: a "Python RAII guard" is an object in Python that acquires
62 // a resource on init and releases the resource on deletion.
63 //
64 // This API turns a C++ RAII guard into an object can be used either as a
65 // Python context manager or as a "Python RAII guard".
66 //
67 // Please prefer `py_context_manager` to this API if you are binding a new
68 // RAII guard into Python because "Python RAII guards" don't work as expected
69 // in Python (Python makes no guarantees about when an object gets deleted)
70 template <typename GuardT, typename... GuardArgs>
py_context_manager_DEPRECATED(const py::module & m,const char * name)71 void py_context_manager_DEPRECATED(const py::module& m, const char* name) {
72 using ContextManagerT = DeprecatedRAIIContextManager<GuardT, GuardArgs...>;
73 py::class_<ContextManagerT>(m, name)
74 .def(py::init<GuardArgs...>())
75 .def("__enter__", [](ContextManagerT& guard) { guard.enter(); })
76 .def(
77 "__exit__",
78 [](ContextManagerT& guard,
79 const py::object& exc_type,
80 const py::object& exc_value,
81 const py::object& traceback) { guard.exit(); });
82 }
83
84 } // namespace torch::impl
85