xref: /aosp_15_r20/external/pytorch/torch/csrc/utils/python_raii.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/utils/pybind.h>
2 #include <optional>
3 #include <tuple>
4 
5 namespace torch::impl {
6 
7 template <typename GuardT, typename... Args>
8 struct RAIIContextManager {
RAIIContextManagerRAIIContextManager9   explicit RAIIContextManager(Args&&... args)
10       : args_(std::forward<Args>(args)...) {}
11 
enterRAIIContextManager12   void enter() {
13     auto emplace = [&](Args... args) {
14       guard_.emplace(std::forward<Args>(args)...);
15     };
16     std::apply(std::move(emplace), args_);
17   }
18 
exitRAIIContextManager19   void exit() {
20     guard_ = std::nullopt;
21   }
22 
23  private:
24   std::optional<GuardT> guard_;
25   std::tuple<Args...> args_;
26 };
27 
28 // Turns a C++ RAII guard into a Python context manager.
29 // See _ExcludeDispatchKeyGuard in python_dispatch.cpp for example.
30 template <typename GuardT, typename... GuardArgs>
py_context_manager(const py::module & m,const char * name)31 void py_context_manager(const py::module& m, const char* name) {
32   using ContextManagerT = RAIIContextManager<GuardT, GuardArgs...>;
33   py::class_<ContextManagerT>(m, name)
34       .def(py::init<GuardArgs...>())
35       .def("__enter__", [](ContextManagerT& guard) { guard.enter(); })
36       .def(
37           "__exit__",
38           [](ContextManagerT& guard,
39              const py::object& exc_type,
40              const py::object& exc_value,
41              const py::object& traceback) { guard.exit(); });
42 }
43 
44 template <typename GuardT, typename... Args>
45 struct DeprecatedRAIIContextManager {
DeprecatedRAIIContextManagerDeprecatedRAIIContextManager46   explicit DeprecatedRAIIContextManager(Args&&... args) {
47     guard_.emplace(std::forward<Args>(args)...);
48   }
49 
enterDeprecatedRAIIContextManager50   void enter() {}
51 
exitDeprecatedRAIIContextManager52   void exit() {
53     guard_ = std::nullopt;
54   }
55 
56  private:
57   std::optional<GuardT> guard_;
58   std::tuple<Args...> args_;
59 };
60 
61 // Definition: a "Python RAII guard" is an object in Python that acquires
62 // a resource on init and releases the resource on deletion.
63 //
64 // This API turns a C++ RAII guard into an object can be used either as a
65 // Python context manager or as a "Python RAII guard".
66 //
67 // Please prefer `py_context_manager` to this API if you are binding a new
68 // RAII guard into Python because "Python RAII guards" don't work as expected
69 // in Python (Python makes no guarantees about when an object gets deleted)
70 template <typename GuardT, typename... GuardArgs>
py_context_manager_DEPRECATED(const py::module & m,const char * name)71 void py_context_manager_DEPRECATED(const py::module& m, const char* name) {
72   using ContextManagerT = DeprecatedRAIIContextManager<GuardT, GuardArgs...>;
73   py::class_<ContextManagerT>(m, name)
74       .def(py::init<GuardArgs...>())
75       .def("__enter__", [](ContextManagerT& guard) { guard.enter(); })
76       .def(
77           "__exit__",
78           [](ContextManagerT& guard,
79              const py::object& exc_type,
80              const py::object& exc_value,
81              const py::object& traceback) { guard.exit(); });
82 }
83 
84 } // namespace torch::impl
85