1 #pragma once 2 #include <ATen/core/ivalue.h> 3 #include <pybind11/pybind11.h> 4 #include <torch/csrc/jit/python/pybind_utils.h> 5 #include <torch/csrc/python_headers.h> 6 #include <torch/csrc/utils/pybind.h> 7 8 namespace py = pybind11; 9 10 namespace c10::ivalue { 11 12 // concrete ivalue Holder that hold a py::object 13 struct C10_EXPORT ConcretePyObjectHolder final : PyObjectHolder { 14 public: createfinal15 static c10::intrusive_ptr<PyObjectHolder> create(py::object py_obj) { 16 return c10::make_intrusive<ConcretePyObjectHolder>(std::move(py_obj)); 17 } 18 createfinal19 static c10::intrusive_ptr<PyObjectHolder> create(const py::handle& handle) { 20 py::gil_scoped_acquire ag; 21 return c10::make_intrusive<ConcretePyObjectHolder>( 22 handle.cast<py::object>()); 23 } 24 getPyObjectfinal25 PyObject* getPyObject() override { 26 return py_obj_.ptr(); 27 } 28 tryToInferTypefinal29 InferredType tryToInferType() override { 30 pybind11::gil_scoped_acquire ag; 31 return torch::jit::tryToInferType(py_obj_); 32 } 33 34 IValue toIValue(const TypePtr& type, std::optional<int32_t> N = std::nullopt) 35 override { 36 pybind11::gil_scoped_acquire ag; 37 return torch::jit::toIValue(py_obj_, type, N); 38 } 39 toStrfinal40 std::string toStr() override { 41 pybind11::gil_scoped_acquire ag; 42 return py::str(py_obj_); 43 } 44 extractTensorsfinal45 std::vector<at::Tensor> extractTensors() override { 46 // We could implement this entirely in C++ via pybind11 but it turns out to 47 // be substantially slower. Namely, the total time taken by markCompleted on 48 // a CUDAFuture is 21.5us with this implementation, but goes up to 58.7us 49 // when using C++. The reason is unclear. 50 try { 51 pybind11::gil_scoped_acquire ag; 52 static py::object& extractorFn = *new py::object( 53 py::module::import("torch._jit_internal").attr("_extract_tensors")); 54 return extractorFn(py_obj_).cast<std::vector<at::Tensor>>(); 55 } catch (py::error_already_set& e) { 56 auto err = std::runtime_error( 57 c10::str("Cannot extract tensors from value: ", e.what())); 58 { 59 pybind11::gil_scoped_acquire ag; 60 e.restore(); 61 PyErr_Clear(); 62 } 63 throw std::runtime_error(err); 64 } 65 } 66 67 // Note [Destructing py::object] 68 // ~~~~~~~~~~~~~~~~~~~~~~~~~~ 69 // 70 // (1) Why py_obj_ = py::none(); does not work. Because we also need to 71 // acquire GIL when destructing py::object of None that de-references None. 72 // https://docs.python.org/3/c-api/none.html#c.Py_RETURN_NONE 73 // 74 // https://stackoverflow.com/questions/15287590/why-should-py-increfpy-none-be-required-before-returning-py-none-in-c 75 // 76 // (2) Why we need to call dec_ref() explicitly. Because py::object of 77 // nullptr, on destruction, effectively does nothing because of it calls 78 // Py_XDECREF(NULL) underlying. 79 // https://docs.python.org/3/c-api/refcounting.html#c.Py_XDECREF ~ConcretePyObjectHolderfinal80 ~ConcretePyObjectHolder() override { 81 pybind11::gil_scoped_acquire ag; 82 py_obj_.dec_ref(); 83 // explicitly setting PyObject* to nullptr to prevent py::object's dtor to 84 // decref on the PyObject again. 85 py_obj_.ptr() = nullptr; 86 } 87 88 // explicit construction to avoid errornous implicit conversion and 89 // copy-initialization ConcretePyObjectHolderfinal90 explicit ConcretePyObjectHolder(py::object py_obj) 91 : py_obj_(std::move(py_obj)) {} 92 93 private: 94 py::object py_obj_; 95 }; 96 97 } // namespace c10::ivalue 98