xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/python/python_ivalue.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 #include <ATen/core/ivalue.h>
3 #include <pybind11/pybind11.h>
4 #include <torch/csrc/jit/python/pybind_utils.h>
5 #include <torch/csrc/python_headers.h>
6 #include <torch/csrc/utils/pybind.h>
7 
8 namespace py = pybind11;
9 
10 namespace c10::ivalue {
11 
12 // concrete ivalue Holder that hold a py::object
13 struct C10_EXPORT ConcretePyObjectHolder final : PyObjectHolder {
14  public:
createfinal15   static c10::intrusive_ptr<PyObjectHolder> create(py::object py_obj) {
16     return c10::make_intrusive<ConcretePyObjectHolder>(std::move(py_obj));
17   }
18 
createfinal19   static c10::intrusive_ptr<PyObjectHolder> create(const py::handle& handle) {
20     py::gil_scoped_acquire ag;
21     return c10::make_intrusive<ConcretePyObjectHolder>(
22         handle.cast<py::object>());
23   }
24 
getPyObjectfinal25   PyObject* getPyObject() override {
26     return py_obj_.ptr();
27   }
28 
tryToInferTypefinal29   InferredType tryToInferType() override {
30     pybind11::gil_scoped_acquire ag;
31     return torch::jit::tryToInferType(py_obj_);
32   }
33 
34   IValue toIValue(const TypePtr& type, std::optional<int32_t> N = std::nullopt)
35       override {
36     pybind11::gil_scoped_acquire ag;
37     return torch::jit::toIValue(py_obj_, type, N);
38   }
39 
toStrfinal40   std::string toStr() override {
41     pybind11::gil_scoped_acquire ag;
42     return py::str(py_obj_);
43   }
44 
extractTensorsfinal45   std::vector<at::Tensor> extractTensors() override {
46     // We could implement this entirely in C++ via pybind11 but it turns out to
47     // be substantially slower. Namely, the total time taken by markCompleted on
48     // a CUDAFuture is 21.5us with this implementation, but goes up to 58.7us
49     // when using C++. The reason is unclear.
50     try {
51       pybind11::gil_scoped_acquire ag;
52       static py::object& extractorFn = *new py::object(
53           py::module::import("torch._jit_internal").attr("_extract_tensors"));
54       return extractorFn(py_obj_).cast<std::vector<at::Tensor>>();
55     } catch (py::error_already_set& e) {
56       auto err = std::runtime_error(
57           c10::str("Cannot extract tensors from value: ", e.what()));
58       {
59         pybind11::gil_scoped_acquire ag;
60         e.restore();
61         PyErr_Clear();
62       }
63       throw std::runtime_error(err);
64     }
65   }
66 
67   // Note [Destructing py::object]
68   // ~~~~~~~~~~~~~~~~~~~~~~~~~~
69   //
70   // (1) Why py_obj_ = py::none(); does not work. Because we also need to
71   // acquire GIL when destructing py::object of None that de-references None.
72   // https://docs.python.org/3/c-api/none.html#c.Py_RETURN_NONE
73   //
74   // https://stackoverflow.com/questions/15287590/why-should-py-increfpy-none-be-required-before-returning-py-none-in-c
75   //
76   // (2) Why we need to call dec_ref() explicitly. Because py::object of
77   // nullptr, on destruction, effectively does nothing because of it calls
78   // Py_XDECREF(NULL) underlying.
79   // https://docs.python.org/3/c-api/refcounting.html#c.Py_XDECREF
~ConcretePyObjectHolderfinal80   ~ConcretePyObjectHolder() override {
81     pybind11::gil_scoped_acquire ag;
82     py_obj_.dec_ref();
83     // explicitly setting PyObject* to nullptr to prevent py::object's dtor to
84     // decref on the PyObject again.
85     py_obj_.ptr() = nullptr;
86   }
87 
88   // explicit construction to avoid errornous implicit conversion and
89   // copy-initialization
ConcretePyObjectHolderfinal90   explicit ConcretePyObjectHolder(py::object py_obj)
91       : py_obj_(std::move(py_obj)) {}
92 
93  private:
94   py::object py_obj_;
95 };
96 
97 } // namespace c10::ivalue
98